Repository: onyx-dot-app/onyx Branch: main Commit: d4a96d70f32f Files: 4225 Total size: 25.4 MB Directory structure: gitextract_2880u6lc/ ├── .git-blame-ignore-revs ├── .github/ │ ├── CODEOWNERS │ ├── actionlint.yml │ ├── actions/ │ │ ├── build-backend-image/ │ │ │ └── action.yml │ │ ├── build-integration-image/ │ │ │ └── action.yml │ │ ├── build-model-server-image/ │ │ │ └── action.yml │ │ ├── run-nightly-provider-chat-test/ │ │ │ └── action.yml │ │ ├── setup-playwright/ │ │ │ └── action.yml │ │ ├── setup-python-and-install-dependencies/ │ │ │ └── action.yml │ │ └── slack-notify/ │ │ ├── action.yml │ │ └── user-mappings.json │ ├── dependabot.yml │ ├── pull_request_template.md │ ├── runs-on.yml │ └── workflows/ │ ├── deployment.yml │ ├── docker-tag-beta.yml │ ├── docker-tag-latest.yml │ ├── helm-chart-releases.yml │ ├── merge-group.yml │ ├── nightly-close-stale-issues.yml │ ├── nightly-llm-provider-chat.yml │ ├── post-merge-beta-cherry-pick.yml │ ├── pr-database-tests.yml │ ├── pr-desktop-build.yml │ ├── pr-external-dependency-unit-tests.yml │ ├── pr-golang-tests.yml │ ├── pr-helm-chart-testing.yml │ ├── pr-integration-tests.yml │ ├── pr-jest-tests.yml │ ├── pr-labeler.yml │ ├── pr-linear-check.yml │ ├── pr-playwright-tests.yml │ ├── pr-python-checks.yml │ ├── pr-python-connector-tests.yml │ ├── pr-python-model-tests.yml │ ├── pr-python-tests.yml │ ├── pr-quality-checks.yml │ ├── preview.yml │ ├── release-cli.yml │ ├── release-devtools.yml │ ├── reusable-nightly-llm-provider-chat.yml │ ├── sandbox-deployment.yml │ ├── storybook-deploy.yml │ ├── sync_foss.yml │ ├── tag-nightly.yml │ └── zizmor.yml ├── .gitignore ├── .greptile/ │ ├── config.json │ ├── files.json │ └── rules.md ├── .pre-commit-config.yaml ├── .prettierignore ├── .vscode/ │ ├── env.web_template.txt │ ├── env_template.txt │ ├── launch.json │ └── tasks.template.jsonc ├── AGENTS.md ├── CONTRIBUTING.md ├── LICENSE ├── README.md ├── backend/ │ ├── .dockerignore │ ├── .gitignore │ ├── .trivyignore │ ├── Dockerfile │ ├── Dockerfile.model_server │ ├── alembic/ │ │ ├── README.md │ │ ├── env.py │ │ ├── run_multitenant_migrations.py │ │ ├── script.py.mako │ │ └── versions/ │ │ ├── 01f8e6d95a33_populate_flow_mapping_data.py │ │ ├── 027381bce97c_add_shortcut_option_for_users.py │ │ ├── 03bf8be6b53a_rework_kg_config.py │ │ ├── 03d085c5c38d_backfill_account_type.py │ │ ├── 03d710ccf29c_add_permission_sync_attempt_tables.py │ │ ├── 0568ccf46a6b_add_thread_specific_model_selection.py │ │ ├── 05c07bf07c00_add_search_doc_relevance_details.py │ │ ├── 07b98176f1de_code_interpreter_seed.py │ │ ├── 0816326d83aa_add_federated_connector_tables.py │ │ ├── 08a1eda20fe1_add_earliest_indexing_to_connector.py │ │ ├── 09995b8811eb_add_theme_preference_to_user.py │ │ ├── 0a2b51deb0b8_add_starter_prompts.py │ │ ├── 0a98909f2757_enable_encrypted_fields.py │ │ ├── 0bb4558f35df_add_scim_username_to_scim_user_mapping.py │ │ ├── 0cd424f32b1d_user_file_data_preparation_and_backfill.py │ │ ├── 0ebb1d516877_add_ccpair_deletion_failure_message.py │ │ ├── 0f7ff6d75b57_add_index_to_index_attempt_time_created.py │ │ ├── 114a638452db_add_default_app_mode_to_user.py │ │ ├── 12635f6655b7_drive_canonical_ids.py │ │ ├── 15326fcec57e_introduce_onyx_apis.py │ │ ├── 16c37a30adf2_user_file_relationship_migration.py │ │ ├── 173cae5bba26_port_config_store.py │ │ ├── 175ea04c7087_add_user_preferences.py │ │ ├── 177de57c21c9_display_custom_llm_models.py │ │ ├── 18b5b2524446_add_is_clarification_to_chat_message.py │ │ ├── 19c0ccb01687_migrate_to_contextual_rag_model.py │ │ ├── 1a03d2c2856b_add_indexes_to_document__tag.py │ │ ├── 1b10e1fda030_add_additional_data_to_notifications.py │ │ ├── 1b8206b29c5d_add_user_delete_cascades.py │ │ ├── 1d78c0ca7853_remove_voice_provider_deleted_column.py │ │ ├── 1f2a3b4c5d6e_add_internet_search_and_content_providers.py │ │ ├── 1f60f60c3401_embedding_model_search_settings.py │ │ ├── 2020d417ec84_single_onyx_craft_migration.py │ │ ├── 213fd978c6d8_notifications.py │ │ ├── 238b84885828_add_foreign_key_to_user__external_user_.py │ │ ├── 23957775e5f5_remove_feedback_foreignkey_constraint.py │ │ ├── 25a5501dc766_group_permissions_phase1.py │ │ ├── 2664261bfaab_add_cache_store_table.py │ │ ├── 2666d766cb9b_google_oauth2.py │ │ ├── 26b931506ecb_default_chosen_assistants_to_none.py │ │ ├── 27c6ecc08586_permission_framework.py │ │ ├── 27fb147a843f_add_timestamps_to_user_table.py │ │ ├── 2955778aa44c_add_chunk_count_to_document.py │ │ ├── 2a391f840e85_add_last_refreshed_at_mcp_server.py │ │ ├── 2acdef638fc2_add_switchover_type_field.py │ │ ├── 2b75d0a8ffcb_user_file_schema_cleanup.py │ │ ├── 2b90f3af54b8_usage_limits.py │ │ ├── 2c2430828bdf_add_unique_constraint_to_inputprompt_.py │ │ ├── 2cdeff6d8c93_set_built_in_to_default.py │ │ ├── 2d2304e27d8c_add_above_below_to_persona.py │ │ ├── 2daa494a0851_add_group_sync_time.py │ │ ├── 2f80c6a2550f_add_chat_session_specific_temperature_.py │ │ ├── 2f95e36923e6_add_indexing_coordination.py │ │ ├── 30c1d5744104_persona_datetime_aware.py │ │ ├── 325975216eb3_add_icon_color_and_icon_shape_to_persona.py │ │ ├── 33cb72ea4d80_single_tool_call_per_message.py │ │ ├── 33ea50e88f24_foreign_key_input_prompts.py │ │ ├── 351faebd379d_add_curator_fields.py │ │ ├── 35e518e0ddf4_properly_cascade.py │ │ ├── 35e6853a51d5_server_default_chosen_assistants.py │ │ ├── 369644546676_add_composite_index_for_index_attempt_.py │ │ ├── 36e9220ab794_update_kg_trigger_functions.py │ │ ├── 3781a5eb12cb_add_chunk_stats_table.py │ │ ├── 3879338f8ba1_add_tool_table.py │ │ ├── 38eda64af7fe_add_chat_session_sharing.py │ │ ├── 3934b1bc7b62_update_github_connector_repo_name_to_.py │ │ ├── 3a7802814195_add_alternate_assistant_to_chat_message.py │ │ ├── 3a78dba1080a_user_file_legacy_data_cleanup.py │ │ ├── 3b25685ff73c_move_is_public_to_cc_pair.py │ │ ├── 3bd4c84fe72f_improved_index.py │ │ ├── 3c5e35aa9af0_polling_document_count.py │ │ ├── 3c6531f32351_add_back_input_prompts.py │ │ ├── 3c9a65f1207f_seed_exa_provider_from_env.py │ │ ├── 3d1cca026fe8_add_oauth_config_and_user_tokens.py │ │ ├── 3fc5d75723b3_add_doc_metadata_field_in_document_model.py │ │ ├── 401c1ac29467_add_tables_for_ui_based_llm_.py │ │ ├── 40926a4dab77_reset_userfile_document_id_migrated_.py │ │ ├── 41fa44bef321_remove_default_prompt_shortcuts.py │ │ ├── 43cbbb3f5e6a_rename_index_origin_to_index_recursively.py │ │ ├── 44f856ae2a4a_add_cloud_embedding_model.py │ │ ├── 4505fd7302e1_added_is_internet_to_dbdoc.py │ │ ├── 465f78d9b7f9_larger_access_tokens_for_oauth.py │ │ ├── 46625e4745d4_remove_native_enum.py │ │ ├── 46b7a812670f_fix_user__external_user_group_id_fk.py │ │ ├── 4738e4b3bae1_pg_file_store.py │ │ ├── 473a1a7ca408_add_display_model_names_to_llm_provider.py │ │ ├── 47433d30de82_create_indexattempt_table.py │ │ ├── 475fcefe8826_add_name_to_api_key.py │ │ ├── 4794bc13e484_update_prompt_length.py │ │ ├── 47a07e1a38f1_fix_invalid_model_configurations_state.py │ │ ├── 47e5bef3a1d7_add_persona_categories.py │ │ ├── 48d14957fe80_add_support_for_custom_tools.py │ │ ├── 495cb26ce93e_create_knowlege_graph_tables.py │ │ ├── 4a1e4b1c89d2_add_indexing_to_userfilestatus.py │ │ ├── 4a951134c801_moved_status_to_connector_credential_.py │ │ ├── 4b08d97e175a_change_default_prune_freq.py │ │ ├── 4cebcbc9b2ae_add_tab_index_to_tool_call.py │ │ ├── 4d58345da04a_lowercase_user_emails.py │ │ ├── 4ea2c93919c1_add_type_to_credentials.py │ │ ├── 4ee1287bd26a_add_multiple_slack_bot_support.py │ │ ├── 4f8a2b3c1d9e_add_open_url_tool.py │ │ ├── 503883791c39_add_effective_permissions.py │ │ ├── 505c488f6662_merge_default_assistants_into_unified.py │ │ ├── 50b683a8295c_add_additional_retrieval_controls_to_.py │ │ ├── 52a219fb5233_add_last_synced_and_last_modified_to_document_table.py │ │ ├── 54a74a0417fc_danswerbot_onyxbot.py │ │ ├── 55546a7967ee_assistant_rework.py │ │ ├── 570282d33c49_track_onyxbot_explicitly.py │ │ ├── 57122d037335_add_python_tool_on_default.py │ │ ├── 57b53544726e_add_document_set_tables.py │ │ ├── 5809c0787398_add_chat_sessions.py │ │ ├── 58c50ef19f08_add_stale_column_to_user__external_user_.py │ │ ├── 5ae8240accb3_add_research_agent_database_tables_and_.py │ │ ├── 5b29123cd710_nullable_search_settings_for_historic_.py │ │ ├── 5c3dca366b35_backend_driven_notification_details.py │ │ ├── 5c448911b12f_add_content_type_to_userfile.py │ │ ├── 5c7fdadae813_match_any_keywords_flag_for_standard_.py │ │ ├── 5d12a446f5c0_add_api_version_and_deployment_name_to_.py │ │ ├── 5e1c073d48a3_add_personal_access_token_table.py │ │ ├── 5e6f7a8b9c0d_update_default_persona_prompt.py │ │ ├── 5e84129c8be3_add_docs_indexed_column_to_index_.py │ │ ├── 5f4b8568a221_add_removed_documents_to_index_attempt.py │ │ ├── 5fc1f54cc252_hybrid_enum.py │ │ ├── 61ff3651add4_add_permission_syncing.py │ │ ├── 62c3a055a141_add_file_names_to_file_connector_config.py │ │ ├── 631fd2504136_add_approx_chunk_count_in_vespa_to_.py │ │ ├── 6436661d5b65_add_created_at_in_project_userfile.py │ │ ├── 643a84a42a33_add_user_configured_names_to_llmprovider.py │ │ ├── 64bd5677aeb6_add_image_input_support_to_model_config.py │ │ ├── 65bc6e0f8500_remove_kg_subtype_from_db.py │ │ ├── 6756efa39ada_id_uuid_for_chat_session.py │ │ ├── 689433b0d8de_add_hook_and_hook_execution_log_tables.py │ │ ├── 699221885109_nullify_default_task_prompt.py │ │ ├── 6a804aeb4830_duplicated_no_harm_user_file_migration.py │ │ ├── 6b3b4083c5aa_persona_cleanup_and_featured.py │ │ ├── 6d387b3196c2_basic_auth.py │ │ ├── 6d562f86c78b_remove_default_bot.py │ │ ├── 6f4f86aef280_add_queries_and_is_web_fetch_to_.py │ │ ├── 6fc7886d665d_make_categories_labels_and_many_to_many.py │ │ ├── 703313b75876_add_tokenratelimit_tables.py │ │ ├── 70f00c45c0f2_more_descriptive_filestore.py │ │ ├── 7206234e012a_add_image_generation_config_table.py │ │ ├── 72aa7de2e5cf_make_processing_mode_default_all_caps.py │ │ ├── 72bdc9929a46_permission_auto_sync_framework.py │ │ ├── 73e9983e5091_add_search_query_table.py │ │ ├── 7477a5f5d728_added_model_defaults_for_users.py │ │ ├── 7547d982db8f_chat_folders.py │ │ ├── 7616121f6e97_add_enterprise_fields_to_scim_user_mapping.py │ │ ├── 767f1c2a00eb_count_chat_tokens.py │ │ ├── 76b60d407dfb_cc_pair_name_not_unique.py │ │ ├── 776b3bbe9092_remove_remaining_enums.py │ │ ├── 77d07dffae64_forcibly_remove_more_enum_types_from_.py │ │ ├── 78dbe7e38469_task_tracking.py │ │ ├── 78ebc66946a0_remove_reranking_from_search_settings.py │ │ ├── 795b20b85b4b_add_llm_group_permissions_control.py │ │ ├── 797089dfb4d2_persona_start_date.py │ │ ├── 79acd316403a_add_api_key_table.py │ │ ├── 7a70b7664e37_add_model_configuration_table.py │ │ ├── 7aea705850d5_added_slack_auto_filter.py │ │ ├── 7b9b952abdf6_update_entities.py │ │ ├── 7bd55f264e1b_add_display_name_to_model_configuration.py │ │ ├── 7cb492013621_code_interpreter_server_model.py │ │ ├── 7cc3fcc116c1_user_file_uuid_primary_key_swap.py │ │ ├── 7ccea01261f6_store_chat_retrieval_docs.py │ │ ├── 7da0ae5ad583_add_description_to_persona.py │ │ ├── 7da543f5672f_add_slackbotconfig_table.py │ │ ├── 7e490836d179_nullify_default_system_prompt.py │ │ ├── 7ed603b64d5a_add_mcp_server_and_connection_config_.py │ │ ├── 7f726bad5367_slack_followup.py │ │ ├── 7f99be1cb9f5_add_index_for_getting_documents_just_by_.py │ │ ├── 800f48024ae9_add_id_to_connectorcredentialpair.py │ │ ├── 80696cf850ae_add_chat_session_to_query_event.py │ │ ├── 8188861f4e92_csv_to_tabular_chat_file_type.py │ │ ├── 81c22b1e2e78_hierarchy_nodes_v1.py │ │ ├── 8405ca81cc83_notifications_constraint.py │ │ ├── 849b21c732f8_add_demo_data_enabled_to_build_session.py │ │ ├── 87c52ec39f84_update_default_system_prompt.py │ │ ├── 8818cf73fa1a_drop_include_citations.py │ │ ├── 891cd83c87a8_add_is_visible_to_persona.py │ │ ├── 8987770549c0_add_full_exception_stack_trace.py │ │ ├── 8a87bd6ec550_associate_index_attempts_with_ccpair.py │ │ ├── 8aabb57f3b49_restructure_document_indices.py │ │ ├── 8b5ce697290e_add_discord_bot_tables.py │ │ ├── 8e1ac4f39a9f_enable_contextual_retrieval.py │ │ ├── 8e26726b7683_chat_context_addition.py │ │ ├── 8f43500ee275_add_index.py │ │ ├── 8ffcc2bcfc11_add_needs_persona_sync_to_user_file.py │ │ ├── 904451035c9b_store_tool_details.py │ │ ├── 904e5138fffb_tags.py │ │ ├── 9087b548dd69_seed_default_image_gen_config.py │ │ ├── 90b409d06e50_add_chat_compression_fields.py │ │ ├── 90e3b9af7da4_tag_fix.py │ │ ├── 91a0a4d62b14_milestone.py │ │ ├── 91fd3b470d1a_remove_documentsource_from_tag.py │ │ ├── 91ffac7e65b3_add_expiry_time.py │ │ ├── 93560ba1b118_add_web_ui_option_to_slack_config.py │ │ ├── 93a2e195e25c_add_voice_provider_and_user_voice_prefs.py │ │ ├── 93c15d6a6fbb_add_chunk_error_and_vespa_count_columns_.py │ │ ├── 949b4a92a401_remove_rt.py │ │ ├── 94dc3d0236f8_make_document_set_description_optional.py │ │ ├── 96a5702df6aa_mcp_tool_enabled.py │ │ ├── 977e834c1427_seed_default_groups.py │ │ ├── 97dbb53fa8c8_add_syncrecord.py │ │ ├── 98a5008d8711_agent_tracking.py │ │ ├── 9a0296d7421e_add_is_auto_mode_to_llm_provider.py │ │ ├── 9aadf32dfeb4_add_user_files.py │ │ ├── 9b66d3156fc6_user_file_schema_additions.py │ │ ├── 9c00a2bccb83_chat_message_agentic.py │ │ ├── 9c54986124c6_add_scim_tables.py │ │ ├── 9cf5c00f72fe_add_creator_to_cc_pair.py │ │ ├── 9d1543a37106_add_processing_duration_seconds_to_chat_.py │ │ ├── 9d97fecfab7f_added_retrieved_docs_to_query_event.py │ │ ├── 9drpiiw74ljy_add_config_to_federated_connector.py │ │ ├── 9f696734098f_combine_search_and_chat.py │ │ ├── a01bf2971c5d_update_default_tool_descriptions.py │ │ ├── a1b2c3d4e5f6_add_license_table.py │ │ ├── a1b2c3d4e5f7_drop_agent_search_metrics_table.py │ │ ├── a2b3c4d5e6f7_remove_fast_default_model_name.py │ │ ├── a3795dce87be_migration_confluence_to_be_explicit.py │ │ ├── a3b8d9e2f1c4_make_scim_external_id_nullable.py │ │ ├── a3bfd0d64902_add_chosen_assistants_to_user_table.py │ │ ├── a3c1a7904cd0_remove_userfile_related_deprecated_.py │ │ ├── a3f8b2c1d4e5_add_preferred_response_id_to_chat_message.py │ │ ├── a4f23d6b71c8_add_llm_provider_persona_restrictions.py │ │ ├── a570b80a5f20_usergroup_tables.py │ │ ├── a6df6b88ef81_remove_recent_assistants.py │ │ ├── a7688ab35c45_add_public_external_user_group_table.py │ │ ├── a852cbe15577_new_chat_history.py │ │ ├── a8c2065484e6_add_auto_scroll_to_user_model.py │ │ ├── abbfec3a5ac5_merge_prompt_into_persona.py │ │ ├── abe7378b8217_add_indexing_trigger_to_cc_pair.py │ │ ├── ac5eaac849f9_add_last_pruned_to_connector_table.py │ │ ├── acaab4ef4507_remove_inactive_ccpair_status_on_.py │ │ ├── ae62505e3acc_add_saml_accounts.py │ │ ├── aeda5f2df4f6_add_pinned_assistants.py │ │ ├── b082fec533f0_make_last_attempt_status_nullable.py │ │ ├── b156fa702355_chat_reworked.py │ │ ├── b30353be4eec_add_mcp_auth_performer.py │ │ ├── b329d00a9ea6_adding_assistant_specific_user_.py │ │ ├── b388730a2899_nullable_preferences.py │ │ ├── b4b7e1028dfd_grant_basic_to_existing_groups.py │ │ ├── b4ef3ae0bf6e_add_user_oauth_token_to_slack_bot.py │ │ ├── b51c6844d1df_seed_memory_tool.py │ │ ├── b558f51620b4_pause_finished_user_file_connectors.py │ │ ├── b5c4d7e8f9a1_add_hierarchy_node_cc_pair_table.py │ │ ├── b728689f45b1_rename_persona_is_visible_to_is_listed_.py │ │ ├── b72ed7a5db0e_remove_description_from_starter_messages.py │ │ ├── b7a7eee5aa15_add_checkpointing_failure_handling.py │ │ ├── b7bcc991d722_assign_users_to_default_groups.py │ │ ├── b7c2b63c4a03_add_background_reindex_enabled_field.py │ │ ├── b7ec9b5b505f_adjust_prompt_length.py │ │ ├── b85f02ec1308_fix_file_type_migration.py │ │ ├── b896bbd0d5a7_backfill_is_internet_data_to_false.py │ │ ├── b8c9d0e1f2a3_drop_milestone_table.py │ │ ├── ba98eba0f66a_add_support_for_litellm_proxy_in_.py │ │ ├── baf71f781b9e_add_llm_model_version_override_to_.py │ │ ├── bc9771dccadf_create_usage_reports_table.py │ │ ├── bceb1e139447_add_base_url_to_cloudembeddingprovider.py │ │ ├── bd2921608c3a_non_nullable_default_persona.py │ │ ├── bd7c3bf8beba_migrate_agent_responses_to_research_.py │ │ ├── be2ab2aa50ee_fix_capitalization.py │ │ ├── be87a654d5af_persona_new_default_model_configuration_.py │ │ ├── bf7a81109301_delete_input_prompts.py │ │ ├── c0aab6edb6dd_delete_workspace.py │ │ ├── c0c937d5c9e5_llm_provider_deprecate_fields.py │ │ ├── c0fd6e4da83a_add_recent_assistants.py │ │ ├── c18cdf4b497e_add_standard_answer_tables.py │ │ ├── c1d2e3f4a5b6_add_deep_research_tool.py │ │ ├── c5b692fa265c_add_index_attempt_errors_table.py │ │ ├── c5eae4a75a1b_add_chat_message__standard_answer_table.py │ │ ├── c7bf5721733e_add_has_been_indexed_to_.py │ │ ├── c7e9f4a3b2d1_add_python_tool.py │ │ ├── c7f2e1b4a9d3_add_sharing_scope_to_build_session.py │ │ ├── c8a93a2af083_personalization_user_info.py │ │ ├── c99d76fcd298_add_nullable_to_persona_id_in_chat_.py │ │ ├── c9e2cd766c29_add_s3_file_store_table.py │ │ ├── ca04500b9ee8_add_cascade_deletes_to_agent_tables.py │ │ ├── cbc03e08d0f3_add_opensearch_migration_tables.py │ │ ├── cec7ec36c505_kgentity_parent.py │ │ ├── cf90764725d8_larger_refresh_tokens.py │ │ ├── d09fc20a3c66_seed_builtin_tools.py │ │ ├── d1b637d7050a_sync_exa_api_key_to_content_provider.py │ │ ├── d25168c2beee_tool_name_consistency.py │ │ ├── d3fd499c829c_add_file_reader_tool.py │ │ ├── d5645c915d0e_remove_deletion_attempt_table.py │ │ ├── d56ffa94ca32_add_file_content.py │ │ ├── d5c86e2c6dc6_add_cascade_delete_to_search_query_user_.py │ │ ├── d61e513bef0a_add_total_docs_for_index_attempt.py │ │ ├── d7111c1238cd_remove_document_ids.py │ │ ├── d716b0791ddd_combined_slack_id_fields.py │ │ ├── d8cdfee5df80_add_skipped_to_userfilestatus.py │ │ ├── d929f0c1c6af_feedback_feature.py │ │ ├── d961aca62eb3_update_status_length.py │ │ ├── d9ec13955951_remove__dim_suffix_from_model_name.py │ │ ├── da42808081e3_migrate_jira_connectors_to_new_format.py │ │ ├── da4c21c69164_chosen_assistants_changed_to_jsonb.py │ │ ├── dab04867cd88_add_composite_index_to_document_by_.py │ │ ├── dba7f71618f5_onyx_custom_tool_flow.py │ │ ├── dbaa756c2ccf_embedding_models.py │ │ ├── df0c7ad8a076_added_deletion_attempt_table.py │ │ ├── df46c75b714e_add_default_vision_provider_to_llm_.py │ │ ├── dfbe9e93d3c7_extended_role_for_non_web.py │ │ ├── e0a68a81d434_add_chat_feedback.py │ │ ├── e1392f05e840_added_input_prompts.py │ │ ├── e209dc5a8156_added_prune_frequency.py │ │ ├── e4334d5b33ba_add_deployment_name_to_llmprovider.py │ │ ├── e50154680a5c_no_source_enum.py │ │ ├── e6a4bbc13fe4_add_index_for_retrieving_latest_index_.py │ │ ├── e7f8a9b0c1d2_create_anonymous_user.py │ │ ├── e86866a9c78a_add_persona_to_chat_session.py │ │ ├── e8f0d2a38171_add_status_to_mcp_server_and_make_auth_.py │ │ ├── e91df4e935ef_private_personas_documentsets.py │ │ ├── eaa3b5593925_add_default_slack_channel_config.py │ │ ├── ec3ec2eabf7b_index_from_beginning.py │ │ ├── ec85f2b3c544_remove_last_attempt_status_from_cc_pair.py │ │ ├── ecab2b3f1a3b_add_overrides_to_the_chat_session.py │ │ ├── ed9e44312505_add_icon_name_field.py │ │ ├── ee3f4b47fad5_added_alternate_model_to_chat_message.py │ │ ├── ef7da92f7213_add_files_to_chatmessage.py │ │ ├── efb35676026c_standard_answer_match_regex_flag.py │ │ ├── f11b408e39d3_force_lowercase_all_users.py │ │ ├── f13db29f3101_add_composite_index_for_last_modified_.py │ │ ├── f17bf3b0d9f1_embedding_provider_by_provider_type.py │ │ ├── f1c6478c3fd8_add_pre_defined_feedback.py │ │ ├── f1ca58b2f2ec_add_passthrough_auth_to_tool.py │ │ ├── f220515df7b4_add_flow_mapping_table.py │ │ ├── f32615f71aeb_add_custom_headers_to_tools.py │ │ ├── f39c5794c10a_add_background_errors_table.py │ │ ├── f5437cc136c5_delete_non_search_assistants.py │ │ ├── f71470ba9274_add_prompt_length_limit.py │ │ ├── f7505c5b0284_updated_constraints_for_ccpairs.py │ │ ├── f7a894b06d02_non_nullbale_slack_bot_id_in_channel_.py │ │ ├── f7ca3e2f45d9_migrate_no_auth_data_to_placeholder.py │ │ ├── f7e58d357687_add_has_web_column_to_user.py │ │ ├── f8a9b2c3d4e5_add_research_answer_purpose_to_chat_message.py │ │ ├── f9b8c7d6e5a4_update_parent_question_id_foreign_key_to_research_agent_iteration.py │ │ ├── fad14119fb92_delete_tags_with_wrong_enum.py │ │ ├── fb80bdd256de_add_chat_background_to_user.py │ │ ├── fcd135795f21_add_slack_bot_display_type.py │ │ ├── febe9eaa0644_add_document_set_persona_relationship_.py │ │ ├── fec3db967bf7_add_time_updated_to_usergroup_and_.py │ │ ├── feead2911109_add_opensearch_tenant_migration_columns.py │ │ └── ffc707a226b4_basic_document_metadata.py │ ├── alembic.ini │ ├── alembic_tenants/ │ │ ├── README.md │ │ ├── __init__.py │ │ ├── env.py │ │ ├── script.py.mako │ │ └── versions/ │ │ ├── 14a83a331951_create_usertenantmapping_table.py │ │ ├── 34e3630c7f32_lowercase_multi_tenant_user_auth.py │ │ ├── 3b45e0018bf1_add_new_available_tenant_table.py │ │ ├── 3b9f09038764_add_read_only_kg_user.py │ │ ├── a4f6ee863c47_mapping_for_anonymous_user_path.py │ │ └── ac842f85f932_new_column_user_tenant_mapping.py │ ├── assets/ │ │ └── .gitignore │ ├── ee/ │ │ ├── LICENSE │ │ ├── __init__.py │ │ └── onyx/ │ │ ├── __init__.py │ │ ├── access/ │ │ │ ├── access.py │ │ │ └── hierarchy_access.py │ │ ├── auth/ │ │ │ ├── __init__.py │ │ │ └── users.py │ │ ├── background/ │ │ │ ├── celery/ │ │ │ │ ├── apps/ │ │ │ │ │ ├── heavy.py │ │ │ │ │ ├── light.py │ │ │ │ │ ├── monitoring.py │ │ │ │ │ └── primary.py │ │ │ │ └── tasks/ │ │ │ │ ├── beat_schedule.py │ │ │ │ ├── cleanup/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── tasks.py │ │ │ │ ├── cloud/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── tasks.py │ │ │ │ ├── doc_permission_syncing/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── tasks.py │ │ │ │ ├── external_group_syncing/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── group_sync_utils.py │ │ │ │ │ └── tasks.py │ │ │ │ ├── hooks/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── tasks.py │ │ │ │ ├── query_history/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── tasks.py │ │ │ │ ├── tenant_provisioning/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── tasks.py │ │ │ │ ├── ttl_management/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── tasks.py │ │ │ │ ├── usage_reporting/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── tasks.py │ │ │ │ └── vespa/ │ │ │ │ ├── __init__.py │ │ │ │ └── tasks.py │ │ │ ├── celery_utils.py │ │ │ └── task_name_builders.py │ │ ├── configs/ │ │ │ ├── __init__.py │ │ │ ├── app_configs.py │ │ │ └── license_enforcement_config.py │ │ ├── connectors/ │ │ │ └── perm_sync_valid.py │ │ ├── db/ │ │ │ ├── __init__.py │ │ │ ├── analytics.py │ │ │ ├── connector.py │ │ │ ├── connector_credential_pair.py │ │ │ ├── document.py │ │ │ ├── document_set.py │ │ │ ├── external_perm.py │ │ │ ├── hierarchy.py │ │ │ ├── license.py │ │ │ ├── persona.py │ │ │ ├── query_history.py │ │ │ ├── saml.py │ │ │ ├── scim.py │ │ │ ├── search.py │ │ │ ├── standard_answer.py │ │ │ ├── token_limit.py │ │ │ ├── usage_export.py │ │ │ └── user_group.py │ │ ├── document_index/ │ │ │ └── vespa/ │ │ │ └── app_config/ │ │ │ └── cloud-services.xml.jinja │ │ ├── external_permissions/ │ │ │ ├── __init__.py │ │ │ ├── confluence/ │ │ │ │ ├── __init__.py │ │ │ │ ├── constants.py │ │ │ │ ├── doc_sync.py │ │ │ │ ├── group_sync.py │ │ │ │ ├── page_access.py │ │ │ │ └── space_access.py │ │ │ ├── github/ │ │ │ │ ├── doc_sync.py │ │ │ │ ├── group_sync.py │ │ │ │ └── utils.py │ │ │ ├── gmail/ │ │ │ │ └── doc_sync.py │ │ │ ├── google_drive/ │ │ │ │ ├── __init__.py │ │ │ │ ├── doc_sync.py │ │ │ │ ├── folder_retrieval.py │ │ │ │ ├── group_sync.py │ │ │ │ ├── models.py │ │ │ │ └── permission_retrieval.py │ │ │ ├── jira/ │ │ │ │ ├── __init__.py │ │ │ │ ├── doc_sync.py │ │ │ │ ├── group_sync.py │ │ │ │ ├── models.py │ │ │ │ └── page_access.py │ │ │ ├── perm_sync_types.py │ │ │ ├── post_query_censoring.py │ │ │ ├── salesforce/ │ │ │ │ ├── postprocessing.py │ │ │ │ └── utils.py │ │ │ ├── sharepoint/ │ │ │ │ ├── doc_sync.py │ │ │ │ ├── group_sync.py │ │ │ │ └── permission_utils.py │ │ │ ├── slack/ │ │ │ │ ├── channel_access.py │ │ │ │ ├── doc_sync.py │ │ │ │ ├── group_sync.py │ │ │ │ └── utils.py │ │ │ ├── sync_params.py │ │ │ ├── teams/ │ │ │ │ └── doc_sync.py │ │ │ └── utils.py │ │ ├── feature_flags/ │ │ │ ├── __init__.py │ │ │ ├── factory.py │ │ │ └── posthog_provider.py │ │ ├── hooks/ │ │ │ ├── __init__.py │ │ │ └── executor.py │ │ ├── main.py │ │ ├── onyxbot/ │ │ │ └── slack/ │ │ │ └── handlers/ │ │ │ ├── __init__.py │ │ │ └── handle_standard_answers.py │ │ ├── prompts/ │ │ │ ├── __init__.py │ │ │ ├── query_expansion.py │ │ │ └── search_flow_classification.py │ │ ├── search/ │ │ │ └── process_search_query.py │ │ ├── secondary_llm_flows/ │ │ │ ├── __init__.py │ │ │ ├── query_expansion.py │ │ │ └── search_flow_classification.py │ │ ├── server/ │ │ │ ├── __init__.py │ │ │ ├── analytics/ │ │ │ │ └── api.py │ │ │ ├── auth_check.py │ │ │ ├── billing/ │ │ │ │ ├── __init__.py │ │ │ │ ├── api.py │ │ │ │ ├── models.py │ │ │ │ └── service.py │ │ │ ├── documents/ │ │ │ │ └── cc_pair.py │ │ │ ├── enterprise_settings/ │ │ │ │ ├── api.py │ │ │ │ ├── models.py │ │ │ │ └── store.py │ │ │ ├── evals/ │ │ │ │ ├── __init__.py │ │ │ │ └── api.py │ │ │ ├── features/ │ │ │ │ ├── __init__.py │ │ │ │ └── hooks/ │ │ │ │ ├── __init__.py │ │ │ │ └── api.py │ │ │ ├── license/ │ │ │ │ ├── api.py │ │ │ │ └── models.py │ │ │ ├── manage/ │ │ │ │ └── standard_answer.py │ │ │ ├── middleware/ │ │ │ │ ├── license_enforcement.py │ │ │ │ └── tenant_tracking.py │ │ │ ├── oauth/ │ │ │ │ ├── api.py │ │ │ │ ├── api_router.py │ │ │ │ ├── confluence_cloud.py │ │ │ │ ├── google_drive.py │ │ │ │ └── slack.py │ │ │ ├── query_and_chat/ │ │ │ │ ├── __init__.py │ │ │ │ ├── models.py │ │ │ │ ├── query_backend.py │ │ │ │ ├── search_backend.py │ │ │ │ ├── streaming_models.py │ │ │ │ └── token_limit.py │ │ │ ├── query_history/ │ │ │ │ ├── api.py │ │ │ │ └── models.py │ │ │ ├── reporting/ │ │ │ │ ├── usage_export_api.py │ │ │ │ ├── usage_export_generation.py │ │ │ │ └── usage_export_models.py │ │ │ ├── scim/ │ │ │ │ ├── __init__.py │ │ │ │ ├── api.py │ │ │ │ ├── auth.py │ │ │ │ ├── filtering.py │ │ │ │ ├── models.py │ │ │ │ ├── patch.py │ │ │ │ ├── providers/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── base.py │ │ │ │ │ ├── entra.py │ │ │ │ │ └── okta.py │ │ │ │ └── schema_definitions.py │ │ │ ├── seeding.py │ │ │ ├── settings/ │ │ │ │ ├── __init__.py │ │ │ │ └── api.py │ │ │ ├── tenant_usage_limits.py │ │ │ ├── tenants/ │ │ │ │ ├── __init__.py │ │ │ │ ├── access.py │ │ │ │ ├── admin_api.py │ │ │ │ ├── anonymous_user_path.py │ │ │ │ ├── anonymous_users_api.py │ │ │ │ ├── api.py │ │ │ │ ├── billing.py │ │ │ │ ├── billing_api.py │ │ │ │ ├── models.py │ │ │ │ ├── product_gating.py │ │ │ │ ├── provisioning.py │ │ │ │ ├── proxy.py │ │ │ │ ├── schema_management.py │ │ │ │ ├── team_membership_api.py │ │ │ │ ├── tenant_management_api.py │ │ │ │ ├── user_invitations_api.py │ │ │ │ └── user_mapping.py │ │ │ ├── token_rate_limits/ │ │ │ │ └── api.py │ │ │ ├── usage_limits.py │ │ │ └── user_group/ │ │ │ ├── api.py │ │ │ └── models.py │ │ └── utils/ │ │ ├── __init__.py │ │ ├── encryption.py │ │ ├── license.py │ │ ├── posthog_client.py │ │ └── telemetry.py │ ├── generated/ │ │ └── README.md │ ├── keys/ │ │ └── license_public_key.pem │ ├── model_server/ │ │ ├── __init__.py │ │ ├── constants.py │ │ ├── encoders.py │ │ ├── legacy/ │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── custom_models.py │ │ │ ├── onyx_torch_model.py │ │ │ └── reranker.py │ │ ├── main.py │ │ ├── management_endpoints.py │ │ └── utils.py │ ├── onyx/ │ │ ├── __init__.py │ │ ├── access/ │ │ │ ├── __init__.py │ │ │ ├── access.py │ │ │ ├── hierarchy_access.py │ │ │ ├── models.py │ │ │ └── utils.py │ │ ├── auth/ │ │ │ ├── __init__.py │ │ │ ├── anonymous_user.py │ │ │ ├── api_key.py │ │ │ ├── captcha.py │ │ │ ├── constants.py │ │ │ ├── disposable_email_validator.py │ │ │ ├── email_utils.py │ │ │ ├── invited_users.py │ │ │ ├── jwt.py │ │ │ ├── oauth_refresher.py │ │ │ ├── oauth_token_manager.py │ │ │ ├── pat.py │ │ │ ├── permissions.py │ │ │ ├── schemas.py │ │ │ ├── users.py │ │ │ └── utils.py │ │ ├── background/ │ │ │ ├── README.md │ │ │ ├── celery/ │ │ │ │ ├── apps/ │ │ │ │ │ ├── app_base.py │ │ │ │ │ ├── beat.py │ │ │ │ │ ├── client.py │ │ │ │ │ ├── docfetching.py │ │ │ │ │ ├── docprocessing.py │ │ │ │ │ ├── heavy.py │ │ │ │ │ ├── light.py │ │ │ │ │ ├── monitoring.py │ │ │ │ │ ├── primary.py │ │ │ │ │ ├── task_formatters.py │ │ │ │ │ └── user_file_processing.py │ │ │ │ ├── celery_k8s_probe.py │ │ │ │ ├── celery_redis.py │ │ │ │ ├── celery_utils.py │ │ │ │ ├── configs/ │ │ │ │ │ ├── base.py │ │ │ │ │ ├── beat.py │ │ │ │ │ ├── client.py │ │ │ │ │ ├── docfetching.py │ │ │ │ │ ├── docprocessing.py │ │ │ │ │ ├── heavy.py │ │ │ │ │ ├── light.py │ │ │ │ │ ├── monitoring.py │ │ │ │ │ ├── primary.py │ │ │ │ │ └── user_file_processing.py │ │ │ │ ├── memory_monitoring.py │ │ │ │ ├── tasks/ │ │ │ │ │ ├── beat_schedule.py │ │ │ │ │ ├── connector_deletion/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── tasks.py │ │ │ │ │ ├── docfetching/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── task_creation_utils.py │ │ │ │ │ │ └── tasks.py │ │ │ │ │ ├── docprocessing/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── heartbeat.py │ │ │ │ │ │ ├── tasks.py │ │ │ │ │ │ └── utils.py │ │ │ │ │ ├── evals/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── tasks.py │ │ │ │ │ ├── hierarchyfetching/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── tasks.py │ │ │ │ │ ├── llm_model_update/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── tasks.py │ │ │ │ │ ├── models.py │ │ │ │ │ ├── monitoring/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── tasks.py │ │ │ │ │ ├── opensearch_migration/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── constants.py │ │ │ │ │ │ ├── tasks.py │ │ │ │ │ │ └── transformer.py │ │ │ │ │ ├── periodic/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── tasks.py │ │ │ │ │ ├── pruning/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── tasks.py │ │ │ │ │ ├── shared/ │ │ │ │ │ │ ├── RetryDocumentIndex.py │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── tasks.py │ │ │ │ │ ├── user_file_processing/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ └── tasks.py │ │ │ │ │ └── vespa/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── document_sync.py │ │ │ │ │ └── tasks.py │ │ │ │ └── versioned_apps/ │ │ │ │ ├── beat.py │ │ │ │ ├── client.py │ │ │ │ ├── docfetching.py │ │ │ │ ├── docprocessing.py │ │ │ │ ├── heavy.py │ │ │ │ ├── light.py │ │ │ │ ├── monitoring.py │ │ │ │ ├── primary.py │ │ │ │ └── user_file_processing.py │ │ │ ├── error_logging.py │ │ │ ├── indexing/ │ │ │ │ ├── checkpointing_utils.py │ │ │ │ ├── dask_utils.py │ │ │ │ ├── index_attempt_utils.py │ │ │ │ ├── job_client.py │ │ │ │ ├── memory_tracer.py │ │ │ │ ├── models.py │ │ │ │ └── run_docfetching.py │ │ │ ├── periodic_poller.py │ │ │ └── task_utils.py │ │ ├── cache/ │ │ │ ├── factory.py │ │ │ ├── interface.py │ │ │ ├── postgres_backend.py │ │ │ └── redis_backend.py │ │ ├── chat/ │ │ │ ├── COMPRESSION.md │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── chat_processing_checker.py │ │ │ ├── chat_state.py │ │ │ ├── chat_utils.py │ │ │ ├── citation_processor.py │ │ │ ├── citation_utils.py │ │ │ ├── compression.py │ │ │ ├── emitter.py │ │ │ ├── llm_loop.py │ │ │ ├── llm_step.py │ │ │ ├── models.py │ │ │ ├── process_message.py │ │ │ ├── prompt_utils.py │ │ │ ├── save_chat.py │ │ │ ├── stop_signal_checker.py │ │ │ └── tool_call_args_streaming.py │ │ ├── configs/ │ │ │ ├── __init__.py │ │ │ ├── agent_configs.py │ │ │ ├── app_configs.py │ │ │ ├── chat_configs.py │ │ │ ├── constants.py │ │ │ ├── embedding_configs.py │ │ │ ├── kg_configs.py │ │ │ ├── llm_configs.py │ │ │ ├── model_configs.py │ │ │ ├── onyxbot_configs.py │ │ │ ├── research_configs.py │ │ │ ├── saml_config/ │ │ │ │ └── template.settings.json │ │ │ └── tool_configs.py │ │ ├── connectors/ │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── airtable/ │ │ │ │ └── airtable_connector.py │ │ │ ├── asana/ │ │ │ │ ├── __init__.py │ │ │ │ ├── asana_api.py │ │ │ │ └── connector.py │ │ │ ├── axero/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── bitbucket/ │ │ │ │ ├── __init__.py │ │ │ │ ├── connector.py │ │ │ │ └── utils.py │ │ │ ├── blob/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── bookstack/ │ │ │ │ ├── __init__.py │ │ │ │ ├── client.py │ │ │ │ └── connector.py │ │ │ ├── canvas/ │ │ │ │ ├── __init__.py │ │ │ │ ├── access.py │ │ │ │ ├── client.py │ │ │ │ └── connector.py │ │ │ ├── clickup/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── coda/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── confluence/ │ │ │ │ ├── __init__.py │ │ │ │ ├── access.py │ │ │ │ ├── connector.py │ │ │ │ ├── models.py │ │ │ │ ├── onyx_confluence.py │ │ │ │ ├── user_profile_override.py │ │ │ │ └── utils.py │ │ │ ├── connector_runner.py │ │ │ ├── credentials_provider.py │ │ │ ├── cross_connector_utils/ │ │ │ │ ├── __init__.py │ │ │ │ ├── miscellaneous_utils.py │ │ │ │ └── rate_limit_wrapper.py │ │ │ ├── discord/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── discourse/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── document360/ │ │ │ │ ├── __init__.py │ │ │ │ ├── connector.py │ │ │ │ └── utils.py │ │ │ ├── dropbox/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── drupal_wiki/ │ │ │ │ ├── __init__.py │ │ │ │ ├── connector.py │ │ │ │ ├── models.py │ │ │ │ └── utils.py │ │ │ ├── egnyte/ │ │ │ │ └── connector.py │ │ │ ├── exceptions.py │ │ │ ├── factory.py │ │ │ ├── file/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── fireflies/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── freshdesk/ │ │ │ │ ├── __init__,py │ │ │ │ └── connector.py │ │ │ ├── gitbook/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── github/ │ │ │ │ ├── __init__.py │ │ │ │ ├── connector.py │ │ │ │ ├── models.py │ │ │ │ ├── rate_limit_utils.py │ │ │ │ └── utils.py │ │ │ ├── gitlab/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── gmail/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── gong/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── google_drive/ │ │ │ │ ├── __init__.py │ │ │ │ ├── connector.py │ │ │ │ ├── constants.py │ │ │ │ ├── doc_conversion.py │ │ │ │ ├── file_retrieval.py │ │ │ │ ├── models.py │ │ │ │ └── section_extraction.py │ │ │ ├── google_site/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── google_utils/ │ │ │ │ ├── __init__.py │ │ │ │ ├── google_auth.py │ │ │ │ ├── google_kv.py │ │ │ │ ├── google_utils.py │ │ │ │ ├── resources.py │ │ │ │ └── shared_constants.py │ │ │ ├── guru/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── highspot/ │ │ │ │ ├── __init__.py │ │ │ │ ├── client.py │ │ │ │ ├── connector.py │ │ │ │ └── utils.py │ │ │ ├── hubspot/ │ │ │ │ ├── __init__.py │ │ │ │ ├── connector.py │ │ │ │ └── rate_limit.py │ │ │ ├── imap/ │ │ │ │ ├── __init__.py │ │ │ │ ├── connector.py │ │ │ │ └── models.py │ │ │ ├── interfaces.py │ │ │ ├── jira/ │ │ │ │ ├── __init__.py │ │ │ │ ├── access.py │ │ │ │ ├── connector.py │ │ │ │ └── utils.py │ │ │ ├── linear/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── loopio/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── mediawiki/ │ │ │ │ ├── __init__.py │ │ │ │ ├── family.py │ │ │ │ └── wiki.py │ │ │ ├── microsoft_graph_env.py │ │ │ ├── mock_connector/ │ │ │ │ └── connector.py │ │ │ ├── models.py │ │ │ ├── notion/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── outline/ │ │ │ │ ├── __init__.py │ │ │ │ ├── client.py │ │ │ │ └── connector.py │ │ │ ├── productboard/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── registry.py │ │ │ ├── requesttracker/ │ │ │ │ ├── .gitignore │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── salesforce/ │ │ │ │ ├── __init__.py │ │ │ │ ├── blacklist.py │ │ │ │ ├── connector.py │ │ │ │ ├── doc_conversion.py │ │ │ │ ├── onyx_salesforce.py │ │ │ │ ├── salesforce_calls.py │ │ │ │ ├── shelve_stuff/ │ │ │ │ │ ├── old_test_salesforce_shelves.py │ │ │ │ │ ├── shelve_functions.py │ │ │ │ │ ├── shelve_utils.py │ │ │ │ │ └── test_salesforce_shelves.py │ │ │ │ ├── sqlite_functions.py │ │ │ │ └── utils.py │ │ │ ├── sharepoint/ │ │ │ │ ├── __init__.py │ │ │ │ ├── connector.py │ │ │ │ └── connector_utils.py │ │ │ ├── slab/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── slack/ │ │ │ │ ├── __init__.py │ │ │ │ ├── access.py │ │ │ │ ├── connector.py │ │ │ │ ├── models.py │ │ │ │ ├── onyx_retry_handler.py │ │ │ │ ├── onyx_slack_web_client.py │ │ │ │ └── utils.py │ │ │ ├── teams/ │ │ │ │ ├── __init__.py │ │ │ │ ├── connector.py │ │ │ │ ├── models.py │ │ │ │ └── utils.py │ │ │ ├── testrail/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── web/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── wikipedia/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── xenforo/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ ├── zendesk/ │ │ │ │ ├── __init__.py │ │ │ │ └── connector.py │ │ │ └── zulip/ │ │ │ ├── __init__.py │ │ │ ├── connector.py │ │ │ ├── schemas.py │ │ │ └── utils.py │ │ ├── context/ │ │ │ └── search/ │ │ │ ├── __init__.py │ │ │ ├── enums.py │ │ │ ├── federated/ │ │ │ │ ├── models.py │ │ │ │ ├── slack_search.py │ │ │ │ └── slack_search_utils.py │ │ │ ├── models.py │ │ │ ├── pipeline.py │ │ │ ├── preprocessing/ │ │ │ │ └── access_filters.py │ │ │ ├── retrieval/ │ │ │ │ └── search_runner.py │ │ │ └── utils.py │ │ ├── db/ │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── _deprecated/ │ │ │ │ └── pg_file_store.py │ │ │ ├── api_key.py │ │ │ ├── auth.py │ │ │ ├── background_error.py │ │ │ ├── chat.py │ │ │ ├── chat_search.py │ │ │ ├── chunk.py │ │ │ ├── code_interpreter.py │ │ │ ├── connector.py │ │ │ ├── connector_credential_pair.py │ │ │ ├── constants.py │ │ │ ├── credentials.py │ │ │ ├── dal.py │ │ │ ├── deletion_attempt.py │ │ │ ├── discord_bot.py │ │ │ ├── document.py │ │ │ ├── document_access.py │ │ │ ├── document_set.py │ │ │ ├── engine/ │ │ │ │ ├── __init__.py │ │ │ │ ├── async_sql_engine.py │ │ │ │ ├── connection_warmup.py │ │ │ │ ├── iam_auth.py │ │ │ │ ├── sql_engine.py │ │ │ │ ├── tenant_utils.py │ │ │ │ └── time_utils.py │ │ │ ├── entities.py │ │ │ ├── entity_type.py │ │ │ ├── enums.py │ │ │ ├── federated.py │ │ │ ├── feedback.py │ │ │ ├── file_content.py │ │ │ ├── file_record.py │ │ │ ├── hierarchy.py │ │ │ ├── hook.py │ │ │ ├── image_generation.py │ │ │ ├── index_attempt.py │ │ │ ├── indexing_coordination.py │ │ │ ├── input_prompt.py │ │ │ ├── kg_config.py │ │ │ ├── kg_temp_view.py │ │ │ ├── llm.py │ │ │ ├── mcp.py │ │ │ ├── memory.py │ │ │ ├── models.py │ │ │ ├── notification.py │ │ │ ├── oauth_config.py │ │ │ ├── opensearch_migration.py │ │ │ ├── pat.py │ │ │ ├── permission_sync_attempt.py │ │ │ ├── permissions.py │ │ │ ├── persona.py │ │ │ ├── projects.py │ │ │ ├── pydantic_type.py │ │ │ ├── relationships.py │ │ │ ├── release_notes.py │ │ │ ├── rotate_encryption_key.py │ │ │ ├── saml.py │ │ │ ├── search_settings.py │ │ │ ├── seeding/ │ │ │ │ └── chat_history_seeding.py │ │ │ ├── slack_bot.py │ │ │ ├── slack_channel_config.py │ │ │ ├── swap_index.py │ │ │ ├── sync_record.py │ │ │ ├── tag.py │ │ │ ├── tasks.py │ │ │ ├── token_limit.py │ │ │ ├── tools.py │ │ │ ├── usage.py │ │ │ ├── user_file.py │ │ │ ├── user_preferences.py │ │ │ ├── users.py │ │ │ ├── utils.py │ │ │ ├── voice.py │ │ │ └── web_search.py │ │ ├── deep_research/ │ │ │ ├── __init__.py │ │ │ ├── dr_loop.py │ │ │ ├── dr_mock_tools.py │ │ │ ├── models.py │ │ │ └── utils.py │ │ ├── document_index/ │ │ │ ├── FILTER_SEMANTICS.md │ │ │ ├── __init__.py │ │ │ ├── chunk_content_enrichment.py │ │ │ ├── disabled.py │ │ │ ├── document_index_utils.py │ │ │ ├── factory.py │ │ │ ├── interfaces.py │ │ │ ├── interfaces_new.py │ │ │ ├── opensearch/ │ │ │ │ ├── README.md │ │ │ │ ├── client.py │ │ │ │ ├── cluster_settings.py │ │ │ │ ├── constants.py │ │ │ │ ├── opensearch_document_index.py │ │ │ │ ├── schema.py │ │ │ │ ├── search.py │ │ │ │ └── string_filtering.py │ │ │ ├── vespa/ │ │ │ │ ├── __init__.py │ │ │ │ ├── app_config/ │ │ │ │ │ ├── schemas/ │ │ │ │ │ │ └── danswer_chunk.sd.jinja │ │ │ │ │ ├── services.xml.jinja │ │ │ │ │ └── validation-overrides.xml.jinja │ │ │ │ ├── chunk_retrieval.py │ │ │ │ ├── deletion.py │ │ │ │ ├── index.py │ │ │ │ ├── indexing_utils.py │ │ │ │ ├── kg_interactions.py │ │ │ │ ├── shared_utils/ │ │ │ │ │ ├── utils.py │ │ │ │ │ └── vespa_request_builders.py │ │ │ │ └── vespa_document_index.py │ │ │ └── vespa_constants.py │ │ ├── error_handling/ │ │ │ ├── __init__.py │ │ │ ├── error_codes.py │ │ │ └── exceptions.py │ │ ├── evals/ │ │ │ ├── README.md │ │ │ ├── eval.py │ │ │ ├── eval_cli.py │ │ │ ├── models.py │ │ │ ├── one_off/ │ │ │ │ └── create_braintrust_dataset.py │ │ │ ├── provider.py │ │ │ └── providers/ │ │ │ ├── braintrust.py │ │ │ └── local.py │ │ ├── feature_flags/ │ │ │ ├── __init__.py │ │ │ ├── factory.py │ │ │ ├── feature_flags_keys.py │ │ │ ├── flags.py │ │ │ └── interface.py │ │ ├── federated_connectors/ │ │ │ ├── __init__.py │ │ │ ├── factory.py │ │ │ ├── federated_retrieval.py │ │ │ ├── interfaces.py │ │ │ ├── models.py │ │ │ ├── oauth_utils.py │ │ │ ├── registry.py │ │ │ └── slack/ │ │ │ ├── __init__.py │ │ │ ├── federated_connector.py │ │ │ └── models.py │ │ ├── file_processing/ │ │ │ ├── __init__.py │ │ │ ├── enums.py │ │ │ ├── extract_file_text.py │ │ │ ├── file_types.py │ │ │ ├── html_utils.py │ │ │ ├── image_summarization.py │ │ │ ├── image_utils.py │ │ │ ├── password_validation.py │ │ │ └── unstructured.py │ │ ├── file_store/ │ │ │ ├── README.md │ │ │ ├── constants.py │ │ │ ├── document_batch_storage.py │ │ │ ├── file_store.py │ │ │ ├── models.py │ │ │ ├── postgres_file_store.py │ │ │ ├── s3_key_utils.py │ │ │ └── utils.py │ │ ├── hooks/ │ │ │ ├── __init__.py │ │ │ ├── api_dependencies.py │ │ │ ├── executor.py │ │ │ ├── models.py │ │ │ ├── points/ │ │ │ │ ├── __init__.py │ │ │ │ ├── base.py │ │ │ │ ├── document_ingestion.py │ │ │ │ └── query_processing.py │ │ │ └── registry.py │ │ ├── httpx/ │ │ │ └── httpx_pool.py │ │ ├── image_gen/ │ │ │ ├── __init__.py │ │ │ ├── exceptions.py │ │ │ ├── factory.py │ │ │ ├── interfaces.py │ │ │ └── providers/ │ │ │ ├── azure_img_gen.py │ │ │ ├── openai_img_gen.py │ │ │ └── vertex_img_gen.py │ │ ├── indexing/ │ │ │ ├── __init__.py │ │ │ ├── adapters/ │ │ │ │ ├── document_indexing_adapter.py │ │ │ │ └── user_file_indexing_adapter.py │ │ │ ├── chunk_batch_store.py │ │ │ ├── chunker.py │ │ │ ├── content_classification.py │ │ │ ├── embedder.py │ │ │ ├── indexing_heartbeat.py │ │ │ ├── indexing_pipeline.py │ │ │ ├── models.py │ │ │ └── vector_db_insertion.py │ │ ├── key_value_store/ │ │ │ ├── __init__.py │ │ │ ├── factory.py │ │ │ ├── interface.py │ │ │ └── store.py │ │ ├── kg/ │ │ │ ├── clustering/ │ │ │ │ ├── clustering.py │ │ │ │ └── normalizations.py │ │ │ ├── extractions/ │ │ │ │ └── extraction_processing.py │ │ │ ├── models.py │ │ │ ├── resets/ │ │ │ │ ├── reset_index.py │ │ │ │ ├── reset_source.py │ │ │ │ └── reset_vespa.py │ │ │ ├── setup/ │ │ │ │ └── kg_default_entity_definitions.py │ │ │ ├── utils/ │ │ │ │ ├── embeddings.py │ │ │ │ ├── extraction_utils.py │ │ │ │ ├── formatting_utils.py │ │ │ │ └── lock_utils.py │ │ │ └── vespa/ │ │ │ └── vespa_interactions.py │ │ ├── llm/ │ │ │ ├── __init__.py │ │ │ ├── constants.py │ │ │ ├── cost.py │ │ │ ├── factory.py │ │ │ ├── interfaces.py │ │ │ ├── litellm_singleton/ │ │ │ │ ├── __init__.py │ │ │ │ ├── config.py │ │ │ │ └── monkey_patches.py │ │ │ ├── model_metadata_enrichments.json │ │ │ ├── model_name_parser.py │ │ │ ├── model_response.py │ │ │ ├── models.py │ │ │ ├── multi_llm.py │ │ │ ├── override_models.py │ │ │ ├── prompt_cache/ │ │ │ │ ├── README.md │ │ │ │ ├── __init__.py │ │ │ │ ├── cache_manager.py │ │ │ │ ├── models.py │ │ │ │ ├── processor.py │ │ │ │ ├── providers/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── anthropic.py │ │ │ │ │ ├── base.py │ │ │ │ │ ├── factory.py │ │ │ │ │ ├── noop.py │ │ │ │ │ ├── openai.py │ │ │ │ │ └── vertex.py │ │ │ │ └── utils.py │ │ │ ├── request_context.py │ │ │ ├── utils.py │ │ │ └── well_known_providers/ │ │ │ ├── auto_update_models.py │ │ │ ├── auto_update_service.py │ │ │ ├── constants.py │ │ │ ├── llm_provider_options.py │ │ │ ├── models.py │ │ │ └── recommended-models.json │ │ ├── main.py │ │ ├── mcp_server/ │ │ │ ├── README.md │ │ │ ├── api.py │ │ │ ├── auth.py │ │ │ ├── mcp.json.template │ │ │ ├── resources/ │ │ │ │ ├── __init__.py │ │ │ │ └── indexed_sources.py │ │ │ ├── tools/ │ │ │ │ ├── __init__.py │ │ │ │ └── search.py │ │ │ └── utils.py │ │ ├── mcp_server_main.py │ │ ├── natural_language_processing/ │ │ │ ├── __init__.py │ │ │ ├── constants.py │ │ │ ├── english_stopwords.py │ │ │ ├── exceptions.py │ │ │ ├── search_nlp_models.py │ │ │ └── utils.py │ │ ├── onyxbot/ │ │ │ ├── discord/ │ │ │ │ ├── DISCORD_MULTITENANT_README.md │ │ │ │ ├── api_client.py │ │ │ │ ├── cache.py │ │ │ │ ├── client.py │ │ │ │ ├── constants.py │ │ │ │ ├── exceptions.py │ │ │ │ ├── handle_commands.py │ │ │ │ ├── handle_message.py │ │ │ │ └── utils.py │ │ │ └── slack/ │ │ │ ├── blocks.py │ │ │ ├── config.py │ │ │ ├── constants.py │ │ │ ├── formatting.py │ │ │ ├── handlers/ │ │ │ │ ├── __init__.py │ │ │ │ ├── handle_buttons.py │ │ │ │ ├── handle_message.py │ │ │ │ ├── handle_regular_answer.py │ │ │ │ ├── handle_standard_answers.py │ │ │ │ └── utils.py │ │ │ ├── icons.py │ │ │ ├── listener.py │ │ │ ├── models.py │ │ │ └── utils.py │ │ ├── prompts/ │ │ │ ├── __init__.py │ │ │ ├── basic_memory.py │ │ │ ├── chat_prompts.py │ │ │ ├── chat_tools.py │ │ │ ├── compression_prompts.py │ │ │ ├── constants.py │ │ │ ├── contextual_retrieval.py │ │ │ ├── deep_research/ │ │ │ │ ├── __init__.py │ │ │ │ ├── dr_tool_prompts.py │ │ │ │ ├── orchestration_layer.py │ │ │ │ └── research_agent.py │ │ │ ├── federated_search.py │ │ │ ├── filter_extration.py │ │ │ ├── image_analysis.py │ │ │ ├── kg_prompts.py │ │ │ ├── prompt_template.py │ │ │ ├── prompt_utils.py │ │ │ ├── search_prompts.py │ │ │ ├── tool_prompts.py │ │ │ └── user_info.py │ │ ├── redis/ │ │ │ ├── iam_auth.py │ │ │ ├── redis_connector.py │ │ │ ├── redis_connector_delete.py │ │ │ ├── redis_connector_doc_perm_sync.py │ │ │ ├── redis_connector_ext_group_sync.py │ │ │ ├── redis_connector_index.py │ │ │ ├── redis_connector_prune.py │ │ │ ├── redis_connector_stop.py │ │ │ ├── redis_connector_utils.py │ │ │ ├── redis_document_set.py │ │ │ ├── redis_hierarchy.py │ │ │ ├── redis_object_helper.py │ │ │ ├── redis_pool.py │ │ │ ├── redis_usergroup.py │ │ │ └── redis_utils.py │ │ ├── secondary_llm_flows/ │ │ │ ├── __init__.py │ │ │ ├── chat_session_naming.py │ │ │ ├── document_filter.py │ │ │ ├── memory_update.py │ │ │ ├── query_expansion.py │ │ │ ├── source_filter.py │ │ │ └── time_filter.py │ │ ├── seeding/ │ │ │ └── __init__.py │ │ ├── server/ │ │ │ ├── __init__.py │ │ │ ├── api_key/ │ │ │ │ ├── api.py │ │ │ │ └── models.py │ │ │ ├── api_key_usage.py │ │ │ ├── auth_check.py │ │ │ ├── documents/ │ │ │ │ ├── __init__.py │ │ │ │ ├── cc_pair.py │ │ │ │ ├── connector.py │ │ │ │ ├── credential.py │ │ │ │ ├── document.py │ │ │ │ ├── document_utils.py │ │ │ │ ├── models.py │ │ │ │ ├── private_key_types.py │ │ │ │ └── standard_oauth.py │ │ │ ├── evals/ │ │ │ │ ├── __init__.py │ │ │ │ └── models.py │ │ │ ├── features/ │ │ │ │ ├── __init__.py │ │ │ │ ├── build/ │ │ │ │ │ ├── .gitignore │ │ │ │ │ ├── AGENTS.template.md │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── api/ │ │ │ │ │ │ ├── api.py │ │ │ │ │ │ ├── messages_api.py │ │ │ │ │ │ ├── models.py │ │ │ │ │ │ ├── packet_logger.py │ │ │ │ │ │ ├── packets.py │ │ │ │ │ │ ├── rate_limit.py │ │ │ │ │ │ ├── sessions_api.py │ │ │ │ │ │ ├── subscription_check.py │ │ │ │ │ │ ├── templates/ │ │ │ │ │ │ │ ├── webapp_hmr_fixer.js │ │ │ │ │ │ │ └── webapp_offline.html │ │ │ │ │ │ └── user_library.py │ │ │ │ │ ├── configs.py │ │ │ │ │ ├── db/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── build_session.py │ │ │ │ │ │ ├── rate_limit.py │ │ │ │ │ │ ├── sandbox.py │ │ │ │ │ │ └── user_library.py │ │ │ │ │ ├── indexing/ │ │ │ │ │ │ └── persistent_document_writer.py │ │ │ │ │ ├── s3/ │ │ │ │ │ │ └── s3_client.py │ │ │ │ │ ├── sandbox/ │ │ │ │ │ │ ├── README.md │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── base.py │ │ │ │ │ │ ├── kubernetes/ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ ├── docker/ │ │ │ │ │ │ │ │ ├── Dockerfile │ │ │ │ │ │ │ │ ├── README.md │ │ │ │ │ │ │ │ ├── generate_agents_md.py │ │ │ │ │ │ │ │ ├── initial-requirements.txt │ │ │ │ │ │ │ │ ├── run-test.sh │ │ │ │ │ │ │ │ ├── skills/ │ │ │ │ │ │ │ │ │ ├── image-generation/ │ │ │ │ │ │ │ │ │ │ ├── SKILL.md │ │ │ │ │ │ │ │ │ │ └── scripts/ │ │ │ │ │ │ │ │ │ │ └── generate.py │ │ │ │ │ │ │ │ │ └── pptx/ │ │ │ │ │ │ │ │ │ ├── SKILL.md │ │ │ │ │ │ │ │ │ ├── editing.md │ │ │ │ │ │ │ │ │ ├── pptxgenjs.md │ │ │ │ │ │ │ │ │ └── scripts/ │ │ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ │ │ ├── add_slide.py │ │ │ │ │ │ │ │ │ ├── clean.py │ │ │ │ │ │ │ │ │ ├── office/ │ │ │ │ │ │ │ │ │ │ ├── helpers/ │ │ │ │ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ │ │ │ │ ├── merge_runs.py │ │ │ │ │ │ │ │ │ │ │ └── simplify_redlines.py │ │ │ │ │ │ │ │ │ │ ├── pack.py │ │ │ │ │ │ │ │ │ │ ├── schemas/ │ │ │ │ │ │ │ │ │ │ │ ├── ISO-IEC29500-4_2016/ │ │ │ │ │ │ │ │ │ │ │ │ ├── dml-chart.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── dml-chartDrawing.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── dml-diagram.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── dml-lockedCanvas.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── dml-main.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── dml-picture.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── dml-spreadsheetDrawing.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── dml-wordprocessingDrawing.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── pml.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── shared-additionalCharacteristics.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── shared-bibliography.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── shared-commonSimpleTypes.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── shared-customXmlDataProperties.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── shared-customXmlSchemaProperties.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── shared-documentPropertiesCustom.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── shared-documentPropertiesExtended.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── shared-documentPropertiesVariantTypes.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── shared-math.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── shared-relationshipReference.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── sml.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── vml-main.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── vml-officeDrawing.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── vml-presentationDrawing.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── vml-spreadsheetDrawing.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── vml-wordprocessingDrawing.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── wml.xsd │ │ │ │ │ │ │ │ │ │ │ │ └── xml.xsd │ │ │ │ │ │ │ │ │ │ │ ├── ecma/ │ │ │ │ │ │ │ │ │ │ │ │ └── fouth-edition/ │ │ │ │ │ │ │ │ │ │ │ │ ├── opc-contentTypes.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── opc-coreProperties.xsd │ │ │ │ │ │ │ │ │ │ │ │ ├── opc-digSig.xsd │ │ │ │ │ │ │ │ │ │ │ │ └── opc-relationships.xsd │ │ │ │ │ │ │ │ │ │ │ ├── mce/ │ │ │ │ │ │ │ │ │ │ │ │ └── mc.xsd │ │ │ │ │ │ │ │ │ │ │ └── microsoft/ │ │ │ │ │ │ │ │ │ │ │ ├── wml-2010.xsd │ │ │ │ │ │ │ │ │ │ │ ├── wml-2012.xsd │ │ │ │ │ │ │ │ │ │ │ ├── wml-2018.xsd │ │ │ │ │ │ │ │ │ │ │ ├── wml-cex-2018.xsd │ │ │ │ │ │ │ │ │ │ │ ├── wml-cid-2016.xsd │ │ │ │ │ │ │ │ │ │ │ ├── wml-sdtdatahash-2020.xsd │ │ │ │ │ │ │ │ │ │ │ └── wml-symex-2015.xsd │ │ │ │ │ │ │ │ │ │ ├── soffice.py │ │ │ │ │ │ │ │ │ │ ├── unpack.py │ │ │ │ │ │ │ │ │ │ ├── validate.py │ │ │ │ │ │ │ │ │ │ └── validators/ │ │ │ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ │ │ │ ├── base.py │ │ │ │ │ │ │ │ │ │ ├── docx.py │ │ │ │ │ │ │ │ │ │ ├── pptx.py │ │ │ │ │ │ │ │ │ │ └── redlining.py │ │ │ │ │ │ │ │ │ ├── preview.py │ │ │ │ │ │ │ │ │ └── thumbnail.py │ │ │ │ │ │ │ │ ├── templates/ │ │ │ │ │ │ │ │ │ └── outputs/ │ │ │ │ │ │ │ │ │ └── web/ │ │ │ │ │ │ │ │ │ ├── .gitignore │ │ │ │ │ │ │ │ │ ├── AGENTS.md │ │ │ │ │ │ │ │ │ ├── app/ │ │ │ │ │ │ │ │ │ │ ├── globals.css │ │ │ │ │ │ │ │ │ │ ├── layout.tsx │ │ │ │ │ │ │ │ │ │ ├── page.tsx │ │ │ │ │ │ │ │ │ │ └── site.webmanifest │ │ │ │ │ │ │ │ │ ├── components/ │ │ │ │ │ │ │ │ │ │ ├── component-example.tsx │ │ │ │ │ │ │ │ │ │ ├── example.tsx │ │ │ │ │ │ │ │ │ │ └── ui/ │ │ │ │ │ │ │ │ │ │ ├── accordion.tsx │ │ │ │ │ │ │ │ │ │ ├── alert-dialog.tsx │ │ │ │ │ │ │ │ │ │ ├── alert.tsx │ │ │ │ │ │ │ │ │ │ ├── aspect-ratio.tsx │ │ │ │ │ │ │ │ │ │ ├── avatar.tsx │ │ │ │ │ │ │ │ │ │ ├── badge.tsx │ │ │ │ │ │ │ │ │ │ ├── breadcrumb.tsx │ │ │ │ │ │ │ │ │ │ ├── button-group.tsx │ │ │ │ │ │ │ │ │ │ ├── button.tsx │ │ │ │ │ │ │ │ │ │ ├── calendar.tsx │ │ │ │ │ │ │ │ │ │ ├── card.tsx │ │ │ │ │ │ │ │ │ │ ├── carousel.tsx │ │ │ │ │ │ │ │ │ │ ├── chart.tsx │ │ │ │ │ │ │ │ │ │ ├── checkbox.tsx │ │ │ │ │ │ │ │ │ │ ├── collapsible.tsx │ │ │ │ │ │ │ │ │ │ ├── combobox.tsx │ │ │ │ │ │ │ │ │ │ ├── command.tsx │ │ │ │ │ │ │ │ │ │ ├── context-menu.tsx │ │ │ │ │ │ │ │ │ │ ├── dialog.tsx │ │ │ │ │ │ │ │ │ │ ├── drawer.tsx │ │ │ │ │ │ │ │ │ │ ├── dropdown-menu.tsx │ │ │ │ │ │ │ │ │ │ ├── empty.tsx │ │ │ │ │ │ │ │ │ │ ├── field.tsx │ │ │ │ │ │ │ │ │ │ ├── hover-card.tsx │ │ │ │ │ │ │ │ │ │ ├── input-group.tsx │ │ │ │ │ │ │ │ │ │ ├── input.tsx │ │ │ │ │ │ │ │ │ │ ├── item.tsx │ │ │ │ │ │ │ │ │ │ ├── kbd.tsx │ │ │ │ │ │ │ │ │ │ ├── label.tsx │ │ │ │ │ │ │ │ │ │ ├── menubar.tsx │ │ │ │ │ │ │ │ │ │ ├── native-select.tsx │ │ │ │ │ │ │ │ │ │ ├── navigation-menu.tsx │ │ │ │ │ │ │ │ │ │ ├── pagination.tsx │ │ │ │ │ │ │ │ │ │ ├── popover.tsx │ │ │ │ │ │ │ │ │ │ ├── progress.tsx │ │ │ │ │ │ │ │ │ │ ├── radio-group.tsx │ │ │ │ │ │ │ │ │ │ ├── resizable.tsx │ │ │ │ │ │ │ │ │ │ ├── scroll-area.tsx │ │ │ │ │ │ │ │ │ │ ├── select.tsx │ │ │ │ │ │ │ │ │ │ ├── separator.tsx │ │ │ │ │ │ │ │ │ │ ├── sheet.tsx │ │ │ │ │ │ │ │ │ │ ├── sidebar.tsx │ │ │ │ │ │ │ │ │ │ ├── skeleton.tsx │ │ │ │ │ │ │ │ │ │ ├── slider.tsx │ │ │ │ │ │ │ │ │ │ ├── sonner.tsx │ │ │ │ │ │ │ │ │ │ ├── spinner.tsx │ │ │ │ │ │ │ │ │ │ ├── switch.tsx │ │ │ │ │ │ │ │ │ │ ├── table.tsx │ │ │ │ │ │ │ │ │ │ ├── tabs.tsx │ │ │ │ │ │ │ │ │ │ ├── textarea.tsx │ │ │ │ │ │ │ │ │ │ ├── toggle-group.tsx │ │ │ │ │ │ │ │ │ │ ├── toggle.tsx │ │ │ │ │ │ │ │ │ │ └── tooltip.tsx │ │ │ │ │ │ │ │ │ ├── components.json │ │ │ │ │ │ │ │ │ ├── eslint.config.mjs │ │ │ │ │ │ │ │ │ ├── hooks/ │ │ │ │ │ │ │ │ │ │ └── use-mobile.ts │ │ │ │ │ │ │ │ │ ├── lib/ │ │ │ │ │ │ │ │ │ │ └── utils.ts │ │ │ │ │ │ │ │ │ ├── next.config.ts │ │ │ │ │ │ │ │ │ ├── package.json │ │ │ │ │ │ │ │ │ ├── postcss.config.mjs │ │ │ │ │ │ │ │ │ └── tsconfig.json │ │ │ │ │ │ │ │ └── test-job.yaml │ │ │ │ │ │ │ ├── internal/ │ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ │ └── acp_exec_client.py │ │ │ │ │ │ │ └── kubernetes_sandbox_manager.py │ │ │ │ │ │ ├── local/ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ ├── agent_client.py │ │ │ │ │ │ │ ├── local_sandbox_manager.py │ │ │ │ │ │ │ ├── process_manager.py │ │ │ │ │ │ │ ├── test_agent_client.py │ │ │ │ │ │ │ └── test_manager.py │ │ │ │ │ │ ├── manager/ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ ├── directory_manager.py │ │ │ │ │ │ │ ├── snapshot_manager.py │ │ │ │ │ │ │ └── test_directory_manager.py │ │ │ │ │ │ ├── models.py │ │ │ │ │ │ ├── tasks/ │ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ │ └── tasks.py │ │ │ │ │ │ └── util/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── agent_instructions.py │ │ │ │ │ │ ├── build_venv_template.py │ │ │ │ │ │ ├── opencode_config.py │ │ │ │ │ │ └── persona_mapping.py │ │ │ │ │ ├── session/ │ │ │ │ │ │ ├── __init__.py │ │ │ │ │ │ ├── manager.py │ │ │ │ │ │ └── prompts.py │ │ │ │ │ └── utils.py │ │ │ │ ├── default_assistant/ │ │ │ │ │ ├── api.py │ │ │ │ │ └── models.py │ │ │ │ ├── document_set/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── api.py │ │ │ │ │ └── models.py │ │ │ │ ├── hierarchy/ │ │ │ │ │ ├── api.py │ │ │ │ │ ├── constants.py │ │ │ │ │ └── models.py │ │ │ │ ├── hooks/ │ │ │ │ │ └── __init__.py │ │ │ │ ├── input_prompt/ │ │ │ │ │ ├── api.py │ │ │ │ │ └── models.py │ │ │ │ ├── mcp/ │ │ │ │ │ ├── api.py │ │ │ │ │ └── models.py │ │ │ │ ├── notifications/ │ │ │ │ │ └── api.py │ │ │ │ ├── oauth_config/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── api.py │ │ │ │ │ └── models.py │ │ │ │ ├── password/ │ │ │ │ │ ├── api.py │ │ │ │ │ └── models.py │ │ │ │ ├── persona/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── api.py │ │ │ │ │ ├── constants.py │ │ │ │ │ └── models.py │ │ │ │ ├── projects/ │ │ │ │ │ ├── api.py │ │ │ │ │ ├── models.py │ │ │ │ │ └── projects_file_utils.py │ │ │ │ ├── release_notes/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── constants.py │ │ │ │ │ ├── models.py │ │ │ │ │ └── utils.py │ │ │ │ ├── tool/ │ │ │ │ │ ├── api.py │ │ │ │ │ ├── models.py │ │ │ │ │ └── tool_visibility.py │ │ │ │ ├── user_oauth_token/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── api.py │ │ │ │ └── web_search/ │ │ │ │ ├── api.py │ │ │ │ └── models.py │ │ │ ├── federated/ │ │ │ │ ├── api.py │ │ │ │ └── models.py │ │ │ ├── kg/ │ │ │ │ ├── api.py │ │ │ │ └── models.py │ │ │ ├── manage/ │ │ │ │ ├── __init__.py │ │ │ │ ├── administrative.py │ │ │ │ ├── code_interpreter/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── api.py │ │ │ │ │ └── models.py │ │ │ │ ├── discord_bot/ │ │ │ │ │ ├── api.py │ │ │ │ │ ├── models.py │ │ │ │ │ └── utils.py │ │ │ │ ├── embedding/ │ │ │ │ │ ├── api.py │ │ │ │ │ └── models.py │ │ │ │ ├── get_state.py │ │ │ │ ├── image_generation/ │ │ │ │ │ ├── api.py │ │ │ │ │ └── models.py │ │ │ │ ├── llm/ │ │ │ │ │ ├── api.py │ │ │ │ │ ├── models.py │ │ │ │ │ └── utils.py │ │ │ │ ├── models.py │ │ │ │ ├── opensearch_migration/ │ │ │ │ │ ├── api.py │ │ │ │ │ └── models.py │ │ │ │ ├── search_settings.py │ │ │ │ ├── slack_bot.py │ │ │ │ ├── users.py │ │ │ │ ├── validate_tokens.py │ │ │ │ ├── voice/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── api.py │ │ │ │ │ ├── models.py │ │ │ │ │ ├── user_api.py │ │ │ │ │ └── websocket_api.py │ │ │ │ └── web_search/ │ │ │ │ ├── api.py │ │ │ │ └── models.py │ │ │ ├── metrics/ │ │ │ │ ├── __init__.py │ │ │ │ ├── celery_task_metrics.py │ │ │ │ ├── indexing_pipeline.py │ │ │ │ ├── indexing_pipeline_setup.py │ │ │ │ ├── indexing_task_metrics.py │ │ │ │ ├── metrics_server.py │ │ │ │ ├── opensearch_search.py │ │ │ │ ├── per_tenant.py │ │ │ │ ├── postgres_connection_pool.py │ │ │ │ ├── prometheus_setup.py │ │ │ │ └── slow_requests.py │ │ │ ├── middleware/ │ │ │ │ ├── latency_logging.py │ │ │ │ └── rate_limiting.py │ │ │ ├── models.py │ │ │ ├── onyx_api/ │ │ │ │ ├── __init__.py │ │ │ │ ├── ingestion.py │ │ │ │ └── models.py │ │ │ ├── pat/ │ │ │ │ ├── __init__.py │ │ │ │ ├── api.py │ │ │ │ └── models.py │ │ │ ├── query_and_chat/ │ │ │ │ ├── __init__.py │ │ │ │ ├── chat_backend.py │ │ │ │ ├── chat_utils.py │ │ │ │ ├── models.py │ │ │ │ ├── placement.py │ │ │ │ ├── query_backend.py │ │ │ │ ├── session_loading.py │ │ │ │ ├── streaming_models.py │ │ │ │ └── token_limit.py │ │ │ ├── runtime/ │ │ │ │ └── onyx_runtime.py │ │ │ ├── saml.py │ │ │ ├── settings/ │ │ │ │ ├── api.py │ │ │ │ ├── models.py │ │ │ │ └── store.py │ │ │ ├── tenant_usage_limits.py │ │ │ ├── token_rate_limits/ │ │ │ │ ├── api.py │ │ │ │ └── models.py │ │ │ ├── usage_limits.py │ │ │ ├── utils.py │ │ │ └── utils_vector_db.py │ │ ├── setup.py │ │ ├── tools/ │ │ │ ├── built_in_tools.py │ │ │ ├── constants.py │ │ │ ├── fake_tools/ │ │ │ │ ├── __init__.py │ │ │ │ └── research_agent.py │ │ │ ├── interface.py │ │ │ ├── models.py │ │ │ ├── tool_constructor.py │ │ │ ├── tool_implementations/ │ │ │ │ ├── custom/ │ │ │ │ │ ├── base_tool_types.py │ │ │ │ │ ├── custom_tool.py │ │ │ │ │ └── openapi_parsing.py │ │ │ │ ├── file_reader/ │ │ │ │ │ └── file_reader_tool.py │ │ │ │ ├── images/ │ │ │ │ │ ├── image_generation_tool.py │ │ │ │ │ └── models.py │ │ │ │ ├── knowledge_graph/ │ │ │ │ │ └── knowledge_graph_tool.py │ │ │ │ ├── mcp/ │ │ │ │ │ ├── mcp_client.py │ │ │ │ │ └── mcp_tool.py │ │ │ │ ├── memory/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── memory_tool.py │ │ │ │ │ └── models.py │ │ │ │ ├── open_url/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── firecrawl.py │ │ │ │ │ ├── models.py │ │ │ │ │ ├── onyx_web_crawler.py │ │ │ │ │ ├── open_url_tool.py │ │ │ │ │ ├── snippet_matcher.py │ │ │ │ │ ├── url_normalization.py │ │ │ │ │ └── utils.py │ │ │ │ ├── python/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ ├── code_interpreter_client.py │ │ │ │ │ └── python_tool.py │ │ │ │ ├── search/ │ │ │ │ │ ├── constants.py │ │ │ │ │ ├── search_tool.py │ │ │ │ │ └── search_utils.py │ │ │ │ ├── search_like_tool_utils.py │ │ │ │ ├── utils.py │ │ │ │ └── web_search/ │ │ │ │ ├── clients/ │ │ │ │ │ ├── brave_client.py │ │ │ │ │ ├── exa_client.py │ │ │ │ │ ├── google_pse_client.py │ │ │ │ │ ├── searxng_client.py │ │ │ │ │ └── serper_client.py │ │ │ │ ├── models.py │ │ │ │ ├── providers.py │ │ │ │ ├── utils.py │ │ │ │ └── web_search_tool.py │ │ │ ├── tool_runner.py │ │ │ └── utils.py │ │ ├── tracing/ │ │ │ ├── braintrust_tracing_processor.py │ │ │ ├── framework/ │ │ │ │ ├── __init__.py │ │ │ │ ├── _error_tracing.py │ │ │ │ ├── create.py │ │ │ │ ├── processor_interface.py │ │ │ │ ├── provider.py │ │ │ │ ├── scope.py │ │ │ │ ├── setup.py │ │ │ │ ├── span_data.py │ │ │ │ ├── spans.py │ │ │ │ ├── traces.py │ │ │ │ └── util.py │ │ │ ├── langfuse_tracing_processor.py │ │ │ ├── llm_utils.py │ │ │ ├── masking.py │ │ │ └── setup.py │ │ ├── utils/ │ │ │ ├── __init__.py │ │ │ ├── b64.py │ │ │ ├── batching.py │ │ │ ├── callbacks.py │ │ │ ├── encryption.py │ │ │ ├── error_handling.py │ │ │ ├── errors.py │ │ │ ├── file.py │ │ │ ├── gpu_utils.py │ │ │ ├── headers.py │ │ │ ├── jsonriver/ │ │ │ │ ├── __init__.py │ │ │ │ ├── parse.py │ │ │ │ └── tokenize.py │ │ │ ├── logger.py │ │ │ ├── long_term_log.py │ │ │ ├── memory_logger.py │ │ │ ├── middleware.py │ │ │ ├── object_size_check.py │ │ │ ├── postgres_sanitization.py │ │ │ ├── pydantic_util.py │ │ │ ├── retry_wrapper.py │ │ │ ├── search_nlp_models_utils.py │ │ │ ├── sensitive.py │ │ │ ├── sitemap.py │ │ │ ├── special_types.py │ │ │ ├── subclasses.py │ │ │ ├── supervisord_watchdog.py │ │ │ ├── telemetry.py │ │ │ ├── tenant.py │ │ │ ├── text_processing.py │ │ │ ├── threadpool_concurrency.py │ │ │ ├── timing.py │ │ │ ├── url.py │ │ │ ├── variable_functionality.py │ │ │ └── web_content.py │ │ └── voice/ │ │ ├── __init__.py │ │ ├── factory.py │ │ ├── interface.py │ │ └── providers/ │ │ ├── __init__.py │ │ ├── azure.py │ │ ├── elevenlabs.py │ │ └── openai.py │ ├── pyproject.toml │ ├── pytest.ini │ ├── requirements/ │ │ ├── README.md │ │ ├── combined.txt │ │ ├── default.txt │ │ ├── dev.txt │ │ ├── ee.txt │ │ └── model_server.txt │ ├── scripts/ │ │ ├── __init__.py │ │ ├── add_connector_creation_script.py │ │ ├── api_inference_sample.py │ │ ├── celery_purge_queue.py │ │ ├── chat_feedback_dump.py │ │ ├── chat_history_seeding.py │ │ ├── chat_loadtest.py │ │ ├── debugging/ │ │ │ ├── debug_usage_limits.py │ │ │ ├── litellm/ │ │ │ │ ├── README │ │ │ │ ├── call_litellm.py │ │ │ │ ├── directly_hit_azure_api.py │ │ │ │ └── payload.json │ │ │ ├── onyx_db.py │ │ │ ├── onyx_list_tenants.py │ │ │ ├── onyx_redis.py │ │ │ ├── onyx_vespa_schemas.py │ │ │ └── opensearch/ │ │ │ ├── benchmark_retrieval.py │ │ │ ├── constants.py │ │ │ ├── embed_and_save.py │ │ │ ├── embedding_io.py │ │ │ ├── opensearch_debug.py │ │ │ └── query_hierarchy_debug.py │ │ ├── decrypt.py │ │ ├── dev_run_background_jobs.py │ │ ├── docker_memory_tracking.sh │ │ ├── force_delete_connector_by_id.py │ │ ├── get_wikidocs.py │ │ ├── hard_delete_chats.py │ │ ├── lib/ │ │ │ └── logger.py │ │ ├── make_foss_repo.sh │ │ ├── onyx_openapi_schema.py │ │ ├── orphan_doc_cleanup_script.py │ │ ├── query_time_check/ │ │ │ ├── seed_dummy_docs.py │ │ │ └── test_query_times.py │ │ ├── reencrypt_secrets.py │ │ ├── reset_indexes.py │ │ ├── reset_postgres.py │ │ ├── restart_containers.sh │ │ ├── resume_paused_connectors.py │ │ ├── run_industryrag_bench_questions.py │ │ ├── save_load_state.py │ │ ├── setup_craft_templates.sh │ │ ├── sources_selection_analysis.py │ │ ├── supervisord_entrypoint.sh │ │ ├── tenant_cleanup/ │ │ │ ├── QUICK_START_NO_BASTION.md │ │ │ ├── README.md │ │ │ ├── analyze_current_tenants.py │ │ │ ├── check_no_bastion_setup.py │ │ │ ├── cleanup_tenants.py │ │ │ ├── cleanup_utils.py │ │ │ ├── mark_connectors_for_deletion.py │ │ │ ├── no_bastion_analyze_tenants.py │ │ │ ├── no_bastion_cleanup_tenants.py │ │ │ ├── no_bastion_cleanup_utils.py │ │ │ ├── no_bastion_mark_connectors.py │ │ │ └── on_pod_scripts/ │ │ │ ├── check_documents_deleted.py │ │ │ ├── cleanup_tenant_schema.py │ │ │ ├── execute_connector_deletion.py │ │ │ ├── get_tenant_connectors.py │ │ │ ├── get_tenant_index_name.py │ │ │ ├── get_tenant_users.py │ │ │ └── understand_tenants.py │ │ ├── test-openapi-key.py │ │ ├── transform_openapi_for_docs.py │ │ └── upload_files_as_connectors.py │ ├── shared_configs/ │ │ ├── __init__.py │ │ ├── configs.py │ │ ├── contextvars.py │ │ ├── enums.py │ │ ├── model_server_models.py │ │ └── utils.py │ ├── slackbot_images/ │ │ └── README.md │ ├── supervisord.conf │ └── tests/ │ ├── README.md │ ├── __init__.py │ ├── api/ │ │ └── test_api.py │ ├── conftest.py │ ├── daily/ │ │ ├── conftest.py │ │ ├── connectors/ │ │ │ ├── airtable/ │ │ │ │ └── test_airtable_basic.py │ │ │ ├── bitbucket/ │ │ │ │ ├── conftest.py │ │ │ │ ├── test_bitbucket_checkpointed.py │ │ │ │ └── test_bitbucket_slim_connector.py │ │ │ ├── blob/ │ │ │ │ └── test_blob_connector.py │ │ │ ├── coda/ │ │ │ │ ├── README.md │ │ │ │ └── test_coda_connector.py │ │ │ ├── confluence/ │ │ │ │ ├── models.py │ │ │ │ ├── test_confluence_basic.py │ │ │ │ ├── test_confluence_permissions_basic.py │ │ │ │ └── test_confluence_user_email_overrides.py │ │ │ ├── conftest.py │ │ │ ├── discord/ │ │ │ │ └── test_discord_connector.py │ │ │ ├── file/ │ │ │ │ └── test_file_connector.py │ │ │ ├── fireflies/ │ │ │ │ ├── test_fireflies_connector.py │ │ │ │ └── test_fireflies_data.json │ │ │ ├── gitbook/ │ │ │ │ └── test_gitbook_connector.py │ │ │ ├── github/ │ │ │ │ └── test_github_basic.py │ │ │ ├── gitlab/ │ │ │ │ └── test_gitlab_basic.py │ │ │ ├── gmail/ │ │ │ │ ├── conftest.py │ │ │ │ └── test_gmail_connector.py │ │ │ ├── gong/ │ │ │ │ └── test_gong.py │ │ │ ├── google_drive/ │ │ │ │ ├── conftest.py │ │ │ │ ├── consts_and_utils.py │ │ │ │ ├── drive_id_mapping.json │ │ │ │ ├── test_admin_oauth.py │ │ │ │ ├── test_drive_perm_sync.py │ │ │ │ ├── test_link_visibility_filter.py │ │ │ │ ├── test_map_test_ids.py │ │ │ │ ├── test_sections.py │ │ │ │ ├── test_service_acct.py │ │ │ │ └── test_user_1_oauth.py │ │ │ ├── highspot/ │ │ │ │ ├── test_highspot_connector.py │ │ │ │ └── test_highspot_data.json │ │ │ ├── hubspot/ │ │ │ │ └── test_hubspot_connector.py │ │ │ ├── imap/ │ │ │ │ ├── models.py │ │ │ │ └── test_imap_connector.py │ │ │ ├── jira/ │ │ │ │ └── test_jira_basic.py │ │ │ ├── notion/ │ │ │ │ └── test_notion_connector.py │ │ │ ├── outline/ │ │ │ │ └── test_outline_connector.py │ │ │ ├── salesforce/ │ │ │ │ ├── test_salesforce_connector.py │ │ │ │ └── test_salesforce_data.json │ │ │ ├── sharepoint/ │ │ │ │ └── test_sharepoint_connector.py │ │ │ ├── slab/ │ │ │ │ ├── test_slab_connector.py │ │ │ │ └── test_slab_data.json │ │ │ ├── slack/ │ │ │ │ ├── conftest.py │ │ │ │ ├── test_slack_connector.py │ │ │ │ └── test_slack_perm_sync.py │ │ │ ├── teams/ │ │ │ │ ├── models.py │ │ │ │ └── test_teams_connector.py │ │ │ ├── utils.py │ │ │ ├── web/ │ │ │ │ └── test_web_connector.py │ │ │ └── zendesk/ │ │ │ ├── test_zendesk_connector.py │ │ │ └── test_zendesk_data.json │ │ ├── embedding/ │ │ │ └── test_embeddings.py │ │ └── llm/ │ │ └── test_bedrock.py │ ├── external_dependency_unit/ │ │ ├── answer/ │ │ │ ├── conftest.py │ │ │ ├── stream_test_assertions.py │ │ │ ├── stream_test_builder.py │ │ │ ├── stream_test_utils.py │ │ │ ├── test_answer_without_openai.py │ │ │ ├── test_current_datetime_replacement.py │ │ │ ├── test_stream_chat_message.py │ │ │ └── test_stream_chat_message_objects.py │ │ ├── background/ │ │ │ ├── test_periodic_task_claim.py │ │ │ └── test_startup_recovery.py │ │ ├── cache/ │ │ │ ├── conftest.py │ │ │ ├── test_cache_backend_parity.py │ │ │ ├── test_kv_store_cache_layer.py │ │ │ └── test_postgres_cache_backend.py │ │ ├── celery/ │ │ │ ├── test_docfetching_priority.py │ │ │ ├── test_docprocessing_priority.py │ │ │ ├── test_persona_file_sync.py │ │ │ ├── test_pruning_hierarchy_nodes.py │ │ │ ├── test_user_file_delete_queue.py │ │ │ ├── test_user_file_indexing_adapter.py │ │ │ └── test_user_file_processing_queue.py │ │ ├── chat/ │ │ │ └── test_user_reminder_message_type.py │ │ ├── conftest.py │ │ ├── connectors/ │ │ │ ├── confluence/ │ │ │ │ ├── conftest.py │ │ │ │ └── test_confluence_group_sync.py │ │ │ ├── google_drive/ │ │ │ │ └── test_google_drive_group_sync.py │ │ │ └── jira/ │ │ │ ├── conftest.py │ │ │ ├── test_jira_doc_sync.py │ │ │ └── test_jira_group_sync.py │ │ ├── constants.py │ │ ├── craft/ │ │ │ ├── conftest.py │ │ │ ├── test_build_packet_storage.py │ │ │ ├── test_file_upload.py │ │ │ ├── test_kubernetes_sandbox.py │ │ │ └── test_persistent_document_writer.py │ │ ├── db/ │ │ │ ├── __init__.py │ │ │ ├── conftest.py │ │ │ ├── test_chat_session_eager_load.py │ │ │ ├── test_credential_sensitive_value.py │ │ │ ├── test_rotate_encryption_key.py │ │ │ ├── test_tag_race_condition.py │ │ │ └── test_user_account_type.py │ │ ├── discord_bot/ │ │ │ ├── conftest.py │ │ │ └── test_discord_events.py │ │ ├── document_index/ │ │ │ ├── conftest.py │ │ │ ├── test_document_index.py │ │ │ └── test_document_index_old.py │ │ ├── feature_flags/ │ │ │ ├── __init__.py │ │ │ └── test_feature_flag_provider_factory.py │ │ ├── file_store/ │ │ │ ├── test_file_store_non_mocked.py │ │ │ └── test_postgres_file_store_non_mocked.py │ │ ├── full_setup.py │ │ ├── hierarchy/ │ │ │ ├── __init__.py │ │ │ └── test_hierarchy_access_filter.py │ │ ├── llm/ │ │ │ ├── test_llm_provider.py │ │ │ ├── test_llm_provider_api_base.py │ │ │ ├── test_llm_provider_auto_mode.py │ │ │ ├── test_llm_provider_called.py │ │ │ ├── test_llm_provider_default_model_protection.py │ │ │ └── test_prompt_caching.py │ │ ├── mock_content_provider.py │ │ ├── mock_image_provider.py │ │ ├── mock_llm.py │ │ ├── mock_search_pipeline.py │ │ ├── mock_search_provider.py │ │ ├── opensearch/ │ │ │ ├── test_assistant_knowledge_filter.py │ │ │ └── test_opensearch_client.py │ │ ├── opensearch_migration/ │ │ │ └── test_opensearch_migration_tasks.py │ │ ├── permission_sync/ │ │ │ ├── test_doc_permission_sync_attempt.py │ │ │ └── test_external_group_permission_sync_attempt.py │ │ ├── search_settings/ │ │ │ └── test_search_settings.py │ │ ├── slack_bot/ │ │ │ ├── __init__.py │ │ │ ├── test_slack_bot_crud.py │ │ │ └── test_slack_bot_federated_search.py │ │ ├── tools/ │ │ │ ├── data/ │ │ │ │ └── financial-sample.xlsx │ │ │ ├── test_image_generation_tool.py │ │ │ ├── test_mcp_passthrough_oauth.py │ │ │ ├── test_memory_tool_integration.py │ │ │ ├── test_oauth_config_crud.py │ │ │ ├── test_oauth_token_manager.py │ │ │ ├── test_oauth_tool_integration.py │ │ │ ├── test_python_tool.py │ │ │ └── test_python_tool_server_enabled.py │ │ └── tracing/ │ │ ├── __init__.py │ │ └── test_llm_span_recording.py │ ├── integration/ │ │ ├── Dockerfile │ │ ├── README.md │ │ ├── __init__.py │ │ ├── common_utils/ │ │ │ ├── chat.py │ │ │ ├── config.py │ │ │ ├── constants.py │ │ │ ├── document_acl.py │ │ │ ├── managers/ │ │ │ │ ├── api_key.py │ │ │ │ ├── cc_pair.py │ │ │ │ ├── chat.py │ │ │ │ ├── connector.py │ │ │ │ ├── credential.py │ │ │ │ ├── discord_bot.py │ │ │ │ ├── document.py │ │ │ │ ├── document_search.py │ │ │ │ ├── document_set.py │ │ │ │ ├── file.py │ │ │ │ ├── image_generation.py │ │ │ │ ├── index_attempt.py │ │ │ │ ├── llm_provider.py │ │ │ │ ├── pat.py │ │ │ │ ├── persona.py │ │ │ │ ├── project.py │ │ │ │ ├── query_history.py │ │ │ │ ├── scim_client.py │ │ │ │ ├── scim_token.py │ │ │ │ ├── settings.py │ │ │ │ ├── tenant.py │ │ │ │ ├── tool.py │ │ │ │ ├── user.py │ │ │ │ └── user_group.py │ │ │ ├── reset.py │ │ │ ├── test_document_utils.py │ │ │ ├── test_file_utils.py │ │ │ ├── test_files/ │ │ │ │ └── three_images.docx │ │ │ ├── test_models.py │ │ │ ├── timeout.py │ │ │ └── vespa.py │ │ ├── conftest.py │ │ ├── connector_job_tests/ │ │ │ ├── github/ │ │ │ │ ├── conftest.py │ │ │ │ ├── test_github_permission_sync.py │ │ │ │ └── utils.py │ │ │ ├── google/ │ │ │ │ ├── google_drive_api_utils.py │ │ │ │ └── test_google_drive_permission_sync.py │ │ │ ├── jira/ │ │ │ │ ├── conftest.py │ │ │ │ └── test_jira_permission_sync_full.py │ │ │ ├── sharepoint/ │ │ │ │ ├── conftest.py │ │ │ │ └── test_sharepoint_permissions.py │ │ │ └── slack/ │ │ │ ├── conftest.py │ │ │ ├── slack_api_utils.py │ │ │ ├── test_permission_sync.py │ │ │ └── test_prune.py │ │ ├── mock_services/ │ │ │ ├── docker-compose.mock-it-services.yml │ │ │ ├── mcp_test_server/ │ │ │ │ ├── run_mcp_server_api_key.py │ │ │ │ ├── run_mcp_server_google_oauth.py │ │ │ │ ├── run_mcp_server_no_auth.py │ │ │ │ ├── run_mcp_server_oauth.py │ │ │ │ └── run_mcp_server_per_user_key.py │ │ │ └── mock_connector_server/ │ │ │ ├── Dockerfile │ │ │ └── main.py │ │ ├── multitenant_tests/ │ │ │ ├── discord_bot/ │ │ │ │ └── test_discord_bot_multitenant.py │ │ │ ├── invitation/ │ │ │ │ └── test_user_invitation.py │ │ │ ├── migrations/ │ │ │ │ └── test_run_multitenant_migrations.py │ │ │ ├── syncing/ │ │ │ │ └── test_search_permissions.py │ │ │ ├── tenants/ │ │ │ │ ├── test_tenant_creation.py │ │ │ │ └── test_tenant_provisioning_rollback.py │ │ │ └── test_get_schemas_needing_migration.py │ │ └── tests/ │ │ ├── anonymous_user/ │ │ │ └── test_anonymous_user.py │ │ ├── api_key/ │ │ │ └── test_api_key.py │ │ ├── auth/ │ │ │ └── test_saml_user_conversion.py │ │ ├── chat/ │ │ │ ├── test_chat_deletion.py │ │ │ └── test_chat_session_access.py │ │ ├── chat_retention/ │ │ │ └── test_chat_retention.py │ │ ├── code_interpreter/ │ │ │ ├── conftest.py │ │ │ └── test_code_interpreter_api.py │ │ ├── connector/ │ │ │ ├── test_connector_creation.py │ │ │ ├── test_connector_deletion.py │ │ │ └── test_last_indexed_time.py │ │ ├── discord_bot/ │ │ │ ├── test_discord_bot_api.py │ │ │ └── test_discord_bot_db.py │ │ ├── document_set/ │ │ │ └── test_syncing.py │ │ ├── image_generation/ │ │ │ ├── test_image_generation_config.py │ │ │ └── test_image_generation_tool_visibility.py │ │ ├── image_indexing/ │ │ │ └── test_indexing_images.py │ │ ├── index_attempt/ │ │ │ └── test_index_attempt_pagination.py │ │ ├── indexing/ │ │ │ ├── conftest.py │ │ │ ├── file_connector/ │ │ │ │ ├── test_file_connector_zip_metadata.py │ │ │ │ └── test_files/ │ │ │ │ ├── .onyx_metadata.json │ │ │ │ ├── sample1.txt │ │ │ │ └── sample2.txt │ │ │ ├── test_checkpointing.py │ │ │ ├── test_initial_permission_sync.py │ │ │ ├── test_polling.py │ │ │ └── test_repeated_error_state.py │ │ ├── ingestion/ │ │ │ └── test_ingestion_api.py │ │ ├── kg/ │ │ │ └── test_kg_api.py │ │ ├── llm_auto_update/ │ │ │ └── test_auto_llm_update.py │ │ ├── llm_provider/ │ │ │ ├── test_llm_provider.py │ │ │ ├── test_llm_provider_access_control.py │ │ │ └── test_llm_provider_persona_access.py │ │ ├── llm_workflows/ │ │ │ ├── test_mock_llm_tool_calls.py │ │ │ ├── test_nightly_provider_chat_workflow.py │ │ │ └── test_tool_policy_enforcement.py │ │ ├── mcp/ │ │ │ ├── test_mcp_client_no_auth_flow.py │ │ │ ├── test_mcp_server_auth.py │ │ │ └── test_mcp_server_search.py │ │ ├── migrations/ │ │ │ ├── conftest.py │ │ │ ├── test_alembic_main.py │ │ │ ├── test_alembic_tenants.py │ │ │ ├── test_assistant_consolidation_migration.py │ │ │ ├── test_migrations.py │ │ │ └── test_tool_seeding.py │ │ ├── no_vectordb/ │ │ │ ├── conftest.py │ │ │ ├── test_no_vectordb_chat.py │ │ │ ├── test_no_vectordb_endpoints.py │ │ │ └── test_no_vectordb_file_lifecycle.py │ │ ├── opensearch_migration/ │ │ │ └── test_opensearch_migration_api.py │ │ ├── pat/ │ │ │ └── test_pat_api.py │ │ ├── permissions/ │ │ │ ├── test_auth_permission_propagation.py │ │ │ ├── test_cc_pair_permissions.py │ │ │ ├── test_connector_permissions.py │ │ │ ├── test_credential_permissions.py │ │ │ ├── test_doc_set_permissions.py │ │ │ ├── test_file_connector_permissions.py │ │ │ ├── test_persona_permissions.py │ │ │ ├── test_user_file_permissions.py │ │ │ ├── test_user_role_permissions.py │ │ │ └── test_whole_curator_flow.py │ │ ├── personalization/ │ │ │ └── test_personalization_flow.py │ │ ├── personas/ │ │ │ ├── test_persona_categories.py │ │ │ ├── test_persona_creation.py │ │ │ ├── test_persona_file_context.py │ │ │ ├── test_persona_label_updates.py │ │ │ ├── test_persona_pagination.py │ │ │ └── test_unified_assistant.py │ │ ├── projects/ │ │ │ └── test_projects.py │ │ ├── pruning/ │ │ │ ├── test_pruning.py │ │ │ └── website/ │ │ │ ├── about.html │ │ │ ├── contact.html │ │ │ ├── courses.html │ │ │ ├── css/ │ │ │ │ ├── animate.css │ │ │ │ ├── custom-fonts.css │ │ │ │ ├── fancybox/ │ │ │ │ │ └── jquery.fancybox.css │ │ │ │ ├── font-awesome.css │ │ │ │ └── style.css │ │ │ ├── fonts/ │ │ │ │ └── fontawesome.otf │ │ │ ├── index.html │ │ │ ├── js/ │ │ │ │ ├── animate.js │ │ │ │ ├── custom.js │ │ │ │ ├── flexslider/ │ │ │ │ │ ├── jquery.flexslider.js │ │ │ │ │ └── setting.js │ │ │ │ ├── google-code-prettify/ │ │ │ │ │ ├── prettify.css │ │ │ │ │ └── prettify.js │ │ │ │ ├── jquery.easing.1.3.js │ │ │ │ ├── jquery.fancybox-media.js │ │ │ │ ├── jquery.fancybox.pack.js │ │ │ │ ├── jquery.flexslider.js │ │ │ │ ├── jquery.js │ │ │ │ ├── portfolio/ │ │ │ │ │ ├── jquery.quicksand.js │ │ │ │ │ └── setting.js │ │ │ │ ├── quicksand/ │ │ │ │ │ ├── jquery.quicksand.js │ │ │ │ │ └── setting.js │ │ │ │ └── validate.js │ │ │ ├── portfolio.html │ │ │ ├── pricing.html │ │ │ └── readme.txt │ │ ├── query_history/ │ │ │ ├── test_query_history.py │ │ │ ├── test_query_history_pagination.py │ │ │ ├── test_usage_reports.py │ │ │ └── utils.py │ │ ├── reporting/ │ │ │ └── test_usage_export_api.py │ │ ├── scim/ │ │ │ ├── test_scim_groups.py │ │ │ ├── test_scim_tokens.py │ │ │ └── test_scim_users.py │ │ ├── search_settings/ │ │ │ └── test_search_settings.py │ │ ├── streaming_endpoints/ │ │ │ ├── test_chat_file_attachment.py │ │ │ └── test_chat_stream.py │ │ ├── tags/ │ │ │ └── test_tags.py │ │ ├── tools/ │ │ │ ├── test_force_tool_use.py │ │ │ └── test_image_generation_streaming.py │ │ ├── usergroup/ │ │ │ ├── test_add_users_to_group.py │ │ │ ├── test_group_membership_updates_user_permissions.py │ │ │ ├── test_new_group_gets_basic_permission.py │ │ │ ├── test_user_group_deletion.py │ │ │ └── test_usergroup_syncing.py │ │ ├── users/ │ │ │ ├── test_default_group_assignment.py │ │ │ ├── test_password_signup_upgrade.py │ │ │ ├── test_reactivation_groups.py │ │ │ ├── test_seat_limit.py │ │ │ ├── test_slack_user_deactivation.py │ │ │ └── test_user_pagination.py │ │ └── web_search/ │ │ └── test_web_search_api.py │ ├── load_env_vars.py │ ├── regression/ │ │ ├── answer_quality/ │ │ │ ├── README.md │ │ │ ├── __init__.py │ │ │ ├── api_utils.py │ │ │ ├── cli_utils.py │ │ │ ├── file_uploader.py │ │ │ ├── launch_eval_env.py │ │ │ └── search_test_config.yaml.template │ │ └── search_quality/ │ │ ├── README.md │ │ ├── models.py │ │ ├── run_search_eval.py │ │ ├── test_queries.json.template │ │ └── utils.py │ └── unit/ │ ├── __init__.py │ ├── build/ │ │ └── test_rewrite_asset_paths.py │ ├── ee/ │ │ ├── conftest.py │ │ └── onyx/ │ │ ├── db/ │ │ │ ├── test_license.py │ │ │ └── test_user_group_rename.py │ │ ├── external_permissions/ │ │ │ ├── salesforce/ │ │ │ │ └── test_postprocessing.py │ │ │ └── sharepoint/ │ │ │ └── test_permission_utils.py │ │ ├── hooks/ │ │ │ ├── __init__.py │ │ │ └── test_executor.py │ │ ├── server/ │ │ │ ├── __init__.py │ │ │ ├── billing/ │ │ │ │ ├── __init__.py │ │ │ │ ├── conftest.py │ │ │ │ ├── test_billing_api.py │ │ │ │ ├── test_billing_service.py │ │ │ │ └── test_proxy.py │ │ │ ├── features/ │ │ │ │ ├── __init__.py │ │ │ │ └── hooks/ │ │ │ │ ├── __init__.py │ │ │ │ └── test_api.py │ │ │ ├── license/ │ │ │ │ └── test_api.py │ │ │ ├── middleware/ │ │ │ │ └── test_license_enforcement.py │ │ │ ├── settings/ │ │ │ │ └── test_license_enforcement_settings.py │ │ │ └── tenants/ │ │ │ ├── test_billing_api.py │ │ │ ├── test_product_gating.py │ │ │ ├── test_proxy.py │ │ │ └── test_schema_management.py │ │ └── utils/ │ │ ├── test_encryption.py │ │ └── test_license_utils.py │ ├── federated_connector/ │ │ └── slack/ │ │ └── test_slack_federated_connnector.py │ ├── file_store/ │ │ ├── test_file_store.py │ │ └── test_postgres_file_store.py │ ├── model_server/ │ │ └── test_embedding.py │ ├── onyx/ │ │ ├── __init__.py │ │ ├── access/ │ │ │ └── test_user_file_access.py │ │ ├── auth/ │ │ │ ├── conftest.py │ │ │ ├── test_disposable_email_validator.py │ │ │ ├── test_email.py │ │ │ ├── test_is_same_origin.py │ │ │ ├── test_jwt_provisioning.py │ │ │ ├── test_oauth_refresher.py │ │ │ ├── test_oidc_pkce.py │ │ │ ├── test_permissions.py │ │ │ ├── test_single_tenant_jwt_strategy.py │ │ │ ├── test_user_create_schema.py │ │ │ ├── test_user_default_pins.py │ │ │ ├── test_user_registration.py │ │ │ ├── test_verify_auth_setting.py │ │ │ ├── test_verify_email_domain.py │ │ │ └── test_verify_email_invite.py │ │ ├── background/ │ │ │ └── celery/ │ │ │ ├── tasks/ │ │ │ │ ├── tenant_provisioning/ │ │ │ │ │ ├── __init__.py │ │ │ │ │ └── test_check_available_tenants.py │ │ │ │ ├── test_hierarchyfetching_queue.py │ │ │ │ ├── test_user_file_impl_redis_locking.py │ │ │ │ ├── test_user_file_processing_no_vectordb.py │ │ │ │ └── test_user_file_project_sync_queue.py │ │ │ └── test_celery_redis.py │ │ ├── chat/ │ │ │ ├── test_argument_delta_streaming.py │ │ │ ├── test_chat_utils.py │ │ │ ├── test_citation_processor.py │ │ │ ├── test_citation_utils.py │ │ │ ├── test_compression.py │ │ │ ├── test_context_files.py │ │ │ ├── test_emitter.py │ │ │ ├── test_llm_loop.py │ │ │ ├── test_llm_step.py │ │ │ ├── test_multi_model_streaming.py │ │ │ ├── test_multi_model_types.py │ │ │ ├── test_process_message.py │ │ │ ├── test_process_message_mock_llm.py │ │ │ ├── test_save_chat.py │ │ │ └── test_stop_signal_checker.py │ │ ├── connectors/ │ │ │ ├── airtable/ │ │ │ │ └── test_airtable_index_all.py │ │ │ ├── asana/ │ │ │ │ └── test_asana_connector.py │ │ │ ├── canvas/ │ │ │ │ └── test_canvas_connector.py │ │ │ ├── confluence/ │ │ │ │ ├── test_confluence_checkpointing.py │ │ │ │ ├── test_onyx_confluence.py │ │ │ │ └── test_rate_limit_handler.py │ │ │ ├── cross_connector_utils/ │ │ │ │ ├── test_html_utils.py │ │ │ │ ├── test_rate_limit.py │ │ │ │ └── test_table.html │ │ │ ├── discord/ │ │ │ │ └── test_discord_validation.py │ │ │ ├── github/ │ │ │ │ └── test_github_checkpointing.py │ │ │ ├── gmail/ │ │ │ │ ├── test_connector.py │ │ │ │ └── thread.json │ │ │ ├── google_utils/ │ │ │ │ └── test_rate_limit_detection.py │ │ │ ├── jira/ │ │ │ │ ├── conftest.py │ │ │ │ ├── test_jira_bulk_fetch.py │ │ │ │ ├── test_jira_checkpointing.py │ │ │ │ ├── test_jira_error_handling.py │ │ │ │ ├── test_jira_large_ticket_handling.py │ │ │ │ └── test_jira_permission_sync.py │ │ │ ├── mediawiki/ │ │ │ │ ├── __init__.py │ │ │ │ ├── test_mediawiki_family.py │ │ │ │ └── test_wiki.py │ │ │ ├── notion/ │ │ │ │ └── test_notion_datasource.py │ │ │ ├── salesforce/ │ │ │ │ ├── test_salesforce_custom_config.py │ │ │ │ ├── test_salesforce_sqlite.py │ │ │ │ └── test_yield_doc_batches.py │ │ │ ├── sharepoint/ │ │ │ │ ├── test_delta_checkpointing.py │ │ │ │ ├── test_denylist.py │ │ │ │ ├── test_drive_matching.py │ │ │ │ ├── test_fetch_site_pages.py │ │ │ │ ├── test_hierarchy_helpers.py │ │ │ │ ├── test_rest_client_context_caching.py │ │ │ │ └── test_url_parsing.py │ │ │ ├── slab/ │ │ │ │ └── test_slab_validation.py │ │ │ ├── slack/ │ │ │ │ └── test_message_filtering.py │ │ │ ├── teams/ │ │ │ │ └── test_collect_teams.py │ │ │ ├── test_connector_factory.py │ │ │ ├── test_document_metadata_coercion.py │ │ │ ├── test_microsoft_graph_env.py │ │ │ ├── utils.py │ │ │ └── zendesk/ │ │ │ ├── test_zendesk_checkpointing.py │ │ │ └── test_zendesk_rate_limit.py │ │ ├── context/ │ │ │ └── search/ │ │ │ └── federated/ │ │ │ ├── test_slack_query_construction.py │ │ │ └── test_slack_thread_context.py │ │ ├── db/ │ │ │ ├── __init__.py │ │ │ ├── conftest.py │ │ │ ├── test_assign_default_groups.py │ │ │ ├── test_chat_sessions.py │ │ │ ├── test_dal.py │ │ │ ├── test_delete_user.py │ │ │ ├── test_llm_sync.py │ │ │ ├── test_persona_display_priority.py │ │ │ ├── test_projects_upload_task_expiry.py │ │ │ ├── test_scim_dal.py │ │ │ ├── test_tools.py │ │ │ ├── test_usage.py │ │ │ └── test_voice.py │ │ ├── document_index/ │ │ │ ├── opensearch/ │ │ │ │ ├── test_get_doc_chunk_id.py │ │ │ │ └── test_opensearch_batch_flush.py │ │ │ ├── test_disabled_document_index.py │ │ │ └── vespa/ │ │ │ ├── shared_utils/ │ │ │ │ └── test_utils.py │ │ │ └── test_vespa_batch_flush.py │ │ ├── error_handling/ │ │ │ ├── __init__.py │ │ │ └── test_exceptions.py │ │ ├── federated_connectors/ │ │ │ ├── test_federated_connector_factory.py │ │ │ └── test_oauth_utils.py │ │ ├── file_processing/ │ │ │ ├── __init__.py │ │ │ ├── test_image_summarization_errors.py │ │ │ ├── test_image_summarization_litellm_errors.py │ │ │ ├── test_pdf.py │ │ │ └── test_xlsx_to_text.py │ │ ├── hooks/ │ │ │ ├── __init__.py │ │ │ ├── test_api_dependencies.py │ │ │ ├── test_base_spec.py │ │ │ ├── test_models.py │ │ │ ├── test_query_processing_spec.py │ │ │ └── test_registry.py │ │ ├── image_gen/ │ │ │ └── test_provider_building.py │ │ ├── indexing/ │ │ │ ├── conftest.py │ │ │ ├── test_censoring.py │ │ │ ├── test_chunker.py │ │ │ ├── test_embed_chunks_in_batches.py │ │ │ ├── test_embedder.py │ │ │ ├── test_indexing_pipeline.py │ │ │ ├── test_personas_in_chunks.py │ │ │ └── test_vespa.py │ │ ├── lazy_handling/ │ │ │ └── __init__.py │ │ ├── llm/ │ │ │ ├── conftest.py │ │ │ ├── test_bedrock_token_limit.py │ │ │ ├── test_factory.py │ │ │ ├── test_formatting_reenabled.py │ │ │ ├── test_litellm_monkey_patches.py │ │ │ ├── test_llm_provider_options.py │ │ │ ├── test_model_is_reasoning.py │ │ │ ├── test_model_map.py │ │ │ ├── test_model_name_parser.py │ │ │ ├── test_model_response.py │ │ │ ├── test_multi_llm.py │ │ │ ├── test_reasoning_effort_mapping.py │ │ │ ├── test_request_context.py │ │ │ ├── test_true_openai_model.py │ │ │ └── test_vision_model_selection_logging.py │ │ ├── natural_language_processing/ │ │ │ └── test_search_nlp_models.py │ │ ├── onyxbot/ │ │ │ ├── discord/ │ │ │ │ ├── conftest.py │ │ │ │ ├── test_api_client.py │ │ │ │ ├── test_cache_manager.py │ │ │ │ ├── test_context_builders.py │ │ │ │ ├── test_discord_utils.py │ │ │ │ ├── test_message_utils.py │ │ │ │ └── test_should_respond.py │ │ │ ├── test_handle_regular_answer.py │ │ │ ├── test_slack_blocks.py │ │ │ ├── test_slack_channel_config.py │ │ │ ├── test_slack_formatting.py │ │ │ └── test_slack_gating.py │ │ ├── prompts/ │ │ │ └── test_prompt_utils.py │ │ ├── redis_ca.pem │ │ ├── server/ │ │ │ ├── __init__.py │ │ │ ├── features/ │ │ │ │ ├── __init__.py │ │ │ │ ├── hierarchy/ │ │ │ │ │ └── test_user_access_info.py │ │ │ │ └── hooks/ │ │ │ │ └── __init__.py │ │ │ ├── manage/ │ │ │ │ ├── embedding/ │ │ │ │ │ └── test_embedding_api.py │ │ │ │ ├── llm/ │ │ │ │ │ ├── test_fetch_models_api.py │ │ │ │ │ └── test_llm_provider_utils.py │ │ │ │ ├── test_bulk_invite_limit.py │ │ │ │ └── voice/ │ │ │ │ └── test_voice_api_validation.py │ │ │ ├── scim/ │ │ │ │ ├── __init__.py │ │ │ │ ├── conftest.py │ │ │ │ ├── test_admin.py │ │ │ │ ├── test_auth.py │ │ │ │ ├── test_entra.py │ │ │ │ ├── test_filtering.py │ │ │ │ ├── test_group_endpoints.py │ │ │ │ ├── test_patch.py │ │ │ │ ├── test_providers.py │ │ │ │ └── test_user_endpoints.py │ │ │ ├── test_full_user_snapshot.py │ │ │ ├── test_pool_metrics.py │ │ │ ├── test_projects_file_utils.py │ │ │ ├── test_prometheus_instrumentation.py │ │ │ ├── test_settings_store.py │ │ │ └── test_upload_files.py │ │ ├── test_redis.py │ │ ├── test_startup_validation.py │ │ ├── tools/ │ │ │ ├── __init__.py │ │ │ ├── custom/ │ │ │ │ └── test_custom_tools.py │ │ │ ├── test_construct_tools_no_vectordb.py │ │ │ ├── test_file_reader_tool.py │ │ │ ├── test_no_vectordb.py │ │ │ ├── test_python_tool_availability.py │ │ │ ├── test_search_utils.py │ │ │ ├── test_tool_runner.py │ │ │ ├── test_tool_runner_chat_files.py │ │ │ ├── test_tool_utils.py │ │ │ └── tool_implementations/ │ │ │ ├── open_url/ │ │ │ │ ├── data/ │ │ │ │ │ └── test_snippet_finding_data.json │ │ │ │ ├── test_onyx_web_crawler.py │ │ │ │ ├── test_snippet_matcher.py │ │ │ │ └── test_url_normalization.py │ │ │ ├── python/ │ │ │ │ ├── __init__.py │ │ │ │ ├── test_code_interpreter_client.py │ │ │ │ └── test_python_tool_upload_cache.py │ │ │ └── websearch/ │ │ │ ├── data/ │ │ │ │ └── tartan.txt │ │ │ ├── test_brave_client.py │ │ │ ├── test_web_search_providers.py │ │ │ ├── test_web_search_tool_run.py │ │ │ └── test_websearch_utils.py │ │ ├── tracing/ │ │ │ ├── __init__.py │ │ │ └── test_tracing_setup.py │ │ ├── utils/ │ │ │ ├── test_gpu_utils.py │ │ │ ├── test_json_river.py │ │ │ ├── test_postgres_sanitization.py │ │ │ ├── test_sensitive.py │ │ │ ├── test_sensitive_typing.py │ │ │ ├── test_telemetry.py │ │ │ ├── test_threadpool_concurrency.py │ │ │ ├── test_threadpool_contextvars.py │ │ │ ├── test_url_ssrf.py │ │ │ ├── test_vespa_query.py │ │ │ └── test_vespa_tasks.py │ │ └── voice/ │ │ └── providers/ │ │ ├── test_azure_provider.py │ │ ├── test_azure_ssml.py │ │ ├── test_elevenlabs_provider.py │ │ └── test_openai_provider.py │ ├── scripts/ │ │ └── __init__.py │ ├── server/ │ │ └── metrics/ │ │ ├── test_celery_task_metrics.py │ │ ├── test_indexing_pipeline_collectors.py │ │ ├── test_indexing_pipeline_setup.py │ │ ├── test_indexing_task_metrics.py │ │ ├── test_metrics_server.py │ │ ├── test_opensearch_search_metrics.py │ │ └── test_worker_health.py │ └── tools/ │ ├── __init__.py │ └── test_memory_tool_packets.py ├── contributor_ip_assignment/ │ └── EE_Contributor_IP_Assignment_Agreement.md ├── ct.yaml ├── cubic.yaml ├── deployment/ │ ├── .gitignore │ ├── README.md │ ├── aws_ecs_fargate/ │ │ └── cloudformation/ │ │ ├── README.md │ │ ├── deploy.sh │ │ ├── onyx_acm_template.yaml │ │ ├── onyx_cluster_template.yaml │ │ ├── onyx_config.jsonl │ │ ├── onyx_efs_template.yaml │ │ ├── services/ │ │ │ ├── onyx_backend_api_server_service_template.yaml │ │ │ ├── onyx_backend_background_server_service_template.yaml │ │ │ ├── onyx_model_server_indexing_service_template.yaml │ │ │ ├── onyx_model_server_inference_service_template.yaml │ │ │ ├── onyx_nginx_service_template.yaml │ │ │ ├── onyx_postgres_service_template.yaml │ │ │ ├── onyx_redis_service_template.yaml │ │ │ ├── onyx_vespaengine_service_template.yaml │ │ │ └── onyx_web_server_service_template.yaml │ │ └── uninstall.sh │ ├── data/ │ │ └── nginx/ │ │ ├── app.conf.template │ │ ├── app.conf.template.no-letsencrypt │ │ ├── app.conf.template.prod │ │ ├── mcp.conf.inc.template │ │ ├── mcp_upstream.conf.inc.template │ │ └── run-nginx.sh │ ├── docker_compose/ │ │ ├── README.md │ │ ├── docker-compose.dev.yml │ │ ├── docker-compose.mcp-api-key-test.yml │ │ ├── docker-compose.mcp-oauth-test.yml │ │ ├── docker-compose.multitenant-dev.yml │ │ ├── docker-compose.onyx-lite.yml │ │ ├── docker-compose.prod-cloud.yml │ │ ├── docker-compose.prod-no-letsencrypt.yml │ │ ├── docker-compose.prod.yml │ │ ├── docker-compose.resources.yml │ │ ├── docker-compose.search-testing.yml │ │ ├── docker-compose.yml │ │ ├── env.nginx.template │ │ ├── env.prod.template │ │ ├── env.template │ │ ├── init-letsencrypt.sh │ │ ├── install.ps1 │ │ └── install.sh │ ├── helm/ │ │ ├── README.md │ │ └── charts/ │ │ └── onyx/ │ │ ├── .gitignore │ │ ├── .helmignore │ │ ├── Chart.yaml │ │ ├── ci/ │ │ │ └── ct-values.yaml │ │ ├── dashboards/ │ │ │ └── indexing-pipeline.json │ │ ├── templates/ │ │ │ ├── _helpers.tpl │ │ │ ├── api-deployment.yaml │ │ │ ├── api-hpa.yaml │ │ │ ├── api-scaledobject.yaml │ │ │ ├── api-service.yaml │ │ │ ├── auth-secrets.yaml │ │ │ ├── celery-beat.yaml │ │ │ ├── celery-worker-docfetching-hpa.yaml │ │ │ ├── celery-worker-docfetching-metrics-service.yaml │ │ │ ├── celery-worker-docfetching-scaledobject.yaml │ │ │ ├── celery-worker-docfetching.yaml │ │ │ ├── celery-worker-docprocessing-hpa.yaml │ │ │ ├── celery-worker-docprocessing-metrics-service.yaml │ │ │ ├── celery-worker-docprocessing-scaledobject.yaml │ │ │ ├── celery-worker-docprocessing.yaml │ │ │ ├── celery-worker-heavy-hpa.yaml │ │ │ ├── celery-worker-heavy-scaledobject.yaml │ │ │ ├── celery-worker-heavy.yaml │ │ │ ├── celery-worker-light-hpa.yaml │ │ │ ├── celery-worker-light-scaledobject.yaml │ │ │ ├── celery-worker-light.yaml │ │ │ ├── celery-worker-monitoring-hpa.yaml │ │ │ ├── celery-worker-monitoring-metrics-service.yaml │ │ │ ├── celery-worker-monitoring-scaledobject.yaml │ │ │ ├── celery-worker-monitoring.yaml │ │ │ ├── celery-worker-primary-hpa.yaml │ │ │ ├── celery-worker-primary-scaledobject.yaml │ │ │ ├── celery-worker-primary.yaml │ │ │ ├── celery-worker-servicemonitors.yaml │ │ │ ├── celery-worker-user-file-processing-hpa.yaml │ │ │ ├── celery-worker-user-file-processing-scaledobject.yaml │ │ │ ├── celery-worker-user-file-processing.yaml │ │ │ ├── configmap.yaml │ │ │ ├── discordbot.yaml │ │ │ ├── grafana-dashboards.yaml │ │ │ ├── indexing-model-deployment.yaml │ │ │ ├── indexing-model-service.yaml │ │ │ ├── inference-model-deployment.yaml │ │ │ ├── inference-model-service.yaml │ │ │ ├── ingress-api.yaml │ │ │ ├── ingress-mcp.yaml │ │ │ ├── ingress-webserver.yaml │ │ │ ├── lets-encrypt.yaml │ │ │ ├── mcp-server-deployment.yaml │ │ │ ├── mcp-server-service.yaml │ │ │ ├── nginx-conf.yaml │ │ │ ├── serviceaccount.yaml │ │ │ ├── slackbot.yaml │ │ │ ├── tests/ │ │ │ │ └── test-connection.yaml │ │ │ ├── tooling-pginto-configmap.yaml │ │ │ ├── webserver-deployment.yaml │ │ │ ├── webserver-hpa.yaml │ │ │ ├── webserver-scaledobject.yaml │ │ │ └── webserver-service.yaml │ │ ├── templates_disabled/ │ │ │ ├── background-deployment.yaml │ │ │ ├── background-hpa.yaml │ │ │ └── onyx-secret.yaml │ │ ├── values-lite.yaml │ │ └── values.yaml │ └── terraform/ │ └── modules/ │ └── aws/ │ ├── README.md │ ├── eks/ │ │ ├── main.tf │ │ ├── outputs.tf │ │ └── variables.tf │ ├── onyx/ │ │ ├── main.tf │ │ ├── outputs.tf │ │ ├── variables.tf │ │ └── versions.tf │ ├── opensearch/ │ │ ├── main.tf │ │ ├── outputs.tf │ │ └── variables.tf │ ├── postgres/ │ │ ├── main.tf │ │ ├── outputs.tf │ │ └── variables.tf │ ├── redis/ │ │ ├── main.tf │ │ ├── outputs.tf │ │ └── variables.tf │ ├── s3/ │ │ ├── main.tf │ │ └── variables.tf │ ├── vpc/ │ │ ├── main.tf │ │ ├── outputs.tf │ │ └── variables.tf │ └── waf/ │ ├── main.tf │ ├── outputs.tf │ └── variables.tf ├── desktop/ │ ├── .gitignore │ ├── README.md │ ├── package.json │ ├── scripts/ │ │ └── generate-icons.sh │ ├── src/ │ │ ├── index.html │ │ └── titlebar.js │ └── src-tauri/ │ ├── Cargo.toml │ ├── build.rs │ ├── gen/ │ │ └── schemas/ │ │ ├── acl-manifests.json │ │ ├── capabilities.json │ │ ├── desktop-schema.json │ │ └── macOS-schema.json │ ├── icons/ │ │ ├── android/ │ │ │ ├── mipmap-anydpi-v26/ │ │ │ │ └── ic_launcher.xml │ │ │ └── values/ │ │ │ └── ic_launcher_background.xml │ │ └── icon.icns │ ├── src/ │ │ └── main.rs │ └── tauri.conf.json ├── docker-bake.hcl ├── docs/ │ └── METRICS.md ├── examples/ │ ├── assistants-api/ │ │ └── topics_analyzer.py │ └── widget/ │ ├── .eslintrc.json │ ├── .gitignore │ ├── README.md │ ├── next.config.mjs │ ├── package.json │ ├── postcss.config.mjs │ ├── src/ │ │ └── app/ │ │ ├── globals.css │ │ ├── layout.tsx │ │ ├── page.tsx │ │ └── widget/ │ │ └── Widget.tsx │ ├── tailwind.config.ts │ └── tsconfig.json ├── extensions/ │ └── chrome/ │ ├── LICENSE │ ├── README.md │ ├── manifest.json │ ├── service_worker.js │ └── src/ │ ├── pages/ │ │ ├── onyx_home.html │ │ ├── onyx_home.js │ │ ├── options.html │ │ ├── options.js │ │ ├── panel.html │ │ ├── panel.js │ │ ├── popup.html │ │ ├── popup.js │ │ ├── welcome.html │ │ └── welcome.js │ ├── styles/ │ │ ├── selection-icon.css │ │ └── shared.css │ └── utils/ │ ├── constants.js │ ├── content.js │ ├── error-modal.js │ ├── selection-icon.js │ └── storage.js ├── profiling/ │ └── grafana/ │ └── dashboards/ │ └── onyx/ │ └── opensearch-search-latency.json ├── pyproject.toml ├── web/ │ ├── .dockerignore │ ├── .eslintrc.json │ ├── .gitignore │ ├── .prettierignore │ ├── .prettierrc.json │ ├── .storybook/ │ │ ├── Introduction.mdx │ │ ├── README.md │ │ ├── main.ts │ │ ├── mocks/ │ │ │ ├── next-image.tsx │ │ │ ├── next-link.tsx │ │ │ └── next-navigation.tsx │ │ ├── preview-head.html │ │ └── preview.ts │ ├── @types/ │ │ ├── favicon-fetch.d.ts │ │ └── images.d.ts │ ├── AGENTS.md │ ├── Dockerfile │ ├── README.md │ ├── components.json │ ├── jest.config.js │ ├── lib/ │ │ └── opal/ │ │ ├── README.md │ │ ├── package.json │ │ ├── scripts/ │ │ │ ├── README.md │ │ │ ├── convert-svg.sh │ │ │ └── icon-template.js │ │ ├── src/ │ │ │ ├── components/ │ │ │ │ ├── README.md │ │ │ │ ├── buttons/ │ │ │ │ │ ├── Button/ │ │ │ │ │ │ └── Button.stories.tsx │ │ │ │ │ ├── button/ │ │ │ │ │ │ ├── README.md │ │ │ │ │ │ └── components.tsx │ │ │ │ │ ├── chevron.css │ │ │ │ │ ├── chevron.tsx │ │ │ │ │ ├── filter-button/ │ │ │ │ │ │ ├── FilterButton.stories.tsx │ │ │ │ │ │ ├── README.md │ │ │ │ │ │ └── components.tsx │ │ │ │ │ ├── icon-wrapper.tsx │ │ │ │ │ ├── line-item-button/ │ │ │ │ │ │ ├── README.md │ │ │ │ │ │ └── components.tsx │ │ │ │ │ ├── open-button/ │ │ │ │ │ │ ├── OpenButton.stories.tsx │ │ │ │ │ │ ├── README.md │ │ │ │ │ │ └── components.tsx │ │ │ │ │ ├── select-button/ │ │ │ │ │ │ ├── README.md │ │ │ │ │ │ ├── components.tsx │ │ │ │ │ │ └── styles.css │ │ │ │ │ └── sidebar-tab/ │ │ │ │ │ ├── README.md │ │ │ │ │ ├── SidebarTab.stories.tsx │ │ │ │ │ └── components.tsx │ │ │ │ ├── cards/ │ │ │ │ │ ├── card/ │ │ │ │ │ │ ├── Card.stories.tsx │ │ │ │ │ │ ├── README.md │ │ │ │ │ │ ├── components.tsx │ │ │ │ │ │ └── styles.css │ │ │ │ │ ├── empty-message-card/ │ │ │ │ │ │ ├── EmptyMessageCard.stories.tsx │ │ │ │ │ │ ├── README.md │ │ │ │ │ │ └── components.tsx │ │ │ │ │ └── select-card/ │ │ │ │ │ ├── README.md │ │ │ │ │ ├── SelectCard.stories.tsx │ │ │ │ │ ├── components.tsx │ │ │ │ │ └── styles.css │ │ │ │ ├── index.ts │ │ │ │ ├── pagination/ │ │ │ │ │ ├── Pagination.stories.tsx │ │ │ │ │ ├── README.md │ │ │ │ │ └── components.tsx │ │ │ │ ├── table/ │ │ │ │ │ ├── ActionsContainer.tsx │ │ │ │ │ ├── ColumnSortabilityPopover.tsx │ │ │ │ │ ├── ColumnVisibilityPopover.tsx │ │ │ │ │ ├── DragOverlayRow.tsx │ │ │ │ │ ├── Footer.tsx │ │ │ │ │ ├── QualifierContainer.tsx │ │ │ │ │ ├── README.md │ │ │ │ │ ├── Table.stories.tsx │ │ │ │ │ ├── TableBody.tsx │ │ │ │ │ ├── TableCell.tsx │ │ │ │ │ ├── TableElement.tsx │ │ │ │ │ ├── TableHead.tsx │ │ │ │ │ ├── TableHeader.tsx │ │ │ │ │ ├── TableQualifier.tsx │ │ │ │ │ ├── TableRow.tsx │ │ │ │ │ ├── TableSizeContext.tsx │ │ │ │ │ ├── columns.ts │ │ │ │ │ ├── components.tsx │ │ │ │ │ ├── hooks/ │ │ │ │ │ │ ├── useColumnWidths.ts │ │ │ │ │ │ ├── useDataTable.ts │ │ │ │ │ │ └── useDraggableRows.ts │ │ │ │ │ ├── styles.css │ │ │ │ │ └── types.ts │ │ │ │ ├── tag/ │ │ │ │ │ ├── README.md │ │ │ │ │ ├── Tag.stories.tsx │ │ │ │ │ ├── components.tsx │ │ │ │ │ └── styles.css │ │ │ │ ├── text/ │ │ │ │ │ ├── InlineMarkdown.tsx │ │ │ │ │ ├── README.md │ │ │ │ │ ├── Text.stories.tsx │ │ │ │ │ └── components.tsx │ │ │ │ └── tooltip.css │ │ │ ├── core/ │ │ │ │ ├── README.md │ │ │ │ ├── animations/ │ │ │ │ │ ├── Hoverable.stories.tsx │ │ │ │ │ ├── README.md │ │ │ │ │ ├── components.tsx │ │ │ │ │ └── styles.css │ │ │ │ ├── disabled/ │ │ │ │ │ ├── components.tsx │ │ │ │ │ └── styles.css │ │ │ │ ├── index.ts │ │ │ │ └── interactive/ │ │ │ │ ├── Interactive.stories.tsx │ │ │ │ ├── README.md │ │ │ │ ├── container/ │ │ │ │ │ ├── README.md │ │ │ │ │ └── components.tsx │ │ │ │ ├── foldable/ │ │ │ │ │ ├── README.md │ │ │ │ │ ├── components.tsx │ │ │ │ │ └── styles.css │ │ │ │ ├── shared.css │ │ │ │ ├── simple/ │ │ │ │ │ └── components.tsx │ │ │ │ ├── stateful/ │ │ │ │ │ ├── README.md │ │ │ │ │ ├── components.tsx │ │ │ │ │ └── styles.css │ │ │ │ ├── stateless/ │ │ │ │ │ ├── README.md │ │ │ │ │ ├── components.tsx │ │ │ │ │ └── styles.css │ │ │ │ └── utils.ts │ │ │ ├── icons/ │ │ │ │ ├── DiscordMono.tsx │ │ │ │ ├── actions.tsx │ │ │ │ ├── activity-small.tsx │ │ │ │ ├── activity.tsx │ │ │ │ ├── add-lines.tsx │ │ │ │ ├── alert-circle.tsx │ │ │ │ ├── alert-triangle.tsx │ │ │ │ ├── arrow-down-dot.tsx │ │ │ │ ├── arrow-exchange.tsx │ │ │ │ ├── arrow-left-dot.tsx │ │ │ │ ├── arrow-left.tsx │ │ │ │ ├── arrow-right-circle.tsx │ │ │ │ ├── arrow-right-dot.tsx │ │ │ │ ├── arrow-right.tsx │ │ │ │ ├── arrow-up-circle.tsx │ │ │ │ ├── arrow-up-dot.tsx │ │ │ │ ├── arrow-up-down.tsx │ │ │ │ ├── arrow-up-right.tsx │ │ │ │ ├── arrow-up.tsx │ │ │ │ ├── arrow-wall-right.tsx │ │ │ │ ├── audio-eq-small.tsx │ │ │ │ ├── audio.tsx │ │ │ │ ├── aws.tsx │ │ │ │ ├── azure.tsx │ │ │ │ ├── bar-chart-small.tsx │ │ │ │ ├── bar-chart.tsx │ │ │ │ ├── bell.tsx │ │ │ │ ├── bifrost.tsx │ │ │ │ ├── blocks.tsx │ │ │ │ ├── book-open.tsx │ │ │ │ ├── bookmark.tsx │ │ │ │ ├── books-line-small.tsx │ │ │ │ ├── books-stack-small.tsx │ │ │ │ ├── bracket-curly.tsx │ │ │ │ ├── branch.tsx │ │ │ │ ├── bubble-text.tsx │ │ │ │ ├── calendar.tsx │ │ │ │ ├── check-circle.tsx │ │ │ │ ├── check-small.tsx │ │ │ │ ├── check-square.tsx │ │ │ │ ├── check.tsx │ │ │ │ ├── chevron-down-small.tsx │ │ │ │ ├── chevron-down.tsx │ │ │ │ ├── chevron-left.tsx │ │ │ │ ├── chevron-right.tsx │ │ │ │ ├── chevron-up-small.tsx │ │ │ │ ├── chevron-up.tsx │ │ │ │ ├── circle.tsx │ │ │ │ ├── claude.tsx │ │ │ │ ├── clipboard.tsx │ │ │ │ ├── clock-hands-small.tsx │ │ │ │ ├── clock.tsx │ │ │ │ ├── cloud.tsx │ │ │ │ ├── code.tsx │ │ │ │ ├── column.tsx │ │ │ │ ├── copy.tsx │ │ │ │ ├── corner-right-up-dot.tsx │ │ │ │ ├── cpu.tsx │ │ │ │ ├── credit-card.tsx │ │ │ │ ├── curate.tsx │ │ │ │ ├── dashboard.tsx │ │ │ │ ├── dev-kit.tsx │ │ │ │ ├── download-cloud.tsx │ │ │ │ ├── download.tsx │ │ │ │ ├── edit-big.tsx │ │ │ │ ├── edit.tsx │ │ │ │ ├── empty.tsx │ │ │ │ ├── expand.tsx │ │ │ │ ├── external-link.tsx │ │ │ │ ├── eye-closed.tsx │ │ │ │ ├── eye-off.tsx │ │ │ │ ├── eye.tsx │ │ │ │ ├── file-braces.tsx │ │ │ │ ├── file-broadcast.tsx │ │ │ │ ├── file-chart-pie.tsx │ │ │ │ ├── file-small.tsx │ │ │ │ ├── file-text.tsx │ │ │ │ ├── files.tsx │ │ │ │ ├── filter-plus.tsx │ │ │ │ ├── filter.tsx │ │ │ │ ├── fold.tsx │ │ │ │ ├── folder-in.tsx │ │ │ │ ├── folder-open.tsx │ │ │ │ ├── folder-partial-open.tsx │ │ │ │ ├── folder-plus.tsx │ │ │ │ ├── folder.tsx │ │ │ │ ├── gemini.tsx │ │ │ │ ├── globe.tsx │ │ │ │ ├── handle.tsx │ │ │ │ ├── hard-drive.tsx │ │ │ │ ├── hash-small.tsx │ │ │ │ ├── hash.tsx │ │ │ │ ├── headset-mic.tsx │ │ │ │ ├── history.tsx │ │ │ │ ├── hourglass.tsx │ │ │ │ ├── image-small.tsx │ │ │ │ ├── image.tsx │ │ │ │ ├── import-icon.tsx │ │ │ │ ├── index.ts │ │ │ │ ├── info-small.tsx │ │ │ │ ├── info.tsx │ │ │ │ ├── key.tsx │ │ │ │ ├── keystroke.tsx │ │ │ │ ├── lightbulb-simple.tsx │ │ │ │ ├── line-chart-up.tsx │ │ │ │ ├── link.tsx │ │ │ │ ├── linked-dots.tsx │ │ │ │ ├── litellm.tsx │ │ │ │ ├── lm-studio.tsx │ │ │ │ ├── loader.tsx │ │ │ │ ├── lock.tsx │ │ │ │ ├── log-out.tsx │ │ │ │ ├── maximize-2.tsx │ │ │ │ ├── mcp.tsx │ │ │ │ ├── menu.tsx │ │ │ │ ├── microphone-off.tsx │ │ │ │ ├── microphone.tsx │ │ │ │ ├── minus-circle.tsx │ │ │ │ ├── minus.tsx │ │ │ │ ├── moon.tsx │ │ │ │ ├── more-horizontal.tsx │ │ │ │ ├── music-small.tsx │ │ │ │ ├── network-graph.tsx │ │ │ │ ├── notification-bubble.tsx │ │ │ │ ├── ollama.tsx │ │ │ │ ├── onyx-logo-typed.tsx │ │ │ │ ├── onyx-logo.tsx │ │ │ │ ├── onyx-octagon.tsx │ │ │ │ ├── onyx-typed.tsx │ │ │ │ ├── openai.tsx │ │ │ │ ├── openrouter.tsx │ │ │ │ ├── organization.tsx │ │ │ │ ├── paint-brush.tsx │ │ │ │ ├── paperclip.tsx │ │ │ │ ├── pause-circle.tsx │ │ │ │ ├── pen-small.tsx │ │ │ │ ├── pencil-ruler.tsx │ │ │ │ ├── pie-chart.tsx │ │ │ │ ├── pin.tsx │ │ │ │ ├── pinned.tsx │ │ │ │ ├── play-circle.tsx │ │ │ │ ├── plug.tsx │ │ │ │ ├── plus-circle.tsx │ │ │ │ ├── plus.tsx │ │ │ │ ├── progress-bars.tsx │ │ │ │ ├── progress-circle.tsx │ │ │ │ ├── question-mark-small.tsx │ │ │ │ ├── quote-end.tsx │ │ │ │ ├── quote-start.tsx │ │ │ │ ├── refresh-cw.tsx │ │ │ │ ├── revert.tsx │ │ │ │ ├── search-menu.tsx │ │ │ │ ├── search-small.tsx │ │ │ │ ├── search.tsx │ │ │ │ ├── server.tsx │ │ │ │ ├── settings.tsx │ │ │ │ ├── share-webhook.tsx │ │ │ │ ├── share.tsx │ │ │ │ ├── shield.tsx │ │ │ │ ├── sidebar.tsx │ │ │ │ ├── slack.tsx │ │ │ │ ├── slash.tsx │ │ │ │ ├── sliders-small.tsx │ │ │ │ ├── sliders.tsx │ │ │ │ ├── sort-order.tsx │ │ │ │ ├── sort.tsx │ │ │ │ ├── sparkle.tsx │ │ │ │ ├── star-off.tsx │ │ │ │ ├── star.tsx │ │ │ │ ├── step1.tsx │ │ │ │ ├── step2.tsx │ │ │ │ ├── step3-end.tsx │ │ │ │ ├── step3.tsx │ │ │ │ ├── stop-circle.tsx │ │ │ │ ├── stop.tsx │ │ │ │ ├── sun.tsx │ │ │ │ ├── tag.tsx │ │ │ │ ├── terminal-small.tsx │ │ │ │ ├── terminal.tsx │ │ │ │ ├── text-lines-small.tsx │ │ │ │ ├── text-lines.tsx │ │ │ │ ├── thumbs-down.tsx │ │ │ │ ├── thumbs-up.tsx │ │ │ │ ├── trash.tsx │ │ │ │ ├── two-line-small.tsx │ │ │ │ ├── unplug.tsx │ │ │ │ ├── upload-cloud.tsx │ │ │ │ ├── user-check.tsx │ │ │ │ ├── user-edit.tsx │ │ │ │ ├── user-key.tsx │ │ │ │ ├── user-manage.tsx │ │ │ │ ├── user-minus.tsx │ │ │ │ ├── user-plus.tsx │ │ │ │ ├── user-shield.tsx │ │ │ │ ├── user-speaker.tsx │ │ │ │ ├── user-sync.tsx │ │ │ │ ├── user-x.tsx │ │ │ │ ├── user.tsx │ │ │ │ ├── users.tsx │ │ │ │ ├── volume-off.tsx │ │ │ │ ├── volume.tsx │ │ │ │ ├── wallet.tsx │ │ │ │ ├── workflow.tsx │ │ │ │ ├── x-circle.tsx │ │ │ │ ├── x-octagon.tsx │ │ │ │ ├── x.tsx │ │ │ │ ├── zoom-in.tsx │ │ │ │ └── zoom-out.tsx │ │ │ ├── illustrations/ │ │ │ │ ├── broken-key.tsx │ │ │ │ ├── connect.tsx │ │ │ │ ├── connected.tsx │ │ │ │ ├── disconnected.tsx │ │ │ │ ├── empty.tsx │ │ │ │ ├── end-of-line.tsx │ │ │ │ ├── index.ts │ │ │ │ ├── limit-alert.tsx │ │ │ │ ├── long-wait.tsx │ │ │ │ ├── no-access.tsx │ │ │ │ ├── no-result.tsx │ │ │ │ ├── not-found.tsx │ │ │ │ ├── overflow.tsx │ │ │ │ ├── plug-broken.tsx │ │ │ │ ├── timeout.tsx │ │ │ │ ├── un-plugged.tsx │ │ │ │ └── usage-alert.tsx │ │ │ ├── layouts/ │ │ │ │ ├── README.md │ │ │ │ ├── cards/ │ │ │ │ │ └── header-layout/ │ │ │ │ │ ├── CardHeaderLayout.stories.tsx │ │ │ │ │ ├── README.md │ │ │ │ │ └── components.tsx │ │ │ │ ├── content/ │ │ │ │ │ ├── Content.stories.tsx │ │ │ │ │ ├── ContentLg.tsx │ │ │ │ │ ├── ContentMd.tsx │ │ │ │ │ ├── ContentSm.tsx │ │ │ │ │ ├── ContentXl.tsx │ │ │ │ │ ├── README.md │ │ │ │ │ ├── components.tsx │ │ │ │ │ └── styles.css │ │ │ │ ├── content-action/ │ │ │ │ │ ├── ContentAction.stories.tsx │ │ │ │ │ ├── README.md │ │ │ │ │ └── components.tsx │ │ │ │ ├── illustration-content/ │ │ │ │ │ ├── IllustrationContent.stories.tsx │ │ │ │ │ ├── README.md │ │ │ │ │ └── components.tsx │ │ │ │ └── index.ts │ │ │ ├── shared.ts │ │ │ ├── types.ts │ │ │ └── utils.ts │ │ └── tsconfig.json │ ├── next.config.js │ ├── package.json │ ├── playwright.config.ts │ ├── postcss.config.js │ ├── public/ │ │ └── fonts/ │ │ └── KHTeka-Medium.otf │ ├── sentry.edge.config.ts │ ├── sentry.server.config.ts │ ├── src/ │ │ ├── app/ │ │ │ ├── PostHogPageView.tsx │ │ │ ├── admin/ │ │ │ │ ├── actions/ │ │ │ │ │ ├── edit/ │ │ │ │ │ │ └── [toolId]/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── edit-mcp/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── mcp/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── new/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── open-api/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ └── page.tsx │ │ │ │ ├── add-connector/ │ │ │ │ │ └── page.tsx │ │ │ │ ├── agents/ │ │ │ │ │ ├── CollapsibleSection.tsx │ │ │ │ │ ├── interfaces.ts │ │ │ │ │ ├── lib.ts │ │ │ │ │ └── page.tsx │ │ │ │ ├── billing/ │ │ │ │ │ ├── BillingDetailsView.tsx │ │ │ │ │ ├── CheckoutView.tsx │ │ │ │ │ ├── LicenseActivationCard.tsx │ │ │ │ │ ├── PlansView.tsx │ │ │ │ │ ├── billing.css │ │ │ │ │ ├── page.test.tsx │ │ │ │ │ └── page.tsx │ │ │ │ ├── bots/ │ │ │ │ │ ├── SlackBotCreationForm.tsx │ │ │ │ │ ├── SlackBotTable.tsx │ │ │ │ │ ├── SlackBotUpdateForm.tsx │ │ │ │ │ ├── SlackTokensForm.tsx │ │ │ │ │ ├── [bot-id]/ │ │ │ │ │ │ ├── SlackChannelConfigsTable.tsx │ │ │ │ │ │ ├── channels/ │ │ │ │ │ │ │ ├── SlackChannelConfigCreationForm.tsx │ │ │ │ │ │ │ ├── SlackChannelConfigFormFields.tsx │ │ │ │ │ │ │ ├── [id]/ │ │ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ │ │ └── new/ │ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ │ ├── hooks.ts │ │ │ │ │ │ ├── lib.ts │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── new/ │ │ │ │ │ │ ├── lib.ts │ │ │ │ │ │ └── page.tsx │ │ │ │ │ └── page.tsx │ │ │ │ ├── configuration/ │ │ │ │ │ ├── chat-preferences/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── code-interpreter/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── document-processing/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── image-generation/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── llm/ │ │ │ │ │ │ ├── ModelConfigurationField.tsx │ │ │ │ │ │ ├── ProviderIcon.tsx │ │ │ │ │ │ ├── page.tsx │ │ │ │ │ │ └── utils.ts │ │ │ │ │ ├── search/ │ │ │ │ │ │ ├── UpgradingPage.tsx │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── voice/ │ │ │ │ │ │ ├── VoiceProviderSetupModal.tsx │ │ │ │ │ │ └── page.tsx │ │ │ │ │ └── web-search/ │ │ │ │ │ └── page.tsx │ │ │ │ ├── connector/ │ │ │ │ │ └── [ccPairId]/ │ │ │ │ │ ├── ConfigDisplay.tsx │ │ │ │ │ ├── DeletionErrorStatus.tsx │ │ │ │ │ ├── IndexAttemptErrorsModal.tsx │ │ │ │ │ ├── IndexAttemptsTable.tsx │ │ │ │ │ ├── InlineFileManagement.tsx │ │ │ │ │ ├── ReIndexModal.tsx │ │ │ │ │ ├── lib.ts │ │ │ │ │ ├── page.tsx │ │ │ │ │ ├── types.ts │ │ │ │ │ └── useStatusChange.tsx │ │ │ │ ├── connectors/ │ │ │ │ │ └── [connector]/ │ │ │ │ │ ├── AddConnectorPage.tsx │ │ │ │ │ ├── ConnectorWrapper.tsx │ │ │ │ │ ├── NavigationRow.tsx │ │ │ │ │ ├── auth/ │ │ │ │ │ │ └── callback/ │ │ │ │ │ │ └── route.ts │ │ │ │ │ ├── oauth/ │ │ │ │ │ │ ├── callback/ │ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ │ └── finalize/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── page.tsx │ │ │ │ │ └── pages/ │ │ │ │ │ ├── Advanced.tsx │ │ │ │ │ ├── ConnectorInput/ │ │ │ │ │ │ ├── FileInput.tsx │ │ │ │ │ │ ├── ListInput.tsx │ │ │ │ │ │ ├── NumberInput.tsx │ │ │ │ │ │ └── SelectInput.tsx │ │ │ │ │ ├── DynamicConnectorCreationForm.tsx │ │ │ │ │ ├── FieldRendering.tsx │ │ │ │ │ ├── gdrive/ │ │ │ │ │ │ ├── Credential.tsx │ │ │ │ │ │ └── GoogleDrivePage.tsx │ │ │ │ │ ├── gmail/ │ │ │ │ │ │ ├── Credential.tsx │ │ │ │ │ │ └── GmailPage.tsx │ │ │ │ │ └── utils/ │ │ │ │ │ ├── files.ts │ │ │ │ │ ├── google_site.ts │ │ │ │ │ └── hooks.ts │ │ │ │ ├── debug/ │ │ │ │ │ └── page.tsx │ │ │ │ ├── discord-bot/ │ │ │ │ │ ├── BotConfigCard.tsx │ │ │ │ │ ├── DiscordGuildsTable.tsx │ │ │ │ │ ├── [guild-id]/ │ │ │ │ │ │ ├── DiscordChannelsTable.tsx │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── hooks.ts │ │ │ │ │ ├── lib.ts │ │ │ │ │ ├── page.tsx │ │ │ │ │ └── types.ts │ │ │ │ ├── document-index-migration/ │ │ │ │ │ └── page.tsx │ │ │ │ ├── documents/ │ │ │ │ │ ├── ScoreEditor.tsx │ │ │ │ │ ├── explorer/ │ │ │ │ │ │ ├── DocumentExplorerPage.tsx │ │ │ │ │ │ ├── Explorer.tsx │ │ │ │ │ │ ├── lib.ts │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── feedback/ │ │ │ │ │ │ ├── DocumentFeedbackTable.tsx │ │ │ │ │ │ ├── constants.ts │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── lib.ts │ │ │ │ │ └── sets/ │ │ │ │ │ ├── DocumentSetCreationForm.tsx │ │ │ │ │ ├── [documentSetId]/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── hooks.tsx │ │ │ │ │ ├── lib.ts │ │ │ │ │ ├── new/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ └── page.tsx │ │ │ │ ├── embeddings/ │ │ │ │ │ ├── EmbeddingModelSelectionForm.tsx │ │ │ │ │ ├── RerankingFormPage.tsx │ │ │ │ │ ├── interfaces.ts │ │ │ │ │ ├── modals/ │ │ │ │ │ │ ├── AlreadyPickedModal.tsx │ │ │ │ │ │ ├── ChangeCredentialsModal.tsx │ │ │ │ │ │ ├── DeleteCredentialsModal.tsx │ │ │ │ │ │ ├── InstantSwitchConfirmModal.tsx │ │ │ │ │ │ ├── ModelSelectionModal.tsx │ │ │ │ │ │ ├── ProviderCreationModal.tsx │ │ │ │ │ │ └── SelectModelModal.tsx │ │ │ │ │ ├── page.tsx │ │ │ │ │ └── pages/ │ │ │ │ │ ├── AdvancedEmbeddingFormPage.tsx │ │ │ │ │ ├── CloudEmbeddingPage.tsx │ │ │ │ │ ├── EmbeddingFormPage.tsx │ │ │ │ │ ├── OpenEmbeddingPage.tsx │ │ │ │ │ └── utils.ts │ │ │ │ ├── federated/ │ │ │ │ │ └── [id]/ │ │ │ │ │ ├── page.tsx │ │ │ │ │ └── useFederatedConnector.ts │ │ │ │ ├── groups/ │ │ │ │ │ ├── [id]/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── create/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ └── page.tsx │ │ │ │ ├── groups2/ │ │ │ │ │ ├── [id]/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── create/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ └── page.tsx │ │ │ │ ├── hooks/ │ │ │ │ │ └── page.tsx │ │ │ │ ├── indexing/ │ │ │ │ │ └── status/ │ │ │ │ │ ├── CCPairIndexingStatusTable.tsx │ │ │ │ │ ├── ConnectorRowSkeleton.tsx │ │ │ │ │ ├── FilterComponent.tsx │ │ │ │ │ ├── SearchAndFilterControls.tsx │ │ │ │ │ └── page.tsx │ │ │ │ ├── kg/ │ │ │ │ │ ├── KGEntityTypes.tsx │ │ │ │ │ ├── interfaces.ts │ │ │ │ │ ├── page.tsx │ │ │ │ │ └── utils.ts │ │ │ │ ├── layout.tsx │ │ │ │ ├── scim/ │ │ │ │ │ ├── ScimModal.tsx │ │ │ │ │ ├── ScimSyncCard.tsx │ │ │ │ │ ├── interfaces.ts │ │ │ │ │ ├── page.tsx │ │ │ │ │ └── svc.ts │ │ │ │ ├── service-accounts/ │ │ │ │ │ └── page.tsx │ │ │ │ ├── systeminfo/ │ │ │ │ │ └── page.tsx │ │ │ │ ├── token-rate-limits/ │ │ │ │ │ ├── CreateRateLimitModal.tsx │ │ │ │ │ ├── TokenRateLimitTables.tsx │ │ │ │ │ ├── lib.ts │ │ │ │ │ ├── page.tsx │ │ │ │ │ └── types.ts │ │ │ │ └── users/ │ │ │ │ └── page.tsx │ │ │ ├── anonymous/ │ │ │ │ └── [id]/ │ │ │ │ ├── AnonymousPage.tsx │ │ │ │ └── page.tsx │ │ │ ├── api/ │ │ │ │ ├── [...path]/ │ │ │ │ │ └── route.ts │ │ │ │ └── chat/ │ │ │ │ └── mcp/ │ │ │ │ └── oauth/ │ │ │ │ └── callback/ │ │ │ │ └── route.ts │ │ │ ├── app/ │ │ │ │ ├── agents/ │ │ │ │ │ ├── create/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── edit/ │ │ │ │ │ │ └── [id]/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ └── page.tsx │ │ │ │ ├── components/ │ │ │ │ │ ├── AgentDescription.tsx │ │ │ │ │ ├── AppPopup.tsx │ │ │ │ │ ├── WelcomeMessage.tsx │ │ │ │ │ ├── files/ │ │ │ │ │ │ ├── InputBarPreview.tsx │ │ │ │ │ │ └── images/ │ │ │ │ │ │ ├── FullImageModal.tsx │ │ │ │ │ │ ├── InMessageImage.tsx │ │ │ │ │ │ ├── InputBarPreviewImage.tsx │ │ │ │ │ │ └── utils.ts │ │ │ │ │ ├── folders/ │ │ │ │ │ │ ├── FolderDropdown.tsx │ │ │ │ │ │ └── interfaces.ts │ │ │ │ │ ├── modifiers/ │ │ │ │ │ │ └── SelectedDocuments.tsx │ │ │ │ │ ├── projects/ │ │ │ │ │ │ ├── ProjectChatSessionList.tsx │ │ │ │ │ │ ├── ProjectContextPanel.tsx │ │ │ │ │ │ └── project_utils.ts │ │ │ │ │ └── tools/ │ │ │ │ │ ├── GeneratingImageDisplay.tsx │ │ │ │ │ └── constants.ts │ │ │ │ ├── interfaces.ts │ │ │ │ ├── layout.tsx │ │ │ │ ├── message/ │ │ │ │ │ ├── BlinkingBar.tsx │ │ │ │ │ ├── CodeBlock.tsx │ │ │ │ │ ├── FileDisplay.tsx │ │ │ │ │ ├── HumanMessage.tsx │ │ │ │ │ ├── MemoizedTextComponents.tsx │ │ │ │ │ ├── MessageSwitcher.tsx │ │ │ │ │ ├── Resubmit.tsx │ │ │ │ │ ├── codeUtils.test.ts │ │ │ │ │ ├── codeUtils.ts │ │ │ │ │ ├── copyingUtils.tsx │ │ │ │ │ ├── custom-code-styles.css │ │ │ │ │ ├── errorHelpers.tsx │ │ │ │ │ ├── hooks.ts │ │ │ │ │ ├── messageComponents/ │ │ │ │ │ │ ├── AgentMessage.tsx │ │ │ │ │ │ ├── CustomToolAuthCard.tsx │ │ │ │ │ │ ├── MessageToolbar.tsx │ │ │ │ │ │ ├── TTSButton.tsx │ │ │ │ │ │ ├── constants.ts │ │ │ │ │ │ ├── hooks/ │ │ │ │ │ │ │ ├── useAuthErrors.ts │ │ │ │ │ │ │ ├── useMessageSwitching.ts │ │ │ │ │ │ │ └── usePacketAnimationAndCollapse.ts │ │ │ │ │ │ ├── interfaces.ts │ │ │ │ │ │ ├── markdownUtils.tsx │ │ │ │ │ │ ├── renderMessageComponent.tsx │ │ │ │ │ │ ├── renderers/ │ │ │ │ │ │ │ ├── CustomToolRenderer.tsx │ │ │ │ │ │ │ ├── ImageToolRenderer.tsx │ │ │ │ │ │ │ └── MessageTextRenderer.tsx │ │ │ │ │ │ ├── timeline/ │ │ │ │ │ │ │ ├── AgentTimeline.tsx │ │ │ │ │ │ │ ├── CollapsedStreamingContent.tsx │ │ │ │ │ │ │ ├── ExpandedTimelineContent.tsx │ │ │ │ │ │ │ ├── ParallelTimelineTabs.tsx │ │ │ │ │ │ │ ├── StepContainer.tsx │ │ │ │ │ │ │ ├── TimelineRendererComponent.tsx │ │ │ │ │ │ │ ├── TimelineStepComposer.tsx │ │ │ │ │ │ │ ├── headers/ │ │ │ │ │ │ │ │ ├── CompletedHeader.tsx │ │ │ │ │ │ │ │ ├── ParallelStreamingHeader.tsx │ │ │ │ │ │ │ │ ├── StoppedHeader.tsx │ │ │ │ │ │ │ │ └── StreamingHeader.tsx │ │ │ │ │ │ │ ├── hooks/ │ │ │ │ │ │ │ │ ├── __tests__/ │ │ │ │ │ │ │ │ │ └── testHelpers.ts │ │ │ │ │ │ │ │ ├── packetProcessor.test.ts │ │ │ │ │ │ │ │ ├── packetProcessor.ts │ │ │ │ │ │ │ │ ├── usePacedTurnGroups.test.tsx │ │ │ │ │ │ │ │ ├── usePacedTurnGroups.ts │ │ │ │ │ │ │ │ ├── usePacketProcessor.test.tsx │ │ │ │ │ │ │ │ ├── usePacketProcessor.ts │ │ │ │ │ │ │ │ ├── useStreamingDuration.ts │ │ │ │ │ │ │ │ ├── useTimelineExpansion.ts │ │ │ │ │ │ │ │ ├── useTimelineHeader.ts │ │ │ │ │ │ │ │ ├── useTimelineMetrics.ts │ │ │ │ │ │ │ │ ├── useTimelineStepState.ts │ │ │ │ │ │ │ │ └── useTimelineUIState.ts │ │ │ │ │ │ │ ├── packetHelpers.ts │ │ │ │ │ │ │ ├── primitives/ │ │ │ │ │ │ │ │ ├── TimelineHeaderRow.tsx │ │ │ │ │ │ │ │ ├── TimelineIconColumn.tsx │ │ │ │ │ │ │ │ ├── TimelineRoot.tsx │ │ │ │ │ │ │ │ ├── TimelineRow.tsx │ │ │ │ │ │ │ │ ├── TimelineStepContent.tsx │ │ │ │ │ │ │ │ ├── TimelineSurface.tsx │ │ │ │ │ │ │ │ ├── TimelineTopSpacer.tsx │ │ │ │ │ │ │ │ └── tokens.ts │ │ │ │ │ │ │ ├── renderers/ │ │ │ │ │ │ │ │ ├── code/ │ │ │ │ │ │ │ │ │ └── PythonToolRenderer.tsx │ │ │ │ │ │ │ │ ├── deepresearch/ │ │ │ │ │ │ │ │ │ ├── DeepResearchPlanRenderer.tsx │ │ │ │ │ │ │ │ │ └── ResearchAgentRenderer.tsx │ │ │ │ │ │ │ │ ├── fetch/ │ │ │ │ │ │ │ │ │ ├── FetchToolRenderer.tsx │ │ │ │ │ │ │ │ │ └── fetchStateUtils.ts │ │ │ │ │ │ │ │ ├── filereader/ │ │ │ │ │ │ │ │ │ └── FileReaderToolRenderer.tsx │ │ │ │ │ │ │ │ ├── memory/ │ │ │ │ │ │ │ │ │ ├── MemoryToolRenderer.tsx │ │ │ │ │ │ │ │ │ └── memoryStateUtils.ts │ │ │ │ │ │ │ │ ├── reasoning/ │ │ │ │ │ │ │ │ │ └── ReasoningRenderer.tsx │ │ │ │ │ │ │ │ ├── search/ │ │ │ │ │ │ │ │ │ ├── InternalSearchToolRenderer.tsx │ │ │ │ │ │ │ │ │ ├── SearchChipList.tsx │ │ │ │ │ │ │ │ │ ├── WebSearchToolRenderer.tsx │ │ │ │ │ │ │ │ │ └── searchStateUtils.ts │ │ │ │ │ │ │ │ └── sharedMarkdownComponents.tsx │ │ │ │ │ │ │ └── transformers.ts │ │ │ │ │ │ ├── timing.ts │ │ │ │ │ │ └── toolDisplayHelpers.tsx │ │ │ │ │ └── thinkingBox/ │ │ │ │ │ └── ThinkingBox.css │ │ │ │ ├── page.tsx │ │ │ │ ├── projects/ │ │ │ │ │ └── projectsService.ts │ │ │ │ ├── services/ │ │ │ │ │ ├── actionUtils.ts │ │ │ │ │ ├── currentMessageFIFO.ts │ │ │ │ │ ├── fileUtils.ts │ │ │ │ │ ├── lib.tsx │ │ │ │ │ ├── messageTree.ts │ │ │ │ │ ├── packetUtils.test.ts │ │ │ │ │ ├── packetUtils.ts │ │ │ │ │ ├── searchParams.ts │ │ │ │ │ ├── streamingModels.ts │ │ │ │ │ └── thinkingTokens.ts │ │ │ │ ├── settings/ │ │ │ │ │ ├── accounts-access/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── chat-preferences/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── connectors/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── general/ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── layout.tsx │ │ │ │ │ └── page.tsx │ │ │ │ ├── shared/ │ │ │ │ │ └── [chatId]/ │ │ │ │ │ ├── SharedChatDisplay.tsx │ │ │ │ │ └── page.tsx │ │ │ │ └── stores/ │ │ │ │ └── useChatSessionStore.ts │ │ │ ├── auth/ │ │ │ │ ├── create-account/ │ │ │ │ │ └── page.tsx │ │ │ │ ├── error/ │ │ │ │ │ ├── AuthErrorContent.tsx │ │ │ │ │ ├── layout.tsx │ │ │ │ │ └── page.tsx │ │ │ │ ├── forgot-password/ │ │ │ │ │ ├── page.tsx │ │ │ │ │ └── utils.ts │ │ │ │ ├── impersonate/ │ │ │ │ │ └── page.tsx │ │ │ │ ├── join/ │ │ │ │ │ └── page.tsx │ │ │ │ ├── lib.ts │ │ │ │ ├── libSS.ts │ │ │ │ ├── login/ │ │ │ │ │ ├── EmailPasswordForm.test.tsx │ │ │ │ │ ├── EmailPasswordForm.tsx │ │ │ │ │ ├── LoginPage.tsx │ │ │ │ │ ├── LoginText.tsx │ │ │ │ │ ├── SignInButton.tsx │ │ │ │ │ └── page.tsx │ │ │ │ ├── logout/ │ │ │ │ │ └── route.ts │ │ │ │ ├── oauth/ │ │ │ │ │ └── callback/ │ │ │ │ │ └── route.ts │ │ │ │ ├── oidc/ │ │ │ │ │ └── callback/ │ │ │ │ │ └── route.ts │ │ │ │ ├── reset-password/ │ │ │ │ │ └── page.tsx │ │ │ │ ├── saml/ │ │ │ │ │ └── callback/ │ │ │ │ │ └── route.ts │ │ │ │ ├── signup/ │ │ │ │ │ ├── ReferralSourceSelector.tsx │ │ │ │ │ └── page.tsx │ │ │ │ ├── verify-email/ │ │ │ │ │ ├── Verify.tsx │ │ │ │ │ └── page.tsx │ │ │ │ └── waiting-on-verification/ │ │ │ │ ├── RequestNewVerificationEmail.tsx │ │ │ │ └── page.tsx │ │ │ ├── components/ │ │ │ │ └── nrf/ │ │ │ │ └── SettingsPanel.tsx │ │ │ ├── config/ │ │ │ │ └── timeRange.tsx │ │ │ ├── connector/ │ │ │ │ └── oauth/ │ │ │ │ └── callback/ │ │ │ │ └── [source]/ │ │ │ │ └── route.tsx │ │ │ ├── craft/ │ │ │ │ ├── README.md │ │ │ │ ├── components/ │ │ │ │ │ ├── BigButton.tsx │ │ │ │ │ ├── BuildLLMPopover.tsx │ │ │ │ │ ├── BuildMessageList.tsx │ │ │ │ │ ├── BuildWelcome.tsx │ │ │ │ │ ├── ChatPanel.tsx │ │ │ │ │ ├── ConnectDataBanner.tsx │ │ │ │ │ ├── ConnectorBannersRow.tsx │ │ │ │ │ ├── CraftingLoader.tsx │ │ │ │ │ ├── DiffView.tsx │ │ │ │ │ ├── FileBrowser.tsx │ │ │ │ │ ├── FilePreviewModal.tsx │ │ │ │ │ ├── InputBar.tsx │ │ │ │ │ ├── IntroBackground.tsx │ │ │ │ │ ├── IntroContent.tsx │ │ │ │ │ ├── OutputPanel.tsx │ │ │ │ │ ├── RawOutputBlock.tsx │ │ │ │ │ ├── SandboxStatusIndicator.tsx │ │ │ │ │ ├── ShareButton.tsx │ │ │ │ │ ├── SideBar.tsx │ │ │ │ │ ├── SuggestedPrompts.tsx │ │ │ │ │ ├── SuggestionBubbles.tsx │ │ │ │ │ ├── TextChunk.tsx │ │ │ │ │ ├── ThinkingCard.tsx │ │ │ │ │ ├── TodoListCard.tsx │ │ │ │ │ ├── ToggleWarningModal.tsx │ │ │ │ │ ├── ToolCallPill.tsx │ │ │ │ │ ├── TypewriterText.tsx │ │ │ │ │ ├── UpgradePlanModal.tsx │ │ │ │ │ ├── UserMessage.tsx │ │ │ │ │ ├── WorkingLine.tsx │ │ │ │ │ ├── WorkingPill.tsx │ │ │ │ │ └── output-panel/ │ │ │ │ │ ├── ArtifactsTab.tsx │ │ │ │ │ ├── FilePreviewContent.tsx │ │ │ │ │ ├── FilesTab.tsx │ │ │ │ │ ├── ImagePreview.tsx │ │ │ │ │ ├── MarkdownFilePreview.tsx │ │ │ │ │ ├── PdfPreview.tsx │ │ │ │ │ ├── PptxPreview.tsx │ │ │ │ │ ├── PreviewTab.tsx │ │ │ │ │ └── UrlBar.tsx │ │ │ │ ├── constants/ │ │ │ │ │ └── exampleBuildPrompts.ts │ │ │ │ ├── constants.ts │ │ │ │ ├── contexts/ │ │ │ │ │ ├── BuildContext.tsx │ │ │ │ │ └── UploadFilesContext.tsx │ │ │ │ ├── hooks/ │ │ │ │ │ ├── useBuildConnectors.ts │ │ │ │ │ ├── useBuildLlmSelection.ts │ │ │ │ │ ├── useBuildSessionController.ts │ │ │ │ │ ├── useBuildSessionStore.ts │ │ │ │ │ ├── useBuildStreaming.ts │ │ │ │ │ ├── usePreProvisionPolling.ts │ │ │ │ │ └── useUsageLimits.ts │ │ │ │ ├── layout.tsx │ │ │ │ ├── onboarding/ │ │ │ │ │ ├── BuildOnboardingProvider.tsx │ │ │ │ │ ├── components/ │ │ │ │ │ │ ├── BuildOnboardingModal.tsx │ │ │ │ │ │ ├── NoLlmProvidersModal.tsx │ │ │ │ │ │ ├── NotAllowedModal.tsx │ │ │ │ │ │ ├── OnboardingInfoPages.tsx │ │ │ │ │ │ ├── OnboardingLlmSetup.tsx │ │ │ │ │ │ └── OnboardingUserInfo.tsx │ │ │ │ │ ├── constants.ts │ │ │ │ │ ├── hooks/ │ │ │ │ │ │ └── useOnboardingModal.ts │ │ │ │ │ └── types.ts │ │ │ │ ├── page.tsx │ │ │ │ ├── services/ │ │ │ │ │ ├── apiServices.ts │ │ │ │ │ └── searchParams.ts │ │ │ │ ├── types/ │ │ │ │ │ ├── displayTypes.ts │ │ │ │ │ ├── streamingTypes.ts │ │ │ │ │ └── user-library.ts │ │ │ │ ├── utils/ │ │ │ │ │ ├── packetTypes.ts │ │ │ │ │ ├── parsePacket.ts │ │ │ │ │ ├── pathSanitizer.test.ts │ │ │ │ │ ├── pathSanitizer.ts │ │ │ │ │ └── streamItemHelpers.ts │ │ │ │ └── v1/ │ │ │ │ ├── configure/ │ │ │ │ │ ├── components/ │ │ │ │ │ │ ├── ComingSoonConnectors.tsx │ │ │ │ │ │ ├── ConfigureConnectorModal.tsx │ │ │ │ │ │ ├── ConfigureOverlays.tsx │ │ │ │ │ │ ├── ConnectorCard.tsx │ │ │ │ │ │ ├── ConnectorConfigStep.tsx │ │ │ │ │ │ ├── CreateCredentialInline.tsx │ │ │ │ │ │ ├── CredentialStep.tsx │ │ │ │ │ │ ├── DemoDataConfirmModal.tsx │ │ │ │ │ │ ├── RequestConnectorModal.tsx │ │ │ │ │ │ └── UserLibraryModal.tsx │ │ │ │ │ ├── page.tsx │ │ │ │ │ └── utils/ │ │ │ │ │ └── createBuildConnector.ts │ │ │ │ ├── constants.ts │ │ │ │ ├── layout.tsx │ │ │ │ └── page.tsx │ │ │ ├── css/ │ │ │ │ ├── attachment-button.css │ │ │ │ ├── button.css │ │ │ │ ├── card.css │ │ │ │ ├── code.css │ │ │ │ ├── color-swatch.css │ │ │ │ ├── colors.css │ │ │ │ ├── divider.css │ │ │ │ ├── general-layouts.css │ │ │ │ ├── inputs.css │ │ │ │ ├── knowledge-table.css │ │ │ │ ├── line-item.css │ │ │ │ ├── sizes.css │ │ │ │ ├── square-button.css │ │ │ │ ├── switch.css │ │ │ │ └── z-index.css │ │ │ ├── ee/ │ │ │ │ ├── EEFeatureRedirect.tsx │ │ │ │ ├── LICENSE │ │ │ │ ├── admin/ │ │ │ │ │ ├── billing/ │ │ │ │ │ │ ├── BillingAlerts.tsx │ │ │ │ │ │ ├── BillingInformationPage.tsx │ │ │ │ │ │ ├── InfoItem.tsx │ │ │ │ │ │ ├── SubscriptionSummary.tsx │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── groups/ │ │ │ │ │ │ ├── [id]/ │ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ │ ├── create/ │ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ │ └── page.tsx │ │ │ │ │ ├── layout.tsx │ │ │ │ │ ├── performance/ │ │ │ │ │ │ ├── custom-analytics/ │ │ │ │ │ │ │ ├── CustomAnalyticsUpdateForm.tsx │ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ │ ├── lib.ts │ │ │ │ │ │ ├── query-history/ │ │ │ │ │ │ │ ├── FeedbackBadge.tsx │ │ │ │ │ │ │ ├── KickoffCSVExport.tsx │ │ │ │ │ │ │ ├── QueryHistoryTable.tsx │ │ │ │ │ │ │ ├── [id]/ │ │ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ │ │ ├── constants.ts │ │ │ │ │ │ │ ├── page.tsx │ │ │ │ │ │ │ ├── types.ts │ │ │ │ │ │ │ └── utils.ts │ │ │ │ │ │ └── usage/ │ │ │ │ │ │ ├── FeedbackChart.tsx │ │ │ │ │ │ ├── OnyxBotChart.tsx │ │ │ │ │ │ ├── PersonaMessagesChart.tsx │ │ │ │ │ │ ├── QueryPerformanceChart.tsx │ │ │ │ │ │ ├── UsageReports.tsx │ │ │ │ │ │ ├── page.tsx │ │ │ │ │ │ └── types.ts │ │ │ │ │ ├── standard-answer/ │ │ │ │ │ │ ├── StandardAnswerCreationForm.tsx │ │ │ │ │ │ ├── [id]/ │ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ │ ├── hooks.ts │ │ │ │ │ │ ├── lib.ts │ │ │ │ │ │ ├── new/ │ │ │ │ │ │ │ └── page.tsx │ │ │ │ │ │ └── page.tsx │ │ │ │ │ └── theme/ │ │ │ │ │ ├── AppearanceThemeSettings.tsx │ │ │ │ │ ├── Preview.tsx │ │ │ │ │ └── page.tsx │ │ │ │ ├── agents/ │ │ │ │ │ └── stats/ │ │ │ │ │ └── [id]/ │ │ │ │ │ ├── AgentStats.tsx │ │ │ │ │ └── page.tsx │ │ │ │ └── layout.tsx │ │ │ ├── federated/ │ │ │ │ └── oauth/ │ │ │ │ └── callback/ │ │ │ │ └── page.tsx │ │ │ ├── global-error.tsx │ │ │ ├── globals.css │ │ │ ├── layout.tsx │ │ │ ├── mcp/ │ │ │ │ ├── [[...path]]/ │ │ │ │ │ └── route.ts │ │ │ │ └── oauth/ │ │ │ │ └── callback/ │ │ │ │ └── page.tsx │ │ │ ├── not-found.tsx │ │ │ ├── nrf/ │ │ │ │ ├── (main)/ │ │ │ │ │ ├── layout.tsx │ │ │ │ │ └── page.tsx │ │ │ │ ├── NRFChrome.tsx │ │ │ │ ├── NRFPage.tsx │ │ │ │ ├── layout.tsx │ │ │ │ └── side-panel/ │ │ │ │ ├── SidePanelHeader.tsx │ │ │ │ └── page.tsx │ │ │ ├── oauth-config/ │ │ │ │ └── callback/ │ │ │ │ └── page.tsx │ │ │ ├── page.tsx │ │ │ ├── providers.tsx │ │ │ └── web-vitals.tsx │ │ ├── ce.tsx │ │ ├── components/ │ │ │ ├── AdvancedOptionsToggle.tsx │ │ │ ├── AgentsMultiSelect.tsx │ │ │ ├── BasicClickable.tsx │ │ │ ├── Bubble.tsx │ │ │ ├── CollapsibleCard.tsx │ │ │ ├── ConnectorMultiSelect.tsx │ │ │ ├── DeleteButton.tsx │ │ │ ├── Dropdown.tsx │ │ │ ├── EditableStringFieldDisplay.tsx │ │ │ ├── EditableValue.tsx │ │ │ ├── ErrorCallout.tsx │ │ │ ├── FederatedConnectorSelector.tsx │ │ │ ├── Field.tsx │ │ │ ├── FormErrorHelpers.tsx │ │ │ ├── GatedContentWrapper.tsx │ │ │ ├── GenericMultiSelect.tsx │ │ │ ├── GroupsMultiSelect.tsx │ │ │ ├── HoverPopup.tsx │ │ │ ├── IsPublicGroupSelector.tsx │ │ │ ├── Loading.tsx │ │ │ ├── MetadataBadge.tsx │ │ │ ├── MultiSelectDropdown.tsx │ │ │ ├── NonSelectableConnectors.tsx │ │ │ ├── OnyxInitializingLoader.tsx │ │ │ ├── PageSelector.tsx │ │ │ ├── RichTextSubtext.tsx │ │ │ ├── SSRAutoRefresh.tsx │ │ │ ├── SearchResultIcon.tsx │ │ │ ├── SourceIcon.tsx │ │ │ ├── SourceTile.tsx │ │ │ ├── Spinner.tsx │ │ │ ├── Status.tsx │ │ │ ├── WebResultIcon.tsx │ │ │ ├── admin/ │ │ │ │ ├── CardSection.tsx │ │ │ │ ├── ClientLayout.tsx │ │ │ │ ├── Layout.tsx │ │ │ │ ├── Title.tsx │ │ │ │ ├── connectors/ │ │ │ │ │ ├── AccessTypeForm.tsx │ │ │ │ │ ├── AccessTypeGroupSelector.tsx │ │ │ │ │ ├── AutoSyncOptions.tsx │ │ │ │ │ ├── BasicTable.tsx │ │ │ │ │ ├── ConnectorDocsLink.tsx │ │ │ │ │ ├── ConnectorTitle.tsx │ │ │ │ │ ├── CredentialForm.tsx │ │ │ │ │ ├── FileUpload.tsx │ │ │ │ │ └── types.ts │ │ │ │ ├── federated/ │ │ │ │ │ └── FederatedConnectorForm.tsx │ │ │ │ └── users/ │ │ │ │ ├── BulkAdd.tsx │ │ │ │ ├── CenteredPageSelector.tsx │ │ │ │ ├── InvitedUserTable.tsx │ │ │ │ ├── PendingUsersTable.tsx │ │ │ │ ├── ResetPasswordModal.tsx │ │ │ │ ├── SignedUpUserTable.tsx │ │ │ │ └── buttons/ │ │ │ │ ├── DeactivateUserButton.tsx │ │ │ │ ├── DeleteUserButton.tsx │ │ │ │ ├── InviteUserButton.tsx │ │ │ │ ├── LeaveOrganizationButton.tsx │ │ │ │ └── UserRoleDropdown.tsx │ │ │ ├── auth/ │ │ │ │ ├── AuthErrorDisplay.tsx │ │ │ │ └── AuthFlowContainer.tsx │ │ │ ├── chat/ │ │ │ │ ├── DynamicBottomSpacer.tsx │ │ │ │ ├── FederatedOAuthModal.tsx │ │ │ │ ├── MCPApiKeyModal.tsx │ │ │ │ ├── MinimalMarkdown.test.tsx │ │ │ │ ├── MinimalMarkdown.tsx │ │ │ │ ├── ProviderContext.tsx │ │ │ │ └── ScrollContainerContext.tsx │ │ │ ├── context/ │ │ │ │ ├── EmbeddingContext.tsx │ │ │ │ ├── FormContext.tsx │ │ │ │ ├── ModalContext.tsx │ │ │ │ └── NRFPreferencesContext.tsx │ │ │ ├── credentials/ │ │ │ │ ├── CredentialFields.tsx │ │ │ │ ├── CredentialSection.tsx │ │ │ │ ├── actions/ │ │ │ │ │ ├── CreateCredential.tsx │ │ │ │ │ ├── CreateStdOAuthCredential.tsx │ │ │ │ │ ├── CredentialFieldsRenderer.tsx │ │ │ │ │ ├── EditCredential.tsx │ │ │ │ │ └── ModifyCredential.tsx │ │ │ │ ├── lib.ts │ │ │ │ └── types.ts │ │ │ ├── dateRangeSelectors/ │ │ │ │ ├── AdminDateRangeSelector.tsx │ │ │ │ ├── SearchDateRangeSelector.tsx │ │ │ │ └── dateUtils.ts │ │ │ ├── dev/ │ │ │ │ ├── StatsOverlay.tsx │ │ │ │ └── StatsOverlayLoader.tsx │ │ │ ├── embedding/ │ │ │ │ ├── CustomEmbeddingModelForm.tsx │ │ │ │ ├── CustomModelForm.tsx │ │ │ │ ├── FailedReIndexAttempts.tsx │ │ │ │ ├── ModelSelector.tsx │ │ │ │ ├── ReindexingProgressTable.tsx │ │ │ │ └── interfaces.tsx │ │ │ ├── errorPages/ │ │ │ │ ├── AccessRestrictedPage.tsx │ │ │ │ ├── CloudErrorPage.tsx │ │ │ │ ├── ErrorPage.tsx │ │ │ │ └── ErrorPageLayout.tsx │ │ │ ├── filters/ │ │ │ │ ├── SourceSelector.tsx │ │ │ │ └── TimeRangeSelector.tsx │ │ │ ├── header/ │ │ │ │ ├── AnnouncementBanner.tsx │ │ │ │ └── HeaderTitle.tsx │ │ │ ├── icons/ │ │ │ │ ├── DynamicFaIcon.tsx │ │ │ │ ├── icons.test.tsx │ │ │ │ └── icons.tsx │ │ │ ├── llm/ │ │ │ │ └── LLMSelector.tsx │ │ │ ├── loading.css │ │ │ ├── modals/ │ │ │ │ ├── AddInstructionModal.tsx │ │ │ │ ├── ConfirmEntityModal.tsx │ │ │ │ ├── CreateProjectModal.tsx │ │ │ │ ├── EditPropertyModal.tsx │ │ │ │ ├── GenericConfirmModal.tsx │ │ │ │ ├── MoveCustomAgentChatModal.tsx │ │ │ │ ├── NewTeamModal.tsx │ │ │ │ ├── NoAgentModal.tsx │ │ │ │ ├── ProviderModal.tsx │ │ │ │ └── UserFilesModal.tsx │ │ │ ├── oauth/ │ │ │ │ └── OAuthCallbackPage.tsx │ │ │ ├── resizable/ │ │ │ │ └── constants.ts │ │ │ ├── search/ │ │ │ │ ├── DocumentDisplay.tsx │ │ │ │ ├── DocumentFeedbackBlock.tsx │ │ │ │ ├── DocumentUpdatedAtBadge.tsx │ │ │ │ ├── filtering/ │ │ │ │ │ └── FilterDropdown.tsx │ │ │ │ └── results/ │ │ │ │ ├── Citation.tsx │ │ │ │ └── ResponseSection.tsx │ │ │ ├── settings/ │ │ │ │ ├── lib.ts │ │ │ │ └── usePaidEnterpriseFeaturesEnabled.ts │ │ │ ├── sidebar/ │ │ │ │ ├── ChatSessionMorePopup.tsx │ │ │ │ └── types.ts │ │ │ ├── spinner.css │ │ │ ├── standardAnswers/ │ │ │ │ ├── StandardAnswerCategoryDropdown.tsx │ │ │ │ └── getStandardAnswerCategoriesIfEE.tsx │ │ │ ├── table/ │ │ │ │ ├── DragHandle.tsx │ │ │ │ ├── DraggableRow.tsx │ │ │ │ ├── DraggableTable.tsx │ │ │ │ └── interfaces.ts │ │ │ ├── theme/ │ │ │ │ └── ThemeProvider.tsx │ │ │ ├── tools/ │ │ │ │ ├── CSVContent.tsx │ │ │ │ ├── ExpandableContentWrapper.tsx │ │ │ │ └── parseCSV.test.ts │ │ │ ├── tooltip/ │ │ │ │ └── CustomTooltip.tsx │ │ │ ├── ui/ │ │ │ │ ├── RadioGroupItemField.tsx │ │ │ │ ├── accordion.tsx │ │ │ │ ├── alert.tsx │ │ │ │ ├── areaChart.tsx │ │ │ │ ├── badge.tsx │ │ │ │ ├── callout.tsx │ │ │ │ ├── card.tsx │ │ │ │ ├── dialog.tsx │ │ │ │ ├── dropdown-menu-with-tooltip.tsx │ │ │ │ ├── dropdown-menu.tsx │ │ │ │ ├── input.tsx │ │ │ │ ├── radio-group.tsx │ │ │ │ ├── scroll-area.tsx │ │ │ │ ├── select.tsx │ │ │ │ ├── slider.tsx │ │ │ │ ├── table.tsx │ │ │ │ ├── title.tsx │ │ │ │ └── tooltip.tsx │ │ │ └── voice/ │ │ │ └── Waveform.tsx │ │ ├── ee/ │ │ │ ├── LICENSE │ │ │ ├── hooks/ │ │ │ │ ├── useHookExecutionLogs.ts │ │ │ │ ├── useHookSpecs.ts │ │ │ │ └── useHooks.ts │ │ │ ├── lib/ │ │ │ │ └── search/ │ │ │ │ └── svc.ts │ │ │ ├── providers/ │ │ │ │ └── QueryControllerProvider.tsx │ │ │ ├── refresh-pages/ │ │ │ │ └── admin/ │ │ │ │ └── HooksPage/ │ │ │ │ ├── HookFormModal.tsx │ │ │ │ ├── HookLogsModal.tsx │ │ │ │ ├── HookStatusPopover.tsx │ │ │ │ ├── index.tsx │ │ │ │ ├── interfaces.ts │ │ │ │ └── svc.ts │ │ │ └── sections/ │ │ │ ├── SearchCard.tsx │ │ │ └── SearchUI.tsx │ │ ├── hooks/ │ │ │ ├── __tests__/ │ │ │ │ └── useShowOnboarding.test.tsx │ │ │ ├── appNavigation.ts │ │ │ ├── formHooks.ts │ │ │ ├── useAdminPersonas.ts │ │ │ ├── useAdminUsers.ts │ │ │ ├── useAgentController.ts │ │ │ ├── useAgentPreferences.ts │ │ │ ├── useAgents.ts │ │ │ ├── useAppFocus.ts │ │ │ ├── useAuthTypeMetadata.ts │ │ │ ├── useAvailableTools.ts │ │ │ ├── useBillingInformation.ts │ │ │ ├── useBoundingBox.ts │ │ │ ├── useBrowserInfo.ts │ │ │ ├── useCCPairs.ts │ │ │ ├── useChatController.ts │ │ │ ├── useChatSessionController.ts │ │ │ ├── useChatSessions.ts │ │ │ ├── useClickOutside.ts │ │ │ ├── useCloudSubscription.ts │ │ │ ├── useCodeInterpreter.ts │ │ │ ├── useContainerCenter.ts │ │ │ ├── useContentSize.ts │ │ │ ├── useCurrentUser.ts │ │ │ ├── useDeepResearchToggle.ts │ │ │ ├── useFederatedOAuthStatus.ts │ │ │ ├── useFeedbackController.ts │ │ │ ├── useFilter.ts │ │ │ ├── useGroups.ts │ │ │ ├── useImageDropzone.ts │ │ │ ├── useIsDefaultAgent.ts │ │ │ ├── useKeyPress.ts │ │ │ ├── useLLMProviders.ts │ │ │ ├── useLicense.ts │ │ │ ├── useMcpServers.ts │ │ │ ├── useMcpServersForAgentEditor.ts │ │ │ ├── useMemoryManager.ts │ │ │ ├── useOnMount.ts │ │ │ ├── useOpenApiTools.ts │ │ │ ├── usePaginatedFetch.ts │ │ │ ├── usePromptShortcuts.ts │ │ │ ├── useScimToken.ts │ │ │ ├── useScreenSize.ts │ │ │ ├── useServerTools.ts │ │ │ ├── useSettings.test.ts │ │ │ ├── useSettings.ts │ │ │ ├── useShareableGroups.ts │ │ │ ├── useShareableUsers.ts │ │ │ ├── useShowOnboarding.ts │ │ │ ├── useTags.ts │ │ │ ├── useToast.ts │ │ │ ├── useTokenRefresh.ts │ │ │ ├── useUserCounts.ts │ │ │ ├── useUserPersonalization.ts │ │ │ ├── useUsers.ts │ │ │ ├── useVoicePlayback.ts │ │ │ ├── useVoiceProviders.ts │ │ │ ├── useVoiceRecorder.ts │ │ │ ├── useVoiceStatus.ts │ │ │ └── useWebSocket.ts │ │ ├── instrumentation-client.ts │ │ ├── instrumentation.ts │ │ ├── interfaces/ │ │ │ ├── llm.ts │ │ │ ├── onboarding.ts │ │ │ └── settings.ts │ │ ├── layouts/ │ │ │ ├── actions-layouts.tsx │ │ │ ├── app-layouts.tsx │ │ │ ├── expandable-card-layouts.tsx │ │ │ ├── general-layouts.tsx │ │ │ ├── input-layouts.tsx │ │ │ ├── settings-layouts.tsx │ │ │ └── table-layouts.tsx │ │ ├── lib/ │ │ │ ├── admin/ │ │ │ │ ├── users/ │ │ │ │ │ └── userMutationFetcher.ts │ │ │ │ └── voice/ │ │ │ │ └── svc.ts │ │ │ ├── admin-routes.ts │ │ │ ├── agents.ts │ │ │ ├── agentsSS.ts │ │ │ ├── analytics.ts │ │ │ ├── appSidebarSS.ts │ │ │ ├── auth/ │ │ │ │ ├── redirectValidation.ts │ │ │ │ └── requireAuth.ts │ │ │ ├── azureTargetUri.ts │ │ │ ├── billing/ │ │ │ │ ├── index.ts │ │ │ │ ├── interfaces.ts │ │ │ │ ├── svc.test.ts │ │ │ │ └── svc.ts │ │ │ ├── browserUtilities.tsx │ │ │ ├── build/ │ │ │ │ └── client.ts │ │ │ ├── ccPair.ts │ │ │ ├── chat/ │ │ │ │ ├── fetchAgentData.ts │ │ │ │ ├── fetchBackendChatSessionSS.ts │ │ │ │ ├── greetingMessages.ts │ │ │ │ └── svc.ts │ │ │ ├── clipboard.test.ts │ │ │ ├── clipboard.ts │ │ │ ├── connector.ts │ │ │ ├── connectors/ │ │ │ │ ├── AutoSyncOptionFields.tsx │ │ │ │ ├── connectors.tsx │ │ │ │ ├── credentials.ts │ │ │ │ ├── fileTypes.ts │ │ │ │ └── oauth.ts │ │ │ ├── constants/ │ │ │ │ └── chatBackgrounds.ts │ │ │ ├── constants.ts │ │ │ ├── contains.ts │ │ │ ├── credential.ts │ │ │ ├── dateUtils.ts │ │ │ ├── documentDeletion.ts │ │ │ ├── documentUtils.ts │ │ │ ├── download.ts │ │ │ ├── drag/ │ │ │ │ └── constants.ts │ │ │ ├── error.ts │ │ │ ├── extension/ │ │ │ │ ├── constants.ts │ │ │ │ └── utils.ts │ │ │ ├── fetchUtils.ts │ │ │ ├── fetcher.ts │ │ │ ├── fileConnector.ts │ │ │ ├── filters.ts │ │ │ ├── generated/ │ │ │ │ └── README.md │ │ │ ├── gmail.ts │ │ │ ├── googleConnector.ts │ │ │ ├── googleDrive.ts │ │ │ ├── headers/ │ │ │ │ └── fetchHeaderDataSS.ts │ │ │ ├── hierarchy/ │ │ │ │ ├── interfaces.ts │ │ │ │ └── svc.ts │ │ │ ├── hooks/ │ │ │ │ ├── useCaptcha.ts │ │ │ │ ├── useCustomAnalyticsEnabled.ts │ │ │ │ ├── useDocumentSets.ts │ │ │ │ ├── useForcedTools.ts │ │ │ │ ├── useLLMProviderOptions.ts │ │ │ │ ├── useLLMProviders.test.ts │ │ │ │ ├── useProjects.ts │ │ │ │ └── useToolOAuthStatus.ts │ │ │ ├── hooks.llmResolver.test.ts │ │ │ ├── hooks.ts │ │ │ ├── indexAttempt.ts │ │ │ ├── languages.test.ts │ │ │ ├── languages.ts │ │ │ ├── llmConfig/ │ │ │ │ ├── cache.ts │ │ │ │ ├── constants.ts │ │ │ │ ├── providers.ts │ │ │ │ ├── svc.ts │ │ │ │ ├── utils.ts │ │ │ │ └── visionLLM.ts │ │ │ ├── oauth/ │ │ │ │ └── api.ts │ │ │ ├── oauth_utils.ts │ │ │ ├── redirectSS.ts │ │ │ ├── search/ │ │ │ │ ├── interfaces.ts │ │ │ │ ├── streamingUtils.ts │ │ │ │ ├── utils.ts │ │ │ │ └── utilsSS.ts │ │ │ ├── sources.ts │ │ │ ├── streamingTTS.ts │ │ │ ├── swr-keys.ts │ │ │ ├── time.ts │ │ │ ├── tools/ │ │ │ │ ├── fetchTools.ts │ │ │ │ ├── interfaces.ts │ │ │ │ ├── mcpService.ts │ │ │ │ ├── mcpUtils.tsx │ │ │ │ └── openApiService.ts │ │ │ ├── types.ts │ │ │ ├── typingUtils.ts │ │ │ ├── updateSlackBotField.ts │ │ │ ├── urlBuilder.ts │ │ │ ├── user.test.ts │ │ │ ├── user.ts │ │ │ ├── userSS.ts │ │ │ ├── userSettings.ts │ │ │ ├── utils.test.ts │ │ │ ├── utils.ts │ │ │ ├── utilsSS.ts │ │ │ └── version.ts │ │ ├── providers/ │ │ │ ├── AppBackgroundProvider.tsx │ │ │ ├── AppProvider.tsx │ │ │ ├── AppSidebarProvider.tsx │ │ │ ├── CustomAnalyticsScript.tsx │ │ │ ├── DynamicMetadata.tsx │ │ │ ├── ProductGatingWrapper.tsx │ │ │ ├── ProjectsContext.tsx │ │ │ ├── QueryControllerProvider.tsx │ │ │ ├── SWRConfigProvider.tsx │ │ │ ├── SettingsProvider.tsx │ │ │ ├── ToastProvider.tsx │ │ │ ├── UserProvider.tsx │ │ │ ├── VoiceModeProvider.tsx │ │ │ └── __tests__/ │ │ │ └── ProjectsContext.test.tsx │ │ ├── proxy.ts │ │ ├── refresh-components/ │ │ │ ├── Attachment.stories.tsx │ │ │ ├── Attachment.tsx │ │ │ ├── Calendar.stories.tsx │ │ │ ├── Calendar.tsx │ │ │ ├── CharacterCount.stories.tsx │ │ │ ├── CharacterCount.tsx │ │ │ ├── Chip.stories.tsx │ │ │ ├── Chip.tsx │ │ │ ├── Code.stories.tsx │ │ │ ├── Code.tsx │ │ │ ├── Collapsible.stories.tsx │ │ │ ├── Collapsible.tsx │ │ │ ├── ColorSwatch.stories.tsx │ │ │ ├── ColorSwatch.tsx │ │ │ ├── ConnectionProviderIcon.stories.tsx │ │ │ ├── ConnectionProviderIcon.tsx │ │ │ ├── Divider.stories.tsx │ │ │ ├── Divider.tsx │ │ │ ├── EmptyMessage.stories.tsx │ │ │ ├── EmptyMessage.tsx │ │ │ ├── EnabledCount.stories.tsx │ │ │ ├── EnabledCount.tsx │ │ │ ├── FadingEdgeContainer.stories.tsx │ │ │ ├── FadingEdgeContainer.tsx │ │ │ ├── FrostedDiv.stories.tsx │ │ │ ├── FrostedDiv.tsx │ │ │ ├── InlineExternalLink.stories.tsx │ │ │ ├── InlineExternalLink.tsx │ │ │ ├── Logo.tsx │ │ │ ├── Modal.stories.tsx │ │ │ ├── Modal.tsx │ │ │ ├── OverflowDiv.stories.tsx │ │ │ ├── OverflowDiv.tsx │ │ │ ├── Popover.stories.tsx │ │ │ ├── Popover.tsx │ │ │ ├── PreviewImage.stories.tsx │ │ │ ├── PreviewImage.tsx │ │ │ ├── ScrollIndicatorDiv.stories.tsx │ │ │ ├── ScrollIndicatorDiv.tsx │ │ │ ├── Separator.stories.tsx │ │ │ ├── Separator.tsx │ │ │ ├── ShadowDiv.stories.tsx │ │ │ ├── ShadowDiv.tsx │ │ │ ├── SimpleCollapsible.stories.tsx │ │ │ ├── SimpleCollapsible.tsx │ │ │ ├── SimplePopover.stories.tsx │ │ │ ├── SimplePopover.tsx │ │ │ ├── SimpleTabs.stories.tsx │ │ │ ├── SimpleTabs.tsx │ │ │ ├── SimpleTooltip.stories.tsx │ │ │ ├── SimpleTooltip.tsx │ │ │ ├── Spacer.stories.tsx │ │ │ ├── Spacer.tsx │ │ │ ├── Tabs.stories.tsx │ │ │ ├── Tabs.tsx │ │ │ ├── TextSeparator.stories.tsx │ │ │ ├── TextSeparator.tsx │ │ │ ├── avatars/ │ │ │ │ ├── AgentAvatar.tsx │ │ │ │ ├── CustomAgentAvatar.stories.tsx │ │ │ │ ├── CustomAgentAvatar.tsx │ │ │ │ └── UserAvatar.tsx │ │ │ ├── buttons/ │ │ │ │ ├── AttachmentButton.stories.tsx │ │ │ │ ├── AttachmentButton.tsx │ │ │ │ ├── BackButton.stories.tsx │ │ │ │ ├── BackButton.tsx │ │ │ │ ├── Button.stories.tsx │ │ │ │ ├── Button.tsx │ │ │ │ ├── ButtonRenaming.stories.tsx │ │ │ │ ├── ButtonRenaming.tsx │ │ │ │ ├── CopyIconButton.stories.tsx │ │ │ │ ├── CopyIconButton.tsx │ │ │ │ ├── CreateButton.stories.tsx │ │ │ │ ├── CreateButton.tsx │ │ │ │ ├── IconButton.stories.tsx │ │ │ │ ├── IconButton.tsx │ │ │ │ ├── LineItem.stories.tsx │ │ │ │ ├── LineItem.tsx │ │ │ │ ├── SelectButton.stories.tsx │ │ │ │ ├── SelectButton.tsx │ │ │ │ ├── SquareButton.stories.tsx │ │ │ │ ├── SquareButton.tsx │ │ │ │ ├── Tag.stories.tsx │ │ │ │ ├── Tag.tsx │ │ │ │ └── source-tag/ │ │ │ │ ├── SourceTag.tsx │ │ │ │ ├── SourceTagDetailsCard.tsx │ │ │ │ ├── index.ts │ │ │ │ └── sourceTagUtils.ts │ │ │ ├── cards/ │ │ │ │ ├── Card.stories.tsx │ │ │ │ ├── Card.tsx │ │ │ │ └── index.ts │ │ │ ├── commandmenu/ │ │ │ │ ├── CommandMenu.stories.tsx │ │ │ │ ├── CommandMenu.test.tsx │ │ │ │ ├── CommandMenu.tsx │ │ │ │ └── types.ts │ │ │ ├── contexts/ │ │ │ │ └── ModalContext.tsx │ │ │ ├── form/ │ │ │ │ ├── CheckboxField.tsx │ │ │ │ ├── FieldContext.tsx │ │ │ │ ├── FormField.stories.tsx │ │ │ │ ├── FormField.tsx │ │ │ │ ├── FormikField.tsx │ │ │ │ ├── FormikFields.stories.tsx │ │ │ │ ├── InputDatePickerField.tsx │ │ │ │ ├── InputSelectField.tsx │ │ │ │ ├── InputTextAreaField.tsx │ │ │ │ ├── InputTypeInElementField.tsx │ │ │ │ ├── InputTypeInField.tsx │ │ │ │ ├── Label.stories.tsx │ │ │ │ ├── Label.tsx │ │ │ │ ├── LabeledCheckboxField.tsx │ │ │ │ ├── PasswordInputTypeInField.tsx │ │ │ │ ├── SwitchField.tsx │ │ │ │ └── types.ts │ │ │ ├── inputs/ │ │ │ │ ├── Checkbox.stories.tsx │ │ │ │ ├── Checkbox.test.tsx │ │ │ │ ├── Checkbox.tsx │ │ │ │ ├── InputAvatar.stories.tsx │ │ │ │ ├── InputAvatar.tsx │ │ │ │ ├── InputChipField.stories.tsx │ │ │ │ ├── InputChipField.tsx │ │ │ │ ├── InputComboBox/ │ │ │ │ │ ├── InputComboBox.stories.tsx │ │ │ │ │ ├── InputComboBox.test.tsx │ │ │ │ │ ├── InputComboBox.tsx │ │ │ │ │ ├── components/ │ │ │ │ │ │ ├── ComboBoxDropdown.tsx │ │ │ │ │ │ ├── OptionItem.tsx │ │ │ │ │ │ └── OptionsList.tsx │ │ │ │ │ ├── hooks.ts │ │ │ │ │ ├── index.ts │ │ │ │ │ ├── types.ts │ │ │ │ │ └── utils/ │ │ │ │ │ ├── aria.ts │ │ │ │ │ └── validation.ts │ │ │ │ ├── InputDatePicker.stories.tsx │ │ │ │ ├── InputDatePicker.tsx │ │ │ │ ├── InputFile.stories.tsx │ │ │ │ ├── InputFile.tsx │ │ │ │ ├── InputImage.stories.tsx │ │ │ │ ├── InputImage.tsx │ │ │ │ ├── InputKeyValue.stories.tsx │ │ │ │ ├── InputKeyValue.tsx │ │ │ │ ├── InputNumber.stories.tsx │ │ │ │ ├── InputNumber.tsx │ │ │ │ ├── InputSearch.stories.tsx │ │ │ │ ├── InputSearch.tsx │ │ │ │ ├── InputSelect.stories.tsx │ │ │ │ ├── InputSelect.tsx │ │ │ │ ├── InputTextArea.stories.tsx │ │ │ │ ├── InputTextArea.tsx │ │ │ │ ├── InputTypeIn.stories.tsx │ │ │ │ ├── InputTypeIn.tsx │ │ │ │ ├── ListFieldInput.stories.tsx │ │ │ │ ├── ListFieldInput.tsx │ │ │ │ ├── PasswordInputTypeIn.stories.tsx │ │ │ │ ├── PasswordInputTypeIn.test.ts │ │ │ │ ├── PasswordInputTypeIn.tsx │ │ │ │ ├── Switch.stories.tsx │ │ │ │ ├── Switch.tsx │ │ │ │ └── styles.ts │ │ │ ├── layouts/ │ │ │ │ ├── ConfirmationModalLayout.stories.tsx │ │ │ │ └── ConfirmationModalLayout.tsx │ │ │ ├── loaders/ │ │ │ │ ├── SimpleLoader.stories.tsx │ │ │ │ └── SimpleLoader.tsx │ │ │ ├── messages/ │ │ │ │ ├── FieldMessage.stories.tsx │ │ │ │ ├── FieldMessage.tsx │ │ │ │ ├── InfoBlock.stories.tsx │ │ │ │ ├── InfoBlock.tsx │ │ │ │ ├── Message.stories.tsx │ │ │ │ └── Message.tsx │ │ │ ├── modals/ │ │ │ │ └── MemoriesModal.tsx │ │ │ ├── popovers/ │ │ │ │ ├── ActionsPopover/ │ │ │ │ │ ├── ActionLineItem.tsx │ │ │ │ │ ├── MCPLineItem.tsx │ │ │ │ │ ├── SwitchList.tsx │ │ │ │ │ └── index.tsx │ │ │ │ ├── FilePickerPopover.tsx │ │ │ │ ├── LLMPopover.test.tsx │ │ │ │ ├── LLMPopover.tsx │ │ │ │ └── interfaces.ts │ │ │ ├── skeletons/ │ │ │ │ ├── ChatSessionSkeleton.stories.tsx │ │ │ │ ├── ChatSessionSkeleton.tsx │ │ │ │ ├── SidebarTabSkeleton.stories.tsx │ │ │ │ └── SidebarTabSkeleton.tsx │ │ │ ├── texts/ │ │ │ │ ├── ExpandableTextDisplay.stories.tsx │ │ │ │ ├── ExpandableTextDisplay.tsx │ │ │ │ ├── Text.stories.tsx │ │ │ │ ├── Text.tsx │ │ │ │ ├── Truncated.stories.tsx │ │ │ │ └── Truncated.tsx │ │ │ └── tiles/ │ │ │ ├── ButtonTile.stories.tsx │ │ │ ├── ButtonTile.tsx │ │ │ ├── FileTile.stories.tsx │ │ │ └── FileTile.tsx │ │ ├── refresh-pages/ │ │ │ ├── AgentEditorPage.tsx │ │ │ ├── AgentsNavigationPage.tsx │ │ │ ├── AppPage.tsx │ │ │ ├── SettingsPage.tsx │ │ │ └── admin/ │ │ │ ├── AgentsPage/ │ │ │ │ ├── AgentRowActions.tsx │ │ │ │ ├── AgentsTable.tsx │ │ │ │ ├── interfaces.ts │ │ │ │ └── svc.ts │ │ │ ├── AgentsPage.tsx │ │ │ ├── ChatPreferencesPage.tsx │ │ │ ├── CodeInterpreterPage/ │ │ │ │ ├── index.tsx │ │ │ │ └── svc.ts │ │ │ ├── GroupsPage/ │ │ │ │ ├── CreateGroupPage.tsx │ │ │ │ ├── EditGroupPage.tsx │ │ │ │ ├── GroupCard.tsx │ │ │ │ ├── GroupsList.tsx │ │ │ │ ├── SharedGroupResources/ │ │ │ │ │ ├── ResourceContent.tsx │ │ │ │ │ ├── ResourcePopover.tsx │ │ │ │ │ ├── index.tsx │ │ │ │ │ └── interfaces.ts │ │ │ │ ├── TokenLimitSection.tsx │ │ │ │ ├── index.tsx │ │ │ │ ├── interfaces.ts │ │ │ │ ├── shared.tsx │ │ │ │ ├── svc.ts │ │ │ │ └── utils.ts │ │ │ ├── ImageGenerationPage/ │ │ │ │ ├── ImageGenerationContent.tsx │ │ │ │ ├── constants.ts │ │ │ │ ├── forms/ │ │ │ │ │ ├── AzureImageGenForm.tsx │ │ │ │ │ ├── ImageGenFormWrapper.tsx │ │ │ │ │ ├── OpenAIImageGenForm.tsx │ │ │ │ │ ├── VertexImageGenForm.tsx │ │ │ │ │ ├── getImageGenForm.tsx │ │ │ │ │ ├── index.ts │ │ │ │ │ └── types.ts │ │ │ │ ├── index.tsx │ │ │ │ └── svc.ts │ │ │ ├── LLMConfigurationPage.tsx │ │ │ ├── ServiceAccountsPage/ │ │ │ │ ├── ApiKeyFormModal.tsx │ │ │ │ ├── index.tsx │ │ │ │ ├── interfaces.ts │ │ │ │ └── svc.ts │ │ │ ├── UsersPage/ │ │ │ │ ├── EditUserModal.tsx │ │ │ │ ├── GroupsCell.tsx │ │ │ │ ├── InviteUsersModal.tsx │ │ │ │ ├── UserActionModals.tsx │ │ │ │ ├── UserFilters.tsx │ │ │ │ ├── UserRoleCell.tsx │ │ │ │ ├── UserRowActions.tsx │ │ │ │ ├── UsersSummary.tsx │ │ │ │ ├── UsersTable.tsx │ │ │ │ ├── index.tsx │ │ │ │ ├── interfaces.ts │ │ │ │ └── svc.ts │ │ │ ├── VoiceConfigurationPage.tsx │ │ │ └── WebSearchPage/ │ │ │ ├── WebProviderModalReducer.ts │ │ │ ├── WebProviderSetupModal.tsx │ │ │ ├── connectProviderFlow.ts │ │ │ ├── contentProviderUtils.ts │ │ │ ├── index.tsx │ │ │ ├── interfaces.ts │ │ │ ├── searchProviderUtils.ts │ │ │ └── svc.ts │ │ ├── sections/ │ │ │ ├── AppHealthBanner.tsx │ │ │ ├── Suggestions.tsx │ │ │ ├── actions/ │ │ │ │ ├── ActionCard.tsx │ │ │ │ ├── ActionCardContext.tsx │ │ │ │ ├── ActionCardHeader.tsx │ │ │ │ ├── Actions.tsx │ │ │ │ ├── MCPActionCard.tsx │ │ │ │ ├── MCPPageContent.tsx │ │ │ │ ├── OpenApiActionCard.tsx │ │ │ │ ├── OpenApiPageContent.tsx │ │ │ │ ├── PerUserAuthConfig.tsx │ │ │ │ ├── ToolItem.tsx │ │ │ │ ├── ToolsList.tsx │ │ │ │ ├── ToolsSection.tsx │ │ │ │ ├── modals/ │ │ │ │ │ ├── AddMCPServerModal.tsx │ │ │ │ │ ├── AddOpenAPIActionModal.tsx │ │ │ │ │ ├── DisconnectEntityModal.tsx │ │ │ │ │ ├── MCPAuthenticationModal.tsx │ │ │ │ │ └── OpenAPIAuthenticationModal.tsx │ │ │ │ └── skeleton/ │ │ │ │ ├── ActionCardSkeleton.tsx │ │ │ │ └── ToolItemSkeleton.tsx │ │ │ ├── admin/ │ │ │ │ ├── AdminListHeader.tsx │ │ │ │ └── ProviderCard.tsx │ │ │ ├── cards/ │ │ │ │ ├── AgentCard.tsx │ │ │ │ ├── DocumentSetCard.tsx │ │ │ │ ├── FileCard.tsx │ │ │ │ └── README.md │ │ │ ├── chat/ │ │ │ │ ├── ChatScrollContainer.tsx │ │ │ │ └── ChatUI.tsx │ │ │ ├── document-sidebar/ │ │ │ │ ├── ChatDocumentDisplay.tsx │ │ │ │ └── DocumentsSidebar.tsx │ │ │ ├── input/ │ │ │ │ ├── AppInputBar.tsx │ │ │ │ ├── MicrophoneButton.tsx │ │ │ │ └── SharedAppInputBar.tsx │ │ │ ├── knowledge/ │ │ │ │ ├── AgentKnowledgePane.tsx │ │ │ │ └── SourceHierarchyBrowser.tsx │ │ │ ├── modals/ │ │ │ │ ├── AgentViewerModal.tsx │ │ │ │ ├── FeedbackModal.tsx │ │ │ │ ├── NewTenantModal.tsx │ │ │ │ ├── PreviewModal/ │ │ │ │ │ ├── ExceptionTraceModal.tsx │ │ │ │ │ ├── FloatingFooter.tsx │ │ │ │ │ ├── PreviewModal.tsx │ │ │ │ │ ├── index.ts │ │ │ │ │ ├── interfaces.ts │ │ │ │ │ └── variants/ │ │ │ │ │ ├── CodePreview.tsx │ │ │ │ │ ├── codeVariant.tsx │ │ │ │ │ ├── csvVariant.tsx │ │ │ │ │ ├── dataVariant.tsx │ │ │ │ │ ├── docxVariant.tsx │ │ │ │ │ ├── imageVariant.tsx │ │ │ │ │ ├── index.ts │ │ │ │ │ ├── markdownVariant.tsx │ │ │ │ │ ├── pdfVariant.tsx │ │ │ │ │ ├── shared.tsx │ │ │ │ │ ├── textVariant.tsx │ │ │ │ │ └── unsupportedVariant.tsx │ │ │ │ ├── ShareAgentModal.test.tsx │ │ │ │ ├── ShareAgentModal.tsx │ │ │ │ ├── ShareChatSessionModal.tsx │ │ │ │ └── llmConfig/ │ │ │ │ ├── AnthropicModal.tsx │ │ │ │ ├── AzureModal.tsx │ │ │ │ ├── BedrockModal.tsx │ │ │ │ ├── BifrostModal.tsx │ │ │ │ ├── CustomModal.test.tsx │ │ │ │ ├── CustomModal.tsx │ │ │ │ ├── LMStudioForm.tsx │ │ │ │ ├── LiteLLMProxyModal.tsx │ │ │ │ ├── OllamaModal.tsx │ │ │ │ ├── OpenAIModal.tsx │ │ │ │ ├── OpenRouterModal.tsx │ │ │ │ ├── VertexAIModal.tsx │ │ │ │ ├── getModal.tsx │ │ │ │ ├── shared.tsx │ │ │ │ ├── svc.ts │ │ │ │ └── utils.ts │ │ │ ├── onboarding/ │ │ │ │ ├── OnboardingFlow.tsx │ │ │ │ ├── __tests__/ │ │ │ │ │ └── onboardingReducer.test.ts │ │ │ │ ├── components/ │ │ │ │ │ ├── LLMProviderCard.tsx │ │ │ │ │ ├── NonAdminStep.tsx │ │ │ │ │ └── OnboardingHeader.tsx │ │ │ │ ├── constants.ts │ │ │ │ ├── forms/ │ │ │ │ │ └── getOnboardingForm.tsx │ │ │ │ ├── reducer.ts │ │ │ │ └── steps/ │ │ │ │ ├── FinalStep.tsx │ │ │ │ ├── LLMStep.tsx │ │ │ │ └── NameStep.tsx │ │ │ ├── settings/ │ │ │ │ └── Memories.tsx │ │ │ └── sidebar/ │ │ │ ├── AdminSidebar.tsx │ │ │ ├── AgentButton.tsx │ │ │ ├── AppSidebar.tsx │ │ │ ├── ChatButton.tsx │ │ │ ├── ChatSearchCommandMenu.tsx │ │ │ ├── CreateConnectorSidebar.tsx │ │ │ ├── NotificationsPopover.tsx │ │ │ ├── ProjectFolderButton.tsx │ │ │ ├── SidebarBody.tsx │ │ │ ├── SidebarSection.tsx │ │ │ ├── SidebarWrapper.tsx │ │ │ ├── StepSidebarWrapper.tsx │ │ │ ├── UpsertEmbeddingSidebar.tsx │ │ │ ├── UserAvatarPopover.tsx │ │ │ ├── chatSearchUtils.ts │ │ │ ├── constants.ts │ │ │ ├── sidebarUtils.ts │ │ │ └── useChatSearchOptimistic.ts │ │ └── types.ts │ ├── tailwind-themes/ │ │ └── tailwind.config.js │ ├── tailwind.config.js │ ├── tests/ │ │ ├── README.md │ │ ├── e2e/ │ │ │ ├── admin/ │ │ │ │ ├── admin_auth.setup.ts │ │ │ │ ├── admin_oauth_redirect_uri.spec.ts │ │ │ │ ├── admin_pages.spec.ts │ │ │ │ ├── code-interpreter/ │ │ │ │ │ └── code_interpreter.spec.ts │ │ │ │ ├── default-agent.spec.ts │ │ │ │ ├── disable_default_agent.spec.ts │ │ │ │ ├── discord-bot/ │ │ │ │ │ ├── admin-workflows.spec.ts │ │ │ │ │ ├── bot-config.spec.ts │ │ │ │ │ ├── channel-config.spec.ts │ │ │ │ │ ├── fixtures.ts │ │ │ │ │ └── guilds-list.spec.ts │ │ │ │ ├── ee_feature_redirect.spec.ts │ │ │ │ ├── groups/ │ │ │ │ │ ├── GroupsAdminPage.ts │ │ │ │ │ ├── fixtures.ts │ │ │ │ │ └── groups.spec.ts │ │ │ │ ├── image-generation/ │ │ │ │ │ ├── disconnect-provider.spec.ts │ │ │ │ │ └── image-generation-content.spec.ts │ │ │ │ ├── llm_provider_setup.spec.ts │ │ │ │ ├── oauth_config/ │ │ │ │ │ └── test_tool_oauth.spec.ts │ │ │ │ ├── scim/ │ │ │ │ │ ├── fixtures.ts │ │ │ │ │ └── scim.spec.ts │ │ │ │ ├── theme/ │ │ │ │ │ └── appearance_theme_settings.spec.ts │ │ │ │ ├── users/ │ │ │ │ │ ├── UsersAdminPage.ts │ │ │ │ │ ├── fixtures.ts │ │ │ │ │ └── users.spec.ts │ │ │ │ ├── voice/ │ │ │ │ │ └── disconnect-provider.spec.ts │ │ │ │ └── web-search/ │ │ │ │ ├── disconnect-provider.spec.ts │ │ │ │ ├── svc.ts │ │ │ │ ├── web_content_providers.spec.ts │ │ │ │ └── web_search_providers.spec.ts │ │ │ ├── agents/ │ │ │ │ ├── create_and_edit_agent.spec.ts │ │ │ │ ├── llm_provider_rbac.spec.ts │ │ │ │ └── user_file_attachment.spec.ts │ │ │ ├── auth/ │ │ │ │ ├── email_verification.spec.ts │ │ │ │ ├── login.spec.ts │ │ │ │ ├── password_managements.spec.ts │ │ │ │ ├── pat_management.spec.ts │ │ │ │ └── signup.spec.ts │ │ │ ├── chat/ │ │ │ │ ├── actions_popover.spec.ts │ │ │ │ ├── chat-search-command-menu.spec.ts │ │ │ │ ├── chat_message_rendering.spec.ts │ │ │ │ ├── chat_session_not_found.spec.ts │ │ │ │ ├── current_agent.spec.ts │ │ │ │ ├── default_agent.spec.ts │ │ │ │ ├── default_app_mode.spec.ts │ │ │ │ ├── file_preview_modal.spec.ts │ │ │ │ ├── input_focus_retention.spec.ts │ │ │ │ ├── live_agent.spec.ts │ │ │ │ ├── llm_ordering.spec.ts │ │ │ │ ├── llm_runtime_selection.spec.ts │ │ │ │ ├── message_edit_regenerate.spec.ts │ │ │ │ ├── message_feedback.spec.ts │ │ │ │ ├── project_files_visual_regression.spec.ts │ │ │ │ ├── scroll_behavior.spec.ts │ │ │ │ ├── share_chat.spec.ts │ │ │ │ └── welcome_page.spec.ts │ │ │ ├── connectors/ │ │ │ │ ├── federated_slack.spec.ts │ │ │ │ └── inlineFileManagement.spec.ts │ │ │ ├── constants.ts │ │ │ ├── fixtures/ │ │ │ │ ├── eeFeatures.ts │ │ │ │ ├── llmProvider.ts │ │ │ │ └── three_images.docx │ │ │ ├── global-setup.ts │ │ │ ├── mcp/ │ │ │ │ ├── default-agent-mcp.spec.ts │ │ │ │ └── mcp_oauth_flow.spec.ts │ │ │ ├── onboarding/ │ │ │ │ └── onboarding_flow.spec.ts │ │ │ ├── settings/ │ │ │ │ └── settings_pages.spec.ts │ │ │ └── utils/ │ │ │ ├── agentUtils.ts │ │ │ ├── auth.ts │ │ │ ├── chatActions.ts │ │ │ ├── chatStream.ts │ │ │ ├── dragUtils.ts │ │ │ ├── mcpServer.ts │ │ │ ├── onyxApiClient.ts │ │ │ ├── pageStateLogger.ts │ │ │ ├── theme.ts │ │ │ ├── tools.ts │ │ │ └── visualRegression.ts │ │ └── setup/ │ │ ├── fileMock.js │ │ ├── jest.setup.ts │ │ ├── llmProviderTestUtils.ts │ │ ├── mocks/ │ │ │ ├── README.md │ │ │ ├── components/ │ │ │ │ └── UserProvider.tsx │ │ │ └── cssMock.js │ │ └── test-utils.tsx │ ├── tsconfig.json │ ├── tsconfig.types.json │ └── types/ │ ├── assets.d.ts │ └── favicon-fetch.d.ts └── widget/ ├── .gitignore ├── README.md ├── index.html ├── package.json ├── src/ │ ├── assets/ │ │ └── logo.ts │ ├── config/ │ │ └── config.ts │ ├── index.ts │ ├── services/ │ │ ├── api-service.ts │ │ └── stream-parser.ts │ ├── styles/ │ │ ├── colors.ts │ │ ├── theme.ts │ │ └── widget-styles.ts │ ├── types/ │ │ ├── api-types.ts │ │ └── widget-types.ts │ ├── utils/ │ │ └── storage.ts │ └── widget.ts ├── tsconfig.json └── vite.config.ts ================================================ FILE CONTENTS ================================================ ================================================ FILE: .git-blame-ignore-revs ================================================ # Exclude these commits from git blame (e.g. mass reformatting). # These are ignored by GitHub automatically. # To enable this locally, run: # # git config blame.ignoreRevsFile .git-blame-ignore-revs 3134e5f840c12c8f32613ce520101a047c89dcc2 # refactor(whitespace): rm temporary react fragments (#7161) ed3f72bc75f3e3a9ae9e4d8cd38278f9c97e78b4 # refactor(whitespace): rm react fragment #7190 7b927e79c25f4ddfd18a067f489e122acd2c89de # chore(format): format files where `ruff` and `black` agree (#9339) ================================================ FILE: .github/CODEOWNERS ================================================ * @onyx-dot-app/onyx-core-team # Helm charts Owners /helm/ @justin-tahara # Web standards updates /web/STANDARDS.md @raunakab @Weves # Agent context files /CLAUDE.md @Weves /AGENTS.md @Weves # Beta cherry-pick workflow owners /.github/workflows/post-merge-beta-cherry-pick.yml @justin-tahara @jmelahman ================================================ FILE: .github/actionlint.yml ================================================ self-hosted-runner: # Labels of self-hosted runner in array of strings. labels: - extras=ecr-cache - extras=s3-cache - hdd=256 - runs-on - runner=1cpu-linux-arm64 - runner=1cpu-linux-x64 - runner=2cpu-linux-arm64 - runner=2cpu-linux-x64 - runner=4cpu-linux-arm64 - runner=4cpu-linux-x64 - runner=8cpu-linux-arm64 - runner=8cpu-linux-x64 - runner=16cpu-linux-arm64 - runner=16cpu-linux-x64 - ubuntu-slim # Currently in public preview - volume=40gb - volume=50gb # Configuration variables in array of strings defined in your repository or # organization. `null` means disabling configuration variables check. # Empty array means no configuration variable is allowed. config-variables: null # Configuration for file paths. The keys are glob patterns to match to file # paths relative to the repository root. The values are the configurations for # the file paths. Note that the path separator is always '/'. # The following configurations are available. # # "ignore" is an array of regular expression patterns. Matched error messages # are ignored. This is similar to the "-ignore" command line option. paths: # Glob pattern relative to the repository root for matching files. The path separator is always '/'. # This example configures any YAML file under the '.github/workflows/' directory. .github/workflows/**/*.{yml,yaml}: # TODO: These are real and should be fixed eventually. ignore: - 'shellcheck reported issue in this script: SC2038:.+' - 'shellcheck reported issue in this script: SC2046:.+' - 'shellcheck reported issue in this script: SC2086:.+' - 'shellcheck reported issue in this script: SC2193:.+' ================================================ FILE: .github/actions/build-backend-image/action.yml ================================================ name: "Build Backend Image" description: "Builds and pushes the backend Docker image with cache reuse" inputs: runs-on-ecr-cache: description: "ECR cache registry from runs-on/action" required: true ref-name: description: "Git ref name used for cache suffix fallback" required: true pr-number: description: "Optional PR number for cache suffix" required: false default: "" github-sha: description: "Commit SHA used for cache keys" required: true run-id: description: "GitHub run ID used in output image tag" required: true docker-username: description: "Docker Hub username" required: true docker-token: description: "Docker Hub token" required: true docker-no-cache: description: "Set to 'true' to disable docker build cache" required: false default: "false" runs: using: "composite" steps: - name: Format branch name for cache id: format-branch shell: bash env: PR_NUMBER: ${{ inputs.pr-number }} REF_NAME: ${{ inputs.ref-name }} run: | if [ -n "${PR_NUMBER}" ]; then CACHE_SUFFIX="${PR_NUMBER}" else # shellcheck disable=SC2001 CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g') fi echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT" - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ inputs.docker-username }} password: ${{ inputs.docker-token }} - name: Build and push Backend Docker image uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6 with: context: ./backend file: ./backend/Dockerfile push: true tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-backend-${{ inputs.run-id }} cache-from: | type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }} type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }} type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache type=registry,ref=onyxdotapp/onyx-backend:latest cache-to: | type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ inputs.github-sha }},mode=max type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max type=registry,ref=${{ inputs.runs-on-ecr-cache }}:backend-cache,mode=max no-cache: ${{ inputs.docker-no-cache == 'true' }} ================================================ FILE: .github/actions/build-integration-image/action.yml ================================================ name: "Build Integration Image" description: "Builds and pushes the integration test image with docker bake" inputs: runs-on-ecr-cache: description: "ECR cache registry from runs-on/action" required: true ref-name: description: "Git ref name used for cache suffix fallback" required: true pr-number: description: "Optional PR number for cache suffix" required: false default: "" github-sha: description: "Commit SHA used for cache keys" required: true run-id: description: "GitHub run ID used in output image tag" required: true docker-username: description: "Docker Hub username" required: true docker-token: description: "Docker Hub token" required: true runs: using: "composite" steps: - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ inputs.docker-username }} password: ${{ inputs.docker-token }} - name: Format branch name for cache id: format-branch shell: bash env: PR_NUMBER: ${{ inputs.pr-number }} REF_NAME: ${{ inputs.ref-name }} run: | if [ -n "${PR_NUMBER}" ]; then CACHE_SUFFIX="${PR_NUMBER}" else # shellcheck disable=SC2001 CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g') fi echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT" - name: Build and push integration test image with Docker Bake shell: bash env: RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }} INTEGRATION_REPOSITORY: ${{ inputs.runs-on-ecr-cache }} TAG: nightly-llm-it-${{ inputs.run-id }} CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }} HEAD_SHA: ${{ inputs.github-sha }} run: | docker buildx bake --push \ --set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \ --set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \ --set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \ --set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \ --set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \ --set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \ --set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \ --set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \ --set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \ --set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \ --set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \ --set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \ --set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \ integration ================================================ FILE: .github/actions/build-model-server-image/action.yml ================================================ name: "Build Model Server Image" description: "Builds and pushes the model server Docker image with cache reuse" inputs: runs-on-ecr-cache: description: "ECR cache registry from runs-on/action" required: true ref-name: description: "Git ref name used for cache suffix fallback" required: true pr-number: description: "Optional PR number for cache suffix" required: false default: "" github-sha: description: "Commit SHA used for cache keys" required: true run-id: description: "GitHub run ID used in output image tag" required: true docker-username: description: "Docker Hub username" required: true docker-token: description: "Docker Hub token" required: true runs: using: "composite" steps: - name: Format branch name for cache id: format-branch shell: bash env: PR_NUMBER: ${{ inputs.pr-number }} REF_NAME: ${{ inputs.ref-name }} run: | if [ -n "${PR_NUMBER}" ]; then CACHE_SUFFIX="${PR_NUMBER}" else # shellcheck disable=SC2001 CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g') fi echo "cache-suffix=${CACHE_SUFFIX}" >> "$GITHUB_OUTPUT" - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ inputs.docker-username }} password: ${{ inputs.docker-token }} - name: Build and push Model Server Docker image uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6 with: context: ./backend file: ./backend/Dockerfile.model_server push: true tags: ${{ inputs.runs-on-ecr-cache }}:nightly-llm-it-model-server-${{ inputs.run-id }} cache-from: | type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }} type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }} type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache type=registry,ref=onyxdotapp/onyx-model-server:latest cache-to: | type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ inputs.github-sha }},mode=max type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max type=registry,ref=${{ inputs.runs-on-ecr-cache }}:model-server-cache,mode=max ================================================ FILE: .github/actions/run-nightly-provider-chat-test/action.yml ================================================ name: "Run Nightly Provider Chat Test" description: "Starts required compose services and runs nightly provider integration test" inputs: provider: description: "Provider slug for NIGHTLY_LLM_PROVIDER" required: true models: description: "Comma-separated model list for NIGHTLY_LLM_MODELS" required: true provider-api-key: description: "API key for NIGHTLY_LLM_API_KEY" required: false default: "" strict: description: "String true/false for NIGHTLY_LLM_STRICT" required: true api-base: description: "Optional NIGHTLY_LLM_API_BASE" required: false default: "" api-version: description: "Optional NIGHTLY_LLM_API_VERSION" required: false default: "" deployment-name: description: "Optional NIGHTLY_LLM_DEPLOYMENT_NAME" required: false default: "" custom-config-json: description: "Optional NIGHTLY_LLM_CUSTOM_CONFIG_JSON" required: false default: "" runs-on-ecr-cache: description: "ECR cache registry from runs-on/action" required: true run-id: description: "GitHub run ID used in image tags" required: true docker-username: description: "Docker Hub username" required: true docker-token: description: "Docker Hub token" required: true runs: using: "composite" steps: - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ inputs.docker-username }} password: ${{ inputs.docker-token }} - name: Create .env file for Docker Compose shell: bash env: ECR_CACHE: ${{ inputs.runs-on-ecr-cache }} RUN_ID: ${{ inputs.run-id }} run: | cat < deployment/docker_compose/.env COMPOSE_PROFILES=s3-filestore ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true LICENSE_ENFORCEMENT_ENABLED=false AUTH_TYPE=basic POSTGRES_POOL_PRE_PING=true POSTGRES_USE_NULL_POOL=true REQUIRE_EMAIL_VERIFICATION=false DISABLE_TELEMETRY=true INTEGRATION_TESTS_MODE=true AUTO_LLM_UPDATE_INTERVAL_SECONDS=10 AWS_REGION_NAME=us-west-2 ONYX_BACKEND_IMAGE=${ECR_CACHE}:nightly-llm-it-backend-${RUN_ID} ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:nightly-llm-it-model-server-${RUN_ID} EOF2 - name: Start Docker containers shell: bash run: | cd deployment/docker_compose docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d --wait \ relational_db \ index \ cache \ minio \ api_server \ inference_model_server - name: Run nightly provider integration test uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3 env: MODELS: ${{ inputs.models }} NIGHTLY_LLM_PROVIDER: ${{ inputs.provider }} NIGHTLY_LLM_API_KEY: ${{ inputs.provider-api-key }} NIGHTLY_LLM_API_BASE: ${{ inputs.api-base }} NIGHTLY_LLM_API_VERSION: ${{ inputs.api-version }} NIGHTLY_LLM_DEPLOYMENT_NAME: ${{ inputs.deployment-name }} NIGHTLY_LLM_CUSTOM_CONFIG_JSON: ${{ inputs.custom-config-json }} NIGHTLY_LLM_STRICT: ${{ inputs.strict }} RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }} RUN_ID: ${{ inputs.run-id }} with: timeout_minutes: 20 max_attempts: 2 retry_wait_seconds: 10 command: | docker run --rm --network onyx_default \ --name test-runner \ -e POSTGRES_HOST=relational_db \ -e POSTGRES_USER=postgres \ -e POSTGRES_PASSWORD=password \ -e POSTGRES_DB=postgres \ -e DB_READONLY_USER=db_readonly_user \ -e DB_READONLY_PASSWORD=password \ -e POSTGRES_POOL_PRE_PING=true \ -e POSTGRES_USE_NULL_POOL=true \ -e VESPA_HOST=index \ -e REDIS_HOST=cache \ -e API_SERVER_HOST=api_server \ -e TEST_WEB_HOSTNAME=test-runner \ -e AWS_REGION_NAME=us-west-2 \ -e NIGHTLY_LLM_PROVIDER="${NIGHTLY_LLM_PROVIDER}" \ -e NIGHTLY_LLM_MODELS="${MODELS}" \ -e NIGHTLY_LLM_API_KEY="${NIGHTLY_LLM_API_KEY}" \ -e NIGHTLY_LLM_API_BASE="${NIGHTLY_LLM_API_BASE}" \ -e NIGHTLY_LLM_API_VERSION="${NIGHTLY_LLM_API_VERSION}" \ -e NIGHTLY_LLM_DEPLOYMENT_NAME="${NIGHTLY_LLM_DEPLOYMENT_NAME}" \ -e NIGHTLY_LLM_CUSTOM_CONFIG_JSON="${NIGHTLY_LLM_CUSTOM_CONFIG_JSON}" \ -e NIGHTLY_LLM_STRICT="${NIGHTLY_LLM_STRICT}" \ ${RUNS_ON_ECR_CACHE}:nightly-llm-it-${RUN_ID} \ /app/tests/integration/tests/llm_workflows/test_nightly_provider_chat_workflow.py ================================================ FILE: .github/actions/setup-playwright/action.yml ================================================ name: "Setup Playwright" description: "Sets up Playwright and system deps (assumes Python and Playwright are installed)" runs: using: "composite" steps: - name: Cache playwright cache uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4 with: path: ~/.cache/ms-playwright key: ${{ runner.os }}-${{ runner.arch }}-playwright-${{ hashFiles('backend/requirements/default.txt') }} restore-keys: | ${{ runner.os }}-${{ runner.arch }}-playwright- - name: Install playwright shell: bash run: | playwright install chromium --with-deps ================================================ FILE: .github/actions/setup-python-and-install-dependencies/action.yml ================================================ name: "Setup Python and Install Dependencies" description: "Sets up Python with uv and installs deps" inputs: requirements: description: "Newline-separated list of requirement files to install (relative to repo root)" required: true runs: using: "composite" steps: - name: Compute requirements hash id: req-hash shell: bash env: REQUIREMENTS: ${{ inputs.requirements }} run: | # Hash the contents of the specified requirement files hash="" while IFS= read -r req; do if [ -n "$req" ] && [ -f "$req" ]; then hash="$hash$(sha256sum "$req")" fi done <<< "$REQUIREMENTS" echo "hash=$(echo "$hash" | sha256sum | cut -d' ' -f1)" >> "$GITHUB_OUTPUT" # NOTE: This comes before Setup uv since clean-ups run in reverse chronological order # such that Setup uv's prune-cache is able to prune the cache before we upload. - name: Cache uv cache directory uses: runs-on/cache@50350ad4242587b6c8c2baa2e740b1bc11285ff4 # ratchet:runs-on/cache@v4 with: path: ~/.cache/uv key: ${{ runner.os }}-uv-${{ steps.req-hash.outputs.hash }} restore-keys: | ${{ runner.os }}-uv- - name: Setup uv uses: astral-sh/setup-uv@ed21f2f24f8dd64503750218de024bcf64c7250a # ratchet:astral-sh/setup-uv@v7 with: version: "0.9.9" # TODO: Enable caching once there is a uv.lock file checked in. # with: # enable-cache: true - name: Setup Python uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # ratchet:actions/setup-python@v5 with: python-version: "3.11" - name: Create virtual environment shell: bash env: VENV_DIR: ${{ runner.temp }}/venv run: | # zizmor: ignore[github-env] uv venv "$VENV_DIR" # Validate path before adding to GITHUB_PATH to prevent code injection if [ -d "$VENV_DIR/bin" ]; then realpath "$VENV_DIR/bin" >> "$GITHUB_PATH" else echo "Error: $VENV_DIR/bin does not exist" exit 1 fi - name: Install Python dependencies with uv shell: bash env: REQUIREMENTS: ${{ inputs.requirements }} run: | # Build the uv pip install command with each requirement file as array elements cmd=("uv" "pip" "install") while IFS= read -r req; do # Skip empty lines if [ -n "$req" ]; then cmd+=("-r" "$req") fi done <<< "$REQUIREMENTS" echo "Running: ${cmd[*]}" "${cmd[@]}" ================================================ FILE: .github/actions/slack-notify/action.yml ================================================ name: "Slack Notify" description: "Sends a Slack notification for workflow events" inputs: webhook-url: description: "Slack webhook URL (can also use SLACK_WEBHOOK_URL env var)" required: false details: description: "Additional message body content" required: false failed-jobs: description: "Deprecated alias for details" required: false mention: description: "GitHub username to resolve to a Slack @-mention. Replaces {mention} in details." required: false title: description: "Title for the notification" required: false default: "🚨 Workflow Failed" ref-name: description: "Git ref name (tag/branch)" required: false runs: using: "composite" steps: - name: Send Slack notification shell: bash env: SLACK_WEBHOOK_URL: ${{ inputs.webhook-url }} DETAILS: ${{ inputs.details }} FAILED_JOBS: ${{ inputs.failed-jobs }} MENTION_USER: ${{ inputs.mention }} TITLE: ${{ inputs.title }} REF_NAME: ${{ inputs.ref-name }} REPO: ${{ github.repository }} WORKFLOW: ${{ github.workflow }} RUN_NUMBER: ${{ github.run_number }} RUN_ID: ${{ github.run_id }} SERVER_URL: ${{ github.server_url }} GITHUB_REF_NAME: ${{ github.ref_name }} run: | if [ -z "$SLACK_WEBHOOK_URL" ]; then echo "webhook-url input or SLACK_WEBHOOK_URL env var is not set, skipping notification" exit 0 fi # Build workflow URL WORKFLOW_URL="${SERVER_URL}/${REPO}/actions/runs/${RUN_ID}" # Use ref_name from input or fall back to github.ref_name if [ -z "$REF_NAME" ]; then REF_NAME="$GITHUB_REF_NAME" fi if [ -z "$DETAILS" ]; then DETAILS="$FAILED_JOBS" fi # Resolve {mention} placeholder if a GitHub username was provided. # Looks up the username in user-mappings.json (co-located with this action) # and replaces {mention} with <@SLACK_ID> for a Slack @-mention. # Falls back to the plain GitHub username if not found in the mapping. if [ -n "$MENTION_USER" ]; then MAPPINGS_FILE="${GITHUB_ACTION_PATH}/user-mappings.json" slack_id="$(jq -r --arg gh "$MENTION_USER" 'to_entries[] | select(.value | ascii_downcase == ($gh | ascii_downcase)) | .key' "$MAPPINGS_FILE" 2>/dev/null | head -1)" if [ -n "$slack_id" ]; then mention_text="<@${slack_id}>" else mention_text="${MENTION_USER}" fi DETAILS="${DETAILS//\{mention\}/$mention_text}" TITLE="${TITLE//\{mention\}/}" else DETAILS="${DETAILS//\{mention\}/}" TITLE="${TITLE//\{mention\}/}" fi normalize_multiline() { printf '%s' "$1" | awk 'BEGIN { ORS=""; first=1 } { if (!first) printf "\\n"; printf "%s", $0; first=0 }' } DETAILS="$(normalize_multiline "$DETAILS")" REF_NAME="$(normalize_multiline "$REF_NAME")" TITLE="$(normalize_multiline "$TITLE")" # Escape JSON special characters escape_json() { local input="$1" # Escape backslashes first (but preserve \n sequences) # Protect \n sequences temporarily input=$(printf '%s' "$input" | sed 's/\\n/\x01NL\x01/g') # Escape remaining backslashes input=$(printf '%s' "$input" | sed 's/\\/\\\\/g') # Restore \n sequences (single backslash, will be correct in JSON) input=$(printf '%s' "$input" | sed 's/\x01NL\x01/\\n/g') # Escape quotes printf '%s' "$input" | sed 's/"/\\"/g' } REF_NAME_ESC=$(escape_json "$REF_NAME") DETAILS_ESC=$(escape_json "$DETAILS") WORKFLOW_URL_ESC=$(escape_json "$WORKFLOW_URL") TITLE_ESC=$(escape_json "$TITLE") # Build JSON payload piece by piece # Note: DETAILS_ESC already contains \n sequences that should remain as \n in JSON PAYLOAD="{" PAYLOAD="${PAYLOAD}\"text\":\"${TITLE_ESC}\"," PAYLOAD="${PAYLOAD}\"blocks\":[{" PAYLOAD="${PAYLOAD}\"type\":\"header\"," PAYLOAD="${PAYLOAD}\"text\":{\"type\":\"plain_text\",\"text\":\"${TITLE_ESC}\"}" PAYLOAD="${PAYLOAD}},{" PAYLOAD="${PAYLOAD}\"type\":\"section\"," PAYLOAD="${PAYLOAD}\"fields\":[" if [ -n "$REF_NAME" ]; then PAYLOAD="${PAYLOAD}{\"type\":\"mrkdwn\",\"text\":\"*Ref:*\\n${REF_NAME_ESC}\"}," fi PAYLOAD="${PAYLOAD}{\"type\":\"mrkdwn\",\"text\":\"*Run ID:*\\n#${RUN_NUMBER}\"}" PAYLOAD="${PAYLOAD}]" PAYLOAD="${PAYLOAD}}" if [ -n "$DETAILS" ]; then PAYLOAD="${PAYLOAD},{" PAYLOAD="${PAYLOAD}\"type\":\"section\"," PAYLOAD="${PAYLOAD}\"text\":{\"type\":\"mrkdwn\",\"text\":\"${DETAILS_ESC}\"}" PAYLOAD="${PAYLOAD}}" fi PAYLOAD="${PAYLOAD},{" PAYLOAD="${PAYLOAD}\"type\":\"actions\"," PAYLOAD="${PAYLOAD}\"elements\":[{" PAYLOAD="${PAYLOAD}\"type\":\"button\"," PAYLOAD="${PAYLOAD}\"text\":{\"type\":\"plain_text\",\"text\":\"View Workflow Run\"}," PAYLOAD="${PAYLOAD}\"url\":\"${WORKFLOW_URL_ESC}\"" PAYLOAD="${PAYLOAD}}]" PAYLOAD="${PAYLOAD}}" PAYLOAD="${PAYLOAD}]" PAYLOAD="${PAYLOAD}}" curl -X POST -H 'Content-type: application/json' \ --data "$PAYLOAD" \ "$SLACK_WEBHOOK_URL" ================================================ FILE: .github/actions/slack-notify/user-mappings.json ================================================ { "U05SAGZPEA1": "yuhongsun96", "U05SAH6UGUD": "Weves", "U07PWEQB7A5": "evan-onyx", "U07V1SM68KF": "joachim-danswer", "U08JZ9N3QNN": "raunakab", "U08L24NCLJE": "Subash-Mohan", "U090B9M07B2": "wenxi-onyx", "U094RASDP0Q": "duo-onyx", "U096L8ZQ85B": "justin-tahara", "U09AHV8UBQX": "jessicasingh7", "U09KAL5T3C2": "nmgarza5", "U09KPGVQ70R": "acaprau", "U09QR8KTSJH": "rohoswagger", "U09RB4NTXA4": "jmelahman", "U0A6K9VCY6A": "Danelegend", "U0AGC4KH71A": "Bo-Onyx" } ================================================ FILE: .github/dependabot.yml ================================================ version: 2 updates: - package-ecosystem: "github-actions" directory: "/" schedule: interval: "weekly" cooldown: default-days: 7 open-pull-requests-limit: 3 assignees: - "jmelahman" labels: - "dependabot:actions" - package-ecosystem: "pip" directory: "/backend" schedule: interval: "weekly" cooldown: default-days: 7 open-pull-requests-limit: 3 assignees: - "jmelahman" labels: - "dependabot:python" ================================================ FILE: .github/pull_request_template.md ================================================ ## Description ## How Has This Been Tested? ## Additional Options - [ ] [Optional] Please cherry-pick this PR to the latest release version. - [ ] [Optional] Override Linear Check ================================================ FILE: .github/runs-on.yml ================================================ _extend: .github-private ================================================ FILE: .github/workflows/deployment.yml ================================================ name: Build and Push Docker Images on Tag on: push: tags: - "*" workflow_dispatch: # Set restrictive default permissions for all jobs. Jobs that need more permissions # should explicitly declare them. permissions: # Required for OIDC authentication with AWS id-token: write # zizmor: ignore[excessive-permissions] env: EDGE_TAG: ${{ startsWith(github.ref_name, 'nightly-latest') }} jobs: # Determine which components to build based on the tag determine-builds: # NOTE: Github-hosted runners have about 20s faster queue times and are preferred here. runs-on: ubuntu-slim timeout-minutes: 90 outputs: build-desktop: ${{ steps.check.outputs.build-desktop }} build-web: ${{ steps.check.outputs.build-web }} build-web-cloud: ${{ steps.check.outputs.build-web-cloud }} build-backend: ${{ steps.check.outputs.build-backend }} build-backend-craft: ${{ steps.check.outputs.build-backend-craft }} build-model-server: ${{ steps.check.outputs.build-model-server }} is-cloud-tag: ${{ steps.check.outputs.is-cloud-tag }} is-beta: ${{ steps.check.outputs.is-beta }} is-beta-standalone: ${{ steps.check.outputs.is-beta-standalone }} is-latest: ${{ steps.check.outputs.is-latest }} is-test-run: ${{ steps.check.outputs.is-test-run }} sanitized-tag: ${{ steps.check.outputs.sanitized-tag }} short-sha: ${{ steps.check.outputs.short-sha }} steps: - name: Checkout (for git tags) uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false fetch-depth: 0 fetch-tags: true - name: Setup uv uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7 with: version: "0.9.9" enable-cache: false - name: Check which components to build and version info id: check env: EVENT_NAME: ${{ github.event_name }} run: | set -eo pipefail TAG="${GITHUB_REF_NAME}" # Sanitize tag name by replacing slashes with hyphens (for Docker tag compatibility) SANITIZED_TAG=$(echo "$TAG" | tr '/' '-') SHORT_SHA="${GITHUB_SHA::7}" # Initialize all flags to false IS_CLOUD=false IS_NIGHTLY=false IS_VERSION_TAG=false IS_STABLE=false IS_BETA=false IS_BETA_STANDALONE=false IS_LATEST=false IS_PROD_TAG=false IS_TEST_RUN=false BUILD_DESKTOP=false BUILD_WEB=false BUILD_WEB_CLOUD=false BUILD_BACKEND=true BUILD_BACKEND_CRAFT=false BUILD_MODEL_SERVER=true # Determine tag type based on pattern matching (do regex checks once) if [[ "$TAG" == *cloud* ]]; then IS_CLOUD=true fi if [[ "$TAG" == nightly* ]]; then IS_NIGHTLY=true fi if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+ ]]; then IS_VERSION_TAG=true fi if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then IS_STABLE=true fi if [[ "$TAG" =~ ^v[0-9]+\.[0-9]+\.[0-9]+-beta(\.[0-9]+)?$ ]]; then IS_BETA=true fi # Determine what to build based on tag type if [[ "$IS_CLOUD" == "true" ]]; then BUILD_WEB_CLOUD=true else BUILD_WEB=true # Only build desktop for semver tags (excluding beta) if [[ "$IS_VERSION_TAG" == "true" ]] && [[ "$IS_BETA" != "true" ]]; then BUILD_DESKTOP=true fi fi # Standalone version checks (for backend/model-server - version excluding cloud tags) if [[ "$IS_BETA" == "true" ]] && [[ "$IS_CLOUD" != "true" ]]; then IS_BETA_STANDALONE=true fi # Determine if this tag should get the "latest" Docker tag. # Only the highest semver stable tag (vX.Y.Z exactly) gets "latest". if [[ "$IS_STABLE" == "true" ]]; then HIGHEST_STABLE=$(uv run --no-sync --with onyx-devtools ods latest-stable-tag) || { echo "::error::Failed to determine highest stable tag via 'ods latest-stable-tag'" exit 1 } if [[ "$TAG" == "$HIGHEST_STABLE" ]]; then IS_LATEST=true fi fi # Build craft-latest backend alongside the regular latest. if [[ "$IS_LATEST" == "true" ]]; then BUILD_BACKEND_CRAFT=true fi # Determine if this is a production tag # Production tags are: version tags (v1.2.3*) or nightly tags if [[ "$IS_VERSION_TAG" == "true" ]] || [[ "$IS_NIGHTLY" == "true" ]]; then IS_PROD_TAG=true fi # Determine if this is a test run (workflow_dispatch on non-production ref) if [[ "$EVENT_NAME" == "workflow_dispatch" ]] && [[ "$IS_PROD_TAG" != "true" ]]; then IS_TEST_RUN=true fi { echo "build-desktop=$BUILD_DESKTOP" echo "build-web=$BUILD_WEB" echo "build-web-cloud=$BUILD_WEB_CLOUD" echo "build-backend=$BUILD_BACKEND" echo "build-backend-craft=$BUILD_BACKEND_CRAFT" echo "build-model-server=$BUILD_MODEL_SERVER" echo "is-cloud-tag=$IS_CLOUD" echo "is-beta=$IS_BETA" echo "is-beta-standalone=$IS_BETA_STANDALONE" echo "is-latest=$IS_LATEST" echo "is-test-run=$IS_TEST_RUN" echo "sanitized-tag=$SANITIZED_TAG" echo "short-sha=$SHORT_SHA" } >> "$GITHUB_OUTPUT" check-version-tag: runs-on: ubuntu-slim timeout-minutes: 10 if: ${{ !startsWith(github.ref_name, 'nightly-latest') && github.event_name != 'workflow_dispatch' }} steps: - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false fetch-depth: 0 - name: Setup uv uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7 with: version: "0.9.9" # NOTE: This isn't caching much and zizmor suggests this could be poisoned, so disable. enable-cache: false - name: Validate tag is versioned correctly run: | uv run --no-sync --with release-tag tag --check notify-slack-on-tag-check-failure: needs: - check-version-tag if: always() && needs.check-version-tag.result == 'failure' && github.event_name != 'workflow_dispatch' runs-on: ubuntu-slim timeout-minutes: 10 environment: release steps: - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Send Slack notification uses: ./.github/actions/slack-notify with: webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }} failed-jobs: "• check-version-tag" title: "🚨 Version Tag Check Failed" ref-name: ${{ github.ref_name }} # Create GitHub release first, before desktop builds start. # This ensures all desktop matrix jobs upload to the same release instead of # racing to create duplicate releases. create-release: needs: determine-builds if: needs.determine-builds.outputs.build-desktop == 'true' runs-on: ubuntu-slim timeout-minutes: 10 permissions: contents: write outputs: release-id: ${{ steps.create-release.outputs.id }} steps: - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Determine release tag id: release-tag env: IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }} SHORT_SHA: ${{ needs.determine-builds.outputs.short-sha }} run: | if [ "${IS_TEST_RUN}" == "true" ]; then echo "tag=v0.0.0-dev+${SHORT_SHA}" >> "$GITHUB_OUTPUT" else echo "tag=${GITHUB_REF_NAME}" >> "$GITHUB_OUTPUT" fi - name: Create GitHub Release id: create-release uses: softprops/action-gh-release@da05d552573ad5aba039eaac05058a918a7bf631 # ratchet:softprops/action-gh-release@v2 with: tag_name: ${{ steps.release-tag.outputs.tag }} name: ${{ steps.release-tag.outputs.tag }} body: "See the assets to download this version and install." draft: true prerelease: false env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} build-desktop: needs: - determine-builds - create-release if: needs.determine-builds.outputs.build-desktop == 'true' permissions: id-token: write contents: write actions: read strategy: fail-fast: false matrix: include: - platform: "macos-latest" # Build a universal image for macOS. args: "--target universal-apple-darwin" - platform: "ubuntu-24.04" args: "--bundles deb,rpm" - platform: "ubuntu-24.04-arm" # Only available in public repos. args: "--bundles deb,rpm" - platform: "windows-latest" args: "" runs-on: ${{ matrix.platform }} timeout-minutes: 90 environment: release steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6.0.2 with: # NOTE: persist-credentials is needed for tauri-action to upload assets to GitHub releases. persist-credentials: true # zizmor: ignore[artipacked] - name: Configure AWS credentials if: startsWith(matrix.platform, 'macos-') uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets if: startsWith(matrix.platform, 'macos-') uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | APPLE_ID, deploy/apple-id APPLE_PASSWORD, deploy/apple-password APPLE_CERTIFICATE, deploy/apple-certificate APPLE_CERTIFICATE_PASSWORD, deploy/apple-certificate-password KEYCHAIN_PASSWORD, deploy/keychain-password APPLE_TEAM_ID, deploy/apple-team-id parse-json-secrets: true - name: install dependencies (ubuntu only) if: startsWith(matrix.platform, 'ubuntu-') run: | sudo apt-get update sudo apt-get install -y \ build-essential \ libglib2.0-dev \ libgirepository1.0-dev \ libgtk-3-dev \ libjavascriptcoregtk-4.1-dev \ libwebkit2gtk-4.1-dev \ libayatana-appindicator3-dev \ gobject-introspection \ pkg-config \ curl \ xdg-utils - name: setup node uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v6.3.0 with: node-version: 24 package-manager-cache: false - name: install Rust stable uses: dtolnay/rust-toolchain@6d9817901c499d6b02debbb57edb38d33daa680b # zizmor: ignore[impostor-commit] with: # Those targets are only used on macos runners so it's in an `if` to slightly speed up windows and linux builds. targets: ${{ matrix.platform == 'macos-latest' && 'aarch64-apple-darwin,x86_64-apple-darwin' || '' }} - name: install frontend dependencies working-directory: ./desktop run: npm install - name: Inject version (Unix) if: runner.os != 'Windows' working-directory: ./desktop env: SHORT_SHA: ${{ needs.determine-builds.outputs.short-sha }} IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }} run: | if [ "${IS_TEST_RUN}" == "true" ]; then VERSION="0.0.0-dev+${SHORT_SHA}" else VERSION="${GITHUB_REF_NAME#v}" fi echo "Injecting version: $VERSION" # Update Cargo.toml sed "s/^version = .*/version = \"$VERSION\"/" src-tauri/Cargo.toml > src-tauri/Cargo.toml.tmp mv src-tauri/Cargo.toml.tmp src-tauri/Cargo.toml # Update tauri.conf.json jq --arg v "$VERSION" '.version = $v' src-tauri/tauri.conf.json > src-tauri/tauri.conf.json.tmp mv src-tauri/tauri.conf.json.tmp src-tauri/tauri.conf.json # Update package.json jq --arg v "$VERSION" '.version = $v' package.json > package.json.tmp mv package.json.tmp package.json echo "Versions set to: $VERSION" - name: Inject version (Windows) if: runner.os == 'Windows' working-directory: ./desktop shell: pwsh env: IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }} run: | # Windows MSI requires numeric-only build metadata, so we skip the SHA suffix if ($env:IS_TEST_RUN -eq "true") { $VERSION = "0.0.0" } else { # Strip 'v' prefix and any pre-release suffix (e.g., -beta.13) for MSI compatibility $VERSION = "$env:GITHUB_REF_NAME" -replace '^v', '' -replace '-.*$', '' } Write-Host "Injecting version: $VERSION" # Update Cargo.toml $cargo = Get-Content src-tauri/Cargo.toml -Raw $cargo = $cargo -replace '(?m)^version = .*', "version = `"$VERSION`"" Set-Content src-tauri/Cargo.toml $cargo -NoNewline # Update tauri.conf.json $json = Get-Content src-tauri/tauri.conf.json | ConvertFrom-Json $json.version = $VERSION $json | ConvertTo-Json -Depth 100 | Set-Content src-tauri/tauri.conf.json # Update package.json $pkg = Get-Content package.json | ConvertFrom-Json $pkg.version = $VERSION $pkg | ConvertTo-Json -Depth 100 | Set-Content package.json Write-Host "Versions set to: $VERSION" - name: Import Apple Developer Certificate if: startsWith(matrix.platform, 'macos-') run: | echo $APPLE_CERTIFICATE | base64 --decode > certificate.p12 security create-keychain -p "$KEYCHAIN_PASSWORD" build.keychain security default-keychain -s build.keychain security unlock-keychain -p "$KEYCHAIN_PASSWORD" build.keychain security set-keychain-settings -t 3600 -u build.keychain security import certificate.p12 -k build.keychain -P "$APPLE_CERTIFICATE_PASSWORD" -T /usr/bin/codesign security set-key-partition-list -S apple-tool:,apple:,codesign: -s -k "$KEYCHAIN_PASSWORD" build.keychain security find-identity -v -p codesigning build.keychain - name: Verify Certificate if: startsWith(matrix.platform, 'macos-') run: | CERT_INFO=$(security find-identity -v -p codesigning build.keychain | grep -E "(Developer ID Application|Apple Distribution|Apple Development)" | head -n 1) CERT_ID=$(echo "$CERT_INFO" | awk -F'"' '{print $2}') echo "CERT_ID=$CERT_ID" >> $GITHUB_ENV echo "Certificate imported." - uses: tauri-apps/tauri-action@73fb865345c54760d875b94642314f8c0c894afa # ratchet:tauri-apps/tauri-action@action-v0.6.1 env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} APPLE_ID: ${{ env.APPLE_ID }} APPLE_PASSWORD: ${{ env.APPLE_PASSWORD }} APPLE_SIGNING_IDENTITY: ${{ env.CERT_ID }} APPLE_TEAM_ID: ${{ env.APPLE_TEAM_ID }} with: # Use the release created by the create-release job to avoid race conditions # when multiple matrix jobs try to create/update the same release simultaneously releaseId: ${{ needs.create-release.outputs.release-id }} assetNamePattern: "[name]_[arch][ext]" args: ${{ matrix.args }} build-web-amd64: needs: determine-builds if: needs.determine-builds.outputs.build-web == 'true' runs-on: - runs-on - runner=4cpu-linux-x64 - run-id=${{ github.run_id }}-web-amd64 - extras=ecr-cache timeout-minutes: 90 environment: release outputs: digest: ${{ steps.build.outputs.digest }} env: REGISTRY_IMAGE: onyxdotapp/onyx-web-server steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Docker meta id: meta uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0 with: images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }} flavor: | latest=false - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Build and push AMD64 id: build uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6 with: context: ./web file: ./web/Dockerfile platforms: linux/amd64 labels: ${{ steps.meta.outputs.labels }} build-args: | ONYX_VERSION=${{ github.ref_name }} NODE_OPTIONS=--max-old-space-size=8192 cache-from: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64 type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-amd64,mode=max outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }} build-web-arm64: needs: determine-builds if: needs.determine-builds.outputs.build-web == 'true' runs-on: - runs-on - runner=4cpu-linux-arm64 - run-id=${{ github.run_id }}-web-arm64 - extras=ecr-cache timeout-minutes: 90 environment: release outputs: digest: ${{ steps.build.outputs.digest }} env: REGISTRY_IMAGE: onyxdotapp/onyx-web-server steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Docker meta id: meta uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0 with: images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }} flavor: | latest=false - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Build and push ARM64 id: build uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6 with: context: ./web file: ./web/Dockerfile platforms: linux/arm64 labels: ${{ steps.meta.outputs.labels }} build-args: | ONYX_VERSION=${{ github.ref_name }} NODE_OPTIONS=--max-old-space-size=8192 cache-from: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64 type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-arm64,mode=max outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }} merge-web: needs: - determine-builds - build-web-amd64 - build-web-arm64 runs-on: - runs-on - runner=2cpu-linux-x64 - run-id=${{ github.run_id }}-merge-web - extras=ecr-cache timeout-minutes: 90 environment: release env: REGISTRY_IMAGE: onyxdotapp/onyx-web-server steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Docker meta id: meta uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0 with: images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }} flavor: | latest=false tags: | type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('web-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }} type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }} type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'craft-latest' || '' }} type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }} type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta == 'true' && 'beta' || '' }} - name: Create and push manifest env: IMAGE_REPO: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }} AMD64_DIGEST: ${{ needs.build-web-amd64.outputs.digest }} ARM64_DIGEST: ${{ needs.build-web-arm64.outputs.digest }} META_TAGS: ${{ steps.meta.outputs.tags }} run: | IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}" docker buildx imagetools create \ $(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \ $IMAGES build-web-cloud-amd64: needs: determine-builds if: needs.determine-builds.outputs.build-web-cloud == 'true' runs-on: - runs-on - runner=4cpu-linux-x64 - run-id=${{ github.run_id }}-web-cloud-amd64 - extras=ecr-cache timeout-minutes: 90 environment: release outputs: digest: ${{ steps.build.outputs.digest }} env: REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Docker meta id: meta uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0 with: images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }} flavor: | latest=false - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Build and push AMD64 id: build uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6 with: context: ./web file: ./web/Dockerfile platforms: linux/amd64 labels: ${{ steps.meta.outputs.labels }} build-args: | ONYX_VERSION=${{ github.ref_name }} NEXT_PUBLIC_CLOUD_ENABLED=true NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }} NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }} NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }} NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }} NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${{ vars.NEXT_PUBLIC_RECAPTCHA_SITE_KEY }} NEXT_PUBLIC_GTM_ENABLED=true NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true NODE_OPTIONS=--max-old-space-size=8192 SENTRY_RELEASE=${{ github.sha }} secrets: | sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }} cache-from: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64 type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-amd64,mode=max outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }} build-web-cloud-arm64: needs: determine-builds if: needs.determine-builds.outputs.build-web-cloud == 'true' runs-on: - runs-on - runner=4cpu-linux-arm64 - run-id=${{ github.run_id }}-web-cloud-arm64 - extras=ecr-cache timeout-minutes: 90 environment: release outputs: digest: ${{ steps.build.outputs.digest }} env: REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Docker meta id: meta uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0 with: images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }} flavor: | latest=false - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Build and push ARM64 id: build uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6 with: context: ./web file: ./web/Dockerfile platforms: linux/arm64 labels: ${{ steps.meta.outputs.labels }} build-args: | ONYX_VERSION=${{ github.ref_name }} NEXT_PUBLIC_CLOUD_ENABLED=true NEXT_PUBLIC_POSTHOG_KEY=${{ secrets.POSTHOG_KEY }} NEXT_PUBLIC_POSTHOG_HOST=${{ secrets.POSTHOG_HOST }} NEXT_PUBLIC_SENTRY_DSN=${{ secrets.SENTRY_DSN }} NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY=${{ secrets.STRIPE_PUBLISHABLE_KEY }} NEXT_PUBLIC_RECAPTCHA_SITE_KEY=${{ vars.NEXT_PUBLIC_RECAPTCHA_SITE_KEY }} NEXT_PUBLIC_GTM_ENABLED=true NEXT_PUBLIC_FORGOT_PASSWORD_ENABLED=true NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK=true NODE_OPTIONS=--max-old-space-size=8192 SENTRY_RELEASE=${{ github.sha }} secrets: | sentry_auth_token=${{ secrets.SENTRY_AUTH_TOKEN }} cache-from: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64 type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:cloudweb-cache-arm64,mode=max outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }} merge-web-cloud: needs: - determine-builds - build-web-cloud-amd64 - build-web-cloud-arm64 runs-on: - runs-on - runner=2cpu-linux-x64 - run-id=${{ github.run_id }}-merge-web-cloud - extras=ecr-cache timeout-minutes: 90 environment: release env: REGISTRY_IMAGE: onyxdotapp/onyx-web-server-cloud steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Docker meta id: meta uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0 with: images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }} flavor: | latest=false tags: | type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('web-cloud-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }} - name: Create and push manifest env: IMAGE_REPO: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }} AMD64_DIGEST: ${{ needs.build-web-cloud-amd64.outputs.digest }} ARM64_DIGEST: ${{ needs.build-web-cloud-arm64.outputs.digest }} META_TAGS: ${{ steps.meta.outputs.tags }} run: | IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}" docker buildx imagetools create \ $(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \ $IMAGES build-backend-amd64: needs: determine-builds if: needs.determine-builds.outputs.build-backend == 'true' runs-on: - runs-on - runner=2cpu-linux-x64 - run-id=${{ github.run_id }}-backend-amd64 - extras=ecr-cache timeout-minutes: 90 environment: release outputs: digest: ${{ steps.build.outputs.digest }} env: REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }} steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Docker meta id: meta uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0 with: images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }} flavor: | latest=false - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Build and push AMD64 id: build uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6 with: context: ./backend file: ./backend/Dockerfile platforms: linux/amd64 labels: ${{ steps.meta.outputs.labels }} build-args: | ONYX_VERSION=${{ github.ref_name }} cache-from: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64 type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-amd64,mode=max outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }} build-backend-arm64: needs: determine-builds if: needs.determine-builds.outputs.build-backend == 'true' runs-on: - runs-on - runner=2cpu-linux-arm64 - run-id=${{ github.run_id }}-backend-arm64 - extras=ecr-cache timeout-minutes: 90 environment: release outputs: digest: ${{ steps.build.outputs.digest }} env: REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }} steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Docker meta id: meta uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0 with: images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }} flavor: | latest=false - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Build and push ARM64 id: build uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6 with: context: ./backend file: ./backend/Dockerfile platforms: linux/arm64 labels: ${{ steps.meta.outputs.labels }} build-args: | ONYX_VERSION=${{ github.ref_name }} cache-from: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64 type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-arm64,mode=max outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }} merge-backend: needs: - determine-builds - build-backend-amd64 - build-backend-arm64 runs-on: - runs-on - runner=2cpu-linux-x64 - run-id=${{ github.run_id }}-merge-backend - extras=ecr-cache timeout-minutes: 90 environment: release env: REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }} steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Docker meta id: meta uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0 with: images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }} flavor: | latest=false tags: | type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('backend-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }} type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }} type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }} type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }} - name: Create and push manifest env: IMAGE_REPO: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }} AMD64_DIGEST: ${{ needs.build-backend-amd64.outputs.digest }} ARM64_DIGEST: ${{ needs.build-backend-arm64.outputs.digest }} META_TAGS: ${{ steps.meta.outputs.tags }} run: | IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}" docker buildx imagetools create \ $(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \ $IMAGES build-backend-craft-amd64: needs: determine-builds if: needs.determine-builds.outputs.build-backend-craft == 'true' runs-on: - runs-on - runner=2cpu-linux-x64 - run-id=${{ github.run_id }}-backend-craft-amd64 - extras=ecr-cache timeout-minutes: 90 environment: release outputs: digest: ${{ steps.build.outputs.digest }} env: REGISTRY_IMAGE: onyxdotapp/onyx-backend steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Docker meta id: meta uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0 with: images: ${{ env.REGISTRY_IMAGE }} flavor: | latest=false - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Build and push AMD64 id: build uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6 with: context: ./backend file: ./backend/Dockerfile platforms: linux/amd64 labels: ${{ steps.meta.outputs.labels }} build-args: | ONYX_VERSION=${{ github.ref_name }} ENABLE_CRAFT=true cache-from: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64 type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-amd64,mode=max outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }} build-backend-craft-arm64: needs: determine-builds if: needs.determine-builds.outputs.build-backend-craft == 'true' runs-on: - runs-on - runner=2cpu-linux-arm64 - run-id=${{ github.run_id }}-backend-craft-arm64 - extras=ecr-cache timeout-minutes: 90 environment: release outputs: digest: ${{ steps.build.outputs.digest }} env: REGISTRY_IMAGE: onyxdotapp/onyx-backend steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Docker meta id: meta uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0 with: images: ${{ env.REGISTRY_IMAGE }} flavor: | latest=false - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Build and push ARM64 id: build uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6 with: context: ./backend file: ./backend/Dockerfile platforms: linux/arm64 labels: ${{ steps.meta.outputs.labels }} build-args: | ONYX_VERSION=${{ github.ref_name }} ENABLE_CRAFT=true cache-from: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64 type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-craft-cache-arm64,mode=max outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }} merge-backend-craft: needs: - determine-builds - build-backend-craft-amd64 - build-backend-craft-arm64 if: needs.determine-builds.outputs.build-backend-craft == 'true' runs-on: - runs-on - runner=2cpu-linux-x64 - run-id=${{ github.run_id }}-merge-backend-craft - extras=ecr-cache timeout-minutes: 90 environment: release env: REGISTRY_IMAGE: onyxdotapp/onyx-backend steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Docker meta id: meta uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0 with: images: ${{ env.REGISTRY_IMAGE }} flavor: | latest=false tags: | type=raw,value=craft-latest - name: Create and push manifest env: IMAGE_REPO: ${{ env.REGISTRY_IMAGE }} AMD64_DIGEST: ${{ needs.build-backend-craft-amd64.outputs.digest }} ARM64_DIGEST: ${{ needs.build-backend-craft-arm64.outputs.digest }} META_TAGS: ${{ steps.meta.outputs.tags }} run: | IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}" docker buildx imagetools create \ $(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \ $IMAGES build-model-server-amd64: needs: determine-builds if: needs.determine-builds.outputs.build-model-server == 'true' runs-on: - runs-on - runner=2cpu-linux-x64 - run-id=${{ github.run_id }}-model-server-amd64 - volume=40gb - extras=ecr-cache timeout-minutes: 90 environment: release outputs: digest: ${{ steps.build.outputs.digest }} env: REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }} steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Docker meta id: meta uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0 with: images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }} flavor: | latest=false - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 with: buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }} - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Build and push AMD64 id: build uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6 env: DEBUG: ${{ vars.DOCKER_DEBUG == 'true' && 1 || 0 }} with: context: ./backend file: ./backend/Dockerfile.model_server platforms: linux/amd64 labels: ${{ steps.meta.outputs.labels }} build-args: | ONYX_VERSION=${{ github.ref_name }} cache-from: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64 type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-amd64,mode=max outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true no-cache: ${{ env.EDGE_TAG != 'true' && vars.MODEL_SERVER_NO_CACHE == 'true' }} provenance: false sbom: false build-model-server-arm64: needs: determine-builds if: needs.determine-builds.outputs.build-model-server == 'true' runs-on: - runs-on - runner=2cpu-linux-arm64 - run-id=${{ github.run_id }}-model-server-arm64 - volume=40gb - extras=ecr-cache timeout-minutes: 90 environment: release outputs: digest: ${{ steps.build.outputs.digest }} env: REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }} steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Docker meta id: meta uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0 with: images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }} flavor: | latest=false - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 with: buildkitd-flags: ${{ vars.DOCKER_DEBUG == 'true' && '--debug' || '' }} - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Build and push ARM64 id: build uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6 env: DEBUG: ${{ vars.DOCKER_DEBUG == 'true' && 1 || 0 }} with: context: ./backend file: ./backend/Dockerfile.model_server platforms: linux/arm64 labels: ${{ steps.meta.outputs.labels }} build-args: | ONYX_VERSION=${{ github.ref_name }} cache-from: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64 type=registry,ref=${{ env.REGISTRY_IMAGE }}:edge type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-arm64,mode=max outputs: type=image,name=${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true no-cache: ${{ env.EDGE_TAG != 'true' && vars.MODEL_SERVER_NO_CACHE == 'true' }} provenance: false sbom: false merge-model-server: needs: - determine-builds - build-model-server-amd64 - build-model-server-arm64 runs-on: - runs-on - runner=2cpu-linux-x64 - run-id=${{ github.run_id }}-merge-model-server - extras=ecr-cache timeout-minutes: 90 environment: release env: REGISTRY_IMAGE: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }} steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Docker meta id: meta uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0 with: images: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }} flavor: | latest=false tags: | type=raw,value=${{ needs.determine-builds.outputs.is-test-run == 'true' && format('model-server-{0}', needs.determine-builds.outputs.sanitized-tag) || github.ref_name }} type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'latest' || '' }} type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-latest == 'true' && 'craft-latest' || '' }} type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && env.EDGE_TAG == 'true' && 'edge' || '' }} type=raw,value=${{ needs.determine-builds.outputs.is-test-run != 'true' && needs.determine-builds.outputs.is-beta-standalone == 'true' && 'beta' || '' }} - name: Create and push manifest env: IMAGE_REPO: ${{ needs.determine-builds.outputs.is-test-run == 'true' && env.RUNS_ON_ECR_CACHE || env.REGISTRY_IMAGE }} AMD64_DIGEST: ${{ needs.build-model-server-amd64.outputs.digest }} ARM64_DIGEST: ${{ needs.build-model-server-arm64.outputs.digest }} META_TAGS: ${{ steps.meta.outputs.tags }} run: | IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}" docker buildx imagetools create \ $(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \ $IMAGES trivy-scan: needs: - determine-builds - merge-web - merge-web-cloud - merge-backend - merge-model-server if: >- always() && !cancelled() && (needs.merge-web.result == 'success' || needs.merge-web-cloud.result == 'success' || needs.merge-backend.result == 'success' || needs.merge-model-server.result == 'success') runs-on: - runs-on - runner=2cpu-linux-arm64 - run-id=${{ github.run_id }}-trivy-scan-${{ matrix.component }} - extras=ecr-cache permissions: security-events: write # needed for SARIF uploads timeout-minutes: 10 strategy: fail-fast: false matrix: include: - component: web registry-image: onyxdotapp/onyx-web-server - component: web-cloud registry-image: onyxdotapp/onyx-web-server-cloud - component: backend registry-image: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-backend-cloud' || 'onyxdotapp/onyx-backend' }} trivyignore: backend/.trivyignore - component: model-server registry-image: ${{ contains(github.ref_name, 'cloud') && 'onyxdotapp/onyx-model-server-cloud' || 'onyxdotapp/onyx-model-server' }} steps: - name: Check if this scan should run id: should-run run: | case "$COMPONENT" in web) RESULT="$MERGE_WEB" ;; web-cloud) RESULT="$MERGE_WEB_CLOUD" ;; backend) RESULT="$MERGE_BACKEND" ;; model-server) RESULT="$MERGE_MODEL_SERVER" ;; esac if [ "$RESULT" == "success" ]; then echo "run=true" >> "$GITHUB_OUTPUT" else echo "run=false" >> "$GITHUB_OUTPUT" fi env: COMPONENT: ${{ matrix.component }} MERGE_WEB: ${{ needs.merge-web.result }} MERGE_WEB_CLOUD: ${{ needs.merge-web-cloud.result }} MERGE_BACKEND: ${{ needs.merge-backend.result }} MERGE_MODEL_SERVER: ${{ needs.merge-model-server.result }} - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 if: steps.should-run.outputs.run == 'true' - name: Checkout if: steps.should-run.outputs.run == 'true' && matrix.trivyignore != '' uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Determine scan image if: steps.should-run.outputs.run == 'true' id: scan-image run: | if [ "$IS_TEST_RUN" == "true" ]; then echo "image=${RUNS_ON_ECR_CACHE}:${TAG_PREFIX}-${SANITIZED_TAG}" >> "$GITHUB_OUTPUT" else echo "image=docker.io/${REGISTRY_IMAGE}:${REF_NAME}" >> "$GITHUB_OUTPUT" fi env: IS_TEST_RUN: ${{ needs.determine-builds.outputs.is-test-run }} TAG_PREFIX: ${{ matrix.component }} SANITIZED_TAG: ${{ needs.determine-builds.outputs.sanitized-tag }} REGISTRY_IMAGE: ${{ matrix.registry-image }} REF_NAME: ${{ github.ref_name }} - name: Run Trivy vulnerability scanner if: steps.should-run.outputs.run == 'true' uses: aquasecurity/trivy-action@57a97c7e7821a5776cebc9bb87c984fa69cba8f1 # ratchet:aquasecurity/trivy-action@v0.35.0 with: image-ref: ${{ steps.scan-image.outputs.image }} severity: CRITICAL,HIGH format: "sarif" output: "trivy-results.sarif" trivyignores: ${{ matrix.trivyignore }} env: TRIVY_USERNAME: ${{ secrets.DOCKER_USERNAME }} TRIVY_PASSWORD: ${{ secrets.DOCKER_TOKEN }} - name: Upload Trivy scan results to GitHub Security tab if: steps.should-run.outputs.run == 'true' uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab with: sarif_file: "trivy-results.sarif" notify-slack-on-failure: needs: - determine-builds - build-desktop - build-web-amd64 - build-web-arm64 - merge-web - build-web-cloud-amd64 - build-web-cloud-arm64 - merge-web-cloud - build-backend-amd64 - build-backend-arm64 - merge-backend - build-backend-craft-amd64 - build-backend-craft-arm64 - merge-backend-craft - build-model-server-amd64 - build-model-server-arm64 - merge-model-server if: always() && (needs.build-desktop.result == 'failure' || needs.build-web-amd64.result == 'failure' || needs.build-web-arm64.result == 'failure' || needs.merge-web.result == 'failure' || needs.build-web-cloud-amd64.result == 'failure' || needs.build-web-cloud-arm64.result == 'failure' || needs.merge-web-cloud.result == 'failure' || needs.build-backend-amd64.result == 'failure' || needs.build-backend-arm64.result == 'failure' || needs.merge-backend.result == 'failure' || (needs.determine-builds.outputs.build-backend-craft == 'true' && (needs.build-backend-craft-amd64.result == 'failure' || needs.build-backend-craft-arm64.result == 'failure' || needs.merge-backend-craft.result == 'failure')) || needs.build-model-server-amd64.result == 'failure' || needs.build-model-server-arm64.result == 'failure' || needs.merge-model-server.result == 'failure') && needs.determine-builds.outputs.is-test-run != 'true' # NOTE: Github-hosted runners have about 20s faster queue times and are preferred here. runs-on: ubuntu-slim timeout-minutes: 90 environment: release steps: - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Determine failed jobs id: failed-jobs shell: bash run: | FAILED_JOBS="" if [ "${NEEDS_BUILD_DESKTOP_RESULT}" == "failure" ]; then FAILED_JOBS="${FAILED_JOBS}• build-desktop\\n" fi if [ "${NEEDS_BUILD_WEB_AMD64_RESULT}" == "failure" ]; then FAILED_JOBS="${FAILED_JOBS}• build-web-amd64\\n" fi if [ "${NEEDS_BUILD_WEB_ARM64_RESULT}" == "failure" ]; then FAILED_JOBS="${FAILED_JOBS}• build-web-arm64\\n" fi if [ "${NEEDS_MERGE_WEB_RESULT}" == "failure" ]; then FAILED_JOBS="${FAILED_JOBS}• merge-web\\n" fi if [ "${NEEDS_BUILD_WEB_CLOUD_AMD64_RESULT}" == "failure" ]; then FAILED_JOBS="${FAILED_JOBS}• build-web-cloud-amd64\\n" fi if [ "${NEEDS_BUILD_WEB_CLOUD_ARM64_RESULT}" == "failure" ]; then FAILED_JOBS="${FAILED_JOBS}• build-web-cloud-arm64\\n" fi if [ "${NEEDS_MERGE_WEB_CLOUD_RESULT}" == "failure" ]; then FAILED_JOBS="${FAILED_JOBS}• merge-web-cloud\\n" fi if [ "${NEEDS_BUILD_BACKEND_AMD64_RESULT}" == "failure" ]; then FAILED_JOBS="${FAILED_JOBS}• build-backend-amd64\\n" fi if [ "${NEEDS_BUILD_BACKEND_ARM64_RESULT}" == "failure" ]; then FAILED_JOBS="${FAILED_JOBS}• build-backend-arm64\\n" fi if [ "${NEEDS_MERGE_BACKEND_RESULT}" == "failure" ]; then FAILED_JOBS="${FAILED_JOBS}• merge-backend\\n" fi if [ "${NEEDS_BUILD_MODEL_SERVER_AMD64_RESULT}" == "failure" ]; then FAILED_JOBS="${FAILED_JOBS}• build-model-server-amd64\\n" fi if [ "${NEEDS_BUILD_MODEL_SERVER_ARM64_RESULT}" == "failure" ]; then FAILED_JOBS="${FAILED_JOBS}• build-model-server-arm64\\n" fi if [ "${NEEDS_MERGE_MODEL_SERVER_RESULT}" == "failure" ]; then FAILED_JOBS="${FAILED_JOBS}• merge-model-server\\n" fi # Remove trailing \n and set output FAILED_JOBS=$(printf '%s' "$FAILED_JOBS" | sed 's/\\n$//') echo "jobs=$FAILED_JOBS" >> "$GITHUB_OUTPUT" env: NEEDS_BUILD_DESKTOP_RESULT: ${{ needs.build-desktop.result }} NEEDS_BUILD_WEB_AMD64_RESULT: ${{ needs.build-web-amd64.result }} NEEDS_BUILD_WEB_ARM64_RESULT: ${{ needs.build-web-arm64.result }} NEEDS_MERGE_WEB_RESULT: ${{ needs.merge-web.result }} NEEDS_BUILD_WEB_CLOUD_AMD64_RESULT: ${{ needs.build-web-cloud-amd64.result }} NEEDS_BUILD_WEB_CLOUD_ARM64_RESULT: ${{ needs.build-web-cloud-arm64.result }} NEEDS_MERGE_WEB_CLOUD_RESULT: ${{ needs.merge-web-cloud.result }} NEEDS_BUILD_BACKEND_AMD64_RESULT: ${{ needs.build-backend-amd64.result }} NEEDS_BUILD_BACKEND_ARM64_RESULT: ${{ needs.build-backend-arm64.result }} NEEDS_MERGE_BACKEND_RESULT: ${{ needs.merge-backend.result }} NEEDS_BUILD_MODEL_SERVER_AMD64_RESULT: ${{ needs.build-model-server-amd64.result }} NEEDS_BUILD_MODEL_SERVER_ARM64_RESULT: ${{ needs.build-model-server-arm64.result }} NEEDS_MERGE_MODEL_SERVER_RESULT: ${{ needs.merge-model-server.result }} - name: Send Slack notification uses: ./.github/actions/slack-notify with: webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }} failed-jobs: ${{ steps.failed-jobs.outputs.jobs }} title: "🚨 Deployment Workflow Failed" ref-name: ${{ github.ref_name }} ================================================ FILE: .github/workflows/docker-tag-beta.yml ================================================ # This workflow is set up to be manually triggered via the GitHub Action tab. # Given a version, it will tag those backend and webserver images as "beta". name: Tag Beta Version on: workflow_dispatch: inputs: version: description: "The version (ie v1.0.0-beta.0) to tag as beta" required: true permissions: contents: read jobs: tag: # See https://runs-on.com/runners/linux/ # use a lower powered instance since this just does i/o to docker hub runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-tag"] timeout-minutes: 45 steps: - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - name: Enable Docker CLI experimental features run: echo "DOCKER_CLI_EXPERIMENTAL=enabled" >> $GITHUB_ENV - name: Pull, Tag and Push Web Server Image env: VERSION: ${{ github.event.inputs.version }} run: | docker buildx imagetools create -t onyxdotapp/onyx-web-server:beta onyxdotapp/onyx-web-server:${VERSION} - name: Pull, Tag and Push API Server Image env: VERSION: ${{ github.event.inputs.version }} run: | docker buildx imagetools create -t onyxdotapp/onyx-backend:beta onyxdotapp/onyx-backend:${VERSION} - name: Pull, Tag and Push Model Server Image env: VERSION: ${{ github.event.inputs.version }} run: | docker buildx imagetools create -t onyxdotapp/onyx-model-server:beta onyxdotapp/onyx-model-server:${VERSION} ================================================ FILE: .github/workflows/docker-tag-latest.yml ================================================ # This workflow is set up to be manually triggered via the GitHub Action tab. # Given a version, it will tag those backend and webserver images as "latest". name: Tag Latest Version on: workflow_dispatch: inputs: version: description: "The version (ie v0.0.1) to tag as latest" required: true permissions: contents: read jobs: tag: # See https://runs-on.com/runners/linux/ # use a lower powered instance since this just does i/o to docker hub runs-on: [runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-tag"] timeout-minutes: 45 steps: - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - name: Enable Docker CLI experimental features run: echo "DOCKER_CLI_EXPERIMENTAL=enabled" >> $GITHUB_ENV - name: Pull, Tag and Push Web Server Image env: VERSION: ${{ github.event.inputs.version }} run: | docker buildx imagetools create -t onyxdotapp/onyx-web-server:latest onyxdotapp/onyx-web-server:${VERSION} - name: Pull, Tag and Push API Server Image env: VERSION: ${{ github.event.inputs.version }} run: | docker buildx imagetools create -t onyxdotapp/onyx-backend:latest onyxdotapp/onyx-backend:${VERSION} - name: Pull, Tag and Push Model Server Image env: VERSION: ${{ github.event.inputs.version }} run: | docker buildx imagetools create -t onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:${VERSION} ================================================ FILE: .github/workflows/helm-chart-releases.yml ================================================ name: Release Onyx Helm Charts on: push: branches: - main permissions: write-all jobs: release: permissions: contents: write runs-on: ubuntu-latest timeout-minutes: 45 steps: - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: fetch-depth: 0 persist-credentials: false - name: Install Helm CLI uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4 with: version: v3.12.1 - name: Add required Helm repositories run: | helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx helm repo add onyx-vespa https://onyx-dot-app.github.io/vespa-helm-charts helm repo add opensearch https://opensearch-project.github.io/helm-charts helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts helm repo add minio https://charts.min.io/ helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/ helm repo update - name: Build chart dependencies run: | set -euo pipefail for chart_dir in deployment/helm/charts/*; do if [ -f "$chart_dir/Chart.yaml" ]; then echo "Building dependencies for $chart_dir" helm dependency build "$chart_dir" fi done - name: Publish Helm charts to gh-pages # NOTE: HEAD of https://github.com/stefanprodan/helm-gh-pages/pull/43 uses: stefanprodan/helm-gh-pages@ad32ad3b8720abfeaac83532fd1e9bdfca5bbe27 # zizmor: ignore[impostor-commit] with: token: ${{ secrets.GITHUB_TOKEN }} charts_dir: deployment/helm/charts branch: gh-pages commit_username: ${{ github.actor }} commit_email: ${{ github.actor }}@users.noreply.github.com ================================================ FILE: .github/workflows/merge-group.yml ================================================ name: Merge Group-Specific on: merge_group: permissions: contents: read jobs: # This job immediately succeeds to satisfy branch protection rules on merge_group events. # There is a similarly named "required" job in pr-integration-tests.yml which runs the actual # integration tests. That job runs on both pull_request and merge_group events, and this job # exists solely to provide a fast-passing check with the same name for branch protection. # The actual tests remain enforced on presubmit (pull_request events). required: runs-on: ubuntu-latest timeout-minutes: 45 steps: - name: Success run: echo "Success" # This job immediately succeeds to satisfy branch protection rules on merge_group events. # There is a similarly named "playwright-required" job in pr-playwright-tests.yml which runs # the actual playwright tests. That job runs on both pull_request and merge_group events, and # this job exists solely to provide a fast-passing check with the same name for branch protection. # The actual tests remain enforced on presubmit (pull_request events). playwright-required: runs-on: ubuntu-latest timeout-minutes: 45 steps: - name: Success run: echo "Success" ================================================ FILE: .github/workflows/nightly-close-stale-issues.yml ================================================ name: 'Nightly - Close stale issues and PRs' on: schedule: - cron: '0 11 * * *' # Runs every day at 3 AM PST / 4 AM PDT / 11 AM UTC permissions: # contents: write # only for delete-branch option issues: write pull-requests: write jobs: stale: runs-on: ubuntu-latest timeout-minutes: 45 steps: - uses: actions/stale@997185467fa4f803885201cee163a9f38240193d # ratchet:actions/stale@v10 with: stale-issue-message: 'This issue is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.' stale-pr-message: 'This PR is stale because it has been open 75 days with no activity. Remove stale label or comment or this will be closed in 15 days.' close-issue-message: 'This issue was closed because it has been stalled for 90 days with no activity.' close-pr-message: 'This PR was closed because it has been stalled for 90 days with no activity.' days-before-stale: 75 # days-before-close: 90 # uncomment after we test stale behavior ================================================ FILE: .github/workflows/nightly-llm-provider-chat.yml ================================================ name: Nightly LLM Provider Chat Tests concurrency: group: Nightly-LLM-Provider-Chat-${{ github.workflow }}-${{ github.ref_name }} cancel-in-progress: true on: schedule: # Runs daily at 10:30 UTC (2:30 AM PST / 3:30 AM PDT) - cron: "30 10 * * *" workflow_dispatch: permissions: contents: read jobs: provider-chat-test: uses: ./.github/workflows/reusable-nightly-llm-provider-chat.yml secrets: AWS_OIDC_ROLE_ARN: ${{ secrets.AWS_OIDC_ROLE_ARN }} permissions: contents: read id-token: write with: openai_models: ${{ vars.NIGHTLY_LLM_OPENAI_MODELS }} anthropic_models: ${{ vars.NIGHTLY_LLM_ANTHROPIC_MODELS }} bedrock_models: ${{ vars.NIGHTLY_LLM_BEDROCK_MODELS }} vertex_ai_models: ${{ vars.NIGHTLY_LLM_VERTEX_AI_MODELS }} azure_models: ${{ vars.NIGHTLY_LLM_AZURE_MODELS }} azure_api_base: ${{ vars.NIGHTLY_LLM_AZURE_API_BASE }} ollama_models: ${{ vars.NIGHTLY_LLM_OLLAMA_MODELS }} openrouter_models: ${{ vars.NIGHTLY_LLM_OPENROUTER_MODELS }} strict: true notify-slack-on-failure: needs: [provider-chat-test] if: failure() && github.event_name == 'schedule' runs-on: ubuntu-slim environment: ci-protected timeout-minutes: 5 steps: - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Send Slack notification uses: ./.github/actions/slack-notify with: webhook-url: ${{ secrets.SLACK_WEBHOOK }} failed-jobs: provider-chat-test title: "🚨 Scheduled LLM Provider Chat Tests failed!" ref-name: ${{ github.ref_name }} ================================================ FILE: .github/workflows/post-merge-beta-cherry-pick.yml ================================================ name: Post-Merge Beta Cherry-Pick on: pull_request_target: types: - closed # SECURITY NOTE: # This workflow intentionally uses pull_request_target so post-merge automation can # use base-repo credentials. Do not checkout PR head refs in this workflow # (e.g. github.event.pull_request.head.sha). Only trusted base refs are allowed. permissions: contents: read jobs: resolve-cherry-pick-request: if: >- github.event.pull_request.merged == true && github.event.pull_request.base.ref == 'main' && github.event.pull_request.head.repo.full_name == github.repository outputs: should_cherrypick: ${{ steps.gate.outputs.should_cherrypick }} pr_number: ${{ steps.gate.outputs.pr_number }} merge_commit_sha: ${{ steps.gate.outputs.merge_commit_sha }} merged_by: ${{ steps.gate.outputs.merged_by }} gate_error: ${{ steps.gate.outputs.gate_error }} runs-on: ubuntu-latest timeout-minutes: 10 steps: - name: Resolve merged PR and checkbox state id: gate env: GH_TOKEN: ${{ github.token }} PR_NUMBER: ${{ github.event.pull_request.number }} # SECURITY: keep PR body in env/plain-text handling; avoid directly # inlining github.event.pull_request.body into shell commands. PR_BODY: ${{ github.event.pull_request.body }} MERGE_COMMIT_SHA: ${{ github.event.pull_request.merge_commit_sha }} MERGED_BY: ${{ github.event.pull_request.merged_by.login }} # Explicit merger allowlist used because pull_request_target runs with # the default GITHUB_TOKEN, which cannot reliably read org/team # membership for this repository context. ALLOWED_MERGERS: | acaprau bo-onyx danelegend duo-onyx evan-onyx jessicasingh7 jmelahman joachim-danswer justin-tahara nmgarza5 raunakab rohoswagger subash-mohan trial2onyx wenxi-onyx weves yuhongsun96 run: | echo "pr_number=${PR_NUMBER}" >> "$GITHUB_OUTPUT" echo "merged_by=${MERGED_BY}" >> "$GITHUB_OUTPUT" if ! echo "${PR_BODY}" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then echo "should_cherrypick=false" >> "$GITHUB_OUTPUT" echo "Cherry-pick checkbox not checked for PR #${PR_NUMBER}. Skipping." exit 0 fi # Keep should_cherrypick output before any possible exit 1 below so # notify-slack can still gate on this output even if this job fails. echo "should_cherrypick=true" >> "$GITHUB_OUTPUT" echo "Cherry-pick checkbox checked for PR #${PR_NUMBER}." if [ -z "${MERGE_COMMIT_SHA}" ] || [ "${MERGE_COMMIT_SHA}" = "null" ]; then echo "gate_error=missing-merge-commit-sha" >> "$GITHUB_OUTPUT" echo "::error::PR #${PR_NUMBER} requested cherry-pick, but merge_commit_sha is missing." exit 1 fi echo "merge_commit_sha=${MERGE_COMMIT_SHA}" >> "$GITHUB_OUTPUT" normalized_merged_by="$(printf '%s' "${MERGED_BY}" | tr '[:upper:]' '[:lower:]')" normalized_allowed_mergers="$(printf '%s\n' "${ALLOWED_MERGERS}" | tr '[:upper:]' '[:lower:]')" if ! printf '%s\n' "${normalized_allowed_mergers}" | grep -Fxq "${normalized_merged_by}"; then echo "gate_error=not-allowed-merger" >> "$GITHUB_OUTPUT" echo "::error::${MERGED_BY} is not in the explicit cherry-pick merger allowlist. Failing cherry-pick gate." exit 1 fi exit 0 cherry-pick-to-latest-release: needs: - resolve-cherry-pick-request if: needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && needs.resolve-cherry-pick-request.result == 'success' permissions: contents: write pull-requests: write outputs: cherry_pick_pr_url: ${{ steps.run_cherry_pick.outputs.pr_url }} cherry_pick_reason: ${{ steps.run_cherry_pick.outputs.reason }} cherry_pick_details: ${{ steps.run_cherry_pick.outputs.details }} runs-on: ubuntu-latest timeout-minutes: 45 steps: - name: Checkout repository # SECURITY: keep checkout pinned to trusted base branch; do not switch to PR head refs. uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: fetch-depth: 0 persist-credentials: true ref: main - name: Install the latest version of uv uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7 with: enable-cache: false version: "0.9.9" - name: Configure git identity run: | git config user.name "github-actions[bot]" git config user.email "github-actions[bot]@users.noreply.github.com" - name: Create cherry-pick PR to latest release id: run_cherry_pick env: GH_TOKEN: ${{ github.token }} GITHUB_TOKEN: ${{ github.token }} CHERRY_PICK_ASSIGNEE: ${{ needs.resolve-cherry-pick-request.outputs.merged_by }} MERGE_COMMIT_SHA: ${{ needs.resolve-cherry-pick-request.outputs.merge_commit_sha }} run: | output_file="$(mktemp)" set +e uv run --no-sync --with onyx-devtools ods cherry-pick "${MERGE_COMMIT_SHA}" --yes --no-verify 2>&1 | tee "$output_file" pipe_statuses=("${PIPESTATUS[@]}") exit_code="${pipe_statuses[0]}" tee_exit="${pipe_statuses[1]:-0}" set -e if [ "${tee_exit}" -ne 0 ]; then echo "status=failure" >> "$GITHUB_OUTPUT" echo "reason=output-capture-failed" >> "$GITHUB_OUTPUT" echo "::error::tee failed to capture cherry-pick output (exit ${tee_exit}); cannot classify result." exit 1 fi if [ "${exit_code}" -eq 0 ]; then pr_url="$(sed -n 's/^.*PR created successfully: \(https:\/\/github\.com\/[^[:space:]]\+\/pull\/[0-9]\+\).*$/\1/p' "$output_file" | tail -n 1)" echo "status=success" >> "$GITHUB_OUTPUT" if [ -n "${pr_url}" ]; then echo "pr_url=${pr_url}" >> "$GITHUB_OUTPUT" fi exit 0 fi echo "status=failure" >> "$GITHUB_OUTPUT" reason="command-failed" if grep -qiE "merge conflict during cherry-pick|CONFLICT|could not apply|cherry-pick in progress with staged changes" "$output_file"; then reason="merge-conflict" fi echo "reason=${reason}" >> "$GITHUB_OUTPUT" { echo "details<> "$GITHUB_OUTPUT" - name: Mark workflow as failed if cherry-pick failed if: steps.run_cherry_pick.outputs.status == 'failure' env: CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }} run: | echo "::error::Automated cherry-pick failed (${CHERRY_PICK_REASON})." exit 1 notify-slack-on-cherry-pick-success: needs: - resolve-cherry-pick-request - cherry-pick-to-latest-release if: needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && needs.resolve-cherry-pick-request.result == 'success' && needs.cherry-pick-to-latest-release.result == 'success' runs-on: ubuntu-slim environment: ci-protected timeout-minutes: 10 steps: - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Fail if Slack webhook secret is missing env: CHERRY_PICK_PRS_WEBHOOK: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }} run: | if [ -z "${CHERRY_PICK_PRS_WEBHOOK}" ]; then echo "::error::CHERRY_PICK_PRS_WEBHOOK is not configured." exit 1 fi - name: Build cherry-pick success summary id: success-summary env: SOURCE_PR_NUMBER: ${{ needs.resolve-cherry-pick-request.outputs.pr_number }} MERGE_COMMIT_SHA: ${{ needs.resolve-cherry-pick-request.outputs.merge_commit_sha }} CHERRY_PICK_PR_URL: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_pr_url }} run: | source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}" details="*Cherry-pick PR opened successfully.*\\n• author: {mention}\\n• source PR: ${source_pr_url}" if [ -n "${CHERRY_PICK_PR_URL}" ]; then details="${details}\\n• cherry-pick PR: ${CHERRY_PICK_PR_URL}" fi if [ -n "${MERGE_COMMIT_SHA}" ]; then details="${details}\\n• merge SHA: ${MERGE_COMMIT_SHA}" fi echo "details=${details}" >> "$GITHUB_OUTPUT" - name: Notify #cherry-pick-prs about cherry-pick success uses: ./.github/actions/slack-notify with: webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }} mention: ${{ needs.resolve-cherry-pick-request.outputs.merged_by }} details: ${{ steps.success-summary.outputs.details }} title: "✅ Automated Cherry-Pick PR Opened" ref-name: ${{ github.event.pull_request.base.ref }} notify-slack-on-cherry-pick-failure: needs: - resolve-cherry-pick-request - cherry-pick-to-latest-release if: always() && needs.resolve-cherry-pick-request.outputs.should_cherrypick == 'true' && (needs.resolve-cherry-pick-request.result == 'failure' || needs.cherry-pick-to-latest-release.result == 'failure') runs-on: ubuntu-slim environment: ci-protected timeout-minutes: 10 steps: - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Fail if Slack webhook secret is missing env: CHERRY_PICK_PRS_WEBHOOK: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }} run: | if [ -z "${CHERRY_PICK_PRS_WEBHOOK}" ]; then echo "::error::CHERRY_PICK_PRS_WEBHOOK is not configured." exit 1 fi - name: Build cherry-pick failure summary id: failure-summary env: SOURCE_PR_NUMBER: ${{ needs.resolve-cherry-pick-request.outputs.pr_number }} MERGE_COMMIT_SHA: ${{ needs.resolve-cherry-pick-request.outputs.merge_commit_sha }} GATE_ERROR: ${{ needs.resolve-cherry-pick-request.outputs.gate_error }} CHERRY_PICK_REASON: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_reason }} CHERRY_PICK_DETAILS: ${{ needs.cherry-pick-to-latest-release.outputs.cherry_pick_details }} run: | source_pr_url="https://github.com/${GITHUB_REPOSITORY}/pull/${SOURCE_PR_NUMBER}" reason_text="cherry-pick command failed" if [ "${GATE_ERROR}" = "missing-merge-commit-sha" ]; then reason_text="requested cherry-pick but merge commit SHA was missing" elif [ "${GATE_ERROR}" = "not-allowed-merger" ]; then reason_text="merger is not in the explicit cherry-pick allowlist" elif [ "${CHERRY_PICK_REASON}" = "output-capture-failed" ]; then reason_text="failed to capture cherry-pick output for classification" elif [ "${CHERRY_PICK_REASON}" = "merge-conflict" ]; then reason_text="merge conflict during cherry-pick" fi details_excerpt="$(printf '%s' "${CHERRY_PICK_DETAILS}" | tail -n 8 | tr '\n' ' ' | sed "s/[[:space:]]\\+/ /g" | sed "s/\"/'/g" | cut -c1-350)" if [ -n "${GATE_ERROR}" ]; then failed_job_label="resolve-cherry-pick-request" else failed_job_label="cherry-pick-to-latest-release" fi details="• author: {mention}\\n• ${failed_job_label}\\n• source PR: ${source_pr_url}\\n• reason: ${reason_text}" if [ -n "${MERGE_COMMIT_SHA}" ]; then details="${details}\\n• merge SHA: ${MERGE_COMMIT_SHA}" fi if [ -n "${details_excerpt}" ]; then details="${details}\\n• excerpt: ${details_excerpt}" fi echo "details=${details}" >> "$GITHUB_OUTPUT" - name: Notify #cherry-pick-prs about cherry-pick failure uses: ./.github/actions/slack-notify with: webhook-url: ${{ secrets.CHERRY_PICK_PRS_WEBHOOK }} mention: ${{ needs.resolve-cherry-pick-request.outputs.merged_by }} details: ${{ steps.failure-summary.outputs.details }} title: "🚨 Automated Cherry-Pick Failed" ref-name: ${{ github.event.pull_request.base.ref }} ================================================ FILE: .github/workflows/pr-database-tests.yml ================================================ name: Database Tests concurrency: group: Database-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }} cancel-in-progress: true on: merge_group: pull_request: branches: - main - "release/**" push: tags: - "v*.*.*" permissions: contents: read jobs: database-tests: runs-on: - runs-on - runner=2cpu-linux-arm64 - "run-id=${{ github.run_id }}-database-tests" timeout-minutes: 45 steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Setup Python and Install Dependencies uses: ./.github/actions/setup-python-and-install-dependencies with: requirements: | backend/requirements/default.txt backend/requirements/dev.txt - name: Generate OpenAPI schema and Python client shell: bash # TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license env: LICENSE_ENFORCEMENT_ENABLED: "false" run: | ods openapi all # needed for pulling external images otherwise, we hit the "Unauthenticated users" limit # https://docs.docker.com/docker-hub/usage/ - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - name: Start Docker containers working-directory: ./deployment/docker_compose run: | docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d \ relational_db - name: Run Database Tests working-directory: ./backend run: pytest -m alembic tests/integration/tests/migrations/ ================================================ FILE: .github/workflows/pr-desktop-build.yml ================================================ name: Build Desktop App concurrency: group: Build-Desktop-App-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }} cancel-in-progress: true on: merge_group: pull_request: branches: - main - "release/**" paths: - "desktop/**" - ".github/workflows/pr-desktop-build.yml" push: tags: - "v*.*.*" permissions: contents: read jobs: build-desktop: name: Build Desktop (${{ matrix.platform }}) runs-on: ${{ matrix.os }} timeout-minutes: 60 strategy: fail-fast: false matrix: include: - platform: linux os: ubuntu-latest target: x86_64-unknown-linux-gnu args: "--bundles deb,rpm" # TODO: Fix and enable the macOS build. #- platform: macos # os: macos-latest # target: universal-apple-darwin # args: "--target universal-apple-darwin" # TODO: Fix and enable the Windows build. #- platform: windows # os: windows-latest # target: x86_64-pc-windows-msvc # args: "" steps: - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd with: persist-credentials: false - name: Setup node uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f with: node-version: 24 cache: "npm" # zizmor: ignore[cache-poisoning] cache-dependency-path: ./desktop/package-lock.json - name: Setup Rust uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9 with: toolchain: stable targets: ${{ matrix.target }} - name: Cache Cargo registry and build uses: actions/cache@668228422ae6a00e4ad889ee87cd7109ec5666a7 # zizmor: ignore[cache-poisoning] with: path: | ~/.cargo/bin/ ~/.cargo/registry/index/ ~/.cargo/registry/cache/ ~/.cargo/git/db/ desktop/src-tauri/target/ key: ${{ runner.os }}-cargo-${{ hashFiles('desktop/src-tauri/Cargo.lock') }} restore-keys: | ${{ runner.os }}-cargo- - name: Install Linux dependencies if: matrix.platform == 'linux' run: | sudo apt-get update sudo apt-get install -y \ build-essential \ libglib2.0-dev \ libgirepository1.0-dev \ libgtk-3-dev \ libjavascriptcoregtk-4.1-dev \ libwebkit2gtk-4.1-dev \ libayatana-appindicator3-dev \ gobject-introspection \ pkg-config \ curl \ xdg-utils - name: Install npm dependencies working-directory: ./desktop run: npm ci - name: Build desktop app working-directory: ./desktop run: npx tauri build ${{ matrix.args }} env: TAURI_SIGNING_PRIVATE_KEY: "" TAURI_SIGNING_PRIVATE_KEY_PASSWORD: "" - name: Upload build artifacts if: always() uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f with: name: desktop-build-${{ matrix.platform }}-${{ github.run_id }} path: | desktop/src-tauri/target/release/bundle/ retention-days: 7 if-no-files-found: ignore ================================================ FILE: .github/workflows/pr-external-dependency-unit-tests.yml ================================================ name: External Dependency Unit Tests concurrency: group: External-Dependency-Unit-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }} cancel-in-progress: true on: merge_group: pull_request: branches: [main] paths: - "backend/**" - "pyproject.toml" - "uv.lock" - ".github/workflows/pr-external-dependency-unit-tests.yml" - ".github/actions/setup-python-and-install-dependencies/**" - ".github/actions/setup-playwright/**" - "deployment/docker_compose/docker-compose.yml" - "deployment/docker_compose/docker-compose.dev.yml" push: tags: - "v*.*.*" permissions: contents: read env: # AWS credentials for S3-specific test S3_AWS_ACCESS_KEY_ID_FOR_TEST: ${{ secrets.S3_AWS_ACCESS_KEY_ID }} S3_AWS_SECRET_ACCESS_KEY_FOR_TEST: ${{ secrets.S3_AWS_SECRET_ACCESS_KEY }} # MinIO S3_ENDPOINT_URL: "http://localhost:9004" S3_AWS_ACCESS_KEY_ID: "minioadmin" S3_AWS_SECRET_ACCESS_KEY: "minioadmin" # Confluence CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }} CONFLUENCE_TEST_SPACE: ${{ vars.CONFLUENCE_TEST_SPACE }} CONFLUENCE_TEST_PAGE_ID: ${{ secrets.CONFLUENCE_TEST_PAGE_ID }} CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }} CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }} CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }} # Jira JIRA_ADMIN_API_TOKEN: ${{ secrets.JIRA_ADMIN_API_TOKEN }} # LLMs OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} VERTEX_CREDENTIALS: ${{ secrets.VERTEX_CREDENTIALS }} VERTEX_LOCATION: ${{ vars.VERTEX_LOCATION }} # Code Interpreter # TODO: debug why this is failing and enable CODE_INTERPRETER_BASE_URL: http://localhost:8000 jobs: discover-test-dirs: # NOTE: Github-hosted runners have about 20s faster queue times and are preferred here. runs-on: ubuntu-slim timeout-minutes: 45 outputs: test-dirs: ${{ steps.set-matrix.outputs.test-dirs }} steps: - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Discover test directories id: set-matrix run: | # Find all subdirectories in backend/tests/external_dependency_unit dirs=$(find backend/tests/external_dependency_unit -mindepth 1 -maxdepth 1 -type d -exec basename {} \; | sort | jq -R -s -c 'split("\n")[:-1]') echo "test-dirs=$dirs" >> $GITHUB_OUTPUT external-dependency-unit-tests: needs: discover-test-dirs # Use larger runner with more resources for Vespa runs-on: - runs-on - runner=2cpu-linux-arm64 - ${{ format('run-id={0}-external-dependency-unit-tests-job-{1}', github.run_id, strategy['job-index']) }} - extras=s3-cache timeout-minutes: 45 strategy: fail-fast: false matrix: test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }} env: PYTHONPATH: ./backend MODEL_SERVER_HOST: "disabled" DISABLE_TELEMETRY: "true" steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Setup Python and Install Dependencies uses: ./.github/actions/setup-python-and-install-dependencies with: requirements: | backend/requirements/default.txt backend/requirements/dev.txt backend/requirements/ee.txt - name: Setup Playwright uses: ./.github/actions/setup-playwright # needed for pulling Vespa, Redis, Postgres, and Minio images # otherwise, we hit the "Unauthenticated users" limit # https://docs.docker.com/docker-hub/usage/ - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - name: Create .env file for Docker Compose run: | cat < deployment/docker_compose/.env COMPOSE_PROFILES=s3-filestore,opensearch-enabled DISABLE_TELEMETRY=true OPENSEARCH_FOR_ONYX_ENABLED=true EOF - name: Set up Standard Dependencies run: | cd deployment/docker_compose docker compose \ -f docker-compose.yml \ -f docker-compose.dev.yml \ up -d \ minio \ relational_db \ cache \ index \ opensearch \ code-interpreter - name: Run migrations run: | cd backend # Run migrations to head alembic upgrade head alembic heads --verbose - name: Run Tests for ${{ matrix.test-dir }} shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}" env: TEST_DIR: ${{ matrix.test-dir }} run: | py.test \ --durations=8 \ -o junit_family=xunit2 \ -xv \ --ff \ backend/tests/external_dependency_unit/${TEST_DIR} - name: Collect Docker logs on failure if: failure() run: | mkdir -p docker-logs cd deployment/docker_compose # Get list of running containers containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q) # Collect logs from each container for container in $containers; do container_name=$(docker inspect --format='{{.Name}}' $container | sed 's/^\///') echo "Collecting logs from $container_name..." docker logs $container > ../../docker-logs/${container_name}.log 2>&1 done cd ../.. echo "Docker logs collected in docker-logs directory" - name: Upload Docker logs if: failure() uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f with: name: docker-logs-${{ matrix.test-dir }} path: docker-logs/ retention-days: 7 ================================================ FILE: .github/workflows/pr-golang-tests.yml ================================================ name: Golang Tests concurrency: group: Golang-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }} cancel-in-progress: true on: merge_group: pull_request: branches: - main - "release/**" push: tags: - "v*.*.*" permissions: {} env: GO_VERSION: "1.26" jobs: detect-modules: runs-on: ubuntu-latest timeout-minutes: 10 outputs: modules: ${{ steps.set-modules.outputs.modules }} steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd with: persist-credentials: false - id: set-modules run: echo "modules=$(find . -name 'go.mod' -exec dirname {} \; | jq -Rc '[.,inputs]')" >> "$GITHUB_OUTPUT" golang: needs: detect-modules runs-on: ubuntu-latest timeout-minutes: 10 strategy: matrix: modules: ${{ fromJSON(needs.detect-modules.outputs.modules) }} steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - uses: actions/setup-go@4dc6199c7b1a012772edbd06daecab0f50c9053c # zizmor: ignore[cache-poisoning] with: go-version: ${{ env.GO_VERSION }} cache-dependency-path: "**/go.sum" - run: go mod tidy working-directory: ${{ matrix.modules }} - run: git diff --exit-code go.mod go.sum working-directory: ${{ matrix.modules }} - run: go test ./... working-directory: ${{ matrix.modules }} ================================================ FILE: .github/workflows/pr-helm-chart-testing.yml ================================================ name: Helm - Lint and Test Charts concurrency: group: Helm-Lint-and-Test-Charts-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }} cancel-in-progress: true on: merge_group: pull_request: branches: [main] push: tags: - "v*.*.*" workflow_dispatch: # Allows manual triggering permissions: contents: read jobs: helm-chart-check: # See https://runs-on.com/runners/linux/ runs-on: [ runs-on, runner=8cpu-linux-x64, hdd=256, "run-id=${{ github.run_id }}-helm-chart-check", ] timeout-minutes: 45 # fetch-depth 0 is required for helm/chart-testing-action steps: - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: fetch-depth: 0 persist-credentials: false - name: Set up Helm uses: azure/setup-helm@1a275c3b69536ee54be43f2070a358922e12c8d4 # ratchet:azure/setup-helm@v4.3.1 with: version: v3.19.0 - name: Set up chart-testing uses: helm/chart-testing-action@2e2940618cb426dce2999631d543b53cdcfc8527 with: uv_version: "0.9.9" # even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command... - name: Run chart-testing (list-changed) id: list-changed env: DEFAULT_BRANCH: ${{ github.event.repository.default_branch }} run: | echo "default_branch: ${DEFAULT_BRANCH}" changed=$(ct list-changed --remote origin --target-branch ${DEFAULT_BRANCH} --chart-dirs deployment/helm/charts) echo "list-changed output: $changed" if [[ -n "$changed" ]]; then echo "changed=true" >> "$GITHUB_OUTPUT" fi # uncomment to force run chart-testing # - name: Force run chart-testing (list-changed) # id: list-changed # run: echo "changed=true" >> $GITHUB_OUTPUT # lint all charts if any changes were detected - name: Run chart-testing (lint) if: steps.list-changed.outputs.changed == 'true' run: ct lint --config ct.yaml --all # the following would lint only changed charts, but linting isn't expensive # run: ct lint --config ct.yaml --target-branch ${{ github.event.repository.default_branch }} - name: Create kind cluster if: steps.list-changed.outputs.changed == 'true' uses: helm/kind-action@ef37e7f390d99f746eb8b610417061a60e82a6cc # ratchet:helm/kind-action@v1.14.0 - name: Pre-install cluster status check if: steps.list-changed.outputs.changed == 'true' run: | echo "=== Pre-install Cluster Status ===" kubectl get nodes -o wide kubectl get pods --all-namespaces kubectl get storageclass - name: Add Helm repositories and update if: steps.list-changed.outputs.changed == 'true' run: | echo "=== Adding Helm repositories ===" helm repo add ingress-nginx https://kubernetes.github.io/ingress-nginx helm repo add vespa https://onyx-dot-app.github.io/vespa-helm-charts helm repo add opensearch https://opensearch-project.github.io/helm-charts helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts helm repo add minio https://charts.min.io/ helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/ helm repo update - name: Install Redis operator if: steps.list-changed.outputs.changed == 'true' shell: bash run: | echo "=== Installing redis-operator CRDs ===" helm upgrade --install redis-operator ot-container-kit/redis-operator \ --namespace redis-operator --create-namespace --wait --timeout 300s - name: Pre-pull required images if: steps.list-changed.outputs.changed == 'true' run: | echo "=== Pre-pulling required images to avoid timeout ===" KIND_CLUSTER=$(kubectl config current-context | sed 's/kind-//') echo "Kind cluster: $KIND_CLUSTER" IMAGES=( "ghcr.io/cloudnative-pg/cloudnative-pg:1.27.0" "quay.io/opstree/redis:v7.0.15" "docker.io/onyxdotapp/onyx-web-server:latest" ) for image in "${IMAGES[@]}"; do echo "Pre-pulling $image" if docker pull "$image"; then kind load docker-image "$image" --name "$KIND_CLUSTER" || echo "Failed to load $image into kind" else echo "Failed to pull $image" fi done echo "=== Images loaded into Kind cluster ===" docker exec "$KIND_CLUSTER"-control-plane crictl images | grep -E "(cloudnative-pg|redis|onyx)" || echo "Some images may still be loading..." - name: Validate chart dependencies if: steps.list-changed.outputs.changed == 'true' run: | echo "=== Validating chart dependencies ===" cd deployment/helm/charts/onyx helm dependency update helm lint . --set auth.userauth.values.user_auth_secret=placeholder - name: Run chart-testing (install) with enhanced monitoring timeout-minutes: 25 if: steps.list-changed.outputs.changed == 'true' run: | echo "=== Starting chart installation with monitoring ===" # Function to monitor cluster state monitor_cluster() { while true; do echo "=== Cluster Status Check at $(date) ===" # Only show non-running pods to reduce noise NON_RUNNING_PODS=$(kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded --no-headers 2>/dev/null | wc -l) if [ "$NON_RUNNING_PODS" -gt 0 ]; then echo "Non-running pods:" kubectl get pods --all-namespaces --field-selector=status.phase!=Running,status.phase!=Succeeded else echo "All pods running successfully" fi # Only show recent events if there are issues RECENT_EVENTS=$(kubectl get events --sort-by=.lastTimestamp --all-namespaces --field-selector=type!=Normal 2>/dev/null | tail -5) if [ -n "$RECENT_EVENTS" ]; then echo "Recent warnings/errors:" echo "$RECENT_EVENTS" fi sleep 60 done } # Start monitoring in background monitor_cluster & MONITOR_PID=$! # Set up cleanup cleanup() { echo "=== Cleaning up monitoring process ===" kill $MONITOR_PID 2>/dev/null || true echo "=== Final cluster state ===" kubectl get pods --all-namespaces kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -20 } # Trap cleanup on exit trap cleanup EXIT # Run the actual installation with detailed logging # Note that opensearch.enabled is true whereas others in this install # are false. There is some work that needs to be done to get this # entire step working in CI, enabling opensearch here is a small step # in that direction. If this is causing issues, disabling it in this # step should be ok in the short term. echo "=== Starting ct install ===" set +e ct install --all \ --helm-extra-set-args="\ --set=nginx.enabled=false \ --set=minio.enabled=false \ --set=vespa.enabled=false \ --set=opensearch.enabled=true \ --set=auth.opensearch.enabled=true \ --set=auth.userauth.values.user_auth_secret=test-secret \ --set=slackbot.enabled=false \ --set=postgresql.enabled=true \ --set=postgresql.cluster.storage.storageClass=standard \ --set=redis.enabled=true \ --set=redis.storageSpec.volumeClaimTemplate.spec.storageClassName=standard \ --set=webserver.replicaCount=1 \ --set=api.replicaCount=0 \ --set=inferenceCapability.replicaCount=0 \ --set=indexCapability.replicaCount=0 \ --set=celery_beat.replicaCount=0 \ --set=celery_worker_heavy.replicaCount=0 \ --set=celery_worker_docfetching.replicaCount=0 \ --set=celery_worker_docprocessing.replicaCount=0 \ --set=celery_worker_light.replicaCount=0 \ --set=celery_worker_monitoring.replicaCount=0 \ --set=celery_worker_primary.replicaCount=0 \ --set=celery_worker_user_file_processing.replicaCount=0 \ --set=celery_worker_user_files_indexing.replicaCount=0" \ --helm-extra-args="--timeout 900s --debug" \ --debug --config ct.yaml CT_EXIT=$? set -e if [[ $CT_EXIT -ne 0 ]]; then echo "ct install failed with exit code $CT_EXIT" exit $CT_EXIT else echo "=== Installation completed successfully ===" fi kubectl get pods --all-namespaces - name: Post-install verification if: steps.list-changed.outputs.changed == 'true' run: | echo "=== Post-install verification ===" if ! kubectl cluster-info >/dev/null 2>&1; then echo "ERROR: Kubernetes cluster is not reachable after install" exit 1 fi kubectl get pods --all-namespaces kubectl get services --all-namespaces # Only show issues if they exist kubectl describe pods --all-namespaces | grep -A 5 -B 2 "Failed\|Error\|Warning" || echo "No pod issues found" - name: Cleanup on failure if: failure() && steps.list-changed.outputs.changed == 'true' run: | echo "=== Cleanup on failure ===" if ! kubectl cluster-info >/dev/null 2>&1; then echo "Skipping failure cleanup: Kubernetes cluster is not reachable" exit 0 fi echo "=== Final cluster state ===" kubectl get pods --all-namespaces kubectl get events --all-namespaces --sort-by=.lastTimestamp | tail -10 echo "=== Pod descriptions for debugging ===" kubectl describe pods --all-namespaces | grep -A 10 -B 3 "Failed\|Error\|Warning\|Pending" || echo "No problematic pods found" echo "=== Recent logs for debugging ===" kubectl logs --all-namespaces --tail=50 | grep -i "error\|timeout\|failed\|pull" || echo "No error logs found" echo "=== Helm releases ===" helm list --all-namespaces # the following would install only changed charts, but we only have one chart so # don't worry about that for now # run: ct install --target-branch ${{ github.event.repository.default_branch }} ================================================ FILE: .github/workflows/pr-integration-tests.yml ================================================ name: Run Integration Tests v2 concurrency: group: Run-Integration-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }} cancel-in-progress: true on: merge_group: pull_request: branches: - main - "release/**" push: tags: - "v*.*.*" permissions: contents: read env: # Test Environment Variables OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} SLACK_BOT_TOKEN_TEST_SPACE: ${{ secrets.SLACK_BOT_TOKEN_TEST_SPACE }} CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }} CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }} CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }} CONFLUENCE_ACCESS_TOKEN_SCOPED: ${{ secrets.CONFLUENCE_ACCESS_TOKEN_SCOPED }} JIRA_BASE_URL: ${{ secrets.JIRA_BASE_URL }} JIRA_USER_EMAIL: ${{ secrets.JIRA_USER_EMAIL }} JIRA_API_TOKEN: ${{ secrets.JIRA_API_TOKEN }} JIRA_API_TOKEN_SCOPED: ${{ secrets.JIRA_API_TOKEN_SCOPED }} PERM_SYNC_SHAREPOINT_CLIENT_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_CLIENT_ID }} PERM_SYNC_SHAREPOINT_PRIVATE_KEY: ${{ secrets.PERM_SYNC_SHAREPOINT_PRIVATE_KEY }} PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD: ${{ secrets.PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD }} PERM_SYNC_SHAREPOINT_DIRECTORY_ID: ${{ secrets.PERM_SYNC_SHAREPOINT_DIRECTORY_ID }} EXA_API_KEY: ${{ secrets.EXA_API_KEY }} GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN: ${{ secrets.ONYX_GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN }} GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC: ${{ secrets.ONYX_GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC }} GITHUB_ADMIN_EMAIL: ${{ secrets.ONYX_GITHUB_ADMIN_EMAIL }} GITHUB_TEST_USER_1_EMAIL: ${{ secrets.ONYX_GITHUB_TEST_USER_1_EMAIL }} GITHUB_TEST_USER_2_EMAIL: ${{ secrets.ONYX_GITHUB_TEST_USER_2_EMAIL }} jobs: discover-test-dirs: # NOTE: Github-hosted runners have about 20s faster queue times and are preferred here. runs-on: ubuntu-slim timeout-minutes: 45 outputs: test-dirs: ${{ steps.set-matrix.outputs.test-dirs }} editions: ${{ steps.set-editions.outputs.editions }} steps: - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Discover test directories id: set-matrix run: | # Find all leaf-level directories in both test directories tests_dirs=$(find backend/tests/integration/tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" ! -name "mcp" ! -name "no_vectordb" -exec basename {} \; | sort) connector_dirs=$(find backend/tests/integration/connector_job_tests -mindepth 1 -maxdepth 1 -type d ! -name "__pycache__" -exec basename {} \; | sort) # Create JSON array with directory info all_dirs="" for dir in $tests_dirs; do all_dirs="$all_dirs{\"path\":\"tests/$dir\",\"name\":\"tests-$dir\"}," done for dir in $connector_dirs; do all_dirs="$all_dirs{\"path\":\"connector_job_tests/$dir\",\"name\":\"connector-$dir\"}," done # Remove trailing comma and wrap in array all_dirs="[${all_dirs%,}]" echo "test-dirs=$all_dirs" >> $GITHUB_OUTPUT - name: Determine editions to test id: set-editions run: | # On PRs, only run EE tests. On merge_group and tags, run both EE and MIT. if [ "${{ github.event_name }}" = "pull_request" ]; then echo 'editions=["ee"]' >> $GITHUB_OUTPUT else echo 'editions=["ee","mit"]' >> $GITHUB_OUTPUT fi build-backend-image: runs-on: [ runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache", ] timeout-minutes: 45 steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Format branch name for cache id: format-branch env: PR_NUMBER: ${{ github.event.pull_request.number }} REF_NAME: ${{ github.ref_name }} run: | if [ -n "${PR_NUMBER}" ]; then CACHE_SUFFIX="${PR_NUMBER}" else # shellcheck disable=SC2001 CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g') fi echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 # needed for pulling Vespa, Redis, Postgres, and Minio images # otherwise, we hit the "Unauthenticated users" limit # https://docs.docker.com/docker-hub/usage/ - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - name: Build and push Backend Docker image uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6 with: context: ./backend file: ./backend/Dockerfile push: true tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-backend-test-${{ github.run_id }} cache-from: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }} type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }} type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache type=registry,ref=onyxdotapp/onyx-backend:latest cache-to: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }} build-model-server-image: runs-on: [ runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache", ] timeout-minutes: 45 steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Format branch name for cache id: format-branch env: PR_NUMBER: ${{ github.event.pull_request.number }} REF_NAME: ${{ github.ref_name }} run: | if [ -n "${PR_NUMBER}" ]; then CACHE_SUFFIX="${PR_NUMBER}" else # shellcheck disable=SC2001 CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g') fi echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 # needed for pulling Vespa, Redis, Postgres, and Minio images # otherwise, we hit the "Unauthenticated users" limit # https://docs.docker.com/docker-hub/usage/ - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - name: Build and push Model Server Docker image uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6 with: context: ./backend file: ./backend/Dockerfile.model_server push: true tags: ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-model-server-test-${{ github.run_id }} cache-from: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }} type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }} type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache type=registry,ref=onyxdotapp/onyx-model-server:latest cache-to: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max build-integration-image: runs-on: [ runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-build-integration-image", "extras=ecr-cache", ] timeout-minutes: 45 steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 # needed for pulling openapitools/openapi-generator-cli # otherwise, we hit the "Unauthenticated users" limit # https://docs.docker.com/docker-hub/usage/ - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - name: Format branch name for cache id: format-branch env: PR_NUMBER: ${{ github.event.pull_request.number }} REF_NAME: ${{ github.ref_name }} run: | if [ -n "${PR_NUMBER}" ]; then CACHE_SUFFIX="${PR_NUMBER}" else # shellcheck disable=SC2001 CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g') fi echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT - name: Build and push integration test image with Docker Bake env: INTEGRATION_REPOSITORY: ${{ env.RUNS_ON_ECR_CACHE }} TAG: integration-test-${{ github.run_id }} CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }} HEAD_SHA: ${{ github.event.pull_request.head.sha || github.sha }} run: | docker buildx bake --push \ --set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA} \ --set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX} \ --set backend.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache \ --set backend.cache-from=type=registry,ref=onyxdotapp/onyx-backend:latest \ --set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${HEAD_SHA},mode=max \ --set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache-${CACHE_SUFFIX},mode=max \ --set backend.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:backend-cache,mode=max \ --set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA} \ --set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX} \ --set integration.cache-from=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache \ --set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${HEAD_SHA},mode=max \ --set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache-${CACHE_SUFFIX},mode=max \ --set integration.cache-to=type=registry,ref=${RUNS_ON_ECR_CACHE}:integration-cache,mode=max \ integration integration-tests: needs: [ discover-test-dirs, build-backend-image, build-model-server-image, build-integration-image, ] runs-on: - runs-on - runner=4cpu-linux-arm64 - ${{ format('run-id={0}-integration-tests-{1}-job-{2}', github.run_id, matrix.edition, strategy['job-index']) }} - extras=ecr-cache timeout-minutes: 45 strategy: fail-fast: false matrix: test-dir: ${{ fromJson(needs.discover-test-dirs.outputs.test-dirs) }} edition: ${{ fromJson(needs.discover-test-dirs.outputs.editions) }} steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false # needed for pulling Vespa, Redis, Postgres, and Minio images # otherwise, we hit the "Unauthenticated users" limit # https://docs.docker.com/docker-hub/usage/ - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} # NOTE: Use pre-ping/null pool to reduce flakiness due to dropped connections # NOTE: don't need web server for integration tests - name: Create .env file for Docker Compose env: ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }} RUN_ID: ${{ github.run_id }} EDITION: ${{ matrix.edition }} run: | # Base config shared by both editions cat < deployment/docker_compose/.env COMPOSE_PROFILES=s3-filestore OPENSEARCH_FOR_ONYX_ENABLED=false AUTH_TYPE=basic POSTGRES_POOL_PRE_PING=true POSTGRES_USE_NULL_POOL=true REQUIRE_EMAIL_VERIFICATION=false DISABLE_TELEMETRY=true ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} INTEGRATION_TESTS_MODE=true MCP_SERVER_ENABLED=true AUTO_LLM_UPDATE_INTERVAL_SECONDS=10 EOF # EE-only config if [ "$EDITION" = "ee" ]; then cat <> deployment/docker_compose/.env ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true # TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license LICENSE_ENFORCEMENT_ENABLED=false CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS=0.001 EOF fi - name: Start Docker containers run: | cd deployment/docker_compose docker compose -f docker-compose.yml -f docker-compose.dev.yml up \ relational_db \ index \ cache \ minio \ api_server \ inference_model_server \ indexing_model_server \ background \ -d id: start_docker - name: Wait for services to be ready run: | echo "Starting wait-for-service script..." wait_for_service() { local url=$1 local label=$2 local timeout=${3:-300} # default 5 minutes local start_time start_time=$(date +%s) while true; do local current_time current_time=$(date +%s) local elapsed_time=$((current_time - start_time)) if [ $elapsed_time -ge $timeout ]; then echo "Timeout reached. ${label} did not become ready in $timeout seconds." exit 1 fi local response response=$(curl -s -o /dev/null -w "%{http_code}" "$url" || echo "curl_error") if [ "$response" = "200" ]; then echo "${label} is ready!" break elif [ "$response" = "curl_error" ]; then echo "Curl encountered an error while checking ${label}. Retrying in 5 seconds..." else echo "${label} not ready yet (HTTP status $response). Retrying in 5 seconds..." fi sleep 5 done } wait_for_service "http://localhost:8080/health" "API server" echo "Finished waiting for services." - name: Start Mock Services run: | cd backend/tests/integration/mock_services docker compose -f docker-compose.mock-it-services.yml \ -p mock-it-services-stack up -d - name: Run Integration Tests (${{ matrix.edition }}) for ${{ matrix.test-dir.name }} uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3 with: timeout_minutes: 20 max_attempts: 3 retry_wait_seconds: 10 command: | echo "Running ${{ matrix.edition }} integration tests for ${{ matrix.test-dir.path }}..." docker run --rm --network onyx_default \ --name test-runner \ -e POSTGRES_HOST=relational_db \ -e POSTGRES_USER=postgres \ -e POSTGRES_PASSWORD=password \ -e POSTGRES_DB=postgres \ -e DB_READONLY_USER=db_readonly_user \ -e DB_READONLY_PASSWORD=password \ -e POSTGRES_POOL_PRE_PING=true \ -e POSTGRES_USE_NULL_POOL=true \ -e VESPA_HOST=index \ -e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \ -e REDIS_HOST=cache \ -e API_SERVER_HOST=api_server \ -e OPENAI_API_KEY=${OPENAI_API_KEY} \ -e EXA_API_KEY=${EXA_API_KEY} \ -e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \ -e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \ -e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \ -e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \ -e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \ -e CONFLUENCE_ACCESS_TOKEN_SCOPED=${CONFLUENCE_ACCESS_TOKEN_SCOPED} \ -e JIRA_BASE_URL=${JIRA_BASE_URL} \ -e JIRA_USER_EMAIL=${JIRA_USER_EMAIL} \ -e JIRA_API_TOKEN=${JIRA_API_TOKEN} \ -e JIRA_API_TOKEN_SCOPED=${JIRA_API_TOKEN_SCOPED} \ -e PERM_SYNC_SHAREPOINT_CLIENT_ID=${PERM_SYNC_SHAREPOINT_CLIENT_ID} \ -e PERM_SYNC_SHAREPOINT_PRIVATE_KEY="${PERM_SYNC_SHAREPOINT_PRIVATE_KEY}" \ -e PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD=${PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD} \ -e PERM_SYNC_SHAREPOINT_DIRECTORY_ID=${PERM_SYNC_SHAREPOINT_DIRECTORY_ID} \ -e GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN=${GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN} \ -e GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC=${GITHUB_PERMISSION_SYNC_TEST_ACCESS_TOKEN_CLASSIC} \ -e GITHUB_ADMIN_EMAIL=${GITHUB_ADMIN_EMAIL} \ -e GITHUB_TEST_USER_1_EMAIL=${GITHUB_TEST_USER_1_EMAIL} \ -e GITHUB_TEST_USER_2_EMAIL=${GITHUB_TEST_USER_2_EMAIL} \ -e TEST_WEB_HOSTNAME=test-runner \ -e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \ -e MOCK_CONNECTOR_SERVER_PORT=8001 \ -e ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=${{ matrix.edition == 'ee' && 'true' || 'false' }} \ ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \ /app/tests/integration/${{ matrix.test-dir.path }} # ------------------------------------------------------------ # Always gather logs BEFORE "down": - name: Dump API server logs if: always() run: | cd deployment/docker_compose docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true - name: Dump all-container logs (optional) if: always() run: | cd deployment/docker_compose docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true - name: Upload logs if: always() uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f with: name: docker-all-logs-${{ matrix.edition }}-${{ matrix.test-dir.name }} path: ${{ github.workspace }}/docker-compose.log # ------------------------------------------------------------ onyx-lite-tests: needs: [build-backend-image, build-integration-image] runs-on: [ runs-on, runner=4cpu-linux-arm64, "run-id=${{ github.run_id }}-onyx-lite-tests", "extras=ecr-cache", ] timeout-minutes: 45 steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - name: Create .env file for Onyx Lite Docker Compose env: ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }} RUN_ID: ${{ github.run_id }} run: | cat < deployment/docker_compose/.env ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true LICENSE_ENFORCEMENT_ENABLED=false AUTH_TYPE=basic POSTGRES_POOL_PRE_PING=true POSTGRES_USE_NULL_POOL=true REQUIRE_EMAIL_VERIFICATION=false DISABLE_TELEMETRY=true ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} INTEGRATION_TESTS_MODE=true EOF # Start only the services needed for Onyx Lite (Postgres + API server) - name: Start Docker containers (onyx-lite) run: | cd deployment/docker_compose docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up \ relational_db \ api_server \ -d id: start_docker_onyx_lite - name: Wait for services to be ready run: | echo "Starting wait-for-service script (onyx-lite)..." start_time=$(date +%s) timeout=300 while true; do current_time=$(date +%s) elapsed_time=$((current_time - start_time)) if [ $elapsed_time -ge $timeout ]; then echo "Timeout reached. Service did not become ready in $timeout seconds." exit 1 fi response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error") if [ "$response" = "200" ]; then echo "API server is ready!" break elif [ "$response" = "curl_error" ]; then echo "Curl encountered an error; retrying..." else echo "Service not ready yet (HTTP $response). Retrying in 5 seconds..." fi sleep 5 done - name: Run Onyx Lite Integration Tests uses: nick-fields/retry@ce71cc2ab81d554ebbe88c79ab5975992d79ba08 # ratchet:nick-fields/retry@v3 with: timeout_minutes: 20 max_attempts: 3 retry_wait_seconds: 10 command: | echo "Running onyx-lite integration tests..." docker run --rm --network onyx_default \ --name test-runner \ -e POSTGRES_HOST=relational_db \ -e POSTGRES_USER=postgres \ -e POSTGRES_PASSWORD=password \ -e POSTGRES_DB=postgres \ -e DB_READONLY_USER=db_readonly_user \ -e DB_READONLY_PASSWORD=password \ -e POSTGRES_POOL_PRE_PING=true \ -e POSTGRES_USE_NULL_POOL=true \ -e API_SERVER_HOST=api_server \ -e OPENAI_API_KEY=${OPENAI_API_KEY} \ -e TEST_WEB_HOSTNAME=test-runner \ ${{ env.RUNS_ON_ECR_CACHE }}:integration-test-${{ github.run_id }} \ /app/tests/integration/tests/no_vectordb - name: Dump API server logs (onyx-lite) if: always() run: | cd deployment/docker_compose docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml \ logs --no-color api_server > $GITHUB_WORKSPACE/api_server_onyx_lite.log || true - name: Dump all-container logs (onyx-lite) if: always() run: | cd deployment/docker_compose docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml \ logs --no-color > $GITHUB_WORKSPACE/docker-compose-onyx-lite.log || true - name: Upload logs (onyx-lite) if: always() uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f with: name: docker-all-logs-onyx-lite path: ${{ github.workspace }}/docker-compose-onyx-lite.log - name: Stop Docker containers (onyx-lite) if: always() run: | cd deployment/docker_compose docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml down -v multitenant-tests: needs: [build-backend-image, build-model-server-image, build-integration-image] runs-on: [ runs-on, runner=8cpu-linux-arm64, "run-id=${{ github.run_id }}-multitenant-tests", "extras=ecr-cache", ] timeout-minutes: 45 steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - name: Start Docker containers for multi-tenant tests env: ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }} RUN_ID: ${{ github.run_id }} run: | cd deployment/docker_compose ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \ LICENSE_ENFORCEMENT_ENABLED=false \ MULTI_TENANT=true \ AUTH_TYPE=cloud \ REQUIRE_EMAIL_VERIFICATION=false \ DISABLE_TELEMETRY=true \ OPENAI_DEFAULT_API_KEY=${OPENAI_API_KEY} \ ONYX_BACKEND_IMAGE=${ECR_CACHE}:integration-test-backend-test-${RUN_ID} \ ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:integration-test-model-server-test-${RUN_ID} \ DEV_MODE=true \ OPENSEARCH_FOR_ONYX_ENABLED=false \ docker compose -f docker-compose.multitenant-dev.yml up \ relational_db \ index \ cache \ minio \ api_server \ inference_model_server \ indexing_model_server \ background \ -d id: start_docker_multi_tenant - name: Wait for service to be ready (multi-tenant) run: | echo "Starting wait-for-service script for multi-tenant..." docker logs -f onyx-api_server-1 & start_time=$(date +%s) timeout=300 while true; do current_time=$(date +%s) elapsed_time=$((current_time - start_time)) if [ $elapsed_time -ge $timeout ]; then echo "Timeout reached. Service did not become ready in 5 minutes." exit 1 fi response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error") if [ "$response" = "200" ]; then echo "Service is ready!" break elif [ "$response" = "curl_error" ]; then echo "Curl encountered an error; retrying..." else echo "Service not ready yet (HTTP $response). Retrying in 5 seconds..." fi sleep 5 done echo "Finished waiting for service." - name: Run Multi-Tenant Integration Tests env: ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }} RUN_ID: ${{ github.run_id }} run: | echo "Running multi-tenant integration tests..." docker run --rm --network onyx_default \ --name test-runner \ -e POSTGRES_HOST=relational_db \ -e POSTGRES_USER=postgres \ -e POSTGRES_PASSWORD=password \ -e DB_READONLY_USER=db_readonly_user \ -e DB_READONLY_PASSWORD=password \ -e POSTGRES_DB=postgres \ -e POSTGRES_USE_NULL_POOL=true \ -e VESPA_HOST=index \ -e ENABLE_OPENSEARCH_INDEXING_FOR_ONYX=false \ -e REDIS_HOST=cache \ -e API_SERVER_HOST=api_server \ -e OPENAI_API_KEY=${OPENAI_API_KEY} \ -e EXA_API_KEY=${EXA_API_KEY} \ -e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \ -e SLACK_BOT_TOKEN_TEST_SPACE=${SLACK_BOT_TOKEN_TEST_SPACE} \ -e TEST_WEB_HOSTNAME=test-runner \ -e AUTH_TYPE=cloud \ -e MULTI_TENANT=true \ -e SKIP_RESET=true \ -e REQUIRE_EMAIL_VERIFICATION=false \ -e DISABLE_TELEMETRY=true \ -e DEV_MODE=true \ ${ECR_CACHE}:integration-test-${RUN_ID} \ /app/tests/integration/multitenant_tests - name: Dump API server logs (multi-tenant) if: always() run: | cd deployment/docker_compose docker compose -f docker-compose.multitenant-dev.yml logs --no-color api_server > $GITHUB_WORKSPACE/api_server_multitenant.log || true - name: Dump all-container logs (multi-tenant) if: always() run: | cd deployment/docker_compose docker compose -f docker-compose.multitenant-dev.yml logs --no-color > $GITHUB_WORKSPACE/docker-compose-multitenant.log || true - name: Upload logs (multi-tenant) if: always() uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f with: name: docker-all-logs-multitenant path: ${{ github.workspace }}/docker-compose-multitenant.log - name: Stop multi-tenant Docker containers if: always() run: | cd deployment/docker_compose docker compose -f docker-compose.multitenant-dev.yml down -v required: # NOTE: Github-hosted runners have about 20s faster queue times and are preferred here. runs-on: ubuntu-slim timeout-minutes: 45 needs: [integration-tests, onyx-lite-tests, multitenant-tests] if: ${{ always() }} steps: - name: Check job status if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }} run: exit 1 ================================================ FILE: .github/workflows/pr-jest-tests.yml ================================================ name: Run Jest Tests concurrency: group: Run-Jest-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }} cancel-in-progress: true on: merge_group: pull_request: branches: - main - "release/**" push: tags: - "v*.*.*" permissions: contents: read jobs: jest-tests: name: Jest Tests runs-on: ubuntu-latest timeout-minutes: 45 steps: - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Setup node uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4 with: node-version: 22 cache: "npm" # zizmor: ignore[cache-poisoning] test-only workflow; no deploy artifacts cache-dependency-path: ./web/package-lock.json - name: Install node dependencies working-directory: ./web run: npm ci - name: Run Jest tests working-directory: ./web run: npm test -- --ci --coverage --maxWorkers=50% - name: Upload coverage reports if: always() uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f with: name: jest-coverage-${{ github.run_id }} path: ./web/coverage retention-days: 7 ================================================ FILE: .github/workflows/pr-labeler.yml ================================================ name: PR Labeler on: pull_request: branches: - main types: - opened - reopened - synchronize - edited permissions: contents: read jobs: validate_pr_title: runs-on: ubuntu-latest timeout-minutes: 45 steps: - name: Check PR title for Conventional Commits env: PR_TITLE: ${{ github.event.pull_request.title }} run: | echo "PR Title: $PR_TITLE" if [[ ! "$PR_TITLE" =~ ^(feat|fix|docs|test|ci|refactor|perf|chore|revert|build)(\(.+\))?:\ .+ ]]; then echo "::error::❌ Your PR title does not follow the Conventional Commits format. This check ensures that all pull requests use clear, consistent titles that help automate changelogs and improve project history. Please update your PR title to follow the Conventional Commits style. Here is a link to a blog explaining the reason why we've included the Conventional Commits style into our PR titles: https://xfuture-blog.com/working-with-conventional-commits **Here are some examples of valid PR titles:** - feat: add user authentication - fix(login): handle null password error - docs(readme): update installation instructions" exit 1 fi ================================================ FILE: .github/workflows/pr-linear-check.yml ================================================ name: Ensure PR references Linear concurrency: group: Ensure-PR-references-Linear-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }} cancel-in-progress: true on: pull_request: types: [opened, edited, reopened, synchronize] permissions: contents: read jobs: linear-check: runs-on: ubuntu-latest timeout-minutes: 45 steps: - name: Check PR body for Linear link or override env: PR_BODY: ${{ github.event.pull_request.body }} run: | # Looking for "https://linear.app" in the body if echo "$PR_BODY" | grep -qE "https://linear\.app"; then echo "Found a Linear link. Check passed." exit 0 fi # Looking for a checked override: "[x] Override Linear Check" if echo "$PR_BODY" | grep -q "\[x\].*Override Linear Check"; then echo "Override box is checked. Check passed." exit 0 fi # Otherwise, fail the run echo "No Linear link or override found in the PR description." exit 1 ================================================ FILE: .github/workflows/pr-playwright-tests.yml ================================================ name: Run Playwright Tests concurrency: group: Run-Playwright-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }} cancel-in-progress: true on: merge_group: pull_request: branches: - main - "release/**" push: tags: - "v*.*.*" # TODO: Remove this if we enable merge-queues for release branches. branches: - "release/**" permissions: contents: read env: # Test Environment Variables OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }} GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }} EXA_API_KEY: ${{ secrets.EXA_API_KEY }} FIRECRAWL_API_KEY: ${{ secrets.FIRECRAWL_API_KEY }} GOOGLE_PSE_API_KEY: ${{ secrets.GOOGLE_PSE_API_KEY }} GOOGLE_PSE_SEARCH_ENGINE_ID: ${{ secrets.GOOGLE_PSE_SEARCH_ENGINE_ID }} # for federated slack tests SLACK_CLIENT_ID: ${{ secrets.SLACK_CLIENT_ID }} SLACK_CLIENT_SECRET: ${{ secrets.SLACK_CLIENT_SECRET }} # for MCP Oauth tests MCP_OAUTH_CLIENT_ID: ${{ secrets.MCP_OAUTH_CLIENT_ID }} MCP_OAUTH_CLIENT_SECRET: ${{ secrets.MCP_OAUTH_CLIENT_SECRET }} MCP_OAUTH_ISSUER: ${{ secrets.MCP_OAUTH_ISSUER }} MCP_OAUTH_JWKS_URI: ${{ secrets.MCP_OAUTH_JWKS_URI }} MCP_OAUTH_USERNAME: ${{ vars.MCP_OAUTH_USERNAME }} MCP_OAUTH_PASSWORD: ${{ secrets.MCP_OAUTH_PASSWORD }} # for MCP API Key tests MCP_API_KEY: test-api-key-12345 MCP_API_KEY_TEST_PORT: 8005 MCP_API_KEY_TEST_URL: http://host.docker.internal:8005/mcp MCP_API_KEY_SERVER_HOST: 0.0.0.0 MCP_API_KEY_SERVER_PUBLIC_HOST: host.docker.internal MOCK_LLM_RESPONSE: true MCP_TEST_SERVER_PORT: 8004 MCP_TEST_SERVER_URL: http://host.docker.internal:8004/mcp MCP_TEST_SERVER_PUBLIC_URL: http://host.docker.internal:8004/mcp MCP_TEST_SERVER_BIND_HOST: 0.0.0.0 MCP_TEST_SERVER_PUBLIC_HOST: host.docker.internal MCP_SERVER_HOST: 0.0.0.0 MCP_SERVER_PUBLIC_HOST: host.docker.internal MCP_SERVER_PUBLIC_URL: http://host.docker.internal:8004/mcp # Visual regression S3 bucket (shared across all jobs) PLAYWRIGHT_S3_BUCKET: onyx-playwright-artifacts jobs: build-web-image: runs-on: [ runs-on, runner=4cpu-linux-arm64, "run-id=${{ github.run_id }}-build-web-image", "extras=ecr-cache", ] timeout-minutes: 45 steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Format branch name for cache id: format-branch env: PR_NUMBER: ${{ github.event.pull_request.number }} REF_NAME: ${{ github.ref_name }} run: | if [ -n "${PR_NUMBER}" ]; then CACHE_SUFFIX="${PR_NUMBER}" else # shellcheck disable=SC2001 CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g') fi echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 # needed for pulling external images otherwise, we hit the "Unauthenticated users" limit # https://docs.docker.com/docker-hub/usage/ - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - name: Build and push Web Docker image uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6 with: context: ./web file: ./web/Dockerfile platforms: linux/arm64 tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-web-${{ github.run_id }} push: true cache-from: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-${{ github.event.pull_request.head.sha || github.sha }} type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-${{ steps.format-branch.outputs.cache-suffix }} type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache type=registry,ref=onyxdotapp/onyx-web-server:latest cache-to: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:web-cache,mode=max no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }} build-backend-image: runs-on: [ runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache", ] timeout-minutes: 45 steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Format branch name for cache id: format-branch env: PR_NUMBER: ${{ github.event.pull_request.number }} REF_NAME: ${{ github.ref_name }} run: | if [ -n "${PR_NUMBER}" ]; then CACHE_SUFFIX="${PR_NUMBER}" else # shellcheck disable=SC2001 CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g') fi echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 # needed for pulling external images otherwise, we hit the "Unauthenticated users" limit # https://docs.docker.com/docker-hub/usage/ - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - name: Build and push Backend Docker image uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6 with: context: ./backend file: ./backend/Dockerfile platforms: linux/arm64 tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-backend-${{ github.run_id }} push: true cache-from: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }} type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }} type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache type=registry,ref=onyxdotapp/onyx-backend:latest cache-to: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:backend-cache,mode=max no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }} build-model-server-image: runs-on: [ runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache", ] timeout-minutes: 45 steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Format branch name for cache id: format-branch env: PR_NUMBER: ${{ github.event.pull_request.number }} REF_NAME: ${{ github.ref_name }} run: | if [ -n "${PR_NUMBER}" ]; then CACHE_SUFFIX="${PR_NUMBER}" else # shellcheck disable=SC2001 CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g') fi echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 # needed for pulling external images otherwise, we hit the "Unauthenticated users" limit # https://docs.docker.com/docker-hub/usage/ - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - name: Build and push Model Server Docker image uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6 with: context: ./backend file: ./backend/Dockerfile.model_server platforms: linux/arm64 tags: ${{ env.RUNS_ON_ECR_CACHE }}:playwright-test-model-server-${{ github.run_id }} push: true cache-from: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }} type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }} type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache type=registry,ref=onyxdotapp/onyx-model-server:latest cache-to: | type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' }} playwright-tests: needs: [build-web-image, build-backend-image, build-model-server-image] name: Playwright Tests (${{ matrix.project }}) permissions: id-token: write # Required for OIDC-based AWS credential exchange (S3 access) contents: read runs-on: - runs-on - runner=8cpu-linux-arm64 - "run-id=${{ github.run_id }}-playwright-tests-${{ matrix.project }}" - "extras=ecr-cache" - volume=50gb timeout-minutes: 45 strategy: fail-fast: false matrix: project: [admin, exclusive] steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Setup node # zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4 with: node-version: 22 cache: "npm" # zizmor: ignore[cache-poisoning] cache-dependency-path: ./web/package-lock.json - name: Install node dependencies working-directory: ./web run: npm ci - name: Cache playwright cache # zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4 with: path: ~/.cache/ms-playwright key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }} restore-keys: | ${{ runner.os }}-playwright-npm- - name: Install playwright browsers working-directory: ./web run: npx playwright install --with-deps - name: Create .env file for Docker Compose env: OPENAI_API_KEY_VALUE: ${{ env.OPENAI_API_KEY }} EXA_API_KEY_VALUE: ${{ env.EXA_API_KEY }} ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }} RUN_ID: ${{ github.run_id }} run: | cat < deployment/docker_compose/.env COMPOSE_PROFILES=s3-filestore ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true # TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license LICENSE_ENFORCEMENT_ENABLED=false AUTH_TYPE=basic INTEGRATION_TESTS_MODE=true GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE} EXA_API_KEY=${EXA_API_KEY_VALUE} REQUIRE_EMAIL_VERIFICATION=false DISABLE_TELEMETRY=true ONYX_BACKEND_IMAGE=${ECR_CACHE}:playwright-test-backend-${RUN_ID} ONYX_MODEL_SERVER_IMAGE=${ECR_CACHE}:playwright-test-model-server-${RUN_ID} ONYX_WEB_SERVER_IMAGE=${ECR_CACHE}:playwright-test-web-${RUN_ID} EOF # needed for pulling Vespa, Redis, Postgres, and Minio images # otherwise, we hit the "Unauthenticated users" limit # https://docs.docker.com/docker-hub/usage/ - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - name: Start Docker containers run: | cd deployment/docker_compose docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.mcp-oauth-test.yml -f docker-compose.mcp-api-key-test.yml up -d id: start_docker - name: Wait for service to be ready run: | echo "Starting wait-for-service script..." docker logs -f onyx-api_server-1 & start_time=$(date +%s) timeout=300 # 5 minutes in seconds while true; do current_time=$(date +%s) elapsed_time=$((current_time - start_time)) if [ $elapsed_time -ge $timeout ]; then echo "Timeout reached. Service did not become ready in 5 minutes." exit 1 fi # Use curl with error handling to ignore specific exit code 56 response=$(curl -s -o /dev/null -w "%{http_code}" http://localhost:8080/health || echo "curl_error") if [ "$response" = "200" ]; then echo "Service is ready!" break elif [ "$response" = "curl_error" ]; then echo "Curl encountered an error, possibly exit code 56. Continuing to retry..." else echo "Service not ready yet (HTTP status $response). Retrying in 5 seconds..." fi sleep 5 done echo "Finished waiting for service." - name: Wait for MCP OAuth mock server run: | echo "Waiting for MCP OAuth mock server on port ${MCP_TEST_SERVER_PORT:-8004}..." start_time=$(date +%s) timeout=120 while true; do current_time=$(date +%s) elapsed_time=$((current_time - start_time)) if [ $elapsed_time -ge $timeout ]; then echo "Timeout reached. MCP OAuth mock server did not become ready in ${timeout}s." exit 1 fi if curl -sf "http://localhost:${MCP_TEST_SERVER_PORT:-8004}/healthz" > /dev/null; then echo "MCP OAuth mock server is ready!" break fi sleep 3 done - name: Wait for MCP API Key mock server run: | echo "Waiting for MCP API Key mock server on port ${MCP_API_KEY_TEST_PORT:-8005}..." start_time=$(date +%s) timeout=120 while true; do current_time=$(date +%s) elapsed_time=$((current_time - start_time)) if [ $elapsed_time -ge $timeout ]; then echo "Timeout reached. MCP API Key mock server did not become ready in ${timeout}s." exit 1 fi if curl -sf "http://localhost:${MCP_API_KEY_TEST_PORT:-8005}/healthz" > /dev/null; then echo "MCP API Key mock server is ready!" break fi sleep 3 done - name: Wait for web server to be ready run: | echo "Waiting for web server on port 3000..." start_time=$(date +%s) timeout=120 while true; do current_time=$(date +%s) elapsed_time=$((current_time - start_time)) if [ $elapsed_time -ge $timeout ]; then echo "Timeout reached. Web server did not become ready in ${timeout}s." exit 1 fi if curl -sf "http://localhost:3000/api/health" > /dev/null 2>&1 || \ curl -sf "http://localhost:3000/" > /dev/null 2>&1; then echo "Web server is ready!" break fi echo "Web server not ready yet. Retrying in 3 seconds..." sleep 3 done - name: Run Playwright tests working-directory: ./web env: PROJECT: ${{ matrix.project }} run: | npx playwright test --project ${PROJECT} - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f if: always() with: # Includes test results and trace.zip files name: playwright-test-results-${{ matrix.project }}-${{ github.run_id }} path: ./web/output/playwright/ retention-days: 30 - name: Upload screenshots uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f if: always() with: name: playwright-screenshots-${{ matrix.project }}-${{ github.run_id }} path: ./web/output/screenshots/ retention-days: 30 # --- Visual Regression Diff --- - name: Configure AWS credentials if: always() uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Install the latest version of uv if: always() uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7 with: enable-cache: false version: "0.9.9" - name: Determine baseline revision if: always() id: baseline-rev env: EVENT_NAME: ${{ github.event_name }} BASE_REF: ${{ github.event.pull_request.base.ref }} MERGE_GROUP_BASE_REF: ${{ github.event.merge_group.base_ref }} GH_REF: ${{ github.ref }} REF_NAME: ${{ github.ref_name }} run: | if [ "${EVENT_NAME}" = "pull_request" ]; then # PRs compare against the base branch (e.g. main, release/2.5) echo "rev=${BASE_REF}" >> "$GITHUB_OUTPUT" elif [ "${EVENT_NAME}" = "merge_group" ]; then # Merge queue compares against the target branch (e.g. refs/heads/main -> main) echo "rev=${MERGE_GROUP_BASE_REF#refs/heads/}" >> "$GITHUB_OUTPUT" elif [[ "${GH_REF}" == refs/tags/* ]]; then # Tag builds compare against the tag name echo "rev=${REF_NAME}" >> "$GITHUB_OUTPUT" else # Push builds (main, release/*) compare against the branch name echo "rev=${REF_NAME}" >> "$GITHUB_OUTPUT" fi - name: Generate screenshot diff report if: always() env: PROJECT: ${{ matrix.project }} PLAYWRIGHT_S3_BUCKET: ${{ env.PLAYWRIGHT_S3_BUCKET }} BASELINE_REV: ${{ steps.baseline-rev.outputs.rev }} run: | uv run --no-sync --with onyx-devtools ods screenshot-diff compare \ --project "${PROJECT}" \ --rev "${BASELINE_REV}" - name: Upload visual diff report to S3 if: always() env: PROJECT: ${{ matrix.project }} PR_NUMBER: ${{ github.event.pull_request.number }} RUN_ID: ${{ github.run_id }} run: | SUMMARY_FILE="web/output/screenshot-diff/${PROJECT}/summary.json" if [ ! -f "${SUMMARY_FILE}" ]; then echo "No summary file found — skipping S3 upload." exit 0 fi HAS_DIFF=$(jq -r '.has_differences' "${SUMMARY_FILE}") if [ "${HAS_DIFF}" != "true" ]; then echo "No visual differences for ${PROJECT} — skipping S3 upload." exit 0 fi aws s3 sync "web/output/screenshot-diff/${PROJECT}/" \ "s3://${PLAYWRIGHT_S3_BUCKET}/reports/pr-${PR_NUMBER}/${RUN_ID}/${PROJECT}/" - name: Upload visual diff summary uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f if: always() with: name: screenshot-diff-summary-${{ matrix.project }} path: ./web/output/screenshot-diff/${{ matrix.project }}/summary.json if-no-files-found: ignore retention-days: 5 - name: Upload visual diff report artifact uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f if: always() with: name: screenshot-diff-report-${{ matrix.project }}-${{ github.run_id }} path: ./web/output/screenshot-diff/${{ matrix.project }}/ if-no-files-found: ignore retention-days: 30 - name: Update S3 baselines if: >- success() && ( github.ref == 'refs/heads/main' || startsWith(github.ref, 'refs/heads/release/') || startsWith(github.ref, 'refs/tags/v') || ( github.event_name == 'merge_group' && ( github.event.merge_group.base_ref == 'refs/heads/main' || startsWith(github.event.merge_group.base_ref, 'refs/heads/release/') ) ) ) env: PROJECT: ${{ matrix.project }} PLAYWRIGHT_S3_BUCKET: ${{ env.PLAYWRIGHT_S3_BUCKET }} BASELINE_REV: ${{ steps.baseline-rev.outputs.rev }} run: | if [ -d "web/output/screenshots/" ] && [ "$(ls -A web/output/screenshots/)" ]; then uv run --no-sync --with onyx-devtools ods screenshot-diff upload-baselines \ --project "${PROJECT}" \ --rev "${BASELINE_REV}" \ --delete else echo "No screenshots to upload for ${PROJECT} — skipping baseline update." fi # save before stopping the containers so the logs can be captured - name: Save Docker logs if: success() || failure() env: WORKSPACE: ${{ github.workspace }} run: | cd deployment/docker_compose docker compose logs > docker-compose.log mv docker-compose.log ${WORKSPACE}/docker-compose.log - name: Upload logs if: success() || failure() uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f with: name: docker-logs-${{ matrix.project }}-${{ github.run_id }} path: ${{ github.workspace }}/docker-compose.log playwright-tests-lite: needs: [build-web-image, build-backend-image] name: Playwright Tests (lite) runs-on: - runs-on - runner=4cpu-linux-arm64 - "run-id=${{ github.run_id }}-playwright-tests-lite" - "extras=ecr-cache" timeout-minutes: 30 steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Setup node # zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4 with: node-version: 22 cache: "npm" # zizmor: ignore[cache-poisoning] cache-dependency-path: ./web/package-lock.json - name: Install node dependencies working-directory: ./web run: npm ci - name: Cache playwright cache # zizmor: ignore[cache-poisoning] ephemeral runners; no release artifacts uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4 with: path: ~/.cache/ms-playwright key: ${{ runner.os }}-playwright-npm-${{ hashFiles('web/package-lock.json') }} restore-keys: | ${{ runner.os }}-playwright-npm- - name: Install playwright browsers working-directory: ./web run: npx playwright install --with-deps - name: Create .env file for Docker Compose env: OPENAI_API_KEY_VALUE: ${{ env.OPENAI_API_KEY }} ECR_CACHE: ${{ env.RUNS_ON_ECR_CACHE }} RUN_ID: ${{ github.run_id }} run: | cat < deployment/docker_compose/.env ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true LICENSE_ENFORCEMENT_ENABLED=false AUTH_TYPE=basic INTEGRATION_TESTS_MODE=true GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE} MOCK_LLM_RESPONSE=true REQUIRE_EMAIL_VERIFICATION=false DISABLE_TELEMETRY=true ONYX_BACKEND_IMAGE=${ECR_CACHE}:playwright-test-backend-${RUN_ID} ONYX_WEB_SERVER_IMAGE=${ECR_CACHE}:playwright-test-web-${RUN_ID} EOF # needed for pulling external images otherwise, we hit the "Unauthenticated users" limit # https://docs.docker.com/docker-hub/usage/ - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - name: Start Docker containers (lite) run: | cd deployment/docker_compose docker compose -f docker-compose.yml -f docker-compose.onyx-lite.yml -f docker-compose.dev.yml up -d id: start_docker - name: Run Playwright tests (lite) working-directory: ./web run: npx playwright test --project lite - uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f if: always() with: name: playwright-test-results-lite-${{ github.run_id }} path: ./web/output/playwright/ retention-days: 30 - name: Save Docker logs if: success() || failure() env: WORKSPACE: ${{ github.workspace }} run: | cd deployment/docker_compose docker compose logs > docker-compose.log mv docker-compose.log ${WORKSPACE}/docker-compose.log - name: Upload logs if: success() || failure() uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f with: name: docker-logs-lite-${{ github.run_id }} path: ${{ github.workspace }}/docker-compose.log # Post a single combined visual regression comment after all matrix jobs finish visual-regression-comment: needs: [playwright-tests] if: >- always() && github.event_name == 'pull_request' && needs.playwright-tests.result != 'cancelled' runs-on: ubuntu-slim timeout-minutes: 5 permissions: pull-requests: write steps: - name: Download visual diff summaries uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 with: pattern: screenshot-diff-summary-* path: summaries/ - name: Post combined PR comment env: GH_TOKEN: ${{ github.token }} PR_NUMBER: ${{ github.event.pull_request.number }} RUN_ID: ${{ github.run_id }} REPO: ${{ github.repository }} S3_BUCKET: ${{ env.PLAYWRIGHT_S3_BUCKET }} run: | MARKER="" # Build the markdown table from all summary files TABLE_HEADER="| Project | Changed | Added | Removed | Unchanged | Report |" TABLE_DIVIDER="|---------|---------|-------|---------|-----------|--------|" TABLE_ROWS="" HAS_ANY_SUMMARY=false for SUMMARY_DIR in summaries/screenshot-diff-summary-*/; do SUMMARY_FILE="${SUMMARY_DIR}summary.json" if [ ! -f "${SUMMARY_FILE}" ]; then continue fi HAS_ANY_SUMMARY=true PROJECT=$(jq -r '.project' "${SUMMARY_FILE}") CHANGED=$(jq -r '.changed' "${SUMMARY_FILE}") ADDED=$(jq -r '.added' "${SUMMARY_FILE}") REMOVED=$(jq -r '.removed' "${SUMMARY_FILE}") UNCHANGED=$(jq -r '.unchanged' "${SUMMARY_FILE}") TOTAL=$(jq -r '.total' "${SUMMARY_FILE}") HAS_DIFF=$(jq -r '.has_differences' "${SUMMARY_FILE}") if [ "${TOTAL}" = "0" ]; then REPORT_LINK="_No screenshots_" elif [ "${HAS_DIFF}" = "true" ]; then REPORT_URL="https://${S3_BUCKET}.s3.us-east-2.amazonaws.com/reports/pr-${PR_NUMBER}/${RUN_ID}/${PROJECT}/index.html" REPORT_LINK="[View Report](${REPORT_URL})" else REPORT_LINK="✅ No changes" fi TABLE_ROWS="${TABLE_ROWS}| \`${PROJECT}\` | ${CHANGED} | ${ADDED} | ${REMOVED} | ${UNCHANGED} | ${REPORT_LINK} |\n" done if [ "${HAS_ANY_SUMMARY}" = "false" ]; then echo "No visual diff summaries found — skipping PR comment." exit 0 fi BODY=$(printf '%s\n' \ "${MARKER}" \ "### 🖼️ Visual Regression Report" \ "" \ "${TABLE_HEADER}" \ "${TABLE_DIVIDER}" \ "$(printf '%b' "${TABLE_ROWS}")") # Upsert: find existing comment with the marker, or create a new one EXISTING_COMMENT_ID=$(gh api \ "repos/${REPO}/issues/${PR_NUMBER}/comments" \ --jq ".[] | select(.body | startswith(\"${MARKER}\")) | .id" \ 2>/dev/null | head -1) if [ -n "${EXISTING_COMMENT_ID}" ]; then gh api \ --method PATCH \ "repos/${REPO}/issues/comments/${EXISTING_COMMENT_ID}" \ -f body="${BODY}" else gh api \ --method POST \ "repos/${REPO}/issues/${PR_NUMBER}/comments" \ -f body="${BODY}" fi playwright-required: # NOTE: Github-hosted runners have about 20s faster queue times and are preferred here. runs-on: ubuntu-slim timeout-minutes: 45 needs: [playwright-tests, playwright-tests-lite] if: ${{ always() }} steps: - name: Check job status if: ${{ contains(needs.*.result, 'failure') || contains(needs.*.result, 'cancelled') || contains(needs.*.result, 'skipped') }} run: exit 1 ================================================ FILE: .github/workflows/pr-python-checks.yml ================================================ name: Python Checks concurrency: group: Python-Checks-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }} cancel-in-progress: true on: merge_group: pull_request: branches: - main - "release/**" push: tags: - "v*.*.*" permissions: contents: read jobs: mypy-check: # See https://runs-on.com/runners/linux/ # Note: Mypy seems quite optimized for x64 compared to arm64. # Similarly, mypy is single-threaded and incremental, so 2cpu is sufficient. runs-on: [ runs-on, runner=2cpu-linux-x64, "run-id=${{ github.run_id }}-mypy-check", "extras=s3-cache", ] timeout-minutes: 45 steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Setup Python and Install Dependencies uses: ./.github/actions/setup-python-and-install-dependencies with: requirements: | backend/requirements/default.txt backend/requirements/dev.txt backend/requirements/model_server.txt backend/requirements/ee.txt - name: Generate OpenAPI schema and Python client shell: bash # TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license env: LICENSE_ENFORCEMENT_ENABLED: "false" run: | ods openapi all - name: Cache mypy cache if: ${{ vars.DISABLE_MYPY_CACHE != 'true' }} uses: runs-on/cache@a5f51d6f3fece787d03b7b4e981c82538a0654ed # ratchet:runs-on/cache@v4 with: path: .mypy_cache key: mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}-${{ hashFiles('**/*.py', '**/*.pyi', 'pyproject.toml') }} restore-keys: | mypy-${{ runner.os }}-${{ github.base_ref || github.event.merge_group.base_ref || 'main' }}- mypy-${{ runner.os }}- - name: Run MyPy env: MYPY_FORCE_COLOR: 1 TERM: xterm-256color run: mypy . ================================================ FILE: .github/workflows/pr-python-connector-tests.yml ================================================ name: Connector Tests concurrency: group: Connector-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }} cancel-in-progress: true on: merge_group: pull_request: branches: [main] paths: - "backend/**" - "pyproject.toml" - "uv.lock" - ".github/workflows/pr-python-connector-tests.yml" - ".github/actions/setup-python-and-install-dependencies/**" - ".github/actions/setup-playwright/**" push: tags: - "v*.*.*" schedule: # This cron expression runs the job daily at 16:00 UTC (9am PT) - cron: "0 16 * * *" permissions: id-token: write # Required for OIDC-based AWS credential exchange contents: read env: PYTHONPATH: ./backend DISABLE_TELEMETRY: "true" R2_ACCOUNT_ID_DAILY_CONNECTOR_TESTS: ${{ vars.R2_ACCOUNT_ID_DAILY_CONNECTOR_TESTS }} CONFLUENCE_TEST_SPACE_URL: ${{ vars.CONFLUENCE_TEST_SPACE_URL }} CONFLUENCE_TEST_SPACE: ${{ vars.CONFLUENCE_TEST_SPACE }} CONFLUENCE_USER_NAME: ${{ vars.CONFLUENCE_USER_NAME }} SF_USERNAME: ${{ vars.SF_USERNAME }} IMAP_HOST: ${{ vars.IMAP_HOST }} IMAP_USERNAME: ${{ vars.IMAP_USERNAME }} IMAP_MAILBOXES: ${{ vars.IMAP_MAILBOXES }} AIRTABLE_TEST_BASE_ID: ${{ vars.AIRTABLE_TEST_BASE_ID }} AIRTABLE_TEST_TABLE_ID: ${{ vars.AIRTABLE_TEST_TABLE_ID }} AIRTABLE_TEST_TABLE_NAME: ${{ vars.AIRTABLE_TEST_TABLE_NAME }} SHAREPOINT_CLIENT_ID: ${{ vars.SHAREPOINT_CLIENT_ID }} SHAREPOINT_CLIENT_DIRECTORY_ID: ${{ vars.SHAREPOINT_CLIENT_DIRECTORY_ID }} SHAREPOINT_SITE: ${{ vars.SHAREPOINT_SITE }} BITBUCKET_EMAIL: ${{ vars.BITBUCKET_EMAIL }} jobs: connectors-check: # See https://runs-on.com/runners/linux/ runs-on: [ runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}-connectors-check", "extras=s3-cache", ] timeout-minutes: 45 environment: ci-protected steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Setup Python and Install Dependencies uses: ./.github/actions/setup-python-and-install-dependencies with: requirements: | backend/requirements/default.txt backend/requirements/dev.txt - name: Setup Playwright uses: ./.github/actions/setup-playwright - name: Detect Connector changes id: changes uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # ratchet:dorny/paths-filter@v3 with: filters: | hubspot: - 'backend/onyx/connectors/hubspot/**' - 'backend/tests/daily/connectors/hubspot/**' - 'uv.lock' salesforce: - 'backend/onyx/connectors/salesforce/**' - 'backend/tests/daily/connectors/salesforce/**' - 'uv.lock' github: - 'backend/onyx/connectors/github/**' - 'backend/tests/daily/connectors/github/**' - 'uv.lock' file_processing: - 'backend/onyx/file_processing/**' - 'uv.lock' - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v4 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get connector test secrets from AWS Secrets Manager uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2 with: parse-json-secrets: false secret-ids: | AWS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS, test/aws-access-key-id AWS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS, test/aws-secret-access-key R2_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS, test/r2-access-key-id R2_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS, test/r2-secret-access-key GCS_ACCESS_KEY_ID_DAILY_CONNECTOR_TESTS, test/gcs-access-key-id GCS_SECRET_ACCESS_KEY_DAILY_CONNECTOR_TESTS, test/gcs-secret-access-key CONFLUENCE_ACCESS_TOKEN, test/confluence-access-token CONFLUENCE_ACCESS_TOKEN_SCOPED, test/confluence-access-token-scoped JIRA_BASE_URL, test/jira-base-url JIRA_USER_EMAIL, test/jira-user-email JIRA_API_TOKEN, test/jira-api-token JIRA_API_TOKEN_SCOPED, test/jira-api-token-scoped GONG_ACCESS_KEY, test/gong-access-key GONG_ACCESS_KEY_SECRET, test/gong-access-key-secret GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR, test/google-drive-service-account-json GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR_TEST_USER_1, test/google-drive-oauth-creds-test-user-1 GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR, test/google-drive-oauth-creds GOOGLE_GMAIL_SERVICE_ACCOUNT_JSON_STR, test/google-gmail-service-account-json GOOGLE_GMAIL_OAUTH_CREDENTIALS_JSON_STR, test/google-gmail-oauth-creds SLAB_BOT_TOKEN, test/slab-bot-token ZENDESK_SUBDOMAIN, test/zendesk-subdomain ZENDESK_EMAIL, test/zendesk-email ZENDESK_TOKEN, test/zendesk-token SF_PASSWORD, test/sf-password SF_SECURITY_TOKEN, test/sf-security-token HUBSPOT_ACCESS_TOKEN, test/hubspot-access-token IMAP_PASSWORD, test/imap-password AIRTABLE_ACCESS_TOKEN, test/airtable-access-token SHAREPOINT_CLIENT_SECRET, test/sharepoint-client-secret PERM_SYNC_SHAREPOINT_CLIENT_ID, test/perm-sync-sharepoint-client-id PERM_SYNC_SHAREPOINT_PRIVATE_KEY, test/perm-sync-sharepoint-private-key PERM_SYNC_SHAREPOINT_CERTIFICATE_PASSWORD, test/perm-sync-sharepoint-cert-password PERM_SYNC_SHAREPOINT_DIRECTORY_ID, test/perm-sync-sharepoint-directory-id ACCESS_TOKEN_GITHUB, test/github-access-token GITLAB_ACCESS_TOKEN, test/gitlab-access-token GITBOOK_SPACE_ID, test/gitbook-space-id GITBOOK_API_KEY, test/gitbook-api-key NOTION_INTEGRATION_TOKEN, test/notion-integration-token HIGHSPOT_KEY, test/highspot-key HIGHSPOT_SECRET, test/highspot-secret SLACK_BOT_TOKEN, test/slack-bot-token DISCORD_CONNECTOR_BOT_TOKEN, test/discord-bot-token TEAMS_APPLICATION_ID, test/teams-application-id TEAMS_DIRECTORY_ID, test/teams-directory-id TEAMS_SECRET, test/teams-secret BITBUCKET_WORKSPACE, test/bitbucket-workspace BITBUCKET_API_TOKEN, test/bitbucket-api-token FIREFLIES_API_KEY, test/fireflies-api-key - name: Run Tests (excluding HubSpot, Salesforce, GitHub, and Coda) shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}" run: | py.test \ -n 8 \ --dist loadfile \ --durations=8 \ -o junit_family=xunit2 \ -xv \ --ff \ backend/tests/daily/connectors \ --ignore backend/tests/daily/connectors/hubspot \ --ignore backend/tests/daily/connectors/salesforce \ --ignore backend/tests/daily/connectors/github \ --ignore backend/tests/daily/connectors/coda - name: Run HubSpot Connector Tests if: ${{ github.event_name == 'schedule' || steps.changes.outputs.hubspot == 'true' || steps.changes.outputs.file_processing == 'true' }} shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}" run: | py.test \ -n 8 \ --dist loadfile \ --durations=8 \ -o junit_family=xunit2 \ -xv \ --ff \ backend/tests/daily/connectors/hubspot - name: Run Salesforce Connector Tests if: ${{ github.event_name == 'schedule' || steps.changes.outputs.salesforce == 'true' || steps.changes.outputs.file_processing == 'true' }} shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}" run: | py.test \ -n 8 \ --dist loadfile \ --durations=8 \ -o junit_family=xunit2 \ -xv \ --ff \ backend/tests/daily/connectors/salesforce - name: Run GitHub Connector Tests if: ${{ github.event_name == 'schedule' || steps.changes.outputs.github == 'true' || steps.changes.outputs.file_processing == 'true' }} shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}" run: | py.test \ -n 8 \ --dist loadfile \ --durations=8 \ -o junit_family=xunit2 \ -xv \ --ff \ backend/tests/daily/connectors/github - name: Alert on Failure if: failure() && github.event_name == 'schedule' env: SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }} REPO: ${{ github.repository }} RUN_ID: ${{ github.run_id }} run: | curl -X POST \ -H 'Content-type: application/json' \ --data "{\"text\":\"Scheduled Connector Tests failed! Check the run at: https://github.com/${REPO}/actions/runs/${RUN_ID}\"}" \ $SLACK_WEBHOOK ================================================ FILE: .github/workflows/pr-python-model-tests.yml ================================================ name: Model Server Tests on: schedule: # This cron expression runs the job daily at 16:00 UTC (9am PT) - cron: "0 16 * * *" workflow_dispatch: permissions: contents: read env: # Bedrock AWS_ACCESS_KEY_ID: ${{ secrets.AWS_ACCESS_KEY_ID }} AWS_SECRET_ACCESS_KEY: ${{ secrets.AWS_SECRET_ACCESS_KEY }} AWS_REGION_NAME: ${{ vars.AWS_REGION_NAME }} # API keys for testing COHERE_API_KEY: ${{ secrets.COHERE_API_KEY }} LITELLM_API_KEY: ${{ secrets.LITELLM_API_KEY }} LITELLM_API_URL: ${{ secrets.LITELLM_API_URL }} OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }} AZURE_API_KEY: ${{ secrets.AZURE_API_KEY }} AZURE_API_URL: ${{ vars.AZURE_API_URL }} jobs: model-check: # See https://runs-on.com/runners/linux/ runs-on: - runs-on - runner=4cpu-linux-arm64 - "run-id=${{ github.run_id }}-model-check" - "extras=ecr-cache" environment: ci-protected timeout-minutes: 45 env: PYTHONPATH: ./backend steps: - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Setup Python and Install Dependencies uses: ./.github/actions/setup-python-and-install-dependencies with: requirements: | backend/requirements/default.txt backend/requirements/dev.txt - name: Format branch name for cache id: format-branch env: PR_NUMBER: ${{ github.event.pull_request.number }} REF_NAME: ${{ github.ref_name }} run: | if [ -n "${PR_NUMBER}" ]; then CACHE_SUFFIX="${PR_NUMBER}" else # shellcheck disable=SC2001 CACHE_SUFFIX=$(echo "${REF_NAME}" | sed 's/[^A-Za-z0-9._-]/-/g') fi echo "cache-suffix=${CACHE_SUFFIX}" >> $GITHUB_OUTPUT - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 with: username: ${{ secrets.DOCKER_USERNAME }} password: ${{ secrets.DOCKER_TOKEN }} - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f - name: Build and load uses: docker/bake-action@82490499d2e5613fcead7e128237ef0b0ea210f7 # ratchet:docker/bake-action@v7.0.0 env: TAG: model-server-${{ github.run_id }} with: load: true targets: model-server set: | model-server.cache-from=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }} model-server.cache-from=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }} model-server.cache-from=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache model-server.cache-from=type=registry,ref=onyxdotapp/onyx-model-server:latest model-server.cache-to=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ github.event.pull_request.head.sha || github.sha }},mode=max model-server.cache-to=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache-${{ steps.format-branch.outputs.cache-suffix }},mode=max model-server.cache-to=type=registry,ref=${{ env.RUNS_ON_ECR_CACHE }}:model-server-cache,mode=max - name: Start Docker containers id: start_docker env: IMAGE_TAG: model-server-${{ github.run_id }} run: | cd deployment/docker_compose docker compose \ -f docker-compose.yml \ -f docker-compose.dev.yml \ up -d --wait \ inference_model_server - name: Run Tests run: | py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/llm py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/embedding - name: Alert on Failure if: failure() && github.event_name == 'schedule' uses: ./.github/actions/slack-notify with: webhook-url: ${{ secrets.SLACK_WEBHOOK }} failed-jobs: model-check title: "🚨 Scheduled Model Tests failed!" ref-name: ${{ github.ref_name }} - name: Dump all-container logs (optional) if: always() run: | cd deployment/docker_compose docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true - name: Upload logs if: always() uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f with: name: docker-all-logs path: ${{ github.workspace }}/docker-compose.log ================================================ FILE: .github/workflows/pr-python-tests.yml ================================================ name: Python Unit Tests concurrency: group: Python-Unit-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }} cancel-in-progress: true on: merge_group: pull_request: branches: - main - 'release/**' push: tags: - "v*.*.*" permissions: contents: read jobs: backend-check: # See https://runs-on.com/runners/linux/ runs-on: [runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-backend-check"] timeout-minutes: 45 env: PYTHONPATH: ./backend REDIS_CLOUD_PYTEST_PASSWORD: ${{ secrets.REDIS_CLOUD_PYTEST_PASSWORD }} DISABLE_TELEMETRY: "true" # TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license LICENSE_ENFORCEMENT_ENABLED: "false" steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Setup Python and Install Dependencies uses: ./.github/actions/setup-python-and-install-dependencies with: requirements: | backend/requirements/default.txt backend/requirements/dev.txt backend/requirements/model_server.txt backend/requirements/ee.txt - name: Run Tests shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}" run: py.test -o junit_family=xunit2 -xv --ff backend/tests/unit ================================================ FILE: .github/workflows/pr-quality-checks.yml ================================================ name: Quality Checks PR concurrency: group: Quality-Checks-PR-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }} cancel-in-progress: true on: merge_group: pull_request: null push: branches: - main tags: - "v*.*.*" permissions: contents: read jobs: quality-checks: runs-on: ubuntu-latest timeout-minutes: 45 steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: fetch-depth: 0 persist-credentials: false - uses: actions/setup-python@a309ff8b426b58ec0e2a45f0f869d46889d02405 # ratchet:actions/setup-python@v6 with: python-version: "3.11" - name: Setup Terraform uses: hashicorp/setup-terraform@5e8dbf3c6d9deaf4193ca7a8fb23f2ac83bb6c85 # ratchet:hashicorp/setup-terraform@v4.0.0 - name: Setup node uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v6 with: # zizmor: ignore[cache-poisoning] node-version: 22 cache: "npm" cache-dependency-path: ./web/package-lock.json - name: Install node dependencies working-directory: ./web run: npm ci - uses: j178/prek-action@0bb87d7f00b0c99306c8bcb8b8beba1eb581c037 # ratchet:j178/prek-action@v1 with: prek-version: '0.3.4' extra-args: ${{ github.event_name == 'pull_request' && format('--from-ref {0} --to-ref {1}', github.event.pull_request.base.sha, github.event.pull_request.head.sha) || github.event_name == 'merge_group' && format('--from-ref {0} --to-ref {1}', github.event.merge_group.base_sha, github.event.merge_group.head_sha) || github.ref_name == 'main' && '--all-files' || '' }} - name: Check Actions uses: giner/check-actions@28d366c7cbbe235f9624a88aa31a628167eee28c # ratchet:giner/check-actions@v1.0.1 with: check_permissions: false check_versions: false ================================================ FILE: .github/workflows/preview.yml ================================================ name: Preview Deployment env: VERCEL_ORG_ID: ${{ secrets.VERCEL_ORG_ID }} VERCEL_PROJECT_ID: ${{ secrets.VERCEL_PROJECT_ID }} VERCEL_CLI: vercel@50.14.1 on: push: branches-ignore: - main paths: - "web/**" permissions: contents: read pull-requests: write jobs: Deploy-Preview: runs-on: ubuntu-latest timeout-minutes: 30 steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd with: persist-credentials: false - name: Setup node uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4 with: node-version: 22 cache: "npm" cache-dependency-path: ./web/package-lock.json - name: Pull Vercel Environment Information run: npx --yes ${{ env.VERCEL_CLI }} pull --yes --environment=preview --token=${{ secrets.VERCEL_TOKEN }} - name: Build Project Artifacts run: npx --yes ${{ env.VERCEL_CLI }} build --token=${{ secrets.VERCEL_TOKEN }} - name: Deploy Project Artifacts to Vercel id: deploy run: | DEPLOYMENT_URL=$(npx --yes ${{ env.VERCEL_CLI }} deploy --prebuilt --token=${{ secrets.VERCEL_TOKEN }}) echo "url=$DEPLOYMENT_URL" >> "$GITHUB_OUTPUT" - name: Update PR comment with deployment URL if: always() && steps.deploy.outputs.url env: GH_TOKEN: ${{ github.token }} DEPLOYMENT_URL: ${{ steps.deploy.outputs.url }} run: | # Find the PR for this branch PR_NUMBER=$(gh pr list --head "$GITHUB_REF_NAME" --json number --jq '.[0].number') if [ -z "$PR_NUMBER" ]; then echo "No open PR found for branch $GITHUB_REF_NAME, skipping comment." exit 0 fi COMMENT_MARKER="" COMMENT_BODY="$COMMENT_MARKER **Preview Deployment** | Status | Preview | Commit | Updated | | --- | --- | --- | --- | | ✅ | $DEPLOYMENT_URL | \`${GITHUB_SHA::7}\` | $(date -u '+%Y-%m-%d %H:%M:%S UTC') |" # Find existing comment by marker EXISTING_COMMENT_ID=$(gh api "repos/$GITHUB_REPOSITORY/issues/$PR_NUMBER/comments" \ --jq ".[] | select(.body | startswith(\"$COMMENT_MARKER\")) | .id" | head -1) if [ -n "$EXISTING_COMMENT_ID" ]; then gh api "repos/$GITHUB_REPOSITORY/issues/comments/$EXISTING_COMMENT_ID" \ --method PATCH --field body="$COMMENT_BODY" else gh pr comment "$PR_NUMBER" --body "$COMMENT_BODY" fi ================================================ FILE: .github/workflows/release-cli.yml ================================================ name: Release CLI on: push: tags: - "cli/v*.*.*" jobs: pypi: runs-on: ubuntu-latest environment: name: release-cli permissions: id-token: write timeout-minutes: 10 steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7 with: enable-cache: false version: "0.9.9" - run: | for goos in linux windows darwin; do for goarch in amd64 arm64; do GOOS="$goos" GOARCH="$goarch" uv build --wheel done done working-directory: cli - run: uv publish working-directory: cli docker-amd64: runs-on: - runs-on - runner=2cpu-linux-x64 - run-id=${{ github.run_id }}-cli-amd64 - extras=ecr-cache environment: deploy permissions: id-token: write timeout-minutes: 30 outputs: digest: ${{ steps.build.outputs.digest }} env: REGISTRY_IMAGE: onyxdotapp/onyx-cli steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Set up Docker Buildx uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4 - name: Login to Docker Hub uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Build and push AMD64 id: build uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7 with: context: ./cli file: ./cli/Dockerfile platforms: linux/amd64 cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: type=inline outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true docker-arm64: runs-on: - runs-on - runner=2cpu-linux-arm64 - run-id=${{ github.run_id }}-cli-arm64 - extras=ecr-cache environment: deploy permissions: id-token: write timeout-minutes: 30 outputs: digest: ${{ steps.build.outputs.digest }} env: REGISTRY_IMAGE: onyxdotapp/onyx-cli steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Set up Docker Buildx uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4 - name: Login to Docker Hub uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Build and push ARM64 id: build uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7 with: context: ./cli file: ./cli/Dockerfile platforms: linux/arm64 cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: type=inline outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true merge-docker: needs: - docker-amd64 - docker-arm64 runs-on: - runs-on - runner=2cpu-linux-x64 - run-id=${{ github.run_id }}-cli-merge environment: deploy permissions: id-token: write timeout-minutes: 10 env: REGISTRY_IMAGE: onyxdotapp/onyx-cli steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Set up Docker Buildx uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4 - name: Login to Docker Hub uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Create and push manifest env: AMD64_DIGEST: ${{ needs.docker-amd64.outputs.digest }} ARM64_DIGEST: ${{ needs.docker-arm64.outputs.digest }} TAG: ${{ github.ref_name }} run: | SANITIZED_TAG="${TAG#cli/}" IMAGES=( "${REGISTRY_IMAGE}@${AMD64_DIGEST}" "${REGISTRY_IMAGE}@${ARM64_DIGEST}" ) if [[ "$TAG" =~ ^cli/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then docker buildx imagetools create \ -t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \ -t "${REGISTRY_IMAGE}:latest" \ "${IMAGES[@]}" else docker buildx imagetools create \ -t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \ "${IMAGES[@]}" fi ================================================ FILE: .github/workflows/release-devtools.yml ================================================ name: Release Devtools on: push: tags: - "ods/v*.*.*" jobs: pypi: runs-on: ubuntu-latest environment: name: release-devtools permissions: id-token: write timeout-minutes: 10 strategy: matrix: os-arch: - { goos: "linux", goarch: "amd64" } - { goos: "linux", goarch: "arm64" } - { goos: "windows", goarch: "amd64" } - { goos: "windows", goarch: "arm64" } - { goos: "darwin", goarch: "amd64" } - { goos: "darwin", goarch: "arm64" } steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7 with: enable-cache: false version: "0.9.9" - run: | GOOS="${{ matrix.os-arch.goos }}" \ GOARCH="${{ matrix.os-arch.goarch }}" \ uv build --wheel working-directory: tools/ods - run: uv publish working-directory: tools/ods ================================================ FILE: .github/workflows/reusable-nightly-llm-provider-chat.yml ================================================ name: Reusable Nightly LLM Provider Chat Tests on: workflow_call: inputs: openai_models: description: "Comma-separated models for openai" required: false default: "" type: string anthropic_models: description: "Comma-separated models for anthropic" required: false default: "" type: string bedrock_models: description: "Comma-separated models for bedrock" required: false default: "" type: string vertex_ai_models: description: "Comma-separated models for vertex_ai" required: false default: "" type: string azure_models: description: "Comma-separated models for azure" required: false default: "" type: string ollama_models: description: "Comma-separated models for ollama_chat" required: false default: "" type: string openrouter_models: description: "Comma-separated models for openrouter" required: false default: "" type: string azure_api_base: description: "API base for azure provider" required: false default: "" type: string strict: description: "Default NIGHTLY_LLM_STRICT passed to tests" required: false default: true type: boolean secrets: AWS_OIDC_ROLE_ARN: description: "AWS role ARN for OIDC auth" required: true permissions: contents: read id-token: write jobs: build-backend-image: runs-on: [ runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-backend-image", "extras=ecr-cache", ] timeout-minutes: 45 environment: ci-protected steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, test/docker-username DOCKER_TOKEN, test/docker-token - name: Build backend image uses: ./.github/actions/build-backend-image with: runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }} ref-name: ${{ github.ref_name }} pr-number: ${{ github.event.pull_request.number }} github-sha: ${{ github.sha }} run-id: ${{ github.run_id }} docker-username: ${{ env.DOCKER_USERNAME }} docker-token: ${{ env.DOCKER_TOKEN }} docker-no-cache: ${{ vars.DOCKER_NO_CACHE == 'true' && 'true' || 'false' }} build-model-server-image: runs-on: [ runs-on, runner=1cpu-linux-arm64, "run-id=${{ github.run_id }}-build-model-server-image", "extras=ecr-cache", ] timeout-minutes: 45 environment: ci-protected steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, test/docker-username DOCKER_TOKEN, test/docker-token - name: Build model server image uses: ./.github/actions/build-model-server-image with: runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }} ref-name: ${{ github.ref_name }} pr-number: ${{ github.event.pull_request.number }} github-sha: ${{ github.sha }} run-id: ${{ github.run_id }} docker-username: ${{ env.DOCKER_USERNAME }} docker-token: ${{ env.DOCKER_TOKEN }} build-integration-image: runs-on: [ runs-on, runner=2cpu-linux-arm64, "run-id=${{ github.run_id }}-build-integration-image", "extras=ecr-cache", ] timeout-minutes: 45 environment: ci-protected steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, test/docker-username DOCKER_TOKEN, test/docker-token - name: Build integration image uses: ./.github/actions/build-integration-image with: runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }} ref-name: ${{ github.ref_name }} pr-number: ${{ github.event.pull_request.number }} github-sha: ${{ github.sha }} run-id: ${{ github.run_id }} docker-username: ${{ env.DOCKER_USERNAME }} docker-token: ${{ env.DOCKER_TOKEN }} provider-chat-test: needs: [ build-backend-image, build-model-server-image, build-integration-image, ] strategy: fail-fast: false matrix: include: - provider: openai models: ${{ inputs.openai_models }} api_key_env: OPENAI_API_KEY custom_config_env: "" api_base: "" api_version: "" deployment_name: "" required: true - provider: anthropic models: ${{ inputs.anthropic_models }} api_key_env: ANTHROPIC_API_KEY custom_config_env: "" api_base: "" api_version: "" deployment_name: "" required: true - provider: bedrock models: ${{ inputs.bedrock_models }} api_key_env: BEDROCK_API_KEY custom_config_env: "" api_base: "" api_version: "" deployment_name: "" required: false - provider: vertex_ai models: ${{ inputs.vertex_ai_models }} api_key_env: "" custom_config_env: NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON api_base: "" api_version: "" deployment_name: "" required: false - provider: azure models: ${{ inputs.azure_models }} api_key_env: AZURE_API_KEY custom_config_env: "" api_base: ${{ inputs.azure_api_base }} api_version: "2025-04-01-preview" deployment_name: "" required: false - provider: ollama_chat models: ${{ inputs.ollama_models }} api_key_env: OLLAMA_API_KEY custom_config_env: "" api_base: "https://ollama.com" api_version: "" deployment_name: "" required: false - provider: openrouter models: ${{ inputs.openrouter_models }} api_key_env: OPENROUTER_API_KEY custom_config_env: "" api_base: "https://openrouter.ai/api/v1" api_version: "" deployment_name: "" required: false runs-on: - runs-on - runner=4cpu-linux-arm64 - "run-id=${{ github.run_id }}-nightly-${{ matrix.provider }}-provider-chat-test" - extras=ecr-cache timeout-minutes: 45 environment: ci-protected steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: # Keep JSON values unparsed so vertex custom config is passed as raw JSON. parse-json-secrets: false secret-ids: | DOCKER_USERNAME, test/docker-username DOCKER_TOKEN, test/docker-token OPENAI_API_KEY, test/openai-api-key ANTHROPIC_API_KEY, test/anthropic-api-key BEDROCK_API_KEY, test/bedrock-api-key NIGHTLY_LLM_VERTEX_AI_CUSTOM_CONFIG_JSON, test/nightly-llm-vertex-ai-custom-config-json AZURE_API_KEY, test/azure-api-key OLLAMA_API_KEY, test/ollama-api-key OPENROUTER_API_KEY, test/openrouter-api-key - name: Run nightly provider chat test uses: ./.github/actions/run-nightly-provider-chat-test with: provider: ${{ matrix.provider }} models: ${{ matrix.models }} provider-api-key: ${{ matrix.api_key_env && env[matrix.api_key_env] || '' }} strict: ${{ inputs.strict && 'true' || 'false' }} api-base: ${{ matrix.api_base }} api-version: ${{ matrix.api_version }} deployment-name: ${{ matrix.deployment_name }} custom-config-json: ${{ matrix.custom_config_env && env[matrix.custom_config_env] || '' }} runs-on-ecr-cache: ${{ env.RUNS_ON_ECR_CACHE }} run-id: ${{ github.run_id }} docker-username: ${{ env.DOCKER_USERNAME }} docker-token: ${{ env.DOCKER_TOKEN }} - name: Dump API server logs if: always() run: | cd deployment/docker_compose docker compose logs --no-color api_server > $GITHUB_WORKSPACE/api_server.log || true - name: Dump all-container logs if: always() run: | cd deployment/docker_compose docker compose logs --no-color > $GITHUB_WORKSPACE/docker-compose.log || true - name: Upload logs if: always() uses: actions/upload-artifact@bbbca2ddaa5d8feaa63e36b76fdaad77386f024f with: name: docker-all-logs-nightly-${{ matrix.provider }}-llm-provider path: | ${{ github.workspace }}/api_server.log ${{ github.workspace }}/docker-compose.log - name: Stop Docker containers if: always() run: | cd deployment/docker_compose docker compose down -v ================================================ FILE: .github/workflows/sandbox-deployment.yml ================================================ name: Build and Push Sandbox Image on Tag on: push: tags: - "experimental-cc4a.*" # Restrictive defaults; jobs declare what they need. permissions: {} jobs: check-sandbox-changes: runs-on: ubuntu-slim timeout-minutes: 10 permissions: contents: read outputs: sandbox-changed: ${{ steps.check.outputs.sandbox-changed }} new-version: ${{ steps.version.outputs.new-version }} steps: - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false fetch-depth: 0 - name: Check for sandbox-relevant file changes id: check run: | # Get the previous tag to diff against CURRENT_TAG="${GITHUB_REF_NAME}" PREVIOUS_TAG=$(git tag --sort=-creatordate | grep '^experimental-cc4a\.' | grep -v "^${CURRENT_TAG}$" | head -n 1) if [ -z "$PREVIOUS_TAG" ]; then echo "No previous experimental-cc4a tag found, building unconditionally" echo "sandbox-changed=true" >> "$GITHUB_OUTPUT" exit 0 fi echo "Comparing ${PREVIOUS_TAG}..${CURRENT_TAG}" # Check if any sandbox-relevant files changed SANDBOX_PATHS=( "backend/onyx/server/features/build/sandbox/" ) CHANGED=false for path in "${SANDBOX_PATHS[@]}"; do if git diff --name-only "${PREVIOUS_TAG}..${CURRENT_TAG}" -- "$path" | grep -q .; then echo "Changes detected in: $path" CHANGED=true break fi done echo "sandbox-changed=$CHANGED" >> "$GITHUB_OUTPUT" - name: Determine new sandbox version id: version if: steps.check.outputs.sandbox-changed == 'true' run: | # Query Docker Hub for the latest versioned tag LATEST_TAG=$(curl -s "https://hub.docker.com/v2/repositories/onyxdotapp/sandbox/tags?page_size=100" \ | jq -r '.results[].name' \ | grep -E '^v[0-9]+\.[0-9]+\.[0-9]+$' \ | sort -V \ | tail -n 1) if [ -z "$LATEST_TAG" ]; then echo "No existing version tags found on Docker Hub, starting at 0.1.1" NEW_VERSION="0.1.1" else CURRENT_VERSION="${LATEST_TAG#v}" echo "Latest version on Docker Hub: $CURRENT_VERSION" # Increment patch version MAJOR=$(echo "$CURRENT_VERSION" | cut -d. -f1) MINOR=$(echo "$CURRENT_VERSION" | cut -d. -f2) PATCH=$(echo "$CURRENT_VERSION" | cut -d. -f3) NEW_PATCH=$((PATCH + 1)) NEW_VERSION="${MAJOR}.${MINOR}.${NEW_PATCH}" fi echo "New version: $NEW_VERSION" echo "new-version=$NEW_VERSION" >> "$GITHUB_OUTPUT" build-sandbox-amd64: needs: check-sandbox-changes if: needs.check-sandbox-changes.outputs.sandbox-changed == 'true' runs-on: - runs-on - runner=4cpu-linux-x64 - run-id=${{ github.run_id }}-sandbox-amd64 - extras=ecr-cache timeout-minutes: 90 environment: release permissions: contents: read id-token: write outputs: digest: ${{ steps.build.outputs.digest }} env: REGISTRY_IMAGE: onyxdotapp/sandbox steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Docker meta id: meta uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0 with: images: ${{ env.REGISTRY_IMAGE }} flavor: | latest=false - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Build and push AMD64 id: build uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6 with: context: ./backend/onyx/server/features/build/sandbox/kubernetes/docker file: ./backend/onyx/server/features/build/sandbox/kubernetes/docker/Dockerfile platforms: linux/amd64 labels: ${{ steps.meta.outputs.labels }} cache-from: | type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true build-sandbox-arm64: needs: check-sandbox-changes if: needs.check-sandbox-changes.outputs.sandbox-changed == 'true' runs-on: - runs-on - runner=4cpu-linux-arm64 - run-id=${{ github.run_id }}-sandbox-arm64 - extras=ecr-cache timeout-minutes: 90 environment: release permissions: contents: read id-token: write outputs: digest: ${{ steps.build.outputs.digest }} env: REGISTRY_IMAGE: onyxdotapp/sandbox steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Checkout uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: persist-credentials: false - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Docker meta id: meta uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0 with: images: ${{ env.REGISTRY_IMAGE }} flavor: | latest=false - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Build and push ARM64 id: build uses: docker/build-push-action@263435318d21b8e681c14492fe198d362a7d2c83 # ratchet:docker/build-push-action@v6 with: context: ./backend/onyx/server/features/build/sandbox/kubernetes/docker file: ./backend/onyx/server/features/build/sandbox/kubernetes/docker/Dockerfile platforms: linux/arm64 labels: ${{ steps.meta.outputs.labels }} cache-from: | type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest cache-to: | type=inline outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true merge-sandbox: needs: - check-sandbox-changes - build-sandbox-amd64 - build-sandbox-arm64 runs-on: - runs-on - runner=2cpu-linux-x64 - run-id=${{ github.run_id }}-merge-sandbox - extras=ecr-cache timeout-minutes: 30 environment: release permissions: id-token: write env: REGISTRY_IMAGE: onyxdotapp/sandbox steps: - uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2 - name: Configure AWS credentials uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 with: role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }} aws-region: us-east-2 - name: Get AWS Secrets uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 with: secret-ids: | DOCKER_USERNAME, deploy/docker-username DOCKER_TOKEN, deploy/docker-token parse-json-secrets: true - name: Set up Docker Buildx uses: docker/setup-buildx-action@8d2750c68a42422c14e847fe6c8ac0403b4cbd6f # ratchet:docker/setup-buildx-action@v3 - name: Login to Docker Hub uses: docker/login-action@c94ce9fb468520275223c153574b00df6fe4bcc9 # ratchet:docker/login-action@v3 with: username: ${{ env.DOCKER_USERNAME }} password: ${{ env.DOCKER_TOKEN }} - name: Docker meta id: meta uses: docker/metadata-action@030e881283bb7a6894de51c315a6bfe6a94e05cf # ratchet:docker/metadata-action@v6.0.0 with: images: ${{ env.REGISTRY_IMAGE }} flavor: | latest=false tags: | type=raw,value=v${{ needs.check-sandbox-changes.outputs.new-version }} type=raw,value=latest - name: Create and push manifest env: IMAGE_REPO: ${{ env.REGISTRY_IMAGE }} AMD64_DIGEST: ${{ needs.build-sandbox-amd64.outputs.digest }} ARM64_DIGEST: ${{ needs.build-sandbox-arm64.outputs.digest }} META_TAGS: ${{ steps.meta.outputs.tags }} run: | IMAGES="${IMAGE_REPO}@${AMD64_DIGEST} ${IMAGE_REPO}@${ARM64_DIGEST}" docker buildx imagetools create \ $(printf '%s\n' "${META_TAGS}" | xargs -I {} echo -t {}) \ $IMAGES ================================================ FILE: .github/workflows/storybook-deploy.yml ================================================ name: Storybook Deploy env: VERCEL_ORG_ID: ${{ secrets.VERCEL_ORG_ID }} VERCEL_PROJECT_ID: prj_sG49mVsA25UsxIPhN2pmBJlikJZM VERCEL_CLI: vercel@50.14.1 VERCEL_TOKEN: ${{ secrets.VERCEL_TOKEN }} concurrency: group: storybook-deploy-production cancel-in-progress: true on: workflow_dispatch: push: branches: - main paths: - "web/lib/opal/**" - "web/src/refresh-components/**" - "web/.storybook/**" - "web/package.json" - "web/package-lock.json" permissions: contents: read jobs: Deploy-Storybook: runs-on: ubuntu-latest environment: ci-protected timeout-minutes: 30 steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4 with: persist-credentials: false - name: Setup node uses: actions/setup-node@53b83947a5a98c8d113130e565377fae1a50d02f # ratchet:actions/setup-node@v4 with: node-version: 22 cache: "npm" cache-dependency-path: ./web/package-lock.json - name: Install dependencies working-directory: web run: npm ci - name: Build Storybook working-directory: web run: npm run storybook:build - name: Deploy to Vercel (Production) working-directory: web run: npx --yes "$VERCEL_CLI" deploy storybook-static/ --prod --yes --token="$VERCEL_TOKEN" notify-slack-on-failure: needs: Deploy-Storybook if: always() && needs.Deploy-Storybook.result == 'failure' runs-on: ubuntu-latest environment: ci-protected timeout-minutes: 10 steps: - uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v4 with: persist-credentials: false sparse-checkout: .github/actions/slack-notify - name: Send Slack notification uses: ./.github/actions/slack-notify with: webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }} failed-jobs: "• Deploy-Storybook" title: "🚨 Storybook Deploy Failed" ================================================ FILE: .github/workflows/sync_foss.yml ================================================ name: Sync FOSS Repo on: schedule: # Run daily at 3am PT (11am UTC during PST) - cron: '0 11 * * *' workflow_dispatch: jobs: sync-foss: runs-on: ubuntu-latest environment: ci-protected timeout-minutes: 45 permissions: contents: read steps: - name: Checkout main Onyx repo uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: fetch-depth: 0 persist-credentials: false - name: Install git-filter-repo run: | sudo apt-get update && sudo apt-get install -y git-filter-repo - name: Configure SSH for deploy key env: FOSS_REPO_DEPLOY_KEY: ${{ secrets.FOSS_REPO_DEPLOY_KEY }} run: | mkdir -p ~/.ssh echo "$FOSS_REPO_DEPLOY_KEY" > ~/.ssh/id_ed25519 chmod 600 ~/.ssh/id_ed25519 ssh-keyscan github.com >> ~/.ssh/known_hosts - name: Set Git config run: | git config --global user.name "onyx-bot" git config --global user.email "bot@onyx.app" - name: Build FOSS version run: bash backend/scripts/make_foss_repo.sh - name: Push to FOSS repo env: FOSS_REPO_URL: git@github.com:onyx-dot-app/onyx-foss.git run: | cd /tmp/foss_repo git remote add public "$FOSS_REPO_URL" git push --force public main ================================================ FILE: .github/workflows/tag-nightly.yml ================================================ name: Nightly Tag Push on: schedule: - cron: "0 10 * * *" # Runs every day at 2 AM PST / 3 AM PDT / 10 AM UTC workflow_dispatch: permissions: contents: write # Allows pushing tags to the repository jobs: create-and-push-tag: runs-on: ubuntu-slim environment: ci-protected timeout-minutes: 45 steps: # actions using GITHUB_TOKEN cannot trigger another workflow, but we do want this to trigger docker pushes # see https://github.com/orgs/community/discussions/27028#discussioncomment-3254367 for the workaround we # implement here which needs an actual user's deploy key - name: Checkout code uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6 with: ssh-key: "${{ secrets.DEPLOY_KEY }}" persist-credentials: true - name: Set up Git user run: | git config user.name "Onyx Bot [bot]" git config user.email "onyx-bot[bot]@onyx.app" - name: Check for existing nightly tag id: check_tag run: | if git tag --points-at HEAD --list "nightly-latest*" | grep -q .; then echo "A tag starting with 'nightly-latest' already exists on HEAD." echo "tag_exists=true" >> $GITHUB_OUTPUT else echo "No tag starting with 'nightly-latest' exists on HEAD." echo "tag_exists=false" >> $GITHUB_OUTPUT fi # don't tag again if HEAD already has a nightly-latest tag on it - name: Create Nightly Tag if: steps.check_tag.outputs.tag_exists == 'false' env: DATE: ${{ github.run_id }} run: | TAG_NAME="nightly-latest-$(date +'%Y%m%d')" echo "Creating tag: $TAG_NAME" git tag $TAG_NAME - name: Push Tag if: steps.check_tag.outputs.tag_exists == 'false' run: | TAG_NAME="nightly-latest-$(date +'%Y%m%d')" git push origin $TAG_NAME - name: Send Slack notification if: failure() uses: ./.github/actions/slack-notify with: webhook-url: ${{ secrets.MONITOR_DEPLOYMENTS_WEBHOOK }} title: "🚨 Nightly Tag Push Failed" ref-name: ${{ github.ref_name }} failed-jobs: "create-and-push-tag" ================================================ FILE: .github/workflows/zizmor.yml ================================================ name: Run Zizmor on: push: branches: ["main"] pull_request: branches: ["**"] paths: - ".github/**" permissions: {} jobs: zizmor: name: zizmor runs-on: ubuntu-slim timeout-minutes: 45 permissions: security-events: write # needed for SARIF uploads steps: - name: Checkout repository uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6.0.2 with: persist-credentials: false - name: Install the latest version of uv uses: astral-sh/setup-uv@37802adc94f370d6bfd71619e3f0bf239e1f3b78 # ratchet:astral-sh/setup-uv@v7 with: enable-cache: false version: "0.9.9" - name: Run zizmor run: uv run --no-sync --with zizmor zizmor --format=sarif . > results.sarif env: GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} - name: Upload SARIF file uses: github/codeql-action/upload-sarif@ba454b8ab46733eb6145342877cd148270bb77ab # ratchet:github/codeql-action/upload-sarif@codeql-bundle-v2.23.5 with: sarif_file: results.sarif category: zizmor ================================================ FILE: .gitignore ================================================ # editors .vscode/* !/.vscode/env_template.txt !/.vscode/env.web_template.txt !/.vscode/launch.json !/.vscode/tasks.template.jsonc .zed .cursor !/.cursor/mcp.json !/.cursor/skills/ # macos .DS_store # python .venv .mypy_cache .idea # testing /web/test-results/ backend/onyx/agent_search/main/test_data.json backend/tests/regression/answer_quality/test_data.json backend/tests/regression/search_quality/eval-* backend/tests/regression/search_quality/search_eval_config.yaml backend/tests/regression/search_quality/*.json backend/onyx/evals/data/ backend/onyx/evals/one_off/*.json *.log *.csv # secret files .env jira_test_env settings.json # others /deployment/data/nginx/app.conf /deployment/data/nginx/mcp.conf.inc /deployment/data/nginx/mcp_upstream.conf.inc *.sw? /backend/tests/regression/answer_quality/search_test_config.yaml *.egg-info # Local .terraform directories **/.terraform/* # Local .tfstate files *.tfstate *.tfstate.* # Local .terraform.lock.hcl file .terraform.lock.hcl node_modules # MCP configs .playwright-mcp # plans plans/ ================================================ FILE: .greptile/config.json ================================================ { "labels": [], "comment": "", "fixWithAI": true, "hideFooter": false, "strictness": 3, "statusCheck": true, "commentTypes": [ "logic", "syntax", "style" ], "instructions": "", "disabledLabels": [], "excludeAuthors": [ "dependabot[bot]", "renovate[bot]" ], "ignoreKeywords": "", "ignorePatterns": "", "includeAuthors": [], "summarySection": { "included": true, "collapsible": false, "defaultOpen": false }, "excludeBranches": [], "fileChangeLimit": 300, "includeBranches": [], "includeKeywords": "", "triggerOnUpdates": true, "updateExistingSummaryComment": true, "updateSummaryOnly": false, "issuesTableSection": { "included": true, "collapsible": false, "defaultOpen": false }, "statusCommentsEnabled": true, "confidenceScoreSection": { "included": true, "collapsible": false }, "sequenceDiagramSection": { "included": true, "collapsible": false, "defaultOpen": false }, "shouldUpdateDescription": false, "rules": [ { "scope": ["web/**"], "rule": "In Onyx's Next.js app, the `app/ee/admin/` directory is a filesystem convention for Enterprise Edition route overrides — it does NOT add an `/ee/` prefix to the URL. Both `app/admin/groups/page.tsx` and `app/ee/admin/groups/page.tsx` serve the same URL `/admin/groups`. Hardcoded `/admin/...` paths in router.push() calls are correct and do NOT break EE deployments. Do not flag hardcoded admin paths as bugs." }, { "scope": ["web/**"], "rule": "In Onyx, each API key creates a unique user row in the database with a unique `user_id` (UUID). There is a 1:1 mapping between API keys and their backing user records. Multiple API keys do NOT share the same `user_id`. Do not flag potential duplicate row IDs when using `user_id` from API key descriptors." }, { "scope": ["backend/**/*.py"], "rule": "Never raise HTTPException directly in business code. Use `raise OnyxError(OnyxErrorCode.XXX, \"message\")` from `onyx.error_handling.exceptions`. A global FastAPI exception handler converts OnyxError into structured JSON responses with {\"error_code\": \"...\", \"detail\": \"...\"}. Error codes are defined in `onyx.error_handling.error_codes.OnyxErrorCode`. For upstream errors with dynamic HTTP status codes, use `status_code_override`: `raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status)`." } ] } ================================================ FILE: .greptile/files.json ================================================ [ { "scope": [], "path": "contributing_guides/best_practices.md", "description": "Best practices for contributing to the codebase" }, { "scope": ["web/**"], "path": "web/AGENTS.md", "description": "Frontend coding standards for the web directory" }, { "scope": ["web/**"], "path": "web/tests/README.md", "description": "Frontend testing guide and conventions" }, { "scope": ["web/**"], "path": "web/CLAUDE.md", "description": "Single source of truth for frontend coding standards" }, { "scope": ["web/**"], "path": "web/lib/opal/README.md", "description": "Opal component library usage guide" }, { "scope": ["backend/**"], "path": "backend/tests/README.md", "description": "Backend testing guide covering all 4 test types, fixtures, and conventions" }, { "scope": ["backend/onyx/connectors/**"], "path": "backend/onyx/connectors/README.md", "description": "Connector development guide covering design, interfaces, and required changes" }, { "scope": [], "path": "CLAUDE.md", "description": "Project instructions and coding standards" }, { "scope": [], "path": "backend/alembic/README.md", "description": "Migration guidance, including multi-tenant migration behavior" }, { "scope": [], "path": "deployment/helm/charts/onyx/values-lite.yaml", "description": "Lite deployment Helm values and service assumptions" }, { "scope": [], "path": "deployment/docker_compose/docker-compose.onyx-lite.yml", "description": "Lite deployment Docker Compose overlay and disabled service behavior" } ] ================================================ FILE: .greptile/rules.md ================================================ # Greptile Review Rules ## Type Annotations Use explicit type annotations for variables to enhance code clarity, especially when moving type hints around in the code. ## Best Practices Use the "Engineering Best Practices" section of `CONTRIBUTING.md` as core review context. Prefer consistency with existing patterns, fix issues in code you touch, avoid tacking new features onto muddy interfaces, fail loudly instead of silently swallowing errors, keep code strictly typed, preserve clear state boundaries, remove duplicate or dead logic, break up overly long functions, avoid hidden import-time side effects, respect module boundaries, and favor correctness-by-construction over relying on callers to use an API correctly. ## TODOs Whenever a TODO is added, there must always be an associated name or ticket with that TODO in the style of `TODO(name): ...` or `TODO(1234): ...` ## Debugging Code Remove temporary debugging code before merging to production, especially tenant-specific debugging logs. ## Hardcoded Booleans When hardcoding a boolean variable to a constant value, remove the variable entirely and clean up all places where it's used rather than just setting it to a constant. ## Multi-tenant vs Single-tenant Code changes must consider both multi-tenant and single-tenant deployments. In multi-tenant mode, preserve tenant isolation, ensure tenant context is propagated correctly, and avoid assumptions that only hold for a single shared schema or globally shared state. In single-tenant mode, avoid introducing unnecessary tenant-specific requirements or cloud-only control-plane dependencies. ## Nginx Routing — New Backend Routes Whenever a new backend route is added that does NOT start with `/api`, it must also be explicitly added to ALL nginx configs: - `deployment/helm/charts/onyx/templates/nginx-conf.yaml` (Helm/k8s) - `deployment/data/nginx/app.conf.template` (docker-compose dev) - `deployment/data/nginx/app.conf.template.prod` (docker-compose prod) - `deployment/data/nginx/app.conf.template.no-letsencrypt` (docker-compose no-letsencrypt) Routes not starting with `/api` are not caught by the existing `^/(api|openapi\.json)` location block and will fall through to `location /`, which proxies to the Next.js web server and returns an HTML 404. The new location block must be placed before the `/api` block. Examples of routes that need this treatment: `/scim`, `/mcp`. ## Full vs Lite Deployments Code changes must consider both regular Onyx deployments and Onyx lite deployments. Lite deployments disable the vector DB, Redis, model servers, and background workers by default, use PostgreSQL-backed cache/auth/file storage, and rely on the API server to handle background work. Do not assume those services are available unless the code path is explicitly limited to full deployments. ## SWR Cache Keys — Always Use SWR_KEYS Registry All `useSWR()` calls and `mutate()` calls in the frontend must reference the centralized `SWR_KEYS` registry in `web/src/lib/swr-keys.ts` instead of inline endpoint strings or local string constants. Never write `useSWR("/api/some/endpoint", ...)` or `mutate("/api/some/endpoint")` — always use the corresponding `SWR_KEYS.someEndpoint` constant. If the endpoint does not yet exist in the registry, add it there first. This applies to all variants of an endpoint (e.g. query-string variants like `?get_editable=true` must also be registered as their own key). ================================================ FILE: .pre-commit-config.yaml ================================================ default_install_hook_types: - pre-commit - post-checkout - post-merge - post-rewrite repos: - repo: https://github.com/astral-sh/uv-pre-commit # From: https://github.com/astral-sh/uv-pre-commit/pull/53/commits/d30b4298e4fb63ce8609e29acdbcf4c9018a483c rev: d30b4298e4fb63ce8609e29acdbcf4c9018a483c hooks: - id: uv-sync args: ["--locked", "--all-extras"] - id: uv-lock - id: uv-export name: uv-export default.txt args: [ "--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "backend", "-o", "backend/requirements/default.txt", ] files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$ - id: uv-export name: uv-export dev.txt args: [ "--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "dev", "-o", "backend/requirements/dev.txt", ] files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$ - id: uv-export name: uv-export ee.txt args: [ "--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "ee", "-o", "backend/requirements/ee.txt", ] files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$ - id: uv-export name: uv-export model_server.txt args: [ "--no-emit-project", "--no-default-groups", "--no-hashes", "--extra", "model_server", "-o", "backend/requirements/model_server.txt", ] files: ^(pyproject\.toml|uv\.lock|backend/requirements/.*\.txt)$ - id: uv-run name: Check lazy imports args: ["--active", "--with=onyx-devtools", "ods", "check-lazy-imports"] pass_filenames: true files: ^backend/(?!\.venv/|scripts/).*\.py$ # NOTE: This takes ~6s on a single, large module which is prohibitively slow. # - id: uv-run # name: mypy # args: ["--all-extras", "mypy"] # pass_filenames: true # files: ^backend/.*\.py$ - repo: https://github.com/pre-commit/pre-commit-hooks rev: 3e8a8703264a2f4a69428a0aa4dcb512790b2c8c # frozen: v6.0.0 hooks: - id: check-added-large-files name: Check for added large files args: ["--maxkb=1500"] - repo: https://github.com/rhysd/actionlint rev: a443f344ff32813837fa49f7aa6cbc478d770e62 # frozen: v1.7.9 hooks: - id: actionlint - repo: https://github.com/psf/black rev: 8a737e727ac5ab2f1d4cf5876720ed276dc8dc4b # frozen: 25.1.0 hooks: - id: black language_version: python3.11 # this is a fork which keeps compatibility with black - repo: https://github.com/wimglenn/reorder-python-imports-black rev: f55cd27f90f0cf0ee775002c2383ce1c7820013d # frozen: v3.14.0 hooks: - id: reorder-python-imports args: ["--py311-plus", "--application-directories=backend/"] # need to ignore alembic files, since reorder-python-imports gets confused # and thinks that alembic is a local package since there is a folder # in the backend directory called `alembic` exclude: ^backend/alembic/ # These settings will remove unused imports with side effects # Note: The repo currently does not and should not have imports with side effects - repo: https://github.com/PyCQA/autoflake rev: 0544741e2b4a22b472d9d93e37d4ea9153820bb1 # frozen: v2.3.1 hooks: - id: autoflake args: [ "--remove-all-unused-imports", "--remove-unused-variables", "--in-place", "--recursive", ] - repo: https://github.com/golangci/golangci-lint rev: 5d1e709b7be35cb2025444e19de266b056b7b7ee # frozen: v2.10.1 hooks: - id: golangci-lint language_version: "1.26.1" entry: bash -c "find . -name go.mod -not -path './.venv/*' -print0 | xargs -0 -I{} bash -c 'cd \"$(dirname {})\" && golangci-lint run ./...'" - repo: https://github.com/astral-sh/ruff-pre-commit # Ruff version. rev: 971923581912ef60a6b70dbf0c3e9a39563c9d47 # frozen: v0.11.4 hooks: - id: ruff - repo: https://github.com/pre-commit/mirrors-prettier rev: ffb6a759a979008c0e6dff86e39f4745a2d9eac4 # frozen: v3.1.0 hooks: - id: prettier types_or: [html, css, javascript, ts, tsx] language_version: system - repo: https://github.com/sirwart/ripsecrets rev: 7d94620933e79b8acaa0cd9e60e9864b07673d86 # frozen: v0.1.11 hooks: - id: ripsecrets args: - --additional-pattern - ^sk-[A-Za-z0-9_\-]{20,}$ - repo: local hooks: - id: terraform-fmt name: terraform fmt entry: terraform fmt -recursive language: system pass_filenames: false files: \.tf$ - id: npm-install name: npm install description: "Automatically run 'npm install' after a checkout, pull or rebase" language: system entry: bash -c 'cd web && npm install --no-save' pass_filenames: false files: ^web/package(-lock)?\.json$ stages: [post-checkout, post-merge, post-rewrite] - id: npm-install-check name: npm install --package-lock-only description: "Check the 'web/package-lock.json' is updated" language: system entry: bash -c 'cd web && npm install --package-lock-only' pass_filenames: false files: ^web/package(-lock)?\.json$ # Uses tsgo (TypeScript's native Go compiler) for ~10x faster type checking. # This is a preview package - if it breaks: # 1. Try updating: cd web && npm update @typescript/native-preview # 2. Or fallback to tsc: replace 'tsgo' with 'tsc' below - id: typescript-check name: TypeScript type check entry: bash -c 'cd web && npx tsgo --noEmit --project tsconfig.types.json' language: system pass_filenames: false files: ^web/.*\.(ts|tsx)$ ================================================ FILE: .prettierignore ================================================ backend/tests/integration/tests/pruning/website ================================================ FILE: .vscode/env.web_template.txt ================================================ # Copy this file to .env.web in the .vscode folder. # Fill in the values as needed # Web Server specific environment variables # Minimal set needed for Next.js dev server # Auth AUTH_TYPE=basic DEV_MODE=true # Enable the full set of Danswer Enterprise Edition features. # NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you # are using this for local testing/development). ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=false # Enable Onyx Craft ENABLE_CRAFT=true ================================================ FILE: .vscode/env_template.txt ================================================ # Copy this file to .env in the .vscode folder. # Fill in the values as needed; it is recommended to set the # GEN_AI_API_KEY value to avoid having to set up an LLM in the UI. # Also check out onyx/backend/scripts/restart_containers.sh for a script to # restart the containers which Onyx relies on outside of VSCode/Cursor # processes. AUTH_TYPE=basic # Recommended for basic auth - used for signing password reset and verification tokens # Generate a secure value with: openssl rand -hex 32 USER_AUTH_SECRET="" DEV_MODE=true # Always keep these on for Dev. # Logs model prompts, reasoning, and answer to stdout. LOG_ONYX_MODEL_INTERACTIONS=False # More verbose logging LOG_LEVEL=debug # Useful if you want to toggle auth on/off (google_oauth/OIDC specifically). OAUTH_CLIENT_ID= OAUTH_CLIENT_SECRET= OPENID_CONFIG_URL= SAML_CONF_DIR=//onyx/backend/ee/onyx/configs/saml_config # Generally not useful for dev, we don't generally want to set up an SMTP server # for dev. REQUIRE_EMAIL_VERIFICATION=False # Set these so if you wipe the DB, you don't end up having to go through the UI # every time. GEN_AI_API_KEY= OPENAI_API_KEY= # If answer quality isn't important for dev, use gpt-4o-mini since it's cheaper. GEN_AI_MODEL_VERSION=gpt-4o # Python stuff PYTHONPATH=../backend PYTHONUNBUFFERED=1 # Enable the full set of Danswer Enterprise Edition features. # NOTE: DO NOT ENABLE THIS UNLESS YOU HAVE A PAID ENTERPRISE LICENSE (or if you # are using this for local testing/development). ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=False # S3 File Store Configuration (MinIO for local development) S3_ENDPOINT_URL=http://localhost:9004 S3_FILE_STORE_BUCKET_NAME=onyx-file-store-bucket S3_AWS_ACCESS_KEY_ID=minioadmin S3_AWS_SECRET_ACCESS_KEY=minioadmin # Show extra/uncommon connectors. SHOW_EXTRA_CONNECTORS=True # Local langsmith tracing LANGSMITH_TRACING="true" LANGSMITH_ENDPOINT="https://api.smith.langchain.com" LANGSMITH_API_KEY= LANGSMITH_PROJECT= # Local Confluence OAuth testing # OAUTH_CONFLUENCE_CLOUD_CLIENT_ID= # OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET= # NEXT_PUBLIC_TEST_ENV=True # OpenSearch # Arbitrary password is fine for local development. OPENSEARCH_INITIAL_ADMIN_PASSWORD= ================================================ FILE: .vscode/launch.json ================================================ { // Use IntelliSense to learn about possible attributes. // Hover to view descriptions of existing attributes. // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 "version": "0.2.0", "compounds": [ { // Dummy entry used to label the group "name": "--- Compound ---", "configurations": ["--- Individual ---"], "presentation": { "group": "1" } }, { "name": "Run All Onyx Services", "configurations": [ "Web Server", "Model Server", "API Server", "MCP Server", "Slack Bot", "Celery primary", "Celery light", "Celery heavy", "Celery docfetching", "Celery docprocessing", "Celery user_file_processing", "Celery beat" ], "presentation": { "group": "1" } }, { "name": "Web / Model / API", "configurations": ["Web Server", "Model Server", "API Server"], "presentation": { "group": "1" } }, { "name": "Celery", "configurations": [ "Celery primary", "Celery light", "Celery heavy", "Celery kg_processing", "Celery monitoring", "Celery user_file_processing", "Celery docfetching", "Celery docprocessing", "Celery beat" ], "presentation": { "group": "1" }, "stopAll": true } ], "configurations": [ { // Dummy entry used to label the group "name": "--- Individual ---", "type": "node", "request": "launch", "presentation": { "group": "2", "order": 0 } }, { "name": "Web Server", "type": "node", "request": "launch", "cwd": "${workspaceRoot}/web", "runtimeExecutable": "npm", "envFile": "${workspaceFolder}/.vscode/.env.web", "runtimeArgs": ["run", "dev"], "presentation": { "group": "2" }, "console": "integratedTerminal", "consoleTitle": "Web Server Console" }, { "name": "Model Server", "consoleName": "Model Server", "type": "debugpy", "request": "launch", "module": "uvicorn", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1" }, "args": ["model_server.main:app", "--reload", "--port", "9000"], "presentation": { "group": "2" }, "consoleTitle": "Model Server Console" }, { "name": "API Server", "consoleName": "API Server", "type": "debugpy", "request": "launch", "module": "uvicorn", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1" }, "args": ["onyx.main:app", "--reload", "--port", "8080"], "presentation": { "group": "2" }, "consoleTitle": "API Server Console", "justMyCode": false }, { "name": "Slack Bot", "consoleName": "Slack Bot", "type": "debugpy", "request": "launch", "program": "onyx/onyxbot/slack/listener.py", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." }, "presentation": { "group": "2" }, "consoleTitle": "Slack Bot Console" }, { "name": "Discord Bot", "consoleName": "Discord Bot", "type": "debugpy", "request": "launch", "program": "onyx/onyxbot/discord/client.py", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." }, "presentation": { "group": "2" }, "consoleTitle": "Discord Bot Console" }, { "name": "MCP Server", "consoleName": "MCP Server", "type": "debugpy", "request": "launch", "module": "uvicorn", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { "MCP_SERVER_ENABLED": "true", "MCP_SERVER_PORT": "8090", "MCP_SERVER_CORS_ORIGINS": "http://localhost:*", "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1" }, "args": [ "onyx.mcp_server.api:mcp_app", "--reload", "--port", "8090", "--timeout-graceful-shutdown", "0" ], "presentation": { "group": "2" }, "consoleTitle": "MCP Server Console" }, { "name": "Celery primary", "type": "debugpy", "request": "launch", "module": "celery", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "INFO", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." }, "args": [ "-A", "onyx.background.celery.versioned_apps.primary", "worker", "--pool=threads", "--concurrency=4", "--prefetch-multiplier=1", "--loglevel=INFO", "--hostname=primary@%n", "-Q", "celery" ], "presentation": { "group": "2" }, "consoleTitle": "Celery primary Console" }, { "name": "Celery light", "type": "debugpy", "request": "launch", "module": "celery", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "INFO", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." }, "args": [ "-A", "onyx.background.celery.versioned_apps.light", "worker", "--pool=threads", "--concurrency=64", "--prefetch-multiplier=8", "--loglevel=INFO", "--hostname=light@%n", "-Q", "vespa_metadata_sync,connector_deletion,doc_permissions_upsert,index_attempt_cleanup,opensearch_migration" ], "presentation": { "group": "2" }, "consoleTitle": "Celery light Console" }, { "name": "Celery heavy", "type": "debugpy", "request": "launch", "module": "celery", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "INFO", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." }, "args": [ "-A", "onyx.background.celery.versioned_apps.heavy", "worker", "--pool=threads", "--concurrency=4", "--prefetch-multiplier=1", "--loglevel=INFO", "--hostname=heavy@%n", "-Q", "connector_pruning,connector_doc_permissions_sync,connector_external_group_sync,csv_generation" ], "presentation": { "group": "2" }, "consoleTitle": "Celery heavy Console", "justMyCode": false }, { "name": "Celery kg_processing", "type": "debugpy", "request": "launch", "module": "celery", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "INFO", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." }, "args": [ "-A", "onyx.background.celery.versioned_apps.kg_processing", "worker", "--pool=threads", "--concurrency=2", "--prefetch-multiplier=1", "--loglevel=INFO", "--hostname=kg_processing@%n", "-Q", "kg_processing" ], "presentation": { "group": "2" }, "consoleTitle": "Celery kg_processing Console" }, { "name": "Celery monitoring", "type": "debugpy", "request": "launch", "module": "celery", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "INFO", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." }, "args": [ "-A", "onyx.background.celery.versioned_apps.monitoring", "worker", "--pool=threads", "--concurrency=1", "--prefetch-multiplier=1", "--loglevel=INFO", "--hostname=monitoring@%n", "-Q", "monitoring" ], "presentation": { "group": "2" }, "consoleTitle": "Celery monitoring Console" }, { "name": "Celery user_file_processing", "type": "debugpy", "request": "launch", "module": "celery", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "INFO", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." }, "args": [ "-A", "onyx.background.celery.versioned_apps.user_file_processing", "worker", "--pool=threads", "--concurrency=2", "--prefetch-multiplier=1", "--loglevel=INFO", "--hostname=user_file_processing@%n", "-Q", "user_file_processing,user_file_project_sync,user_file_delete" ], "presentation": { "group": "2" }, "consoleTitle": "Celery user_file_processing Console", "justMyCode": false }, { "name": "Celery docfetching", "type": "debugpy", "request": "launch", "module": "celery", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." }, "args": [ "-A", "onyx.background.celery.versioned_apps.docfetching", "worker", "--pool=threads", "--prefetch-multiplier=1", "--loglevel=INFO", "--hostname=docfetching@%n", "-Q", "connector_doc_fetching" ], "presentation": { "group": "2" }, "consoleTitle": "Celery docfetching Console", "justMyCode": false }, { "name": "Celery docprocessing", "type": "debugpy", "request": "launch", "module": "celery", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { "ENABLE_MULTIPASS_INDEXING": "false", "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." }, "args": [ "-A", "onyx.background.celery.versioned_apps.docprocessing", "worker", "--pool=threads", "--prefetch-multiplier=1", "--loglevel=INFO", "--hostname=docprocessing@%n", "-Q", "docprocessing" ], "presentation": { "group": "2" }, "consoleTitle": "Celery docprocessing Console", "justMyCode": false }, { "name": "Celery beat", "type": "debugpy", "request": "launch", "module": "celery", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." }, "args": [ "-A", "onyx.background.celery.versioned_apps.beat", "beat", "--loglevel=INFO" ], "presentation": { "group": "2" }, "consoleTitle": "Celery beat Console" }, { "name": "Pytest", "consoleName": "Pytest", "type": "debugpy", "request": "launch", "module": "pytest", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." }, "args": [ "-v" // Specify a specific module/test to run or provide nothing to run all tests // "tests/unit/onyx/llm/answering/test_prune_and_merge.py" ], "presentation": { "group": "2" }, "consoleTitle": "Pytest Console" }, { // Dummy entry used to label the group "name": "--- Tasks ---", "type": "node", "request": "launch", "presentation": { "group": "3", "order": 0 } }, { "name": "Clear and Restart External Volumes and Containers", "type": "node", "request": "launch", "runtimeExecutable": "bash", "runtimeArgs": [ "${workspaceFolder}/backend/scripts/restart_containers.sh" ], "cwd": "${workspaceFolder}", "console": "integratedTerminal", "presentation": { "group": "3" } }, { "name": "Eval CLI", "type": "debugpy", "request": "launch", "program": "${workspaceFolder}/backend/onyx/evals/eval_cli.py", "cwd": "${workspaceFolder}/backend", "console": "integratedTerminal", "justMyCode": false, "envFile": "${workspaceFolder}/.vscode/.env", "presentation": { "group": "3" }, "env": { "LOG_LEVEL": "INFO", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." }, "args": ["--verbose"], "consoleTitle": "Eval CLI Console" }, { // Celery jobs launched through a single background script (legacy) // Recommend using the "Celery (all)" compound launch instead. "name": "Background Jobs", "consoleName": "Background Jobs", "type": "debugpy", "request": "launch", "program": "scripts/dev_run_background_jobs.py", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { "LOG_LEVEL": "DEBUG", "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." } }, { "name": "Install Python Requirements", "type": "node", "request": "launch", "runtimeExecutable": "uv", "runtimeArgs": [ "sync", "--all-extras" ], "cwd": "${workspaceFolder}", "console": "integratedTerminal", "presentation": { "group": "3" } }, { "name": "Build Sandbox Templates", "type": "debugpy", "request": "launch", "module": "onyx.server.features.build.sandbox.build_templates", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.vscode/.env", "env": { "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." }, "console": "integratedTerminal", "presentation": { "group": "3" }, "consoleTitle": "Build Sandbox Templates" }, { // Dummy entry used to label the group "name": "--- Database ---", "type": "node", "request": "launch", "presentation": { "group": "4", "order": 0 } }, { "name": "Restore seeded database dump", "type": "node", "request": "launch", "runtimeExecutable": "uv", "runtimeArgs": [ "run", "--with", "onyx-devtools", "ods", "db", "restore", "--fetch-seeded", "--yes" ], "cwd": "${workspaceFolder}", "console": "integratedTerminal", "presentation": { "group": "4" } }, { "name": "Clean restore seeded database dump (destructive)", "type": "node", "request": "launch", "runtimeExecutable": "uv", "runtimeArgs": [ "run", "--with", "onyx-devtools", "ods", "db", "restore", "--fetch-seeded", "--clean", "--yes" ], "cwd": "${workspaceFolder}", "console": "integratedTerminal", "presentation": { "group": "4" } }, { "name": "Create database snapshot", "type": "node", "request": "launch", "runtimeExecutable": "uv", "runtimeArgs": [ "run", "--with", "onyx-devtools", "ods", "db", "dump", "backup.dump" ], "cwd": "${workspaceFolder}", "console": "integratedTerminal", "presentation": { "group": "4" } }, { "name": "Clean restore database snapshot (destructive)", "type": "node", "request": "launch", "runtimeExecutable": "uv", "runtimeArgs": [ "run", "--with", "onyx-devtools", "ods", "db", "restore", "--clean", "--yes", "backup.dump" ], "cwd": "${workspaceFolder}", "console": "integratedTerminal", "presentation": { "group": "4" } }, { "name": "Upgrade database to head revision", "type": "node", "request": "launch", "runtimeExecutable": "uv", "runtimeArgs": [ "run", "--with", "onyx-devtools", "ods", "db", "upgrade" ], "cwd": "${workspaceFolder}", "console": "integratedTerminal", "presentation": { "group": "4" } }, { // script to generate the openapi schema "name": "Onyx OpenAPI Schema Generator", "type": "debugpy", "request": "launch", "program": "backend/scripts/onyx_openapi_schema.py", "cwd": "${workspaceFolder}", "envFile": "${workspaceFolder}/.env", "env": { "PYTHONUNBUFFERED": "1", "PYTHONPATH": "backend" }, "args": ["--filename", "backend/generated/openapi.json", "--generate-python-client"] }, { // script to debug multi tenant db issues "name": "Onyx DB Manager (Top Chunks)", "type": "debugpy", "request": "launch", "program": "scripts/debugging/onyx_db.py", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.env", "env": { "PYTHONUNBUFFERED": "1", "PYTHONPATH": "." }, "args": [ "--password", "your_password_here", "--port", "5433", "--report", "top-chunks", "--filename", "generated/tenants_by_num_docs.csv" ] }, { "name": "Debug React Web App in Chrome", "type": "chrome", "request": "launch", "url": "http://localhost:3000", "webRoot": "${workspaceFolder}/web" } ] } ================================================ FILE: .vscode/tasks.template.jsonc ================================================ { "version": "2.0.0", "tasks": [ { "type": "austin", "label": "Profile celery beat", "envFile": "${workspaceFolder}/.env", "options": { "cwd": "${workspaceFolder}/backend" }, "command": [ "sudo", "-E" ], "args": [ "celery", "-A", "onyx.background.celery.versioned_apps.beat", "beat", "--loglevel=INFO" ] }, { "type": "shell", "label": "Generate Onyx OpenAPI Python client", "cwd": "${workspaceFolder}/backend", "envFile": "${workspaceFolder}/.env", "options": { "cwd": "${workspaceFolder}/backend" }, "command": [ "openapi-generator" ], "args": [ "generate", "-i", "generated/openapi.json", "-g", "python", "-o", "generated/onyx_openapi_client", "--package-name", "onyx_openapi_client", ] }, { "type": "shell", "label": "Generate Typescript Fetch client (openapi-generator)", "envFile": "${workspaceFolder}/.env", "options": { "cwd": "${workspaceFolder}" }, "command": [ "openapi-generator" ], "args": [ "generate", "-i", "backend/generated/openapi.json", "-g", "typescript-fetch", "-o", "${workspaceFolder}/web/src/lib/generated/onyx_api", "--additional-properties=disallowAdditionalPropertiesIfNotPresent=false,legacyDiscriminatorBehavior=false,supportsES6=true", ] }, { "type": "shell", "label": "Generate TypeScript Client (openapi-ts)", "envFile": "${workspaceFolder}/.env", "options": { "cwd": "${workspaceFolder}/web" }, "command": [ "npx" ], "args": [ "openapi-typescript", "../backend/generated/openapi.json", "--output", "./src/lib/generated/onyx-schema.ts", ] }, { "type": "shell", "label": "Generate TypeScript Client (orval)", "envFile": "${workspaceFolder}/.env", "options": { "cwd": "${workspaceFolder}/web" }, "command": [ "npx" ], "args": [ "orval", "--config", "orval.config.js", ] } ] } ================================================ FILE: AGENTS.md ================================================ # PROJECT KNOWLEDGE BASE This file provides guidance to AI agents when working with code in this repository. ## KEY NOTES - If you run into any missing python dependency errors, try running your command with `source .venv/bin/activate` \ to assume the python venv. - To make tests work, check the `.env` file at the root of the project to find an OpenAI key. - If using `playwright` to explore the frontend, you can usually log in with username `a@example.com` and password `a`. The app can be accessed at `http://localhost:3000`. - You should assume that all Onyx services are running. To verify, you can check the `backend/log` directory to make sure we see logs coming out from the relevant service. - To connect to the Postgres database, use: `docker exec -it onyx-relational_db-1 psql -U postgres -c ""` - When making calls to the backend, always go through the frontend. E.g. make a call to `http://localhost:3000/api/persona` not `http://localhost:8080/api/persona` - Put ALL db operations under the `backend/onyx/db` / `backend/ee/onyx/db` directories. Don't run queries outside of those directories. ## Project Overview **Onyx** (formerly Danswer) is an open-source Gen-AI and Enterprise Search platform that connects to company documents, apps, and people. It features a modular architecture with both Community Edition (MIT licensed) and Enterprise Edition offerings. ### Background Workers (Celery) Onyx uses Celery for asynchronous task processing with multiple specialized workers: #### Worker Types 1. **Primary Worker** (`celery_app.py`) - Coordinates core background tasks and system-wide operations - Handles connector management, document sync, pruning, and periodic checks - Runs with 4 threads concurrency - Tasks: connector deletion, vespa sync, pruning, LLM model updates, user file sync 2. **Docfetching Worker** (`docfetching`) - Fetches documents from external data sources (connectors) - Spawns docprocessing tasks for each document batch - Implements watchdog monitoring for stuck connectors - Configurable concurrency (default from env) 3. **Docprocessing Worker** (`docprocessing`) - Processes fetched documents through the indexing pipeline: - Upserts documents to PostgreSQL - Chunks documents and adds contextual information - Embeds chunks via model server - Writes chunks to Vespa vector database - Updates document metadata - Configurable concurrency (default from env) 4. **Light Worker** (`light`) - Handles lightweight, fast operations - Tasks: vespa operations, document permissions sync, external group sync - Higher concurrency for quick tasks 5. **Heavy Worker** (`heavy`) - Handles resource-intensive operations - Primary task: document pruning operations - Runs with 4 threads concurrency 6. **KG Processing Worker** (`kg_processing`) - Handles Knowledge Graph processing and clustering - Builds relationships between documents - Runs clustering algorithms - Configurable concurrency 7. **Monitoring Worker** (`monitoring`) - System health monitoring and metrics collection - Monitors Celery queues, process memory, and system status - Single thread (monitoring doesn't need parallelism) - Cloud-specific monitoring tasks 8. **User File Processing Worker** (`user_file_processing`) - Processes user-uploaded files - Handles user file indexing and project synchronization - Configurable concurrency 9. **Beat Worker** (`beat`) - Celery's scheduler for periodic tasks - Uses DynamicTenantScheduler for multi-tenant support - Schedules tasks like: - Indexing checks (every 15 seconds) - Connector deletion checks (every 20 seconds) - Vespa sync checks (every 20 seconds) - Pruning checks (every 20 seconds) - KG processing (every 60 seconds) - Monitoring tasks (every 5 minutes) - Cleanup tasks (hourly) #### Key Features - **Thread-based Workers**: All workers use thread pools (not processes) for stability - **Tenant Awareness**: Multi-tenant support with per-tenant task isolation. There is a middleware layer that automatically finds the appropriate tenant ID when sending tasks via Celery Beat. - **Task Prioritization**: High, Medium, Low priority queues - **Monitoring**: Built-in heartbeat and liveness checking - **Failure Handling**: Automatic retry and failure recovery mechanisms - **Redis Coordination**: Inter-process communication via Redis - **PostgreSQL State**: Task state and metadata stored in PostgreSQL #### Important Notes **Defining Tasks**: - Always use `@shared_task` rather than `@celery_app` - Put tasks under `background/celery/tasks/` or `ee/background/celery/tasks` - Never enqueue a task without an expiration. Always supply `expires=` when sending tasks, either from the beat schedule or directly from another task. It should never be acceptable to submit code which enqueues tasks without an expiration, as doing so can lead to unbounded task queue growth. **Defining APIs**: When creating new FastAPI APIs, do NOT use the `response_model` field. Instead, just type the function. **Testing Updates**: If you make any updates to a celery worker and you want to test these changes, you will need to ask me to restart the celery worker. There is no auto-restart on code-change mechanism. **Task Time Limits**: Since all tasks are executed in thread pools, the time limit features of Celery are silently disabled and won't work. Timeout logic must be implemented within the task itself. ### Code Quality ```bash # Install and run pre-commit hooks pre-commit install pre-commit run --all-files ``` NOTE: Always make sure everything is strictly typed (both in Python and Typescript). ## Architecture Overview ### Technology Stack - **Backend**: Python 3.11, FastAPI, SQLAlchemy, Alembic, Celery - **Frontend**: Next.js 15+, React 18, TypeScript, Tailwind CSS - **Database**: PostgreSQL with Redis caching - **Search**: Vespa vector database - **Auth**: OAuth2, SAML, multi-provider support - **AI/ML**: LangChain, LiteLLM, multiple embedding models ### Directory Structure ``` backend/ ├── onyx/ │ ├── auth/ # Authentication & authorization │ ├── chat/ # Chat functionality & LLM interactions │ ├── connectors/ # Data source connectors │ ├── db/ # Database models & operations │ ├── document_index/ # Vespa integration │ ├── federated_connectors/ # External search connectors │ ├── llm/ # LLM provider integrations │ └── server/ # API endpoints & routers ├── ee/ # Enterprise Edition features ├── alembic/ # Database migrations └── tests/ # Test suites web/ ├── src/app/ # Next.js app router pages ├── src/components/ # Reusable React components └── src/lib/ # Utilities & business logic ``` ## Frontend Standards Frontend standards for the `web/` and `desktop/` projects live in `web/AGENTS.md`. ## Database & Migrations ### Running Migrations ```bash # Standard migrations alembic upgrade head # Multi-tenant (Enterprise) alembic -n schema_private upgrade head ``` ### Creating Migrations ```bash # Create migration alembic revision -m "description" # Multi-tenant migration alembic -n schema_private revision -m "description" ``` Write the migration manually and place it in the file that alembic creates when running the above command. ## Testing Strategy First, you must activate the virtual environment with `source .venv/bin/activate`. There are 4 main types of tests within Onyx: ### Unit Tests These should not assume any Onyx/external services are available to be called. Interactions with the outside world should be mocked using `unittest.mock`. Generally, only write these for complex, isolated modules e.g. `citation_processing.py`. To run them: ```bash pytest -xv backend/tests/unit ``` ### External Dependency Unit Tests These tests assume that all external dependencies of Onyx are available and callable (e.g. Postgres, Redis, MinIO/S3, Vespa are running + OpenAI can be called + any request to the internet is fine + etc.). However, the actual Onyx containers are not running and with these tests we call the function to test directly. We can also mock components/calls at will. The goal with these tests are to minimize mocking while giving some flexibility to mock things that are flakey, need strictly controlled behavior, or need to have their internal behavior validated (e.g. verify a function is called with certain args, something that would be impossible with proper integration tests). A great example of this type of test is `backend/tests/external_dependency_unit/connectors/confluence/test_confluence_group_sync.py`. To run them: ```bash python -m dotenv -f .vscode/.env run -- pytest backend/tests/external_dependency_unit ``` ### Integration Tests Standard integration tests. Every test in `backend/tests/integration` runs against a real Onyx deployment. We cannot mock anything in these tests. Prefer writing integration tests (or External Dependency Unit Tests if mocking/internal verification is necessary) over any other type of test. Tests are parallelized at a directory level. When writing integration tests, make sure to check the root `conftest.py` for useful fixtures + the `backend/tests/integration/common_utils` directory for utilities. Prefer (if one exists), calling the appropriate Manager class in the utils over directly calling the APIs with a library like `requests`. Prefer using fixtures rather than calling the utilities directly (e.g. do NOT create admin users with `admin_user = UserManager.create(name="admin_user")`, instead use the `admin_user` fixture). A great example of this type of test is `backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py`. To run them: ```bash python -m dotenv -f .vscode/.env run -- pytest backend/tests/integration ``` ### Playwright (E2E) Tests These tests are an even more complete version of the Integration Tests mentioned above. Has all services of Onyx running, _including_ the Web Server. Use these tests for anything that requires significant frontend <-> backend coordination. Tests are located at `web/tests/e2e`. Tests are written in TypeScript. To run them: ```bash npx playwright test ``` For shared fixtures, best practices, and detailed guidance, see `backend/tests/README.md`. ## Logs When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access to logs via the `backend/log/_debug.log` file. All Onyx services (api_server, web_server, celery_X) will be tailing their logs to this file. ## Security Considerations - Never commit API keys or secrets to repository - Use encrypted credential storage for connector credentials - Follow RBAC patterns for new features - Implement proper input validation with Pydantic models - Use parameterized queries to prevent SQL injection ## AI/LLM Integration - Multiple LLM providers supported via LiteLLM - Configurable models per feature (chat, search, embeddings) - Streaming support for real-time responses - Token management and rate limiting - Custom prompts and agent actions ## Creating a Plan When creating a plan in the `plans` directory, make sure to include at least these elements: **Issues to Address** What the change is meant to do. **Important Notes** Things you come across in your research that are important to the implementation. **Implementation strategy** How you are going to make the changes happen. High level approach. **Tests** What unit (use rarely), external dependency unit, integration, and playwright tests you plan to write to verify the correct behavior. Don't overtest. Usually, a given change only needs one type of test. Do NOT include these: _Timeline_, _Rollback plan_ This is a minimal list - feel free to include more. Do NOT write code as part of your plan. Keep it high level. You can reference certain files or functions though. Before writing your plan, make sure to do research. Explore the relevant sections in the codebase. ## Error Handling **Always raise `OnyxError` from `onyx.error_handling.exceptions` instead of `HTTPException`. Never hardcode status codes or use `starlette.status` / `fastapi.status` constants directly.** A global FastAPI exception handler converts `OnyxError` into a JSON response with the standard `{"error_code": "...", "detail": "..."}` shape. This eliminates boilerplate and keeps error handling consistent across the entire backend. ```python from onyx.error_handling.error_codes import OnyxErrorCode from onyx.error_handling.exceptions import OnyxError # ✅ Good raise OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found") # ✅ Good — no extra message needed raise OnyxError(OnyxErrorCode.UNAUTHENTICATED) # ✅ Good — upstream service with dynamic status code raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status) # ❌ Bad — using HTTPException directly raise HTTPException(status_code=404, detail="Session not found") # ❌ Bad — starlette constant raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Access denied") ``` Available error codes are defined in `backend/onyx/error_handling/error_codes.py`. If a new error category is needed, add it there first — do not invent ad-hoc codes. **Upstream service errors:** When forwarding errors from an upstream service where the HTTP status code is dynamic (comes from the upstream response), use `status_code_override`: ```python raise OnyxError(OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=e.response.status_code) ``` ## Best Practices In addition to the other content in this file, best practices for contributing to the codebase can be found in the "Engineering Best Practices" section of `CONTRIBUTING.md`. Understand its contents and follow them. ================================================ FILE: CONTRIBUTING.md ================================================ # Contributing to Onyx Hey there! We are so excited that you're interested in Onyx. ## Table of Contents - [Contribution Opportunities](#contribution-opportunities) - [Contribution Process](#contribution-process) - [Development Setup](#development-setup) - [Prerequisites](#prerequisites) - [Backend: Python Requirements](#backend-python-requirements) - [Frontend: Node Dependencies](#frontend-node-dependencies) - [Formatting and Linting](#formatting-and-linting) - [Running the Application](#running-the-application) - [VSCode Debugger (Recommended)](#vscode-debugger-recommended) - [Manually Running for Development](#manually-running-for-development) - [Running in Docker](#running-in-docker) - [macOS-Specific Notes](#macos-specific-notes) - [Engineering Best Practices](#engineering-best-practices) - [Principles and Collaboration](#principles-and-collaboration) - [Style and Maintainability](#style-and-maintainability) - [Performance and Correctness](#performance-and-correctness) - [Repository Conventions](#repository-conventions) - [Release Process](#release-process) - [Getting Help](#getting-help) - [Enterprise Edition Contributions](#enterprise-edition-contributions) --- ## Contribution Opportunities The [GitHub Issues](https://github.com/onyx-dot-app/onyx/issues) page is a great place to look for and share contribution ideas. If you have your own feature that you would like to build, please create an issue and community members can provide feedback and upvote if they feel a common need. --- ## Contribution Process To contribute, please follow the ["fork and pull request"](https://docs.github.com/en/get-started/quickstart/contributing-to-projects) workflow. ### 1. Get the feature or enhancement approved Create a GitHub issue and see if there are upvotes. If you feel the feature is sufficiently value-additive and you would like approval to contribute it to the repo, tag [Yuhong](https://github.com/yuhongsun96) to review. If you do not get a response within a week, feel free to email yuhong@onyx.app and include the issue in the message. Not all small features and enhancements will be accepted as there is a balance between feature richness and bloat. We strive to provide the best user experience possible so we have to be intentional about what we include in the app. ### 2. Get the design approved The Onyx team will either provide a design doc and PRD for the feature or request one from you, the contributor. The scope and detail of the design will depend on the individual feature. ### 3. IP attribution for EE contributions If you are contributing features to Onyx Enterprise Edition, you are required to sign the [IP Assignment Agreement](contributor_ip_assignment/EE_Contributor_IP_Assignment_Agreement.md). ### 4. Review and testing Your features must pass all tests and all comments must be addressed prior to merging. ### Implicit agreements If we approve an issue, we are promising you the following: - Your work will receive timely attention and we will put aside other important items to ensure you are not blocked. - You will receive necessary coaching on eng quality, system design, etc. to ensure the feature is completed well. - The Onyx team will pull resources and bandwidth from design, PM, and engineering to ensure that you have all the resources to build the feature to the quality required for merging. Because this is a large investment from our team, we ask that you: - Thoroughly read all the requirements of the design docs, engineering best practices, and try to minimize overhead for the Onyx team. - Complete the feature in a timely manner to reduce context switching and an ongoing resource pull from the Onyx team. --- ## Development Setup Onyx being a fully functional app, relies on some external software, specifically: - [Postgres](https://www.postgresql.org/) (Relational DB) - [OpenSearch](https://opensearch.org/) (Vector DB/Search Engine) - [Redis](https://redis.io/) (Cache) - [MinIO](https://min.io/) (File Store) - [Nginx](https://nginx.org/) (Not needed for development flows generally) > **Note:** > This guide provides instructions to build and run Onyx locally from source with Docker containers providing the above external software. > We believe this combination is easier for development purposes. If you prefer to use pre-built container images, see [Running in Docker](#running-in-docker) below. ### Prerequisites - **Python 3.11** — If using a lower version, modifications will have to be made to the code. Higher versions may have library compatibility issues. - **Docker** — Required for running external services (Postgres, OpenSearch, Redis, MinIO). - **Node.js v22** — We recommend using [nvm](https://github.com/nvm-sh/nvm) to manage Node installations. ### Backend: Python Requirements We use [uv](https://docs.astral.sh/uv/) and recommend creating a [virtual environment](https://docs.astral.sh/uv/pip/environments/#using-a-virtual-environment). ```bash uv venv .venv --python 3.11 source .venv/bin/activate ``` _For Windows, activate the virtual environment using Command Prompt:_ ```bash .venv\Scripts\activate ``` If using PowerShell, the command slightly differs: ```powershell .venv\Scripts\Activate.ps1 ``` Install the required Python dependencies: ```bash uv sync --all-extras ``` Install Playwright for Python (headless browser required by the Web Connector): ```bash uv run playwright install ``` ### Frontend: Node Dependencies ```bash nvm install 22 && nvm use 22 node -v # verify your active version ``` Navigate to `onyx/web` and run: ```bash npm i ``` ### Formatting and Linting #### Backend Set up pre-commit hooks (black / reorder-python-imports): ```bash uv run pre-commit install ``` We also use `mypy` for static type checking. Onyx is fully type-annotated, and we want to keep it that way! To run the mypy checks manually: ```bash uv run mypy . # from onyx/backend ``` #### Frontend We use `prettier` for formatting. The desired version will be installed via `npm i` from the `onyx/web` directory. To run the formatter: ```bash npx prettier --write . # from onyx/web ``` Pre-commit will also run prettier automatically on files you've recently touched. If re-formatted, your commit will fail. Re-stage your changes and commit again. --- ## Running the Application ### VSCode Debugger (Recommended) We highly recommend using VSCode's debugger for development. #### Initial Setup 1. Copy `.vscode/env_template.txt` to `.vscode/.env` 2. Fill in the necessary environment variables in `.vscode/.env` #### Using the Debugger Before starting, make sure the Docker Daemon is running. 1. Open the Debug view in VSCode (Cmd+Shift+D on macOS) 2. From the dropdown at the top, select "Clear and Restart External Volumes and Containers" and press the green play button 3. From the dropdown at the top, select "Run All Onyx Services" and press the green play button 4. Navigate to http://localhost:3000 in your browser to start using the app 5. Set breakpoints by clicking to the left of line numbers to help debug while the app is running 6. Use the debug toolbar to step through code, inspect variables, etc. > **Note:** "Clear and Restart External Volumes and Containers" will reset your Postgres and OpenSearch (relational-db and index). Only run this if you are okay with wiping your data. **Features:** - Hot reload is enabled for the web server and API servers - Python debugging is configured with debugpy - Environment variables are loaded from `.vscode/.env` - Console output is organized in the integrated terminal with labeled tabs ### Manually Running for Development #### Docker containers for external software You will need Docker installed to run these containers. Navigate to `onyx/deployment/docker_compose`, then start up Postgres/OpenSearch/Redis/MinIO with: ```bash docker compose -f docker-compose.yml -f docker-compose.dev.yml up -d index relational_db cache minio ``` (index refers to OpenSearch, relational_db refers to Postgres, and cache refers to Redis) #### Running Onyx locally To start the frontend, navigate to `onyx/web` and run: ```bash npm run dev ``` Next, start the model server which runs the local NLP models. Navigate to `onyx/backend` and run: ```bash uvicorn model_server.main:app --reload --port 9000 ``` _For Windows (for compatibility with both PowerShell and Command Prompt):_ ```bash powershell -Command "uvicorn model_server.main:app --reload --port 9000" ``` The first time running Onyx, you will need to run the DB migrations for Postgres. After the first time, this is no longer required unless the DB models change. Navigate to `onyx/backend` and with the venv active, run: ```bash alembic upgrade head ``` Next, start the task queue which orchestrates the background jobs. Still in `onyx/backend`, run: ```bash python ./scripts/dev_run_background_jobs.py ``` To run the backend API server, navigate back to `onyx/backend` and run: ```bash AUTH_TYPE=basic uvicorn onyx.main:app --reload --port 8080 ``` _For Windows (for compatibility with both PowerShell and Command Prompt):_ ```bash powershell -Command " $env:AUTH_TYPE='basic' uvicorn onyx.main:app --reload --port 8080 " ``` > **Note:** If you need finer logging, add the additional environment variable `LOG_LEVEL=DEBUG` to the relevant services. #### Wrapping up You should now have 4 servers running: - Web server - Backend API - Model server - Background jobs Now, visit http://localhost:3000 in your browser. You should see the Onyx onboarding wizard where you can connect your external LLM provider to Onyx. You've successfully set up a local Onyx instance! ### Running in Docker You can run the full Onyx application stack from pre-built images including all external software dependencies. Navigate to `onyx/deployment/docker_compose` and run: ```bash docker compose up -d ``` After Docker pulls and starts these containers, navigate to http://localhost:3000 to use Onyx. If you want to make changes to Onyx and run those changes in Docker, you can also build a local version of the Onyx container images that incorporates your changes: ```bash docker compose up -d --build ``` --- ## macOS-Specific Notes ### Setting up Python Ensure [Homebrew](https://brew.sh/) is already set up, then install Python 3.11: ```bash brew install python@3.11 ``` Add Python 3.11 to your path by adding the following line to `~/.zshrc`: ``` export PATH="$(brew --prefix)/opt/python@3.11/libexec/bin:$PATH" ``` > **Note:** You will need to open a new terminal for the path change above to take effect. ### Setting up Docker On macOS, you will need to install [Docker Desktop](https://www.docker.com/products/docker-desktop/) and ensure it is running before continuing with the docker commands. ### Formatting and Linting macOS will likely require you to remove some quarantine attributes on some of the hooks for them to execute properly. After installing pre-commit, run the following command: ```bash sudo xattr -r -d com.apple.quarantine ~/.cache/pre-commit ``` --- ## Engineering Best Practices > These are also what we adhere to as a team internally, we love to build in the open and to uplevel our community and each other through being transparent. ### Principles and Collaboration - **Use 1-way vs 2-way doors.** For 2-way doors, move faster and iterate. For 1-way doors, be more deliberate. - **Consistency > being "right."** Prefer consistent patterns across the codebase. If something is truly bad, fix it everywhere. - **Fix what you touch (selectively).** - Don't feel obligated to fix every best-practice issue you notice. - Don't introduce new bad practices. - If your change touches code that violates best practices, fix it as part of the change. - **Don't tack features on.** When adding functionality, restructure logically as needed to avoid muddying interfaces and accumulating tech debt. ### Style and Maintainability #### Comments and readability Add clear comments: - At logical boundaries (e.g., interfaces) so the reader doesn't need to dig 10 layers deeper. - Wherever assumptions are made or something non-obvious/unexpected is done. - For complicated flows/functions. - Wherever it saves time (e.g., nontrivial regex patterns). #### Errors and exceptions - **Fail loudly** rather than silently skipping work. - Example: raise and let exceptions propagate instead of silently dropping a document. - **Don't overuse `try/except`.** - Put `try/except` at the correct logical level. - Do not mask exceptions unless it is clearly appropriate. #### Typing - Everything should be **as strictly typed as possible**. - Use `cast` for annoying/loose-typed interfaces (e.g., results of `run_functions_tuples_in_parallel`). - Only `cast` when the type checker sees `Any` or types are too loose. - Prefer types that are easy to read. - Avoid dense types like `dict[tuple[str, str], list[list[float]]]`. - Prefer domain models, e.g.: - `EmbeddingModel(provider_name, model_name)` as a Pydantic model - `dict[EmbeddingModel, list[EmbeddingVector]]` #### State, objects, and boundaries - Keep **clear logical boundaries** for state containers and objects. - A **config** object should never contain things like a `db_session`. - Avoid state containers that are overly nested, or huge + flat (use judgment). - Prefer **composition and functional style** over inheritance/OOP. - Prefer **no mutation** unless there's a strong reason. - State objects should be **intentional and explicit**, ideally nonmutating. - Use interfaces/objects to create clear separation of responsibility. - Prefer simplicity when there's no clear gain. - Avoid overcomplicated mechanisms like semaphores. - Prefer **hash maps (dicts)** over tree structures unless there's a strong reason. #### Naming - Name variables carefully and intentionally. - Prefer long, explicit names when undecided. - Avoid single-character variables except for small, self-contained utilities (or not at all). - Keep the same object/name consistent through the call stack and within functions when reasonable. - Good: `for token in tokens:` - Bad: `for msg in tokens:` (if iterating tokens) - Function names should bias toward **long + descriptive** for codebase search. - IntelliSense can miss call sites; search works best with unique names. #### Correctness by construction - Prefer self-contained correctness — don't rely on callers to "use it right" if you can make misuse hard. - Avoid redundancies: if a function takes an arg, it shouldn't also take a state object that contains that same arg. - No dead code (unless there's a very good reason). - No commented-out code in main or feature branches (unless there's a very good reason). - No duplicate logic: - Don't copy/paste into branches when shared logic can live above the conditional. - If you're afraid to touch the original, you don't understand it well enough. - LLMs often create subtle duplicate logic — review carefully and remove it. - Avoid "nearly identical" objects that confuse when to use which. - Avoid extremely long functions with chained logic: - Encapsulate steps into helpers for readability, even if not reused. - "Pythonic" multi-step expressions are OK in moderation; don't trade clarity for cleverness. ### Performance and Correctness - Avoid holding resources for extended periods (DB sessions, locks/semaphores). - Validate objects on creation and right before use. - Connector code (data to Onyx documents): - Any in-memory structure that can grow without bound based on input must be periodically size-checked. - If a connector is OOMing (often shows up as "missing celery tasks"), this is a top thing to check retroactively. - Async and event loops: - Never introduce new async/event loop Python code, and try to make existing async code synchronous when possible if it makes sense. - Writing async code without 100% understanding the code and having a concrete reason to do so is likely to introduce bugs and not add any meaningful performance gains. ### Repository Conventions #### Where code lives - Pydantic + data models: `models.py` files. - DB interface functions (excluding lazy loading): `db/` directory. - LLM prompts: `prompts/` directory, roughly mirroring the code layout that uses them. - API routes: `server/` directory. #### Pydantic and modeling - Prefer **Pydantic** over dataclasses. - If absolutely required, use `allow_arbitrary_types`. #### Data conventions - Prefer explicit `None` over sentinel empty strings (usually; depends on intent). - Prefer explicit identifiers: use string enums instead of integer codes. - Avoid magic numbers (co-location is good when necessary). **Always avoid magic strings.** #### Logging - Log messages where they are created. - Don't propagate log messages around just to log them elsewhere. #### Encapsulation - Don't use private attributes/methods/properties from other classes/modules. - "Private" is private — respect that boundary. #### SQLAlchemy guidance - Lazy loading is often bad at scale, especially across multiple list relationships. - Be careful when accessing SQLAlchemy object attributes: - It can help avoid redundant DB queries, - but it can also fail if accessed outside an active session, - and lazy loading can add hidden DB dependencies to otherwise "simple" functions. - Reference: https://www.reddit.com/r/SQLAlchemy/comments/138f248/joinedload_vs_selectinload/ #### Trunk-based development and feature flags - **PRs should contain no more than 500 lines of real change.** - **Merge to main frequently.** Avoid long-lived feature branches — they create merge conflicts and integration pain. - **Use feature flags for incremental rollout.** - Large features should be merged in small, shippable increments behind a flag. - This allows continuous integration without exposing incomplete functionality. - **Keep flags short-lived.** Once a feature is fully rolled out, remove the flag and dead code paths promptly. - **Flag at the right level.** Prefer flagging at API/UI entry points rather than deep in business logic. - **Test both flag states.** Ensure the codebase works correctly with the flag on and off. #### Miscellaneous - Any TODOs you add in the code must be accompanied by either the name/username of the owner of that TODO, or an issue number for an issue referencing that piece of work. - Avoid module-level logic that runs on import, which leads to import-time side effects. Essentially every piece of meaningful logic should exist within some function that has to be explicitly invoked. Acceptable exceptions may include loading environment variables or setting up loggers. - If you find yourself needing something like this, you may want that logic to exist in a file dedicated for manual execution (contains `if __name__ == "__main__":`) which should not be imported by anything else. - Do not conflate Python scripts you intend to run from the command line (contains `if __name__ == "__main__":`) with modules you intend to import from elsewhere. If for some unlikely reason they have to be the same file, any logic specific to executing the file (including imports) should be contained in the `if __name__ == "__main__":` block. - Generally these executable files exist in `backend/scripts/`. --- ## Release Process Onyx loosely follows the SemVer versioning standard. A set of Docker containers will be pushed automatically to DockerHub with every tag. You can see the containers [here](https://hub.docker.com/search?q=onyx%2F). --- ## Getting Help We have support channels and generally interesting discussions on our [Discord](https://discord.gg/4NA5SbzrWb). See you there! --- ## Enterprise Edition Contributions If you are contributing features to Onyx Enterprise Edition (code under any `ee/` directory), you are required to sign the [IP Assignment Agreement](contributor_ip_assignment/EE_Contributor_IP_Assignment_Agreement.md) ([PDF version](contributor_ip_assignment/EE_Contributor_IP_Assignment_Agreement.pdf)). ================================================ FILE: LICENSE ================================================ Copyright (c) 2023-present DanswerAI, Inc. Portions of this software are licensed as follows: - All content that resides under "ee" directories of this repository is licensed under the Onyx Enterprise License. Each ee directory contains an identical copy of this license at its root: - backend/ee/LICENSE - web/src/app/ee/LICENSE - web/src/ee/LICENSE - All third party components incorporated into the Onyx Software are licensed under the original license provided by the owner of the applicable component. - Content outside of the above mentioned directories or restrictions above is available under the "MIT Expat" license as defined below. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: README.md ================================================

Discord Documentation Documentation License

onyx-dot-app/onyx | Trendshift

# Onyx - The Open Source AI Platform **[Onyx](https://www.onyx.app/?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme)** is the application layer for LLMs - bringing a feature-rich interface that can be easily hosted by anyone. Onyx enables LLMs through advanced capabilities like RAG, web search, code execution, file creation, deep research and more. Connect your applications with over 50+ indexing based connectors provided out of the box or via MCP. > [!TIP] > Deploy with a single command: > ``` > curl -fsSL https://onyx.app/install_onyx.sh | bash > ``` ![Onyx Chat Silent Demo](https://github.com/onyx-dot-app/onyx/releases/download/v3.0.0/Onyx.gif) --- ## ⭐ Features - **🔍 Agentic RAG:** Get best in class search and answer quality based on hybrid index + AI Agents for information retrieval - Benchmark to release soon! - **🔬 Deep Research:** Get in depth reports with a multi-step research flow. - Top of [leaderboard](https://github.com/onyx-dot-app/onyx_deep_research_bench) as of Feb 2026. - **🤖 Custom Agents:** Build AI Agents with unique instructions, knowledge, and actions. - **🌍 Web Search:** Browse the web to get up to date information. - Supports Serper, Google PSE, Brave, SearXNG, and others. - Comes with an in house web crawler and support for Firecrawl/Exa. - **📄 Artifacts:** Generate documents, graphics, and other downloadable artifacts. - **▶️ Actions & MCP:** Let Onyx agents interact with external applications, comes with flexible Auth options. - **💻 Code Execution:** Execute code in a sandbox to analyze data, render graphs, or modify files. - **🎙️ Voice Mode:** Chat with Onyx via text-to-speech and speech-to-text. - **🎨 Image Generation:** Generate images based on user prompts. Onyx supports all major LLM providers, both self-hosted (like Ollama, LiteLLM, vLLM, etc.) and proprietary (like Anthropic, OpenAI, Gemini, etc.). To learn more - check out our [docs](https://docs.onyx.app/welcome?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme)! --- ## 🚀 Deployment Modes > Onyx supports deployments in Docker, Kubernetes, Helm/Terraform and provides guides for major cloud providers. > Detailed deployment guides found [here](https://docs.onyx.app/deployment/overview). Onyx supports two separate deployment options: standard and lite. #### Onyx Lite The Lite mode can be thought of as a lightweight Chat UI. It requires less resources (under 1GB memory) and runs a less complex stack. It is great for users who want to test out Onyx quickly or for teams who are only interested in the Chat UI and Agents functionalities. #### Standard Onyx The complete feature set of Onyx which is recommended for serious users and larger teams. Additional components not included in Lite mode: - Vector + Keyword index for RAG. - Background containers to run job queues and workers for syncing knowledge from connectors. - AI model inference servers to run deep learning models used during indexing and inference. - Performance optimizations for large scale use via in memory cache (Redis) and blob store (MinIO). > [!TIP] > **To try Onyx for free without deploying, visit [Onyx Cloud](https://cloud.onyx.app/signup?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme)**. --- ## 🏢 Onyx for Enterprise Onyx is built for teams of all sizes, from individual users to the largest global enterprises: - 👥 Collaboration: Share chats and agents with other members of your organization. - 🔐 Single Sign On: SSO via Google OAuth, OIDC, or SAML. Group syncing and user provisioning via SCIM. - 🛡️ Role Based Access Control: RBAC for sensitive resources like access to agents, actions, etc. - 📊 Analytics: Usage graphs broken down by teams, LLMs, or agents. - 🕵️ Query History: Audit usage to ensure safe adoption of AI in your organization. - 💻 Custom code: Run custom code to remove PII, reject sensitive queries, or to run custom analysis. - 🎨 Whitelabeling: Customize the look and feel of Onyx with custom naming, icons, banners, and more. ## 📚 Licensing There are two editions of Onyx: - Onyx Community Edition (CE) is available freely under the MIT license and covers all of the core features for Chat, RAG, Agents, and Actions. - Onyx Enterprise Edition (EE) includes extra features that are primarily useful for larger organizations. For feature details, check out [our website](https://www.onyx.app/pricing?utm_source=onyx_repo&utm_medium=github&utm_campaign=readme). ## 👪 Community Join our open source community on **[Discord](https://discord.gg/TDJ59cGV2X)**! ## 💡 Contributing Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details. ================================================ FILE: backend/.dockerignore ================================================ **/__pycache__ venv/ env/ *.egg-info .cache .git/ .svn/ .vscode/ .idea/ *.log log/ .env secrets.yaml build/ dist/ .coverage htmlcov/ model_server/legacy/ # Craft: demo_data directory should be unzipped at container startup, not copied **/demo_data/ # Craft: templates/outputs/venv is created at container startup **/templates/outputs/venv ================================================ FILE: backend/.gitignore ================================================ __pycache__/ .mypy_cache .idea/ site_crawls/ .ipynb_checkpoints/ api_keys.py *ipynb .env* vespa-app.zip dynamic_config_storage/ celerybeat-schedule* onyx/connectors/salesforce/data/ .test.env /generated ================================================ FILE: backend/.trivyignore ================================================ # https://github.com/madler/zlib/issues/868 # Pulled in with base Debian image, it's part of the contrib folder but unused # zlib1g is fine # Will be gone with Debian image upgrade # No impact in our settings CVE-2023-45853 # krb5 related, worst case is denial of service by resource exhaustion # Accept the risk CVE-2024-26458 CVE-2024-26461 CVE-2024-26462 CVE-2024-26458 CVE-2024-26461 CVE-2024-26462 CVE-2024-26458 CVE-2024-26461 CVE-2024-26462 CVE-2024-26458 CVE-2024-26461 CVE-2024-26462 # Specific to Firefox which we do not use # No impact in our settings CVE-2024-0743 # bind9 related, worst case is denial of service by CPU resource exhaustion # Accept the risk CVE-2023-50387 CVE-2023-50868 CVE-2023-50387 CVE-2023-50868 # libexpat1, XML parsing resource exhaustion # We don't parse any user provided XMLs # No impact in our settings CVE-2023-52425 CVE-2024-28757 # libharfbuzz0b, O(n^2) growth, worst case is denial of service # Accept the risk CVE-2023-25193 ================================================ FILE: backend/Dockerfile ================================================ FROM python:3.11.7-slim-bookworm LABEL com.danswer.maintainer="founders@onyx.app" LABEL com.danswer.description="This image is the web/frontend container of Onyx which \ contains code for both the Community and Enterprise editions of Onyx. If you do not \ have a contract or agreement with DanswerAI, you are not permitted to use the Enterprise \ Edition features outside of personal development or testing purposes. Please reach out to \ founders@onyx.app for more information. Please visit https://github.com/onyx-dot-app/onyx" # Build argument for Craft support (disabled by default) # Use --build-arg ENABLE_CRAFT=true to include Node.js and opencode CLI ARG ENABLE_CRAFT=false # DO_NOT_TRACK is used to disable telemetry for Unstructured ENV DANSWER_RUNNING_IN_DOCKER="true" \ DO_NOT_TRACK="true" \ PLAYWRIGHT_BROWSERS_PATH="/app/.cache/ms-playwright" # Create non-root user for security best practices RUN groupadd -g 1001 onyx && \ useradd -u 1001 -g onyx -m -s /bin/bash onyx && \ mkdir -p /var/log/onyx && \ chmod 755 /var/log/onyx && \ chown onyx:onyx /var/log/onyx COPY --from=ghcr.io/astral-sh/uv:0.9.9 /uv /uvx /bin/ # Install system dependencies # cmake needed for psycopg (postgres) # libpq-dev needed for psycopg (postgres) # curl included just for users' convenience # zip for Vespa step futher down # ca-certificates for HTTPS RUN apt-get update && \ apt-get install -y \ cmake \ curl \ zip \ ca-certificates \ libgnutls30 \ libblkid1 \ libmount1 \ libsmartcols1 \ libuuid1 \ libxmlsec1-dev \ pkg-config \ gcc \ nano \ vim \ # Install procps so kubernetes exec sessions can use ps aux for debugging procps \ libjemalloc2 \ && \ rm -rf /var/lib/apt/lists/* && \ apt-get clean # Conditionally install Node.js 20 for Craft (required for Next.js) # Only installed when ENABLE_CRAFT=true RUN if [ "$ENABLE_CRAFT" = "true" ]; then \ echo "Installing Node.js 20 for Craft support..." && \ curl -fsSL https://deb.nodesource.com/setup_20.x | bash - && \ apt-get install -y nodejs && \ rm -rf /var/lib/apt/lists/*; \ fi # Conditionally install opencode CLI for Craft agent functionality # Only installed when ENABLE_CRAFT=true # TODO: download a specific, versioned release of the opencode CLI RUN if [ "$ENABLE_CRAFT" = "true" ]; then \ echo "Installing opencode CLI for Craft support..." && \ curl -fsSL https://opencode.ai/install | bash; \ fi ENV PATH="/root/.opencode/bin:${PATH}" # Install Python dependencies # Remove py which is pulled in by retry, py is not needed and is a CVE COPY ./requirements/default.txt /tmp/requirements.txt COPY ./requirements/ee.txt /tmp/ee-requirements.txt RUN uv pip install --system --no-cache-dir --upgrade \ -r /tmp/requirements.txt \ -r /tmp/ee-requirements.txt && \ pip uninstall -y py && \ playwright install chromium && \ playwright install-deps chromium && \ chown -R onyx:onyx /app && \ ln -s /usr/local/bin/supervisord /usr/bin/supervisord && \ # Cleanup for CVEs and size reduction # https://github.com/tornadoweb/tornado/issues/3107 # xserver-common and xvfb included by playwright installation but not needed after # perl-base is part of the base Python Debian image but not needed for Onyx functionality # perl-base could only be removed with --allow-remove-essential apt-get update && \ apt-get remove -y --allow-remove-essential \ perl-base \ xserver-common \ xvfb \ cmake \ libldap-2.5-0 \ libxmlsec1-dev \ pkg-config \ gcc && \ # Install here to avoid some packages being cleaned up above apt-get install -y \ libxmlsec1-openssl \ # Install postgresql-client for easy manual tests postgresql-client && \ apt-get autoremove -y && \ rm -rf /var/lib/apt/lists/* && \ rm -rf ~/.cache/uv /tmp/*.txt && \ rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key # Pre-downloading models for setups with limited egress RUN python -c "from tokenizers import Tokenizer; \ Tokenizer.from_pretrained('nomic-ai/nomic-embed-text-v1')" # Pre-downloading NLTK for setups with limited egress RUN python -c "import nltk; \ nltk.download('stopwords', quiet=True); \ nltk.download('punkt_tab', quiet=True);" # nltk.download('wordnet', quiet=True); introduce this back if lemmatization is needed # Pre-downloading tiktoken for setups with limited egress RUN python -c "import tiktoken; \ tiktoken.get_encoding('cl100k_base')" # Set up application files WORKDIR /app # Enterprise Version Files COPY --chown=onyx:onyx ./ee /app/ee COPY supervisord.conf /etc/supervisor/conf.d/supervisord.conf # Set up application files COPY --chown=onyx:onyx ./onyx /app/onyx COPY --chown=onyx:onyx ./shared_configs /app/shared_configs COPY --chown=onyx:onyx ./alembic /app/alembic COPY --chown=onyx:onyx ./alembic_tenants /app/alembic_tenants COPY --chown=onyx:onyx ./alembic.ini /app/alembic.ini COPY supervisord.conf /usr/etc/supervisord.conf COPY --chown=onyx:onyx ./static /app/static COPY --chown=onyx:onyx ./keys /app/keys # Escape hatch scripts COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh COPY --chown=onyx:onyx ./scripts/reencrypt_secrets.py /app/scripts/reencrypt_secrets.py RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh # Run Craft template setup at build time when ENABLE_CRAFT=true # This pre-bakes demo data, Python venv, and npm dependencies into the image RUN if [ "$ENABLE_CRAFT" = "true" ]; then \ echo "Running Craft template setup at build time..." && \ ENABLE_CRAFT=true /app/scripts/setup_craft_templates.sh; \ fi # Set Craft template paths to the in-image locations # These match the paths where setup_craft_templates.sh creates the templates ENV OUTPUTS_TEMPLATE_PATH=/app/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs ENV VENV_TEMPLATE_PATH=/app/onyx/server/features/build/sandbox/kubernetes/docker/templates/venv # Put logo in assets COPY --chown=onyx:onyx ./assets /app/assets ENV PYTHONPATH=/app # Default ONYX_VERSION, typically overriden during builds by GitHub Actions. ARG ONYX_VERSION=0.0.0-dev ENV ONYX_VERSION=${ONYX_VERSION} # Use jemalloc instead of glibc malloc to reduce memory fragmentation # in long-running Python processes (API server, Celery workers). # The soname is architecture-independent; the dynamic linker resolves # the correct path from standard library directories. # Placed after all RUN steps so build-time processes are unaffected. ENV LD_PRELOAD=libjemalloc.so.2 # Default command which does nothing # This container is used by api server and background which specify their own CMD CMD ["tail", "-f", "/dev/null"] ================================================ FILE: backend/Dockerfile.model_server ================================================ # Base stage with dependencies FROM python:3.11.7-slim-bookworm AS base ENV DANSWER_RUNNING_IN_DOCKER="true" \ HF_HOME=/app/.cache/huggingface COPY --from=ghcr.io/astral-sh/uv:0.9.9 /uv /uvx /bin/ RUN mkdir -p /app/.cache/huggingface COPY ./requirements/model_server.txt /tmp/requirements.txt RUN uv pip install --system --no-cache-dir --upgrade \ -r /tmp/requirements.txt && \ rm -rf ~/.cache/uv /tmp/*.txt # Stage for downloading embedding models FROM base AS embedding-models RUN python -c "from huggingface_hub import snapshot_download; \ snapshot_download('nomic-ai/nomic-embed-text-v1');" # Initialize SentenceTransformer to cache the custom architecture RUN python -c "from sentence_transformers import SentenceTransformer; \ SentenceTransformer(model_name_or_path='nomic-ai/nomic-embed-text-v1', trust_remote_code=True);" # Final stage - combine all downloads FROM base AS final LABEL com.danswer.maintainer="founders@onyx.app" LABEL com.danswer.description="This image is for the Onyx model server which runs all of the \ AI models for Onyx. This container and all the code is MIT Licensed and free for all to use. \ You can find it at https://hub.docker.com/r/onyx/onyx-model-server. For more details, \ visit https://github.com/onyx-dot-app/onyx." # Create non-root user for security best practices RUN groupadd -g 1001 onyx && \ useradd -u 1001 -g onyx -m -s /bin/bash onyx && \ mkdir -p /var/log/onyx && \ chmod 755 /var/log/onyx && \ chown onyx:onyx /var/log/onyx # In case the user has volumes mounted to /app/.cache/huggingface that they've downloaded while # running Onyx, move the current contents of the cache folder to a temporary location to ensure # it's preserved in order to combine with the user's cache contents COPY --chown=onyx:onyx --from=embedding-models /app/.cache/huggingface /app/.cache/temp_huggingface WORKDIR /app # Utils used by model server COPY ./onyx/utils/logger.py /app/onyx/utils/logger.py COPY ./onyx/utils/middleware.py /app/onyx/utils/middleware.py COPY ./onyx/utils/tenant.py /app/onyx/utils/tenant.py # Place to fetch version information COPY ./onyx/__init__.py /app/onyx/__init__.py # Shared between Onyx Backend and Model Server COPY ./shared_configs /app/shared_configs # Model Server main code COPY ./model_server /app/model_server ENV PYTHONPATH=/app # Default ONYX_VERSION, typically overriden during builds by GitHub Actions. ARG ONYX_VERSION=0.0.0-dev ENV ONYX_VERSION=${ONYX_VERSION} CMD ["uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000"] ================================================ FILE: backend/alembic/README.md ================================================ # Alembic DB Migrations These files are for creating/updating the tables in the Relational DB (Postgres). Onyx migrations use a generic single-database configuration with an async dbapi. ## To generate new migrations: From onyx/backend, run: `alembic revision -m ` Note: you cannot use the `--autogenerate` flag as the automatic schema parsing does not work. Manually populate the upgrade and downgrade in your new migration. More info can be found here: https://alembic.sqlalchemy.org/en/latest/autogenerate.html ## Running migrations To run all un-applied migrations: `alembic upgrade head` To undo migrations: `alembic downgrade -X` where X is the number of migrations you want to undo from the current state ### Multi-tenant migrations For multi-tenant deployments, you can use additional options: **Upgrade all tenants:** ```bash alembic -x upgrade_all_tenants=true upgrade head ``` **Upgrade specific schemas:** ```bash # Single schema alembic -x schemas=tenant_12345678-1234-1234-1234-123456789012 upgrade head # Multiple schemas (comma-separated) alembic -x schemas=tenant_12345678-1234-1234-1234-123456789012,public,another_tenant upgrade head ``` **Upgrade tenants within an alphabetical range:** ```bash # Upgrade tenants 100-200 when sorted alphabetically (positions 100 to 200) alembic -x upgrade_all_tenants=true -x tenant_range_start=100 -x tenant_range_end=200 upgrade head # Upgrade tenants starting from position 1000 alphabetically alembic -x upgrade_all_tenants=true -x tenant_range_start=1000 upgrade head # Upgrade first 500 tenants alphabetically alembic -x upgrade_all_tenants=true -x tenant_range_end=500 upgrade head ``` **Continue on error (for batch operations):** ```bash alembic -x upgrade_all_tenants=true -x continue=true upgrade head ``` The tenant range filtering works by: 1. Sorting tenant IDs alphabetically 2. Using 1-based position numbers (1st, 2nd, 3rd tenant, etc.) 3. Filtering to the specified range of positions 4. Non-tenant schemas (like 'public') are always included ================================================ FILE: backend/alembic/env.py ================================================ from typing import Any, Literal from onyx.db.engine.iam_auth import get_iam_auth_token from onyx.configs.app_configs import USE_IAM_AUTH from onyx.configs.app_configs import POSTGRES_HOST from onyx.configs.app_configs import POSTGRES_PORT from onyx.configs.app_configs import POSTGRES_USER from onyx.configs.app_configs import AWS_REGION_NAME from onyx.db.engine.sql_engine import build_connection_string from onyx.db.engine.tenant_utils import get_all_tenant_ids from sqlalchemy import event from sqlalchemy import pool from sqlalchemy import text from sqlalchemy.engine.base import Connection import os import ssl import asyncio import logging from logging.config import fileConfig from alembic import context from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.sql.schema import SchemaItem from onyx.configs.constants import SSL_CERT_FILE from shared_configs.configs import ( MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA, TENANT_ID_PREFIX, ) from onyx.db.models import Base from celery.backends.database.session import ResultModelBase # type: ignore from onyx.db.engine.sql_engine import SqlEngine # Make sure in alembic.ini [logger_root] level=INFO is set or most logging will be # hidden! (defaults to level=WARN) # Alembic Config object config = context.config if config.config_file_name is not None and config.attributes.get( "configure_logger", True ): # disable_existing_loggers=False prevents breaking pytest's caplog fixture # See: https://pytest-alembic.readthedocs.io/en/latest/setup.html#caplog-issues fileConfig(config.config_file_name, disable_existing_loggers=False) target_metadata = [Base.metadata, ResultModelBase.metadata] EXCLUDE_TABLES = {"kombu_queue", "kombu_message"} logger = logging.getLogger(__name__) ssl_context: ssl.SSLContext | None = None if USE_IAM_AUTH: if not os.path.exists(SSL_CERT_FILE): raise FileNotFoundError(f"Expected {SSL_CERT_FILE} when USE_IAM_AUTH is true.") ssl_context = ssl.create_default_context(cafile=SSL_CERT_FILE) def include_object( object: SchemaItem, # noqa: ARG001 name: str | None, type_: Literal[ "schema", "table", "column", "index", "unique_constraint", "foreign_key_constraint", ], reflected: bool, # noqa: ARG001 compare_to: SchemaItem | None, # noqa: ARG001 ) -> bool: if type_ == "table" and name in EXCLUDE_TABLES: return False return True def filter_tenants_by_range( tenant_ids: list[str], start_range: int | None = None, end_range: int | None = None ) -> list[str]: """ Filter tenant IDs by alphabetical position range. Args: tenant_ids: List of tenant IDs to filter start_range: Starting position in alphabetically sorted list (1-based, inclusive) end_range: Ending position in alphabetically sorted list (1-based, inclusive) Returns: Filtered list of tenant IDs in their original order """ if start_range is None and end_range is None: return tenant_ids # Separate tenant IDs from non-tenant schemas tenant_schemas = [tid for tid in tenant_ids if tid.startswith(TENANT_ID_PREFIX)] non_tenant_schemas = [ tid for tid in tenant_ids if not tid.startswith(TENANT_ID_PREFIX) ] # Sort tenant schemas alphabetically. # NOTE: can cause missed schemas if a schema is created in between workers # fetching of all tenant IDs. We accept this risk for now. Just re-running # the migration will fix the issue. sorted_tenant_schemas = sorted(tenant_schemas) # Apply range filtering (0-based indexing) start_idx = start_range if start_range is not None else 0 end_idx = end_range if end_range is not None else len(sorted_tenant_schemas) # Ensure indices are within bounds start_idx = max(0, start_idx) end_idx = min(len(sorted_tenant_schemas), end_idx) # Get the filtered tenant schemas filtered_tenant_schemas = sorted_tenant_schemas[start_idx:end_idx] # Combine with non-tenant schemas and preserve original order filtered_tenants = [] for tenant_id in tenant_ids: if tenant_id in filtered_tenant_schemas or tenant_id in non_tenant_schemas: filtered_tenants.append(tenant_id) return filtered_tenants def get_schema_options() -> ( tuple[bool, bool, bool, int | None, int | None, list[str] | None] ): x_args_raw = context.get_x_argument() x_args = {} for arg in x_args_raw: if "=" in arg: key, value = arg.split("=", 1) x_args[key.strip()] = value.strip() else: raise ValueError(f"Invalid argument: {arg}") create_schema = x_args.get("create_schema", "true").lower() == "true" upgrade_all_tenants = x_args.get("upgrade_all_tenants", "false").lower() == "true" # continue on error with individual tenant # only applies to online migrations continue_on_error = x_args.get("continue", "false").lower() == "true" # Tenant range filtering tenant_range_start = None tenant_range_end = None if "tenant_range_start" in x_args: try: tenant_range_start = int(x_args["tenant_range_start"]) except ValueError: raise ValueError( f"Invalid tenant_range_start value: {x_args['tenant_range_start']}. Must be an integer." ) if "tenant_range_end" in x_args: try: tenant_range_end = int(x_args["tenant_range_end"]) except ValueError: raise ValueError( f"Invalid tenant_range_end value: {x_args['tenant_range_end']}. Must be an integer." ) # Validate range if tenant_range_start is not None and tenant_range_end is not None: if tenant_range_start > tenant_range_end: raise ValueError( f"tenant_range_start ({tenant_range_start}) cannot be greater than tenant_range_end ({tenant_range_end})" ) # Specific schema names filtering (replaces both schema_name and the old tenant_ids approach) schemas = None if "schemas" in x_args: schema_names_str = x_args["schemas"].strip() if schema_names_str: # Split by comma and strip whitespace schemas = [ name.strip() for name in schema_names_str.split(",") if name.strip() ] if schemas: logger.info(f"Specific schema names specified: {schemas}") # Validate that only one method is used at a time range_filtering = tenant_range_start is not None or tenant_range_end is not None specific_filtering = schemas is not None and len(schemas) > 0 if range_filtering and specific_filtering: raise ValueError( "Cannot use both tenant range filtering (tenant_range_start/tenant_range_end) " "and specific schema filtering (schemas) at the same time. " "Please use only one filtering method." ) if upgrade_all_tenants and specific_filtering: raise ValueError( "Cannot use both upgrade_all_tenants=true and schemas at the same time. " "Use either upgrade_all_tenants=true for all tenants, or schemas for specific schemas." ) # If any filtering parameters are specified, we're not doing the default single schema migration if range_filtering: upgrade_all_tenants = True # Validate multi-tenant requirements if MULTI_TENANT and not upgrade_all_tenants and not specific_filtering: raise ValueError( "In multi-tenant mode, you must specify either upgrade_all_tenants=true " "or provide schemas. Cannot run default migration." ) return ( create_schema, upgrade_all_tenants, continue_on_error, tenant_range_start, tenant_range_end, schemas, ) def do_run_migrations( connection: Connection, schema_name: str, create_schema: bool ) -> None: if create_schema: connection.execute(text(f'CREATE SCHEMA IF NOT EXISTS "{schema_name}"')) connection.execute(text(f'SET search_path TO "{schema_name}"')) context.configure( connection=connection, target_metadata=target_metadata, # type: ignore include_object=include_object, version_table_schema=schema_name, include_schemas=True, compare_type=True, compare_server_default=True, script_location=config.get_main_option("script_location"), ) with context.begin_transaction(): context.run_migrations() def provide_iam_token_for_alembic( dialect: Any, # noqa: ARG001 conn_rec: Any, # noqa: ARG001 cargs: Any, # noqa: ARG001 cparams: Any, ) -> None: if USE_IAM_AUTH: # Database connection settings region = AWS_REGION_NAME host = POSTGRES_HOST port = POSTGRES_PORT user = POSTGRES_USER # Get IAM authentication token token = get_iam_auth_token(host, port, user, region) # For Alembic / SQLAlchemy in this context, set SSL and password cparams["password"] = token cparams["ssl"] = ssl_context async def run_async_migrations() -> None: ( create_schema, upgrade_all_tenants, continue_on_error, tenant_range_start, tenant_range_end, schemas, ) = get_schema_options() if not schemas and not MULTI_TENANT: schemas = [POSTGRES_DEFAULT_SCHEMA] # without init_engine, subsequent engine calls fail hard intentionally SqlEngine.init_engine(pool_size=20, max_overflow=5) engine = create_async_engine( build_connection_string(), poolclass=pool.NullPool, ) if USE_IAM_AUTH: @event.listens_for(engine.sync_engine, "do_connect") def event_provide_iam_token_for_alembic( dialect: Any, conn_rec: Any, cargs: Any, cparams: Any ) -> None: provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams) if schemas: # Use specific schema names directly without fetching all tenants logger.info(f"Migrating specific schema names: {schemas}") i_schema = 0 num_schemas = len(schemas) for schema in schemas: i_schema += 1 logger.info( f"Migrating schema: index={i_schema} num_schemas={num_schemas} schema={schema}" ) try: async with engine.connect() as connection: await connection.run_sync( do_run_migrations, schema_name=schema, create_schema=create_schema, ) await connection.commit() except Exception as e: logger.error(f"Error migrating schema {schema}: {e}") if not continue_on_error: logger.error("--continue=true is not set, raising exception!") raise logger.warning("--continue=true is set, continuing to next schema.") elif upgrade_all_tenants: tenant_schemas = get_all_tenant_ids() filtered_tenant_schemas = filter_tenants_by_range( tenant_schemas, tenant_range_start, tenant_range_end ) if tenant_range_start is not None or tenant_range_end is not None: logger.info( f"Filtering tenants by range: start={tenant_range_start}, end={tenant_range_end}" ) logger.info( f"Total tenants: {len(tenant_schemas)}, Filtered tenants: {len(filtered_tenant_schemas)}" ) i_tenant = 0 num_tenants = len(filtered_tenant_schemas) for schema in filtered_tenant_schemas: i_tenant += 1 logger.info( f"Migrating schema: index={i_tenant} num_tenants={num_tenants} schema={schema}" ) try: async with engine.connect() as connection: await connection.run_sync( do_run_migrations, schema_name=schema, create_schema=create_schema, ) await connection.commit() except Exception as e: logger.error(f"Error migrating schema {schema}: {e}") if not continue_on_error: logger.error("--continue=true is not set, raising exception!") raise logger.warning("--continue=true is set, continuing to next schema.") else: # This should not happen in the new design since we require either # upgrade_all_tenants=true or schemas in multi-tenant mode # and for non-multi-tenant mode, we should use schemas with the default schema raise ValueError( "No migration target specified. Use either upgrade_all_tenants=true for all tenants or schemas for specific schemas." ) await engine.dispose() def run_migrations_offline() -> None: """ NOTE(rkuo): This generates a sql script that can be used to migrate the database ... instead of migrating the db live via an open connection Not clear on when this would be used by us or if it even works. If it is offline, then why are there calls to the db engine? This doesn't really get used when we migrate in the cloud.""" logger.info("run_migrations_offline starting.") # without init_engine, subsequent engine calls fail hard intentionally SqlEngine.init_engine(pool_size=20, max_overflow=5) ( create_schema, upgrade_all_tenants, continue_on_error, tenant_range_start, tenant_range_end, schemas, ) = get_schema_options() url = build_connection_string() if schemas: # Use specific schema names directly without fetching all tenants logger.info(f"Migrating specific schema names: {schemas}") for schema in schemas: logger.info(f"Migrating schema: {schema}") context.configure( url=url, target_metadata=target_metadata, # type: ignore literal_binds=True, include_object=include_object, version_table_schema=schema, include_schemas=True, script_location=config.get_main_option("script_location"), dialect_opts={"paramstyle": "named"}, ) with context.begin_transaction(): context.run_migrations() elif upgrade_all_tenants: engine = create_async_engine(url) if USE_IAM_AUTH: @event.listens_for(engine.sync_engine, "do_connect") def event_provide_iam_token_for_alembic_offline( dialect: Any, conn_rec: Any, cargs: Any, cparams: Any ) -> None: provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams) tenant_schemas = get_all_tenant_ids() engine.sync_engine.dispose() filtered_tenant_schemas = filter_tenants_by_range( tenant_schemas, tenant_range_start, tenant_range_end ) if tenant_range_start is not None or tenant_range_end is not None: logger.info( f"Filtering tenants by range: start={tenant_range_start}, end={tenant_range_end}" ) logger.info( f"Total tenants: {len(tenant_schemas)}, Filtered tenants: {len(filtered_tenant_schemas)}" ) for schema in filtered_tenant_schemas: logger.info(f"Migrating schema: {schema}") context.configure( url=url, target_metadata=target_metadata, # type: ignore literal_binds=True, include_object=include_object, version_table_schema=schema, include_schemas=True, script_location=config.get_main_option("script_location"), dialect_opts={"paramstyle": "named"}, ) with context.begin_transaction(): context.run_migrations() else: # This should not happen in the new design raise ValueError( "No migration target specified. Use either upgrade_all_tenants=true for all tenants or schemas for specific schemas." ) def run_migrations_online() -> None: """Run migrations in 'online' mode. Supports pytest-alembic by checking for a pre-configured connection in context.config.attributes["connection"]. If present, uses that connection/engine directly instead of creating a new async engine. """ # Check if pytest-alembic is providing a connection/engine connectable = context.config.attributes.get("connection", None) if connectable is not None: # pytest-alembic is providing an engine - use it directly logger.debug("run_migrations_online starting (pytest-alembic mode).") # For pytest-alembic, we use the default schema (public) schema_name = context.config.attributes.get( "schema_name", POSTGRES_DEFAULT_SCHEMA ) # pytest-alembic passes an Engine, we need to get a connection from it with connectable.connect() as connection: # Set search path for the schema connection.execute(text(f'SET search_path TO "{schema_name}"')) context.configure( connection=connection, target_metadata=target_metadata, # type: ignore include_object=include_object, version_table_schema=schema_name, include_schemas=True, compare_type=True, compare_server_default=True, script_location=config.get_main_option("script_location"), ) with context.begin_transaction(): context.run_migrations() # Commit the transaction to ensure changes are visible to next migration connection.commit() else: # Normal operation - use async migrations logger.info("run_migrations_online starting.") asyncio.run(run_async_migrations()) if context.is_offline_mode(): run_migrations_offline() else: run_migrations_online() ================================================ FILE: backend/alembic/run_multitenant_migrations.py ================================================ #!/usr/bin/env python3 """Parallel Alembic Migration Runner Upgrades tenant schemas to head in batched, parallel alembic subprocesses. Each subprocess handles a batch of schemas (via ``-x schemas=a,b,c``), reducing per-process overhead compared to one-schema-per-process. Usage examples:: # defaults: 6 workers, 50 schemas/batch python alembic/run_multitenant_migrations.py # custom settings python alembic/run_multitenant_migrations.py -j 8 -b 100 """ from __future__ import annotations import argparse import subprocess import sys import threading import time from concurrent.futures import ThreadPoolExecutor, as_completed from typing import NamedTuple from alembic.config import Config from alembic.script import ScriptDirectory from onyx.db.engine.sql_engine import SqlEngine from onyx.db.engine.tenant_utils import get_all_tenant_ids from onyx.db.engine.tenant_utils import get_schemas_needing_migration from shared_configs.configs import TENANT_ID_PREFIX # --------------------------------------------------------------------------- # Data types # --------------------------------------------------------------------------- class Args(NamedTuple): jobs: int batch_size: int class BatchResult(NamedTuple): schemas: list[str] success: bool output: str elapsed_sec: float # --------------------------------------------------------------------------- # Core functions # --------------------------------------------------------------------------- def run_alembic_for_batch(schemas: list[str]) -> BatchResult: """Run ``alembic upgrade head`` for a batch of schemas in one subprocess. If the batch fails, it is automatically retried with ``-x continue=true`` so that the remaining schemas in the batch still get migrated. The retry output (which contains alembic's per-schema error messages) is returned for diagnosis. """ csv = ",".join(schemas) base_cmd = ["alembic", "-x", f"schemas={csv}"] start = time.monotonic() result = subprocess.run( [*base_cmd, "upgrade", "head"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, ) if result.returncode == 0: elapsed = time.monotonic() - start return BatchResult(schemas, True, result.stdout or "", elapsed) # At least one schema failed. Print the initial error output, then # re-run with continue=true so the remaining schemas still get migrated. if result.stdout: print(f"Initial error output:\n{result.stdout}", file=sys.stderr, flush=True) print( f"Batch failed (exit {result.returncode}), retrying with 'continue=true'...", file=sys.stderr, flush=True, ) retry = subprocess.run( [*base_cmd, "-x", "continue=true", "upgrade", "head"], stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True, ) elapsed = time.monotonic() - start return BatchResult(schemas, False, retry.stdout or "", elapsed) def get_head_revision() -> str | None: """Get the head revision from the alembic script directory.""" alembic_cfg = Config("alembic.ini") script = ScriptDirectory.from_config(alembic_cfg) return script.get_current_head() def run_migrations_parallel( schemas: list[str], max_workers: int, batch_size: int, ) -> bool: """Chunk *schemas* into batches and run them in parallel. A background monitor thread prints a status line every 60 s listing which batches are still in-flight, making it easy to spot hung tenants. """ batches = [schemas[i : i + batch_size] for i in range(0, len(schemas), batch_size)] total_batches = len(batches) print( f"{len(schemas)} schemas in {total_batches} batch(es) with {max_workers} workers (batch size: {batch_size})...", flush=True, ) all_success = True # Thread-safe tracking of in-flight batches for the monitor thread. in_flight: dict[int, list[str]] = {} prev_in_flight: set[int] = set() lock = threading.Lock() stop_event = threading.Event() def _monitor() -> None: """Print a status line every 60 s listing batches still in-flight. Only prints batches that were also present in the previous tick, making it easy to spot batches that are stuck. """ nonlocal prev_in_flight while not stop_event.wait(60): with lock: if not in_flight: prev_in_flight = set() continue current = set(in_flight) stuck = current & prev_in_flight prev_in_flight = current if not stuck: continue schemas = [s for idx in sorted(stuck) for s in in_flight[idx]] print( f"⏳ batch(es) still running since last check " f"({', '.join(str(i + 1) for i in sorted(stuck))}): " + ", ".join(schemas), flush=True, ) monitor_thread = threading.Thread(target=_monitor, daemon=True) monitor_thread.start() try: with ThreadPoolExecutor(max_workers=max_workers) as executor: def _run(batch_idx: int, batch: list[str]) -> BatchResult: with lock: in_flight[batch_idx] = batch print( f"Batch {batch_idx + 1}/{total_batches} started ({len(batch)} schemas): {', '.join(batch)}", flush=True, ) result = run_alembic_for_batch(batch) with lock: in_flight.pop(batch_idx, None) return result future_to_idx = { executor.submit(_run, i, b): i for i, b in enumerate(batches) } for future in as_completed(future_to_idx): batch_idx = future_to_idx[future] try: result = future.result() status = "✓" if result.success else "✗" print( f"Batch {batch_idx + 1}/{total_batches} " f"{status} {len(result.schemas)} schemas " f"in {result.elapsed_sec:.1f}s", flush=True, ) if not result.success: # Print last 20 lines of retry output for diagnosis tail = result.output.strip().splitlines()[-20:] for line in tail: print(f" {line}", flush=True) all_success = False except Exception as e: print( f"Batch {batch_idx + 1}/{total_batches} ✗ exception: {e}", flush=True, ) all_success = False finally: stop_event.set() monitor_thread.join(timeout=2) return all_success # --------------------------------------------------------------------------- # CLI # --------------------------------------------------------------------------- def parse_args() -> Args: parser = argparse.ArgumentParser( description="Run alembic migrations for all tenant schemas in parallel" ) parser.add_argument( "-j", "--jobs", type=int, default=6, metavar="N", help="Number of parallel alembic processes (default: 6)", ) parser.add_argument( "-b", "--batch-size", type=int, default=50, metavar="N", help="Schemas per alembic process (default: 50)", ) args = parser.parse_args() if args.jobs < 1: parser.error("--jobs must be >= 1") if args.batch_size < 1: parser.error("--batch-size must be >= 1") return Args(jobs=args.jobs, batch_size=args.batch_size) def main() -> int: args = parse_args() head_rev = get_head_revision() if head_rev is None: print("Could not determine head revision.", file=sys.stderr) return 1 with SqlEngine.scoped_engine(pool_size=5, max_overflow=2): tenant_ids = get_all_tenant_ids() tenant_schemas = [tid for tid in tenant_ids if tid.startswith(TENANT_ID_PREFIX)] if not tenant_schemas: print( "No tenant schemas found. Is MULTI_TENANT=true set?", file=sys.stderr, ) return 1 schemas_to_migrate = get_schemas_needing_migration(tenant_schemas, head_rev) if not schemas_to_migrate: print( f"All {len(tenant_schemas)} tenants are already at head revision ({head_rev})." ) return 0 print( f"{len(schemas_to_migrate)}/{len(tenant_schemas)} tenants need migration (head: {head_rev})." ) success = run_migrations_parallel( schemas_to_migrate, max_workers=args.jobs, batch_size=args.batch_size, ) print(f"\n{'All migrations successful' if success else 'Some migrations failed'}") return 0 if success else 1 if __name__ == "__main__": raise SystemExit(main()) ================================================ FILE: backend/alembic/script.py.mako ================================================ """${message} Revision ID: ${up_revision} Revises: ${down_revision | comma,n} Create Date: ${create_date} """ from alembic import op import sqlalchemy as sa ${imports if imports else ""} # revision identifiers, used by Alembic. revision = ${repr(up_revision)} down_revision = ${repr(down_revision)} branch_labels = ${repr(branch_labels)} depends_on = ${repr(depends_on)} def upgrade() -> None: ${upgrades if upgrades else "pass"} def downgrade() -> None: ${downgrades if downgrades else "pass"} ================================================ FILE: backend/alembic/versions/01f8e6d95a33_populate_flow_mapping_data.py ================================================ """Populate flow mapping data Revision ID: 01f8e6d95a33 Revises: d5c86e2c6dc6 Create Date: 2026-01-31 17:37:10.485558 """ from alembic import op # revision identifiers, used by Alembic. revision = "01f8e6d95a33" down_revision = "d5c86e2c6dc6" branch_labels = None depends_on = None def upgrade() -> None: # Add each model config to the conversation flow, setting the global default if it exists # Exclude models that are part of ImageGenerationConfig op.execute( """ INSERT INTO llm_model_flow (llm_model_flow_type, is_default, model_configuration_id) SELECT 'CHAT' AS llm_model_flow_type, COALESCE( (lp.is_default_provider IS TRUE AND lp.default_model_name = mc.name), FALSE ) AS is_default, mc.id AS model_configuration_id FROM model_configuration mc LEFT JOIN llm_provider lp ON lp.id = mc.llm_provider_id WHERE NOT EXISTS ( SELECT 1 FROM image_generation_config igc WHERE igc.model_configuration_id = mc.id ); """ ) # Add models with supports_image_input to the vision flow op.execute( """ INSERT INTO llm_model_flow (llm_model_flow_type, is_default, model_configuration_id) SELECT 'VISION' AS llm_model_flow_type, COALESCE( (lp.is_default_vision_provider IS TRUE AND lp.default_vision_model = mc.name), FALSE ) AS is_default, mc.id AS model_configuration_id FROM model_configuration mc LEFT JOIN llm_provider lp ON lp.id = mc.llm_provider_id WHERE mc.supports_image_input IS TRUE; """ ) def downgrade() -> None: # Populate vision defaults from model_flow op.execute( """ UPDATE llm_provider AS lp SET is_default_vision_provider = TRUE, default_vision_model = mc.name FROM llm_model_flow mf JOIN model_configuration mc ON mc.id = mf.model_configuration_id WHERE mf.llm_model_flow_type = 'VISION' AND mf.is_default = TRUE AND mc.llm_provider_id = lp.id; """ ) # Populate conversation defaults from model_flow op.execute( """ UPDATE llm_provider AS lp SET is_default_provider = TRUE, default_model_name = mc.name FROM llm_model_flow mf JOIN model_configuration mc ON mc.id = mf.model_configuration_id WHERE mf.llm_model_flow_type = 'CHAT' AND mf.is_default = TRUE AND mc.llm_provider_id = lp.id; """ ) # For providers that have conversation flow mappings but aren't the default, # we still need a default_model_name (it was NOT NULL originally) # Pick the first visible model or any model for that provider op.execute( """ UPDATE llm_provider AS lp SET default_model_name = ( SELECT mc.name FROM model_configuration mc JOIN llm_model_flow mf ON mf.model_configuration_id = mc.id WHERE mc.llm_provider_id = lp.id AND mf.llm_model_flow_type = 'CHAT' ORDER BY mc.is_visible DESC, mc.id ASC LIMIT 1 ) WHERE lp.default_model_name IS NULL; """ ) # Delete all model_flow entries (reverse the inserts from upgrade) op.execute("DELETE FROM llm_model_flow;") ================================================ FILE: backend/alembic/versions/027381bce97c_add_shortcut_option_for_users.py ================================================ """add shortcut option for users Revision ID: 027381bce97c Revises: 6fc7886d665d Create Date: 2025-01-14 12:14:00.814390 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "027381bce97c" down_revision = "6fc7886d665d" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "user", sa.Column( "shortcut_enabled", sa.Boolean(), nullable=False, server_default="false" ), ) def downgrade() -> None: op.drop_column("user", "shortcut_enabled") ================================================ FILE: backend/alembic/versions/03bf8be6b53a_rework_kg_config.py ================================================ """rework-kg-config Revision ID: 03bf8be6b53a Revises: 65bc6e0f8500 Create Date: 2025-06-16 10:52:34.815335 """ import json from datetime import datetime from datetime import timedelta from sqlalchemy.dialects import postgresql from sqlalchemy import text from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "03bf8be6b53a" down_revision = "65bc6e0f8500" branch_labels = None depends_on = None def upgrade() -> None: # get current config current_configs = ( op.get_bind() .execute(text("SELECT kg_variable_name, kg_variable_values FROM kg_config")) .all() ) current_config_dict = { config.kg_variable_name: ( config.kg_variable_values[0] if config.kg_variable_name not in ("KG_VENDOR_DOMAINS", "KG_IGNORE_EMAIL_DOMAINS") else config.kg_variable_values ) for config in current_configs if config.kg_variable_values } # not using the KGConfigSettings model here in case it changes in the future kg_config_settings = json.dumps( { "KG_EXPOSED": current_config_dict.get("KG_EXPOSED", False), "KG_ENABLED": current_config_dict.get("KG_ENABLED", False), "KG_VENDOR": current_config_dict.get("KG_VENDOR", None), "KG_VENDOR_DOMAINS": current_config_dict.get("KG_VENDOR_DOMAINS", []), "KG_IGNORE_EMAIL_DOMAINS": current_config_dict.get( "KG_IGNORE_EMAIL_DOMAINS", [] ), "KG_COVERAGE_START": current_config_dict.get( "KG_COVERAGE_START", (datetime.now() - timedelta(days=90)).strftime("%Y-%m-%d"), ), "KG_MAX_COVERAGE_DAYS": current_config_dict.get("KG_MAX_COVERAGE_DAYS", 90), "KG_MAX_PARENT_RECURSION_DEPTH": current_config_dict.get( "KG_MAX_PARENT_RECURSION_DEPTH", 2 ), "KG_BETA_PERSONA_ID": current_config_dict.get("KG_BETA_PERSONA_ID", None), } ) op.execute( f"INSERT INTO key_value_store (key, value) VALUES ('kg_config', '{kg_config_settings}')" ) # drop kg config table op.drop_table("kg_config") def downgrade() -> None: # get current config current_config_dict = { "KG_EXPOSED": False, "KG_ENABLED": False, "KG_VENDOR": [], "KG_VENDOR_DOMAINS": [], "KG_IGNORE_EMAIL_DOMAINS": [], "KG_COVERAGE_START": (datetime.now() - timedelta(days=90)).strftime("%Y-%m-%d"), "KG_MAX_COVERAGE_DAYS": 90, "KG_MAX_PARENT_RECURSION_DEPTH": 2, } current_configs = ( op.get_bind() .execute(text("SELECT value FROM key_value_store WHERE key = 'kg_config'")) .one_or_none() ) if current_configs is not None: current_config_dict.update(current_configs[0]) insert_values = [ { "kg_variable_name": name, "kg_variable_values": ( [str(val).lower() if isinstance(val, bool) else str(val)] if not isinstance(val, list) else val ), } for name, val in current_config_dict.items() ] op.create_table( "kg_config", sa.Column("id", sa.Integer(), primary_key=True, nullable=False, index=True), sa.Column("kg_variable_name", sa.String(), nullable=False, index=True), sa.Column("kg_variable_values", postgresql.ARRAY(sa.String()), nullable=False), sa.UniqueConstraint("kg_variable_name", name="uq_kg_config_variable_name"), ) op.bulk_insert( sa.table( "kg_config", sa.column("kg_variable_name", sa.String), sa.column("kg_variable_values", postgresql.ARRAY(sa.String)), ), insert_values, ) op.execute("DELETE FROM key_value_store WHERE key = 'kg_config'") ================================================ FILE: backend/alembic/versions/03d085c5c38d_backfill_account_type.py ================================================ """backfill_account_type Revision ID: 03d085c5c38d Revises: 977e834c1427 Create Date: 2026-03-25 16:00:00.000000 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "03d085c5c38d" down_revision = "977e834c1427" branch_labels = None depends_on = None _STANDARD = "STANDARD" _BOT = "BOT" _EXT_PERM_USER = "EXT_PERM_USER" _SERVICE_ACCOUNT = "SERVICE_ACCOUNT" _ANONYMOUS = "ANONYMOUS" # Well-known anonymous user UUID ANONYMOUS_USER_ID = "00000000-0000-0000-0000-000000000002" # Email pattern for API key virtual users API_KEY_EMAIL_PATTERN = r"API\_KEY\_\_%" # Reflect the table structure for use in DML user_table = sa.table( "user", sa.column("id", sa.Uuid), sa.column("email", sa.String), sa.column("role", sa.String), sa.column("account_type", sa.String), ) def upgrade() -> None: # ------------------------------------------------------------------ # Step 1: Backfill account_type from role. # Order matters — most-specific matches first so the final catch-all # only touches rows that haven't been classified yet. # ------------------------------------------------------------------ # 1a. API key virtual users → SERVICE_ACCOUNT op.execute( sa.update(user_table) .where( user_table.c.email.ilike(API_KEY_EMAIL_PATTERN), user_table.c.account_type.is_(None), ) .values(account_type=_SERVICE_ACCOUNT) ) # 1b. Anonymous user → ANONYMOUS op.execute( sa.update(user_table) .where( user_table.c.id == ANONYMOUS_USER_ID, user_table.c.account_type.is_(None), ) .values(account_type=_ANONYMOUS) ) # 1c. SLACK_USER role → BOT op.execute( sa.update(user_table) .where( user_table.c.role == "SLACK_USER", user_table.c.account_type.is_(None), ) .values(account_type=_BOT) ) # 1d. EXT_PERM_USER role → EXT_PERM_USER op.execute( sa.update(user_table) .where( user_table.c.role == "EXT_PERM_USER", user_table.c.account_type.is_(None), ) .values(account_type=_EXT_PERM_USER) ) # 1e. Everything else → STANDARD op.execute( sa.update(user_table) .where(user_table.c.account_type.is_(None)) .values(account_type=_STANDARD) ) # ------------------------------------------------------------------ # Step 2: Set account_type to NOT NULL now that every row is filled. # ------------------------------------------------------------------ op.alter_column( "user", "account_type", nullable=False, server_default="STANDARD", ) def downgrade() -> None: op.alter_column("user", "account_type", nullable=True, server_default=None) op.execute(sa.update(user_table).values(account_type=None)) ================================================ FILE: backend/alembic/versions/03d710ccf29c_add_permission_sync_attempt_tables.py ================================================ """add permission sync attempt tables Revision ID: 03d710ccf29c Revises: 96a5702df6aa Create Date: 2025-09-11 13:30:00.000000 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "03d710ccf29c" # Generate a new unique ID down_revision = "96a5702df6aa" branch_labels = None depends_on = None def upgrade() -> None: # Create the permission sync status enum permission_sync_status_enum = sa.Enum( "not_started", "in_progress", "success", "canceled", "failed", "completed_with_errors", name="permissionsyncstatus", native_enum=False, ) # Create doc_permission_sync_attempt table op.create_table( "doc_permission_sync_attempt", sa.Column("id", sa.Integer(), nullable=False), sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False), sa.Column("status", permission_sync_status_enum, nullable=False), sa.Column("total_docs_synced", sa.Integer(), nullable=True), sa.Column("docs_with_permission_errors", sa.Integer(), nullable=True), sa.Column("error_message", sa.Text(), nullable=True), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column("time_started", sa.DateTime(timezone=True), nullable=True), sa.Column("time_finished", sa.DateTime(timezone=True), nullable=True), sa.ForeignKeyConstraint( ["connector_credential_pair_id"], ["connector_credential_pair.id"], ), sa.PrimaryKeyConstraint("id"), ) # Create indexes for doc_permission_sync_attempt op.create_index( "ix_doc_permission_sync_attempt_time_created", "doc_permission_sync_attempt", ["time_created"], unique=False, ) op.create_index( "ix_permission_sync_attempt_latest_for_cc_pair", "doc_permission_sync_attempt", ["connector_credential_pair_id", "time_created"], unique=False, ) op.create_index( "ix_permission_sync_attempt_status_time", "doc_permission_sync_attempt", ["status", sa.text("time_finished DESC")], unique=False, ) # Create external_group_permission_sync_attempt table # connector_credential_pair_id is nullable - group syncs can be global (e.g., Confluence) op.create_table( "external_group_permission_sync_attempt", sa.Column("id", sa.Integer(), nullable=False), sa.Column("connector_credential_pair_id", sa.Integer(), nullable=True), sa.Column("status", permission_sync_status_enum, nullable=False), sa.Column("total_users_processed", sa.Integer(), nullable=True), sa.Column("total_groups_processed", sa.Integer(), nullable=True), sa.Column("total_group_memberships_synced", sa.Integer(), nullable=True), sa.Column("error_message", sa.Text(), nullable=True), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column("time_started", sa.DateTime(timezone=True), nullable=True), sa.Column("time_finished", sa.DateTime(timezone=True), nullable=True), sa.ForeignKeyConstraint( ["connector_credential_pair_id"], ["connector_credential_pair.id"], ), sa.PrimaryKeyConstraint("id"), ) # Create indexes for external_group_permission_sync_attempt op.create_index( "ix_external_group_permission_sync_attempt_time_created", "external_group_permission_sync_attempt", ["time_created"], unique=False, ) op.create_index( "ix_group_sync_attempt_cc_pair_time", "external_group_permission_sync_attempt", ["connector_credential_pair_id", "time_created"], unique=False, ) op.create_index( "ix_group_sync_attempt_status_time", "external_group_permission_sync_attempt", ["status", sa.text("time_finished DESC")], unique=False, ) def downgrade() -> None: # Drop indexes op.drop_index( "ix_group_sync_attempt_status_time", table_name="external_group_permission_sync_attempt", ) op.drop_index( "ix_group_sync_attempt_cc_pair_time", table_name="external_group_permission_sync_attempt", ) op.drop_index( "ix_external_group_permission_sync_attempt_time_created", table_name="external_group_permission_sync_attempt", ) op.drop_index( "ix_permission_sync_attempt_status_time", table_name="doc_permission_sync_attempt", ) op.drop_index( "ix_permission_sync_attempt_latest_for_cc_pair", table_name="doc_permission_sync_attempt", ) op.drop_index( "ix_doc_permission_sync_attempt_time_created", table_name="doc_permission_sync_attempt", ) # Drop tables op.drop_table("external_group_permission_sync_attempt") op.drop_table("doc_permission_sync_attempt") ================================================ FILE: backend/alembic/versions/0568ccf46a6b_add_thread_specific_model_selection.py ================================================ """Add thread specific model selection Revision ID: 0568ccf46a6b Revises: e209dc5a8156 Create Date: 2024-06-19 14:25:36.376046 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "0568ccf46a6b" down_revision = "e209dc5a8156" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "chat_session", sa.Column("current_alternate_model", sa.String(), nullable=True), ) def downgrade() -> None: op.drop_column("chat_session", "current_alternate_model") ================================================ FILE: backend/alembic/versions/05c07bf07c00_add_search_doc_relevance_details.py ================================================ """add search doc relevance details Revision ID: 05c07bf07c00 Revises: b896bbd0d5a7 Create Date: 2024-07-10 17:48:15.886653 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "05c07bf07c00" down_revision = "b896bbd0d5a7" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "search_doc", sa.Column("is_relevant", sa.Boolean(), nullable=True), ) op.add_column( "search_doc", sa.Column("relevance_explanation", sa.String(), nullable=True), ) def downgrade() -> None: op.drop_column("search_doc", "relevance_explanation") op.drop_column("search_doc", "is_relevant") ================================================ FILE: backend/alembic/versions/07b98176f1de_code_interpreter_seed.py ================================================ """code interpreter seed Revision ID: 07b98176f1de Revises: 7cb492013621 Create Date: 2026-02-23 15:55:07.606784 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "07b98176f1de" down_revision = "7cb492013621" branch_labels = None depends_on = None def upgrade() -> None: # Seed the single instance of code_interpreter_server # NOTE: There should only exist at most and at minimum 1 code_interpreter_server row op.execute( sa.text("INSERT INTO code_interpreter_server (server_enabled) VALUES (true)") ) def downgrade() -> None: op.execute(sa.text("DELETE FROM code_interpreter_server")) ================================================ FILE: backend/alembic/versions/0816326d83aa_add_federated_connector_tables.py ================================================ """add federated connector tables Revision ID: 0816326d83aa Revises: 12635f6655b7 Create Date: 2025-06-29 14:09:45.109518 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "0816326d83aa" down_revision = "12635f6655b7" branch_labels = None depends_on = None def upgrade() -> None: # Create federated_connector table op.create_table( "federated_connector", sa.Column("id", sa.Integer(), nullable=False), sa.Column("source", sa.String(), nullable=False), sa.Column("credentials", sa.LargeBinary(), nullable=False), sa.PrimaryKeyConstraint("id"), ) # Create federated_connector_oauth_token table op.create_table( "federated_connector_oauth_token", sa.Column("id", sa.Integer(), nullable=False), sa.Column("federated_connector_id", sa.Integer(), nullable=False), sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), sa.Column("token", sa.LargeBinary(), nullable=False), sa.Column("expires_at", sa.DateTime(), nullable=True), sa.ForeignKeyConstraint( ["federated_connector_id"], ["federated_connector.id"], ondelete="CASCADE" ), sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("id"), ) # Create federated_connector__document_set table op.create_table( "federated_connector__document_set", sa.Column("id", sa.Integer(), nullable=False), sa.Column("federated_connector_id", sa.Integer(), nullable=False), sa.Column("document_set_id", sa.Integer(), nullable=False), sa.Column("entities", postgresql.JSONB(), nullable=False), sa.ForeignKeyConstraint( ["federated_connector_id"], ["federated_connector.id"], ondelete="CASCADE" ), sa.ForeignKeyConstraint( ["document_set_id"], ["document_set.id"], ondelete="CASCADE" ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint( "federated_connector_id", "document_set_id", name="uq_federated_connector_document_set", ), ) def downgrade() -> None: # Drop tables in reverse order due to foreign key dependencies op.drop_table("federated_connector__document_set") op.drop_table("federated_connector_oauth_token") op.drop_table("federated_connector") ================================================ FILE: backend/alembic/versions/08a1eda20fe1_add_earliest_indexing_to_connector.py ================================================ """add_indexing_start_to_connector Revision ID: 08a1eda20fe1 Revises: 8a87bd6ec550 Create Date: 2024-07-23 11:12:39.462397 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "08a1eda20fe1" down_revision = "8a87bd6ec550" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "connector", sa.Column("indexing_start", sa.DateTime(), nullable=True) ) def downgrade() -> None: op.drop_column("connector", "indexing_start") ================================================ FILE: backend/alembic/versions/09995b8811eb_add_theme_preference_to_user.py ================================================ """add theme_preference to user Revision ID: 09995b8811eb Revises: 3d1cca026fe8 Create Date: 2025-10-24 08:58:50.246949 """ from alembic import op import sqlalchemy as sa from onyx.db.enums import ThemePreference # revision identifiers, used by Alembic. revision = "09995b8811eb" down_revision = "3d1cca026fe8" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "user", sa.Column( "theme_preference", sa.Enum(ThemePreference, native_enum=False), nullable=True, ), ) def downgrade() -> None: op.drop_column("user", "theme_preference") ================================================ FILE: backend/alembic/versions/0a2b51deb0b8_add_starter_prompts.py ================================================ """Add starter prompts Revision ID: 0a2b51deb0b8 Revises: 5f4b8568a221 Create Date: 2024-03-02 23:23:49.960309 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "0a2b51deb0b8" down_revision = "5f4b8568a221" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "persona", sa.Column( "starter_messages", postgresql.JSONB(astext_type=sa.Text()), nullable=True, ), ) def downgrade() -> None: op.drop_column("persona", "starter_messages") ================================================ FILE: backend/alembic/versions/0a98909f2757_enable_encrypted_fields.py ================================================ """Enable Encrypted Fields Revision ID: 0a98909f2757 Revises: 570282d33c49 Create Date: 2024-05-05 19:30:34.317972 """ from alembic import op import sqlalchemy as sa from sqlalchemy.sql import table from sqlalchemy.dialects import postgresql import json from onyx.utils.encryption import encrypt_string_to_bytes # revision identifiers, used by Alembic. revision = "0a98909f2757" down_revision = "570282d33c49" branch_labels: None = None depends_on: None = None def upgrade() -> None: connection = op.get_bind() op.alter_column("key_value_store", "value", nullable=True) op.add_column( "key_value_store", sa.Column( "encrypted_value", sa.LargeBinary, nullable=True, ), ) # Need a temporary column to translate the JSONB to binary op.add_column("credential", sa.Column("temp_column", sa.LargeBinary())) creds_table = table( "credential", sa.Column("id", sa.Integer(), nullable=False), sa.Column( "credential_json", postgresql.JSONB(astext_type=sa.Text()), nullable=False, ), sa.Column( "temp_column", sa.LargeBinary(), nullable=False, ), ) results = connection.execute(sa.select(creds_table)) # This uses the MIT encrypt which does not actually encrypt the credentials # In other words, this upgrade does not apply the encryption. Porting existing sensitive data # and key rotation currently is not supported and will come out in the future for row_id, creds, _ in results: creds_binary = encrypt_string_to_bytes(json.dumps(creds)) connection.execute( creds_table.update() .where(creds_table.c.id == row_id) .values(temp_column=creds_binary) ) op.drop_column("credential", "credential_json") op.alter_column("credential", "temp_column", new_column_name="credential_json") op.add_column("llm_provider", sa.Column("temp_column", sa.LargeBinary())) llm_table = table( "llm_provider", sa.Column("id", sa.Integer(), nullable=False), sa.Column( "api_key", sa.String(), nullable=False, ), sa.Column( "temp_column", sa.LargeBinary(), nullable=False, ), ) results = connection.execute(sa.select(llm_table)) for row_id, api_key, _ in results: llm_key = encrypt_string_to_bytes(api_key) connection.execute( llm_table.update() .where(llm_table.c.id == row_id) .values(temp_column=llm_key) ) op.drop_column("llm_provider", "api_key") op.alter_column("llm_provider", "temp_column", new_column_name="api_key") def downgrade() -> None: # Some information loss but this is ok. Should not allow decryption via downgrade. op.drop_column("credential", "credential_json") op.drop_column("llm_provider", "api_key") op.add_column("llm_provider", sa.Column("api_key", sa.String())) op.add_column( "credential", sa.Column("credential_json", postgresql.JSONB(astext_type=sa.Text())), ) op.execute("DELETE FROM key_value_store WHERE value IS NULL") op.alter_column("key_value_store", "value", nullable=False) op.drop_column("key_value_store", "encrypted_value") ================================================ FILE: backend/alembic/versions/0bb4558f35df_add_scim_username_to_scim_user_mapping.py ================================================ """add scim_username to scim_user_mapping Revision ID: 0bb4558f35df Revises: 631fd2504136 Create Date: 2026-02-20 10:45:30.340188 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "0bb4558f35df" down_revision = "631fd2504136" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "scim_user_mapping", sa.Column("scim_username", sa.String(), nullable=True), ) def downgrade() -> None: op.drop_column("scim_user_mapping", "scim_username") ================================================ FILE: backend/alembic/versions/0cd424f32b1d_user_file_data_preparation_and_backfill.py ================================================ """Migration 2: User file data preparation and backfill Revision ID: 0cd424f32b1d Revises: 9b66d3156fc6 Create Date: 2025-09-22 09:44:42.727034 This migration populates the new columns added in migration 1. It prepares data for the UUID transition and relationship migration. """ from alembic import op import sqlalchemy as sa from sqlalchemy import text import logging logger = logging.getLogger("alembic.runtime.migration") # revision identifiers, used by Alembic. revision = "0cd424f32b1d" down_revision = "9b66d3156fc6" branch_labels = None depends_on = None def upgrade() -> None: """Populate new columns with data.""" bind = op.get_bind() inspector = sa.inspect(bind) # === Step 1: Populate user_file.new_id === user_file_columns = [col["name"] for col in inspector.get_columns("user_file")] has_new_id = "new_id" in user_file_columns if has_new_id: logger.info("Populating user_file.new_id with UUIDs...") # Count rows needing UUIDs null_count = bind.execute( text("SELECT COUNT(*) FROM user_file WHERE new_id IS NULL") ).scalar_one() if null_count > 0: logger.info(f"Generating UUIDs for {null_count} user_file records...") # Populate in batches to avoid long locks batch_size = 10000 total_updated = 0 while True: result = bind.execute( text( """ UPDATE user_file SET new_id = gen_random_uuid() WHERE new_id IS NULL AND id IN ( SELECT id FROM user_file WHERE new_id IS NULL LIMIT :batch_size ) """ ), {"batch_size": batch_size}, ) updated = result.rowcount total_updated += updated if updated < batch_size: break logger.info(f" Updated {total_updated}/{null_count} records...") logger.info(f"Generated UUIDs for {total_updated} user_file records") # Verify all records have UUIDs remaining_null = bind.execute( text("SELECT COUNT(*) FROM user_file WHERE new_id IS NULL") ).scalar_one() if remaining_null > 0: raise Exception( f"Failed to populate all user_file.new_id values ({remaining_null} NULL)" ) # Lock down the column op.alter_column("user_file", "new_id", nullable=False) op.alter_column("user_file", "new_id", server_default=None) logger.info("Locked down user_file.new_id column") # === Step 2: Populate persona__user_file.user_file_id_uuid === persona_user_file_columns = [ col["name"] for col in inspector.get_columns("persona__user_file") ] if has_new_id and "user_file_id_uuid" in persona_user_file_columns: logger.info("Populating persona__user_file.user_file_id_uuid...") # Count rows needing update null_count = bind.execute( text( """ SELECT COUNT(*) FROM persona__user_file WHERE user_file_id IS NOT NULL AND user_file_id_uuid IS NULL """ ) ).scalar_one() if null_count > 0: logger.info(f"Updating {null_count} persona__user_file records...") # Update in batches batch_size = 10000 total_updated = 0 while True: result = bind.execute( text( """ UPDATE persona__user_file p SET user_file_id_uuid = uf.new_id FROM user_file uf WHERE p.user_file_id = uf.id AND p.user_file_id_uuid IS NULL AND p.persona_id IN ( SELECT persona_id FROM persona__user_file WHERE user_file_id_uuid IS NULL LIMIT :batch_size ) """ ), {"batch_size": batch_size}, ) updated = result.rowcount total_updated += updated if updated < batch_size: break logger.info(f" Updated {total_updated}/{null_count} records...") logger.info(f"Updated {total_updated} persona__user_file records") # Verify all records are populated remaining_null = bind.execute( text( """ SELECT COUNT(*) FROM persona__user_file WHERE user_file_id IS NOT NULL AND user_file_id_uuid IS NULL """ ) ).scalar_one() if remaining_null > 0: raise Exception( f"Failed to populate all persona__user_file.user_file_id_uuid values ({remaining_null} NULL)" ) op.alter_column("persona__user_file", "user_file_id_uuid", nullable=False) logger.info("Locked down persona__user_file.user_file_id_uuid column") # === Step 3: Create user_project records from chat_folder === if "chat_folder" in inspector.get_table_names(): logger.info("Creating user_project records from chat_folder...") result = bind.execute( text( """ INSERT INTO user_project (user_id, name) SELECT cf.user_id, cf.name FROM chat_folder cf WHERE NOT EXISTS ( SELECT 1 FROM user_project up WHERE up.user_id = cf.user_id AND up.name = cf.name ) """ ) ) logger.info(f"Created {result.rowcount} user_project records from chat_folder") # === Step 4: Populate chat_session.project_id === chat_session_columns = [ col["name"] for col in inspector.get_columns("chat_session") ] if "folder_id" in chat_session_columns and "project_id" in chat_session_columns: logger.info("Populating chat_session.project_id...") # Count sessions needing update null_count = bind.execute( text( """ SELECT COUNT(*) FROM chat_session WHERE project_id IS NULL AND folder_id IS NOT NULL """ ) ).scalar_one() if null_count > 0: logger.info(f"Updating {null_count} chat_session records...") result = bind.execute( text( """ UPDATE chat_session cs SET project_id = up.id FROM chat_folder cf JOIN user_project up ON up.user_id = cf.user_id AND up.name = cf.name WHERE cs.folder_id = cf.id AND cs.project_id IS NULL """ ) ) logger.info(f"Updated {result.rowcount} chat_session records") # Verify all records are populated remaining_null = bind.execute( text( """ SELECT COUNT(*) FROM chat_session WHERE project_id IS NULL AND folder_id IS NOT NULL """ ) ).scalar_one() if remaining_null > 0: logger.warning( f"Warning: {remaining_null} chat_session records could not be mapped to projects" ) # === Step 5: Update plaintext FileRecord IDs/display names to UUID scheme === # Prior to UUID migration, plaintext cache files were stored with file_id like 'plain_text_'. # After migration, we use 'plaintext_' (note the name change to 'plaintext_'). # This step remaps existing FileRecord rows to the new naming while preserving object_key/bucket. logger.info("Updating plaintext FileRecord ids and display names to UUID scheme...") # Count legacy plaintext records that can be mapped to UUID user_file ids count_query = text( """ SELECT COUNT(*) FROM file_record fr JOIN user_file uf ON fr.file_id = CONCAT('plaintext_', uf.id::text) WHERE LOWER(fr.file_origin::text) = 'plaintext_cache' """ ) legacy_count = bind.execute(count_query).scalar_one() if legacy_count and legacy_count > 0: logger.info(f"Found {legacy_count} legacy plaintext file records to update") # Update display_name first for readability (safe regardless of rename) bind.execute( text( """ UPDATE file_record fr SET display_name = CONCAT('Plaintext for user file ', uf.new_id::text) FROM user_file uf WHERE LOWER(fr.file_origin::text) = 'plaintext_cache' AND fr.file_id = CONCAT('plaintext_', uf.id::text) """ ) ) # Remap file_id from 'plaintext_' -> 'plaintext_' using transitional new_id # Use a single UPDATE ... WHERE file_id LIKE 'plain_text_%' # and ensure it aligns to existing user_file ids to avoid renaming unrelated rows result = bind.execute( text( """ UPDATE file_record fr SET file_id = CONCAT('plaintext_', uf.new_id::text) FROM user_file uf WHERE LOWER(fr.file_origin::text) = 'plaintext_cache' AND fr.file_id = CONCAT('plaintext_', uf.id::text) """ ) ) logger.info( f"Updated {result.rowcount} plaintext file_record ids to UUID scheme" ) # === Step 6: Ensure document_id_migrated default TRUE and backfill existing FALSE === # New records should default to migrated=True so the migration task won't run for them. # Existing rows that had a legacy document_id should be marked as not migrated to be processed. # Backfill existing records: if document_id is not null, set to FALSE bind.execute( text( """ UPDATE user_file SET document_id_migrated = FALSE WHERE document_id IS NOT NULL """ ) ) # === Step 7: Backfill user_file.status from index_attempt === logger.info("Backfilling user_file.status from index_attempt...") # Update user_file status based on latest index attempt # Using CTEs instead of temp tables for asyncpg compatibility result = bind.execute( text( """ WITH latest_attempt AS ( SELECT DISTINCT ON (ia.connector_credential_pair_id) ia.connector_credential_pair_id, ia.status FROM index_attempt ia ORDER BY ia.connector_credential_pair_id, ia.time_updated DESC ), uf_to_ccp AS ( SELECT DISTINCT uf.id AS uf_id, ccp.id AS cc_pair_id FROM user_file uf JOIN document_by_connector_credential_pair dcc ON dcc.id = REPLACE(uf.document_id, 'USER_FILE_CONNECTOR__', 'FILE_CONNECTOR__') JOIN connector_credential_pair ccp ON ccp.connector_id = dcc.connector_id AND ccp.credential_id = dcc.credential_id ) UPDATE user_file uf SET status = CASE WHEN la.status IN ('NOT_STARTED', 'IN_PROGRESS') THEN 'PROCESSING' WHEN la.status = 'SUCCESS' THEN 'COMPLETED' ELSE 'FAILED' END FROM uf_to_ccp ufc LEFT JOIN latest_attempt la ON la.connector_credential_pair_id = ufc.cc_pair_id WHERE uf.id = ufc.uf_id AND uf.status = 'PROCESSING' """ ) ) logger.info(f"Updated status for {result.rowcount} user_file records") logger.info("Migration 2 (data preparation) completed successfully") def downgrade() -> None: """Reset populated data to allow clean downgrade of schema.""" bind = op.get_bind() inspector = sa.inspect(bind) logger.info("Starting downgrade of data preparation...") # Reset user_file columns to allow nulls before data removal if "user_file" in inspector.get_table_names(): columns = [col["name"] for col in inspector.get_columns("user_file")] if "new_id" in columns: op.alter_column( "user_file", "new_id", nullable=True, server_default=sa.text("gen_random_uuid()"), ) # Optionally clear the data # bind.execute(text("UPDATE user_file SET new_id = NULL")) logger.info("Reset user_file.new_id to nullable") # Reset persona__user_file.user_file_id_uuid if "persona__user_file" in inspector.get_table_names(): columns = [col["name"] for col in inspector.get_columns("persona__user_file")] if "user_file_id_uuid" in columns: op.alter_column("persona__user_file", "user_file_id_uuid", nullable=True) # Optionally clear the data # bind.execute(text("UPDATE persona__user_file SET user_file_id_uuid = NULL")) logger.info("Reset persona__user_file.user_file_id_uuid to nullable") # Note: We don't delete user_project records or reset chat_session.project_id # as these might be in use and can be handled by the schema downgrade # Reset user_file.status to default if "user_file" in inspector.get_table_names(): columns = [col["name"] for col in inspector.get_columns("user_file")] if "status" in columns: bind.execute(text("UPDATE user_file SET status = 'PROCESSING'")) logger.info("Reset user_file.status to default") logger.info("Downgrade completed successfully") ================================================ FILE: backend/alembic/versions/0ebb1d516877_add_ccpair_deletion_failure_message.py ================================================ """add ccpair deletion failure message Revision ID: 0ebb1d516877 Revises: 52a219fb5233 Create Date: 2024-09-10 15:03:48.233926 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "0ebb1d516877" down_revision = "52a219fb5233" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "connector_credential_pair", sa.Column("deletion_failure_message", sa.String(), nullable=True), ) def downgrade() -> None: op.drop_column("connector_credential_pair", "deletion_failure_message") ================================================ FILE: backend/alembic/versions/0f7ff6d75b57_add_index_to_index_attempt_time_created.py ================================================ """add index to index_attempt.time_created Revision ID: 0f7ff6d75b57 Revises: 369644546676 Create Date: 2025-01-10 14:01:14.067144 """ from alembic import op # revision identifiers, used by Alembic. revision = "0f7ff6d75b57" down_revision = "fec3db967bf7" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_index( op.f("ix_index_attempt_status"), "index_attempt", ["status"], unique=False, ) op.create_index( op.f("ix_index_attempt_time_created"), "index_attempt", ["time_created"], unique=False, ) def downgrade() -> None: op.drop_index(op.f("ix_index_attempt_time_created"), table_name="index_attempt") op.drop_index(op.f("ix_index_attempt_status"), table_name="index_attempt") ================================================ FILE: backend/alembic/versions/114a638452db_add_default_app_mode_to_user.py ================================================ """add default_app_mode to user Revision ID: 114a638452db Revises: feead2911109 Create Date: 2026-02-09 18:57:08.274640 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "114a638452db" down_revision = "feead2911109" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "user", sa.Column( "default_app_mode", sa.String(), nullable=False, server_default="CHAT", ), ) def downgrade() -> None: op.drop_column("user", "default_app_mode") ================================================ FILE: backend/alembic/versions/12635f6655b7_drive_canonical_ids.py ================================================ """drive-canonical-ids Revision ID: 12635f6655b7 Revises: 58c50ef19f08 Create Date: 2025-06-20 14:44:54.241159 """ from alembic import op import sqlalchemy as sa from urllib.parse import urlparse, urlunparse from httpx import HTTPStatusError import httpx from onyx.db.search_settings import SearchSettings from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client from onyx.document_index.vespa.shared_utils.utils import ( replace_invalid_doc_id_characters, ) from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT from onyx.utils.logger import setup_logger import os logger = setup_logger() # revision identifiers, used by Alembic. revision = "12635f6655b7" down_revision = "58c50ef19f08" branch_labels = None depends_on = None SKIP_CANON_DRIVE_IDS = os.environ.get("SKIP_CANON_DRIVE_IDS", "true").lower() == "true" def active_search_settings() -> tuple[SearchSettings, SearchSettings | None]: result = op.get_bind().execute( sa.text( """ SELECT * FROM search_settings WHERE status = 'PRESENT' ORDER BY id DESC LIMIT 1 """ ) ) search_settings_fetch = result.fetchall() search_settings = ( SearchSettings(**search_settings_fetch[0]._asdict()) if search_settings_fetch else None ) result2 = op.get_bind().execute( sa.text( """ SELECT * FROM search_settings WHERE status = 'FUTURE' ORDER BY id DESC LIMIT 1 """ ) ) search_settings_future_fetch = result2.fetchall() search_settings_future = ( SearchSettings(**search_settings_future_fetch[0]._asdict()) if search_settings_future_fetch else None ) if not isinstance(search_settings, SearchSettings): raise RuntimeError( "current search settings is of type " + str(type(search_settings)) ) if ( not isinstance(search_settings_future, SearchSettings) and search_settings_future is not None ): raise RuntimeError( "future search settings is of type " + str(type(search_settings_future)) ) return search_settings, search_settings_future def normalize_google_drive_url(url: str) -> str: """Remove query parameters from Google Drive URLs to create canonical document IDs. NOTE: copied from drive doc_conversion.py """ parsed_url = urlparse(url) parsed_url = parsed_url._replace(query="") spl_path = parsed_url.path.split("/") if spl_path and (spl_path[-1] in ["edit", "view", "preview"]): spl_path.pop() parsed_url = parsed_url._replace(path="/".join(spl_path)) # Remove query parameters and reconstruct URL return urlunparse(parsed_url) def get_google_drive_documents_from_database() -> list[dict]: """Get all Google Drive documents from the database.""" bind = op.get_bind() result = bind.execute( sa.text( """ SELECT d.id FROM document d JOIN document_by_connector_credential_pair dcc ON d.id = dcc.id JOIN connector_credential_pair cc ON dcc.connector_id = cc.connector_id AND dcc.credential_id = cc.credential_id JOIN connector c ON cc.connector_id = c.id WHERE c.source = 'GOOGLE_DRIVE' """ ) ) documents = [] for row in result: documents.append({"document_id": row.id}) return documents def update_document_id_in_database( old_doc_id: str, new_doc_id: str, index_name: str ) -> None: """Update document IDs in all relevant database tables using copy-and-swap approach.""" bind = op.get_bind() # print(f"Updating database tables for document {old_doc_id} -> {new_doc_id}") # Check if new document ID already exists result = bind.execute( sa.text("SELECT COUNT(*) FROM document WHERE id = :new_id"), {"new_id": new_doc_id}, ) row = result.fetchone() if row and row[0] > 0: # print(f"Document with ID {new_doc_id} already exists, deleting old one") delete_document_from_db(old_doc_id, index_name) return # Step 1: Create a new document row with the new ID (copy all fields from old row) # Use a conservative approach to handle columns that might not exist in all installations try: bind.execute( sa.text( """ INSERT INTO document (id, from_ingestion_api, boost, hidden, semantic_id, link, doc_updated_at, primary_owners, secondary_owners, external_user_emails, external_user_group_ids, is_public, chunk_count, last_modified, last_synced, kg_stage, kg_processing_time) SELECT :new_id, from_ingestion_api, boost, hidden, semantic_id, link, doc_updated_at, primary_owners, secondary_owners, external_user_emails, external_user_group_ids, is_public, chunk_count, last_modified, last_synced, kg_stage, kg_processing_time FROM document WHERE id = :old_id """ ), {"new_id": new_doc_id, "old_id": old_doc_id}, ) # print(f"Successfully updated database tables for document {old_doc_id} -> {new_doc_id}") except Exception as e: # If the full INSERT fails, try a more basic version with only core columns logger.warning(f"Full INSERT failed, trying basic version: {e}") bind.execute( sa.text( """ INSERT INTO document (id, from_ingestion_api, boost, hidden, semantic_id, link, doc_updated_at, primary_owners, secondary_owners) SELECT :new_id, from_ingestion_api, boost, hidden, semantic_id, link, doc_updated_at, primary_owners, secondary_owners FROM document WHERE id = :old_id """ ), {"new_id": new_doc_id, "old_id": old_doc_id}, ) # Step 2: Update all foreign key references to point to the new ID # Update document_by_connector_credential_pair table bind.execute( sa.text( "UPDATE document_by_connector_credential_pair SET id = :new_id WHERE id = :old_id" ), {"new_id": new_doc_id, "old_id": old_doc_id}, ) # print(f"Successfully updated document_by_connector_credential_pair table for document {old_doc_id} -> {new_doc_id}") # Update search_doc table (stores search results for chat replay) # This is critical for agent functionality bind.execute( sa.text( "UPDATE search_doc SET document_id = :new_id WHERE document_id = :old_id" ), {"new_id": new_doc_id, "old_id": old_doc_id}, ) # print(f"Successfully updated search_doc table for document {old_doc_id} -> {new_doc_id}") # Update document_retrieval_feedback table (user feedback on documents) bind.execute( sa.text( "UPDATE document_retrieval_feedback SET document_id = :new_id WHERE document_id = :old_id" ), {"new_id": new_doc_id, "old_id": old_doc_id}, ) # print(f"Successfully updated document_retrieval_feedback table for document {old_doc_id} -> {new_doc_id}") # Update document__tag table (document-tag relationships) bind.execute( sa.text( "UPDATE document__tag SET document_id = :new_id WHERE document_id = :old_id" ), {"new_id": new_doc_id, "old_id": old_doc_id}, ) # print(f"Successfully updated document__tag table for document {old_doc_id} -> {new_doc_id}") # Update user_file table (user uploaded files linked to documents) bind.execute( sa.text( "UPDATE user_file SET document_id = :new_id WHERE document_id = :old_id" ), {"new_id": new_doc_id, "old_id": old_doc_id}, ) # print(f"Successfully updated user_file table for document {old_doc_id} -> {new_doc_id}") # Update KG and chunk_stats tables (these may not exist in all installations) try: # Update kg_entity table bind.execute( sa.text( "UPDATE kg_entity SET document_id = :new_id WHERE document_id = :old_id" ), {"new_id": new_doc_id, "old_id": old_doc_id}, ) # print(f"Successfully updated kg_entity table for document {old_doc_id} -> {new_doc_id}") # Update kg_entity_extraction_staging table bind.execute( sa.text( "UPDATE kg_entity_extraction_staging SET document_id = :new_id WHERE document_id = :old_id" ), {"new_id": new_doc_id, "old_id": old_doc_id}, ) # print(f"Successfully updated kg_entity_extraction_staging table for document {old_doc_id} -> {new_doc_id}") # Update kg_relationship table bind.execute( sa.text( "UPDATE kg_relationship SET source_document = :new_id WHERE source_document = :old_id" ), {"new_id": new_doc_id, "old_id": old_doc_id}, ) # print(f"Successfully updated kg_relationship table for document {old_doc_id} -> {new_doc_id}") # Update kg_relationship_extraction_staging table bind.execute( sa.text( "UPDATE kg_relationship_extraction_staging SET source_document = :new_id WHERE source_document = :old_id" ), {"new_id": new_doc_id, "old_id": old_doc_id}, ) # print(f"Successfully updated kg_relationship_extraction_staging table for document {old_doc_id} -> {new_doc_id}") # Update chunk_stats table bind.execute( sa.text( "UPDATE chunk_stats SET document_id = :new_id WHERE document_id = :old_id" ), {"new_id": new_doc_id, "old_id": old_doc_id}, ) # print(f"Successfully updated chunk_stats table for document {old_doc_id} -> {new_doc_id}") # Update chunk_stats ID field which includes document_id bind.execute( sa.text( """ UPDATE chunk_stats SET id = REPLACE(id, :old_id, :new_id) WHERE id LIKE :old_id_pattern """ ), { "new_id": new_doc_id, "old_id": old_doc_id, "old_id_pattern": f"{old_doc_id}__%", }, ) # print(f"Successfully updated chunk_stats ID field for document {old_doc_id} -> {new_doc_id}") except Exception as e: logger.warning(f"Some KG/chunk tables may not exist or failed to update: {e}") # Step 3: Delete the old document row (this should now be safe since all FKs point to new row) bind.execute( sa.text("DELETE FROM document WHERE id = :old_id"), {"old_id": old_doc_id} ) # print(f"Successfully deleted document {old_doc_id} from database") def _visit_chunks( *, http_client: httpx.Client, index_name: str, selection: str, continuation: str | None = None, ) -> tuple[list[dict], str | None]: """Helper that calls the /document/v1 visit API once and returns (docs, next_token).""" # Use the same URL as the document API, but with visit-specific params base_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name) params: dict[str, str] = { "selection": selection, "wantedDocumentCount": "1000", } if continuation: params["continuation"] = continuation # print(f"Visiting chunks for selection '{selection}' with params {params}") resp = http_client.get(base_url, params=params, timeout=None) # print(f"Visited chunks for document {selection}") resp.raise_for_status() payload = resp.json() return payload.get("documents", []), payload.get("continuation") def delete_document_chunks_from_vespa(index_name: str, doc_id: str) -> None: """Delete all chunks for *doc_id* from Vespa using continuation-token paging (no offset).""" total_deleted = 0 # Use exact match instead of contains - Document Selector Language doesn't support contains selection = f'{index_name}.document_id=="{doc_id}"' with get_vespa_http_client() as http_client: continuation: str | None = None while True: docs, continuation = _visit_chunks( http_client=http_client, index_name=index_name, selection=selection, continuation=continuation, ) if not docs: break for doc in docs: vespa_full_id = doc.get("id") if not vespa_full_id: continue vespa_doc_uuid = vespa_full_id.split("::")[-1] delete_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_doc_uuid}" try: resp = http_client.delete(delete_url) resp.raise_for_status() total_deleted += 1 except Exception as e: print(f"Failed to delete chunk {vespa_doc_uuid}: {e}") if not continuation: break def update_document_id_in_vespa( index_name: str, old_doc_id: str, new_doc_id: str ) -> None: """Update all chunks' document_id field from *old_doc_id* to *new_doc_id* using continuation paging.""" clean_new_doc_id = replace_invalid_doc_id_characters(new_doc_id) # Use exact match instead of contains - Document Selector Language doesn't support contains selection = f'{index_name}.document_id=="{old_doc_id}"' with get_vespa_http_client() as http_client: continuation: str | None = None while True: # print(f"Visiting chunks for document {old_doc_id} -> {new_doc_id}") docs, continuation = _visit_chunks( http_client=http_client, index_name=index_name, selection=selection, continuation=continuation, ) if not docs: break for doc in docs: vespa_full_id = doc.get("id") if not vespa_full_id: continue vespa_doc_uuid = vespa_full_id.split("::")[-1] vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_doc_uuid}" update_request = { "fields": {"document_id": {"assign": clean_new_doc_id}} } try: resp = http_client.put(vespa_url, json=update_request) resp.raise_for_status() except Exception as e: print(f"Failed to update chunk {vespa_doc_uuid}: {e}") raise if not continuation: break def delete_document_from_db(current_doc_id: str, index_name: str) -> None: # Delete all foreign key references first, then delete the document try: bind = op.get_bind() # Delete from agent-related tables first (order matters due to foreign keys) # Delete from agent__sub_query__search_doc first since it references search_doc bind.execute( sa.text( """ DELETE FROM agent__sub_query__search_doc WHERE search_doc_id IN ( SELECT id FROM search_doc WHERE document_id = :doc_id ) """ ), {"doc_id": current_doc_id}, ) # Delete from chat_message__search_doc bind.execute( sa.text( """ DELETE FROM chat_message__search_doc WHERE search_doc_id IN ( SELECT id FROM search_doc WHERE document_id = :doc_id ) """ ), {"doc_id": current_doc_id}, ) # Now we can safely delete from search_doc bind.execute( sa.text("DELETE FROM search_doc WHERE document_id = :doc_id"), {"doc_id": current_doc_id}, ) # Delete from document_by_connector_credential_pair bind.execute( sa.text( "DELETE FROM document_by_connector_credential_pair WHERE id = :doc_id" ), {"doc_id": current_doc_id}, ) # Delete from other tables that reference this document bind.execute( sa.text( "DELETE FROM document_retrieval_feedback WHERE document_id = :doc_id" ), {"doc_id": current_doc_id}, ) bind.execute( sa.text("DELETE FROM document__tag WHERE document_id = :doc_id"), {"doc_id": current_doc_id}, ) bind.execute( sa.text("DELETE FROM user_file WHERE document_id = :doc_id"), {"doc_id": current_doc_id}, ) # Delete from KG tables if they exist try: bind.execute( sa.text("DELETE FROM kg_entity WHERE document_id = :doc_id"), {"doc_id": current_doc_id}, ) bind.execute( sa.text( "DELETE FROM kg_entity_extraction_staging WHERE document_id = :doc_id" ), {"doc_id": current_doc_id}, ) bind.execute( sa.text("DELETE FROM kg_relationship WHERE source_document = :doc_id"), {"doc_id": current_doc_id}, ) bind.execute( sa.text( "DELETE FROM kg_relationship_extraction_staging WHERE source_document = :doc_id" ), {"doc_id": current_doc_id}, ) bind.execute( sa.text("DELETE FROM chunk_stats WHERE document_id = :doc_id"), {"doc_id": current_doc_id}, ) bind.execute( sa.text("DELETE FROM chunk_stats WHERE id LIKE :doc_id_pattern"), {"doc_id_pattern": f"{current_doc_id}__%"}, ) except Exception as e: logger.warning( f"Some KG/chunk tables may not exist or failed to delete from: {e}" ) # Finally delete the document itself bind.execute( sa.text("DELETE FROM document WHERE id = :doc_id"), {"doc_id": current_doc_id}, ) # Delete chunks from vespa delete_document_chunks_from_vespa(index_name, current_doc_id) except Exception as e: print(f"Failed to delete duplicate document {current_doc_id}: {e}") # Continue with other documents instead of failing the entire migration def upgrade() -> None: if SKIP_CANON_DRIVE_IDS: return current_search_settings, _ = active_search_settings() # Get the index name if hasattr(current_search_settings, "index_name"): index_name = current_search_settings.index_name else: # Default index name if we can't get it from the document_index index_name = "danswer_index" # Get all Google Drive documents from the database (this is faster and more reliable) gdrive_documents = get_google_drive_documents_from_database() if not gdrive_documents: return # Track normalized document IDs to detect duplicates all_normalized_doc_ids = set() updated_count = 0 for doc_info in gdrive_documents: current_doc_id = doc_info["document_id"] normalized_doc_id = normalize_google_drive_url(current_doc_id) print(f"Processing document {current_doc_id} -> {normalized_doc_id}") # Check for duplicates if normalized_doc_id in all_normalized_doc_ids: # print(f"Deleting duplicate document {current_doc_id}") delete_document_from_db(current_doc_id, index_name) continue all_normalized_doc_ids.add(normalized_doc_id) # If the document ID already doesn't have query parameters, skip it if current_doc_id == normalized_doc_id: # print(f"Skipping document {current_doc_id} -> {normalized_doc_id} because it already has no query parameters") continue try: # Update both database and Vespa in order # Database first to ensure consistency update_document_id_in_database( current_doc_id, normalized_doc_id, index_name ) # For Vespa, we can now use the original document IDs since we're using contains matching update_document_id_in_vespa(index_name, current_doc_id, normalized_doc_id) updated_count += 1 # print(f"Finished updating document {current_doc_id} -> {normalized_doc_id}") except Exception as e: print(f"Failed to update document {current_doc_id}: {e}") if isinstance(e, HTTPStatusError): print(f"HTTPStatusError: {e}") print(f"Response: {e.response.text}") print(f"Status: {e.response.status_code}") print(f"Headers: {e.response.headers}") print(f"Request: {e.request.url}") print(f"Request headers: {e.request.headers}") # Note: Rollback is complex with copy-and-swap approach since the old document is already deleted # In case of failure, manual intervention may be required # Continue with other documents instead of failing the entire migration continue logger.info(f"Migration complete. Updated {updated_count} Google Drive documents") def downgrade() -> None: # this is a one way migration, so no downgrade. # It wouldn't make sense to store the extra query parameters # and duplicate documents to allow a reversal. pass ================================================ FILE: backend/alembic/versions/15326fcec57e_introduce_onyx_apis.py ================================================ """Introduce Onyx APIs Revision ID: 15326fcec57e Revises: 77d07dffae64 Create Date: 2023-11-11 20:51:24.228999 """ from alembic import op import sqlalchemy as sa from onyx.configs.constants import DocumentSource # revision identifiers, used by Alembic. revision = "15326fcec57e" down_revision = "77d07dffae64" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.alter_column("credential", "is_admin", new_column_name="admin_public") op.add_column( "document", sa.Column("from_ingestion_api", sa.Boolean(), nullable=True), ) op.alter_column( "connector", "source", type_=sa.String(length=50), existing_type=sa.Enum(DocumentSource, native_enum=False), existing_nullable=False, ) def downgrade() -> None: op.drop_column("document", "from_ingestion_api") op.alter_column("credential", "admin_public", new_column_name="is_admin") ================================================ FILE: backend/alembic/versions/16c37a30adf2_user_file_relationship_migration.py ================================================ """Migration 3: User file relationship migration Revision ID: 16c37a30adf2 Revises: 0cd424f32b1d Create Date: 2025-09-22 09:47:34.175596 This migration converts folder-based relationships to project-based relationships. It migrates persona__user_folder to persona__user_file and populates project__user_file. """ from alembic import op import sqlalchemy as sa from sqlalchemy import text import logging logger = logging.getLogger("alembic.runtime.migration") # revision identifiers, used by Alembic. revision = "16c37a30adf2" down_revision = "0cd424f32b1d" branch_labels = None depends_on = None def upgrade() -> None: """Migrate folder-based relationships to project-based relationships.""" bind = op.get_bind() inspector = sa.inspect(bind) # === Step 1: Migrate persona__user_folder to persona__user_file === table_names = inspector.get_table_names() if "persona__user_folder" in table_names and "user_file" in table_names: user_file_columns = [col["name"] for col in inspector.get_columns("user_file")] has_new_id = "new_id" in user_file_columns if has_new_id and "folder_id" in user_file_columns: logger.info( "Migrating persona__user_folder relationships to persona__user_file..." ) # Count relationships to migrate (asyncpg-compatible) count_query = text( """ SELECT COUNT(*) FROM ( SELECT DISTINCT puf.persona_id, uf.id FROM persona__user_folder puf JOIN user_file uf ON uf.folder_id = puf.user_folder_id WHERE NOT EXISTS ( SELECT 1 FROM persona__user_file p2 WHERE p2.persona_id = puf.persona_id AND p2.user_file_id = uf.id ) ) AS distinct_pairs """ ) to_migrate = bind.execute(count_query).scalar_one() if to_migrate > 0: logger.info(f"Creating {to_migrate} persona-file relationships...") # Migrate in batches to avoid memory issues batch_size = 10000 total_inserted = 0 while True: # Insert batch directly using subquery (asyncpg compatible) result = bind.execute( text( """ INSERT INTO persona__user_file (persona_id, user_file_id, user_file_id_uuid) SELECT DISTINCT puf.persona_id, uf.id as file_id, uf.new_id FROM persona__user_folder puf JOIN user_file uf ON uf.folder_id = puf.user_folder_id WHERE NOT EXISTS ( SELECT 1 FROM persona__user_file p2 WHERE p2.persona_id = puf.persona_id AND p2.user_file_id = uf.id ) LIMIT :batch_size """ ), {"batch_size": batch_size}, ) inserted = result.rowcount total_inserted += inserted if inserted < batch_size: break logger.info( f" Migrated {total_inserted}/{to_migrate} relationships..." ) logger.info( f"Created {total_inserted} persona__user_file relationships" ) # === Step 2: Add foreign key for chat_session.project_id === chat_session_fks = inspector.get_foreign_keys("chat_session") fk_exists = any( fk["name"] == "fk_chat_session_project_id" for fk in chat_session_fks ) if not fk_exists: logger.info("Adding foreign key constraint for chat_session.project_id...") op.create_foreign_key( "fk_chat_session_project_id", "chat_session", "user_project", ["project_id"], ["id"], ) logger.info("Added foreign key constraint") # === Step 3: Populate project__user_file from user_file.folder_id === user_file_columns = [col["name"] for col in inspector.get_columns("user_file")] has_new_id = "new_id" in user_file_columns if has_new_id and "folder_id" in user_file_columns: logger.info("Populating project__user_file from folder relationships...") # Count relationships to create count_query = text( """ SELECT COUNT(*) FROM user_file uf WHERE uf.folder_id IS NOT NULL AND NOT EXISTS ( SELECT 1 FROM project__user_file puf WHERE puf.project_id = uf.folder_id AND puf.user_file_id = uf.new_id ) """ ) to_create = bind.execute(count_query).scalar_one() if to_create > 0: logger.info(f"Creating {to_create} project-file relationships...") # Insert in batches batch_size = 10000 total_inserted = 0 while True: result = bind.execute( text( """ INSERT INTO project__user_file (project_id, user_file_id) SELECT uf.folder_id, uf.new_id FROM user_file uf WHERE uf.folder_id IS NOT NULL AND NOT EXISTS ( SELECT 1 FROM project__user_file puf WHERE puf.project_id = uf.folder_id AND puf.user_file_id = uf.new_id ) LIMIT :batch_size ON CONFLICT (project_id, user_file_id) DO NOTHING """ ), {"batch_size": batch_size}, ) inserted = result.rowcount total_inserted += inserted if inserted < batch_size: break logger.info(f" Created {total_inserted}/{to_create} relationships...") logger.info(f"Created {total_inserted} project__user_file relationships") # === Step 4: Create index on chat_session.project_id === try: indexes = [ix.get("name") for ix in inspector.get_indexes("chat_session")] except Exception: indexes = [] if "ix_chat_session_project_id" not in indexes: logger.info("Creating index on chat_session.project_id...") op.create_index( "ix_chat_session_project_id", "chat_session", ["project_id"], unique=False ) logger.info("Created index") logger.info("Migration 3 (relationship migration) completed successfully") def downgrade() -> None: """Remove migrated relationships and constraints.""" bind = op.get_bind() inspector = sa.inspect(bind) logger.info("Starting downgrade of relationship migration...") # Drop index on chat_session.project_id try: indexes = [ix.get("name") for ix in inspector.get_indexes("chat_session")] if "ix_chat_session_project_id" in indexes: op.drop_index("ix_chat_session_project_id", "chat_session") logger.info("Dropped index on chat_session.project_id") except Exception: pass # Drop foreign key constraint try: chat_session_fks = inspector.get_foreign_keys("chat_session") fk_exists = any( fk["name"] == "fk_chat_session_project_id" for fk in chat_session_fks ) if fk_exists: op.drop_constraint( "fk_chat_session_project_id", "chat_session", type_="foreignkey" ) logger.info("Dropped foreign key constraint on chat_session.project_id") except Exception: pass # Clear project__user_file relationships (but keep the table for migration 1 to handle) if "project__user_file" in inspector.get_table_names(): result = bind.execute(text("DELETE FROM project__user_file")) logger.info(f"Cleared {result.rowcount} records from project__user_file") # Remove migrated persona__user_file relationships # Only remove those that came from folder relationships if all( table in inspector.get_table_names() for table in ["persona__user_file", "persona__user_folder", "user_file"] ): user_file_columns = [col["name"] for col in inspector.get_columns("user_file")] if "folder_id" in user_file_columns: result = bind.execute( text( """ DELETE FROM persona__user_file puf WHERE EXISTS ( SELECT 1 FROM user_file uf JOIN persona__user_folder puf2 ON puf2.user_folder_id = uf.folder_id WHERE puf.persona_id = puf2.persona_id AND puf.user_file_id = uf.id ) """ ) ) logger.info( f"Removed {result.rowcount} migrated persona__user_file relationships" ) logger.info("Downgrade completed successfully") ================================================ FILE: backend/alembic/versions/173cae5bba26_port_config_store.py ================================================ """Port Config Store Revision ID: 173cae5bba26 Revises: e50154680a5c Create Date: 2024-03-19 15:30:44.425436 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "173cae5bba26" down_revision = "e50154680a5c" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "key_value_store", sa.Column("key", sa.String(), nullable=False), sa.Column("value", postgresql.JSONB(astext_type=sa.Text()), nullable=False), sa.PrimaryKeyConstraint("key"), ) def downgrade() -> None: op.drop_table("key_value_store") ================================================ FILE: backend/alembic/versions/175ea04c7087_add_user_preferences.py ================================================ """add_user_preferences Revision ID: 175ea04c7087 Revises: d56ffa94ca32 Create Date: 2026-02-04 18:16:24.830873 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "175ea04c7087" down_revision = "d56ffa94ca32" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "user", sa.Column("user_preferences", sa.Text(), nullable=True), ) def downgrade() -> None: op.drop_column("user", "user_preferences") ================================================ FILE: backend/alembic/versions/177de57c21c9_display_custom_llm_models.py ================================================ """display custom llm models Revision ID: 177de57c21c9 Revises: 4ee1287bd26a Create Date: 2024-11-21 11:49:04.488677 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql from sqlalchemy import and_ revision = "177de57c21c9" down_revision = "4ee1287bd26a" branch_labels = None depends_on = None depends_on = None def upgrade() -> None: conn = op.get_bind() llm_provider = sa.table( "llm_provider", sa.column("id", sa.Integer), sa.column("provider", sa.String), sa.column("model_names", postgresql.ARRAY(sa.String)), sa.column("display_model_names", postgresql.ARRAY(sa.String)), ) excluded_providers = ["openai", "bedrock", "anthropic", "azure"] providers_to_update = sa.select( llm_provider.c.id, llm_provider.c.model_names, llm_provider.c.display_model_names, ).where( and_( ~llm_provider.c.provider.in_(excluded_providers), llm_provider.c.model_names.isnot(None), ) ) results = conn.execute(providers_to_update).fetchall() for provider_id, model_names, display_model_names in results: if display_model_names is None: display_model_names = [] combined_model_names = list(set(display_model_names + model_names)) update_stmt = ( llm_provider.update() .where(llm_provider.c.id == provider_id) .values(display_model_names=combined_model_names) ) conn.execute(update_stmt) def downgrade() -> None: pass ================================================ FILE: backend/alembic/versions/18b5b2524446_add_is_clarification_to_chat_message.py ================================================ """add is_clarification to chat_message Revision ID: 18b5b2524446 Revises: 87c52ec39f84 Create Date: 2025-01-16 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "18b5b2524446" down_revision = "87c52ec39f84" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "chat_message", sa.Column( "is_clarification", sa.Boolean(), nullable=False, server_default="false" ), ) def downgrade() -> None: op.drop_column("chat_message", "is_clarification") ================================================ FILE: backend/alembic/versions/19c0ccb01687_migrate_to_contextual_rag_model.py ================================================ """Migrate to contextual rag model Revision ID: 19c0ccb01687 Revises: 9c54986124c6 Create Date: 2026-02-12 11:21:41.798037 """ import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. revision = "19c0ccb01687" down_revision = "9c54986124c6" branch_labels = None depends_on = None def upgrade() -> None: # Widen the column to fit 'CONTEXTUAL_RAG' (15 chars); was varchar(10) # when the table was created with only CHAT/VISION values. op.alter_column( "llm_model_flow", "llm_model_flow_type", type_=sa.String(length=20), existing_type=sa.String(length=10), existing_nullable=False, ) # For every search_settings row that has contextual rag configured, # create an llm_model_flow entry. is_default is TRUE if the row # belongs to the PRESENT search settings, FALSE otherwise. op.execute( """ INSERT INTO llm_model_flow (llm_model_flow_type, model_configuration_id, is_default) SELECT DISTINCT 'CONTEXTUAL_RAG', mc.id, (ss.status = 'PRESENT') FROM search_settings ss JOIN llm_provider lp ON lp.name = ss.contextual_rag_llm_provider JOIN model_configuration mc ON mc.llm_provider_id = lp.id AND mc.name = ss.contextual_rag_llm_name WHERE ss.enable_contextual_rag = TRUE AND ss.contextual_rag_llm_name IS NOT NULL AND ss.contextual_rag_llm_provider IS NOT NULL ON CONFLICT (llm_model_flow_type, model_configuration_id) DO UPDATE SET is_default = EXCLUDED.is_default WHERE EXCLUDED.is_default = TRUE """ ) def downgrade() -> None: op.execute( """ DELETE FROM llm_model_flow WHERE llm_model_flow_type = 'CONTEXTUAL_RAG' """ ) op.alter_column( "llm_model_flow", "llm_model_flow_type", type_=sa.String(length=10), existing_type=sa.String(length=20), existing_nullable=False, ) ================================================ FILE: backend/alembic/versions/1a03d2c2856b_add_indexes_to_document__tag.py ================================================ """Add indexes to document__tag Revision ID: 1a03d2c2856b Revises: 9c00a2bccb83 Create Date: 2025-02-18 10:45:13.957807 """ from alembic import op # revision identifiers, used by Alembic. revision = "1a03d2c2856b" down_revision = "9c00a2bccb83" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_index( op.f("ix_document__tag_tag_id"), "document__tag", ["tag_id"], unique=False, ) def downgrade() -> None: op.drop_index(op.f("ix_document__tag_tag_id"), table_name="document__tag") ================================================ FILE: backend/alembic/versions/1b10e1fda030_add_additional_data_to_notifications.py ================================================ """add additional data to notifications Revision ID: 1b10e1fda030 Revises: 6756efa39ada Create Date: 2024-10-15 19:26:44.071259 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "1b10e1fda030" down_revision = "6756efa39ada" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "notification", sa.Column("additional_data", postgresql.JSONB(), nullable=True) ) def downgrade() -> None: op.drop_column("notification", "additional_data") ================================================ FILE: backend/alembic/versions/1b8206b29c5d_add_user_delete_cascades.py ================================================ """add_user_delete_cascades Revision ID: 1b8206b29c5d Revises: 35e6853a51d5 Create Date: 2024-09-18 11:48:59.418726 """ from alembic import op # revision identifiers, used by Alembic. revision = "1b8206b29c5d" down_revision = "35e6853a51d5" branch_labels = None depends_on = None def upgrade() -> None: op.drop_constraint("credential_user_id_fkey", "credential", type_="foreignkey") op.create_foreign_key( "credential_user_id_fkey", "credential", "user", ["user_id"], ["id"], ondelete="CASCADE", ) op.drop_constraint("chat_session_user_id_fkey", "chat_session", type_="foreignkey") op.create_foreign_key( "chat_session_user_id_fkey", "chat_session", "user", ["user_id"], ["id"], ondelete="CASCADE", ) op.drop_constraint("chat_folder_user_id_fkey", "chat_folder", type_="foreignkey") op.create_foreign_key( "chat_folder_user_id_fkey", "chat_folder", "user", ["user_id"], ["id"], ondelete="CASCADE", ) op.drop_constraint("prompt_user_id_fkey", "prompt", type_="foreignkey") op.create_foreign_key( "prompt_user_id_fkey", "prompt", "user", ["user_id"], ["id"], ondelete="CASCADE" ) op.drop_constraint("notification_user_id_fkey", "notification", type_="foreignkey") op.create_foreign_key( "notification_user_id_fkey", "notification", "user", ["user_id"], ["id"], ondelete="CASCADE", ) op.drop_constraint("inputprompt_user_id_fkey", "inputprompt", type_="foreignkey") op.create_foreign_key( "inputprompt_user_id_fkey", "inputprompt", "user", ["user_id"], ["id"], ondelete="CASCADE", ) def downgrade() -> None: op.drop_constraint("credential_user_id_fkey", "credential", type_="foreignkey") op.create_foreign_key( "credential_user_id_fkey", "credential", "user", ["user_id"], ["id"] ) op.drop_constraint("chat_session_user_id_fkey", "chat_session", type_="foreignkey") op.create_foreign_key( "chat_session_user_id_fkey", "chat_session", "user", ["user_id"], ["id"] ) op.drop_constraint("chat_folder_user_id_fkey", "chat_folder", type_="foreignkey") op.create_foreign_key( "chat_folder_user_id_fkey", "chat_folder", "user", ["user_id"], ["id"] ) op.drop_constraint("prompt_user_id_fkey", "prompt", type_="foreignkey") op.create_foreign_key("prompt_user_id_fkey", "prompt", "user", ["user_id"], ["id"]) op.drop_constraint("notification_user_id_fkey", "notification", type_="foreignkey") op.create_foreign_key( "notification_user_id_fkey", "notification", "user", ["user_id"], ["id"] ) op.drop_constraint("inputprompt_user_id_fkey", "inputprompt", type_="foreignkey") op.create_foreign_key( "inputprompt_user_id_fkey", "inputprompt", "user", ["user_id"], ["id"] ) ================================================ FILE: backend/alembic/versions/1d78c0ca7853_remove_voice_provider_deleted_column.py ================================================ """remove voice_provider deleted column Revision ID: 1d78c0ca7853 Revises: a3f8b2c1d4e5 Create Date: 2026-03-26 11:30:53.883127 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "1d78c0ca7853" down_revision = "a3f8b2c1d4e5" branch_labels = None depends_on = None def upgrade() -> None: # Hard-delete any soft-deleted rows before dropping the column op.execute("DELETE FROM voice_provider WHERE deleted = true") op.drop_column("voice_provider", "deleted") def downgrade() -> None: op.add_column( "voice_provider", sa.Column( "deleted", sa.Boolean(), nullable=False, server_default=sa.text("false"), ), ) ================================================ FILE: backend/alembic/versions/1f2a3b4c5d6e_add_internet_search_and_content_providers.py ================================================ """add internet search and content provider tables Revision ID: 1f2a3b4c5d6e Revises: 9drpiiw74ljy Create Date: 2025-11-10 19:45:00.000000 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "1f2a3b4c5d6e" down_revision = "9drpiiw74ljy" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "internet_search_provider", sa.Column("id", sa.Integer(), primary_key=True), sa.Column("name", sa.String(), nullable=False, unique=True), sa.Column("provider_type", sa.String(), nullable=False), sa.Column("api_key", sa.LargeBinary(), nullable=True), sa.Column("config", postgresql.JSONB(astext_type=sa.Text()), nullable=True), sa.Column( "is_active", sa.Boolean(), nullable=False, server_default=sa.text("false") ), sa.Column( "time_created", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()"), ), sa.Column( "time_updated", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()"), ), ) op.create_index( "ix_internet_search_provider_is_active", "internet_search_provider", ["is_active"], ) op.create_table( "internet_content_provider", sa.Column("id", sa.Integer(), primary_key=True), sa.Column("name", sa.String(), nullable=False, unique=True), sa.Column("provider_type", sa.String(), nullable=False), sa.Column("api_key", sa.LargeBinary(), nullable=True), sa.Column("config", postgresql.JSONB(astext_type=sa.Text()), nullable=True), sa.Column( "is_active", sa.Boolean(), nullable=False, server_default=sa.text("false") ), sa.Column( "time_created", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()"), ), sa.Column( "time_updated", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()"), ), ) op.create_index( "ix_internet_content_provider_is_active", "internet_content_provider", ["is_active"], ) def downgrade() -> None: op.drop_index( "ix_internet_content_provider_is_active", table_name="internet_content_provider" ) op.drop_table("internet_content_provider") op.drop_index( "ix_internet_search_provider_is_active", table_name="internet_search_provider" ) op.drop_table("internet_search_provider") ================================================ FILE: backend/alembic/versions/1f60f60c3401_embedding_model_search_settings.py ================================================ """embedding model -> search settings Revision ID: 1f60f60c3401 Revises: f17bf3b0d9f1 Create Date: 2024-08-25 12:39:51.731632 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "1f60f60c3401" down_revision = "f17bf3b0d9f1" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.drop_constraint( "index_attempt__embedding_model_fk", "index_attempt", type_="foreignkey" ) # Rename the table op.rename_table("embedding_model", "search_settings") # Add new columns op.add_column( "search_settings", sa.Column( "multipass_indexing", sa.Boolean(), nullable=False, server_default="false" ), ) op.add_column( "search_settings", sa.Column( "multilingual_expansion", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}", ), ) op.add_column( "search_settings", sa.Column( "disable_rerank_for_streaming", sa.Boolean(), nullable=False, server_default="false", ), ) op.add_column( "search_settings", sa.Column("rerank_model_name", sa.String(), nullable=True) ) op.add_column( "search_settings", sa.Column("rerank_provider_type", sa.String(), nullable=True) ) op.add_column( "search_settings", sa.Column("rerank_api_key", sa.String(), nullable=True) ) op.add_column( "search_settings", sa.Column( "num_rerank", sa.Integer(), nullable=False, server_default=str(20), ), ) # Add the new column as nullable initially op.add_column( "index_attempt", sa.Column("search_settings_id", sa.Integer(), nullable=True) ) # Populate the new column with data from the existing embedding_model_id op.execute("UPDATE index_attempt SET search_settings_id = embedding_model_id") # Create the foreign key constraint op.create_foreign_key( "fk_index_attempt_search_settings", "index_attempt", "search_settings", ["search_settings_id"], ["id"], ) # Make the new column non-nullable op.alter_column("index_attempt", "search_settings_id", nullable=False) # Drop the old embedding_model_id column op.drop_column("index_attempt", "embedding_model_id") def downgrade() -> None: # Add back the embedding_model_id column op.add_column( "index_attempt", sa.Column("embedding_model_id", sa.Integer(), nullable=True) ) # Populate the old column with data from search_settings_id op.execute("UPDATE index_attempt SET embedding_model_id = search_settings_id") # Make the old column non-nullable op.alter_column("index_attempt", "embedding_model_id", nullable=False) # Drop the foreign key constraint op.drop_constraint( "fk_index_attempt_search_settings", "index_attempt", type_="foreignkey" ) # Drop the new search_settings_id column op.drop_column("index_attempt", "search_settings_id") # Rename the table back op.rename_table("search_settings", "embedding_model") # Remove added columns op.drop_column("embedding_model", "num_rerank") op.drop_column("embedding_model", "rerank_api_key") op.drop_column("embedding_model", "rerank_provider_type") op.drop_column("embedding_model", "rerank_model_name") op.drop_column("embedding_model", "disable_rerank_for_streaming") op.drop_column("embedding_model", "multilingual_expansion") op.drop_column("embedding_model", "multipass_indexing") op.create_foreign_key( "index_attempt__embedding_model_fk", "index_attempt", "embedding_model", ["embedding_model_id"], ["id"], ) ================================================ FILE: backend/alembic/versions/2020d417ec84_single_onyx_craft_migration.py ================================================ """single onyx craft migration Consolidates all buildmode/onyx craft tables into a single migration. Tables created: - build_session: User build sessions with status tracking - sandbox: User-owned containerized environments (one per user) - artifact: Build output files (web apps, documents, images) - snapshot: Sandbox filesystem snapshots - build_message: Conversation messages for build sessions Existing table modified: - connector_credential_pair: Added processing_mode column Revision ID: 2020d417ec84 Revises: 41fa44bef321 Create Date: 2026-01-26 14:43:54.641405 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "2020d417ec84" down_revision = "41fa44bef321" branch_labels = None depends_on = None def upgrade() -> None: # ========================================================================== # ENUMS # ========================================================================== # Build session status enum build_session_status_enum = sa.Enum( "active", "idle", name="buildsessionstatus", native_enum=False, ) # Sandbox status enum sandbox_status_enum = sa.Enum( "provisioning", "running", "idle", "sleeping", "terminated", "failed", name="sandboxstatus", native_enum=False, ) # Artifact type enum artifact_type_enum = sa.Enum( "web_app", "pptx", "docx", "markdown", "excel", "image", name="artifacttype", native_enum=False, ) # ========================================================================== # BUILD_SESSION TABLE # ========================================================================== op.create_table( "build_session", sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), sa.Column( "user_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("user.id", ondelete="CASCADE"), nullable=True, ), sa.Column("name", sa.String(), nullable=True), sa.Column( "status", build_session_status_enum, nullable=False, server_default="active", ), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column( "last_activity_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column("nextjs_port", sa.Integer(), nullable=True), sa.PrimaryKeyConstraint("id"), ) op.create_index( "ix_build_session_user_created", "build_session", ["user_id", sa.text("created_at DESC")], unique=False, ) op.create_index( "ix_build_session_status", "build_session", ["status"], unique=False, ) # ========================================================================== # SANDBOX TABLE (user-owned, one per user) # ========================================================================== op.create_table( "sandbox", sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), sa.Column( "user_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("user.id", ondelete="CASCADE"), nullable=False, ), sa.Column("container_id", sa.String(), nullable=True), sa.Column( "status", sandbox_status_enum, nullable=False, server_default="provisioning", ), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column("last_heartbeat", sa.DateTime(timezone=True), nullable=True), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("user_id", name="sandbox_user_id_key"), ) op.create_index( "ix_sandbox_status", "sandbox", ["status"], unique=False, ) op.create_index( "ix_sandbox_container_id", "sandbox", ["container_id"], unique=False, ) # ========================================================================== # ARTIFACT TABLE # ========================================================================== op.create_table( "artifact", sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), sa.Column( "session_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("build_session.id", ondelete="CASCADE"), nullable=False, ), sa.Column("type", artifact_type_enum, nullable=False), sa.Column("path", sa.String(), nullable=False), sa.Column("name", sa.String(), nullable=False), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column( "updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.PrimaryKeyConstraint("id"), ) op.create_index( "ix_artifact_session_created", "artifact", ["session_id", sa.text("created_at DESC")], unique=False, ) op.create_index( "ix_artifact_type", "artifact", ["type"], unique=False, ) # ========================================================================== # SNAPSHOT TABLE # ========================================================================== op.create_table( "snapshot", sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), sa.Column( "session_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("build_session.id", ondelete="CASCADE"), nullable=False, ), sa.Column("storage_path", sa.String(), nullable=False), sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.PrimaryKeyConstraint("id"), ) op.create_index( "ix_snapshot_session_created", "snapshot", ["session_id", sa.text("created_at DESC")], unique=False, ) # ========================================================================== # BUILD_MESSAGE TABLE # ========================================================================== op.create_table( "build_message", sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), sa.Column( "session_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("build_session.id", ondelete="CASCADE"), nullable=False, ), sa.Column( "turn_index", sa.Integer(), nullable=False, ), sa.Column( "type", sa.Enum( "SYSTEM", "USER", "ASSISTANT", "DANSWER", name="messagetype", create_type=False, native_enum=False, ), nullable=False, ), sa.Column( "message_metadata", postgresql.JSONB(), nullable=False, ), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.PrimaryKeyConstraint("id"), ) op.create_index( "ix_build_message_session_turn", "build_message", ["session_id", "turn_index", sa.text("created_at ASC")], unique=False, ) # ========================================================================== # CONNECTOR_CREDENTIAL_PAIR MODIFICATION # ========================================================================== op.add_column( "connector_credential_pair", sa.Column( "processing_mode", sa.String(), nullable=False, server_default="regular", ), ) def downgrade() -> None: # ========================================================================== # CONNECTOR_CREDENTIAL_PAIR MODIFICATION # ========================================================================== op.drop_column("connector_credential_pair", "processing_mode") # ========================================================================== # BUILD_MESSAGE TABLE # ========================================================================== op.drop_index("ix_build_message_session_turn", table_name="build_message") op.drop_table("build_message") # ========================================================================== # SNAPSHOT TABLE # ========================================================================== op.drop_index("ix_snapshot_session_created", table_name="snapshot") op.drop_table("snapshot") # ========================================================================== # ARTIFACT TABLE # ========================================================================== op.drop_index("ix_artifact_type", table_name="artifact") op.drop_index("ix_artifact_session_created", table_name="artifact") op.drop_table("artifact") sa.Enum(name="artifacttype").drop(op.get_bind(), checkfirst=True) # ========================================================================== # SANDBOX TABLE # ========================================================================== op.drop_index("ix_sandbox_container_id", table_name="sandbox") op.drop_index("ix_sandbox_status", table_name="sandbox") op.drop_table("sandbox") sa.Enum(name="sandboxstatus").drop(op.get_bind(), checkfirst=True) # ========================================================================== # BUILD_SESSION TABLE # ========================================================================== op.drop_index("ix_build_session_status", table_name="build_session") op.drop_index("ix_build_session_user_created", table_name="build_session") op.drop_table("build_session") sa.Enum(name="buildsessionstatus").drop(op.get_bind(), checkfirst=True) ================================================ FILE: backend/alembic/versions/213fd978c6d8_notifications.py ================================================ """notifications Revision ID: 213fd978c6d8 Revises: 5fc1f54cc252 Create Date: 2024-08-10 11:13:36.070790 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "213fd978c6d8" down_revision = "5fc1f54cc252" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "notification", sa.Column("id", sa.Integer(), nullable=False), sa.Column( "notif_type", sa.String(), nullable=False, ), sa.Column( "user_id", sa.UUID(), nullable=True, ), sa.Column("dismissed", sa.Boolean(), nullable=False), sa.Column("last_shown", sa.DateTime(timezone=True), nullable=False), sa.Column("first_shown", sa.DateTime(timezone=True), nullable=False), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("id"), ) def downgrade() -> None: op.drop_table("notification") ================================================ FILE: backend/alembic/versions/238b84885828_add_foreign_key_to_user__external_user_.py ================================================ """Add foreign key to user__external_user_group_id Revision ID: 238b84885828 Revises: a7688ab35c45 Create Date: 2025-05-19 17:15:33.424584 """ from alembic import op # revision identifiers, used by Alembic. revision = "238b84885828" down_revision = "a7688ab35c45" branch_labels = None depends_on = None def upgrade() -> None: # First, clean up any entries that don't have a valid cc_pair_id op.execute( """ DELETE FROM user__external_user_group_id WHERE cc_pair_id NOT IN (SELECT id FROM connector_credential_pair) """ ) # Add foreign key constraint with cascade delete op.create_foreign_key( "fk_user__external_user_group_id_cc_pair_id", "user__external_user_group_id", "connector_credential_pair", ["cc_pair_id"], ["id"], ondelete="CASCADE", ) def downgrade() -> None: # Drop the foreign key constraint op.drop_constraint( "fk_user__external_user_group_id_cc_pair_id", "user__external_user_group_id", type_="foreignkey", ) ================================================ FILE: backend/alembic/versions/23957775e5f5_remove_feedback_foreignkey_constraint.py ================================================ """remove-feedback-foreignkey-constraint Revision ID: 23957775e5f5 Revises: bc9771dccadf Create Date: 2024-06-27 16:04:51.480437 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "23957775e5f5" down_revision = "bc9771dccadf" branch_labels = None depends_on = None def upgrade() -> None: op.drop_constraint( "chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey" ) op.create_foreign_key( "chat_feedback__chat_message_fk", "chat_feedback", "chat_message", ["chat_message_id"], ["id"], ondelete="SET NULL", ) op.alter_column( "chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=True ) op.drop_constraint( "document_retrieval_feedback__chat_message_fk", "document_retrieval_feedback", type_="foreignkey", ) op.create_foreign_key( "document_retrieval_feedback__chat_message_fk", "document_retrieval_feedback", "chat_message", ["chat_message_id"], ["id"], ondelete="SET NULL", ) op.alter_column( "document_retrieval_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=True, ) def downgrade() -> None: op.alter_column( "chat_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=False ) op.drop_constraint( "chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey" ) op.create_foreign_key( "chat_feedback__chat_message_fk", "chat_feedback", "chat_message", ["chat_message_id"], ["id"], ) op.alter_column( "document_retrieval_feedback", "chat_message_id", existing_type=sa.Integer(), nullable=False, ) op.drop_constraint( "document_retrieval_feedback__chat_message_fk", "document_retrieval_feedback", type_="foreignkey", ) op.create_foreign_key( "document_retrieval_feedback__chat_message_fk", "document_retrieval_feedback", "chat_message", ["chat_message_id"], ["id"], ) ================================================ FILE: backend/alembic/versions/25a5501dc766_group_permissions_phase1.py ================================================ """group_permissions_phase1 Revision ID: 25a5501dc766 Revises: b728689f45b1 Create Date: 2026-03-23 11:41:25.557442 """ from alembic import op import fastapi_users_db_sqlalchemy import sqlalchemy as sa from onyx.db.enums import AccountType from onyx.db.enums import GrantSource from onyx.db.enums import Permission # revision identifiers, used by Alembic. revision = "25a5501dc766" down_revision = "b728689f45b1" branch_labels = None depends_on = None def upgrade() -> None: # 1. Add account_type column to user table (nullable for now). # TODO(subash): backfill account_type for existing rows and add NOT NULL. op.add_column( "user", sa.Column( "account_type", sa.Enum(AccountType, native_enum=False), nullable=True, ), ) # 2. Add is_default column to user_group table op.add_column( "user_group", sa.Column( "is_default", sa.Boolean(), nullable=False, server_default=sa.false(), ), ) # 3. Create permission_grant table op.create_table( "permission_grant", sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), sa.Column("group_id", sa.Integer(), nullable=False), sa.Column( "permission", sa.Enum(Permission, native_enum=False), nullable=False, ), sa.Column( "grant_source", sa.Enum(GrantSource, native_enum=False), nullable=False, ), sa.Column( "granted_by", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=True, ), sa.Column( "granted_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False, ), sa.Column( "is_deleted", sa.Boolean(), nullable=False, server_default=sa.false(), ), sa.PrimaryKeyConstraint("id"), sa.ForeignKeyConstraint( ["group_id"], ["user_group.id"], ondelete="CASCADE", ), sa.ForeignKeyConstraint( ["granted_by"], ["user.id"], ondelete="SET NULL", ), sa.UniqueConstraint( "group_id", "permission", name="uq_permission_grant_group_permission" ), ) # 4. Index on user__user_group(user_id) — existing composite PK # has user_group_id as leading column; user-filtered queries need this op.create_index( "ix_user__user_group_user_id", "user__user_group", ["user_id"], ) def downgrade() -> None: op.drop_index("ix_user__user_group_user_id", table_name="user__user_group") op.drop_table("permission_grant") op.drop_column("user_group", "is_default") op.drop_column("user", "account_type") ================================================ FILE: backend/alembic/versions/2664261bfaab_add_cache_store_table.py ================================================ """add cache_store table Revision ID: 2664261bfaab Revises: 4a1e4b1c89d2 Create Date: 2026-02-27 00:00:00.000000 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "2664261bfaab" down_revision = "4a1e4b1c89d2" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "cache_store", sa.Column("key", sa.String(), nullable=False), sa.Column("value", sa.LargeBinary(), nullable=True), sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True), sa.PrimaryKeyConstraint("key"), ) op.create_index( "ix_cache_store_expires", "cache_store", ["expires_at"], postgresql_where=sa.text("expires_at IS NOT NULL"), ) def downgrade() -> None: op.drop_index("ix_cache_store_expires", table_name="cache_store") op.drop_table("cache_store") ================================================ FILE: backend/alembic/versions/2666d766cb9b_google_oauth2.py ================================================ """Google OAuth2 Revision ID: 2666d766cb9b Revises: 6d387b3196c2 Create Date: 2023-05-05 15:49:35.716016 """ import fastapi_users_db_sqlalchemy import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. revision = "2666d766cb9b" down_revision = "6d387b3196c2" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "oauth_account", sa.Column("id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False, ), sa.Column("oauth_name", sa.String(length=100), nullable=False), sa.Column("access_token", sa.String(length=1024), nullable=False), sa.Column("expires_at", sa.Integer(), nullable=True), sa.Column("refresh_token", sa.String(length=1024), nullable=True), sa.Column("account_id", sa.String(length=320), nullable=False), sa.Column("account_email", sa.String(length=320), nullable=False), sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="cascade"), sa.PrimaryKeyConstraint("id"), ) op.create_index( op.f("ix_oauth_account_account_id"), "oauth_account", ["account_id"], unique=False, ) op.create_index( op.f("ix_oauth_account_oauth_name"), "oauth_account", ["oauth_name"], unique=False, ) def downgrade() -> None: op.drop_index(op.f("ix_oauth_account_oauth_name"), table_name="oauth_account") op.drop_index(op.f("ix_oauth_account_account_id"), table_name="oauth_account") op.drop_table("oauth_account") ================================================ FILE: backend/alembic/versions/26b931506ecb_default_chosen_assistants_to_none.py ================================================ """default chosen assistants to none Revision ID: 26b931506ecb Revises: 2daa494a0851 Create Date: 2024-11-12 13:23:29.858995 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "26b931506ecb" down_revision = "2daa494a0851" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "user", sa.Column("chosen_assistants_new", postgresql.JSONB(), nullable=True) ) op.execute( """ UPDATE "user" SET chosen_assistants_new = CASE WHEN chosen_assistants = '[-2, -1, 0]' THEN NULL ELSE chosen_assistants END """ ) op.drop_column("user", "chosen_assistants") op.alter_column( "user", "chosen_assistants_new", new_column_name="chosen_assistants" ) def downgrade() -> None: op.add_column( "user", sa.Column( "chosen_assistants_old", postgresql.JSONB(), nullable=False, server_default="[-2, -1, 0]", ), ) op.execute( """ UPDATE "user" SET chosen_assistants_old = CASE WHEN chosen_assistants IS NULL THEN '[-2, -1, 0]'::jsonb ELSE chosen_assistants END """ ) op.drop_column("user", "chosen_assistants") op.alter_column( "user", "chosen_assistants_old", new_column_name="chosen_assistants" ) ================================================ FILE: backend/alembic/versions/27c6ecc08586_permission_framework.py ================================================ """Permission Framework Revision ID: 27c6ecc08586 Revises: 2666d766cb9b Create Date: 2023-05-24 18:45:17.244495 """ import fastapi_users_db_sqlalchemy import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "27c6ecc08586" down_revision = "2666d766cb9b" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.execute("TRUNCATE TABLE index_attempt") op.create_table( "connector", sa.Column("id", sa.Integer(), nullable=False), sa.Column("name", sa.String(), nullable=False), sa.Column( "source", sa.Enum( "SLACK", "WEB", "GOOGLE_DRIVE", "GITHUB", "CONFLUENCE", name="documentsource", native_enum=False, ), nullable=False, ), sa.Column( "input_type", sa.Enum( "LOAD_STATE", "POLL", "EVENT", name="inputtype", native_enum=False, ), nullable=True, ), sa.Column( "connector_specific_config", postgresql.JSONB(astext_type=sa.Text()), nullable=False, ), sa.Column("refresh_freq", sa.Integer(), nullable=True), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column( "time_updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column("disabled", sa.Boolean(), nullable=False), sa.PrimaryKeyConstraint("id"), ) op.create_table( "credential", sa.Column("id", sa.Integer(), nullable=False), sa.Column( "credential_json", postgresql.JSONB(astext_type=sa.Text()), nullable=False, ), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=True, ), sa.Column("public_doc", sa.Boolean(), nullable=False), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column( "time_updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("id"), ) op.create_table( "connector_credential_pair", sa.Column("connector_id", sa.Integer(), nullable=False), sa.Column("credential_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["connector_id"], ["connector.id"], ), sa.ForeignKeyConstraint( ["credential_id"], ["credential.id"], ), sa.PrimaryKeyConstraint("connector_id", "credential_id"), ) op.add_column( "index_attempt", sa.Column("connector_id", sa.Integer(), nullable=True), ) op.add_column( "index_attempt", sa.Column("credential_id", sa.Integer(), nullable=True), ) op.create_foreign_key( "fk_index_attempt_credential_id", "index_attempt", "credential", ["credential_id"], ["id"], ) op.create_foreign_key( "fk_index_attempt_connector_id", "index_attempt", "connector", ["connector_id"], ["id"], ) op.drop_column("index_attempt", "connector_specific_config") op.drop_column("index_attempt", "source") op.drop_column("index_attempt", "input_type") def downgrade() -> None: op.execute("TRUNCATE TABLE index_attempt") conn = op.get_bind() inspector = sa.inspect(conn) existing_columns = {col["name"] for col in inspector.get_columns("index_attempt")} if "input_type" not in existing_columns: op.add_column( "index_attempt", sa.Column("input_type", sa.VARCHAR(), autoincrement=False, nullable=False), ) if "source" not in existing_columns: op.add_column( "index_attempt", sa.Column("source", sa.VARCHAR(), autoincrement=False, nullable=False), ) if "connector_specific_config" not in existing_columns: op.add_column( "index_attempt", sa.Column( "connector_specific_config", postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=False, ), ) # Check if the constraint exists before dropping constraints = inspector.get_foreign_keys("index_attempt") if any( constraint["name"] == "fk_index_attempt_credential_id" for constraint in constraints ): op.drop_constraint( "fk_index_attempt_credential_id", "index_attempt", type_="foreignkey" ) if any( constraint["name"] == "fk_index_attempt_connector_id" for constraint in constraints ): op.drop_constraint( "fk_index_attempt_connector_id", "index_attempt", type_="foreignkey" ) if "credential_id" in existing_columns: op.drop_column("index_attempt", "credential_id") if "connector_id" in existing_columns: op.drop_column("index_attempt", "connector_id") op.execute("DROP TABLE IF EXISTS connector_credential_pair CASCADE") op.execute("DROP TABLE IF EXISTS credential CASCADE") op.execute("DROP TABLE IF EXISTS connector CASCADE") ================================================ FILE: backend/alembic/versions/27fb147a843f_add_timestamps_to_user_table.py ================================================ """add timestamps to user table Revision ID: 27fb147a843f Revises: b5c4d7e8f9a1 Create Date: 2026-03-08 17:18:40.828644 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "27fb147a843f" down_revision = "b5c4d7e8f9a1" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "user", sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False, ), ) op.add_column( "user", sa.Column( "updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False, ), ) def downgrade() -> None: op.drop_column("user", "updated_at") op.drop_column("user", "created_at") ================================================ FILE: backend/alembic/versions/2955778aa44c_add_chunk_count_to_document.py ================================================ """add chunk count to document Revision ID: 2955778aa44c Revises: c0aab6edb6dd Create Date: 2025-01-04 11:39:43.268612 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "2955778aa44c" down_revision = "c0aab6edb6dd" branch_labels = None depends_on = None def upgrade() -> None: op.add_column("document", sa.Column("chunk_count", sa.Integer(), nullable=True)) def downgrade() -> None: op.drop_column("document", "chunk_count") ================================================ FILE: backend/alembic/versions/2a391f840e85_add_last_refreshed_at_mcp_server.py ================================================ """add last refreshed at mcp server Revision ID: 2a391f840e85 Revises: 4cebcbc9b2ae Create Date: 2025-12-06 15:19:59.766066 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembi. revision = "2a391f840e85" down_revision = "4cebcbc9b2ae" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "mcp_server", sa.Column("last_refreshed_at", sa.DateTime(timezone=True), nullable=True), ) def downgrade() -> None: op.drop_column("mcp_server", "last_refreshed_at") ================================================ FILE: backend/alembic/versions/2acdef638fc2_add_switchover_type_field.py ================================================ """add switchover_type field and remove background_reindex_enabled Revision ID: 2acdef638fc2 Revises: a4f23d6b71c8 Create Date: 2025-01-XX XX:XX:XX.XXXXXX """ from alembic import op import sqlalchemy as sa from onyx.db.enums import SwitchoverType # revision identifiers, used by Alembic. revision = "2acdef638fc2" down_revision = "a4f23d6b71c8" branch_labels = None depends_on = None def upgrade() -> None: # Add switchover_type column with default value of REINDEX op.add_column( "search_settings", sa.Column( "switchover_type", sa.Enum(SwitchoverType, native_enum=False), nullable=False, server_default=SwitchoverType.REINDEX.value, ), ) # Migrate existing data: set switchover_type based on background_reindex_enabled # REINDEX where background_reindex_enabled=True, INSTANT where False op.execute( """ UPDATE search_settings SET switchover_type = CASE WHEN background_reindex_enabled = true THEN 'REINDEX' ELSE 'INSTANT' END """ ) # Remove the background_reindex_enabled column (replaced by switchover_type) op.drop_column("search_settings", "background_reindex_enabled") def downgrade() -> None: # Re-add the background_reindex_enabled column with default value of True op.add_column( "search_settings", sa.Column( "background_reindex_enabled", sa.Boolean(), nullable=False, server_default="true", ), ) # Set background_reindex_enabled based on switchover_type op.execute( """ UPDATE search_settings SET background_reindex_enabled = CASE WHEN switchover_type = 'INSTANT' THEN false ELSE true END """ ) # Remove the switchover_type column op.drop_column("search_settings", "switchover_type") ================================================ FILE: backend/alembic/versions/2b75d0a8ffcb_user_file_schema_cleanup.py ================================================ """Migration 6: User file schema cleanup Revision ID: 2b75d0a8ffcb Revises: 3a78dba1080a Create Date: 2025-09-22 10:09:26.375377 This migration removes legacy columns and tables after data migration is complete. It should only be run after verifying all data has been successfully migrated. """ from alembic import op import sqlalchemy as sa from sqlalchemy import text import logging import fastapi_users_db_sqlalchemy logger = logging.getLogger("alembic.runtime.migration") # revision identifiers, used by Alembic. revision = "2b75d0a8ffcb" down_revision = "3a78dba1080a" branch_labels = None depends_on = None def upgrade() -> None: """Remove legacy columns and tables.""" bind = op.get_bind() inspector = sa.inspect(bind) logger.info("Starting schema cleanup...") # === Step 1: Verify data migration is complete === logger.info("Verifying data migration completion...") # Check if any chat sessions still have folder_id references chat_session_columns = [ col["name"] for col in inspector.get_columns("chat_session") ] if "folder_id" in chat_session_columns: orphaned_count = bind.execute( text( """ SELECT COUNT(*) FROM chat_session WHERE folder_id IS NOT NULL AND project_id IS NULL """ ) ).scalar_one() if orphaned_count > 0: logger.warning( f"WARNING: {orphaned_count} chat_session records still have folder_id without project_id. Proceeding anyway." ) # === Step 2: Drop chat_session.folder_id === if "folder_id" in chat_session_columns: logger.info("Dropping chat_session.folder_id...") # Drop foreign key constraint first op.execute( "ALTER TABLE chat_session DROP CONSTRAINT IF EXISTS chat_session_chat_folder_fk" ) op.execute( "ALTER TABLE chat_session DROP CONSTRAINT IF EXISTS chat_session_folder_fk" ) # Drop the column op.drop_column("chat_session", "folder_id") logger.info("Dropped chat_session.folder_id") # === Step 3: Drop persona__user_folder table === if "persona__user_folder" in inspector.get_table_names(): logger.info("Dropping persona__user_folder table...") # Check for any remaining data remaining = bind.execute( text("SELECT COUNT(*) FROM persona__user_folder") ).scalar_one() if remaining > 0: logger.warning( f"WARNING: Dropping persona__user_folder with {remaining} records" ) op.drop_table("persona__user_folder") logger.info("Dropped persona__user_folder table") # === Step 4: Drop chat_folder table === if "chat_folder" in inspector.get_table_names(): logger.info("Dropping chat_folder table...") # Check for any remaining data remaining = bind.execute(text("SELECT COUNT(*) FROM chat_folder")).scalar_one() if remaining > 0: logger.warning(f"WARNING: Dropping chat_folder with {remaining} records") op.drop_table("chat_folder") logger.info("Dropped chat_folder table") # === Step 5: Drop user_file legacy columns === user_file_columns = [col["name"] for col in inspector.get_columns("user_file")] # Drop folder_id if "folder_id" in user_file_columns: logger.info("Dropping user_file.folder_id...") op.drop_column("user_file", "folder_id") logger.info("Dropped user_file.folder_id") # Drop cc_pair_id (already handled in migration 5, but be sure) if "cc_pair_id" in user_file_columns: logger.info("Dropping user_file.cc_pair_id...") # Drop any remaining foreign key constraints bind.execute( text( """ DO $$ DECLARE r RECORD; BEGIN FOR r IN ( SELECT conname FROM pg_constraint c JOIN pg_class t ON c.conrelid = t.oid WHERE c.contype = 'f' AND t.relname = 'user_file' AND EXISTS ( SELECT 1 FROM pg_attribute a WHERE a.attrelid = t.oid AND a.attname = 'cc_pair_id' ) ) LOOP EXECUTE format('ALTER TABLE user_file DROP CONSTRAINT IF EXISTS %I', r.conname); END LOOP; END$$; """ ) ) op.drop_column("user_file", "cc_pair_id") logger.info("Dropped user_file.cc_pair_id") # === Step 6: Clean up any remaining constraints === logger.info("Cleaning up remaining constraints...") # Drop any unique constraints on removed columns op.execute( "ALTER TABLE user_file DROP CONSTRAINT IF EXISTS user_file_cc_pair_id_key" ) logger.info("Migration 6 (schema cleanup) completed successfully") logger.info("Legacy schema has been fully removed") def downgrade() -> None: """Recreate dropped columns and tables (structure only, no data).""" bind = op.get_bind() inspector = sa.inspect(bind) logger.warning("Downgrading schema cleanup - recreating structure only, no data!") # Recreate user_file columns if "user_file" in inspector.get_table_names(): columns = [col["name"] for col in inspector.get_columns("user_file")] if "cc_pair_id" not in columns: op.add_column( "user_file", sa.Column("cc_pair_id", sa.Integer(), nullable=True) ) if "folder_id" not in columns: op.add_column( "user_file", sa.Column("folder_id", sa.Integer(), nullable=True) ) # Recreate persona__user_folder table if "persona__user_folder" not in inspector.get_table_names(): op.create_table( "persona__user_folder", sa.Column("persona_id", sa.Integer(), nullable=False), sa.Column("user_folder_id", sa.Integer(), nullable=False), sa.PrimaryKeyConstraint("persona_id", "user_folder_id"), sa.ForeignKeyConstraint(["persona_id"], ["persona.id"]), sa.ForeignKeyConstraint(["user_folder_id"], ["user_project.id"]), ) # Recreate chat_folder table and related structures if "chat_folder" not in inspector.get_table_names(): op.create_table( "chat_folder", sa.Column("id", sa.Integer(), nullable=False), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=True, ), sa.Column("name", sa.String(), nullable=True), sa.Column("display_priority", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], name="chat_folder_user_id_fkey", ), sa.PrimaryKeyConstraint("id"), ) # Add folder_id back to chat_session if "chat_session" in inspector.get_table_names(): columns = [col["name"] for col in inspector.get_columns("chat_session")] if "folder_id" not in columns: op.add_column( "chat_session", sa.Column("folder_id", sa.Integer(), nullable=True) ) # Add foreign key if chat_folder exists if "chat_folder" in inspector.get_table_names(): op.create_foreign_key( "chat_session_chat_folder_fk", "chat_session", "chat_folder", ["folder_id"], ["id"], ) logger.info("Downgrade completed - structure recreated but data is lost") ================================================ FILE: backend/alembic/versions/2b90f3af54b8_usage_limits.py ================================================ """usage_limits Revision ID: 2b90f3af54b8 Revises: 9a0296d7421e Create Date: 2026-01-03 16:55:30.449692 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "2b90f3af54b8" down_revision = "9a0296d7421e" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "tenant_usage", sa.Column("id", sa.Integer(), nullable=False), sa.Column( "window_start", sa.DateTime(timezone=True), nullable=False, index=True ), sa.Column("llm_cost_cents", sa.Float(), nullable=False, server_default="0.0"), sa.Column("chunks_indexed", sa.Integer(), nullable=False, server_default="0"), sa.Column("api_calls", sa.Integer(), nullable=False, server_default="0"), sa.Column( "non_streaming_api_calls", sa.Integer(), nullable=False, server_default="0" ), sa.Column( "updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=True, ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("window_start", name="uq_tenant_usage_window"), ) def downgrade() -> None: op.drop_index("ix_tenant_usage_window_start", table_name="tenant_usage") op.drop_table("tenant_usage") ================================================ FILE: backend/alembic/versions/2c2430828bdf_add_unique_constraint_to_inputprompt_.py ================================================ """add_unique_constraint_to_inputprompt_prompt_user_id Revision ID: 2c2430828bdf Revises: fb80bdd256de Create Date: 2026-01-20 16:01:54.314805 """ from alembic import op # revision identifiers, used by Alembic. revision = "2c2430828bdf" down_revision = "fb80bdd256de" branch_labels = None depends_on = None def upgrade() -> None: # Create unique constraint on (prompt, user_id) for user-owned prompts # This ensures each user can only have one shortcut with a given name op.create_unique_constraint( "uq_inputprompt_prompt_user_id", "inputprompt", ["prompt", "user_id"], ) # Create partial unique index for public prompts (where user_id IS NULL) # PostgreSQL unique constraints don't enforce uniqueness for NULL values, # so we need a partial index to ensure public prompt names are also unique op.execute( """ CREATE UNIQUE INDEX uq_inputprompt_prompt_public ON inputprompt (prompt) WHERE user_id IS NULL """ ) def downgrade() -> None: op.execute("DROP INDEX IF EXISTS uq_inputprompt_prompt_public") op.drop_constraint("uq_inputprompt_prompt_user_id", "inputprompt", type_="unique") ================================================ FILE: backend/alembic/versions/2cdeff6d8c93_set_built_in_to_default.py ================================================ """set built in to default Revision ID: 2cdeff6d8c93 Revises: f5437cc136c5 Create Date: 2025-02-11 14:57:51.308775 """ from alembic import op # revision identifiers, used by Alembic. revision = "2cdeff6d8c93" down_revision = "f5437cc136c5" branch_labels = None depends_on = None def upgrade() -> None: # Prior to this migration / point in the codebase history, # built in personas were implicitly treated as default personas (with no option to change this) # This migration makes that explicit op.execute( """ UPDATE persona SET is_default_persona = TRUE WHERE builtin_persona = TRUE """ ) def downgrade() -> None: pass ================================================ FILE: backend/alembic/versions/2d2304e27d8c_add_above_below_to_persona.py ================================================ """Add Above Below to Persona Revision ID: 2d2304e27d8c Revises: 4b08d97e175a Create Date: 2024-08-21 19:15:15.762948 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "2d2304e27d8c" down_revision = "4b08d97e175a" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column("persona", sa.Column("chunks_above", sa.Integer(), nullable=True)) op.add_column("persona", sa.Column("chunks_below", sa.Integer(), nullable=True)) op.execute( "UPDATE persona SET chunks_above = 1, chunks_below = 1 WHERE chunks_above IS NULL AND chunks_below IS NULL" ) op.alter_column("persona", "chunks_above", nullable=False) op.alter_column("persona", "chunks_below", nullable=False) def downgrade() -> None: op.drop_column("persona", "chunks_below") op.drop_column("persona", "chunks_above") ================================================ FILE: backend/alembic/versions/2daa494a0851_add_group_sync_time.py ================================================ """add-group-sync-time Revision ID: 2daa494a0851 Revises: c0fd6e4da83a Create Date: 2024-11-11 10:57:22.991157 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "2daa494a0851" down_revision = "c0fd6e4da83a" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "connector_credential_pair", sa.Column( "last_time_external_group_sync", sa.DateTime(timezone=True), nullable=True, ), ) def downgrade() -> None: op.drop_column("connector_credential_pair", "last_time_external_group_sync") ================================================ FILE: backend/alembic/versions/2f80c6a2550f_add_chat_session_specific_temperature_.py ================================================ """add chat session specific temperature override Revision ID: 2f80c6a2550f Revises: 33ea50e88f24 Create Date: 2025-01-31 10:30:27.289646 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "2f80c6a2550f" down_revision = "33ea50e88f24" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "chat_session", sa.Column("temperature_override", sa.Float(), nullable=True) ) op.add_column( "user", sa.Column( "temperature_override_enabled", sa.Boolean(), nullable=False, server_default=sa.false(), ), ) def downgrade() -> None: op.drop_column("chat_session", "temperature_override") op.drop_column("user", "temperature_override_enabled") ================================================ FILE: backend/alembic/versions/2f95e36923e6_add_indexing_coordination.py ================================================ """add_indexing_coordination Revision ID: 2f95e36923e6 Revises: 0816326d83aa Create Date: 2025-07-10 16:17:57.762182 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "2f95e36923e6" down_revision = "0816326d83aa" branch_labels = None depends_on = None def upgrade() -> None: # Add database-based coordination fields (replacing Redis fencing) op.add_column( "index_attempt", sa.Column("celery_task_id", sa.String(), nullable=True) ) op.add_column( "index_attempt", sa.Column( "cancellation_requested", sa.Boolean(), nullable=False, server_default="false", ), ) # Add batch coordination fields (replacing FileStore state) op.add_column( "index_attempt", sa.Column("total_batches", sa.Integer(), nullable=True) ) op.add_column( "index_attempt", sa.Column( "completed_batches", sa.Integer(), nullable=False, server_default="0" ), ) op.add_column( "index_attempt", sa.Column( "total_failures_batch_level", sa.Integer(), nullable=False, server_default="0", ), ) op.add_column( "index_attempt", sa.Column("total_chunks", sa.Integer(), nullable=False, server_default="0"), ) # Progress tracking for stall detection op.add_column( "index_attempt", sa.Column("last_progress_time", sa.DateTime(timezone=True), nullable=True), ) op.add_column( "index_attempt", sa.Column( "last_batches_completed_count", sa.Integer(), nullable=False, server_default="0", ), ) # Heartbeat tracking for worker liveness detection op.add_column( "index_attempt", sa.Column( "heartbeat_counter", sa.Integer(), nullable=False, server_default="0" ), ) op.add_column( "index_attempt", sa.Column( "last_heartbeat_value", sa.Integer(), nullable=False, server_default="0" ), ) op.add_column( "index_attempt", sa.Column("last_heartbeat_time", sa.DateTime(timezone=True), nullable=True), ) # Add index for coordination queries op.create_index( "ix_index_attempt_active_coordination", "index_attempt", ["connector_credential_pair_id", "search_settings_id", "status"], ) def downgrade() -> None: # Remove the new index op.drop_index("ix_index_attempt_active_coordination", table_name="index_attempt") # Remove the new columns op.drop_column("index_attempt", "last_batches_completed_count") op.drop_column("index_attempt", "last_progress_time") op.drop_column("index_attempt", "last_heartbeat_time") op.drop_column("index_attempt", "last_heartbeat_value") op.drop_column("index_attempt", "heartbeat_counter") op.drop_column("index_attempt", "total_chunks") op.drop_column("index_attempt", "total_failures_batch_level") op.drop_column("index_attempt", "completed_batches") op.drop_column("index_attempt", "total_batches") op.drop_column("index_attempt", "cancellation_requested") op.drop_column("index_attempt", "celery_task_id") ================================================ FILE: backend/alembic/versions/30c1d5744104_persona_datetime_aware.py ================================================ """Persona Datetime Aware Revision ID: 30c1d5744104 Revises: 7f99be1cb9f5 Create Date: 2023-10-16 23:21:01.283424 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "30c1d5744104" down_revision = "7f99be1cb9f5" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column("persona", sa.Column("datetime_aware", sa.Boolean(), nullable=True)) op.execute("UPDATE persona SET datetime_aware = TRUE") op.alter_column("persona", "datetime_aware", nullable=False) op.create_index( "_default_persona_name_idx", "persona", ["name"], unique=True, postgresql_where=sa.text("default_persona = true"), ) def downgrade() -> None: op.drop_index( "_default_persona_name_idx", table_name="persona", postgresql_where=sa.text("default_persona = true"), ) op.drop_column("persona", "datetime_aware") ================================================ FILE: backend/alembic/versions/325975216eb3_add_icon_color_and_icon_shape_to_persona.py ================================================ """Add icon_color and icon_shape to Persona Revision ID: 325975216eb3 Revises: 91ffac7e65b3 Create Date: 2024-07-24 21:29:31.784562 """ import random from alembic import op import sqlalchemy as sa from sqlalchemy.sql import table, column, select # revision identifiers, used by Alembic. revision = "325975216eb3" down_revision = "91ffac7e65b3" branch_labels: None = None depends_on: None = None colorOptions = [ "#FF6FBF", "#6FB1FF", "#B76FFF", "#FFB56F", "#6FFF8D", "#FF6F6F", "#6FFFFF", ] # Function to generate a random shape ensuring at least 3 of the middle 4 squares are filled def generate_random_shape() -> int: center_squares = [12, 10, 6, 14, 13, 11, 7, 15] center_fill = random.choice(center_squares) remaining_squares = [i for i in range(16) if not (center_fill & (1 << i))] random.shuffle(remaining_squares) for i in range(10 - bin(center_fill).count("1")): center_fill |= 1 << remaining_squares[i] return center_fill def upgrade() -> None: op.add_column("persona", sa.Column("icon_color", sa.String(), nullable=True)) op.add_column("persona", sa.Column("icon_shape", sa.Integer(), nullable=True)) op.add_column("persona", sa.Column("uploaded_image_id", sa.String(), nullable=True)) persona = table( "persona", column("id", sa.Integer), column("icon_color", sa.String), column("icon_shape", sa.Integer), ) conn = op.get_bind() personas = conn.execute(select(persona.c.id)) for persona_id in personas: random_color = random.choice(colorOptions) random_shape = generate_random_shape() conn.execute( persona.update() .where(persona.c.id == persona_id[0]) .values(icon_color=random_color, icon_shape=random_shape) ) def downgrade() -> None: op.drop_column("persona", "icon_shape") op.drop_column("persona", "uploaded_image_id") op.drop_column("persona", "icon_color") ================================================ FILE: backend/alembic/versions/33cb72ea4d80_single_tool_call_per_message.py ================================================ """single tool call per message Revision ID: 33cb72ea4d80 Revises: 5b29123cd710 Create Date: 2024-11-01 12:51:01.535003 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "33cb72ea4d80" down_revision = "5b29123cd710" branch_labels = None depends_on = None def upgrade() -> None: # Step 1: Delete extraneous ToolCall entries # Keep only the ToolCall with the smallest 'id' for each 'message_id' op.execute( sa.text( """ DELETE FROM tool_call WHERE id NOT IN ( SELECT MIN(id) FROM tool_call WHERE message_id IS NOT NULL GROUP BY message_id ); """ ) ) # Step 2: Add a unique constraint on message_id op.create_unique_constraint( constraint_name="uq_tool_call_message_id", table_name="tool_call", columns=["message_id"], ) def downgrade() -> None: # Step 1: Drop the unique constraint on message_id op.drop_constraint( constraint_name="uq_tool_call_message_id", table_name="tool_call", type_="unique", ) ================================================ FILE: backend/alembic/versions/33ea50e88f24_foreign_key_input_prompts.py ================================================ """foreign key input prompts Revision ID: 33ea50e88f24 Revises: a6df6b88ef81 Create Date: 2025-01-29 10:54:22.141765 """ from alembic import op # revision identifiers, used by Alembic. revision = "33ea50e88f24" down_revision = "a6df6b88ef81" branch_labels = None depends_on = None def upgrade() -> None: # Safely drop constraints if exists op.execute( """ ALTER TABLE inputprompt__user DROP CONSTRAINT IF EXISTS inputprompt__user_input_prompt_id_fkey """ ) op.execute( """ ALTER TABLE inputprompt__user DROP CONSTRAINT IF EXISTS inputprompt__user_user_id_fkey """ ) # Recreate with ON DELETE CASCADE op.create_foreign_key( "inputprompt__user_input_prompt_id_fkey", "inputprompt__user", "inputprompt", ["input_prompt_id"], ["id"], ondelete="CASCADE", ) op.create_foreign_key( "inputprompt__user_user_id_fkey", "inputprompt__user", "user", ["user_id"], ["id"], ondelete="CASCADE", ) def downgrade() -> None: # Drop the new FKs with ondelete op.drop_constraint( "inputprompt__user_input_prompt_id_fkey", "inputprompt__user", type_="foreignkey", ) op.drop_constraint( "inputprompt__user_user_id_fkey", "inputprompt__user", type_="foreignkey", ) # Recreate them without cascading op.create_foreign_key( "inputprompt__user_input_prompt_id_fkey", "inputprompt__user", "inputprompt", ["input_prompt_id"], ["id"], ) op.create_foreign_key( "inputprompt__user_user_id_fkey", "inputprompt__user", "user", ["user_id"], ["id"], ) ================================================ FILE: backend/alembic/versions/351faebd379d_add_curator_fields.py ================================================ """Add curator fields Revision ID: 351faebd379d Revises: ee3f4b47fad5 Create Date: 2024-08-15 22:37:08.397052 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "351faebd379d" down_revision = "ee3f4b47fad5" branch_labels: None = None depends_on: None = None def upgrade() -> None: # Add is_curator column to User__UserGroup table op.add_column( "user__user_group", sa.Column("is_curator", sa.Boolean(), nullable=False, server_default="false"), ) # Use batch mode to modify the enum type with op.batch_alter_table("user", schema=None) as batch_op: batch_op.alter_column( # type: ignore[attr-defined] "role", type_=sa.Enum( "BASIC", "ADMIN", "CURATOR", "GLOBAL_CURATOR", name="userrole", native_enum=False, ), existing_type=sa.Enum("BASIC", "ADMIN", name="userrole", native_enum=False), existing_nullable=False, ) # Create the association table op.create_table( "credential__user_group", sa.Column("credential_id", sa.Integer(), nullable=False), sa.Column("user_group_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["credential_id"], ["credential.id"], ), sa.ForeignKeyConstraint( ["user_group_id"], ["user_group.id"], ), sa.PrimaryKeyConstraint("credential_id", "user_group_id"), ) op.add_column( "credential", sa.Column( "curator_public", sa.Boolean(), nullable=False, server_default="false" ), ) def downgrade() -> None: # Update existing records to ensure they fit within the BASIC/ADMIN roles op.execute( "UPDATE \"user\" SET role = 'ADMIN' WHERE role IN ('CURATOR', 'GLOBAL_CURATOR')" ) # Remove is_curator column from User__UserGroup table op.drop_column("user__user_group", "is_curator") with op.batch_alter_table("user", schema=None) as batch_op: batch_op.alter_column( # type: ignore[attr-defined] "role", type_=sa.Enum( "BASIC", "ADMIN", name="userrole", native_enum=False, length=20 ), existing_type=sa.Enum( "BASIC", "ADMIN", "CURATOR", "GLOBAL_CURATOR", name="userrole", native_enum=False, ), existing_nullable=False, ) # Drop the association table op.drop_table("credential__user_group") op.drop_column("credential", "curator_public") ================================================ FILE: backend/alembic/versions/35e518e0ddf4_properly_cascade.py ================================================ """properly_cascade Revision ID: 35e518e0ddf4 Revises: 91a0a4d62b14 Create Date: 2024-09-20 21:24:04.891018 """ from alembic import op # revision identifiers, used by Alembic. revision = "35e518e0ddf4" down_revision = "91a0a4d62b14" branch_labels = None depends_on = None def upgrade() -> None: # Update chat_message foreign key constraint op.drop_constraint( "chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey" ) op.create_foreign_key( "chat_message_chat_session_id_fkey", "chat_message", "chat_session", ["chat_session_id"], ["id"], ondelete="CASCADE", ) # Update chat_message__search_doc foreign key constraints op.drop_constraint( "chat_message__search_doc_chat_message_id_fkey", "chat_message__search_doc", type_="foreignkey", ) op.drop_constraint( "chat_message__search_doc_search_doc_id_fkey", "chat_message__search_doc", type_="foreignkey", ) op.create_foreign_key( "chat_message__search_doc_chat_message_id_fkey", "chat_message__search_doc", "chat_message", ["chat_message_id"], ["id"], ondelete="CASCADE", ) op.create_foreign_key( "chat_message__search_doc_search_doc_id_fkey", "chat_message__search_doc", "search_doc", ["search_doc_id"], ["id"], ondelete="CASCADE", ) # Add CASCADE delete for tool_call foreign key op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey") op.create_foreign_key( "tool_call_message_id_fkey", "tool_call", "chat_message", ["message_id"], ["id"], ondelete="CASCADE", ) def downgrade() -> None: # Revert chat_message foreign key constraint op.drop_constraint( "chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey" ) op.create_foreign_key( "chat_message_chat_session_id_fkey", "chat_message", "chat_session", ["chat_session_id"], ["id"], ) # Revert chat_message__search_doc foreign key constraints op.drop_constraint( "chat_message__search_doc_chat_message_id_fkey", "chat_message__search_doc", type_="foreignkey", ) op.drop_constraint( "chat_message__search_doc_search_doc_id_fkey", "chat_message__search_doc", type_="foreignkey", ) op.create_foreign_key( "chat_message__search_doc_chat_message_id_fkey", "chat_message__search_doc", "chat_message", ["chat_message_id"], ["id"], ) op.create_foreign_key( "chat_message__search_doc_search_doc_id_fkey", "chat_message__search_doc", "search_doc", ["search_doc_id"], ["id"], ) # Revert tool_call foreign key constraint op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey") op.create_foreign_key( "tool_call_message_id_fkey", "tool_call", "chat_message", ["message_id"], ["id"], ) ================================================ FILE: backend/alembic/versions/35e6853a51d5_server_default_chosen_assistants.py ================================================ """server default chosen assistants Revision ID: 35e6853a51d5 Revises: c99d76fcd298 Create Date: 2024-09-13 13:20:32.885317 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "35e6853a51d5" down_revision = "c99d76fcd298" branch_labels = None depends_on = None DEFAULT_ASSISTANTS = [-2, -1, 0] def upgrade() -> None: # Step 1: Update any NULL values to the default value # This upgrades existing users without ordered assistant # to have default assistants set to visible assistants which are # accessible by them. op.execute( """ UPDATE "user" u SET chosen_assistants = ( SELECT jsonb_agg( p.id ORDER BY COALESCE(p.display_priority, 2147483647) ASC, p.id ASC ) FROM persona p LEFT JOIN persona__user pu ON p.id = pu.persona_id AND pu.user_id = u.id WHERE p.is_visible = true AND (p.is_public = true OR pu.user_id IS NOT NULL) ) WHERE chosen_assistants IS NULL OR chosen_assistants = 'null' OR jsonb_typeof(chosen_assistants) = 'null' OR (jsonb_typeof(chosen_assistants) = 'string' AND chosen_assistants = '"null"') """ ) # Step 2: Alter the column to make it non-nullable op.alter_column( "user", "chosen_assistants", type_=postgresql.JSONB(astext_type=sa.Text()), nullable=False, server_default=sa.text(f"'{DEFAULT_ASSISTANTS}'::jsonb"), ) def downgrade() -> None: op.alter_column( "user", "chosen_assistants", type_=postgresql.JSONB(astext_type=sa.Text()), nullable=True, server_default=None, ) ================================================ FILE: backend/alembic/versions/369644546676_add_composite_index_for_index_attempt_.py ================================================ """add composite index for index attempt time updated Revision ID: 369644546676 Revises: 2955778aa44c Create Date: 2025-01-08 15:38:17.224380 """ from alembic import op from sqlalchemy import text # revision identifiers, used by Alembic. revision = "369644546676" down_revision = "2955778aa44c" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_index( "ix_index_attempt_ccpair_search_settings_time_updated", "index_attempt", [ "connector_credential_pair_id", "search_settings_id", text("time_updated DESC"), ], unique=False, ) def downgrade() -> None: op.drop_index( "ix_index_attempt_ccpair_search_settings_time_updated", table_name="index_attempt", ) ================================================ FILE: backend/alembic/versions/36e9220ab794_update_kg_trigger_functions.py ================================================ """update_kg_trigger_functions Revision ID: 36e9220ab794 Revises: c9e2cd766c29 Create Date: 2025-06-22 17:33:25.833733 """ from alembic import op from sqlalchemy.orm import Session from sqlalchemy import text from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA # revision identifiers, used by Alembic. revision = "36e9220ab794" down_revision = "c9e2cd766c29" branch_labels = None depends_on = None def _get_tenant_contextvar(session: Session) -> str: """Get the current schema for the migration""" current_tenant = session.execute(text("SELECT current_schema()")).scalar() if isinstance(current_tenant, str): return current_tenant else: raise ValueError("Current tenant is not a string") def upgrade() -> None: bind = op.get_bind() session = Session(bind=bind) # Create kg_entity trigger to update kg_entity.name and its trigrams tenant_id = _get_tenant_contextvar(session) alphanum_pattern = r"[^a-z0-9]+" truncate_length = 1000 function = "update_kg_entity_name" op.execute( text( f""" CREATE OR REPLACE FUNCTION "{tenant_id}".{function}() RETURNS TRIGGER AS $$ DECLARE name text; cleaned_name text; BEGIN -- Set name to semantic_id if document_id is not NULL IF NEW.document_id IS NOT NULL THEN SELECT lower(semantic_id) INTO name FROM "{tenant_id}".document WHERE id = NEW.document_id; ELSE name = lower(NEW.name); END IF; -- Clean name and truncate if too long cleaned_name = regexp_replace( name, '{alphanum_pattern}', '', 'g' ); IF length(cleaned_name) > {truncate_length} THEN cleaned_name = left(cleaned_name, {truncate_length}); END IF; -- Set name and name trigrams NEW.name = name; NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name); RETURN NEW; END; $$ LANGUAGE plpgsql; """ ) ) trigger = f"{function}_trigger" op.execute(f'DROP TRIGGER IF EXISTS {trigger} ON "{tenant_id}".kg_entity') op.execute( f""" CREATE TRIGGER {trigger} BEFORE INSERT OR UPDATE OF name ON "{tenant_id}".kg_entity FOR EACH ROW EXECUTE FUNCTION "{tenant_id}".{function}(); """ ) # Create kg_entity trigger to update kg_entity.name and its trigrams function = "update_kg_entity_name_from_doc" op.execute( text( f""" CREATE OR REPLACE FUNCTION "{tenant_id}".{function}() RETURNS TRIGGER AS $$ DECLARE doc_name text; cleaned_name text; BEGIN doc_name = lower(NEW.semantic_id); -- Clean name and truncate if too long cleaned_name = regexp_replace( doc_name, '{alphanum_pattern}', '', 'g' ); IF length(cleaned_name) > {truncate_length} THEN cleaned_name = left(cleaned_name, {truncate_length}); END IF; -- Set name and name trigrams for all entities referencing this document UPDATE "{tenant_id}".kg_entity SET name = doc_name, name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name) WHERE document_id = NEW.id; RETURN NEW; END; $$ LANGUAGE plpgsql; """ ) ) trigger = f"{function}_trigger" op.execute(f'DROP TRIGGER IF EXISTS {trigger} ON "{tenant_id}".document') op.execute( f""" CREATE TRIGGER {trigger} AFTER UPDATE OF semantic_id ON "{tenant_id}".document FOR EACH ROW EXECUTE FUNCTION "{tenant_id}".{function}(); """ ) def downgrade() -> None: pass ================================================ FILE: backend/alembic/versions/3781a5eb12cb_add_chunk_stats_table.py ================================================ """add chunk stats table Revision ID: 3781a5eb12cb Revises: df46c75b714e Create Date: 2025-03-10 10:02:30.586666 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "3781a5eb12cb" down_revision = "df46c75b714e" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "chunk_stats", sa.Column("id", sa.String(), primary_key=True, index=True), sa.Column( "document_id", sa.String(), sa.ForeignKey("document.id"), nullable=False, index=True, ), sa.Column("chunk_in_doc_id", sa.Integer(), nullable=False), sa.Column("information_content_boost", sa.Float(), nullable=True), sa.Column( "last_modified", sa.DateTime(timezone=True), nullable=False, index=True, server_default=sa.func.now(), ), sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True, index=True), sa.UniqueConstraint( "document_id", "chunk_in_doc_id", name="uq_chunk_stats_doc_chunk" ), ) op.create_index( "ix_chunk_sync_status", "chunk_stats", ["last_modified", "last_synced"] ) def downgrade() -> None: op.drop_index("ix_chunk_sync_status", table_name="chunk_stats") op.drop_table("chunk_stats") ================================================ FILE: backend/alembic/versions/3879338f8ba1_add_tool_table.py ================================================ """Add tool table Revision ID: 3879338f8ba1 Revises: f1c6478c3fd8 Create Date: 2024-05-11 16:11:23.718084 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "3879338f8ba1" down_revision = "f1c6478c3fd8" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "tool", sa.Column("id", sa.Integer(), nullable=False), sa.Column("name", sa.String(), nullable=False), sa.Column("description", sa.Text(), nullable=True), sa.Column("in_code_tool_id", sa.String(), nullable=True), sa.PrimaryKeyConstraint("id"), ) op.create_table( "persona__tool", sa.Column("persona_id", sa.Integer(), nullable=False), sa.Column("tool_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["persona_id"], ["persona.id"], ), sa.ForeignKeyConstraint( ["tool_id"], ["tool.id"], ), sa.PrimaryKeyConstraint("persona_id", "tool_id"), ) def downgrade() -> None: op.drop_table("persona__tool") op.drop_table("tool") ================================================ FILE: backend/alembic/versions/38eda64af7fe_add_chat_session_sharing.py ================================================ """Add chat session sharing Revision ID: 38eda64af7fe Revises: 776b3bbe9092 Create Date: 2024-03-27 19:41:29.073594 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "38eda64af7fe" down_revision = "776b3bbe9092" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "chat_session", sa.Column( "shared_status", sa.Enum( "PUBLIC", "PRIVATE", name="chatsessionsharedstatus", native_enum=False, ), nullable=True, ), ) op.execute("UPDATE chat_session SET shared_status='PRIVATE'") op.alter_column( "chat_session", "shared_status", nullable=False, ) def downgrade() -> None: op.drop_column("chat_session", "shared_status") ================================================ FILE: backend/alembic/versions/3934b1bc7b62_update_github_connector_repo_name_to_.py ================================================ """Update GitHub connector repo_name to repositories Revision ID: 3934b1bc7b62 Revises: b7c2b63c4a03 Create Date: 2025-03-05 10:50:30.516962 """ from alembic import op import sqlalchemy as sa import json import logging # revision identifiers, used by Alembic. revision = "3934b1bc7b62" down_revision = "b7c2b63c4a03" branch_labels = None depends_on = None logger = logging.getLogger("alembic.runtime.migration") def upgrade() -> None: # Get all GitHub connectors conn = op.get_bind() # First get all GitHub connectors github_connectors = conn.execute( sa.text( """ SELECT id, connector_specific_config FROM connector WHERE source = 'GITHUB' """ ) ).fetchall() # Update each connector's config updated_count = 0 for connector_id, config in github_connectors: try: if not config: logger.warning(f"Connector {connector_id} has no config, skipping") continue # Parse the config if it's a string if isinstance(config, str): config = json.loads(config) if "repo_name" not in config: continue # Create new config with repositories instead of repo_name new_config = dict(config) repo_name_value = new_config.pop("repo_name") new_config["repositories"] = repo_name_value # Update the connector with the new config conn.execute( sa.text( """ UPDATE connector SET connector_specific_config = :new_config WHERE id = :connector_id """ ), {"connector_id": connector_id, "new_config": json.dumps(new_config)}, ) updated_count += 1 except Exception as e: logger.error(f"Error updating connector {connector_id}: {str(e)}") def downgrade() -> None: # Get all GitHub connectors conn = op.get_bind() logger.debug( "Starting rollback of GitHub connectors from repositories to repo_name" ) github_connectors = conn.execute( sa.text( """ SELECT id, connector_specific_config FROM connector WHERE source = 'GITHUB' """ ) ).fetchall() logger.debug(f"Found {len(github_connectors)} GitHub connectors to rollback") # Revert each GitHub connector to use repo_name instead of repositories reverted_count = 0 for connector_id, config in github_connectors: try: if not config: continue # Parse the config if it's a string if isinstance(config, str): config = json.loads(config) if "repositories" not in config: continue # Create new config with repo_name instead of repositories new_config = dict(config) repositories_value = new_config.pop("repositories") new_config["repo_name"] = repositories_value # Update the connector with the new config conn.execute( sa.text( """ UPDATE connector SET connector_specific_config = :new_config WHERE id = :connector_id """ ), {"new_config": json.dumps(new_config), "connector_id": connector_id}, ) reverted_count += 1 except Exception as e: logger.error(f"Error reverting connector {connector_id}: {str(e)}") ================================================ FILE: backend/alembic/versions/3a7802814195_add_alternate_assistant_to_chat_message.py ================================================ """add alternate assistant to chat message Revision ID: 3a7802814195 Revises: 23957775e5f5 Create Date: 2024-06-05 11:18:49.966333 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "3a7802814195" down_revision = "23957775e5f5" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True) ) op.create_foreign_key( "fk_chat_message_persona", "chat_message", "persona", ["alternate_assistant_id"], ["id"], ) def downgrade() -> None: op.drop_constraint("fk_chat_message_persona", "chat_message", type_="foreignkey") op.drop_column("chat_message", "alternate_assistant_id") ================================================ FILE: backend/alembic/versions/3a78dba1080a_user_file_legacy_data_cleanup.py ================================================ """Migration 5: User file legacy data cleanup Revision ID: 3a78dba1080a Revises: 7cc3fcc116c1 Create Date: 2025-09-22 10:04:27.986294 This migration removes legacy user-file documents and connector_credential_pairs. It performs bulk deletions of obsolete data after the UUID migration. """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql as psql from sqlalchemy import text import logging from typing import List import uuid logger = logging.getLogger("alembic.runtime.migration") # revision identifiers, used by Alembic. revision = "3a78dba1080a" down_revision = "7cc3fcc116c1" branch_labels = None depends_on = None def batch_delete( bind: sa.engine.Connection, table_name: str, id_column: str, ids: List[str | int | uuid.UUID], batch_size: int = 1000, id_type: str = "int", ) -> int: """Delete records in batches to avoid memory issues and timeouts.""" total_count = len(ids) if total_count == 0: return 0 logger.info( f"Starting batch deletion of {total_count} records from {table_name}..." ) # Determine appropriate ARRAY type if id_type == "uuid": array_type = psql.ARRAY(psql.UUID(as_uuid=True)) elif id_type == "int": array_type = psql.ARRAY(sa.Integer()) else: array_type = psql.ARRAY(sa.String()) total_deleted = 0 failed_batches = [] for i in range(0, total_count, batch_size): batch_ids = ids[i : i + batch_size] try: stmt = text( f"DELETE FROM {table_name} WHERE {id_column} = ANY(:ids)" ).bindparams(sa.bindparam("ids", value=batch_ids, type_=array_type)) result = bind.execute(stmt) total_deleted += result.rowcount # Log progress every 10 batches or at completion batch_num = (i // batch_size) + 1 if batch_num % 10 == 0 or i + batch_size >= total_count: logger.info( f" Deleted {min(i + batch_size, total_count)}/{total_count} records " f"({total_deleted} actual) from {table_name}" ) except Exception as e: logger.error(f"Failed to delete batch {(i // batch_size) + 1}: {e}") failed_batches.append((i, min(i + batch_size, total_count))) if failed_batches: logger.warning( f"Failed to delete {len(failed_batches)} batches from {table_name}. Total deleted: {total_deleted}/{total_count}" ) # Fail the migration to avoid silently succeeding on partial cleanup raise RuntimeError( f"Batch deletion failed for {table_name}: " f"{len(failed_batches)} failed batches out of " f"{(total_count + batch_size - 1) // batch_size}." ) return total_deleted def upgrade() -> None: """Remove legacy user-file documents and connector_credential_pairs.""" bind = op.get_bind() inspector = sa.inspect(bind) logger.info("Starting legacy data cleanup...") # === Step 1: Identify and delete user-file documents === logger.info("Identifying user-file documents to delete...") # Get document IDs to delete doc_rows = bind.execute( text( """ SELECT DISTINCT dcc.id AS document_id FROM document_by_connector_credential_pair dcc JOIN connector_credential_pair u ON u.connector_id = dcc.connector_id AND u.credential_id = dcc.credential_id WHERE u.is_user_file IS TRUE """ ) ).fetchall() doc_ids = [r[0] for r in doc_rows] if doc_ids: logger.info(f"Found {len(doc_ids)} user-file documents to delete") # Delete dependent rows first tables_to_clean = [ ("document_retrieval_feedback", "document_id"), ("document__tag", "document_id"), ("chunk_stats", "document_id"), ] for table_name, column_name in tables_to_clean: if table_name in inspector.get_table_names(): # document_id is a string in these tables deleted = batch_delete( bind, table_name, column_name, doc_ids, id_type="str" ) logger.info(f"Deleted {deleted} records from {table_name}") # Delete document_by_connector_credential_pair entries deleted = batch_delete( bind, "document_by_connector_credential_pair", "id", doc_ids, id_type="str" ) logger.info(f"Deleted {deleted} document_by_connector_credential_pair records") # Delete documents themselves deleted = batch_delete(bind, "document", "id", doc_ids, id_type="str") logger.info(f"Deleted {deleted} document records") else: logger.info("No user-file documents found to delete") # === Step 2: Clean up user-file connector_credential_pairs === logger.info("Cleaning up user-file connector_credential_pairs...") # Get cc_pair IDs cc_pair_rows = bind.execute( text( """ SELECT id AS cc_pair_id FROM connector_credential_pair WHERE is_user_file IS TRUE """ ) ).fetchall() cc_pair_ids = [r[0] for r in cc_pair_rows] if cc_pair_ids: logger.info( f"Found {len(cc_pair_ids)} user-file connector_credential_pairs to clean up" ) # Delete related records # Clean child tables first to satisfy foreign key constraints, # then the parent tables tables_to_clean = [ ("index_attempt_errors", "connector_credential_pair_id"), ("index_attempt", "connector_credential_pair_id"), ("background_error", "cc_pair_id"), ("document_set__connector_credential_pair", "connector_credential_pair_id"), ("user_group__connector_credential_pair", "cc_pair_id"), ] for table_name, column_name in tables_to_clean: if table_name in inspector.get_table_names(): deleted = batch_delete( bind, table_name, column_name, cc_pair_ids, id_type="int" ) logger.info(f"Deleted {deleted} records from {table_name}") # === Step 3: Identify connectors and credentials to delete === logger.info("Identifying orphaned connectors and credentials...") # Get connectors used only by user-file cc_pairs connector_rows = bind.execute( text( """ SELECT DISTINCT ccp.connector_id FROM connector_credential_pair ccp WHERE ccp.is_user_file IS TRUE AND ccp.connector_id != 0 -- Exclude system default AND NOT EXISTS ( SELECT 1 FROM connector_credential_pair c2 WHERE c2.connector_id = ccp.connector_id AND c2.is_user_file IS NOT TRUE ) """ ) ).fetchall() userfile_only_connector_ids = [r[0] for r in connector_rows] # Get credentials used only by user-file cc_pairs credential_rows = bind.execute( text( """ SELECT DISTINCT ccp.credential_id FROM connector_credential_pair ccp WHERE ccp.is_user_file IS TRUE AND ccp.credential_id != 0 -- Exclude public/default AND NOT EXISTS ( SELECT 1 FROM connector_credential_pair c2 WHERE c2.credential_id = ccp.credential_id AND c2.is_user_file IS NOT TRUE ) """ ) ).fetchall() userfile_only_credential_ids = [r[0] for r in credential_rows] # === Step 4: Delete the cc_pairs themselves === if cc_pair_ids: # Remove FK dependency from user_file first bind.execute( text( """ DO $$ DECLARE r RECORD; BEGIN FOR r IN ( SELECT conname FROM pg_constraint c JOIN pg_class t ON c.conrelid = t.oid JOIN pg_class ft ON c.confrelid = ft.oid WHERE c.contype = 'f' AND t.relname = 'user_file' AND ft.relname = 'connector_credential_pair' ) LOOP EXECUTE format('ALTER TABLE user_file DROP CONSTRAINT IF EXISTS %I', r.conname); END LOOP; END$$; """ ) ) # Delete cc_pairs deleted = batch_delete( bind, "connector_credential_pair", "id", cc_pair_ids, id_type="int" ) logger.info(f"Deleted {deleted} connector_credential_pair records") # === Step 5: Delete orphaned connectors === if userfile_only_connector_ids: deleted = batch_delete( bind, "connector", "id", userfile_only_connector_ids, id_type="int" ) logger.info(f"Deleted {deleted} orphaned connector records") # === Step 6: Delete orphaned credentials === if userfile_only_credential_ids: # Clean up credential__user_group mappings first deleted = batch_delete( bind, "credential__user_group", "credential_id", userfile_only_credential_ids, id_type="int", ) logger.info(f"Deleted {deleted} credential__user_group records") # Delete credentials deleted = batch_delete( bind, "credential", "id", userfile_only_credential_ids, id_type="int" ) logger.info(f"Deleted {deleted} orphaned credential records") logger.info("Migration 5 (legacy data cleanup) completed successfully") def downgrade() -> None: """Cannot restore deleted data - requires backup restoration.""" logger.error("CRITICAL: Downgrading data cleanup cannot restore deleted data!") logger.error("Data restoration requires backup files or database backup.") # raise NotImplementedError( # "Downgrade of legacy data cleanup is not supported. " # "Deleted data must be restored from backups." # ) ================================================ FILE: backend/alembic/versions/3b25685ff73c_move_is_public_to_cc_pair.py ================================================ """Move is_public to cc_pair Revision ID: 3b25685ff73c Revises: e0a68a81d434 Create Date: 2023-10-05 18:47:09.582849 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "3b25685ff73c" down_revision = "e0a68a81d434" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "connector_credential_pair", sa.Column("is_public", sa.Boolean(), nullable=True), ) # fill in is_public for existing rows op.execute( "UPDATE connector_credential_pair SET is_public = true WHERE is_public IS NULL" ) op.alter_column("connector_credential_pair", "is_public", nullable=False) op.add_column( "credential", sa.Column("is_admin", sa.Boolean(), nullable=True), ) op.execute("UPDATE credential SET is_admin = true WHERE is_admin IS NULL") op.alter_column("credential", "is_admin", nullable=False) op.drop_column("credential", "public_doc") def downgrade() -> None: op.add_column( "credential", sa.Column("public_doc", sa.Boolean(), nullable=True), ) # setting public_doc to false for all existing rows to be safe # NOTE: this is likely not the correct state of the world but it's the best we can do op.execute("UPDATE credential SET public_doc = false WHERE public_doc IS NULL") op.alter_column("credential", "public_doc", nullable=False) op.drop_column("connector_credential_pair", "is_public") op.drop_column("credential", "is_admin") ================================================ FILE: backend/alembic/versions/3bd4c84fe72f_improved_index.py ================================================ """improved index Revision ID: 3bd4c84fe72f Revises: 8f43500ee275 Create Date: 2025-02-26 13:07:56.217791 """ from alembic import op # revision identifiers, used by Alembic. revision = "3bd4c84fe72f" down_revision = "8f43500ee275" branch_labels = None depends_on = None # NOTE: # This migration addresses issues with the previous migration (8f43500ee275) which caused # an outage by creating an index without using CONCURRENTLY. This migration: # # 1. Creates more efficient full-text search capabilities using tsvector columns and GIN indexes # 2. Adds indexes to both chat_message and chat_session tables for comprehensive search # 3. Note: CONCURRENTLY was removed due to operational issues def upgrade() -> None: # First, drop any existing indexes to avoid conflicts op.execute("DROP INDEX IF EXISTS idx_chat_message_tsv;") op.execute("DROP INDEX IF EXISTS idx_chat_session_desc_tsv;") op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;") # Drop existing columns if they exist op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;") op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;") # Create a GIN index for full-text search on chat_message.message op.execute( """ ALTER TABLE chat_message ADD COLUMN message_tsv tsvector GENERATED ALWAYS AS (to_tsvector('english', message)) STORED; """ ) op.execute( """ CREATE INDEX IF NOT EXISTS idx_chat_message_tsv ON chat_message USING GIN (message_tsv) """ ) # Also add a stored tsvector column for chat_session.description op.execute( """ ALTER TABLE chat_session ADD COLUMN description_tsv tsvector GENERATED ALWAYS AS (to_tsvector('english', coalesce(description, ''))) STORED; """ ) op.execute( """ CREATE INDEX IF NOT EXISTS idx_chat_session_desc_tsv ON chat_session USING GIN (description_tsv) """ ) def downgrade() -> None: # Drop the indexes first op.execute("DROP INDEX IF EXISTS idx_chat_message_tsv;") op.execute("DROP INDEX IF EXISTS idx_chat_session_desc_tsv;") # Then drop the columns op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv;") op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS description_tsv;") op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;") ================================================ FILE: backend/alembic/versions/3c5e35aa9af0_polling_document_count.py ================================================ """Polling Document Count Revision ID: 3c5e35aa9af0 Revises: 27c6ecc08586 Create Date: 2023-06-14 23:45:51.760440 """ import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. revision = "3c5e35aa9af0" down_revision = "27c6ecc08586" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "connector_credential_pair", sa.Column( "last_successful_index_time", sa.DateTime(timezone=True), nullable=True, ), ) op.add_column( "connector_credential_pair", sa.Column( "last_attempt_status", sa.Enum( "NOT_STARTED", "IN_PROGRESS", "SUCCESS", "FAILED", name="indexingstatus", native_enum=False, ), nullable=False, ), ) op.add_column( "connector_credential_pair", sa.Column("total_docs_indexed", sa.Integer(), nullable=False), ) def downgrade() -> None: op.drop_column("connector_credential_pair", "total_docs_indexed") op.drop_column("connector_credential_pair", "last_attempt_status") op.drop_column("connector_credential_pair", "last_successful_index_time") ================================================ FILE: backend/alembic/versions/3c6531f32351_add_back_input_prompts.py ================================================ """add back input prompts Revision ID: 3c6531f32351 Revises: aeda5f2df4f6 Create Date: 2025-01-13 12:49:51.705235 """ from alembic import op import sqlalchemy as sa import fastapi_users_db_sqlalchemy # revision identifiers, used by Alembic. revision = "3c6531f32351" down_revision = "aeda5f2df4f6" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "inputprompt", sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), sa.Column("prompt", sa.String(), nullable=False), sa.Column("content", sa.String(), nullable=False), sa.Column("active", sa.Boolean(), nullable=False), sa.Column("is_public", sa.Boolean(), nullable=False), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=True, ), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("id"), ) op.create_table( "inputprompt__user", sa.Column("input_prompt_id", sa.Integer(), nullable=False), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False ), sa.Column("disabled", sa.Boolean(), nullable=False, default=False), sa.ForeignKeyConstraint( ["input_prompt_id"], ["inputprompt.id"], ), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("input_prompt_id", "user_id"), ) def downgrade() -> None: op.drop_table("inputprompt__user") op.drop_table("inputprompt") ================================================ FILE: backend/alembic/versions/3c9a65f1207f_seed_exa_provider_from_env.py ================================================ """seed_exa_provider_from_env Revision ID: 3c9a65f1207f Revises: 1f2a3b4c5d6e Create Date: 2025-11-20 19:18:00.000000 """ from __future__ import annotations import os from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql from dotenv import load_dotenv, find_dotenv from onyx.utils.encryption import encrypt_string_to_bytes revision = "3c9a65f1207f" down_revision = "1f2a3b4c5d6e" branch_labels = None depends_on = None EXA_PROVIDER_NAME = "Exa" def _get_internet_search_table(metadata: sa.MetaData) -> sa.Table: return sa.Table( "internet_search_provider", metadata, sa.Column("id", sa.Integer, primary_key=True), sa.Column("name", sa.String), sa.Column("provider_type", sa.String), sa.Column("api_key", sa.LargeBinary), sa.Column("config", postgresql.JSONB), sa.Column("is_active", sa.Boolean), sa.Column( "time_created", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()"), ), sa.Column( "time_updated", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("now()"), ), ) def upgrade() -> None: load_dotenv(find_dotenv()) exa_api_key = os.environ.get("EXA_API_KEY") if not exa_api_key: return bind = op.get_bind() metadata = sa.MetaData() table = _get_internet_search_table(metadata) existing = bind.execute( sa.select(table.c.id).where(table.c.name == EXA_PROVIDER_NAME) ).first() if existing: return encrypted_key = encrypt_string_to_bytes(exa_api_key) has_active_provider = bind.execute( sa.select(table.c.id).where(table.c.is_active.is_(True)) ).first() bind.execute( table.insert().values( name=EXA_PROVIDER_NAME, provider_type="exa", api_key=encrypted_key, config=None, is_active=not bool(has_active_provider), ) ) def downgrade() -> None: return ================================================ FILE: backend/alembic/versions/3d1cca026fe8_add_oauth_config_and_user_tokens.py ================================================ """add_oauth_config_and_user_tokens Revision ID: 3d1cca026fe8 Revises: c8a93a2af083 Create Date: 2025-10-21 13:27:34.274721 """ from alembic import op import fastapi_users_db_sqlalchemy import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "3d1cca026fe8" down_revision = "c8a93a2af083" branch_labels = None depends_on = None def upgrade() -> None: # Create oauth_config table op.create_table( "oauth_config", sa.Column("id", sa.Integer(), nullable=False), sa.Column("name", sa.String(), nullable=False), sa.Column("authorization_url", sa.Text(), nullable=False), sa.Column("token_url", sa.Text(), nullable=False), sa.Column("client_id", sa.LargeBinary(), nullable=False), sa.Column("client_secret", sa.LargeBinary(), nullable=False), sa.Column("scopes", postgresql.JSONB(astext_type=sa.Text()), nullable=True), sa.Column( "additional_params", postgresql.JSONB(astext_type=sa.Text()), nullable=True, ), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column( "updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("name"), ) # Create oauth_user_token table op.create_table( "oauth_user_token", sa.Column("id", sa.Integer(), nullable=False), sa.Column("oauth_config_id", sa.Integer(), nullable=False), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False, ), sa.Column("token_data", sa.LargeBinary(), nullable=False), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column( "updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.ForeignKeyConstraint( ["oauth_config_id"], ["oauth_config.id"], ondelete="CASCADE" ), sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("oauth_config_id", "user_id", name="uq_oauth_user_token"), ) # Create index on user_id for efficient user-based token lookups # Note: unique constraint on (oauth_config_id, user_id) already creates # an index for config-based lookups op.create_index( "ix_oauth_user_token_user_id", "oauth_user_token", ["user_id"], ) # Add oauth_config_id column to tool table op.add_column("tool", sa.Column("oauth_config_id", sa.Integer(), nullable=True)) # Create foreign key from tool to oauth_config op.create_foreign_key( "tool_oauth_config_fk", "tool", "oauth_config", ["oauth_config_id"], ["id"], ondelete="SET NULL", ) def downgrade() -> None: # Drop foreign key from tool to oauth_config op.drop_constraint("tool_oauth_config_fk", "tool", type_="foreignkey") # Drop oauth_config_id column from tool table op.drop_column("tool", "oauth_config_id") # Drop index on user_id op.drop_index("ix_oauth_user_token_user_id", table_name="oauth_user_token") # Drop oauth_user_token table (will cascade delete tokens) op.drop_table("oauth_user_token") # Drop oauth_config table op.drop_table("oauth_config") ================================================ FILE: backend/alembic/versions/3fc5d75723b3_add_doc_metadata_field_in_document_model.py ================================================ """add_doc_metadata_field_in_document_model Revision ID: 3fc5d75723b3 Revises: 2f95e36923e6 Create Date: 2025-07-28 18:45:37.985406 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "3fc5d75723b3" down_revision = "2f95e36923e6" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "document", sa.Column( "doc_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True ), ) def downgrade() -> None: op.drop_column("document", "doc_metadata") ================================================ FILE: backend/alembic/versions/401c1ac29467_add_tables_for_ui_based_llm_.py ================================================ """Add tables for UI-based LLM configuration Revision ID: 401c1ac29467 Revises: 703313b75876 Create Date: 2024-04-13 18:07:29.153817 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "401c1ac29467" down_revision = "703313b75876" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "llm_provider", sa.Column("id", sa.Integer(), nullable=False), sa.Column("name", sa.String(), nullable=False), sa.Column("api_key", sa.String(), nullable=True), sa.Column("api_base", sa.String(), nullable=True), sa.Column("api_version", sa.String(), nullable=True), sa.Column( "custom_config", postgresql.JSONB(astext_type=sa.Text()), nullable=True, ), sa.Column("default_model_name", sa.String(), nullable=False), sa.Column("fast_default_model_name", sa.String(), nullable=True), sa.Column("is_default_provider", sa.Boolean(), unique=True, nullable=True), sa.Column("model_names", postgresql.ARRAY(sa.String()), nullable=True), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("name"), ) op.add_column( "persona", sa.Column("llm_model_provider_override", sa.String(), nullable=True), ) def downgrade() -> None: op.drop_column("persona", "llm_model_provider_override") op.drop_table("llm_provider") ================================================ FILE: backend/alembic/versions/40926a4dab77_reset_userfile_document_id_migrated_.py ================================================ """reset userfile document_id_migrated field Revision ID: 40926a4dab77 Revises: 64bd5677aeb6 Create Date: 2025-10-06 16:10:32.898668 """ from alembic import op # revision identifiers, used by Alembic. revision = "40926a4dab77" down_revision = "64bd5677aeb6" branch_labels = None depends_on = None def upgrade() -> None: # Set all existing records to not migrated op.execute( "UPDATE user_file SET document_id_migrated = FALSE WHERE document_id_migrated IS DISTINCT FROM FALSE;" ) def downgrade() -> None: # No-op pass ================================================ FILE: backend/alembic/versions/41fa44bef321_remove_default_prompt_shortcuts.py ================================================ """remove default prompt shortcuts Revision ID: 41fa44bef321 Revises: 2c2430828bdf Create Date: 2025-01-21 """ from alembic import op # revision identifiers, used by Alembic. revision = "41fa44bef321" down_revision = "2c2430828bdf" branch_labels = None depends_on = None def upgrade() -> None: # Delete any user associations for the default prompts first (foreign key constraint) op.execute( "DELETE FROM inputprompt__user WHERE input_prompt_id IN (SELECT id FROM inputprompt WHERE id < 0)" ) # Delete the pre-seeded default prompt shortcuts (they have negative IDs) op.execute("DELETE FROM inputprompt WHERE id < 0") def downgrade() -> None: # We don't restore the default prompts on downgrade pass ================================================ FILE: backend/alembic/versions/43cbbb3f5e6a_rename_index_origin_to_index_recursively.py ================================================ """Rename index_origin to index_recursively Revision ID: 1d6ad76d1f37 Revises: e1392f05e840 Create Date: 2024-08-01 12:38:54.466081 """ from alembic import op # revision identifiers, used by Alembic. revision = "1d6ad76d1f37" down_revision = "e1392f05e840" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.execute( """ UPDATE connector SET connector_specific_config = jsonb_set( connector_specific_config, '{index_recursively}', 'true'::jsonb ) - 'index_origin' WHERE connector_specific_config ? 'index_origin' """ ) def downgrade() -> None: op.execute( """ UPDATE connector SET connector_specific_config = jsonb_set( connector_specific_config, '{index_origin}', connector_specific_config->'index_recursively' ) - 'index_recursively' WHERE connector_specific_config ? 'index_recursively' """ ) ================================================ FILE: backend/alembic/versions/44f856ae2a4a_add_cloud_embedding_model.py ================================================ """add cloud embedding model and update embedding_model Revision ID: 44f856ae2a4a Revises: d716b0791ddd Create Date: 2024-06-28 20:01:05.927647 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "44f856ae2a4a" down_revision = "d716b0791ddd" branch_labels: None = None depends_on: None = None def upgrade() -> None: # Create embedding_provider table op.create_table( "embedding_provider", sa.Column("id", sa.Integer(), nullable=False), sa.Column("name", sa.String(), nullable=False), sa.Column("api_key", sa.LargeBinary(), nullable=True), sa.Column("default_model_id", sa.Integer(), nullable=True), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("name"), ) # Add cloud_provider_id to embedding_model table op.add_column( "embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True) ) # Add foreign key constraints op.create_foreign_key( "fk_embedding_model_cloud_provider", "embedding_model", "embedding_provider", ["cloud_provider_id"], ["id"], ) op.create_foreign_key( "fk_embedding_provider_default_model", "embedding_provider", "embedding_model", ["default_model_id"], ["id"], ) def downgrade() -> None: # Remove foreign key constraints op.drop_constraint( "fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey" ) op.drop_constraint( "fk_embedding_provider_default_model", "embedding_provider", type_="foreignkey" ) # Remove cloud_provider_id column op.drop_column("embedding_model", "cloud_provider_id") # Drop embedding_provider table op.drop_table("embedding_provider") ================================================ FILE: backend/alembic/versions/4505fd7302e1_added_is_internet_to_dbdoc.py ================================================ """added is_internet to DBDoc Revision ID: 4505fd7302e1 Revises: c18cdf4b497e Create Date: 2024-06-18 20:46:09.095034 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "4505fd7302e1" down_revision = "c18cdf4b497e" def upgrade() -> None: op.add_column("search_doc", sa.Column("is_internet", sa.Boolean(), nullable=True)) op.add_column("tool", sa.Column("display_name", sa.String(), nullable=True)) def downgrade() -> None: op.drop_column("tool", "display_name") op.drop_column("search_doc", "is_internet") ================================================ FILE: backend/alembic/versions/465f78d9b7f9_larger_access_tokens_for_oauth.py ================================================ """Larger Access Tokens for OAUTH Revision ID: 465f78d9b7f9 Revises: 3c5e35aa9af0 Create Date: 2023-07-18 17:33:40.365034 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "465f78d9b7f9" down_revision = "3c5e35aa9af0" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.alter_column("oauth_account", "access_token", type_=sa.Text()) def downgrade() -> None: op.alter_column("oauth_account", "access_token", type_=sa.String(length=1024)) ================================================ FILE: backend/alembic/versions/46625e4745d4_remove_native_enum.py ================================================ """Remove Native Enum Revision ID: 46625e4745d4 Revises: 9d97fecfab7f Create Date: 2023-10-27 11:38:33.803145 """ from alembic import op from sqlalchemy import String # revision identifiers, used by Alembic. revision = "46625e4745d4" down_revision = "9d97fecfab7f" branch_labels: None = None depends_on: None = None def upgrade() -> None: # At this point, we directly changed some previous migrations, # https://github.com/onyx-dot-app/onyx/pull/637 # Due to using Postgres native Enums, it caused some complications for first time users. # To remove those complications, all Enums are only handled application side moving forward. # This migration exists to ensure that existing users don't run into upgrade issues. op.alter_column("index_attempt", "status", type_=String) op.alter_column("connector_credential_pair", "last_attempt_status", type_=String) op.execute("DROP TYPE IF EXISTS indexingstatus") def downgrade() -> None: # We don't want Native Enums, do nothing pass ================================================ FILE: backend/alembic/versions/46b7a812670f_fix_user__external_user_group_id_fk.py ================================================ """fix_user__external_user_group_id_fk Revision ID: 46b7a812670f Revises: f32615f71aeb Create Date: 2024-09-23 12:58:03.894038 """ from alembic import op # revision identifiers, used by Alembic. revision = "46b7a812670f" down_revision = "f32615f71aeb" branch_labels = None depends_on = None def upgrade() -> None: # Drop the existing primary key op.drop_constraint( "user__external_user_group_id_pkey", "user__external_user_group_id", type_="primary", ) # Add the new composite primary key op.create_primary_key( "user__external_user_group_id_pkey", "user__external_user_group_id", ["user_id", "external_user_group_id", "cc_pair_id"], ) def downgrade() -> None: # Drop the composite primary key op.drop_constraint( "user__external_user_group_id_pkey", "user__external_user_group_id", type_="primary", ) # Delete all entries from the table op.execute("DELETE FROM user__external_user_group_id") # Recreate the original primary key on user_id op.create_primary_key( "user__external_user_group_id_pkey", "user__external_user_group_id", ["user_id"] ) ================================================ FILE: backend/alembic/versions/4738e4b3bae1_pg_file_store.py ================================================ """PG File Store Revision ID: 4738e4b3bae1 Revises: e91df4e935ef Create Date: 2024-03-20 18:53:32.461518 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "4738e4b3bae1" down_revision = "e91df4e935ef" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "file_store", sa.Column("file_name", sa.String(), nullable=False), sa.Column("lobj_oid", sa.Integer(), nullable=False), sa.PrimaryKeyConstraint("file_name"), ) def downgrade() -> None: op.drop_table("file_store") ================================================ FILE: backend/alembic/versions/473a1a7ca408_add_display_model_names_to_llm_provider.py ================================================ """Add display_model_names to llm_provider Revision ID: 473a1a7ca408 Revises: 325975216eb3 Create Date: 2024-07-25 14:31:02.002917 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "473a1a7ca408" down_revision = "325975216eb3" branch_labels: None = None depends_on: None = None default_models_by_provider = { "openai": ["gpt-4", "gpt-4o", "gpt-4o-mini"], "bedrock": [ "meta.llama3-1-70b-instruct-v1:0", "meta.llama3-1-8b-instruct-v1:0", "anthropic.claude-3-opus-20240229-v1:0", "mistral.mistral-large-2402-v1:0", "anthropic.claude-3-5-sonnet-20240620-v1:0", ], "anthropic": ["claude-3-opus-20240229", "claude-3-5-sonnet-20240620"], } def upgrade() -> None: op.add_column( "llm_provider", sa.Column("display_model_names", postgresql.ARRAY(sa.String()), nullable=True), ) connection = op.get_bind() for provider, models in default_models_by_provider.items(): connection.execute( sa.text( "UPDATE llm_provider SET display_model_names = :models WHERE provider = :provider" ), {"models": models, "provider": provider}, ) def downgrade() -> None: op.drop_column("llm_provider", "display_model_names") ================================================ FILE: backend/alembic/versions/47433d30de82_create_indexattempt_table.py ================================================ """Create IndexAttempt table Revision ID: 47433d30de82 Revises: Create Date: 2023-05-04 00:55:32.971991 """ import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "47433d30de82" down_revision: None = None branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "index_attempt", sa.Column("id", sa.Integer(), nullable=False), # String type since python enum will change often sa.Column( "source", sa.String(), nullable=False, ), # String type to easily accomodate new ways of pulling # in documents sa.Column( "input_type", sa.String(), nullable=False, ), sa.Column( "connector_specific_config", postgresql.JSONB(), nullable=False, ), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True, ), sa.Column( "time_updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), server_onupdate=sa.text("now()"), # type: ignore nullable=True, ), sa.Column( "status", sa.Enum( "NOT_STARTED", "IN_PROGRESS", "SUCCESS", "FAILED", name="indexingstatus", native_enum=False, ), nullable=False, ), sa.Column("document_ids", postgresql.ARRAY(sa.String()), nullable=True), sa.Column("error_msg", sa.String(), nullable=True), sa.PrimaryKeyConstraint("id"), ) def downgrade() -> None: op.drop_table("index_attempt") ================================================ FILE: backend/alembic/versions/475fcefe8826_add_name_to_api_key.py ================================================ """Add name to api_key Revision ID: 475fcefe8826 Revises: ecab2b3f1a3b Create Date: 2024-04-11 11:05:18.414438 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "475fcefe8826" down_revision = "ecab2b3f1a3b" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column("api_key", sa.Column("name", sa.String(), nullable=True)) def downgrade() -> None: op.drop_column("api_key", "name") ================================================ FILE: backend/alembic/versions/4794bc13e484_update_prompt_length.py ================================================ """update prompt length Revision ID: 4794bc13e484 Revises: f7505c5b0284 Create Date: 2025-04-02 11:26:36.180328 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "4794bc13e484" down_revision = "f7505c5b0284" branch_labels = None depends_on = None def upgrade() -> None: op.alter_column( "prompt", "system_prompt", existing_type=sa.TEXT(), type_=sa.String(length=5000000), existing_nullable=False, ) op.alter_column( "prompt", "task_prompt", existing_type=sa.TEXT(), type_=sa.String(length=5000000), existing_nullable=False, ) def downgrade() -> None: op.alter_column( "prompt", "system_prompt", existing_type=sa.String(length=5000000), type_=sa.TEXT(), existing_nullable=False, ) op.alter_column( "prompt", "task_prompt", existing_type=sa.String(length=5000000), type_=sa.TEXT(), existing_nullable=False, ) ================================================ FILE: backend/alembic/versions/47a07e1a38f1_fix_invalid_model_configurations_state.py ================================================ """Fix invalid model-configurations state Revision ID: 47a07e1a38f1 Revises: 7a70b7664e37 Create Date: 2025-04-23 15:39:43.159504 """ from alembic import op from pydantic import BaseModel, ConfigDict import sqlalchemy as sa from sqlalchemy.dialects import postgresql from onyx.llm.well_known_providers.llm_provider_options import ( fetch_model_names_for_provider_as_set, fetch_visible_model_names_for_provider_as_set, ) # revision identifiers, used by Alembic. revision = "47a07e1a38f1" down_revision = "7a70b7664e37" branch_labels = None depends_on = None class _SimpleModelConfiguration(BaseModel): # Configure model to read from attributes model_config = ConfigDict(from_attributes=True) id: int llm_provider_id: int name: str is_visible: bool max_input_tokens: int | None def upgrade() -> None: llm_provider_table = sa.sql.table( "llm_provider", sa.column("id", sa.Integer), sa.column("provider", sa.String), sa.column("model_names", postgresql.ARRAY(sa.String)), sa.column("display_model_names", postgresql.ARRAY(sa.String)), sa.column("default_model_name", sa.String), sa.column("fast_default_model_name", sa.String), ) model_configuration_table = sa.sql.table( "model_configuration", sa.column("id", sa.Integer), sa.column("llm_provider_id", sa.Integer), sa.column("name", sa.String), sa.column("is_visible", sa.Boolean), sa.column("max_input_tokens", sa.Integer), ) connection = op.get_bind() llm_providers = connection.execute( sa.select( llm_provider_table.c.id, llm_provider_table.c.provider, ) ).fetchall() for llm_provider in llm_providers: llm_provider_id, provider_name = llm_provider default_models = fetch_model_names_for_provider_as_set(provider_name) display_models = fetch_visible_model_names_for_provider_as_set( provider_name=provider_name ) # if `fetch_model_names_for_provider_as_set` returns `None`, then # that means that `provider_name` is not a well-known llm provider. if not default_models: continue if not display_models: raise RuntimeError( "If `default_models` is non-None, `display_models` must be non-None too." ) model_configurations = [ _SimpleModelConfiguration.model_validate(model_configuration) for model_configuration in connection.execute( sa.select( model_configuration_table.c.id, model_configuration_table.c.llm_provider_id, model_configuration_table.c.name, model_configuration_table.c.is_visible, model_configuration_table.c.max_input_tokens, ).where(model_configuration_table.c.llm_provider_id == llm_provider_id) ).fetchall() ] if model_configurations: at_least_one_is_visible = any( [ model_configuration.is_visible for model_configuration in model_configurations ] ) # If there is at least one model which is public, this is a valid state. # Therefore, don't touch it and move on to the next one. if at_least_one_is_visible: continue existing_visible_model_names: set[str] = set( [ model_configuration.name for model_configuration in model_configurations if model_configuration.is_visible ] ) difference = display_models.difference(existing_visible_model_names) for model_name in difference: if not model_name: continue insert_statement = postgresql.insert(model_configuration_table).values( llm_provider_id=llm_provider_id, name=model_name, is_visible=True, max_input_tokens=None, ) connection.execute( insert_statement.on_conflict_do_update( index_elements=["llm_provider_id", "name"], set_={"is_visible": insert_statement.excluded.is_visible}, ) ) else: for model_name in default_models: connection.execute( model_configuration_table.insert().values( llm_provider_id=llm_provider_id, name=model_name, is_visible=model_name in display_models, max_input_tokens=None, ) ) def downgrade() -> None: pass ================================================ FILE: backend/alembic/versions/47e5bef3a1d7_add_persona_categories.py ================================================ """add persona categories Revision ID: 47e5bef3a1d7 Revises: dfbe9e93d3c7 Create Date: 2024-11-05 18:55:02.221064 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "47e5bef3a1d7" down_revision = "dfbe9e93d3c7" branch_labels = None depends_on = None def upgrade() -> None: # Create the persona_category table op.create_table( "persona_category", sa.Column("id", sa.Integer(), nullable=False), sa.Column("name", sa.String(), nullable=False), sa.Column("description", sa.String(), nullable=True), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("name"), ) # Add category_id to persona table op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True)) op.create_foreign_key( "fk_persona_category", "persona", "persona_category", ["category_id"], ["id"], ondelete="SET NULL", ) def downgrade() -> None: op.drop_constraint("persona_category_id_fkey", "persona", type_="foreignkey") op.drop_column("persona", "category_id") op.drop_table("persona_category") ================================================ FILE: backend/alembic/versions/48d14957fe80_add_support_for_custom_tools.py ================================================ """Add support for custom tools Revision ID: 48d14957fe80 Revises: b85f02ec1308 Create Date: 2024-06-09 14:58:19.946509 """ from alembic import op import fastapi_users_db_sqlalchemy import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "48d14957fe80" down_revision = "b85f02ec1308" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "tool", sa.Column( "openapi_schema", postgresql.JSONB(astext_type=sa.Text()), nullable=True, ), ) op.add_column( "tool", sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=True, ), ) op.create_foreign_key("tool_user_fk", "tool", "user", ["user_id"], ["id"]) op.create_table( "tool_call", sa.Column("id", sa.Integer(), primary_key=True), sa.Column("tool_id", sa.Integer(), nullable=False), sa.Column("tool_name", sa.String(), nullable=False), sa.Column( "tool_arguments", postgresql.JSONB(astext_type=sa.Text()), nullable=False ), sa.Column( "tool_result", postgresql.JSONB(astext_type=sa.Text()), nullable=False ), sa.Column( "message_id", sa.Integer(), sa.ForeignKey("chat_message.id"), nullable=False ), ) def downgrade() -> None: op.drop_table("tool_call") op.drop_constraint("tool_user_fk", "tool", type_="foreignkey") op.drop_column("tool", "user_id") op.drop_column("tool", "openapi_schema") ================================================ FILE: backend/alembic/versions/495cb26ce93e_create_knowlege_graph_tables.py ================================================ """create knowledge graph tables Revision ID: 495cb26ce93e Revises: ca04500b9ee8 Create Date: 2025-03-19 08:51:14.341989 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql from sqlalchemy import text from datetime import datetime, timedelta from onyx.configs.app_configs import DB_READONLY_USER from onyx.configs.app_configs import DB_READONLY_PASSWORD from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA # revision identifiers, used by Alembic. revision = "495cb26ce93e" down_revision = "ca04500b9ee8" branch_labels = None depends_on = None def upgrade() -> None: # Create a new permission-less user to be later used for knowledge graph queries. # The user will later get temporary read privileges for a specific view that will be # ad hoc generated specific to a knowledge graph query. # # Note: in order for the migration to run, the DB_READONLY_USER and DB_READONLY_PASSWORD # environment variables MUST be set. Otherwise, an exception will be raised. if not MULTI_TENANT: # Enable pg_trgm extension if not already enabled op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm") # Create read-only db user here only in single tenant mode. For multi-tenant mode, # the user is created in the alembic_tenants migration. if not (DB_READONLY_USER and DB_READONLY_PASSWORD): raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set") op.execute( text( f""" DO $$ BEGIN -- Check if the read-only user already exists IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN -- Create the read-only user with the specified password EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}'); -- First revoke all privileges to ensure a clean slate EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}'); -- Grant only the CONNECT privilege to allow the user to connect to the database -- but not perform any operations without additional specific grants EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}'); END IF; END $$; """ ) ) # Grant usage on current schema to readonly user op.execute( text( f""" DO $$ BEGIN IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN EXECUTE format('GRANT USAGE ON SCHEMA %I TO %I', current_schema(), '{DB_READONLY_USER}'); END IF; END $$; """ ) ) op.execute("DROP TABLE IF EXISTS kg_config CASCADE") op.create_table( "kg_config", sa.Column("id", sa.Integer(), primary_key=True, nullable=False, index=True), sa.Column("kg_variable_name", sa.String(), nullable=False, index=True), sa.Column("kg_variable_values", postgresql.ARRAY(sa.String()), nullable=False), sa.UniqueConstraint("kg_variable_name", name="uq_kg_config_variable_name"), ) # Insert initial data into kg_config table op.bulk_insert( sa.table( "kg_config", sa.column("kg_variable_name", sa.String), sa.column("kg_variable_values", postgresql.ARRAY(sa.String)), ), [ {"kg_variable_name": "KG_EXPOSED", "kg_variable_values": ["false"]}, {"kg_variable_name": "KG_ENABLED", "kg_variable_values": ["false"]}, {"kg_variable_name": "KG_VENDOR", "kg_variable_values": []}, {"kg_variable_name": "KG_VENDOR_DOMAINS", "kg_variable_values": []}, {"kg_variable_name": "KG_IGNORE_EMAIL_DOMAINS", "kg_variable_values": []}, { "kg_variable_name": "KG_EXTRACTION_IN_PROGRESS", "kg_variable_values": ["false"], }, { "kg_variable_name": "KG_CLUSTERING_IN_PROGRESS", "kg_variable_values": ["false"], }, { "kg_variable_name": "KG_COVERAGE_START", "kg_variable_values": [ (datetime.now() - timedelta(days=90)).strftime("%Y-%m-%d") ], }, {"kg_variable_name": "KG_MAX_COVERAGE_DAYS", "kg_variable_values": ["90"]}, { "kg_variable_name": "KG_MAX_PARENT_RECURSION_DEPTH", "kg_variable_values": ["2"], }, ], ) op.execute("DROP TABLE IF EXISTS kg_entity_type CASCADE") op.create_table( "kg_entity_type", sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True), sa.Column("description", sa.String(), nullable=True), sa.Column("grounding", sa.String(), nullable=False), sa.Column( "attributes", postgresql.JSONB, nullable=False, server_default="{}", ), sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False), sa.Column("active", sa.Boolean(), nullable=False, default=False), sa.Column("deep_extraction", sa.Boolean(), nullable=False, default=False), sa.Column( "time_updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), onupdate=sa.text("now()"), ), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()") ), sa.Column("grounded_source_name", sa.String(), nullable=True), sa.Column("entity_values", postgresql.ARRAY(sa.String()), nullable=True), sa.Column( "clustering", postgresql.JSONB, nullable=False, server_default="{}", ), ) op.execute("DROP TABLE IF EXISTS kg_relationship_type CASCADE") # Create KGRelationshipType table op.create_table( "kg_relationship_type", sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True), sa.Column("name", sa.String(), nullable=False, index=True), sa.Column( "source_entity_type_id_name", sa.String(), nullable=False, index=True ), sa.Column( "target_entity_type_id_name", sa.String(), nullable=False, index=True ), sa.Column("definition", sa.Boolean(), nullable=False, default=False), sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False), sa.Column("type", sa.String(), nullable=False, index=True), sa.Column("active", sa.Boolean(), nullable=False, default=True), sa.Column( "time_updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), onupdate=sa.text("now()"), ), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()") ), sa.Column( "clustering", postgresql.JSONB, nullable=False, server_default="{}", ), sa.ForeignKeyConstraint( ["source_entity_type_id_name"], ["kg_entity_type.id_name"] ), sa.ForeignKeyConstraint( ["target_entity_type_id_name"], ["kg_entity_type.id_name"] ), ) op.execute("DROP TABLE IF EXISTS kg_relationship_type_extraction_staging CASCADE") # Create KGRelationshipTypeExtractionStaging table op.create_table( "kg_relationship_type_extraction_staging", sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True), sa.Column("name", sa.String(), nullable=False, index=True), sa.Column( "source_entity_type_id_name", sa.String(), nullable=False, index=True ), sa.Column( "target_entity_type_id_name", sa.String(), nullable=False, index=True ), sa.Column("definition", sa.Boolean(), nullable=False, default=False), sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False), sa.Column("type", sa.String(), nullable=False, index=True), sa.Column("active", sa.Boolean(), nullable=False, default=True), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()") ), sa.Column( "clustering", postgresql.JSONB, nullable=False, server_default="{}", ), sa.Column("transferred", sa.Boolean(), nullable=False, server_default="false"), sa.ForeignKeyConstraint( ["source_entity_type_id_name"], ["kg_entity_type.id_name"] ), sa.ForeignKeyConstraint( ["target_entity_type_id_name"], ["kg_entity_type.id_name"] ), ) op.execute("DROP TABLE IF EXISTS kg_entity CASCADE") # Create KGEntity table op.create_table( "kg_entity", sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True), sa.Column("name", sa.String(), nullable=False, index=True), sa.Column("entity_class", sa.String(), nullable=True, index=True), sa.Column("entity_subtype", sa.String(), nullable=True, index=True), sa.Column("entity_key", sa.String(), nullable=True, index=True), sa.Column("name_trigrams", postgresql.ARRAY(sa.String(3)), nullable=True), sa.Column("document_id", sa.String(), nullable=True, index=True), sa.Column( "alternative_names", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}", ), sa.Column("entity_type_id_name", sa.String(), nullable=False, index=True), sa.Column("description", sa.String(), nullable=True), sa.Column( "keywords", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}", ), sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False), sa.Column( "acl", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}" ), sa.Column("boosts", postgresql.JSONB, nullable=False, server_default="{}"), sa.Column("attributes", postgresql.JSONB, nullable=False, server_default="{}"), sa.Column("event_time", sa.DateTime(timezone=True), nullable=True), sa.Column( "time_updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), onupdate=sa.text("now()"), ), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()") ), sa.ForeignKeyConstraint(["entity_type_id_name"], ["kg_entity_type.id_name"]), sa.ForeignKeyConstraint(["document_id"], ["document.id"]), sa.UniqueConstraint( "name", "entity_type_id_name", "document_id", name="uq_kg_entity_name_type_doc", ), ) op.create_index("ix_entity_type_acl", "kg_entity", ["entity_type_id_name", "acl"]) op.create_index( "ix_entity_name_search", "kg_entity", ["name", "entity_type_id_name"] ) op.execute("DROP TABLE IF EXISTS kg_entity_extraction_staging CASCADE") # Create KGEntityExtractionStaging table op.create_table( "kg_entity_extraction_staging", sa.Column("id_name", sa.String(), primary_key=True, nullable=False, index=True), sa.Column("name", sa.String(), nullable=False, index=True), sa.Column("document_id", sa.String(), nullable=True, index=True), sa.Column( "alternative_names", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}", ), sa.Column("entity_type_id_name", sa.String(), nullable=False, index=True), sa.Column("description", sa.String(), nullable=True), sa.Column( "keywords", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}", ), sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False), sa.Column( "acl", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}" ), sa.Column("boosts", postgresql.JSONB, nullable=False, server_default="{}"), sa.Column("attributes", postgresql.JSONB, nullable=False, server_default="{}"), sa.Column("transferred_id_name", sa.String(), nullable=True, default=None), sa.Column("entity_class", sa.String(), nullable=True, index=True), sa.Column("entity_key", sa.String(), nullable=True, index=True), sa.Column("entity_subtype", sa.String(), nullable=True, index=True), sa.Column("parent_key", sa.String(), nullable=True, index=True), sa.Column("event_time", sa.DateTime(timezone=True), nullable=True), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()") ), sa.ForeignKeyConstraint(["entity_type_id_name"], ["kg_entity_type.id_name"]), sa.ForeignKeyConstraint(["document_id"], ["document.id"]), ) op.create_index( "ix_entity_extraction_staging_acl", "kg_entity_extraction_staging", ["entity_type_id_name", "acl"], ) op.create_index( "ix_entity_extraction_staging_name_search", "kg_entity_extraction_staging", ["name", "entity_type_id_name"], ) op.execute("DROP TABLE IF EXISTS kg_relationship CASCADE") # Create KGRelationship table op.create_table( "kg_relationship", sa.Column("id_name", sa.String(), nullable=False, index=True), sa.Column("source_node", sa.String(), nullable=False, index=True), sa.Column("target_node", sa.String(), nullable=False, index=True), sa.Column("source_node_type", sa.String(), nullable=False, index=True), sa.Column("target_node_type", sa.String(), nullable=False, index=True), sa.Column("source_document", sa.String(), nullable=True, index=True), sa.Column("type", sa.String(), nullable=False, index=True), sa.Column("relationship_type_id_name", sa.String(), nullable=False, index=True), sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False), sa.Column( "time_updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), onupdate=sa.text("now()"), ), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()") ), sa.ForeignKeyConstraint(["source_node"], ["kg_entity.id_name"]), sa.ForeignKeyConstraint(["target_node"], ["kg_entity.id_name"]), sa.ForeignKeyConstraint(["source_node_type"], ["kg_entity_type.id_name"]), sa.ForeignKeyConstraint(["target_node_type"], ["kg_entity_type.id_name"]), sa.ForeignKeyConstraint(["source_document"], ["document.id"]), sa.ForeignKeyConstraint( ["relationship_type_id_name"], ["kg_relationship_type.id_name"] ), sa.UniqueConstraint( "source_node", "target_node", "type", name="uq_kg_relationship_source_target_type", ), sa.PrimaryKeyConstraint("id_name", "source_document"), ) op.create_index( "ix_kg_relationship_nodes", "kg_relationship", ["source_node", "target_node"] ) op.execute("DROP TABLE IF EXISTS kg_relationship_extraction_staging CASCADE") # Create KGRelationshipExtractionStaging table op.create_table( "kg_relationship_extraction_staging", sa.Column("id_name", sa.String(), nullable=False, index=True), sa.Column("source_node", sa.String(), nullable=False, index=True), sa.Column("target_node", sa.String(), nullable=False, index=True), sa.Column("source_node_type", sa.String(), nullable=False, index=True), sa.Column("target_node_type", sa.String(), nullable=False, index=True), sa.Column("source_document", sa.String(), nullable=True, index=True), sa.Column("type", sa.String(), nullable=False, index=True), sa.Column("relationship_type_id_name", sa.String(), nullable=False, index=True), sa.Column("occurrences", sa.Integer(), server_default="1", nullable=False), sa.Column("transferred", sa.Boolean(), nullable=False, server_default="false"), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()") ), sa.ForeignKeyConstraint( ["source_node"], ["kg_entity_extraction_staging.id_name"] ), sa.ForeignKeyConstraint( ["target_node"], ["kg_entity_extraction_staging.id_name"] ), sa.ForeignKeyConstraint(["source_node_type"], ["kg_entity_type.id_name"]), sa.ForeignKeyConstraint(["target_node_type"], ["kg_entity_type.id_name"]), sa.ForeignKeyConstraint(["source_document"], ["document.id"]), sa.ForeignKeyConstraint( ["relationship_type_id_name"], ["kg_relationship_type_extraction_staging.id_name"], ), sa.UniqueConstraint( "source_node", "target_node", "type", name="uq_kg_relationship_extraction_staging_source_target_type", ), sa.PrimaryKeyConstraint("id_name", "source_document"), ) op.create_index( "ix_kg_relationship_extraction_staging_nodes", "kg_relationship_extraction_staging", ["source_node", "target_node"], ) op.execute("DROP TABLE IF EXISTS kg_term CASCADE") # Create KGTerm table op.create_table( "kg_term", sa.Column("id_term", sa.String(), primary_key=True, nullable=False, index=True), sa.Column( "entity_types", postgresql.ARRAY(sa.String()), nullable=False, server_default="{}", ), sa.Column( "time_updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), onupdate=sa.text("now()"), ), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()") ), ) op.create_index("ix_search_term_entities", "kg_term", ["entity_types"]) op.create_index("ix_search_term_term", "kg_term", ["id_term"]) op.add_column( "document", sa.Column("kg_stage", sa.String(), nullable=True, index=True), ) op.add_column( "document", sa.Column("kg_processing_time", sa.DateTime(timezone=True), nullable=True), ) op.add_column( "connector", sa.Column( "kg_processing_enabled", sa.Boolean(), nullable=True, server_default="false", ), ) op.add_column( "connector", sa.Column( "kg_coverage_days", sa.Integer(), nullable=True, server_default=None, ), ) # Create GIN index for clustering and normalization op.execute( "CREATE INDEX IF NOT EXISTS idx_kg_entity_clustering_trigrams " f"ON kg_entity USING GIN (name {POSTGRES_DEFAULT_SCHEMA}.gin_trgm_ops)" ) op.execute( "CREATE INDEX IF NOT EXISTS idx_kg_entity_normalization_trigrams ON kg_entity USING GIN (name_trigrams)" ) # Create kg_entity trigger to update kg_entity.name and its trigrams alphanum_pattern = r"[^a-z0-9]+" truncate_length = 1000 function = "update_kg_entity_name" op.execute( text( f""" CREATE OR REPLACE FUNCTION {function}() RETURNS TRIGGER AS $$ DECLARE name text; cleaned_name text; BEGIN -- Set name to semantic_id if document_id is not NULL IF NEW.document_id IS NOT NULL THEN SELECT lower(semantic_id) INTO name FROM document WHERE id = NEW.document_id; ELSE name = lower(NEW.name); END IF; -- Clean name and truncate if too long cleaned_name = regexp_replace( name, '{alphanum_pattern}', '', 'g' ); IF length(cleaned_name) > {truncate_length} THEN cleaned_name = left(cleaned_name, {truncate_length}); END IF; -- Set name and name trigrams NEW.name = name; NEW.name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name); RETURN NEW; END; $$ LANGUAGE plpgsql; """ ) ) trigger = f"{function}_trigger" op.execute(f"DROP TRIGGER IF EXISTS {trigger} ON kg_entity") op.execute( f""" CREATE TRIGGER {trigger} BEFORE INSERT OR UPDATE OF name ON kg_entity FOR EACH ROW EXECUTE FUNCTION {function}(); """ ) # Create kg_entity trigger to update kg_entity.name and its trigrams function = "update_kg_entity_name_from_doc" op.execute( text( f""" CREATE OR REPLACE FUNCTION {function}() RETURNS TRIGGER AS $$ DECLARE doc_name text; cleaned_name text; BEGIN doc_name = lower(NEW.semantic_id); -- Clean name and truncate if too long cleaned_name = regexp_replace( doc_name, '{alphanum_pattern}', '', 'g' ); IF length(cleaned_name) > {truncate_length} THEN cleaned_name = left(cleaned_name, {truncate_length}); END IF; -- Set name and name trigrams for all entities referencing this document UPDATE kg_entity SET name = doc_name, name_trigrams = {POSTGRES_DEFAULT_SCHEMA}.show_trgm(cleaned_name) WHERE document_id = NEW.id; RETURN NEW; END; $$ LANGUAGE plpgsql; """ ) ) trigger = f"{function}_trigger" op.execute(f"DROP TRIGGER IF EXISTS {trigger} ON document") op.execute( f""" CREATE TRIGGER {trigger} AFTER UPDATE OF semantic_id ON document FOR EACH ROW EXECUTE FUNCTION {function}(); """ ) def downgrade() -> None: # Drop all views that start with 'kg_' op.execute( """ DO $$ DECLARE view_name text; BEGIN FOR view_name IN SELECT c.relname FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE c.relkind = 'v' AND n.nspname = current_schema() AND c.relname LIKE 'kg_relationships_with_access%' LOOP EXECUTE 'DROP VIEW IF EXISTS ' || quote_ident(view_name); END LOOP; END $$; """ ) op.execute( """ DO $$ DECLARE view_name text; BEGIN FOR view_name IN SELECT c.relname FROM pg_catalog.pg_class c JOIN pg_catalog.pg_namespace n ON n.oid = c.relnamespace WHERE c.relkind = 'v' AND n.nspname = current_schema() AND c.relname LIKE 'allowed_docs%' LOOP EXECUTE 'DROP VIEW IF EXISTS ' || quote_ident(view_name); END LOOP; END $$; """ ) for table, function in ( ("kg_entity", "update_kg_entity_name"), ("document", "update_kg_entity_name_from_doc"), ): op.execute(f"DROP TRIGGER IF EXISTS {function}_trigger ON {table}") op.execute(f"DROP FUNCTION IF EXISTS {function}()") # Drop index op.execute("DROP INDEX IF EXISTS idx_kg_entity_clustering_trigrams") op.execute("DROP INDEX IF EXISTS idx_kg_entity_normalization_trigrams") # Drop tables in reverse order of creation to handle dependencies op.drop_table("kg_term") op.drop_table("kg_relationship") op.drop_table("kg_entity") op.drop_table("kg_relationship_type") op.drop_table("kg_relationship_extraction_staging") op.drop_table("kg_relationship_type_extraction_staging") op.drop_table("kg_entity_extraction_staging") op.drop_table("kg_entity_type") op.drop_column("connector", "kg_processing_enabled") op.drop_column("connector", "kg_coverage_days") op.drop_column("document", "kg_stage") op.drop_column("document", "kg_processing_time") op.drop_table("kg_config") # Revoke usage on current schema for the readonly user op.execute( text( f""" DO $$ BEGIN IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN EXECUTE format('REVOKE ALL ON SCHEMA %I FROM %I', current_schema(), '{DB_READONLY_USER}'); END IF; END $$; """ ) ) if not MULTI_TENANT: # Drop read-only db user here only in single tenant mode. For multi-tenant mode, # the user is dropped in the alembic_tenants migration. op.execute( text( f""" DO $$ BEGIN IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN -- First revoke all privileges from the database EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}'); -- Then drop the user EXECUTE format('DROP USER %I', '{DB_READONLY_USER}'); END IF; END $$; """ ) ) op.execute(text("DROP EXTENSION IF EXISTS pg_trgm")) ================================================ FILE: backend/alembic/versions/4a1e4b1c89d2_add_indexing_to_userfilestatus.py ================================================ """Add INDEXING to UserFileStatus Revision ID: 4a1e4b1c89d2 Revises: 6b3b4083c5aa Create Date: 2026-02-28 00:00:00.000000 """ import sqlalchemy as sa from alembic import op revision = "4a1e4b1c89d2" down_revision = "6b3b4083c5aa" branch_labels = None depends_on = None TABLE = "user_file" COLUMN = "status" CONSTRAINT_NAME = "ck_user_file_status" OLD_VALUES = ("PROCESSING", "COMPLETED", "FAILED", "CANCELED", "DELETING") NEW_VALUES = ("PROCESSING", "INDEXING", "COMPLETED", "FAILED", "CANCELED", "DELETING") def _drop_status_check_constraint() -> None: """Drop the existing CHECK constraint on user_file.status. The constraint name is auto-generated by SQLAlchemy and unknown, so we look it up via the inspector. """ inspector = sa.inspect(op.get_bind()) for constraint in inspector.get_check_constraints(TABLE): if COLUMN in constraint.get("sqltext", ""): constraint_name = constraint["name"] if constraint_name is not None: op.drop_constraint(constraint_name, TABLE, type_="check") def upgrade() -> None: _drop_status_check_constraint() in_clause = ", ".join(f"'{v}'" for v in NEW_VALUES) op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})") def downgrade() -> None: op.execute( f"UPDATE {TABLE} SET {COLUMN} = 'PROCESSING' WHERE {COLUMN} = 'INDEXING'" ) op.drop_constraint(CONSTRAINT_NAME, TABLE, type_="check") in_clause = ", ".join(f"'{v}'" for v in OLD_VALUES) op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})") ================================================ FILE: backend/alembic/versions/4a951134c801_moved_status_to_connector_credential_.py ================================================ """Moved status to connector credential pair Revision ID: 4a951134c801 Revises: 7477a5f5d728 Create Date: 2024-08-10 19:20:34.527559 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "4a951134c801" down_revision = "7477a5f5d728" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "connector_credential_pair", sa.Column( "status", sa.Enum( "ACTIVE", "PAUSED", "DELETING", name="connectorcredentialpairstatus", native_enum=False, ), nullable=True, ), ) # Update status of connector_credential_pair based on connector's disabled status op.execute( """ UPDATE connector_credential_pair SET status = CASE WHEN ( SELECT disabled FROM connector WHERE connector.id = connector_credential_pair.connector_id ) = FALSE THEN 'ACTIVE' ELSE 'PAUSED' END """ ) # Make the status column not nullable after setting values op.alter_column("connector_credential_pair", "status", nullable=False) op.drop_column("connector", "disabled") def downgrade() -> None: op.add_column( "connector", sa.Column("disabled", sa.BOOLEAN(), autoincrement=False, nullable=True), ) # Update disabled status of connector based on connector_credential_pair's status op.execute( """ UPDATE connector SET disabled = CASE WHEN EXISTS ( SELECT 1 FROM connector_credential_pair WHERE connector_credential_pair.connector_id = connector.id AND connector_credential_pair.status = 'ACTIVE' ) THEN FALSE ELSE TRUE END """ ) # Make the disabled column not nullable after setting values op.alter_column("connector", "disabled", nullable=False) op.drop_column("connector_credential_pair", "status") ================================================ FILE: backend/alembic/versions/4b08d97e175a_change_default_prune_freq.py ================================================ """change default prune_freq Revision ID: 4b08d97e175a Revises: d9ec13955951 Create Date: 2024-08-20 15:28:52.993827 """ from alembic import op # revision identifiers, used by Alembic. revision = "4b08d97e175a" down_revision = "d9ec13955951" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.execute( """ UPDATE connector SET prune_freq = 2592000 WHERE prune_freq = 86400 """ ) def downgrade() -> None: op.execute( """ UPDATE connector SET prune_freq = 86400 WHERE prune_freq = 2592000 """ ) ================================================ FILE: backend/alembic/versions/4cebcbc9b2ae_add_tab_index_to_tool_call.py ================================================ """add tab_index to tool_call Revision ID: 4cebcbc9b2ae Revises: a1b2c3d4e5f6 Create Date: 2025-12-16 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "4cebcbc9b2ae" down_revision = "a1b2c3d4e5f6" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "tool_call", sa.Column("tab_index", sa.Integer(), nullable=False, server_default="0"), ) def downgrade() -> None: op.drop_column("tool_call", "tab_index") ================================================ FILE: backend/alembic/versions/4d58345da04a_lowercase_user_emails.py ================================================ """lowercase_user_emails Revision ID: 4d58345da04a Revises: f1ca58b2f2ec Create Date: 2025-01-29 07:48:46.784041 """ import logging from typing import cast from alembic import op from sqlalchemy.exc import IntegrityError from sqlalchemy.sql import text # revision identifiers, used by Alembic. revision = "4d58345da04a" down_revision = "f1ca58b2f2ec" branch_labels = None depends_on = None logger = logging.getLogger("alembic.runtime.migration") def upgrade() -> None: """Conflicts on lowercasing will result in the uppercased email getting a unique integer suffix when converted to lowercase.""" connection = op.get_bind() # Fetch all user emails that are not already lowercase user_emails = connection.execute( text('SELECT id, email FROM "user" WHERE email != LOWER(email)') ).fetchall() for user_id, email in user_emails: email = cast(str, email) username, domain = email.rsplit("@", 1) new_email = f"{username.lower()}@{domain.lower()}" attempt = 1 while True: try: # Try updating the email connection.execute( text('UPDATE "user" SET email = :new_email WHERE id = :user_id'), {"new_email": new_email, "user_id": user_id}, ) break # Success, exit loop except IntegrityError: next_email = f"{username.lower()}_{attempt}@{domain.lower()}" # Email conflict occurred, append `_1`, `_2`, etc., to the username logger.warning( f"Conflict while lowercasing email: old_email={email} conflicting_email={new_email} next_email={next_email}" ) new_email = next_email attempt += 1 def downgrade() -> None: # Cannot restore original case of emails pass ================================================ FILE: backend/alembic/versions/4ea2c93919c1_add_type_to_credentials.py ================================================ """Add type to credentials Revision ID: 4ea2c93919c1 Revises: 473a1a7ca408 Create Date: 2024-07-18 13:07:13.655895 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "4ea2c93919c1" down_revision = "473a1a7ca408" branch_labels: None = None depends_on: None = None def upgrade() -> None: # Add the new 'source' column to the 'credential' table op.add_column( "credential", sa.Column( "source", sa.String(length=100), # Use String instead of Enum nullable=True, # Initially allow NULL values ), ) op.add_column( "credential", sa.Column( "name", sa.String(), nullable=True, ), ) # Create a temporary table that maps each credential to a single connector source. # This is needed because a credential can be associated with multiple connectors, # but we want to assign a single source to each credential. # We use DISTINCT ON to ensure we only get one row per credential_id. op.execute( """ CREATE TEMPORARY TABLE temp_connector_credential AS SELECT DISTINCT ON (cc.credential_id) cc.credential_id, c.source AS connector_source FROM connector_credential_pair cc JOIN connector c ON cc.connector_id = c.id """ ) # Update the 'source' column in the 'credential' table op.execute( """ UPDATE credential cred SET source = COALESCE( (SELECT connector_source FROM temp_connector_credential temp WHERE cred.id = temp.credential_id), 'NOT_APPLICABLE' ) """ ) # Drop the temporary table to avoid conflicts if migration runs again # (e.g., during upgrade -> downgrade -> upgrade cycles in tests) op.execute("DROP TABLE IF EXISTS temp_connector_credential") # If no exception was raised, alter the column op.alter_column("credential", "source", nullable=True) # TODO modify # # ### end Alembic commands ### def downgrade() -> None: op.drop_column("credential", "source") op.drop_column("credential", "name") ================================================ FILE: backend/alembic/versions/4ee1287bd26a_add_multiple_slack_bot_support.py ================================================ """add_multiple_slack_bot_support Revision ID: 4ee1287bd26a Revises: 47e5bef3a1d7 Create Date: 2024-11-06 13:15:53.302644 """ from typing import cast from alembic import op import sqlalchemy as sa from sqlalchemy.orm import Session from onyx.key_value_store.factory import get_kv_store from onyx.db.models import SlackBot from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "4ee1287bd26a" down_revision = "47e5bef3a1d7" branch_labels: None = None depends_on: None = None def upgrade() -> None: # Create new slack_bot table op.create_table( "slack_bot", sa.Column("id", sa.Integer(), nullable=False), sa.Column("name", sa.String(), nullable=False), sa.Column("enabled", sa.Boolean(), nullable=False, server_default="true"), sa.Column("bot_token", sa.LargeBinary(), nullable=False), sa.Column("app_token", sa.LargeBinary(), nullable=False), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("bot_token"), sa.UniqueConstraint("app_token"), ) # # Create new slack_channel_config table op.create_table( "slack_channel_config", sa.Column("id", sa.Integer(), nullable=False), sa.Column("slack_bot_id", sa.Integer(), nullable=True), sa.Column("persona_id", sa.Integer(), nullable=True), sa.Column("channel_config", postgresql.JSONB(), nullable=False), sa.Column("response_type", sa.String(), nullable=False), sa.Column( "enable_auto_filters", sa.Boolean(), nullable=False, server_default="false" ), sa.ForeignKeyConstraint( ["slack_bot_id"], ["slack_bot.id"], ), sa.ForeignKeyConstraint( ["persona_id"], ["persona.id"], ), sa.PrimaryKeyConstraint("id"), ) # Handle existing Slack bot tokens first bot_token = None app_token = None first_row_id = None try: tokens = cast(dict, get_kv_store().load("slack_bot_tokens_config_key")) except Exception: tokens = {} bot_token = tokens.get("bot_token") app_token = tokens.get("app_token") if bot_token and app_token: session = Session(bind=op.get_bind()) new_slack_bot = SlackBot( name="Slack Bot (Migrated)", enabled=True, bot_token=bot_token, app_token=app_token, ) session.add(new_slack_bot) session.commit() first_row_id = new_slack_bot.id # Create a default bot if none exists # This is in case there are no slack tokens but there are channels configured op.execute( sa.text( """ INSERT INTO slack_bot (name, enabled, bot_token, app_token) SELECT 'Default Bot', true, '', '' WHERE NOT EXISTS (SELECT 1 FROM slack_bot) RETURNING id; """ ) ) # Get the bot ID to use (either from existing migration or newly created) bot_id_query = sa.text( """ SELECT COALESCE( :first_row_id, (SELECT id FROM slack_bot ORDER BY id ASC LIMIT 1) ) as bot_id; """ ) result = op.get_bind().execute(bot_id_query, {"first_row_id": first_row_id}) bot_id = result.scalar() # CTE (Common Table Expression) that transforms the old slack_bot_config table data # This splits up the channel_names into their own rows channel_names_cte = """ WITH channel_names AS ( SELECT sbc.id as config_id, sbc.persona_id, sbc.response_type, sbc.enable_auto_filters, jsonb_array_elements_text(sbc.channel_config->'channel_names') as channel_name, sbc.channel_config->>'respond_tag_only' as respond_tag_only, sbc.channel_config->>'respond_to_bots' as respond_to_bots, sbc.channel_config->'respond_member_group_list' as respond_member_group_list, sbc.channel_config->'answer_filters' as answer_filters, sbc.channel_config->'follow_up_tags' as follow_up_tags FROM slack_bot_config sbc ) """ # Insert the channel names into the new slack_channel_config table insert_statement = """ INSERT INTO slack_channel_config ( slack_bot_id, persona_id, channel_config, response_type, enable_auto_filters ) SELECT :bot_id, channel_name.persona_id, jsonb_build_object( 'channel_name', channel_name.channel_name, 'respond_tag_only', COALESCE((channel_name.respond_tag_only)::boolean, false), 'respond_to_bots', COALESCE((channel_name.respond_to_bots)::boolean, false), 'respond_member_group_list', COALESCE(channel_name.respond_member_group_list, '[]'::jsonb), 'answer_filters', COALESCE(channel_name.answer_filters, '[]'::jsonb), 'follow_up_tags', COALESCE(channel_name.follow_up_tags, '[]'::jsonb) ), channel_name.response_type, channel_name.enable_auto_filters FROM channel_names channel_name; """ op.execute(sa.text(channel_names_cte + insert_statement).bindparams(bot_id=bot_id)) # Clean up old tokens if they existed try: if bot_token and app_token: get_kv_store().delete("slack_bot_tokens_config_key") except Exception: pass # Rename the table op.rename_table( "slack_bot_config__standard_answer_category", "slack_channel_config__standard_answer_category", ) # Rename the column op.alter_column( "slack_channel_config__standard_answer_category", "slack_bot_config_id", new_column_name="slack_channel_config_id", ) # Drop the table with CASCADE to handle dependent objects op.execute("DROP TABLE slack_bot_config CASCADE") def downgrade() -> None: # Recreate the old slack_bot_config table op.create_table( "slack_bot_config", sa.Column("id", sa.Integer(), nullable=False), sa.Column("persona_id", sa.Integer(), nullable=True), sa.Column("channel_config", postgresql.JSONB(), nullable=False), sa.Column("response_type", sa.String(), nullable=False), sa.Column("enable_auto_filters", sa.Boolean(), nullable=False), sa.ForeignKeyConstraint( ["persona_id"], ["persona.id"], ), sa.PrimaryKeyConstraint("id"), ) # Migrate data back to the old format # Group by persona_id to combine channel names back into arrays op.execute( sa.text( """ INSERT INTO slack_bot_config ( persona_id, channel_config, response_type, enable_auto_filters ) SELECT DISTINCT ON (persona_id) persona_id, jsonb_build_object( 'channel_names', ( SELECT jsonb_agg(c.channel_config->>'channel_name') FROM slack_channel_config c WHERE c.persona_id = scc.persona_id ), 'respond_tag_only', (channel_config->>'respond_tag_only')::boolean, 'respond_to_bots', (channel_config->>'respond_to_bots')::boolean, 'respond_member_group_list', channel_config->'respond_member_group_list', 'answer_filters', channel_config->'answer_filters', 'follow_up_tags', channel_config->'follow_up_tags' ), response_type, enable_auto_filters FROM slack_channel_config scc WHERE persona_id IS NOT NULL; """ ) ) # Rename the table back op.rename_table( "slack_channel_config__standard_answer_category", "slack_bot_config__standard_answer_category", ) # Rename the column back op.alter_column( "slack_bot_config__standard_answer_category", "slack_channel_config_id", new_column_name="slack_bot_config_id", ) # Try to save the first bot's tokens back to KV store try: first_bot = ( op.get_bind() .execute( sa.text( "SELECT bot_token, app_token FROM slack_bot ORDER BY id LIMIT 1" ) ) .first() ) if first_bot and first_bot.bot_token and first_bot.app_token: tokens = { "bot_token": first_bot.bot_token, "app_token": first_bot.app_token, } get_kv_store().store("slack_bot_tokens_config_key", tokens) except Exception: pass # Drop the new tables in reverse order op.drop_table("slack_channel_config") op.drop_table("slack_bot") ================================================ FILE: backend/alembic/versions/4f8a2b3c1d9e_add_open_url_tool.py ================================================ """add_open_url_tool Revision ID: 4f8a2b3c1d9e Revises: a852cbe15577 Create Date: 2025-11-24 12:00:00.000000 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "4f8a2b3c1d9e" down_revision = "a852cbe15577" branch_labels = None depends_on = None OPEN_URL_TOOL = { "name": "OpenURLTool", "display_name": "Open URL", "description": ( "The Open URL Action allows the agent to fetch and read contents of web pages." ), "in_code_tool_id": "OpenURLTool", "enabled": True, } def upgrade() -> None: conn = op.get_bind() # Check if tool already exists existing = conn.execute( sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"), {"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]}, ).fetchone() if existing: tool_id = existing[0] # Update existing tool conn.execute( sa.text( """ UPDATE tool SET name = :name, display_name = :display_name, description = :description WHERE in_code_tool_id = :in_code_tool_id """ ), OPEN_URL_TOOL, ) else: # Insert new tool conn.execute( sa.text( """ INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled) VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled) """ ), OPEN_URL_TOOL, ) # Get the newly inserted tool's id result = conn.execute( sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"), {"in_code_tool_id": OPEN_URL_TOOL["in_code_tool_id"]}, ).fetchone() tool_id = result[0] # type: ignore # Associate the tool with all existing personas # Get all persona IDs persona_ids = conn.execute(sa.text("SELECT id FROM persona")).fetchall() for (persona_id,) in persona_ids: # Check if association already exists exists = conn.execute( sa.text( """ SELECT 1 FROM persona__tool WHERE persona_id = :persona_id AND tool_id = :tool_id """ ), {"persona_id": persona_id, "tool_id": tool_id}, ).fetchone() if not exists: conn.execute( sa.text( """ INSERT INTO persona__tool (persona_id, tool_id) VALUES (:persona_id, :tool_id) """ ), {"persona_id": persona_id, "tool_id": tool_id}, ) def downgrade() -> None: # We don't remove the tool on downgrade since it's fine to have it around. # If we upgrade again, it will be a no-op. pass ================================================ FILE: backend/alembic/versions/503883791c39_add_effective_permissions.py ================================================ """add_effective_permissions Adds a JSONB column `effective_permissions` to the user table to store directly granted permissions (e.g. ["admin"] or ["basic"]). Implied permissions are expanded at read time, not stored. Backfill: joins user__user_group → permission_grant to collect each user's granted permissions into a JSON array. Users without group memberships keep the default []. Revision ID: 503883791c39 Revises: b4b7e1028dfd Create Date: 2026-03-30 14:49:22.261748 """ from collections.abc import Sequence from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "503883791c39" down_revision = "b4b7e1028dfd" branch_labels: str | None = None depends_on: str | Sequence[str] | None = None user_table = sa.table( "user", sa.column("id", sa.Uuid), sa.column("effective_permissions", postgresql.JSONB), ) user_user_group = sa.table( "user__user_group", sa.column("user_id", sa.Uuid), sa.column("user_group_id", sa.Integer), ) permission_grant = sa.table( "permission_grant", sa.column("group_id", sa.Integer), sa.column("permission", sa.String), sa.column("is_deleted", sa.Boolean), ) def upgrade() -> None: op.add_column( "user", sa.Column( "effective_permissions", postgresql.JSONB(), nullable=False, server_default=sa.text("'[]'::jsonb"), ), ) conn = op.get_bind() # Deduplicated permissions per user deduped = ( sa.select( user_user_group.c.user_id, permission_grant.c.permission, ) .select_from( user_user_group.join( permission_grant, sa.and_( permission_grant.c.group_id == user_user_group.c.user_group_id, permission_grant.c.is_deleted == sa.false(), ), ) ) .distinct() .subquery("deduped") ) # Aggregate into JSONB array per user (order is not guaranteed; # consumers read this as a set so ordering does not matter) perms_per_user = ( sa.select( deduped.c.user_id, sa.func.jsonb_agg( deduped.c.permission, type_=postgresql.JSONB, ).label("perms"), ) .group_by(deduped.c.user_id) .subquery("sub") ) conn.execute( user_table.update() .where(user_table.c.id == perms_per_user.c.user_id) .values(effective_permissions=perms_per_user.c.perms) ) def downgrade() -> None: op.drop_column("user", "effective_permissions") ================================================ FILE: backend/alembic/versions/505c488f6662_merge_default_assistants_into_unified.py ================================================ """merge_default_assistants_into_unified Revision ID: 505c488f6662 Revises: d09fc20a3c66 Create Date: 2025-09-09 19:00:56.816626 """ import json from typing import Any from typing import NamedTuple from uuid import UUID from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "505c488f6662" down_revision = "d09fc20a3c66" branch_labels = None depends_on = None # Constants for the unified assistant UNIFIED_ASSISTANT_NAME = "Assistant" UNIFIED_ASSISTANT_DESCRIPTION = ( "Your AI assistant with search, web browsing, and image generation capabilities." ) UNIFIED_ASSISTANT_NUM_CHUNKS = 25 UNIFIED_ASSISTANT_DISPLAY_PRIORITY = 0 UNIFIED_ASSISTANT_LLM_FILTER_EXTRACTION = True UNIFIED_ASSISTANT_LLM_RELEVANCE_FILTER = False UNIFIED_ASSISTANT_RECENCY_BIAS = "AUTO" # NOTE: needs to be capitalized UNIFIED_ASSISTANT_CHUNKS_ABOVE = 0 UNIFIED_ASSISTANT_CHUNKS_BELOW = 0 UNIFIED_ASSISTANT_DATETIME_AWARE = True # NOTE: tool specific prompts are handled on the fly and automatically injected # into the prompt before passing to the LLM. DEFAULT_SYSTEM_PROMPT = """ You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the \ user's intent, ask clarifying questions when needed, think step-by-step through complex problems, \ provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always \ prioritize being truthful, nuanced, insightful, and efficient. The current date is [[CURRENT_DATETIME]] You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make \ your responses more readable and engaging. You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, \ symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline. For code you prefer to use Markdown and specify the language. You can use Markdown horizontal rules (---) to separate sections of your responses. You can use Markdown tables to format your responses for data, lists, and other structured information. """.strip() INSERT_DICT: dict[str, Any] = { "name": UNIFIED_ASSISTANT_NAME, "description": UNIFIED_ASSISTANT_DESCRIPTION, "system_prompt": DEFAULT_SYSTEM_PROMPT, "num_chunks": UNIFIED_ASSISTANT_NUM_CHUNKS, "display_priority": UNIFIED_ASSISTANT_DISPLAY_PRIORITY, "llm_filter_extraction": UNIFIED_ASSISTANT_LLM_FILTER_EXTRACTION, "llm_relevance_filter": UNIFIED_ASSISTANT_LLM_RELEVANCE_FILTER, "recency_bias": UNIFIED_ASSISTANT_RECENCY_BIAS, "chunks_above": UNIFIED_ASSISTANT_CHUNKS_ABOVE, "chunks_below": UNIFIED_ASSISTANT_CHUNKS_BELOW, "datetime_aware": UNIFIED_ASSISTANT_DATETIME_AWARE, } GENERAL_ASSISTANT_ID = -1 ART_ASSISTANT_ID = -3 class UserRow(NamedTuple): """Typed representation of user row from database query.""" id: UUID chosen_assistants: list[int] | None visible_assistants: list[int] | None hidden_assistants: list[int] | None pinned_assistants: list[int] | None def upgrade() -> None: conn = op.get_bind() # Step 1: Create or update the unified assistant (ID 0) search_assistant = conn.execute( sa.text("SELECT * FROM persona WHERE id = 0") ).fetchone() if search_assistant: # Update existing Search assistant to be the unified assistant conn.execute( sa.text( """ UPDATE persona SET name = :name, description = :description, system_prompt = :system_prompt, num_chunks = :num_chunks, is_default_persona = true, is_visible = true, deleted = false, display_priority = :display_priority, llm_filter_extraction = :llm_filter_extraction, llm_relevance_filter = :llm_relevance_filter, recency_bias = :recency_bias, chunks_above = :chunks_above, chunks_below = :chunks_below, datetime_aware = :datetime_aware, starter_messages = null WHERE id = 0 """ ), INSERT_DICT, ) else: # Create new unified assistant with ID 0 conn.execute( sa.text( """ INSERT INTO persona ( id, name, description, system_prompt, num_chunks, is_default_persona, is_visible, deleted, display_priority, llm_filter_extraction, llm_relevance_filter, recency_bias, chunks_above, chunks_below, datetime_aware, starter_messages, builtin_persona ) VALUES ( 0, :name, :description, :system_prompt, :num_chunks, true, true, false, :display_priority, :llm_filter_extraction, :llm_relevance_filter, :recency_bias, :chunks_above, :chunks_below, :datetime_aware, null, true ) """ ), INSERT_DICT, ) # Step 2: Mark ALL builtin assistants as deleted (except the unified assistant ID 0) conn.execute( sa.text( """ UPDATE persona SET deleted = true, is_visible = false, is_default_persona = false WHERE builtin_persona = true AND id != 0 """ ) ) # Step 3: Add all built-in tools to the unified assistant # First, get the tool IDs for SearchTool, ImageGenerationTool, and WebSearchTool search_tool = conn.execute( sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'SearchTool'") ).fetchone() if not search_tool: raise ValueError( "SearchTool not found in database. Ensure tools migration has run first." ) image_gen_tool = conn.execute( sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'ImageGenerationTool'") ).fetchone() if not image_gen_tool: raise ValueError( "ImageGenerationTool not found in database. Ensure tools migration has run first." ) # WebSearchTool is optional - may not be configured web_search_tool = conn.execute( sa.text("SELECT id FROM tool WHERE in_code_tool_id = 'WebSearchTool'") ).fetchone() # Clear existing tool associations for persona 0 conn.execute(sa.text("DELETE FROM persona__tool WHERE persona_id = 0")) # Add tools to the unified assistant conn.execute( sa.text( """ INSERT INTO persona__tool (persona_id, tool_id) VALUES (0, :tool_id) ON CONFLICT DO NOTHING """ ), {"tool_id": search_tool[0]}, ) conn.execute( sa.text( """ INSERT INTO persona__tool (persona_id, tool_id) VALUES (0, :tool_id) ON CONFLICT DO NOTHING """ ), {"tool_id": image_gen_tool[0]}, ) if web_search_tool: conn.execute( sa.text( """ INSERT INTO persona__tool (persona_id, tool_id) VALUES (0, :tool_id) ON CONFLICT DO NOTHING """ ), {"tool_id": web_search_tool[0]}, ) # Step 4: Migrate existing chat sessions from all builtin assistants to unified assistant conn.execute( sa.text( """ UPDATE chat_session SET persona_id = 0 WHERE persona_id IN ( SELECT id FROM persona WHERE builtin_persona = true AND id != 0 ) """ ) ) # Step 5: Migrate user preferences - remove references to all builtin assistants # First, get all builtin assistant IDs (except 0) builtin_assistants_result = conn.execute( sa.text( """ SELECT id FROM persona WHERE builtin_persona = true AND id != 0 """ ) ).fetchall() builtin_assistant_ids = [row[0] for row in builtin_assistants_result] # Get all users with preferences users_result = conn.execute( sa.text( """ SELECT id, chosen_assistants, visible_assistants, hidden_assistants, pinned_assistants FROM "user" """ ) ).fetchall() for user_row in users_result: user = UserRow(*user_row) user_id: UUID = user.id updates: dict[str, Any] = {} # Remove all builtin assistants from chosen_assistants if user.chosen_assistants: new_chosen: list[int] = [ assistant_id for assistant_id in user.chosen_assistants if assistant_id not in builtin_assistant_ids ] if new_chosen != user.chosen_assistants: updates["chosen_assistants"] = json.dumps(new_chosen) # Remove all builtin assistants from visible_assistants if user.visible_assistants: new_visible: list[int] = [ assistant_id for assistant_id in user.visible_assistants if assistant_id not in builtin_assistant_ids ] if new_visible != user.visible_assistants: updates["visible_assistants"] = json.dumps(new_visible) # Add all builtin assistants to hidden_assistants if user.hidden_assistants: new_hidden: list[int] = list(user.hidden_assistants) for old_id in builtin_assistant_ids: if old_id not in new_hidden: new_hidden.append(old_id) if new_hidden != user.hidden_assistants: updates["hidden_assistants"] = json.dumps(new_hidden) else: updates["hidden_assistants"] = json.dumps(builtin_assistant_ids) # Remove all builtin assistants from pinned_assistants if user.pinned_assistants: new_pinned: list[int] = [ assistant_id for assistant_id in user.pinned_assistants if assistant_id not in builtin_assistant_ids ] if new_pinned != user.pinned_assistants: updates["pinned_assistants"] = json.dumps(new_pinned) # Apply updates if any if updates: set_clause = ", ".join([f"{k} = :{k}" for k in updates.keys()]) updates["user_id"] = str(user_id) # Convert UUID to string for SQL conn.execute( sa.text(f'UPDATE "user" SET {set_clause} WHERE id = :user_id'), updates, ) def downgrade() -> None: conn = op.get_bind() # Only restore General (ID -1) and Art (ID -3) assistants # Step 1: Keep Search assistant (ID 0) as default but restore original state conn.execute( sa.text( """ UPDATE persona SET is_default_persona = true, is_visible = true, deleted = false WHERE id = 0 """ ) ) # Step 2: Restore General assistant (ID -1) conn.execute( sa.text( """ UPDATE persona SET deleted = false, is_visible = true, is_default_persona = true WHERE id = :general_assistant_id """ ), {"general_assistant_id": GENERAL_ASSISTANT_ID}, ) # Step 3: Restore Art assistant (ID -3) conn.execute( sa.text( """ UPDATE persona SET deleted = false, is_visible = true, is_default_persona = true WHERE id = :art_assistant_id """ ), {"art_assistant_id": ART_ASSISTANT_ID}, ) # Note: We don't restore the original tool associations, names, or descriptions # as those would require more complex logic to determine original state. # We also cannot restore original chat session persona_ids as we don't # have the original mappings. # Other builtin assistants remain deleted as per the requirement. ================================================ FILE: backend/alembic/versions/50b683a8295c_add_additional_retrieval_controls_to_.py ================================================ """Add additional retrieval controls to Persona Revision ID: 50b683a8295c Revises: 7da0ae5ad583 Create Date: 2023-11-27 17:23:29.668422 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "50b683a8295c" down_revision = "7da0ae5ad583" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column("persona", sa.Column("num_chunks", sa.Integer(), nullable=True)) op.add_column( "persona", sa.Column("apply_llm_relevance_filter", sa.Boolean(), nullable=True), ) def downgrade() -> None: op.drop_column("persona", "apply_llm_relevance_filter") op.drop_column("persona", "num_chunks") ================================================ FILE: backend/alembic/versions/52a219fb5233_add_last_synced_and_last_modified_to_document_table.py ================================================ """Add last synced and last modified to document table Revision ID: 52a219fb5233 Revises: f7e58d357687 Create Date: 2024-08-28 17:40:46.077470 """ from alembic import op import sqlalchemy as sa from sqlalchemy.sql import func # revision identifiers, used by Alembic. revision = "52a219fb5233" down_revision = "f7e58d357687" branch_labels = None depends_on = None def upgrade() -> None: # last modified represents the last time anything needing syncing to vespa changed # including row metadata and the document itself. This obviously does not include # the last_synced column. op.add_column( "document", sa.Column( "last_modified", sa.DateTime(timezone=True), nullable=False, server_default=func.now(), ), ) # last synced represents the last time this document was synced to Vespa op.add_column( "document", sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True), ) # Set last_synced to the same value as last_modified for existing rows op.execute( """ UPDATE document SET last_synced = last_modified """ ) op.create_index( op.f("ix_document_last_modified"), "document", ["last_modified"], unique=False, ) op.create_index( op.f("ix_document_last_synced"), "document", ["last_synced"], unique=False, ) def downgrade() -> None: op.drop_index(op.f("ix_document_last_synced"), table_name="document") op.drop_index(op.f("ix_document_last_modified"), table_name="document") op.drop_column("document", "last_synced") op.drop_column("document", "last_modified") ================================================ FILE: backend/alembic/versions/54a74a0417fc_danswerbot_onyxbot.py ================================================ """danswerbot -> onyxbot Revision ID: 54a74a0417fc Revises: 94dc3d0236f8 Create Date: 2024-12-11 18:05:05.490737 """ from alembic import op # revision identifiers, used by Alembic. revision = "54a74a0417fc" down_revision = "94dc3d0236f8" branch_labels = None depends_on = None def upgrade() -> None: op.alter_column("chat_session", "danswerbot_flow", new_column_name="onyxbot_flow") def downgrade() -> None: op.alter_column("chat_session", "onyxbot_flow", new_column_name="danswerbot_flow") ================================================ FILE: backend/alembic/versions/55546a7967ee_assistant_rework.py ================================================ """assistant_rework Revision ID: 55546a7967ee Revises: 61ff3651add4 Create Date: 2024-09-18 17:00:23.755399 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "55546a7967ee" down_revision = "61ff3651add4" branch_labels = None depends_on = None def upgrade() -> None: # Reworking persona and user tables for new assistant features # keep track of user's chosen assistants separate from their `ordering` op.add_column("persona", sa.Column("builtin_persona", sa.Boolean(), nullable=True)) op.execute("UPDATE persona SET builtin_persona = default_persona") op.alter_column("persona", "builtin_persona", nullable=False) op.drop_index("_default_persona_name_idx", table_name="persona") op.create_index( "_builtin_persona_name_idx", "persona", ["name"], unique=True, postgresql_where=sa.text("builtin_persona = true"), ) op.add_column( "user", sa.Column("visible_assistants", postgresql.JSONB(), nullable=True) ) op.add_column( "user", sa.Column("hidden_assistants", postgresql.JSONB(), nullable=True) ) op.execute( "UPDATE \"user\" SET visible_assistants = '[]'::jsonb, hidden_assistants = '[]'::jsonb" ) op.alter_column( "user", "visible_assistants", nullable=False, server_default=sa.text("'[]'::jsonb"), ) op.alter_column( "user", "hidden_assistants", nullable=False, server_default=sa.text("'[]'::jsonb"), ) op.drop_column("persona", "default_persona") op.add_column( "persona", sa.Column("is_default_persona", sa.Boolean(), nullable=True) ) def downgrade() -> None: # Reverting changes made in upgrade op.drop_column("user", "hidden_assistants") op.drop_column("user", "visible_assistants") op.drop_index("_builtin_persona_name_idx", table_name="persona") op.drop_column("persona", "is_default_persona") op.add_column("persona", sa.Column("default_persona", sa.Boolean(), nullable=True)) op.execute("UPDATE persona SET default_persona = builtin_persona") op.alter_column("persona", "default_persona", nullable=False) op.drop_column("persona", "builtin_persona") op.create_index( "_default_persona_name_idx", "persona", ["name"], unique=True, postgresql_where=sa.text("default_persona = true"), ) ================================================ FILE: backend/alembic/versions/570282d33c49_track_onyxbot_explicitly.py ================================================ """Track Onyxbot Explicitly Revision ID: 570282d33c49 Revises: 7547d982db8f Create Date: 2024-05-04 17:49:28.568109 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "570282d33c49" down_revision = "7547d982db8f" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "chat_session", sa.Column("danswerbot_flow", sa.Boolean(), nullable=True) ) op.execute("UPDATE chat_session SET danswerbot_flow = one_shot") op.alter_column("chat_session", "danswerbot_flow", nullable=False) def downgrade() -> None: op.drop_column("chat_session", "danswerbot_flow") ================================================ FILE: backend/alembic/versions/57122d037335_add_python_tool_on_default.py ================================================ """add python tool on default Revision ID: 57122d037335 Revises: c0c937d5c9e5 Create Date: 2026-02-27 10:10:40.124925 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "57122d037335" down_revision = "c0c937d5c9e5" branch_labels = None depends_on = None PYTHON_TOOL_NAME = "python" def upgrade() -> None: conn = op.get_bind() # Look up the PythonTool id result = conn.execute( sa.text("SELECT id FROM tool WHERE name = :name"), {"name": PYTHON_TOOL_NAME}, ).fetchone() if not result: return tool_id = result[0] # Attach to the default persona (id=0) if not already attached conn.execute( sa.text( """ INSERT INTO persona__tool (persona_id, tool_id) VALUES (0, :tool_id) ON CONFLICT DO NOTHING """ ), {"tool_id": tool_id}, ) def downgrade() -> None: conn = op.get_bind() result = conn.execute( sa.text("SELECT id FROM tool WHERE name = :name"), {"name": PYTHON_TOOL_NAME}, ).fetchone() if not result: return conn.execute( sa.text( """ DELETE FROM persona__tool WHERE persona_id = 0 AND tool_id = :tool_id """ ), {"tool_id": result[0]}, ) ================================================ FILE: backend/alembic/versions/57b53544726e_add_document_set_tables.py ================================================ """Add document set tables Revision ID: 57b53544726e Revises: 800f48024ae9 Create Date: 2023-09-20 16:59:39.097177 """ from alembic import op import fastapi_users_db_sqlalchemy import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "57b53544726e" down_revision = "800f48024ae9" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "document_set", sa.Column("id", sa.Integer(), nullable=False), sa.Column("name", sa.String(), nullable=False), sa.Column("description", sa.String(), nullable=False), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=True, ), sa.Column("is_up_to_date", sa.Boolean(), nullable=False), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("name"), ) op.create_table( "document_set__connector_credential_pair", sa.Column("document_set_id", sa.Integer(), nullable=False), sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False), sa.Column("is_current", sa.Boolean(), nullable=False), sa.ForeignKeyConstraint( ["connector_credential_pair_id"], ["connector_credential_pair.id"], ), sa.ForeignKeyConstraint( ["document_set_id"], ["document_set.id"], ), sa.PrimaryKeyConstraint( "document_set_id", "connector_credential_pair_id", "is_current" ), ) def downgrade() -> None: op.drop_table("document_set__connector_credential_pair") op.drop_table("document_set") ================================================ FILE: backend/alembic/versions/5809c0787398_add_chat_sessions.py ================================================ """Add Chat Sessions Revision ID: 5809c0787398 Revises: d929f0c1c6af Create Date: 2023-09-04 15:29:44.002164 """ import fastapi_users_db_sqlalchemy from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "5809c0787398" down_revision = "d929f0c1c6af" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "chat_session", sa.Column("id", sa.Integer(), nullable=False), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=True, ), sa.Column("description", sa.Text(), nullable=False), sa.Column("deleted", sa.Boolean(), nullable=False), sa.Column( "time_updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("id"), ) op.create_table( "chat_message", sa.Column("chat_session_id", sa.Integer(), nullable=False), sa.Column("message_number", sa.Integer(), nullable=False), sa.Column("edit_number", sa.Integer(), nullable=False), sa.Column("parent_edit_number", sa.Integer(), nullable=True), sa.Column("latest", sa.Boolean(), nullable=False), sa.Column("message", sa.Text(), nullable=False), sa.Column( "message_type", sa.Enum( "SYSTEM", "USER", "ASSISTANT", "DANSWER", name="messagetype", native_enum=False, ), nullable=False, ), sa.Column( "time_sent", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.ForeignKeyConstraint( ["chat_session_id"], ["chat_session.id"], ), sa.PrimaryKeyConstraint("chat_session_id", "message_number", "edit_number"), ) def downgrade() -> None: op.drop_table("chat_message") op.drop_table("chat_session") ================================================ FILE: backend/alembic/versions/58c50ef19f08_add_stale_column_to_user__external_user_.py ================================================ """add stale column to external user group tables Revision ID: 58c50ef19f08 Revises: 7b9b952abdf6 Create Date: 2025-06-25 14:08:14.162380 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "58c50ef19f08" down_revision = "7b9b952abdf6" branch_labels = None depends_on = None def upgrade() -> None: # Add the stale column with default value False to user__external_user_group_id op.add_column( "user__external_user_group_id", sa.Column("stale", sa.Boolean(), nullable=False, server_default="false"), ) # Create index for efficient querying of stale rows by cc_pair_id op.create_index( "ix_user__external_user_group_id_cc_pair_id_stale", "user__external_user_group_id", ["cc_pair_id", "stale"], unique=False, ) # Create index for efficient querying of all stale rows op.create_index( "ix_user__external_user_group_id_stale", "user__external_user_group_id", ["stale"], unique=False, ) # Add the stale column with default value False to public_external_user_group op.add_column( "public_external_user_group", sa.Column("stale", sa.Boolean(), nullable=False, server_default="false"), ) # Create index for efficient querying of stale rows by cc_pair_id op.create_index( "ix_public_external_user_group_cc_pair_id_stale", "public_external_user_group", ["cc_pair_id", "stale"], unique=False, ) # Create index for efficient querying of all stale rows op.create_index( "ix_public_external_user_group_stale", "public_external_user_group", ["stale"], unique=False, ) def downgrade() -> None: # Drop the indices for public_external_user_group first op.drop_index( "ix_public_external_user_group_stale", table_name="public_external_user_group" ) op.drop_index( "ix_public_external_user_group_cc_pair_id_stale", table_name="public_external_user_group", ) # Drop the stale column from public_external_user_group op.drop_column("public_external_user_group", "stale") # Drop the indices for user__external_user_group_id op.drop_index( "ix_user__external_user_group_id_stale", table_name="user__external_user_group_id", ) op.drop_index( "ix_user__external_user_group_id_cc_pair_id_stale", table_name="user__external_user_group_id", ) # Drop the stale column from user__external_user_group_id op.drop_column("user__external_user_group_id", "stale") ================================================ FILE: backend/alembic/versions/5ae8240accb3_add_research_agent_database_tables_and_.py ================================================ """add research agent database tables and chat message research fields Revision ID: 5ae8240accb3 Revises: b558f51620b4 Create Date: 2025-08-06 14:29:24.691388 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "5ae8240accb3" down_revision = "b558f51620b4" branch_labels = None depends_on = None def upgrade() -> None: # Add research_type and research_plan columns to chat_message table op.add_column( "chat_message", sa.Column("research_type", sa.String(), nullable=True), ) op.add_column( "chat_message", sa.Column("research_plan", postgresql.JSONB(), nullable=True), ) # Create research_agent_iteration table op.create_table( "research_agent_iteration", sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), sa.Column( "primary_question_id", sa.Integer(), sa.ForeignKey("chat_message.id", ondelete="CASCADE"), nullable=False, ), sa.Column("iteration_nr", sa.Integer(), nullable=False), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False, ), sa.Column("purpose", sa.String(), nullable=True), sa.Column("reasoning", sa.String(), nullable=True), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint( "primary_question_id", "iteration_nr", name="_research_agent_iteration_unique_constraint", ), ) # Create research_agent_iteration_sub_step table op.create_table( "research_agent_iteration_sub_step", sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), sa.Column( "primary_question_id", sa.Integer(), sa.ForeignKey("chat_message.id", ondelete="CASCADE"), nullable=False, ), sa.Column( "parent_question_id", sa.Integer(), sa.ForeignKey("research_agent_iteration_sub_step.id", ondelete="CASCADE"), nullable=True, ), sa.Column("iteration_nr", sa.Integer(), nullable=False), sa.Column("iteration_sub_step_nr", sa.Integer(), nullable=False), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False, ), sa.Column("sub_step_instructions", sa.String(), nullable=True), sa.Column( "sub_step_tool_id", sa.Integer(), sa.ForeignKey("tool.id"), nullable=True, ), sa.Column("reasoning", sa.String(), nullable=True), sa.Column("sub_answer", sa.String(), nullable=True), sa.Column("cited_doc_results", postgresql.JSONB(), nullable=True), sa.Column("claims", postgresql.JSONB(), nullable=True), sa.Column("generated_images", postgresql.JSONB(), nullable=True), sa.Column("additional_data", postgresql.JSONB(), nullable=True), sa.PrimaryKeyConstraint("id"), sa.ForeignKeyConstraint( ["primary_question_id", "iteration_nr"], [ "research_agent_iteration.primary_question_id", "research_agent_iteration.iteration_nr", ], ondelete="CASCADE", ), ) def downgrade() -> None: # Drop tables in reverse order op.drop_table("research_agent_iteration_sub_step") op.drop_table("research_agent_iteration") # Remove columns from chat_message table op.drop_column("chat_message", "research_plan") op.drop_column("chat_message", "research_type") ================================================ FILE: backend/alembic/versions/5b29123cd710_nullable_search_settings_for_historic_.py ================================================ """nullable search settings for historic index attempts Revision ID: 5b29123cd710 Revises: 949b4a92a401 Create Date: 2024-10-30 19:37:59.630704 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "5b29123cd710" down_revision = "949b4a92a401" branch_labels = None depends_on = None def upgrade() -> None: # Drop the existing foreign key constraint op.drop_constraint( "fk_index_attempt_search_settings", "index_attempt", type_="foreignkey" ) # Modify the column to be nullable op.alter_column( "index_attempt", "search_settings_id", existing_type=sa.INTEGER(), nullable=True ) # Add back the foreign key with ON DELETE SET NULL op.create_foreign_key( "fk_index_attempt_search_settings", "index_attempt", "search_settings", ["search_settings_id"], ["id"], ondelete="SET NULL", ) def downgrade() -> None: # Warning: This will delete all index attempts that don't have search settings op.execute( """ DELETE FROM index_attempt WHERE search_settings_id IS NULL """ ) # Drop foreign key constraint op.drop_constraint( "fk_index_attempt_search_settings", "index_attempt", type_="foreignkey" ) # Modify the column to be not nullable op.alter_column( "index_attempt", "search_settings_id", existing_type=sa.INTEGER(), nullable=False, ) # Add back the foreign key without ON DELETE SET NULL op.create_foreign_key( "fk_index_attempt_search_settings", "index_attempt", "search_settings", ["search_settings_id"], ["id"], ) ================================================ FILE: backend/alembic/versions/5c3dca366b35_backend_driven_notification_details.py ================================================ """backend driven notification details Revision ID: 5c3dca366b35 Revises: 9087b548dd69 Create Date: 2026-01-06 16:03:11.413724 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "5c3dca366b35" down_revision = "9087b548dd69" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "notification", sa.Column( "title", sa.String(), nullable=False, server_default="New Notification" ), ) op.add_column( "notification", sa.Column("description", sa.String(), nullable=True, server_default=""), ) def downgrade() -> None: op.drop_column("notification", "title") op.drop_column("notification", "description") ================================================ FILE: backend/alembic/versions/5c448911b12f_add_content_type_to_userfile.py ================================================ """Add content type to UserFile Revision ID: 5c448911b12f Revises: 47a07e1a38f1 Create Date: 2025-04-25 16:59:48.182672 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "5c448911b12f" down_revision = "47a07e1a38f1" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column("user_file", sa.Column("content_type", sa.String(), nullable=True)) def downgrade() -> None: op.drop_column("user_file", "content_type") ================================================ FILE: backend/alembic/versions/5c7fdadae813_match_any_keywords_flag_for_standard_.py ================================================ """match_any_keywords flag for standard answers Revision ID: 5c7fdadae813 Revises: efb35676026c Create Date: 2024-09-13 18:52:59.256478 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "5c7fdadae813" down_revision = "efb35676026c" branch_labels = None depends_on = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.add_column( "standard_answer", sa.Column( "match_any_keywords", sa.Boolean(), nullable=False, server_default=sa.false(), ), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_column("standard_answer", "match_any_keywords") # ### end Alembic commands ### ================================================ FILE: backend/alembic/versions/5d12a446f5c0_add_api_version_and_deployment_name_to_.py ================================================ """add api_version and deployment_name to search settings Revision ID: 5d12a446f5c0 Revises: e4334d5b33ba Create Date: 2024-10-08 15:56:07.975636 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "5d12a446f5c0" down_revision = "e4334d5b33ba" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "embedding_provider", sa.Column("api_version", sa.String(), nullable=True) ) op.add_column( "embedding_provider", sa.Column("deployment_name", sa.String(), nullable=True) ) def downgrade() -> None: op.drop_column("embedding_provider", "deployment_name") op.drop_column("embedding_provider", "api_version") ================================================ FILE: backend/alembic/versions/5e1c073d48a3_add_personal_access_token_table.py ================================================ """add_personal_access_token_table Revision ID: 5e1c073d48a3 Revises: 09995b8811eb Create Date: 2025-10-30 17:30:24.308521 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "5e1c073d48a3" down_revision = "09995b8811eb" branch_labels = None depends_on = None def upgrade() -> None: # Create personal_access_token table op.create_table( "personal_access_token", sa.Column("id", sa.Integer(), nullable=False), sa.Column("name", sa.String(), nullable=False), sa.Column("hashed_token", sa.String(length=64), nullable=False), sa.Column("token_display", sa.String(), nullable=False), sa.Column( "user_id", postgresql.UUID(as_uuid=True), nullable=False, ), sa.Column( "expires_at", sa.DateTime(timezone=True), nullable=True, ), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column( "last_used_at", sa.DateTime(timezone=True), nullable=True, ), sa.Column( "is_revoked", sa.Boolean(), server_default=sa.text("false"), nullable=False, ), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ondelete="CASCADE", ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("hashed_token"), ) # Create indexes op.create_index( "ix_personal_access_token_expires_at", "personal_access_token", ["expires_at"], unique=False, ) op.create_index( "ix_pat_user_created", "personal_access_token", ["user_id", sa.text("created_at DESC")], unique=False, ) def downgrade() -> None: # Drop indexes first op.drop_index("ix_pat_user_created", table_name="personal_access_token") op.drop_index( "ix_personal_access_token_expires_at", table_name="personal_access_token" ) # Drop table op.drop_table("personal_access_token") ================================================ FILE: backend/alembic/versions/5e6f7a8b9c0d_update_default_persona_prompt.py ================================================ """update_default_persona_prompt Revision ID: 5e6f7a8b9c0d Revises: 4f8a2b3c1d9e Create Date: 2025-11-30 12:00:00.000000 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "5e6f7a8b9c0d" down_revision = "4f8a2b3c1d9e" branch_labels = None depends_on = None DEFAULT_PERSONA_ID = 0 # ruff: noqa: E501, W605 start DEFAULT_SYSTEM_PROMPT = """ You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the user's intent, ask clarifying questions when needed, think step-by-step through complex problems, provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always prioritize being truthful, nuanced, insightful, and efficient. The current date is [[CURRENT_DATETIME]].{citation_reminder_or_empty} # Response Style You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make your responses more readable and engaging. You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline. For code you prefer to use Markdown and specify the language. You can use horizontal rules (---) to separate sections of your responses. You can use Markdown tables to format your responses for data, lists, and other structured information. """.lstrip() # ruff: noqa: E501, W605 end def upgrade() -> None: conn = op.get_bind() conn.execute( sa.text( """ UPDATE persona SET system_prompt = :system_prompt WHERE id = :persona_id """ ), {"system_prompt": DEFAULT_SYSTEM_PROMPT, "persona_id": DEFAULT_PERSONA_ID}, ) def downgrade() -> None: # We don't revert the system prompt on downgrade since we don't know # what the previous value was. The new prompt is a reasonable default. pass ================================================ FILE: backend/alembic/versions/5e84129c8be3_add_docs_indexed_column_to_index_.py ================================================ """Add docs_indexed_column + time_started to index_attempt table Revision ID: 5e84129c8be3 Revises: e6a4bbc13fe4 Create Date: 2023-08-10 21:43:09.069523 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "5e84129c8be3" down_revision = "e6a4bbc13fe4" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "index_attempt", sa.Column("num_docs_indexed", sa.Integer()), ) op.add_column( "index_attempt", sa.Column( "time_started", sa.DateTime(timezone=True), nullable=True, ), ) def downgrade() -> None: op.drop_column("index_attempt", "time_started") op.drop_column("index_attempt", "num_docs_indexed") ================================================ FILE: backend/alembic/versions/5f4b8568a221_add_removed_documents_to_index_attempt.py ================================================ """add removed documents to index_attempt Revision ID: 5f4b8568a221 Revises: dbaa756c2ccf Create Date: 2024-02-16 15:02:03.319907 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "5f4b8568a221" down_revision = "8987770549c0" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "index_attempt", sa.Column("docs_removed_from_index", sa.Integer()), ) op.execute("UPDATE index_attempt SET docs_removed_from_index = 0") def downgrade() -> None: op.drop_column("index_attempt", "docs_removed_from_index") ================================================ FILE: backend/alembic/versions/5fc1f54cc252_hybrid_enum.py ================================================ """hybrid-enum Revision ID: 5fc1f54cc252 Revises: 1d6ad76d1f37 Create Date: 2024-08-06 15:35:40.278485 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "5fc1f54cc252" down_revision = "1d6ad76d1f37" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.drop_column("persona", "search_type") def downgrade() -> None: op.add_column("persona", sa.Column("search_type", sa.String(), nullable=True)) op.execute("UPDATE persona SET search_type = 'SEMANTIC'") op.alter_column("persona", "search_type", nullable=False) ================================================ FILE: backend/alembic/versions/61ff3651add4_add_permission_syncing.py ================================================ """Add Permission Syncing Revision ID: 61ff3651add4 Revises: 1b8206b29c5d Create Date: 2024-09-05 13:57:11.770413 """ import fastapi_users_db_sqlalchemy from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "61ff3651add4" down_revision = "1b8206b29c5d" branch_labels = None depends_on = None def upgrade() -> None: # Admin user who set up connectors will lose access to the docs temporarily # only way currently to give back access is to rerun from beginning op.add_column( "connector_credential_pair", sa.Column( "access_type", sa.String(), nullable=True, ), ) op.execute( "UPDATE connector_credential_pair SET access_type = 'PUBLIC' WHERE is_public = true" ) op.execute( "UPDATE connector_credential_pair SET access_type = 'PRIVATE' WHERE is_public = false" ) op.alter_column("connector_credential_pair", "access_type", nullable=False) op.add_column( "connector_credential_pair", sa.Column( "auto_sync_options", postgresql.JSONB(astext_type=sa.Text()), nullable=True, ), ) op.add_column( "connector_credential_pair", sa.Column("last_time_perm_sync", sa.DateTime(timezone=True), nullable=True), ) op.drop_column("connector_credential_pair", "is_public") op.add_column( "document", sa.Column("external_user_emails", postgresql.ARRAY(sa.String()), nullable=True), ) op.add_column( "document", sa.Column( "external_user_group_ids", postgresql.ARRAY(sa.String()), nullable=True ), ) op.add_column( "document", sa.Column("is_public", sa.Boolean(), nullable=True), ) op.create_table( "user__external_user_group_id", sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False ), sa.Column("external_user_group_id", sa.String(), nullable=False), sa.Column("cc_pair_id", sa.Integer(), nullable=False), sa.PrimaryKeyConstraint("user_id"), ) op.drop_column("external_permission", "user_id") op.drop_column("email_to_external_user_cache", "user_id") op.drop_table("permission_sync_run") op.drop_table("external_permission") op.drop_table("email_to_external_user_cache") def downgrade() -> None: op.add_column( "connector_credential_pair", sa.Column("is_public", sa.BOOLEAN(), nullable=True), ) op.execute( "UPDATE connector_credential_pair SET is_public = (access_type = 'PUBLIC')" ) op.alter_column("connector_credential_pair", "is_public", nullable=False) op.drop_column("connector_credential_pair", "auto_sync_options") op.drop_column("connector_credential_pair", "access_type") op.drop_column("connector_credential_pair", "last_time_perm_sync") op.drop_column("document", "external_user_emails") op.drop_column("document", "external_user_group_ids") op.drop_column("document", "is_public") op.drop_table("user__external_user_group_id") # Drop the enum type at the end of the downgrade op.create_table( "permission_sync_run", sa.Column("id", sa.Integer(), nullable=False), sa.Column( "source_type", sa.String(), nullable=False, ), sa.Column("update_type", sa.String(), nullable=False), sa.Column("cc_pair_id", sa.Integer(), nullable=True), sa.Column( "status", sa.String(), nullable=False, ), sa.Column("error_msg", sa.Text(), nullable=True), sa.Column( "updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.ForeignKeyConstraint( ["cc_pair_id"], ["connector_credential_pair.id"], ), sa.PrimaryKeyConstraint("id"), ) op.create_table( "external_permission", sa.Column("id", sa.Integer(), nullable=False), sa.Column("user_id", sa.UUID(), nullable=True), sa.Column("user_email", sa.String(), nullable=False), sa.Column( "source_type", sa.String(), nullable=False, ), sa.Column("external_permission_group", sa.String(), nullable=False), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("id"), ) op.create_table( "email_to_external_user_cache", sa.Column("id", sa.Integer(), nullable=False), sa.Column("external_user_id", sa.String(), nullable=False), sa.Column("user_id", sa.UUID(), nullable=True), sa.Column("user_email", sa.String(), nullable=False), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("id"), ) ================================================ FILE: backend/alembic/versions/62c3a055a141_add_file_names_to_file_connector_config.py ================================================ """add file names to file connector config Revision ID: 62c3a055a141 Revises: 3fc5d75723b3 Create Date: 2025-07-30 17:01:24.417551 """ from alembic import op import sqlalchemy as sa import json import os import logging # revision identifiers, used by Alembic. revision = "62c3a055a141" down_revision = "3fc5d75723b3" branch_labels = None depends_on = None SKIP_FILE_NAME_MIGRATION = ( os.environ.get("SKIP_FILE_NAME_MIGRATION", "true").lower() == "true" ) logger = logging.getLogger("alembic.runtime.migration") def upgrade() -> None: if SKIP_FILE_NAME_MIGRATION: logger.info( "Skipping file name migration. Hint: set SKIP_FILE_NAME_MIGRATION=false to run this migration" ) return logger.info("Running file name migration") # Get connection conn = op.get_bind() # Get all FILE connectors with their configs file_connectors = conn.execute( sa.text( """ SELECT id, connector_specific_config FROM connector WHERE source = 'FILE' """ ) ).fetchall() for connector_id, config in file_connectors: # Parse config if it's a string if isinstance(config, str): config = json.loads(config) # Get file_locations list file_locations = config.get("file_locations", []) # Get display names for each file_id file_names = [] for file_id in file_locations: result = conn.execute( sa.text( """ SELECT display_name FROM file_record WHERE file_id = :file_id """ ), {"file_id": file_id}, ).fetchone() if result: file_names.append(result[0]) else: file_names.append(file_id) # Should not happen # Add file_names to config new_config = dict(config) new_config["file_names"] = file_names # Update the connector conn.execute( sa.text( """ UPDATE connector SET connector_specific_config = :new_config WHERE id = :connector_id """ ), {"connector_id": connector_id, "new_config": json.dumps(new_config)}, ) def downgrade() -> None: # Get connection conn = op.get_bind() # Remove file_names from all FILE connectors file_connectors = conn.execute( sa.text( """ SELECT id, connector_specific_config FROM connector WHERE source = 'FILE' """ ) ).fetchall() for connector_id, config in file_connectors: # Parse config if it's a string if isinstance(config, str): config = json.loads(config) # Remove file_names if it exists if "file_names" in config: new_config = dict(config) del new_config["file_names"] # Update the connector conn.execute( sa.text( """ UPDATE connector SET connector_specific_config = :new_config WHERE id = :connector_id """ ), { "connector_id": connector_id, "new_config": json.dumps(new_config), }, ) ================================================ FILE: backend/alembic/versions/631fd2504136_add_approx_chunk_count_in_vespa_to_.py ================================================ """add approx_chunk_count_in_vespa to opensearch tenant migration Revision ID: 631fd2504136 Revises: c7f2e1b4a9d3 Create Date: 2026-02-18 21:07:52.831215 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "631fd2504136" down_revision = "c7f2e1b4a9d3" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "opensearch_tenant_migration_record", sa.Column( "approx_chunk_count_in_vespa", sa.Integer(), nullable=True, ), ) def downgrade() -> None: op.drop_column("opensearch_tenant_migration_record", "approx_chunk_count_in_vespa") ================================================ FILE: backend/alembic/versions/6436661d5b65_add_created_at_in_project_userfile.py ================================================ """add_created_at_in_project_userfile Revision ID: 6436661d5b65 Revises: c7e9f4a3b2d1 Create Date: 2025-11-24 11:50:24.536052 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "6436661d5b65" down_revision = "c7e9f4a3b2d1" branch_labels = None depends_on = None def upgrade() -> None: # Add created_at column to project__user_file table op.add_column( "project__user_file", sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), ) # Add composite index on (project_id, created_at DESC) op.create_index( "ix_project__user_file_project_id_created_at", "project__user_file", ["project_id", sa.text("created_at DESC")], ) def downgrade() -> None: # Remove composite index on (project_id, created_at) op.drop_index( "ix_project__user_file_project_id_created_at", table_name="project__user_file" ) # Remove created_at column from project__user_file table op.drop_column("project__user_file", "created_at") ================================================ FILE: backend/alembic/versions/643a84a42a33_add_user_configured_names_to_llmprovider.py ================================================ """Add user-configured names to LLMProvider Revision ID: 643a84a42a33 Revises: 0a98909f2757 Create Date: 2024-05-07 14:54:55.493100 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "643a84a42a33" down_revision = "0a98909f2757" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column("llm_provider", sa.Column("provider", sa.String(), nullable=True)) # move "name" -> "provider" to match the new schema op.execute("UPDATE llm_provider SET provider = name") # pretty up display name op.execute("UPDATE llm_provider SET name = 'OpenAI' WHERE name = 'openai'") op.execute("UPDATE llm_provider SET name = 'Anthropic' WHERE name = 'anthropic'") op.execute("UPDATE llm_provider SET name = 'Azure OpenAI' WHERE name = 'azure'") op.execute("UPDATE llm_provider SET name = 'AWS Bedrock' WHERE name = 'bedrock'") # update personas to use the new provider names op.execute( "UPDATE persona SET llm_model_provider_override = 'OpenAI' WHERE llm_model_provider_override = 'openai'" ) op.execute( "UPDATE persona SET llm_model_provider_override = 'Anthropic' WHERE llm_model_provider_override = 'anthropic'" ) op.execute( "UPDATE persona SET llm_model_provider_override = 'Azure OpenAI' WHERE llm_model_provider_override = 'azure'" ) op.execute( "UPDATE persona SET llm_model_provider_override = 'AWS Bedrock' WHERE llm_model_provider_override = 'bedrock'" ) def downgrade() -> None: op.execute("UPDATE llm_provider SET name = provider") op.drop_column("llm_provider", "provider") ================================================ FILE: backend/alembic/versions/64bd5677aeb6_add_image_input_support_to_model_config.py ================================================ """Add image input support to model config Revision ID: 64bd5677aeb6 Revises: b30353be4eec Create Date: 2025-09-28 15:48:12.003612 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "64bd5677aeb6" down_revision = "b30353be4eec" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "model_configuration", sa.Column("supports_image_input", sa.Boolean(), nullable=True), ) # Seems to be left over from when model visibility was introduced and a nullable field. # Set any null is_visible values to False connection = op.get_bind() connection.execute( sa.text( "UPDATE model_configuration SET is_visible = false WHERE is_visible IS NULL" ) ) def downgrade() -> None: op.drop_column("model_configuration", "supports_image_input") ================================================ FILE: backend/alembic/versions/65bc6e0f8500_remove_kg_subtype_from_db.py ================================================ """remove kg subtype from db Revision ID: 65bc6e0f8500 Revises: cec7ec36c505 Create Date: 2025-06-13 10:04:27.705976 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "65bc6e0f8500" down_revision = "cec7ec36c505" branch_labels = None depends_on = None def upgrade() -> None: op.drop_column("kg_entity", "entity_class") op.drop_column("kg_entity", "entity_subtype") op.drop_column("kg_entity_extraction_staging", "entity_class") op.drop_column("kg_entity_extraction_staging", "entity_subtype") def downgrade() -> None: op.add_column( "kg_entity_extraction_staging", sa.Column("entity_subtype", sa.String(), nullable=True, index=True), ) op.add_column( "kg_entity_extraction_staging", sa.Column("entity_class", sa.String(), nullable=True, index=True), ) op.add_column( "kg_entity", sa.Column("entity_subtype", sa.String(), nullable=True, index=True) ) op.add_column( "kg_entity", sa.Column("entity_class", sa.String(), nullable=True, index=True) ) ================================================ FILE: backend/alembic/versions/6756efa39ada_id_uuid_for_chat_session.py ================================================ """Migrate chat_session and chat_message tables to use UUID primary keys Revision ID: 6756efa39ada Revises: 5d12a446f5c0 Create Date: 2024-10-15 17:47:44.108537 """ from alembic import op import sqlalchemy as sa revision = "6756efa39ada" down_revision = "5d12a446f5c0" branch_labels = None depends_on = None """ This script: 1. Adds UUID columns to chat_session and chat_message 2. Populates new columns with UUIDs 3. Updates foreign key relationships 4. Removes old integer ID columns Note: Downgrade will assign new integer IDs, not restore original ones. """ def upgrade() -> None: op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto;") op.add_column( "chat_session", sa.Column( "new_id", sa.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), nullable=False, ), ) op.execute("UPDATE chat_session SET new_id = gen_random_uuid();") op.add_column( "chat_message", sa.Column("new_chat_session_id", sa.UUID(as_uuid=True), nullable=True), ) op.execute( """ UPDATE chat_message SET new_chat_session_id = cs.new_id FROM chat_session cs WHERE chat_message.chat_session_id = cs.id; """ ) op.drop_constraint( "chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey" ) op.drop_column("chat_message", "chat_session_id") op.alter_column( "chat_message", "new_chat_session_id", new_column_name="chat_session_id" ) op.drop_constraint("chat_session_pkey", "chat_session", type_="primary") op.drop_column("chat_session", "id") op.alter_column("chat_session", "new_id", new_column_name="id") op.create_primary_key("chat_session_pkey", "chat_session", ["id"]) op.create_foreign_key( "chat_message_chat_session_id_fkey", "chat_message", "chat_session", ["chat_session_id"], ["id"], ondelete="CASCADE", ) def downgrade() -> None: op.drop_constraint( "chat_message_chat_session_id_fkey", "chat_message", type_="foreignkey" ) op.add_column( "chat_session", sa.Column("old_id", sa.Integer, autoincrement=True, nullable=True), ) op.execute("CREATE SEQUENCE chat_session_old_id_seq OWNED BY chat_session.old_id;") op.execute( "ALTER TABLE chat_session ALTER COLUMN old_id SET DEFAULT nextval('chat_session_old_id_seq');" ) op.execute( "UPDATE chat_session SET old_id = nextval('chat_session_old_id_seq') WHERE old_id IS NULL;" ) op.alter_column("chat_session", "old_id", nullable=False) op.drop_constraint("chat_session_pkey", "chat_session", type_="primary") op.create_primary_key("chat_session_pkey", "chat_session", ["old_id"]) op.add_column( "chat_message", sa.Column("old_chat_session_id", sa.Integer, nullable=True), ) op.execute( """ UPDATE chat_message SET old_chat_session_id = cs.old_id FROM chat_session cs WHERE chat_message.chat_session_id = cs.id; """ ) op.drop_column("chat_message", "chat_session_id") op.alter_column( "chat_message", "old_chat_session_id", new_column_name="chat_session_id" ) op.create_foreign_key( "chat_message_chat_session_id_fkey", "chat_message", "chat_session", ["chat_session_id"], ["old_id"], ondelete="CASCADE", ) op.drop_column("chat_session", "id") op.alter_column("chat_session", "old_id", new_column_name="id") op.alter_column( "chat_session", "id", type_=sa.Integer(), existing_type=sa.Integer(), existing_nullable=False, existing_server_default=False, ) # Rename the sequence op.execute("ALTER SEQUENCE chat_session_old_id_seq RENAME TO chat_session_id_seq;") # Update the default value to use the renamed sequence op.alter_column( "chat_session", "id", server_default=sa.text("nextval('chat_session_id_seq'::regclass)"), ) ================================================ FILE: backend/alembic/versions/689433b0d8de_add_hook_and_hook_execution_log_tables.py ================================================ """add_hook_and_hook_execution_log_tables Revision ID: 689433b0d8de Revises: 93a2e195e25c Create Date: 2026-03-13 11:25:06.547474 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects.postgresql import UUID as PGUUID # revision identifiers, used by Alembic. revision = "689433b0d8de" down_revision = "93a2e195e25c" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "hook", sa.Column("id", sa.Integer(), nullable=False), sa.Column("name", sa.String(), nullable=False), sa.Column( "hook_point", sa.Enum("document_ingestion", "query_processing", native_enum=False), nullable=False, ), sa.Column("endpoint_url", sa.Text(), nullable=True), sa.Column("api_key", sa.LargeBinary(), nullable=True), sa.Column("is_reachable", sa.Boolean(), nullable=True), sa.Column( "fail_strategy", sa.Enum("hard", "soft", native_enum=False), nullable=False, ), sa.Column("timeout_seconds", sa.Float(), nullable=False), sa.Column( "is_active", sa.Boolean(), nullable=False, server_default=sa.text("false") ), sa.Column( "deleted", sa.Boolean(), nullable=False, server_default=sa.text("false") ), sa.Column("creator_id", PGUUID(as_uuid=True), nullable=True), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column( "updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.ForeignKeyConstraint(["creator_id"], ["user.id"], ondelete="SET NULL"), sa.PrimaryKeyConstraint("id"), ) op.create_index( "ix_hook_one_non_deleted_per_point", "hook", ["hook_point"], unique=True, postgresql_where=sa.text("deleted = false"), ) op.create_table( "hook_execution_log", sa.Column("id", sa.Integer(), nullable=False), sa.Column("hook_id", sa.Integer(), nullable=False), sa.Column( "is_success", sa.Boolean(), nullable=False, ), sa.Column("error_message", sa.Text(), nullable=True), sa.Column("status_code", sa.Integer(), nullable=True), sa.Column("duration_ms", sa.Integer(), nullable=True), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.ForeignKeyConstraint(["hook_id"], ["hook.id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("id"), ) op.create_index("ix_hook_execution_log_hook_id", "hook_execution_log", ["hook_id"]) op.create_index( "ix_hook_execution_log_created_at", "hook_execution_log", ["created_at"] ) def downgrade() -> None: op.drop_index("ix_hook_execution_log_created_at", table_name="hook_execution_log") op.drop_index("ix_hook_execution_log_hook_id", table_name="hook_execution_log") op.drop_table("hook_execution_log") op.drop_index("ix_hook_one_non_deleted_per_point", table_name="hook") op.drop_table("hook") ================================================ FILE: backend/alembic/versions/699221885109_nullify_default_task_prompt.py ================================================ """nullify_default_task_prompt Revision ID: 699221885109 Revises: 7e490836d179 Create Date: 2025-12-30 10:00:00.000000 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "699221885109" down_revision = "7e490836d179" branch_labels = None depends_on = None DEFAULT_PERSONA_ID = 0 def upgrade() -> None: # Make task_prompt column nullable # Note: The model had nullable=True but the DB column was NOT NULL until this point op.alter_column( "persona", "task_prompt", nullable=True, ) # Set task_prompt to NULL for the default persona conn = op.get_bind() conn.execute( sa.text( """ UPDATE persona SET task_prompt = NULL WHERE id = :persona_id """ ), {"persona_id": DEFAULT_PERSONA_ID}, ) def downgrade() -> None: # Restore task_prompt to empty string for the default persona conn = op.get_bind() conn.execute( sa.text( """ UPDATE persona SET task_prompt = '' WHERE id = :persona_id AND task_prompt IS NULL """ ), {"persona_id": DEFAULT_PERSONA_ID}, ) # Set any remaining NULL task_prompts to empty string before making non-nullable conn.execute( sa.text( """ UPDATE persona SET task_prompt = '' WHERE task_prompt IS NULL """ ) ) # Revert task_prompt column to not nullable op.alter_column( "persona", "task_prompt", nullable=False, ) ================================================ FILE: backend/alembic/versions/6a804aeb4830_duplicated_no_harm_user_file_migration.py ================================================ """duplicated no-harm user file migration Revision ID: 6a804aeb4830 Revises: 8e1ac4f39a9f Create Date: 2025-04-01 07:26:10.539362 """ # revision identifiers, used by Alembic. revision = "6a804aeb4830" down_revision = "8e1ac4f39a9f" branch_labels = None depends_on = None # Leaving this around only because some people might be on this migration # originally was a duplicate of the user files migration def upgrade() -> None: pass def downgrade() -> None: pass ================================================ FILE: backend/alembic/versions/6b3b4083c5aa_persona_cleanup_and_featured.py ================================================ """persona cleanup and featured Revision ID: 6b3b4083c5aa Revises: 57122d037335 Create Date: 2026-02-26 12:00:00.000000 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "6b3b4083c5aa" down_revision = "57122d037335" branch_labels = None depends_on = None def upgrade() -> None: # Add featured column with nullable=True first op.add_column("persona", sa.Column("featured", sa.Boolean(), nullable=True)) # Migrate data from is_default_persona to featured op.execute("UPDATE persona SET featured = is_default_persona") # Make featured non-nullable with default=False op.alter_column( "persona", "featured", existing_type=sa.Boolean(), nullable=False, server_default=sa.false(), ) # Drop is_default_persona column op.drop_column("persona", "is_default_persona") # Drop unused columns op.drop_column("persona", "num_chunks") op.drop_column("persona", "chunks_above") op.drop_column("persona", "chunks_below") op.drop_column("persona", "llm_relevance_filter") op.drop_column("persona", "llm_filter_extraction") op.drop_column("persona", "recency_bias") def downgrade() -> None: # Add back recency_bias column op.add_column( "persona", sa.Column( "recency_bias", sa.VARCHAR(), nullable=False, server_default="base_decay", ), ) # Add back llm_filter_extraction column op.add_column( "persona", sa.Column( "llm_filter_extraction", sa.Boolean(), nullable=False, server_default=sa.false(), ), ) # Add back llm_relevance_filter column op.add_column( "persona", sa.Column( "llm_relevance_filter", sa.Boolean(), nullable=False, server_default=sa.false(), ), ) # Add back chunks_below column op.add_column( "persona", sa.Column("chunks_below", sa.Integer(), nullable=False, server_default="0"), ) # Add back chunks_above column op.add_column( "persona", sa.Column("chunks_above", sa.Integer(), nullable=False, server_default="0"), ) # Add back num_chunks column op.add_column("persona", sa.Column("num_chunks", sa.Float(), nullable=True)) # Add back is_default_persona column op.add_column( "persona", sa.Column( "is_default_persona", sa.Boolean(), nullable=False, server_default=sa.false(), ), ) # Migrate data from featured to is_default_persona op.execute("UPDATE persona SET is_default_persona = featured") # Drop featured column op.drop_column("persona", "featured") ================================================ FILE: backend/alembic/versions/6d387b3196c2_basic_auth.py ================================================ """Basic Auth Revision ID: 6d387b3196c2 Revises: 47433d30de82 Create Date: 2023-05-05 14:40:10.242502 """ import fastapi_users_db_sqlalchemy import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "6d387b3196c2" down_revision = "47433d30de82" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "user", sa.Column("id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False), sa.Column("email", sa.String(length=320), nullable=False), sa.Column("hashed_password", sa.String(length=1024), nullable=False), sa.Column("is_active", sa.Boolean(), nullable=False), sa.Column("is_superuser", sa.Boolean(), nullable=False), sa.Column("is_verified", sa.Boolean(), nullable=False), sa.Column( "role", sa.Enum("BASIC", "ADMIN", name="userrole", native_enum=False), default="BASIC", nullable=False, ), sa.PrimaryKeyConstraint("id"), ) op.create_index(op.f("ix_user_email"), "user", ["email"], unique=True) op.create_table( "accesstoken", sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False, ), sa.Column("token", sa.String(length=43), nullable=False), sa.Column( "created_at", fastapi_users_db_sqlalchemy.generics.TIMESTAMPAware(timezone=True), nullable=False, ), sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="cascade"), sa.PrimaryKeyConstraint("token"), ) op.create_index( op.f("ix_accesstoken_created_at"), "accesstoken", ["created_at"], unique=False, ) op.alter_column( "index_attempt", "time_created", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False, existing_server_default=sa.text("now()"), # type: ignore ) op.alter_column( "index_attempt", "time_updated", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=False, ) def downgrade() -> None: op.alter_column( "index_attempt", "time_updated", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=True, ) op.alter_column( "index_attempt", "time_created", existing_type=postgresql.TIMESTAMP(timezone=True), nullable=True, existing_server_default=sa.text("now()"), # type: ignore ) op.drop_index(op.f("ix_accesstoken_created_at"), table_name="accesstoken") op.drop_table("accesstoken") op.drop_index(op.f("ix_user_email"), table_name="user") op.drop_table("user") ================================================ FILE: backend/alembic/versions/6d562f86c78b_remove_default_bot.py ================================================ """remove default bot Revision ID: 6d562f86c78b Revises: 177de57c21c9 Create Date: 2024-11-22 11:51:29.331336 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "6d562f86c78b" down_revision = "177de57c21c9" branch_labels = None depends_on = None def upgrade() -> None: op.execute( sa.text( """ DELETE FROM slack_bot WHERE name = 'Default Bot' AND bot_token = '' AND app_token = '' AND NOT EXISTS ( SELECT 1 FROM slack_channel_config WHERE slack_channel_config.slack_bot_id = slack_bot.id ) """ ) ) def downgrade() -> None: op.execute( sa.text( """ INSERT INTO slack_bot (name, enabled, bot_token, app_token) SELECT 'Default Bot', true, '', '' WHERE NOT EXISTS (SELECT 1 FROM slack_bot) RETURNING id; """ ) ) ================================================ FILE: backend/alembic/versions/6f4f86aef280_add_queries_and_is_web_fetch_to_.py ================================================ """add queries and is web fetch to iteration answer Revision ID: 6f4f86aef280 Revises: 03d710ccf29c Create Date: 2025-10-14 18:08:30.920123 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "6f4f86aef280" down_revision = "03d710ccf29c" branch_labels = None depends_on = None def upgrade() -> None: # Add is_web_fetch column op.add_column( "research_agent_iteration_sub_step", sa.Column("is_web_fetch", sa.Boolean(), nullable=True), ) # Add queries column op.add_column( "research_agent_iteration_sub_step", sa.Column("queries", postgresql.JSONB(), nullable=True), ) def downgrade() -> None: op.drop_column("research_agent_iteration_sub_step", "queries") op.drop_column("research_agent_iteration_sub_step", "is_web_fetch") ================================================ FILE: backend/alembic/versions/6fc7886d665d_make_categories_labels_and_many_to_many.py ================================================ """make categories labels and many to many Revision ID: 6fc7886d665d Revises: 3c6531f32351 Create Date: 2025-01-13 18:12:18.029112 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "6fc7886d665d" down_revision = "3c6531f32351" branch_labels = None depends_on = None def upgrade() -> None: # Rename persona_category table to persona_label op.rename_table("persona_category", "persona_label") # Create the new association table op.create_table( "persona__persona_label", sa.Column("persona_id", sa.Integer(), nullable=False), sa.Column("persona_label_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["persona_id"], ["persona.id"], ), sa.ForeignKeyConstraint( ["persona_label_id"], ["persona_label.id"], ondelete="CASCADE", ), sa.PrimaryKeyConstraint("persona_id", "persona_label_id"), ) # Copy existing relationships to the new table op.execute( """ INSERT INTO persona__persona_label (persona_id, persona_label_id) SELECT id, category_id FROM persona WHERE category_id IS NOT NULL """ ) # Remove the old category_id column from persona table op.drop_column("persona", "category_id") def downgrade() -> None: # Rename persona_label table back to persona_category op.rename_table("persona_label", "persona_category") # Add back the category_id column to persona table op.add_column("persona", sa.Column("category_id", sa.Integer(), nullable=True)) op.create_foreign_key( "persona_category_id_fkey", "persona", "persona_category", ["category_id"], ["id"], ) # Copy the first label relationship back to the persona table op.execute( """ UPDATE persona SET category_id = ( SELECT persona_label_id FROM persona__persona_label WHERE persona__persona_label.persona_id = persona.id LIMIT 1 ) """ ) # Drop the association table op.drop_table("persona__persona_label") ================================================ FILE: backend/alembic/versions/703313b75876_add_tokenratelimit_tables.py ================================================ """Add TokenRateLimit Tables Revision ID: 703313b75876 Revises: fad14119fb92 Create Date: 2024-04-15 01:36:02.952809 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "703313b75876" down_revision = "fad14119fb92" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "token_rate_limit", sa.Column("id", sa.Integer(), nullable=False), sa.Column("enabled", sa.Boolean(), nullable=False), sa.Column("token_budget", sa.Integer(), nullable=False), sa.Column("period_hours", sa.Integer(), nullable=False), sa.Column( "scope", sa.String(length=10), nullable=False, ), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.PrimaryKeyConstraint("id"), ) op.create_table( "token_rate_limit__user_group", sa.Column("rate_limit_id", sa.Integer(), nullable=False), sa.Column("user_group_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["rate_limit_id"], ["token_rate_limit.id"], ), sa.ForeignKeyConstraint( ["user_group_id"], ["user_group.id"], ), sa.PrimaryKeyConstraint("rate_limit_id", "user_group_id"), ) # NOTE: rate limit settings used to be stored in the "token_budget_settings" key in the # KeyValueStore. This will now be lost. The KV store works differently than it used to # so the migration is fairly complicated and likely not worth it to support (pretty much # nobody will have it set) def downgrade() -> None: op.drop_table("token_rate_limit__user_group") op.drop_table("token_rate_limit") ================================================ FILE: backend/alembic/versions/70f00c45c0f2_more_descriptive_filestore.py ================================================ """More Descriptive Filestore Revision ID: 70f00c45c0f2 Revises: 3879338f8ba1 Create Date: 2024-05-17 17:51:41.926893 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "70f00c45c0f2" down_revision = "3879338f8ba1" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column("file_store", sa.Column("display_name", sa.String(), nullable=True)) op.add_column( "file_store", sa.Column( "file_origin", sa.String(), nullable=False, server_default="connector", # Default to connector ), ) op.add_column( "file_store", sa.Column( "file_type", sa.String(), nullable=False, server_default="text/plain" ), ) op.add_column( "file_store", sa.Column( "file_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True, ), ) op.execute( """ UPDATE file_store SET file_origin = CASE WHEN file_name LIKE 'chat__%' THEN 'chat_upload' ELSE 'connector' END, file_name = CASE WHEN file_name LIKE 'chat__%' THEN SUBSTR(file_name, 7) ELSE file_name END, file_type = CASE WHEN file_name LIKE 'chat__%' THEN 'image/png' ELSE 'text/plain' END """ ) def downgrade() -> None: op.drop_column("file_store", "file_metadata") op.drop_column("file_store", "file_type") op.drop_column("file_store", "file_origin") op.drop_column("file_store", "display_name") ================================================ FILE: backend/alembic/versions/7206234e012a_add_image_generation_config_table.py ================================================ """add image generation config table Revision ID: 7206234e012a Revises: 699221885109 Create Date: 2025-12-21 00:00:00.000000 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "7206234e012a" down_revision = "699221885109" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "image_generation_config", sa.Column("image_provider_id", sa.String(), primary_key=True), sa.Column("model_configuration_id", sa.Integer(), nullable=False), sa.Column("is_default", sa.Boolean(), nullable=False), sa.ForeignKeyConstraint( ["model_configuration_id"], ["model_configuration.id"], ondelete="CASCADE", ), ) op.create_index( "ix_image_generation_config_is_default", "image_generation_config", ["is_default"], unique=False, ) op.create_index( "ix_image_generation_config_model_configuration_id", "image_generation_config", ["model_configuration_id"], unique=False, ) def downgrade() -> None: op.drop_index( "ix_image_generation_config_model_configuration_id", table_name="image_generation_config", ) op.drop_index( "ix_image_generation_config_is_default", table_name="image_generation_config" ) op.drop_table("image_generation_config") ================================================ FILE: backend/alembic/versions/72aa7de2e5cf_make_processing_mode_default_all_caps.py ================================================ """make processing mode default all caps Revision ID: 72aa7de2e5cf Revises: 2020d417ec84 Create Date: 2026-01-26 18:58:47.705253 This migration fixes the ProcessingMode enum value mismatch: - SQLAlchemy's Enum with native_enum=False uses enum member NAMES as valid values - The original migration stored lowercase VALUES ('regular', 'file_system') - This converts existing data to uppercase NAMES ('REGULAR', 'FILE_SYSTEM') - Also drops any spurious native PostgreSQL enum type that may have been auto-created """ from alembic import op # revision identifiers, used by Alembic. revision = "72aa7de2e5cf" down_revision = "2020d417ec84" branch_labels = None depends_on = None def upgrade() -> None: # Convert existing lowercase values to uppercase to match enum member names op.execute( "UPDATE connector_credential_pair SET processing_mode = 'REGULAR' WHERE processing_mode = 'regular'" ) op.execute( "UPDATE connector_credential_pair SET processing_mode = 'FILE_SYSTEM' WHERE processing_mode = 'file_system'" ) # Update the server default to use uppercase op.alter_column( "connector_credential_pair", "processing_mode", server_default="REGULAR", ) def downgrade() -> None: # State prior to this was broken, so we don't want to revert back to it pass ================================================ FILE: backend/alembic/versions/72bdc9929a46_permission_auto_sync_framework.py ================================================ """Permission Auto Sync Framework Revision ID: 72bdc9929a46 Revises: 475fcefe8826 Create Date: 2024-04-14 21:15:28.659634 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "72bdc9929a46" down_revision = "475fcefe8826" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "email_to_external_user_cache", sa.Column("id", sa.Integer(), nullable=False), sa.Column("external_user_id", sa.String(), nullable=False), sa.Column("user_id", sa.UUID(), nullable=True), sa.Column("user_email", sa.String(), nullable=False), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("id"), ) op.create_table( "external_permission", sa.Column("id", sa.Integer(), nullable=False), sa.Column("user_id", sa.UUID(), nullable=True), sa.Column("user_email", sa.String(), nullable=False), sa.Column( "source_type", sa.String(), nullable=False, ), sa.Column("external_permission_group", sa.String(), nullable=False), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("id"), ) op.create_table( "permission_sync_run", sa.Column("id", sa.Integer(), nullable=False), sa.Column( "source_type", sa.String(), nullable=False, ), sa.Column("update_type", sa.String(), nullable=False), sa.Column("cc_pair_id", sa.Integer(), nullable=True), sa.Column( "status", sa.String(), nullable=False, ), sa.Column("error_msg", sa.Text(), nullable=True), sa.Column( "updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.ForeignKeyConstraint( ["cc_pair_id"], ["connector_credential_pair.id"], ), sa.PrimaryKeyConstraint("id"), ) def downgrade() -> None: op.drop_table("permission_sync_run") op.drop_table("external_permission") op.drop_table("email_to_external_user_cache") ================================================ FILE: backend/alembic/versions/73e9983e5091_add_search_query_table.py ================================================ """add_search_query_table Revision ID: 73e9983e5091 Revises: d1b637d7050a Create Date: 2026-01-14 14:16:52.837489 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "73e9983e5091" down_revision = "d1b637d7050a" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "search_query", sa.Column("id", postgresql.UUID(as_uuid=True), primary_key=True), sa.Column( "user_id", postgresql.UUID(as_uuid=True), sa.ForeignKey("user.id"), nullable=False, ), sa.Column("query", sa.String(), nullable=False), sa.Column("query_expansions", postgresql.ARRAY(sa.String()), nullable=True), sa.Column( "created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now(), ), ) op.create_index("ix_search_query_user_id", "search_query", ["user_id"]) op.create_index("ix_search_query_created_at", "search_query", ["created_at"]) def downgrade() -> None: op.drop_index("ix_search_query_created_at", table_name="search_query") op.drop_index("ix_search_query_user_id", table_name="search_query") op.drop_table("search_query") ================================================ FILE: backend/alembic/versions/7477a5f5d728_added_model_defaults_for_users.py ================================================ """Added model defaults for users Revision ID: 7477a5f5d728 Revises: 213fd978c6d8 Create Date: 2024-08-04 19:00:04.512634 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "7477a5f5d728" down_revision = "213fd978c6d8" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column("user", sa.Column("default_model", sa.Text(), nullable=True)) def downgrade() -> None: op.drop_column("user", "default_model") ================================================ FILE: backend/alembic/versions/7547d982db8f_chat_folders.py ================================================ """Chat Folders Revision ID: 7547d982db8f Revises: ef7da92f7213 Create Date: 2024-05-02 15:18:56.573347 """ from alembic import op import sqlalchemy as sa import fastapi_users_db_sqlalchemy # revision identifiers, used by Alembic. revision = "7547d982db8f" down_revision = "ef7da92f7213" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "chat_folder", sa.Column("id", sa.Integer(), nullable=False), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=True, ), sa.Column("name", sa.String(), nullable=True), sa.Column("display_priority", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("id"), ) op.add_column("chat_session", sa.Column("folder_id", sa.Integer(), nullable=True)) op.create_foreign_key( "chat_session_chat_folder_fk", "chat_session", "chat_folder", ["folder_id"], ["id"], ) def downgrade() -> None: bind = op.get_bind() inspector = sa.inspect(bind) if "chat_session" in inspector.get_table_names(): chat_session_fks = { fk.get("name") for fk in inspector.get_foreign_keys("chat_session") } if "chat_session_chat_folder_fk" in chat_session_fks: op.drop_constraint( "chat_session_chat_folder_fk", "chat_session", type_="foreignkey" ) chat_session_columns = { col["name"] for col in inspector.get_columns("chat_session") } if "folder_id" in chat_session_columns: op.drop_column("chat_session", "folder_id") if "chat_folder" in inspector.get_table_names(): op.drop_table("chat_folder") ================================================ FILE: backend/alembic/versions/7616121f6e97_add_enterprise_fields_to_scim_user_mapping.py ================================================ """add enterprise and name fields to scim_user_mapping Revision ID: 7616121f6e97 Revises: 07b98176f1de Create Date: 2026-02-23 12:00:00.000000 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "7616121f6e97" down_revision = "07b98176f1de" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "scim_user_mapping", sa.Column("department", sa.String(), nullable=True), ) op.add_column( "scim_user_mapping", sa.Column("manager", sa.String(), nullable=True), ) op.add_column( "scim_user_mapping", sa.Column("given_name", sa.String(), nullable=True), ) op.add_column( "scim_user_mapping", sa.Column("family_name", sa.String(), nullable=True), ) op.add_column( "scim_user_mapping", sa.Column("scim_emails_json", sa.Text(), nullable=True), ) def downgrade() -> None: op.drop_column("scim_user_mapping", "scim_emails_json") op.drop_column("scim_user_mapping", "family_name") op.drop_column("scim_user_mapping", "given_name") op.drop_column("scim_user_mapping", "manager") op.drop_column("scim_user_mapping", "department") ================================================ FILE: backend/alembic/versions/767f1c2a00eb_count_chat_tokens.py ================================================ """Count Chat Tokens Revision ID: 767f1c2a00eb Revises: dba7f71618f5 Create Date: 2023-09-21 10:03:21.509899 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "767f1c2a00eb" down_revision = "dba7f71618f5" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "chat_message", sa.Column("token_count", sa.Integer(), nullable=False) ) def downgrade() -> None: op.drop_column("chat_message", "token_count") ================================================ FILE: backend/alembic/versions/76b60d407dfb_cc_pair_name_not_unique.py ================================================ """CC-Pair Name not Unique Revision ID: 76b60d407dfb Revises: b156fa702355 Create Date: 2023-12-22 21:42:10.018804 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "76b60d407dfb" down_revision = "b156fa702355" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.execute("DELETE FROM connector_credential_pair WHERE name IS NULL") op.drop_constraint( "connector_credential_pair__name__key", "connector_credential_pair", type_="unique", ) op.alter_column( "connector_credential_pair", "name", existing_type=sa.String(), nullable=False ) def downgrade() -> None: op.create_unique_constraint( "connector_credential_pair__name__key", "connector_credential_pair", ["name"] ) op.alter_column( "connector_credential_pair", "name", existing_type=sa.String(), nullable=True ) ================================================ FILE: backend/alembic/versions/776b3bbe9092_remove_remaining_enums.py ================================================ """Remove Remaining Enums Revision ID: 776b3bbe9092 Revises: 4738e4b3bae1 Create Date: 2024-03-22 21:34:27.629444 """ from alembic import op import sqlalchemy as sa from onyx.db.models import IndexModelStatus from onyx.context.search.enums import RecencyBiasSetting, SearchType # revision identifiers, used by Alembic. revision = "776b3bbe9092" down_revision = "4738e4b3bae1" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.alter_column( "persona", "search_type", type_=sa.String, existing_type=sa.Enum(SearchType, native_enum=False), existing_nullable=False, ) op.alter_column( "persona", "recency_bias", type_=sa.String, existing_type=sa.Enum(RecencyBiasSetting, native_enum=False), existing_nullable=False, ) # Because the indexmodelstatus enum does not have a mapping to a string type # we need this workaround instead of directly changing the type op.add_column("embedding_model", sa.Column("temp_status", sa.String)) op.execute("UPDATE embedding_model SET temp_status = status::text") op.drop_column("embedding_model", "status") op.alter_column("embedding_model", "temp_status", new_column_name="status") op.execute("DROP TYPE IF EXISTS searchtype") op.execute("DROP TYPE IF EXISTS recencybiassetting") op.execute("DROP TYPE IF EXISTS indexmodelstatus") def downgrade() -> None: op.alter_column( "persona", "search_type", type_=sa.Enum(SearchType, native_enum=False), existing_type=sa.String(length=50), existing_nullable=False, ) op.alter_column( "persona", "recency_bias", type_=sa.Enum(RecencyBiasSetting, native_enum=False), existing_type=sa.String(length=50), existing_nullable=False, ) op.alter_column( "embedding_model", "status", type_=sa.Enum(IndexModelStatus, native_enum=False), existing_type=sa.String(length=50), existing_nullable=False, ) ================================================ FILE: backend/alembic/versions/77d07dffae64_forcibly_remove_more_enum_types_from_.py ================================================ """forcibly remove more enum types from postgres Revision ID: 77d07dffae64 Revises: d61e513bef0a Create Date: 2023-11-01 12:33:01.999617 """ from alembic import op from sqlalchemy import String # revision identifiers, used by Alembic. revision = "77d07dffae64" down_revision = "d61e513bef0a" branch_labels: None = None depends_on: None = None def upgrade() -> None: # In a PR: # https://github.com/onyx-dot-app/onyx/pull/397/files#diff-f05fb341f6373790b91852579631b64ca7645797a190837156a282b67e5b19c2 # we directly changed some previous migrations. This caused some users to have native enums # while others wouldn't. This has caused some issues when adding new fields to these enums. # This migration manually changes the enum types to ensure that nobody uses native enums. op.alter_column("query_event", "selected_search_flow", type_=String) op.alter_column("query_event", "feedback", type_=String) op.alter_column("document_retrieval_feedback", "feedback", type_=String) op.execute("DROP TYPE IF EXISTS searchtype") op.execute("DROP TYPE IF EXISTS qafeedbacktype") op.execute("DROP TYPE IF EXISTS searchfeedbacktype") def downgrade() -> None: # We don't want Native Enums, do nothing pass ================================================ FILE: backend/alembic/versions/78dbe7e38469_task_tracking.py ================================================ """Task Tracking Revision ID: 78dbe7e38469 Revises: 7ccea01261f6 Create Date: 2023-10-15 23:40:50.593262 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "78dbe7e38469" down_revision = "7ccea01261f6" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "task_queue_jobs", sa.Column("id", sa.Integer(), nullable=False), sa.Column("task_id", sa.String(), nullable=False), sa.Column("task_name", sa.String(), nullable=False), sa.Column( "status", sa.Enum( "PENDING", "STARTED", "SUCCESS", "FAILURE", name="taskstatus", native_enum=False, ), nullable=False, ), sa.Column("start_time", sa.DateTime(timezone=True), nullable=True), sa.Column( "register_time", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.PrimaryKeyConstraint("id"), ) def downgrade() -> None: op.drop_table("task_queue_jobs") ================================================ FILE: backend/alembic/versions/78ebc66946a0_remove_reranking_from_search_settings.py ================================================ """remove reranking from search_settings Revision ID: 78ebc66946a0 Revises: 849b21c732f8 Create Date: 2026-01-28 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "78ebc66946a0" down_revision = "849b21c732f8" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.drop_column("search_settings", "disable_rerank_for_streaming") op.drop_column("search_settings", "rerank_model_name") op.drop_column("search_settings", "rerank_provider_type") op.drop_column("search_settings", "rerank_api_key") op.drop_column("search_settings", "rerank_api_url") op.drop_column("search_settings", "num_rerank") def downgrade() -> None: op.add_column( "search_settings", sa.Column( "disable_rerank_for_streaming", sa.Boolean(), nullable=False, server_default="false", ), ) op.add_column( "search_settings", sa.Column("rerank_model_name", sa.String(), nullable=True) ) op.add_column( "search_settings", sa.Column("rerank_provider_type", sa.String(), nullable=True) ) op.add_column( "search_settings", sa.Column("rerank_api_key", sa.String(), nullable=True) ) op.add_column( "search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True) ) op.add_column( "search_settings", sa.Column( "num_rerank", sa.Integer(), nullable=False, server_default=str(20), ), ) ================================================ FILE: backend/alembic/versions/795b20b85b4b_add_llm_group_permissions_control.py ================================================ """add_llm_group_permissions_control Revision ID: 795b20b85b4b Revises: 05c07bf07c00 Create Date: 2024-07-19 11:54:35.701558 """ from alembic import op import sqlalchemy as sa revision = "795b20b85b4b" down_revision = "05c07bf07c00" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "llm_provider__user_group", sa.Column("llm_provider_id", sa.Integer(), nullable=False), sa.Column("user_group_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["llm_provider_id"], ["llm_provider.id"], ), sa.ForeignKeyConstraint( ["user_group_id"], ["user_group.id"], ), sa.PrimaryKeyConstraint("llm_provider_id", "user_group_id"), ) op.add_column( "llm_provider", sa.Column("is_public", sa.Boolean(), nullable=False, server_default="true"), ) def downgrade() -> None: op.drop_table("llm_provider__user_group") op.drop_column("llm_provider", "is_public") ================================================ FILE: backend/alembic/versions/797089dfb4d2_persona_start_date.py ================================================ """persona_start_date Revision ID: 797089dfb4d2 Revises: 55546a7967ee Create Date: 2024-09-11 14:51:49.785835 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "797089dfb4d2" down_revision = "55546a7967ee" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "persona", sa.Column("search_start_date", sa.DateTime(timezone=True), nullable=True), ) def downgrade() -> None: op.drop_column("persona", "search_start_date") ================================================ FILE: backend/alembic/versions/79acd316403a_add_api_key_table.py ================================================ """Add api_key table Revision ID: 79acd316403a Revises: 904e5138fffb Create Date: 2024-01-11 17:56:37.934381 """ from alembic import op import fastapi_users_db_sqlalchemy import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "79acd316403a" down_revision = "904e5138fffb" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "api_key", sa.Column("id", sa.Integer(), nullable=False), sa.Column("hashed_api_key", sa.String(), nullable=False), sa.Column("api_key_display", sa.String(), nullable=False), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False, ), sa.Column( "owner_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=True, ), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("api_key_display"), sa.UniqueConstraint("hashed_api_key"), ) def downgrade() -> None: op.drop_table("api_key") ================================================ FILE: backend/alembic/versions/7a70b7664e37_add_model_configuration_table.py ================================================ """Add model-configuration table Revision ID: 7a70b7664e37 Revises: d961aca62eb3 Create Date: 2025-04-10 15:00:35.984669 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql from onyx.llm.well_known_providers.llm_provider_options import ( fetch_model_names_for_provider_as_set, fetch_visible_model_names_for_provider_as_set, ) # revision identifiers, used by Alembic. revision = "7a70b7664e37" down_revision = "d961aca62eb3" branch_labels = None depends_on = None def _resolve( provider_name: str, model_names: list[str] | None, display_model_names: list[str] | None, default_model_name: str, fast_default_model_name: str | None, ) -> set[tuple[str, bool]]: models = set(model_names) if model_names else None display_models = set(display_model_names) if display_model_names else None # If both are defined, we need to make sure that `model_names` is a superset of `display_model_names`. if models and display_models: models = display_models.union(models) # If only `model_names` is defined, then: # - If default-model-names are available for the `provider_name`, then set `display_model_names` to it # and set `model_names` to the union of those default-model-names with itself. # - If no default-model-names are available, then set `display_models` to `models`. # # This preserves the invariant that `display_models` is a subset of `models`. elif models and not display_models: visible_default_models = fetch_visible_model_names_for_provider_as_set( provider_name=provider_name ) if visible_default_models: display_models = set(visible_default_models) models = display_models.union(models) else: display_models = set(models) # If only the `display_model_names` are defined, then set `models` to the union of `display_model_names` # and the default-model-names for that provider. # # This will also preserve the invariant that `display_models` is a subset of `models`. elif not models and display_models: default_models = fetch_model_names_for_provider_as_set( provider_name=provider_name ) if default_models: models = display_models.union(default_models) else: models = set(display_models) # If neither are defined, then set `models` and `display_models` to the default-model-names for the given provider. # # This will also preserve the invariant that `display_models` is a subset of `models`. else: default_models = fetch_model_names_for_provider_as_set( provider_name=provider_name ) visible_default_models = fetch_visible_model_names_for_provider_as_set( provider_name=provider_name ) if default_models: if not visible_default_models: raise RuntimeError raise RuntimeError( "If `default_models` is non-None, `visible_default_models` must be non-None too." ) models = default_models display_models = visible_default_models # This is not a well-known llm-provider; we can't provide any model suggestions. # Therefore, we set to the empty set and continue else: models = set() display_models = set() # It is possible that `default_model_name` is not in `models` and is not in `display_models`. # It is also possible that `fast_default_model_name` is not in `models` and is not in `display_models`. models.add(default_model_name) if fast_default_model_name: models.add(fast_default_model_name) display_models.add(default_model_name) if fast_default_model_name: display_models.add(fast_default_model_name) return set([(model, model in display_models) for model in models]) def upgrade() -> None: op.create_table( "model_configuration", sa.Column("id", sa.Integer(), nullable=False), sa.Column("llm_provider_id", sa.Integer(), nullable=False), sa.Column("name", sa.String(), nullable=False), sa.Column("is_visible", sa.Boolean(), nullable=False), sa.Column("max_input_tokens", sa.Integer(), nullable=True), sa.ForeignKeyConstraint( ["llm_provider_id"], ["llm_provider.id"], ondelete="CASCADE" ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("llm_provider_id", "name"), ) # Create temporary sqlalchemy references to tables for data migration llm_provider_table = sa.sql.table( "llm_provider", sa.column("id", sa.Integer), sa.column("provider", sa.Integer), sa.column("model_names", postgresql.ARRAY(sa.String)), sa.column("display_model_names", postgresql.ARRAY(sa.String)), sa.column("default_model_name", sa.String), sa.column("fast_default_model_name", sa.String), ) model_configuration_table = sa.sql.table( "model_configuration", sa.column("id", sa.Integer), sa.column("llm_provider_id", sa.Integer), sa.column("name", sa.String), sa.column("is_visible", sa.Boolean), sa.column("max_input_tokens", sa.Integer), ) connection = op.get_bind() llm_providers = connection.execute( sa.select( llm_provider_table.c.id, llm_provider_table.c.provider, llm_provider_table.c.model_names, llm_provider_table.c.display_model_names, llm_provider_table.c.default_model_name, llm_provider_table.c.fast_default_model_name, ) ).fetchall() for llm_provider in llm_providers: provider_id = llm_provider[0] provider_name = llm_provider[1] model_names = llm_provider[2] display_model_names = llm_provider[3] default_model_name = llm_provider[4] fast_default_model_name = llm_provider[5] model_configurations = _resolve( provider_name=provider_name, model_names=model_names, display_model_names=display_model_names, default_model_name=default_model_name, fast_default_model_name=fast_default_model_name, ) for model_name, is_visible in model_configurations: connection.execute( model_configuration_table.insert().values( llm_provider_id=provider_id, name=model_name, is_visible=is_visible, max_input_tokens=None, ) ) op.drop_column("llm_provider", "model_names") op.drop_column("llm_provider", "display_model_names") def downgrade() -> None: llm_provider = sa.table( "llm_provider", sa.column("id", sa.Integer), sa.column("model_names", postgresql.ARRAY(sa.String)), sa.column("display_model_names", postgresql.ARRAY(sa.String)), ) model_configuration = sa.table( "model_configuration", sa.column("id", sa.Integer), sa.column("llm_provider_id", sa.Integer), sa.column("name", sa.String), sa.column("is_visible", sa.Boolean), sa.column("max_input_tokens", sa.Integer), ) op.add_column( "llm_provider", sa.Column( "model_names", postgresql.ARRAY(sa.VARCHAR()), autoincrement=False, nullable=True, ), ) op.add_column( "llm_provider", sa.Column( "display_model_names", postgresql.ARRAY(sa.VARCHAR()), autoincrement=False, nullable=True, ), ) connection = op.get_bind() provider_ids = connection.execute(sa.select(llm_provider.c.id)).fetchall() for (provider_id,) in provider_ids: # Get all models for this provider models = connection.execute( sa.select( model_configuration.c.name, model_configuration.c.is_visible ).where(model_configuration.c.llm_provider_id == provider_id) ).fetchall() all_models = [model[0] for model in models] visible_models = [model[0] for model in models if model[1]] # Update provider with arrays op.execute( llm_provider.update() .where(llm_provider.c.id == provider_id) .values(model_names=all_models, display_model_names=visible_models) ) op.drop_table("model_configuration") ================================================ FILE: backend/alembic/versions/7aea705850d5_added_slack_auto_filter.py ================================================ """added slack_auto_filter Revision ID: 7aea705850d5 Revises: 4505fd7302e1 Create Date: 2024-07-10 11:01:23.581015 """ from alembic import op import sqlalchemy as sa revision = "7aea705850d5" down_revision = "4505fd7302e1" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "slack_bot_config", sa.Column("enable_auto_filters", sa.Boolean(), nullable=True), ) op.execute( "UPDATE slack_bot_config SET enable_auto_filters = FALSE WHERE enable_auto_filters IS NULL" ) op.alter_column( "slack_bot_config", "enable_auto_filters", existing_type=sa.Boolean(), nullable=False, server_default=sa.false(), ) def downgrade() -> None: op.drop_column("slack_bot_config", "enable_auto_filters") ================================================ FILE: backend/alembic/versions/7b9b952abdf6_update_entities.py ================================================ """update-entities Revision ID: 7b9b952abdf6 Revises: 36e9220ab794 Create Date: 2025-06-23 20:24:08.139201 """ import json from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "7b9b952abdf6" down_revision = "36e9220ab794" branch_labels = None depends_on = None def upgrade() -> None: conn = op.get_bind() # new entity type metadata_attribute_conversion new_entity_type_conversion = { "LINEAR": { "team": {"name": "team", "keep": True, "implication_property": None}, "state": {"name": "state", "keep": True, "implication_property": None}, "priority": { "name": "priority", "keep": True, "implication_property": None, }, "estimate": { "name": "estimate", "keep": True, "implication_property": None, }, "created_at": { "name": "created_at", "keep": True, "implication_property": None, }, "started_at": { "name": "started_at", "keep": True, "implication_property": None, }, "completed_at": { "name": "completed_at", "keep": True, "implication_property": None, }, "due_date": { "name": "due_date", "keep": True, "implication_property": None, }, "creator": { "name": "creator", "keep": False, "implication_property": { "implied_entity_type": "from_email", "implied_relationship_name": "is_creator_of", }, }, "assignee": { "name": "assignee", "keep": False, "implication_property": { "implied_entity_type": "from_email", "implied_relationship_name": "is_assignee_of", }, }, }, "JIRA": { "issuetype": { "name": "subtype", "keep": True, "implication_property": None, }, "status": {"name": "status", "keep": True, "implication_property": None}, "priority": { "name": "priority", "keep": True, "implication_property": None, }, "project_name": { "name": "project", "keep": True, "implication_property": None, }, "created": { "name": "created_at", "keep": True, "implication_property": None, }, "updated": { "name": "updated_at", "keep": True, "implication_property": None, }, "resolution_date": { "name": "completed_at", "keep": True, "implication_property": None, }, "duedate": {"name": "due_date", "keep": True, "implication_property": None}, "reporter_email": { "name": "creator", "keep": False, "implication_property": { "implied_entity_type": "from_email", "implied_relationship_name": "is_creator_of", }, }, "assignee_email": { "name": "assignee", "keep": False, "implication_property": { "implied_entity_type": "from_email", "implied_relationship_name": "is_assignee_of", }, }, "key": {"name": "key", "keep": True, "implication_property": None}, "parent": {"name": "parent", "keep": True, "implication_property": None}, }, "GITHUB_PR": { "repo": {"name": "repository", "keep": True, "implication_property": None}, "state": {"name": "state", "keep": True, "implication_property": None}, "num_commits": { "name": "num_commits", "keep": True, "implication_property": None, }, "num_files_changed": { "name": "num_files_changed", "keep": True, "implication_property": None, }, "labels": {"name": "labels", "keep": True, "implication_property": None}, "merged": {"name": "merged", "keep": True, "implication_property": None}, "merged_at": { "name": "merged_at", "keep": True, "implication_property": None, }, "closed_at": { "name": "closed_at", "keep": True, "implication_property": None, }, "created_at": { "name": "created_at", "keep": True, "implication_property": None, }, "updated_at": { "name": "updated_at", "keep": True, "implication_property": None, }, "user": { "name": "creator", "keep": False, "implication_property": { "implied_entity_type": "from_email", "implied_relationship_name": "is_creator_of", }, }, "assignees": { "name": "assignees", "keep": False, "implication_property": { "implied_entity_type": "from_email", "implied_relationship_name": "is_assignee_of", }, }, }, "GITHUB_ISSUE": { "repo": {"name": "repository", "keep": True, "implication_property": None}, "state": {"name": "state", "keep": True, "implication_property": None}, "labels": {"name": "labels", "keep": True, "implication_property": None}, "closed_at": { "name": "closed_at", "keep": True, "implication_property": None, }, "created_at": { "name": "created_at", "keep": True, "implication_property": None, }, "updated_at": { "name": "updated_at", "keep": True, "implication_property": None, }, "user": { "name": "creator", "keep": False, "implication_property": { "implied_entity_type": "from_email", "implied_relationship_name": "is_creator_of", }, }, "assignees": { "name": "assignees", "keep": False, "implication_property": { "implied_entity_type": "from_email", "implied_relationship_name": "is_assignee_of", }, }, }, "FIREFLIES": {}, "ACCOUNT": {}, "OPPORTUNITY": { "name": {"name": "name", "keep": True, "implication_property": None}, "stage_name": {"name": "stage", "keep": True, "implication_property": None}, "type": {"name": "type", "keep": True, "implication_property": None}, "amount": {"name": "amount", "keep": True, "implication_property": None}, "fiscal_year": { "name": "fiscal_year", "keep": True, "implication_property": None, }, "fiscal_quarter": { "name": "fiscal_quarter", "keep": True, "implication_property": None, }, "is_closed": { "name": "is_closed", "keep": True, "implication_property": None, }, "close_date": { "name": "close_date", "keep": True, "implication_property": None, }, "probability": { "name": "close_probability", "keep": True, "implication_property": None, }, "created_date": { "name": "created_at", "keep": True, "implication_property": None, }, "last_modified_date": { "name": "updated_at", "keep": True, "implication_property": None, }, "account": { "name": "account", "keep": False, "implication_property": { "implied_entity_type": "ACCOUNT", "implied_relationship_name": "is_account_of", }, }, }, "VENDOR": {}, "EMPLOYEE": {}, } current_entity_types = conn.execute( sa.text("SELECT id_name, attributes from kg_entity_type") ).all() for entity_type, attributes in current_entity_types: # delete removed entity types if entity_type not in new_entity_type_conversion: op.execute( sa.text(f"DELETE FROM kg_entity_type WHERE id_name = '{entity_type}'") ) continue # update entity type attributes if "metadata_attributes" in attributes: del attributes["metadata_attributes"] attributes["metadata_attribute_conversion"] = new_entity_type_conversion[ entity_type ] attributes_str = json.dumps(attributes).replace("'", "''") op.execute( sa.text( f"UPDATE kg_entity_type SET attributes = '{attributes_str}'WHERE id_name = '{entity_type}'" ), ) def downgrade() -> None: conn = op.get_bind() current_entity_types = conn.execute( sa.text("SELECT id_name, attributes from kg_entity_type") ).all() for entity_type, attributes in current_entity_types: conversion = {} if "metadata_attribute_conversion" in attributes: conversion = attributes.pop("metadata_attribute_conversion") attributes["metadata_attributes"] = { attr: prop["name"] for attr, prop in conversion.items() if prop["keep"] } attributes_str = json.dumps(attributes).replace("'", "''") op.execute( sa.text( f"UPDATE kg_entity_type SET attributes = '{attributes_str}'WHERE id_name = '{entity_type}'" ), ) ================================================ FILE: backend/alembic/versions/7bd55f264e1b_add_display_name_to_model_configuration.py ================================================ """Add display_name to model_configuration Revision ID: 7bd55f264e1b Revises: e8f0d2a38171 Create Date: 2025-12-04 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "7bd55f264e1b" down_revision = "e8f0d2a38171" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "model_configuration", sa.Column("display_name", sa.String(), nullable=True), ) def downgrade() -> None: op.drop_column("model_configuration", "display_name") ================================================ FILE: backend/alembic/versions/7cb492013621_code_interpreter_server_model.py ================================================ """code interpreter server model Revision ID: 7cb492013621 Revises: 0bb4558f35df Create Date: 2026-02-22 18:54:54.007265 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "7cb492013621" down_revision = "0bb4558f35df" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "code_interpreter_server", sa.Column("id", sa.Integer, primary_key=True), sa.Column( "server_enabled", sa.Boolean, nullable=False, server_default=sa.true() ), ) def downgrade() -> None: op.drop_table("code_interpreter_server") ================================================ FILE: backend/alembic/versions/7cc3fcc116c1_user_file_uuid_primary_key_swap.py ================================================ """Migration 4: User file UUID primary key swap Revision ID: 7cc3fcc116c1 Revises: 16c37a30adf2 Create Date: 2025-09-22 09:54:38.292952 This migration performs the critical UUID primary key swap on user_file table. It updates all foreign key references to use UUIDs instead of integers. """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql as psql import logging logger = logging.getLogger("alembic.runtime.migration") # revision identifiers, used by Alembic. revision = "7cc3fcc116c1" down_revision = "16c37a30adf2" branch_labels = None depends_on = None def upgrade() -> None: """Swap user_file primary key from integer to UUID.""" bind = op.get_bind() inspector = sa.inspect(bind) # Verify we're in the expected state user_file_columns = [col["name"] for col in inspector.get_columns("user_file")] if "new_id" not in user_file_columns: logger.warning( "user_file.new_id not found - migration may have already been applied" ) return logger.info("Starting UUID primary key swap...") # === Step 1: Update persona__user_file foreign key to UUID === logger.info("Updating persona__user_file foreign key...") # Drop existing foreign key constraints op.execute( "ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_user_file_id_uuid_fkey" ) op.execute( "ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_user_file_id_fkey" ) # Create new foreign key to user_file.new_id op.create_foreign_key( "persona__user_file_user_file_id_fkey", "persona__user_file", "user_file", local_cols=["user_file_id_uuid"], remote_cols=["new_id"], ) # Drop the old integer column and rename UUID column op.execute("ALTER TABLE persona__user_file DROP COLUMN IF EXISTS user_file_id") op.alter_column( "persona__user_file", "user_file_id_uuid", new_column_name="user_file_id", existing_type=psql.UUID(as_uuid=True), nullable=False, ) # Recreate composite primary key op.execute( "ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_pkey" ) op.execute( "ALTER TABLE persona__user_file ADD PRIMARY KEY (persona_id, user_file_id)" ) logger.info("Updated persona__user_file to use UUID foreign key") # === Step 2: Perform the primary key swap on user_file === logger.info("Swapping user_file primary key to UUID...") # Drop the primary key constraint op.execute("ALTER TABLE user_file DROP CONSTRAINT IF EXISTS user_file_pkey") # Drop the old id column and rename new_id to id op.execute("ALTER TABLE user_file DROP COLUMN IF EXISTS id") op.alter_column( "user_file", "new_id", new_column_name="id", existing_type=psql.UUID(as_uuid=True), nullable=False, ) # Set default for new inserts op.alter_column( "user_file", "id", existing_type=psql.UUID(as_uuid=True), server_default=sa.text("gen_random_uuid()"), ) # Create new primary key op.execute("ALTER TABLE user_file ADD PRIMARY KEY (id)") logger.info("Swapped user_file primary key to UUID") # === Step 3: Update foreign key constraints === logger.info("Updating foreign key constraints...") # Recreate persona__user_file foreign key to point to user_file.id # Drop existing FK first to break dependency on the unique constraint op.execute( "ALTER TABLE persona__user_file DROP CONSTRAINT IF EXISTS persona__user_file_user_file_id_fkey" ) # Drop the unique constraint on (formerly) new_id BEFORE recreating the FK, # so the FK will bind to the primary key instead of the unique index. op.execute("ALTER TABLE user_file DROP CONSTRAINT IF EXISTS uq_user_file_new_id") # Now recreate FK to the primary key column op.create_foreign_key( "persona__user_file_user_file_id_fkey", "persona__user_file", "user_file", local_cols=["user_file_id"], remote_cols=["id"], ) # Add foreign keys for project__user_file existing_fks = inspector.get_foreign_keys("project__user_file") has_user_file_fk = any( fk.get("referred_table") == "user_file" and fk.get("constrained_columns") == ["user_file_id"] for fk in existing_fks ) if not has_user_file_fk: op.create_foreign_key( "fk_project__user_file_user_file_id", "project__user_file", "user_file", ["user_file_id"], ["id"], ) logger.info("Added project__user_file -> user_file foreign key") has_project_fk = any( fk.get("referred_table") == "user_project" and fk.get("constrained_columns") == ["project_id"] for fk in existing_fks ) if not has_project_fk: op.create_foreign_key( "fk_project__user_file_project_id", "project__user_file", "user_project", ["project_id"], ["id"], ) logger.info("Added project__user_file -> user_project foreign key") # === Step 4: Mark files for document_id migration === logger.info("Marking files for background document_id migration...") logger.info("Migration 4 (UUID primary key swap) completed successfully") logger.info( "NOTE: Background task will update document IDs in Vespa and search_doc" ) def downgrade() -> None: """Revert UUID primary key back to integer (data destructive!).""" logger.error("CRITICAL: Downgrading UUID primary key swap is data destructive!") logger.error( "This will break all UUID-based references created after the migration." ) logger.error("Only proceed if absolutely necessary and have backups.") bind = op.get_bind() inspector = sa.inspect(bind) # Capture existing primary key definitions so we can restore them after swaps persona_pk = inspector.get_pk_constraint("persona__user_file") or {} persona_pk_name = persona_pk.get("name") persona_pk_cols = persona_pk.get("constrained_columns") or [] project_pk = inspector.get_pk_constraint("project__user_file") or {} project_pk_name = project_pk.get("name") project_pk_cols = project_pk.get("constrained_columns") or [] # Drop foreign keys that reference the UUID primary key op.drop_constraint( "persona__user_file_user_file_id_fkey", "persona__user_file", type_="foreignkey", ) op.drop_constraint( "fk_project__user_file_user_file_id", "project__user_file", type_="foreignkey", ) # Drop primary keys that rely on the UUID column so we can replace it if persona_pk_name: op.drop_constraint(persona_pk_name, "persona__user_file", type_="primary") if project_pk_name: op.drop_constraint(project_pk_name, "project__user_file", type_="primary") # Rebuild integer IDs on user_file using a sequence-backed column op.execute("CREATE SEQUENCE IF NOT EXISTS user_file_id_seq") op.add_column( "user_file", sa.Column( "id_int", sa.Integer(), server_default=sa.text("nextval('user_file_id_seq')"), nullable=False, ), ) op.execute("ALTER SEQUENCE user_file_id_seq OWNED BY user_file.id_int") # Prepare integer foreign key columns on referencing tables op.add_column( "persona__user_file", sa.Column("user_file_id_int", sa.Integer(), nullable=True), ) op.add_column( "project__user_file", sa.Column("user_file_id_int", sa.Integer(), nullable=True), ) # Populate the new integer foreign key columns by mapping from the UUID IDs op.execute( """ UPDATE persona__user_file AS p SET user_file_id_int = uf.id_int FROM user_file AS uf WHERE p.user_file_id = uf.id """ ) op.execute( """ UPDATE project__user_file AS p SET user_file_id_int = uf.id_int FROM user_file AS uf WHERE p.user_file_id = uf.id """ ) op.alter_column( "persona__user_file", "user_file_id_int", existing_type=sa.Integer(), nullable=False, ) op.alter_column( "project__user_file", "user_file_id_int", existing_type=sa.Integer(), nullable=False, ) # Remove the UUID foreign key columns and rename the integer replacements op.drop_column("persona__user_file", "user_file_id") op.alter_column( "persona__user_file", "user_file_id_int", new_column_name="user_file_id", existing_type=sa.Integer(), nullable=False, ) op.drop_column("project__user_file", "user_file_id") op.alter_column( "project__user_file", "user_file_id_int", new_column_name="user_file_id", existing_type=sa.Integer(), nullable=False, ) # Swap the user_file primary key back to the integer column op.drop_constraint("user_file_pkey", "user_file", type_="primary") op.drop_column("user_file", "id") op.alter_column( "user_file", "id_int", new_column_name="id", existing_type=sa.Integer(), ) op.alter_column( "user_file", "id", existing_type=sa.Integer(), nullable=False, server_default=sa.text("nextval('user_file_id_seq')"), ) op.execute("ALTER SEQUENCE user_file_id_seq OWNED BY user_file.id") op.execute( """ SELECT setval( 'user_file_id_seq', GREATEST(COALESCE(MAX(id), 1), 1), MAX(id) IS NOT NULL ) FROM user_file """ ) op.create_primary_key("user_file_pkey", "user_file", ["id"]) # Restore primary keys on referencing tables if persona_pk_cols: op.create_primary_key( "persona__user_file_pkey", "persona__user_file", persona_pk_cols ) if project_pk_cols: op.create_primary_key( "project__user_file_pkey", "project__user_file", project_pk_cols, ) # Recreate foreign keys pointing at the integer primary key op.create_foreign_key( "persona__user_file_user_file_id_fkey", "persona__user_file", "user_file", ["user_file_id"], ["id"], ) op.create_foreign_key( "fk_project__user_file_user_file_id", "project__user_file", "user_file", ["user_file_id"], ["id"], ) ================================================ FILE: backend/alembic/versions/7ccea01261f6_store_chat_retrieval_docs.py ================================================ """Store Chat Retrieval Docs Revision ID: 7ccea01261f6 Revises: a570b80a5f20 Create Date: 2023-10-15 10:39:23.317453 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "7ccea01261f6" down_revision = "a570b80a5f20" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "chat_message", sa.Column( "reference_docs", postgresql.JSONB(astext_type=sa.Text()), nullable=True, ), ) def downgrade() -> None: op.drop_column("chat_message", "reference_docs") ================================================ FILE: backend/alembic/versions/7da0ae5ad583_add_description_to_persona.py ================================================ """Add description to persona Revision ID: 7da0ae5ad583 Revises: e86866a9c78a Create Date: 2023-11-27 00:16:19.959414 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "7da0ae5ad583" down_revision = "e86866a9c78a" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column("persona", sa.Column("description", sa.String(), nullable=True)) def downgrade() -> None: op.drop_column("persona", "description") ================================================ FILE: backend/alembic/versions/7da543f5672f_add_slackbotconfig_table.py ================================================ """Add SlackBotConfig table Revision ID: 7da543f5672f Revises: febe9eaa0644 Create Date: 2023-09-24 16:34:17.526128 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "7da543f5672f" down_revision = "febe9eaa0644" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "slack_bot_config", sa.Column("id", sa.Integer(), nullable=False), sa.Column("persona_id", sa.Integer(), nullable=True), sa.Column( "channel_config", postgresql.JSONB(astext_type=sa.Text()), nullable=False, ), sa.ForeignKeyConstraint( ["persona_id"], ["persona.id"], ), sa.PrimaryKeyConstraint("id"), ) def downgrade() -> None: op.drop_table("slack_bot_config") ================================================ FILE: backend/alembic/versions/7e490836d179_nullify_default_system_prompt.py ================================================ """nullify_default_system_prompt Revision ID: 7e490836d179 Revises: c1d2e3f4a5b6 Create Date: 2025-12-29 16:54:36.635574 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "7e490836d179" down_revision = "c1d2e3f4a5b6" branch_labels = None depends_on = None # This is the default system prompt from the previous migration (87c52ec39f84) # ruff: noqa: E501, W605 start PREVIOUS_DEFAULT_SYSTEM_PROMPT = """ You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the user's intent, ask clarifying questions when needed, think step-by-step through complex problems, provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always prioritize being truthful, nuanced, insightful, and efficient. The current date is [[CURRENT_DATETIME]].[[CITATION_GUIDANCE]] # Response Style You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make your responses more readable and engaging. You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline. For code you prefer to use Markdown and specify the language. You can use horizontal rules (---) to separate sections of your responses. You can use Markdown tables to format your responses for data, lists, and other structured information. """.lstrip() # ruff: noqa: E501, W605 end def upgrade() -> None: # Make system_prompt column nullable (model already has nullable=True but DB doesn't) op.alter_column( "persona", "system_prompt", nullable=True, ) # Set system_prompt to NULL where it matches the previous default conn = op.get_bind() conn.execute( sa.text( """ UPDATE persona SET system_prompt = NULL WHERE system_prompt = :previous_default """ ), {"previous_default": PREVIOUS_DEFAULT_SYSTEM_PROMPT}, ) def downgrade() -> None: # Restore the default system prompt for personas that have NULL # Note: This may restore the prompt to personas that originally had NULL # before this migration, but there's no way to distinguish them conn = op.get_bind() conn.execute( sa.text( """ UPDATE persona SET system_prompt = :previous_default WHERE system_prompt IS NULL """ ), {"previous_default": PREVIOUS_DEFAULT_SYSTEM_PROMPT}, ) # Revert system_prompt column to not nullable op.alter_column( "persona", "system_prompt", nullable=False, ) ================================================ FILE: backend/alembic/versions/7ed603b64d5a_add_mcp_server_and_connection_config_.py ================================================ """add_mcp_server_and_connection_config_models Revision ID: 7ed603b64d5a Revises: b329d00a9ea6 Create Date: 2025-07-28 17:35:59.900680 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql from onyx.db.enums import MCPAuthenticationType # revision identifiers, used by Alembic. revision = "7ed603b64d5a" down_revision = "b329d00a9ea6" branch_labels = None depends_on = None def upgrade() -> None: """Create tables and columns for MCP Server support""" # 1. MCP Server main table (no FK constraints yet to avoid circular refs) op.create_table( "mcp_server", sa.Column("id", sa.Integer(), primary_key=True), sa.Column("owner", sa.String(), nullable=False), sa.Column("name", sa.String(), nullable=False), sa.Column("description", sa.String(), nullable=True), sa.Column("server_url", sa.String(), nullable=False), sa.Column( "auth_type", sa.Enum( MCPAuthenticationType, name="mcp_authentication_type", native_enum=False, ), nullable=False, ), sa.Column("admin_connection_config_id", sa.Integer(), nullable=True), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column( "updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), ) # 2. MCP Connection Config table (can reference mcp_server now that it exists) op.create_table( "mcp_connection_config", sa.Column("id", sa.Integer(), primary_key=True), sa.Column("mcp_server_id", sa.Integer(), nullable=True), sa.Column("user_email", sa.String(), nullable=False, default=""), sa.Column("config", sa.LargeBinary(), nullable=False), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column( "updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.ForeignKeyConstraint( ["mcp_server_id"], ["mcp_server.id"], ondelete="CASCADE" ), ) # Helpful indexes op.create_index( "ix_mcp_connection_config_server_user", "mcp_connection_config", ["mcp_server_id", "user_email"], ) op.create_index( "ix_mcp_connection_config_user_email", "mcp_connection_config", ["user_email"], ) # 3. Add the back-references from mcp_server to connection configs op.create_foreign_key( "mcp_server_admin_config_fk", "mcp_server", "mcp_connection_config", ["admin_connection_config_id"], ["id"], ondelete="SET NULL", ) # 4. Association / access-control tables op.create_table( "mcp_server__user", sa.Column("mcp_server_id", sa.Integer(), primary_key=True), sa.Column("user_id", sa.UUID(), primary_key=True), sa.ForeignKeyConstraint( ["mcp_server_id"], ["mcp_server.id"], ondelete="CASCADE" ), sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), ) op.create_table( "mcp_server__user_group", sa.Column("mcp_server_id", sa.Integer(), primary_key=True), sa.Column("user_group_id", sa.Integer(), primary_key=True), sa.ForeignKeyConstraint( ["mcp_server_id"], ["mcp_server.id"], ondelete="CASCADE" ), sa.ForeignKeyConstraint(["user_group_id"], ["user_group.id"]), ) # 5. Update existing `tool` table – allow tools to belong to an MCP server op.add_column( "tool", sa.Column("mcp_server_id", sa.Integer(), nullable=True), ) # Add column for MCP tool input schema op.add_column( "tool", sa.Column("mcp_input_schema", postgresql.JSONB(), nullable=True), ) op.create_foreign_key( "tool_mcp_server_fk", "tool", "mcp_server", ["mcp_server_id"], ["id"], ondelete="CASCADE", ) # 6. Update persona__tool foreign keys to cascade delete # This ensures that when a tool is deleted (including via MCP server deletion), # the corresponding persona__tool rows are also deleted op.drop_constraint( "persona__tool_tool_id_fkey", "persona__tool", type_="foreignkey" ) op.drop_constraint( "persona__tool_persona_id_fkey", "persona__tool", type_="foreignkey" ) op.create_foreign_key( "persona__tool_persona_id_fkey", "persona__tool", "persona", ["persona_id"], ["id"], ondelete="CASCADE", ) op.create_foreign_key( "persona__tool_tool_id_fkey", "persona__tool", "tool", ["tool_id"], ["id"], ondelete="CASCADE", ) # 7. Update research_agent_iteration_sub_step foreign key to SET NULL on delete # This ensures that when a tool is deleted, the sub_step_tool_id is set to NULL # instead of causing a foreign key constraint violation op.drop_constraint( "research_agent_iteration_sub_step_sub_step_tool_id_fkey", "research_agent_iteration_sub_step", type_="foreignkey", ) op.create_foreign_key( "research_agent_iteration_sub_step_sub_step_tool_id_fkey", "research_agent_iteration_sub_step", "tool", ["sub_step_tool_id"], ["id"], ondelete="SET NULL", ) def downgrade() -> None: """Drop all MCP-related tables / columns""" # # # 1. Drop FK & columns from tool # op.drop_constraint("tool_mcp_server_fk", "tool", type_="foreignkey") op.execute("DELETE FROM tool WHERE mcp_server_id IS NOT NULL") op.drop_constraint( "research_agent_iteration_sub_step_sub_step_tool_id_fkey", "research_agent_iteration_sub_step", type_="foreignkey", ) op.create_foreign_key( "research_agent_iteration_sub_step_sub_step_tool_id_fkey", "research_agent_iteration_sub_step", "tool", ["sub_step_tool_id"], ["id"], ) # Restore original persona__tool foreign keys (without CASCADE) op.drop_constraint( "persona__tool_persona_id_fkey", "persona__tool", type_="foreignkey" ) op.drop_constraint( "persona__tool_tool_id_fkey", "persona__tool", type_="foreignkey" ) op.create_foreign_key( "persona__tool_persona_id_fkey", "persona__tool", "persona", ["persona_id"], ["id"], ) op.create_foreign_key( "persona__tool_tool_id_fkey", "persona__tool", "tool", ["tool_id"], ["id"], ) op.drop_column("tool", "mcp_input_schema") op.drop_column("tool", "mcp_server_id") # 2. Drop association tables op.drop_table("mcp_server__user_group") op.drop_table("mcp_server__user") # 3. Drop FK from mcp_server to connection configs op.drop_constraint("mcp_server_admin_config_fk", "mcp_server", type_="foreignkey") # 4. Drop connection config indexes & table op.drop_index( "ix_mcp_connection_config_user_email", table_name="mcp_connection_config" ) op.drop_index( "ix_mcp_connection_config_server_user", table_name="mcp_connection_config" ) op.drop_table("mcp_connection_config") # 5. Finally drop mcp_server table op.drop_table("mcp_server") ================================================ FILE: backend/alembic/versions/7f726bad5367_slack_followup.py ================================================ """Slack Followup Revision ID: 7f726bad5367 Revises: 79acd316403a Create Date: 2024-01-15 00:19:55.991224 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "7f726bad5367" down_revision = "79acd316403a" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "chat_feedback", sa.Column("required_followup", sa.Boolean(), nullable=True), ) def downgrade() -> None: op.drop_column("chat_feedback", "required_followup") ================================================ FILE: backend/alembic/versions/7f99be1cb9f5_add_index_for_getting_documents_just_by_.py ================================================ """Add index for getting documents just by connector id / credential id Revision ID: 7f99be1cb9f5 Revises: 78dbe7e38469 Create Date: 2023-10-15 22:48:15.487762 """ from alembic import op # revision identifiers, used by Alembic. revision = "7f99be1cb9f5" down_revision = "78dbe7e38469" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_index( op.f( "ix_document_by_connector_credential_pair_pkey__connector_id__credential_id" ), "document_by_connector_credential_pair", ["connector_id", "credential_id"], unique=False, ) def downgrade() -> None: op.drop_index( op.f( "ix_document_by_connector_credential_pair_pkey__connector_id__credential_id" ), table_name="document_by_connector_credential_pair", ) ================================================ FILE: backend/alembic/versions/800f48024ae9_add_id_to_connectorcredentialpair.py ================================================ """Add ID to ConnectorCredentialPair Revision ID: 800f48024ae9 Revises: 767f1c2a00eb Create Date: 2023-09-19 16:13:42.299715 """ from alembic import op import sqlalchemy as sa from sqlalchemy.schema import Sequence, CreateSequence # revision identifiers, used by Alembic. revision = "800f48024ae9" down_revision = "767f1c2a00eb" branch_labels: None = None depends_on: None = None def upgrade() -> None: sequence = Sequence("connector_credential_pair_id_seq") op.execute(CreateSequence(sequence)) # type: ignore op.add_column( "connector_credential_pair", sa.Column( "id", sa.Integer(), nullable=True, server_default=sequence.next_value() ), ) op.add_column( "connector_credential_pair", sa.Column("name", sa.String(), nullable=True), ) # fill in IDs for existing rows op.execute( "UPDATE connector_credential_pair SET id = nextval('connector_credential_pair_id_seq') WHERE id IS NULL" ) op.alter_column("connector_credential_pair", "id", nullable=False) op.create_unique_constraint( "connector_credential_pair__name__key", "connector_credential_pair", ["name"] ) op.create_unique_constraint( "connector_credential_pair__id__key", "connector_credential_pair", ["id"] ) def downgrade() -> None: op.drop_constraint( "connector_credential_pair__name__key", "connector_credential_pair", type_="unique", ) op.drop_constraint( "connector_credential_pair__id__key", "connector_credential_pair", type_="unique", ) op.drop_column("connector_credential_pair", "name") op.drop_column("connector_credential_pair", "id") op.execute("DROP SEQUENCE connector_credential_pair_id_seq") ================================================ FILE: backend/alembic/versions/80696cf850ae_add_chat_session_to_query_event.py ================================================ """Add chat session to query_event Revision ID: 80696cf850ae Revises: 15326fcec57e Create Date: 2023-11-26 02:38:35.008070 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "80696cf850ae" down_revision = "15326fcec57e" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "query_event", sa.Column("chat_session_id", sa.Integer(), nullable=True), ) op.create_foreign_key( "fk_query_event_chat_session_id", "query_event", "chat_session", ["chat_session_id"], ["id"], ) def downgrade() -> None: op.drop_constraint( "fk_query_event_chat_session_id", "query_event", type_="foreignkey" ) op.drop_column("query_event", "chat_session_id") ================================================ FILE: backend/alembic/versions/8188861f4e92_csv_to_tabular_chat_file_type.py ================================================ """csv to tabular chat file type Revision ID: 8188861f4e92 Revises: d8cdfee5df80 Create Date: 2026-03-31 19:23:05.753184 """ from alembic import op # revision identifiers, used by Alembic. revision = "8188861f4e92" down_revision = "d8cdfee5df80" branch_labels = None depends_on = None def upgrade() -> None: op.execute( """ UPDATE chat_message SET files = ( SELECT jsonb_agg( CASE WHEN elem->>'type' = 'csv' THEN jsonb_set(elem, '{type}', '"tabular"') ELSE elem END ) FROM jsonb_array_elements(files) AS elem ) WHERE files::text LIKE '%"type": "csv"%' """ ) def downgrade() -> None: op.execute( """ UPDATE chat_message SET files = ( SELECT jsonb_agg( CASE WHEN elem->>'type' = 'tabular' THEN jsonb_set(elem, '{type}', '"csv"') ELSE elem END ) FROM jsonb_array_elements(files) AS elem ) WHERE files::text LIKE '%"type": "tabular"%' """ ) ================================================ FILE: backend/alembic/versions/81c22b1e2e78_hierarchy_nodes_v1.py ================================================ """hierarchy_nodes_v1 Revision ID: 81c22b1e2e78 Revises: 72aa7de2e5cf Create Date: 2026-01-13 18:10:01.021451 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql from onyx.configs.constants import DocumentSource # revision identifiers, used by Alembic. revision = "81c22b1e2e78" down_revision = "72aa7de2e5cf" branch_labels = None depends_on = None # Human-readable display names for each source SOURCE_DISPLAY_NAMES: dict[str, str] = { "ingestion_api": "Ingestion API", "slack": "Slack", "web": "Web", "google_drive": "Google Drive", "gmail": "Gmail", "requesttracker": "Request Tracker", "github": "GitHub", "gitbook": "GitBook", "gitlab": "GitLab", "guru": "Guru", "bookstack": "BookStack", "outline": "Outline", "confluence": "Confluence", "jira": "Jira", "slab": "Slab", "productboard": "Productboard", "file": "File", "coda": "Coda", "notion": "Notion", "zulip": "Zulip", "linear": "Linear", "hubspot": "HubSpot", "document360": "Document360", "gong": "Gong", "google_sites": "Google Sites", "zendesk": "Zendesk", "loopio": "Loopio", "dropbox": "Dropbox", "sharepoint": "SharePoint", "teams": "Teams", "salesforce": "Salesforce", "discourse": "Discourse", "axero": "Axero", "clickup": "ClickUp", "mediawiki": "MediaWiki", "wikipedia": "Wikipedia", "asana": "Asana", "s3": "S3", "r2": "R2", "google_cloud_storage": "Google Cloud Storage", "oci_storage": "OCI Storage", "xenforo": "XenForo", "not_applicable": "Not Applicable", "discord": "Discord", "freshdesk": "Freshdesk", "fireflies": "Fireflies", "egnyte": "Egnyte", "airtable": "Airtable", "highspot": "Highspot", "drupal_wiki": "Drupal Wiki", "imap": "IMAP", "bitbucket": "Bitbucket", "testrail": "TestRail", "mock_connector": "Mock Connector", "user_file": "User File", } def upgrade() -> None: # 1. Create hierarchy_node table op.create_table( "hierarchy_node", sa.Column("id", sa.Integer(), nullable=False), sa.Column("raw_node_id", sa.String(), nullable=False), sa.Column("display_name", sa.String(), nullable=False), sa.Column("link", sa.String(), nullable=True), sa.Column("source", sa.String(), nullable=False), sa.Column("node_type", sa.String(), nullable=False), sa.Column("document_id", sa.String(), nullable=True), sa.Column("parent_id", sa.Integer(), nullable=True), # Permission fields - same pattern as Document table sa.Column( "external_user_emails", postgresql.ARRAY(sa.String()), nullable=True, ), sa.Column( "external_user_group_ids", postgresql.ARRAY(sa.String()), nullable=True, ), sa.Column("is_public", sa.Boolean(), nullable=False, server_default="false"), sa.PrimaryKeyConstraint("id"), # When document is deleted, just unlink (node can exist without document) sa.ForeignKeyConstraint(["document_id"], ["document.id"], ondelete="SET NULL"), # When parent node is deleted, orphan children (cleanup via pruning) sa.ForeignKeyConstraint( ["parent_id"], ["hierarchy_node.id"], ondelete="SET NULL" ), sa.UniqueConstraint( "raw_node_id", "source", name="uq_hierarchy_node_raw_id_source" ), ) op.create_index("ix_hierarchy_node_parent_id", "hierarchy_node", ["parent_id"]) op.create_index( "ix_hierarchy_node_source_type", "hierarchy_node", ["source", "node_type"] ) # Add partial unique index to ensure only one SOURCE-type node per source # This prevents duplicate source root nodes from being created # NOTE: node_type stores enum NAME ('SOURCE'), not value ('source') op.execute( sa.text( """ CREATE UNIQUE INDEX uq_hierarchy_node_one_source_per_type ON hierarchy_node (source) WHERE node_type = 'SOURCE' """ ) ) # 2. Create hierarchy_fetch_attempt table op.create_table( "hierarchy_fetch_attempt", sa.Column("id", postgresql.UUID(as_uuid=True), nullable=False), sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False), sa.Column("status", sa.String(), nullable=False), sa.Column("nodes_fetched", sa.Integer(), nullable=True, server_default="0"), sa.Column("nodes_updated", sa.Integer(), nullable=True, server_default="0"), sa.Column("error_msg", sa.Text(), nullable=True), sa.Column("full_exception_trace", sa.Text(), nullable=True), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False, ), sa.Column("time_started", sa.DateTime(timezone=True), nullable=True), sa.Column( "time_updated", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False, ), sa.PrimaryKeyConstraint("id"), sa.ForeignKeyConstraint( ["connector_credential_pair_id"], ["connector_credential_pair.id"], ondelete="CASCADE", ), ) op.create_index( "ix_hierarchy_fetch_attempt_status", "hierarchy_fetch_attempt", ["status"] ) op.create_index( "ix_hierarchy_fetch_attempt_time_created", "hierarchy_fetch_attempt", ["time_created"], ) op.create_index( "ix_hierarchy_fetch_attempt_cc_pair", "hierarchy_fetch_attempt", ["connector_credential_pair_id"], ) # 3. Insert SOURCE-type hierarchy nodes for each DocumentSource # We insert these so every existing document can have a parent hierarchy node # NOTE: SQLAlchemy's Enum with native_enum=False stores the enum NAME (e.g., 'GOOGLE_DRIVE'), # not the VALUE (e.g., 'google_drive'). We must use .name for source and node_type columns. # SOURCE nodes are always public since they're just categorical roots. for source in DocumentSource: source_name = ( source.name ) # e.g., 'GOOGLE_DRIVE' - what SQLAlchemy stores/expects source_value = source.value # e.g., 'google_drive' - the raw_node_id display_name = SOURCE_DISPLAY_NAMES.get( source_value, source_value.replace("_", " ").title() ) op.execute( sa.text( """ INSERT INTO hierarchy_node (raw_node_id, display_name, source, node_type, parent_id, is_public) VALUES (:raw_node_id, :display_name, :source, 'SOURCE', NULL, true) ON CONFLICT (raw_node_id, source) DO NOTHING """ ).bindparams( raw_node_id=source_value, # Use .value for raw_node_id (human-readable identifier) display_name=display_name, source=source_name, # Use .name for source column (SQLAlchemy enum storage) ) ) # 4. Add parent_hierarchy_node_id column to document table op.add_column( "document", sa.Column("parent_hierarchy_node_id", sa.Integer(), nullable=True), ) # When hierarchy node is deleted, just unlink the document (SET NULL) op.create_foreign_key( "fk_document_parent_hierarchy_node", "document", "hierarchy_node", ["parent_hierarchy_node_id"], ["id"], ondelete="SET NULL", ) op.create_index( "ix_document_parent_hierarchy_node_id", "document", ["parent_hierarchy_node_id"], ) # 5. Set all existing documents' parent_hierarchy_node_id to their source's SOURCE node # For documents with multiple connectors, we pick one source deterministically (MIN connector_id) # NOTE: Both connector.source and hierarchy_node.source store enum NAMEs (e.g., 'GOOGLE_DRIVE') # because SQLAlchemy Enum(native_enum=False) uses the enum name for storage. op.execute( sa.text( """ UPDATE document d SET parent_hierarchy_node_id = hn.id FROM ( -- Get the source for each document (pick MIN connector_id for determinism) SELECT DISTINCT ON (dbcc.id) dbcc.id as doc_id, c.source as source FROM document_by_connector_credential_pair dbcc JOIN connector c ON dbcc.connector_id = c.id ORDER BY dbcc.id, dbcc.connector_id ) doc_source JOIN hierarchy_node hn ON hn.source = doc_source.source AND hn.node_type = 'SOURCE' WHERE d.id = doc_source.doc_id """ ) ) # Create the persona__hierarchy_node association table op.create_table( "persona__hierarchy_node", sa.Column("persona_id", sa.Integer(), nullable=False), sa.Column("hierarchy_node_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["persona_id"], ["persona.id"], ondelete="CASCADE", ), sa.ForeignKeyConstraint( ["hierarchy_node_id"], ["hierarchy_node.id"], ondelete="CASCADE", ), sa.PrimaryKeyConstraint("persona_id", "hierarchy_node_id"), ) # Add index for efficient lookups op.create_index( "ix_persona__hierarchy_node_hierarchy_node_id", "persona__hierarchy_node", ["hierarchy_node_id"], ) # Create the persona__document association table for attaching individual # documents directly to assistants op.create_table( "persona__document", sa.Column("persona_id", sa.Integer(), nullable=False), sa.Column("document_id", sa.String(), nullable=False), sa.ForeignKeyConstraint( ["persona_id"], ["persona.id"], ondelete="CASCADE", ), sa.ForeignKeyConstraint( ["document_id"], ["document.id"], ondelete="CASCADE", ), sa.PrimaryKeyConstraint("persona_id", "document_id"), ) # Add index for efficient lookups by document_id op.create_index( "ix_persona__document_document_id", "persona__document", ["document_id"], ) # 6. Add last_time_hierarchy_fetch column to connector_credential_pair table op.add_column( "connector_credential_pair", sa.Column( "last_time_hierarchy_fetch", sa.DateTime(timezone=True), nullable=True ), ) def downgrade() -> None: # Remove last_time_hierarchy_fetch from connector_credential_pair op.drop_column("connector_credential_pair", "last_time_hierarchy_fetch") # Drop persona__document table op.drop_index("ix_persona__document_document_id", table_name="persona__document") op.drop_table("persona__document") # Drop persona__hierarchy_node table op.drop_index( "ix_persona__hierarchy_node_hierarchy_node_id", table_name="persona__hierarchy_node", ) op.drop_table("persona__hierarchy_node") # Remove parent_hierarchy_node_id from document op.drop_index("ix_document_parent_hierarchy_node_id", table_name="document") op.drop_constraint( "fk_document_parent_hierarchy_node", "document", type_="foreignkey" ) op.drop_column("document", "parent_hierarchy_node_id") # Drop hierarchy_fetch_attempt table op.drop_index( "ix_hierarchy_fetch_attempt_cc_pair", table_name="hierarchy_fetch_attempt" ) op.drop_index( "ix_hierarchy_fetch_attempt_time_created", table_name="hierarchy_fetch_attempt" ) op.drop_index( "ix_hierarchy_fetch_attempt_status", table_name="hierarchy_fetch_attempt" ) op.drop_table("hierarchy_fetch_attempt") # Drop hierarchy_node table op.drop_index("uq_hierarchy_node_one_source_per_type", table_name="hierarchy_node") op.drop_index("ix_hierarchy_node_source_type", table_name="hierarchy_node") op.drop_index("ix_hierarchy_node_parent_id", table_name="hierarchy_node") op.drop_table("hierarchy_node") ================================================ FILE: backend/alembic/versions/8405ca81cc83_notifications_constraint.py ================================================ """notifications constraint, sort index, and cleanup old notifications Revision ID: 8405ca81cc83 Revises: a3c1a7904cd0 Create Date: 2026-01-07 16:43:44.855156 """ from alembic import op # revision identifiers, used by Alembic. revision = "8405ca81cc83" down_revision = "a3c1a7904cd0" branch_labels = None depends_on = None def upgrade() -> None: # Create unique index for notification deduplication. # This enables atomic ON CONFLICT DO NOTHING inserts in batch_create_notifications. # # Uses COALESCE to handle NULL additional_data (NULLs are normally distinct # in unique constraints, but we want NULL == NULL for deduplication). # The '{}' represents an empty JSONB object as the NULL replacement. # Clean up legacy notifications first op.execute("DELETE FROM notification WHERE title = 'New Notification'") op.execute( """ CREATE UNIQUE INDEX IF NOT EXISTS ix_notification_user_type_data ON notification (user_id, notif_type, COALESCE(additional_data, '{}'::jsonb)) """ ) # Create index for efficient notification sorting by user # Covers: WHERE user_id = ? ORDER BY dismissed, first_shown DESC op.execute( """ CREATE INDEX IF NOT EXISTS ix_notification_user_sort ON notification (user_id, dismissed, first_shown DESC) """ ) def downgrade() -> None: op.execute("DROP INDEX IF EXISTS ix_notification_user_type_data") op.execute("DROP INDEX IF EXISTS ix_notification_user_sort") ================================================ FILE: backend/alembic/versions/849b21c732f8_add_demo_data_enabled_to_build_session.py ================================================ """add demo_data_enabled to build_session Revision ID: 849b21c732f8 Revises: 81c22b1e2e78 Create Date: 2026-01-28 10:00:00.000000 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "849b21c732f8" down_revision = "81c22b1e2e78" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "build_session", sa.Column( "demo_data_enabled", sa.Boolean(), nullable=False, server_default=sa.text("true"), ), ) def downgrade() -> None: op.drop_column("build_session", "demo_data_enabled") ================================================ FILE: backend/alembic/versions/87c52ec39f84_update_default_system_prompt.py ================================================ """update_default_system_prompt Revision ID: 87c52ec39f84 Revises: 7bd55f264e1b Create Date: 2025-12-05 15:54:06.002452 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "87c52ec39f84" down_revision = "7bd55f264e1b" branch_labels = None depends_on = None DEFAULT_PERSONA_ID = 0 # ruff: noqa: E501, W605 start DEFAULT_SYSTEM_PROMPT = """ You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the user's intent, ask clarifying questions when needed, think step-by-step through complex problems, provide clear and accurate answers, and proactively anticipate helpful follow-up information. Always prioritize being truthful, nuanced, insightful, and efficient. The current date is [[CURRENT_DATETIME]].[[CITATION_GUIDANCE]] # Response Style You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make your responses more readable and engaging. You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline. For code you prefer to use Markdown and specify the language. You can use horizontal rules (---) to separate sections of your responses. You can use Markdown tables to format your responses for data, lists, and other structured information. """.lstrip() # ruff: noqa: E501, W605 end def upgrade() -> None: conn = op.get_bind() conn.execute( sa.text( """ UPDATE persona SET system_prompt = :system_prompt WHERE id = :persona_id """ ), {"system_prompt": DEFAULT_SYSTEM_PROMPT, "persona_id": DEFAULT_PERSONA_ID}, ) def downgrade() -> None: # We don't revert the system prompt on downgrade since we don't know # what the previous value was. The new prompt is a reasonable default. pass ================================================ FILE: backend/alembic/versions/8818cf73fa1a_drop_include_citations.py ================================================ """drop include citations Revision ID: 8818cf73fa1a Revises: 7ed603b64d5a Create Date: 2025-09-02 19:43:50.060680 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "8818cf73fa1a" down_revision = "7ed603b64d5a" branch_labels = None depends_on = None def upgrade() -> None: op.drop_column("prompt", "include_citations") def downgrade() -> None: op.add_column( "prompt", sa.Column( "include_citations", sa.BOOLEAN(), autoincrement=False, nullable=True, ), ) # Set include_citations based on prompt name: FALSE for ImageGeneration, TRUE for others op.execute( sa.text( "UPDATE prompt SET include_citations = CASE WHEN name = 'ImageGeneration' THEN FALSE ELSE TRUE END" ) ) ================================================ FILE: backend/alembic/versions/891cd83c87a8_add_is_visible_to_persona.py ================================================ """Add is_visible to Persona Revision ID: 891cd83c87a8 Revises: 76b60d407dfb Create Date: 2023-12-21 11:55:54.132279 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "891cd83c87a8" down_revision = "76b60d407dfb" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "persona", sa.Column("is_visible", sa.Boolean(), nullable=True), ) op.execute("UPDATE persona SET is_visible = true") op.alter_column("persona", "is_visible", nullable=False) op.add_column( "persona", sa.Column("display_priority", sa.Integer(), nullable=True), ) def downgrade() -> None: op.drop_column("persona", "is_visible") op.drop_column("persona", "display_priority") ================================================ FILE: backend/alembic/versions/8987770549c0_add_full_exception_stack_trace.py ================================================ """Add full exception stack trace Revision ID: 8987770549c0 Revises: ec3ec2eabf7b Create Date: 2024-02-10 19:31:28.339135 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "8987770549c0" down_revision = "ec3ec2eabf7b" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "index_attempt", sa.Column("full_exception_trace", sa.Text(), nullable=True) ) def downgrade() -> None: op.drop_column("index_attempt", "full_exception_trace") ================================================ FILE: backend/alembic/versions/8a87bd6ec550_associate_index_attempts_with_ccpair.py ================================================ """associate index attempts with ccpair Revision ID: 8a87bd6ec550 Revises: 4ea2c93919c1 Create Date: 2024-07-22 15:15:52.558451 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "8a87bd6ec550" down_revision = "4ea2c93919c1" branch_labels: None = None depends_on: None = None def upgrade() -> None: # Add the new connector_credential_pair_id column op.add_column( "index_attempt", sa.Column("connector_credential_pair_id", sa.Integer(), nullable=True), ) # Create a foreign key constraint to the connector_credential_pair table op.create_foreign_key( "fk_index_attempt_connector_credential_pair_id", "index_attempt", "connector_credential_pair", ["connector_credential_pair_id"], ["id"], ) # Populate the new connector_credential_pair_id column using existing connector_id and credential_id op.execute( """ UPDATE index_attempt ia SET connector_credential_pair_id = ( SELECT id FROM connector_credential_pair ccp WHERE (ia.connector_id IS NULL OR ccp.connector_id = ia.connector_id) AND (ia.credential_id IS NULL OR ccp.credential_id = ia.credential_id) LIMIT 1 ) WHERE ia.connector_id IS NOT NULL OR ia.credential_id IS NOT NULL """ ) # For good measure op.execute( """ DELETE FROM index_attempt WHERE connector_credential_pair_id IS NULL """ ) # Make the new connector_credential_pair_id column non-nullable op.alter_column("index_attempt", "connector_credential_pair_id", nullable=False) # Drop the old connector_id and credential_id columns op.drop_column("index_attempt", "connector_id") op.drop_column("index_attempt", "credential_id") # Update the index to use connector_credential_pair_id op.create_index( "ix_index_attempt_latest_for_connector_credential_pair", "index_attempt", ["connector_credential_pair_id", "time_created"], ) def downgrade() -> None: # Add back the old connector_id and credential_id columns op.add_column( "index_attempt", sa.Column("connector_id", sa.Integer(), nullable=True) ) op.add_column( "index_attempt", sa.Column("credential_id", sa.Integer(), nullable=True) ) # Populate the old connector_id and credential_id columns using the connector_credential_pair_id op.execute( """ UPDATE index_attempt ia SET connector_id = ccp.connector_id, credential_id = ccp.credential_id FROM connector_credential_pair ccp WHERE ia.connector_credential_pair_id = ccp.id """ ) # Make the old connector_id and credential_id columns non-nullable op.alter_column("index_attempt", "connector_id", nullable=False) op.alter_column("index_attempt", "credential_id", nullable=False) # Drop the new connector_credential_pair_id column op.drop_constraint( "fk_index_attempt_connector_credential_pair_id", "index_attempt", type_="foreignkey", ) op.drop_column("index_attempt", "connector_credential_pair_id") op.create_index( "ix_index_attempt_latest_for_connector_credential_pair", "index_attempt", ["connector_id", "credential_id", "time_created"], ) ================================================ FILE: backend/alembic/versions/8aabb57f3b49_restructure_document_indices.py ================================================ """Restructure Document Indices Revision ID: 8aabb57f3b49 Revises: 5e84129c8be3 Create Date: 2023-08-18 21:15:57.629515 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "8aabb57f3b49" down_revision = "5e84129c8be3" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.drop_table("chunk") op.execute("DROP TYPE IF EXISTS documentstoretype") def downgrade() -> None: op.create_table( "chunk", sa.Column("id", sa.VARCHAR(), autoincrement=False, nullable=False), sa.Column( "document_store_type", postgresql.ENUM("VECTOR", "KEYWORD", name="documentstoretype"), autoincrement=False, nullable=False, ), sa.Column("document_id", sa.VARCHAR(), autoincrement=False, nullable=False), sa.ForeignKeyConstraint( ["document_id"], ["document.id"], name="chunk_document_id_fkey" ), sa.PrimaryKeyConstraint("id", "document_store_type", name="chunk_pkey"), ) ================================================ FILE: backend/alembic/versions/8b5ce697290e_add_discord_bot_tables.py ================================================ """Add Discord bot tables Revision ID: 8b5ce697290e Revises: a1b2c3d4e5f7 Create Date: 2025-01-14 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "8b5ce697290e" down_revision = "a1b2c3d4e5f7" branch_labels: None = None depends_on: None = None def upgrade() -> None: # DiscordBotConfig (singleton table - one per tenant) op.create_table( "discord_bot_config", sa.Column( "id", sa.String(), primary_key=True, server_default=sa.text("'SINGLETON'"), ), sa.Column("bot_token", sa.LargeBinary(), nullable=False), # EncryptedString sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False, ), sa.CheckConstraint("id = 'SINGLETON'", name="ck_discord_bot_config_singleton"), ) # DiscordGuildConfig op.create_table( "discord_guild_config", sa.Column("id", sa.Integer(), primary_key=True), sa.Column("guild_id", sa.BigInteger(), nullable=True, unique=True), sa.Column("guild_name", sa.String(), nullable=True), sa.Column("registration_key", sa.String(), nullable=False, unique=True), sa.Column("registered_at", sa.DateTime(timezone=True), nullable=True), sa.Column( "default_persona_id", sa.Integer(), sa.ForeignKey("persona.id", ondelete="SET NULL"), nullable=True, ), sa.Column( "enabled", sa.Boolean(), server_default=sa.text("true"), nullable=False ), ) # DiscordChannelConfig op.create_table( "discord_channel_config", sa.Column("id", sa.Integer(), primary_key=True), sa.Column( "guild_config_id", sa.Integer(), sa.ForeignKey("discord_guild_config.id", ondelete="CASCADE"), nullable=False, ), sa.Column("channel_id", sa.BigInteger(), nullable=False), sa.Column("channel_name", sa.String(), nullable=False), sa.Column( "channel_type", sa.String(20), server_default=sa.text("'text'"), nullable=False, ), sa.Column( "is_private", sa.Boolean(), server_default=sa.text("false"), nullable=False, ), sa.Column( "thread_only_mode", sa.Boolean(), server_default=sa.text("false"), nullable=False, ), sa.Column( "require_bot_invocation", sa.Boolean(), server_default=sa.text("true"), nullable=False, ), sa.Column( "persona_override_id", sa.Integer(), sa.ForeignKey("persona.id", ondelete="SET NULL"), nullable=True, ), sa.Column( "enabled", sa.Boolean(), server_default=sa.text("false"), nullable=False ), ) # Unique constraint: one config per channel per guild op.create_unique_constraint( "uq_discord_channel_guild_channel", "discord_channel_config", ["guild_config_id", "channel_id"], ) def downgrade() -> None: op.drop_table("discord_channel_config") op.drop_table("discord_guild_config") op.drop_table("discord_bot_config") ================================================ FILE: backend/alembic/versions/8e1ac4f39a9f_enable_contextual_retrieval.py ================================================ """enable contextual retrieval Revision ID: 8e1ac4f39a9f Revises: 9aadf32dfeb4 Create Date: 2024-12-20 13:29:09.918661 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "8e1ac4f39a9f" down_revision = "9aadf32dfeb4" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "search_settings", sa.Column( "enable_contextual_rag", sa.Boolean(), nullable=False, server_default="false", ), ) op.add_column( "search_settings", sa.Column( "contextual_rag_llm_name", sa.String(), nullable=True, ), ) op.add_column( "search_settings", sa.Column( "contextual_rag_llm_provider", sa.String(), nullable=True, ), ) def downgrade() -> None: op.drop_column("search_settings", "enable_contextual_rag") op.drop_column("search_settings", "contextual_rag_llm_name") op.drop_column("search_settings", "contextual_rag_llm_provider") ================================================ FILE: backend/alembic/versions/8e26726b7683_chat_context_addition.py ================================================ """Chat Context Addition Revision ID: 8e26726b7683 Revises: 5809c0787398 Create Date: 2023-09-13 18:34:31.327944 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "8e26726b7683" down_revision = "5809c0787398" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "persona", sa.Column("id", sa.Integer(), nullable=False), sa.Column("name", sa.String(), nullable=False), sa.Column("system_text", sa.Text(), nullable=True), sa.Column("tools_text", sa.Text(), nullable=True), sa.Column("hint_text", sa.Text(), nullable=True), sa.Column("default_persona", sa.Boolean(), nullable=False), sa.Column("deleted", sa.Boolean(), nullable=False), sa.PrimaryKeyConstraint("id"), ) op.add_column("chat_message", sa.Column("persona_id", sa.Integer(), nullable=True)) op.create_foreign_key( "fk_chat_message_persona_id", "chat_message", "persona", ["persona_id"], ["id"] ) def downgrade() -> None: op.drop_constraint("fk_chat_message_persona_id", "chat_message", type_="foreignkey") op.drop_column("chat_message", "persona_id") op.drop_table("persona") ================================================ FILE: backend/alembic/versions/8f43500ee275_add_index.py ================================================ """add index Revision ID: 8f43500ee275 Revises: da42808081e3 Create Date: 2025-02-24 17:35:33.072714 """ from alembic import op # revision identifiers, used by Alembic. revision = "8f43500ee275" down_revision = "da42808081e3" branch_labels = None depends_on = None def upgrade() -> None: # Create a basic index on the lowercase message column for direct text matching # Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4 # op.execute( # """ # CREATE INDEX idx_chat_message_message_lower # ON chat_message (LOWER(substring(message, 1, 1500))) # """ # ) pass def downgrade() -> None: # Drop the index op.execute("DROP INDEX IF EXISTS idx_chat_message_message_lower;") ================================================ FILE: backend/alembic/versions/8ffcc2bcfc11_add_needs_persona_sync_to_user_file.py ================================================ """add needs_persona_sync to user_file Revision ID: 8ffcc2bcfc11 Revises: 7616121f6e97 Create Date: 2026-02-23 10:48:48.343826 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "8ffcc2bcfc11" down_revision = "7616121f6e97" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "user_file", sa.Column( "needs_persona_sync", sa.Boolean(), nullable=False, server_default=sa.text("false"), ), ) def downgrade() -> None: op.drop_column("user_file", "needs_persona_sync") ================================================ FILE: backend/alembic/versions/904451035c9b_store_tool_details.py ================================================ """Store Tool Details Revision ID: 904451035c9b Revises: 3b25685ff73c Create Date: 2023-10-05 12:29:26.620000 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "904451035c9b" down_revision = "3b25685ff73c" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "persona", sa.Column("tools", postgresql.JSONB(astext_type=sa.Text()), nullable=True), ) op.drop_column("persona", "tools_text") def downgrade() -> None: op.add_column( "persona", sa.Column("tools_text", sa.TEXT(), autoincrement=False, nullable=True), ) op.drop_column("persona", "tools") ================================================ FILE: backend/alembic/versions/904e5138fffb_tags.py ================================================ """Tags Revision ID: 904e5138fffb Revises: 891cd83c87a8 Create Date: 2024-01-01 10:44:43.733974 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "904e5138fffb" down_revision = "891cd83c87a8" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "tag", sa.Column("id", sa.Integer(), nullable=False), sa.Column("tag_key", sa.String(), nullable=False), sa.Column("tag_value", sa.String(), nullable=False), sa.Column("source", sa.String(), nullable=False), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint( "tag_key", "tag_value", "source", name="_tag_key_value_source_uc" ), ) op.create_table( "document__tag", sa.Column("document_id", sa.String(), nullable=False), sa.Column("tag_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["document_id"], ["document.id"], ), sa.ForeignKeyConstraint( ["tag_id"], ["tag.id"], ), sa.PrimaryKeyConstraint("document_id", "tag_id"), ) op.add_column( "search_doc", sa.Column( "doc_metadata", postgresql.JSONB(astext_type=sa.Text()), nullable=True, ), ) op.execute("UPDATE search_doc SET doc_metadata = '{}' WHERE doc_metadata IS NULL") op.alter_column("search_doc", "doc_metadata", nullable=False) def downgrade() -> None: op.drop_table("document__tag") op.drop_table("tag") op.drop_column("search_doc", "doc_metadata") ================================================ FILE: backend/alembic/versions/9087b548dd69_seed_default_image_gen_config.py ================================================ """seed_default_image_gen_config Revision ID: 9087b548dd69 Revises: 2b90f3af54b8 Create Date: 2026-01-05 00:00:00.000000 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "9087b548dd69" down_revision = "2b90f3af54b8" branch_labels = None depends_on = None # Constants for default image generation config # Source: web/src/app/admin/configuration/image-generation/constants.ts IMAGE_PROVIDER_ID = "openai_gpt_image_1" MODEL_NAME = "gpt-image-1" PROVIDER_NAME = "openai" def upgrade() -> None: conn = op.get_bind() # Check if image_generation_config table already has records existing_configs = ( conn.execute(sa.text("SELECT COUNT(*) FROM image_generation_config")).scalar() or 0 ) if existing_configs > 0: # Skip if configs already exist - user may have configured manually return # Find the first OpenAI LLM provider openai_provider = conn.execute( sa.text( """ SELECT id, api_key FROM llm_provider WHERE provider = :provider ORDER BY id LIMIT 1 """ ), {"provider": PROVIDER_NAME}, ).fetchone() if not openai_provider: # No OpenAI provider found - nothing to do return source_provider_id, api_key = openai_provider # Create new LLM provider for image generation (clone only api_key) result = conn.execute( sa.text( """ INSERT INTO llm_provider ( name, provider, api_key, api_base, api_version, deployment_name, default_model_name, is_public, is_default_provider, is_default_vision_provider, is_auto_mode ) VALUES ( :name, :provider, :api_key, NULL, NULL, NULL, :default_model_name, :is_public, NULL, NULL, :is_auto_mode ) RETURNING id """ ), { "name": f"Image Gen - {IMAGE_PROVIDER_ID}", "provider": PROVIDER_NAME, "api_key": api_key, "default_model_name": MODEL_NAME, "is_public": True, "is_auto_mode": False, }, ) new_provider_id = result.scalar() # Create model configuration result = conn.execute( sa.text( """ INSERT INTO model_configuration ( llm_provider_id, name, is_visible, max_input_tokens, supports_image_input, display_name ) VALUES ( :llm_provider_id, :name, :is_visible, :max_input_tokens, :supports_image_input, :display_name ) RETURNING id """ ), { "llm_provider_id": new_provider_id, "name": MODEL_NAME, "is_visible": True, "max_input_tokens": None, "supports_image_input": False, "display_name": None, }, ) model_config_id = result.scalar() # Create image generation config conn.execute( sa.text( """ INSERT INTO image_generation_config ( image_provider_id, model_configuration_id, is_default ) VALUES ( :image_provider_id, :model_configuration_id, :is_default ) """ ), { "image_provider_id": IMAGE_PROVIDER_ID, "model_configuration_id": model_config_id, "is_default": True, }, ) def downgrade() -> None: # We don't remove the config on downgrade since it's safe to keep around # If we upgrade again, it will be a no-op due to the existing records check pass ================================================ FILE: backend/alembic/versions/90b409d06e50_add_chat_compression_fields.py ================================================ """add_chat_compression_fields Revision ID: 90b409d06e50 Revises: f220515df7b4 Create Date: 2026-01-26 09:13:09.635427 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "90b409d06e50" down_revision = "f220515df7b4" branch_labels = None depends_on = None def upgrade() -> None: # Add last_summarized_message_id to chat_message # This field marks a message as a summary and indicates the last message it covers. # Summaries are branch-aware via their parent_message_id pointing to the branch. op.add_column( "chat_message", sa.Column( "last_summarized_message_id", sa.Integer(), sa.ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True, ), ) def downgrade() -> None: op.drop_column("chat_message", "last_summarized_message_id") ================================================ FILE: backend/alembic/versions/90e3b9af7da4_tag_fix.py ================================================ """tag-fix Revision ID: 90e3b9af7da4 Revises: 62c3a055a141 Create Date: 2025-08-01 20:58:14.607624 """ import json import logging import os from typing import cast from typing import Generator from alembic import op import sqlalchemy as sa from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT from onyx.db.search_settings import SearchSettings from onyx.configs.app_configs import AUTH_TYPE from onyx.configs.constants import AuthType from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client logger = logging.getLogger("alembic.runtime.migration") # revision identifiers, used by Alembic. revision = "90e3b9af7da4" down_revision = "62c3a055a141" branch_labels = None depends_on = None SKIP_TAG_FIX = os.environ.get("SKIP_TAG_FIX", "true").lower() == "true" # override for cloud if AUTH_TYPE == AuthType.CLOUD: SKIP_TAG_FIX = True def set_is_list_for_known_tags() -> None: """ Sets is_list to true for all tags that are known to be lists. """ LIST_METADATA: list[tuple[str, str]] = [ ("CLICKUP", "tags"), ("CONFLUENCE", "labels"), ("DISCOURSE", "tags"), ("FRESHDESK", "emails"), ("GITHUB", "assignees"), ("GITHUB", "labels"), ("GURU", "tags"), ("GURU", "folders"), ("HUBSPOT", "associated_contact_ids"), ("HUBSPOT", "associated_company_ids"), ("HUBSPOT", "associated_deal_ids"), ("HUBSPOT", "associated_ticket_ids"), ("JIRA", "labels"), ("MEDIAWIKI", "categories"), ("ZENDESK", "labels"), ("ZENDESK", "content_tags"), ] bind = op.get_bind() for source, key in LIST_METADATA: bind.execute( sa.text( f""" UPDATE tag SET is_list = true WHERE tag_key = '{key}' AND source = '{source}' """ ) ) def set_is_list_for_list_tags() -> None: """ Sets is_list to true for all tags which have multiple values for a given document, key, and source triplet. This only works if we remove old tags from the database. """ bind = op.get_bind() bind.execute( sa.text( """ UPDATE tag SET is_list = true FROM ( SELECT DISTINCT tag.tag_key, tag.source FROM tag JOIN document__tag ON tag.id = document__tag.tag_id GROUP BY tag.tag_key, tag.source, document__tag.document_id HAVING count(*) > 1 ) AS list_tags WHERE tag.tag_key = list_tags.tag_key AND tag.source = list_tags.source """ ) ) def log_list_tags() -> None: bind = op.get_bind() result = bind.execute( sa.text( """ SELECT DISTINCT source, tag_key FROM tag WHERE is_list ORDER BY source, tag_key """ ) ).fetchall() logger.info( "List tags:\n" + "\n".join(f" {source}: {key}" for source, key in result) ) def remove_old_tags() -> None: """ Removes old tags from the database. Previously, there was a bug where if a document got indexed with a tag and then the document got reindexed, the old tag would not be removed. This function removes those old tags by comparing it against the tags in vespa. """ current_search_settings, _ = active_search_settings() # Get the index name if hasattr(current_search_settings, "index_name"): index_name = current_search_settings.index_name else: # Default index name if we can't get it from the document_index index_name = "danswer_index" for batch in _get_batch_documents_with_multiple_tags(): n_deleted = 0 for document_id in batch: true_metadata = _get_vespa_metadata(document_id, index_name) tags = _get_document_tags(document_id) # identify document__tags to delete to_delete: list[str] = [] for tag_id, tag_key, tag_value in tags: true_val = true_metadata.get(tag_key, "") if (isinstance(true_val, list) and tag_value not in true_val) or ( isinstance(true_val, str) and tag_value != true_val ): to_delete.append(str(tag_id)) if not to_delete: continue # delete old document__tags bind = op.get_bind() result = bind.execute( sa.text( f""" DELETE FROM document__tag WHERE document_id = '{document_id}' AND tag_id IN ({",".join(to_delete)}) """ ) ) n_deleted += result.rowcount logger.info(f"Processed {len(batch)} documents and deleted {n_deleted} tags") def active_search_settings() -> tuple[SearchSettings, SearchSettings | None]: result = op.get_bind().execute( sa.text( """ SELECT * FROM search_settings WHERE status = 'PRESENT' ORDER BY id DESC LIMIT 1 """ ) ) search_settings_fetch = result.fetchall() search_settings = ( SearchSettings(**search_settings_fetch[0]._asdict()) if search_settings_fetch else None ) result2 = op.get_bind().execute( sa.text( """ SELECT * FROM search_settings WHERE status = 'FUTURE' ORDER BY id DESC LIMIT 1 """ ) ) search_settings_future_fetch = result2.fetchall() search_settings_future = ( SearchSettings(**search_settings_future_fetch[0]._asdict()) if search_settings_future_fetch else None ) if not isinstance(search_settings, SearchSettings): raise RuntimeError( "current search settings is of type " + str(type(search_settings)) ) if ( not isinstance(search_settings_future, SearchSettings) and search_settings_future is not None ): raise RuntimeError( "future search settings is of type " + str(type(search_settings_future)) ) return search_settings, search_settings_future def _get_batch_documents_with_multiple_tags( batch_size: int = 128, ) -> Generator[list[str], None, None]: """ Returns a list of document ids which contain a one to many tag. The document may either contain a list metadata value, or may contain leftover old tags from reindexing. """ offset_clause = "" bind = op.get_bind() while True: batch = bind.execute( sa.text( f""" SELECT DISTINCT document__tag.document_id FROM tag JOIN document__tag ON tag.id = document__tag.tag_id GROUP BY tag.tag_key, tag.source, document__tag.document_id HAVING count(*) > 1 {offset_clause} ORDER BY document__tag.document_id LIMIT {batch_size} """ ) ).fetchall() if not batch: break doc_ids = [document_id for (document_id,) in batch] yield doc_ids offset_clause = f"AND document__tag.document_id > '{doc_ids[-1]}'" def _get_vespa_metadata( document_id: str, index_name: str ) -> dict[str, str | list[str]]: url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name) # Document-Selector language selection = ( f"{index_name}.document_id=='{document_id}' and {index_name}.chunk_id==0" ) params: dict[str, str | int] = { "selection": selection, "wantedDocumentCount": 1, "fieldSet": f"{index_name}:metadata", } with get_vespa_http_client() as client: resp = client.get(url, params=params) resp.raise_for_status() docs = resp.json().get("documents", []) if not docs: raise RuntimeError(f"No chunk-0 found for document {document_id}") # for some reason, metadata is a string metadata = docs[0]["fields"]["metadata"] return json.loads(metadata) def _get_document_tags(document_id: str) -> list[tuple[int, str, str]]: bind = op.get_bind() result = bind.execute( sa.text( f""" SELECT tag.id, tag.tag_key, tag.tag_value FROM tag JOIN document__tag ON tag.id = document__tag.tag_id WHERE document__tag.document_id = '{document_id}' """ ) ).fetchall() return cast(list[tuple[int, str, str]], result) def upgrade() -> None: op.add_column( "tag", sa.Column("is_list", sa.Boolean(), nullable=False, server_default="false"), ) op.drop_constraint( constraint_name="_tag_key_value_source_uc", table_name="tag", type_="unique", ) op.create_unique_constraint( constraint_name="_tag_key_value_source_list_uc", table_name="tag", columns=["tag_key", "tag_value", "source", "is_list"], ) set_is_list_for_known_tags() if SKIP_TAG_FIX: logger.warning( "Skipping removal of old tags. " "This can cause issues when using the knowledge graph, or " "when filtering for documents by tags." ) log_list_tags() return remove_old_tags() set_is_list_for_list_tags() # debug log_list_tags() def downgrade() -> None: # the migration adds and populates the is_list column, and removes old bugged tags # there isn't a point in adding back the bugged tags, so we just drop the column op.drop_constraint( constraint_name="_tag_key_value_source_list_uc", table_name="tag", type_="unique", ) op.create_unique_constraint( constraint_name="_tag_key_value_source_uc", table_name="tag", columns=["tag_key", "tag_value", "source"], ) op.drop_column("tag", "is_list") ================================================ FILE: backend/alembic/versions/91a0a4d62b14_milestone.py ================================================ """Milestone Revision ID: 91a0a4d62b14 Revises: dab04867cd88 Create Date: 2024-12-13 19:03:30.947551 """ from alembic import op import sqlalchemy as sa import fastapi_users_db_sqlalchemy from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "91a0a4d62b14" down_revision = "dab04867cd88" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "milestone", sa.Column("id", sa.UUID(), nullable=False), sa.Column("tenant_id", sa.String(), nullable=True), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=True, ), sa.Column("event_type", sa.String(), nullable=False), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column("event_tracker", postgresql.JSONB(), nullable=True), sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("event_type", name="uq_milestone_event_type"), ) def downgrade() -> None: op.drop_table("milestone") ================================================ FILE: backend/alembic/versions/91fd3b470d1a_remove_documentsource_from_tag.py ================================================ """Remove DocumentSource from Tag Revision ID: 91fd3b470d1a Revises: 173cae5bba26 Create Date: 2024-03-21 12:05:23.956734 """ from alembic import op import sqlalchemy as sa from onyx.configs.constants import DocumentSource # revision identifiers, used by Alembic. revision = "91fd3b470d1a" down_revision = "173cae5bba26" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.alter_column( "tag", "source", type_=sa.String(length=50), existing_type=sa.Enum(DocumentSource, native_enum=False), existing_nullable=False, ) def downgrade() -> None: op.alter_column( "tag", "source", type_=sa.Enum(DocumentSource, native_enum=False), existing_type=sa.String(length=50), existing_nullable=False, ) ================================================ FILE: backend/alembic/versions/91ffac7e65b3_add_expiry_time.py ================================================ """add expiry time Revision ID: 91ffac7e65b3 Revises: bc9771dccadf Create Date: 2024-06-24 09:39:56.462242 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "91ffac7e65b3" down_revision = "795b20b85b4b" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "user", sa.Column("oidc_expiry", sa.DateTime(timezone=True), nullable=True) ) def downgrade() -> None: op.drop_column("user", "oidc_expiry") ================================================ FILE: backend/alembic/versions/93560ba1b118_add_web_ui_option_to_slack_config.py ================================================ """add web ui option to slack config Revision ID: 93560ba1b118 Revises: 6d562f86c78b Create Date: 2024-11-24 06:36:17.490612 """ from alembic import op # revision identifiers, used by Alembic. revision = "93560ba1b118" down_revision = "6d562f86c78b" branch_labels = None depends_on = None def upgrade() -> None: # Add show_continue_in_web_ui with default False to all existing channel_configs op.execute( """ UPDATE slack_channel_config SET channel_config = channel_config || '{"show_continue_in_web_ui": false}'::jsonb WHERE NOT channel_config ? 'show_continue_in_web_ui' """ ) def downgrade() -> None: # Remove show_continue_in_web_ui from all channel_configs op.execute( """ UPDATE slack_channel_config SET channel_config = channel_config - 'show_continue_in_web_ui' """ ) ================================================ FILE: backend/alembic/versions/93a2e195e25c_add_voice_provider_and_user_voice_prefs.py ================================================ """add_voice_provider_and_user_voice_prefs Revision ID: 93a2e195e25c Revises: 27fb147a843f Create Date: 2026-02-23 15:16:39.507304 """ from alembic import op import sqlalchemy as sa from sqlalchemy import column from sqlalchemy import true from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "93a2e195e25c" down_revision = "27fb147a843f" branch_labels = None depends_on = None def upgrade() -> None: # Create voice_provider table op.create_table( "voice_provider", sa.Column("id", sa.Integer(), primary_key=True), sa.Column("name", sa.String(), unique=True, nullable=False), sa.Column("provider_type", sa.String(), nullable=False), sa.Column("api_key", sa.LargeBinary(), nullable=True), sa.Column("api_base", sa.String(), nullable=True), sa.Column("custom_config", postgresql.JSONB(), nullable=True), sa.Column("stt_model", sa.String(), nullable=True), sa.Column("tts_model", sa.String(), nullable=True), sa.Column("default_voice", sa.String(), nullable=True), sa.Column( "is_default_stt", sa.Boolean(), nullable=False, server_default="false" ), sa.Column( "is_default_tts", sa.Boolean(), nullable=False, server_default="false" ), sa.Column("deleted", sa.Boolean(), nullable=False, server_default="false"), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False, ), sa.Column( "time_updated", sa.DateTime(timezone=True), server_default=sa.func.now(), onupdate=sa.func.now(), nullable=False, ), ) # Add partial unique indexes to enforce only one default STT/TTS provider op.create_index( "ix_voice_provider_one_default_stt", "voice_provider", ["is_default_stt"], unique=True, postgresql_where=column("is_default_stt") == true(), ) op.create_index( "ix_voice_provider_one_default_tts", "voice_provider", ["is_default_tts"], unique=True, postgresql_where=column("is_default_tts") == true(), ) # Add voice preference columns to user table op.add_column( "user", sa.Column( "voice_auto_send", sa.Boolean(), default=False, nullable=False, server_default="false", ), ) op.add_column( "user", sa.Column( "voice_auto_playback", sa.Boolean(), default=False, nullable=False, server_default="false", ), ) op.add_column( "user", sa.Column( "voice_playback_speed", sa.Float(), default=1.0, nullable=False, server_default="1.0", ), ) def downgrade() -> None: # Remove user voice preference columns op.drop_column("user", "voice_playback_speed") op.drop_column("user", "voice_auto_playback") op.drop_column("user", "voice_auto_send") op.drop_index("ix_voice_provider_one_default_tts", table_name="voice_provider") op.drop_index("ix_voice_provider_one_default_stt", table_name="voice_provider") # Drop voice_provider table op.drop_table("voice_provider") ================================================ FILE: backend/alembic/versions/93c15d6a6fbb_add_chunk_error_and_vespa_count_columns_.py ================================================ """add chunk error and vespa count columns to opensearch tenant migration Revision ID: 93c15d6a6fbb Revises: d3fd499c829c Create Date: 2026-02-11 23:07:34.576725 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "93c15d6a6fbb" down_revision = "d3fd499c829c" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "opensearch_tenant_migration_record", sa.Column( "total_chunks_errored", sa.Integer(), nullable=False, server_default="0", ), ) op.add_column( "opensearch_tenant_migration_record", sa.Column( "total_chunks_in_vespa", sa.Integer(), nullable=False, server_default="0", ), ) def downgrade() -> None: op.drop_column("opensearch_tenant_migration_record", "total_chunks_in_vespa") op.drop_column("opensearch_tenant_migration_record", "total_chunks_errored") ================================================ FILE: backend/alembic/versions/949b4a92a401_remove_rt.py ================================================ """remove rt Revision ID: 949b4a92a401 Revises: 1b10e1fda030 Create Date: 2024-10-26 13:06:06.937969 """ from alembic import op from sqlalchemy.orm import Session from sqlalchemy import text # Import your models and constants from onyx.db.models import ( Connector, ConnectorCredentialPair, Credential, IndexAttempt, ) # revision identifiers, used by Alembic. revision = "949b4a92a401" down_revision = "1b10e1fda030" branch_labels = None depends_on = None def upgrade() -> None: # Deletes all RequestTracker connectors and associated data bind = op.get_bind() session = Session(bind=bind) # Get connectors using raw SQL result = bind.execute( text("SELECT id FROM connector WHERE source = 'requesttracker'") ) connector_ids = [row[0] for row in result] if connector_ids: cc_pairs_to_delete = ( session.query(ConnectorCredentialPair) .filter(ConnectorCredentialPair.connector_id.in_(connector_ids)) .all() ) cc_pair_ids = [cc_pair.id for cc_pair in cc_pairs_to_delete] if cc_pair_ids: session.query(IndexAttempt).filter( IndexAttempt.connector_credential_pair_id.in_(cc_pair_ids) ).delete(synchronize_session=False) session.query(ConnectorCredentialPair).filter( ConnectorCredentialPair.id.in_(cc_pair_ids) ).delete(synchronize_session=False) credential_ids = [cc_pair.credential_id for cc_pair in cc_pairs_to_delete] if credential_ids: session.query(Credential).filter(Credential.id.in_(credential_ids)).delete( synchronize_session=False ) session.query(Connector).filter(Connector.id.in_(connector_ids)).delete( synchronize_session=False ) session.commit() def downgrade() -> None: # No-op downgrade as we cannot restore deleted data pass ================================================ FILE: backend/alembic/versions/94dc3d0236f8_make_document_set_description_optional.py ================================================ """make document set description optional Revision ID: 94dc3d0236f8 Revises: bf7a81109301 Create Date: 2024-12-11 11:26:10.616722 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "94dc3d0236f8" down_revision = "bf7a81109301" branch_labels = None depends_on = None def upgrade() -> None: # Make document_set.description column nullable op.alter_column( "document_set", "description", existing_type=sa.String(), nullable=True ) def downgrade() -> None: # Revert document_set.description column to non-nullable op.alter_column( "document_set", "description", existing_type=sa.String(), nullable=False ) ================================================ FILE: backend/alembic/versions/96a5702df6aa_mcp_tool_enabled.py ================================================ """mcp_tool_enabled Revision ID: 96a5702df6aa Revises: 40926a4dab77 Create Date: 2025-10-09 12:10:21.733097 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "96a5702df6aa" down_revision = "40926a4dab77" branch_labels = None depends_on = None DELETE_DISABLED_TOOLS_SQL = "DELETE FROM tool WHERE enabled = false" def upgrade() -> None: op.add_column( "tool", sa.Column( "enabled", sa.Boolean(), nullable=False, server_default=sa.true(), ), ) op.create_index( "ix_tool_mcp_server_enabled", "tool", ["mcp_server_id", "enabled"], ) # Remove the server default so application controls defaulting op.alter_column("tool", "enabled", server_default=None) def downgrade() -> None: op.execute(DELETE_DISABLED_TOOLS_SQL) op.drop_index("ix_tool_mcp_server_enabled", table_name="tool") op.drop_column("tool", "enabled") ================================================ FILE: backend/alembic/versions/977e834c1427_seed_default_groups.py ================================================ """seed_default_groups Revision ID: 977e834c1427 Revises: 8188861f4e92 Create Date: 2026-03-25 14:59:41.313091 """ from typing import Any from alembic import op import sqlalchemy as sa from sqlalchemy.dialects.postgresql import insert as pg_insert # revision identifiers, used by Alembic. revision = "977e834c1427" down_revision = "8188861f4e92" branch_labels = None depends_on = None # (group_name, permission_value) DEFAULT_GROUPS = [ ("Admin", "admin"), ("Basic", "basic"), ] CUSTOM_SUFFIX = "(Custom)" MAX_RENAME_ATTEMPTS = 100 # Reflect table structures for use in DML user_group_table = sa.table( "user_group", sa.column("id", sa.Integer), sa.column("name", sa.String), sa.column("is_up_to_date", sa.Boolean), sa.column("is_up_for_deletion", sa.Boolean), sa.column("is_default", sa.Boolean), ) permission_grant_table = sa.table( "permission_grant", sa.column("group_id", sa.Integer), sa.column("permission", sa.String), sa.column("grant_source", sa.String), ) user__user_group_table = sa.table( "user__user_group", sa.column("user_group_id", sa.Integer), sa.column("user_id", sa.Uuid), ) def _find_available_name(conn: sa.engine.Connection, base: str) -> str: """Return a name like 'Admin (Custom)' or 'Admin (Custom 2)' that is not taken.""" candidate = f"{base} {CUSTOM_SUFFIX}" attempt = 1 while attempt <= MAX_RENAME_ATTEMPTS: exists: Any = conn.execute( sa.select(sa.literal(1)) .select_from(user_group_table) .where(user_group_table.c.name == candidate) .limit(1) ).fetchone() if exists is None: return candidate attempt += 1 candidate = f"{base} (Custom {attempt})" raise RuntimeError( f"Could not find an available name for group '{base}' " f"after {MAX_RENAME_ATTEMPTS} attempts" ) def upgrade() -> None: conn = op.get_bind() for group_name, permission_value in DEFAULT_GROUPS: # Step 1: Rename ALL existing groups that clash with the canonical name. conflicting = conn.execute( sa.select(user_group_table.c.id, user_group_table.c.name).where( user_group_table.c.name == group_name ) ).fetchall() for row_id, row_name in conflicting: new_name = _find_available_name(conn, row_name) op.execute( sa.update(user_group_table) .where(user_group_table.c.id == row_id) .values(name=new_name, is_up_to_date=False) ) # Step 2: Create a fresh default group. result = conn.execute( user_group_table.insert() .values( name=group_name, is_up_to_date=True, is_up_for_deletion=False, is_default=True, ) .returning(user_group_table.c.id) ).fetchone() assert result is not None group_id = result[0] # Step 3: Upsert permission grant. op.execute( pg_insert(permission_grant_table) .values( group_id=group_id, permission=permission_value, grant_source="SYSTEM", ) .on_conflict_do_nothing(index_elements=["group_id", "permission"]) ) def downgrade() -> None: # Remove the default groups created by this migration. # First remove user-group memberships that reference default groups # to avoid FK violations, then delete the groups themselves. default_group_ids = sa.select(user_group_table.c.id).where( user_group_table.c.is_default == True # noqa: E712 ) conn = op.get_bind() conn.execute( sa.delete(user__user_group_table).where( user__user_group_table.c.user_group_id.in_(default_group_ids) ) ) conn.execute( sa.delete(user_group_table).where( user_group_table.c.is_default == True # noqa: E712 ) ) ================================================ FILE: backend/alembic/versions/97dbb53fa8c8_add_syncrecord.py ================================================ """Add SyncRecord Revision ID: 97dbb53fa8c8 Revises: 369644546676 Create Date: 2025-01-11 19:39:50.426302 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "97dbb53fa8c8" down_revision = "be2ab2aa50ee" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "sync_record", sa.Column("id", sa.Integer(), nullable=False), sa.Column("entity_id", sa.Integer(), nullable=False), sa.Column( "sync_type", sa.Enum( "DOCUMENT_SET", "USER_GROUP", "CONNECTOR_DELETION", name="synctype", native_enum=False, length=40, ), nullable=False, ), sa.Column( "sync_status", sa.Enum( "IN_PROGRESS", "SUCCESS", "FAILED", "CANCELED", name="syncstatus", native_enum=False, length=40, ), nullable=False, ), sa.Column("num_docs_synced", sa.Integer(), nullable=False), sa.Column("sync_start_time", sa.DateTime(timezone=True), nullable=False), sa.Column("sync_end_time", sa.DateTime(timezone=True), nullable=True), sa.PrimaryKeyConstraint("id"), ) # Add index for fetch_latest_sync_record query op.create_index( "ix_sync_record_entity_id_sync_type_sync_start_time", "sync_record", ["entity_id", "sync_type", "sync_start_time"], ) # Add index for cleanup_sync_records query op.create_index( "ix_sync_record_entity_id_sync_type_sync_status", "sync_record", ["entity_id", "sync_type", "sync_status"], ) def downgrade() -> None: op.drop_index("ix_sync_record_entity_id_sync_type_sync_status") op.drop_index("ix_sync_record_entity_id_sync_type_sync_start_time") op.drop_table("sync_record") ================================================ FILE: backend/alembic/versions/98a5008d8711_agent_tracking.py ================================================ """agent_tracking Revision ID: 98a5008d8711 Revises: 2f80c6a2550f Create Date: 2025-01-29 17:00:00.000001 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql from sqlalchemy.dialects.postgresql import UUID # revision identifiers, used by Alembic. revision = "98a5008d8711" down_revision = "2f80c6a2550f" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "agent__search_metrics", sa.Column("id", sa.Integer(), nullable=False), sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True), sa.Column("persona_id", sa.Integer(), nullable=True), sa.Column("agent_type", sa.String(), nullable=False), sa.Column("start_time", sa.DateTime(timezone=True), nullable=False), sa.Column("base_duration_s", sa.Float(), nullable=False), sa.Column("full_duration_s", sa.Float(), nullable=False), sa.Column("base_metrics", postgresql.JSONB(), nullable=True), sa.Column("refined_metrics", postgresql.JSONB(), nullable=True), sa.Column("all_metrics", postgresql.JSONB(), nullable=True), sa.ForeignKeyConstraint( ["persona_id"], ["persona.id"], ), sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("id"), ) # Create sub_question table op.create_table( "agent__sub_question", sa.Column("id", sa.Integer, primary_key=True), sa.Column("primary_question_id", sa.Integer, sa.ForeignKey("chat_message.id")), sa.Column( "chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id") ), sa.Column("sub_question", sa.Text), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.func.now() ), sa.Column("sub_answer", sa.Text), sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=True), sa.Column("level", sa.Integer(), nullable=False), sa.Column("level_question_num", sa.Integer(), nullable=False), ) # Create sub_query table op.create_table( "agent__sub_query", sa.Column("id", sa.Integer, primary_key=True), sa.Column( "parent_question_id", sa.Integer, sa.ForeignKey("agent__sub_question.id") ), sa.Column( "chat_session_id", UUID(as_uuid=True), sa.ForeignKey("chat_session.id") ), sa.Column("sub_query", sa.Text), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.func.now() ), ) # Create sub_query__search_doc association table op.create_table( "agent__sub_query__search_doc", sa.Column( "sub_query_id", sa.Integer, sa.ForeignKey("agent__sub_query.id"), primary_key=True, ), sa.Column( "search_doc_id", sa.Integer, sa.ForeignKey("search_doc.id"), primary_key=True, ), ) op.add_column( "chat_message", sa.Column( "refined_answer_improvement", sa.Boolean(), nullable=True, ), ) def downgrade() -> None: op.drop_column("chat_message", "refined_answer_improvement") op.drop_table("agent__sub_query__search_doc") op.drop_table("agent__sub_query") op.drop_table("agent__sub_question") op.drop_table("agent__search_metrics") ================================================ FILE: backend/alembic/versions/9a0296d7421e_add_is_auto_mode_to_llm_provider.py ================================================ """add_is_auto_mode_to_llm_provider Revision ID: 9a0296d7421e Revises: 7206234e012a Create Date: 2025-12-17 18:14:29.620981 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "9a0296d7421e" down_revision = "7206234e012a" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "llm_provider", sa.Column( "is_auto_mode", sa.Boolean(), nullable=False, server_default="false", ), ) def downgrade() -> None: op.drop_column("llm_provider", "is_auto_mode") ================================================ FILE: backend/alembic/versions/9aadf32dfeb4_add_user_files.py ================================================ """add user files Revision ID: 9aadf32dfeb4 Revises: 3781a5eb12cb Create Date: 2025-01-26 16:08:21.551022 """ import sqlalchemy as sa import datetime from alembic import op # revision identifiers, used by Alembic. revision = "9aadf32dfeb4" down_revision = "3781a5eb12cb" branch_labels = None depends_on = None def upgrade() -> None: # Create user_folder table without parent_id op.create_table( "user_folder", sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True), sa.Column("name", sa.String(length=255), nullable=True), sa.Column("description", sa.String(length=255), nullable=True), sa.Column("display_priority", sa.Integer(), nullable=True, default=0), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.func.now() ), ) # Create user_file table with folder_id instead of parent_folder_id op.create_table( "user_file", sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), sa.Column("user_id", sa.UUID(), sa.ForeignKey("user.id"), nullable=True), sa.Column( "folder_id", sa.Integer(), sa.ForeignKey("user_folder.id"), nullable=True, ), sa.Column("link_url", sa.String(), nullable=True), sa.Column("token_count", sa.Integer(), nullable=True), sa.Column("file_type", sa.String(), nullable=True), sa.Column("file_id", sa.String(length=255), nullable=False), sa.Column("document_id", sa.String(length=255), nullable=False), sa.Column("name", sa.String(length=255), nullable=False), sa.Column( "created_at", sa.DateTime(), default=datetime.datetime.utcnow, ), sa.Column( "cc_pair_id", sa.Integer(), sa.ForeignKey("connector_credential_pair.id"), nullable=True, unique=True, ), ) # Create persona__user_file table op.create_table( "persona__user_file", sa.Column( "persona_id", sa.Integer(), sa.ForeignKey("persona.id"), primary_key=True ), sa.Column( "user_file_id", sa.Integer(), sa.ForeignKey("user_file.id"), primary_key=True, ), ) # Create persona__user_folder table op.create_table( "persona__user_folder", sa.Column( "persona_id", sa.Integer(), sa.ForeignKey("persona.id"), primary_key=True ), sa.Column( "user_folder_id", sa.Integer(), sa.ForeignKey("user_folder.id"), primary_key=True, ), ) op.add_column( "connector_credential_pair", sa.Column("is_user_file", sa.Boolean(), nullable=True, default=False), ) # Update existing records to have is_user_file=False instead of NULL op.execute( "UPDATE connector_credential_pair SET is_user_file = FALSE WHERE is_user_file IS NULL" ) def downgrade() -> None: op.drop_column("connector_credential_pair", "is_user_file") # Drop the persona__user_folder table op.drop_table("persona__user_folder") # Drop the persona__user_file table op.drop_table("persona__user_file") # Drop the user_file table op.drop_table("user_file") # Drop the user_folder table op.drop_table("user_folder") ================================================ FILE: backend/alembic/versions/9b66d3156fc6_user_file_schema_additions.py ================================================ """Migration 1: User file schema additions Revision ID: 9b66d3156fc6 Revises: b4ef3ae0bf6e Create Date: 2025-09-22 09:42:06.086732 This migration adds new columns and tables without modifying existing data. It is safe to run and can be easily rolled back. """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql as psql import logging logger = logging.getLogger("alembic.runtime.migration") # revision identifiers, used by Alembic. revision = "9b66d3156fc6" down_revision = "b4ef3ae0bf6e" branch_labels = None depends_on = None def upgrade() -> None: """Add new columns and tables without modifying existing data.""" # Enable pgcrypto for UUID generation op.execute("CREATE EXTENSION IF NOT EXISTS pgcrypto") bind = op.get_bind() inspector = sa.inspect(bind) # === USER_FILE: Add new columns === logger.info("Adding new columns to user_file table...") user_file_columns = [col["name"] for col in inspector.get_columns("user_file")] # Check if ID is already UUID (in case of re-run after partial migration) id_is_uuid = any( col["name"] == "id" and "uuid" in str(col["type"]).lower() for col in inspector.get_columns("user_file") ) # Add transitional UUID column only if ID is not already UUID if "new_id" not in user_file_columns and not id_is_uuid: op.add_column( "user_file", sa.Column( "new_id", psql.UUID(as_uuid=True), nullable=True, server_default=sa.text("gen_random_uuid()"), ), ) op.create_unique_constraint("uq_user_file_new_id", "user_file", ["new_id"]) logger.info("Added new_id column to user_file") # Add status column if "status" not in user_file_columns: op.add_column( "user_file", sa.Column( "status", sa.Enum( "PROCESSING", "COMPLETED", "FAILED", "CANCELED", name="userfilestatus", native_enum=False, ), nullable=False, server_default="PROCESSING", ), ) logger.info("Added status column to user_file") # Add other tracking columns if "chunk_count" not in user_file_columns: op.add_column( "user_file", sa.Column("chunk_count", sa.Integer(), nullable=True) ) logger.info("Added chunk_count column to user_file") if "last_accessed_at" not in user_file_columns: op.add_column( "user_file", sa.Column("last_accessed_at", sa.DateTime(timezone=True), nullable=True), ) logger.info("Added last_accessed_at column to user_file") if "needs_project_sync" not in user_file_columns: op.add_column( "user_file", sa.Column( "needs_project_sync", sa.Boolean(), nullable=False, server_default=sa.text("false"), ), ) logger.info("Added needs_project_sync column to user_file") if "last_project_sync_at" not in user_file_columns: op.add_column( "user_file", sa.Column( "last_project_sync_at", sa.DateTime(timezone=True), nullable=True ), ) logger.info("Added last_project_sync_at column to user_file") if "document_id_migrated" not in user_file_columns: op.add_column( "user_file", sa.Column( "document_id_migrated", sa.Boolean(), nullable=False, server_default=sa.text("true"), ), ) logger.info("Added document_id_migrated column to user_file") # === USER_FOLDER -> USER_PROJECT rename === table_names = set(inspector.get_table_names()) if "user_folder" in table_names: logger.info("Updating user_folder table...") # Make description nullable first op.alter_column("user_folder", "description", nullable=True) # Rename table if user_project doesn't exist if "user_project" not in table_names: op.execute("ALTER TABLE user_folder RENAME TO user_project") logger.info("Renamed user_folder to user_project") elif "user_project" in table_names: # If already renamed, ensure column nullability project_cols = [col["name"] for col in inspector.get_columns("user_project")] if "description" in project_cols: op.alter_column("user_project", "description", nullable=True) # Add instructions column to user_project inspector = sa.inspect(bind) # Refresh after rename if "user_project" in inspector.get_table_names(): project_columns = [col["name"] for col in inspector.get_columns("user_project")] if "instructions" not in project_columns: op.add_column( "user_project", sa.Column("instructions", sa.String(), nullable=True), ) logger.info("Added instructions column to user_project") # === CHAT_SESSION: Add project_id === chat_session_columns = [ col["name"] for col in inspector.get_columns("chat_session") ] if "project_id" not in chat_session_columns: op.add_column( "chat_session", sa.Column("project_id", sa.Integer(), nullable=True), ) logger.info("Added project_id column to chat_session") # === PERSONA__USER_FILE: Add UUID column === persona_user_file_columns = [ col["name"] for col in inspector.get_columns("persona__user_file") ] if "user_file_id_uuid" not in persona_user_file_columns: op.add_column( "persona__user_file", sa.Column("user_file_id_uuid", psql.UUID(as_uuid=True), nullable=True), ) logger.info("Added user_file_id_uuid column to persona__user_file") # === PROJECT__USER_FILE: Create new table === if "project__user_file" not in inspector.get_table_names(): op.create_table( "project__user_file", sa.Column("project_id", sa.Integer(), nullable=False), sa.Column("user_file_id", psql.UUID(as_uuid=True), nullable=False), sa.PrimaryKeyConstraint("project_id", "user_file_id"), ) logger.info("Created project__user_file table") # Only create the index if it doesn't exist existing_indexes = [ ix["name"] for ix in inspector.get_indexes("project__user_file") ] if "idx_project__user_file_user_file_id" not in existing_indexes: op.create_index( "idx_project__user_file_user_file_id", "project__user_file", ["user_file_id"], ) logger.info( "Created index idx_project__user_file_user_file_id on project__user_file" ) logger.info("Migration 1 (schema additions) completed successfully") def downgrade() -> None: """Remove added columns and tables.""" bind = op.get_bind() inspector = sa.inspect(bind) logger.info("Starting downgrade of schema additions...") # Drop project__user_file table if "project__user_file" in inspector.get_table_names(): # op.drop_index("idx_project__user_file_user_file_id", "project__user_file") op.drop_table("project__user_file") logger.info("Dropped project__user_file table") # Remove columns from persona__user_file if "persona__user_file" in inspector.get_table_names(): columns = [col["name"] for col in inspector.get_columns("persona__user_file")] if "user_file_id_uuid" in columns: op.drop_column("persona__user_file", "user_file_id_uuid") logger.info("Dropped user_file_id_uuid from persona__user_file") # Remove columns from chat_session if "chat_session" in inspector.get_table_names(): columns = [col["name"] for col in inspector.get_columns("chat_session")] if "project_id" in columns: op.drop_column("chat_session", "project_id") logger.info("Dropped project_id from chat_session") # Rename user_project back to user_folder and remove instructions if "user_project" in inspector.get_table_names(): columns = [col["name"] for col in inspector.get_columns("user_project")] if "instructions" in columns: op.drop_column("user_project", "instructions") op.execute("ALTER TABLE user_project RENAME TO user_folder") # Update NULL descriptions to empty string before setting NOT NULL constraint op.execute("UPDATE user_folder SET description = '' WHERE description IS NULL") op.alter_column("user_folder", "description", nullable=False) logger.info("Renamed user_project back to user_folder") # Remove columns from user_file if "user_file" in inspector.get_table_names(): columns = [col["name"] for col in inspector.get_columns("user_file")] columns_to_drop = [ "document_id_migrated", "last_project_sync_at", "needs_project_sync", "last_accessed_at", "chunk_count", "status", ] for col in columns_to_drop: if col in columns: op.drop_column("user_file", col) logger.info(f"Dropped {col} from user_file") if "new_id" in columns: op.drop_constraint("uq_user_file_new_id", "user_file", type_="unique") op.drop_column("user_file", "new_id") logger.info("Dropped new_id from user_file") # Drop enum type if no columns use it bind.execute(sa.text("DROP TYPE IF EXISTS userfilestatus")) logger.info("Downgrade completed successfully") ================================================ FILE: backend/alembic/versions/9c00a2bccb83_chat_message_agentic.py ================================================ """chat_message_agentic Revision ID: 9c00a2bccb83 Revises: b7a7eee5aa15 Create Date: 2025-02-17 11:15:43.081150 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "9c00a2bccb83" down_revision = "b7a7eee5aa15" branch_labels = None depends_on = None def upgrade() -> None: # First add the column as nullable op.add_column("chat_message", sa.Column("is_agentic", sa.Boolean(), nullable=True)) # Update existing rows based on presence of SubQuestions op.execute( """ UPDATE chat_message SET is_agentic = EXISTS ( SELECT 1 FROM agent__sub_question WHERE agent__sub_question.primary_question_id = chat_message.id ) WHERE is_agentic IS NULL """ ) # Make the column non-nullable with a default value of False op.alter_column( "chat_message", "is_agentic", nullable=False, server_default=sa.text("false") ) def downgrade() -> None: op.drop_column("chat_message", "is_agentic") ================================================ FILE: backend/alembic/versions/9c54986124c6_add_scim_tables.py ================================================ """add_scim_tables Revision ID: 9c54986124c6 Revises: b51c6844d1df Create Date: 2026-02-12 20:29:47.448614 """ from alembic import op import fastapi_users_db_sqlalchemy import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "9c54986124c6" down_revision = "b51c6844d1df" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "scim_token", sa.Column("id", sa.Integer(), nullable=False), sa.Column("name", sa.String(), nullable=False), sa.Column("hashed_token", sa.String(length=64), nullable=False), sa.Column("token_display", sa.String(), nullable=False), sa.Column( "created_by_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False, ), sa.Column( "is_active", sa.Boolean(), server_default=sa.text("true"), nullable=False, ), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column("last_used_at", sa.DateTime(timezone=True), nullable=True), sa.ForeignKeyConstraint(["created_by_id"], ["user.id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("hashed_token"), ) op.create_table( "scim_group_mapping", sa.Column("id", sa.Integer(), nullable=False), sa.Column("external_id", sa.String(), nullable=False), sa.Column("user_group_id", sa.Integer(), nullable=False), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column( "updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), onupdate=sa.text("now()"), nullable=False, ), sa.ForeignKeyConstraint( ["user_group_id"], ["user_group.id"], ondelete="CASCADE" ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("user_group_id"), ) op.create_index( op.f("ix_scim_group_mapping_external_id"), "scim_group_mapping", ["external_id"], unique=True, ) op.create_table( "scim_user_mapping", sa.Column("id", sa.Integer(), nullable=False), sa.Column("external_id", sa.String(), nullable=False), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False, ), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column( "updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), onupdate=sa.text("now()"), nullable=False, ), sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("user_id"), ) op.create_index( op.f("ix_scim_user_mapping_external_id"), "scim_user_mapping", ["external_id"], unique=True, ) def downgrade() -> None: op.drop_index( op.f("ix_scim_user_mapping_external_id"), table_name="scim_user_mapping", ) op.drop_table("scim_user_mapping") op.drop_index( op.f("ix_scim_group_mapping_external_id"), table_name="scim_group_mapping", ) op.drop_table("scim_group_mapping") op.drop_table("scim_token") ================================================ FILE: backend/alembic/versions/9cf5c00f72fe_add_creator_to_cc_pair.py ================================================ """add creator to cc pair Revision ID: 9cf5c00f72fe Revises: 26b931506ecb Create Date: 2024-11-12 15:16:42.682902 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "9cf5c00f72fe" down_revision = "26b931506ecb" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "connector_credential_pair", sa.Column( "creator_id", sa.UUID(as_uuid=True), nullable=True, ), ) def downgrade() -> None: op.drop_column("connector_credential_pair", "creator_id") ================================================ FILE: backend/alembic/versions/9d1543a37106_add_processing_duration_seconds_to_chat_.py ================================================ """add processing_duration_seconds to chat_message Revision ID: 9d1543a37106 Revises: cbc03e08d0f3 Create Date: 2026-01-21 11:42:18.546188 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "9d1543a37106" down_revision = "cbc03e08d0f3" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "chat_message", sa.Column("processing_duration_seconds", sa.Float(), nullable=True), ) def downgrade() -> None: op.drop_column("chat_message", "processing_duration_seconds") ================================================ FILE: backend/alembic/versions/9d97fecfab7f_added_retrieved_docs_to_query_event.py ================================================ """Added retrieved docs to query event Revision ID: 9d97fecfab7f Revises: ffc707a226b4 Create Date: 2023-10-20 12:22:31.930449 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "9d97fecfab7f" down_revision = "ffc707a226b4" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "query_event", sa.Column( "retrieved_document_ids", postgresql.ARRAY(sa.String()), nullable=True, ), ) def downgrade() -> None: op.drop_column("query_event", "retrieved_document_ids") ================================================ FILE: backend/alembic/versions/9drpiiw74ljy_add_config_to_federated_connector.py ================================================ """add config to federated_connector Revision ID: 9drpiiw74ljy Revises: 2acdef638fc2 Create Date: 2025-11-03 12:00:00.000000 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "9drpiiw74ljy" down_revision = "2acdef638fc2" branch_labels = None depends_on = None def upgrade() -> None: connection = op.get_bind() # Check if column already exists in current schema result = connection.execute( sa.text( """ SELECT column_name FROM information_schema.columns WHERE table_schema = current_schema() AND table_name = 'federated_connector' AND column_name = 'config' """ ) ) column_exists = result.fetchone() is not None # Add config column with default empty object (only if it doesn't exist) if not column_exists: op.add_column( "federated_connector", sa.Column( "config", postgresql.JSONB(), nullable=False, server_default="{}" ), ) # Data migration: Single bulk update for all Slack connectors connection.execute( sa.text( """ WITH connector_configs AS ( SELECT fc.id as connector_id, CASE WHEN fcds.entities->'channels' IS NOT NULL AND jsonb_typeof(fcds.entities->'channels') = 'array' AND jsonb_array_length(fcds.entities->'channels') > 0 THEN jsonb_build_object( 'channels', fcds.entities->'channels', 'search_all_channels', false ) || CASE WHEN fcds.entities->'include_dm' IS NOT NULL THEN jsonb_build_object('include_dm', fcds.entities->'include_dm') ELSE '{}'::jsonb END ELSE jsonb_build_object('search_all_channels', true) || CASE WHEN fcds.entities->'include_dm' IS NOT NULL THEN jsonb_build_object('include_dm', fcds.entities->'include_dm') ELSE '{}'::jsonb END END as config FROM federated_connector fc LEFT JOIN LATERAL ( SELECT entities FROM federated_connector__document_set WHERE federated_connector_id = fc.id AND entities IS NOT NULL ORDER BY id LIMIT 1 ) fcds ON true WHERE fc.source = 'FEDERATED_SLACK' AND fcds.entities IS NOT NULL ) UPDATE federated_connector fc SET config = cc.config FROM connector_configs cc WHERE fc.id = cc.connector_id """ ) ) def downgrade() -> None: op.drop_column("federated_connector", "config") ================================================ FILE: backend/alembic/versions/9f696734098f_combine_search_and_chat.py ================================================ """Combine Search and Chat Revision ID: 9f696734098f Revises: a8c2065484e6 Create Date: 2024-11-27 15:32:19.694972 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "9f696734098f" down_revision = "a8c2065484e6" branch_labels = None depends_on = None def upgrade() -> None: op.alter_column("chat_session", "description", nullable=True) op.drop_column("chat_session", "one_shot") op.drop_column("slack_channel_config", "response_type") def downgrade() -> None: op.execute("UPDATE chat_session SET description = '' WHERE description IS NULL") op.alter_column("chat_session", "description", nullable=False) op.add_column( "chat_session", sa.Column("one_shot", sa.Boolean(), nullable=False, server_default=sa.false()), ) op.add_column( "slack_channel_config", sa.Column( "response_type", sa.String(), nullable=False, server_default="citations" ), ) ================================================ FILE: backend/alembic/versions/a01bf2971c5d_update_default_tool_descriptions.py ================================================ """update_default_tool_descriptions Revision ID: a01bf2971c5d Revises: 87c52ec39f84 Create Date: 2025-12-16 15:21:25.656375 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "a01bf2971c5d" down_revision = "18b5b2524446" branch_labels = None depends_on = None # new tool descriptions (12/2025) TOOL_DESCRIPTIONS = { "SearchTool": "The Search Action allows the agent to search through connected knowledge to help build an answer.", "ImageGenerationTool": ( "The Image Generation Action allows the agent to use DALL-E 3 or GPT-IMAGE-1 to generate images. " "The action will be used when the user asks the agent to generate an image." ), "WebSearchTool": ( "The Web Search Action allows the agent to perform internet searches for up-to-date information." ), "KnowledgeGraphTool": ( "The Knowledge Graph Search Action allows the agent to search the " "Knowledge Graph for information. This tool can (for now) only be active in the KG Beta Agent, " "and it requires the Knowledge Graph to be enabled." ), "OktaProfileTool": ( "The Okta Profile Action allows the agent to fetch the current user's information from Okta. " "This may include the user's name, email, phone number, address, and other details such as their " "manager and direct reports." ), } def upgrade() -> None: conn = op.get_bind() for tool_id, description in TOOL_DESCRIPTIONS.items(): conn.execute( sa.text( "UPDATE tool SET description = :description WHERE in_code_tool_id = :tool_id" ), {"description": description, "tool_id": tool_id}, ) def downgrade() -> None: pass ================================================ FILE: backend/alembic/versions/a1b2c3d4e5f6_add_license_table.py ================================================ """add license table Revision ID: a1b2c3d4e5f6 Revises: a01bf2971c5d Create Date: 2025-12-04 10:00:00.000000 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "a1b2c3d4e5f6" down_revision = "a01bf2971c5d" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "license", sa.Column("id", sa.Integer(), primary_key=True), sa.Column("license_data", sa.Text(), nullable=False), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False, ), sa.Column( "updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False, ), ) # Singleton pattern - only ever one row in this table op.create_index( "idx_license_singleton", "license", [sa.text("(true)")], unique=True, ) def downgrade() -> None: op.drop_index("idx_license_singleton", table_name="license") op.drop_table("license") ================================================ FILE: backend/alembic/versions/a1b2c3d4e5f7_drop_agent_search_metrics_table.py ================================================ """drop agent_search_metrics table Revision ID: a1b2c3d4e5f7 Revises: 73e9983e5091 Create Date: 2026-01-17 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "a1b2c3d4e5f7" down_revision = "73e9983e5091" branch_labels = None depends_on = None def upgrade() -> None: op.drop_table("agent__search_metrics") def downgrade() -> None: op.create_table( "agent__search_metrics", sa.Column("id", sa.Integer(), nullable=False), sa.Column("user_id", sa.UUID(), nullable=True), sa.Column("persona_id", sa.Integer(), nullable=True), sa.Column("agent_type", sa.String(), nullable=False), sa.Column("start_time", sa.DateTime(timezone=True), nullable=False), sa.Column("base_duration_s", sa.Float(), nullable=False), sa.Column("full_duration_s", sa.Float(), nullable=False), sa.Column("base_metrics", postgresql.JSONB(), nullable=True), sa.Column("refined_metrics", postgresql.JSONB(), nullable=True), sa.Column("all_metrics", postgresql.JSONB(), nullable=True), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ondelete="CASCADE", ), sa.ForeignKeyConstraint( ["persona_id"], ["persona.id"], ), sa.PrimaryKeyConstraint("id"), ) ================================================ FILE: backend/alembic/versions/a2b3c4d5e6f7_remove_fast_default_model_name.py ================================================ """Remove fast_default_model_name from llm_provider Revision ID: a2b3c4d5e6f7 Revises: 2a391f840e85 Create Date: 2024-12-17 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "a2b3c4d5e6f7" down_revision = "2a391f840e85" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.drop_column("llm_provider", "fast_default_model_name") def downgrade() -> None: op.add_column( "llm_provider", sa.Column("fast_default_model_name", sa.String(), nullable=True), ) ================================================ FILE: backend/alembic/versions/a3795dce87be_migration_confluence_to_be_explicit.py ================================================ """migration confluence to be explicit Revision ID: a3795dce87be Revises: 1f60f60c3401 Create Date: 2024-09-01 13:52:12.006740 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql from sqlalchemy.sql import table, column revision = "a3795dce87be" down_revision = "1f60f60c3401" branch_labels: None = None depends_on: None = None def extract_confluence_keys_from_url(wiki_url: str) -> tuple[str, str, str, bool]: from urllib.parse import urlparse def _extract_confluence_keys_from_cloud_url(wiki_url: str) -> tuple[str, str, str]: parsed_url = urlparse(wiki_url) wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split('/spaces')[0]}" path_parts = parsed_url.path.split("/") space = path_parts[3] page_id = path_parts[5] if len(path_parts) > 5 else "" return wiki_base, space, page_id def _extract_confluence_keys_from_datacenter_url( wiki_url: str, ) -> tuple[str, str, str]: DISPLAY = "/display/" PAGE = "/pages/" parsed_url = urlparse(wiki_url) wiki_base = f"{parsed_url.scheme}://{parsed_url.netloc}{parsed_url.path.split(DISPLAY)[0]}" space = DISPLAY.join(parsed_url.path.split(DISPLAY)[1:]).split("/")[0] page_id = "" if (content := parsed_url.path.split(PAGE)) and len(content) > 1: page_id = content[1] return wiki_base, space, page_id is_confluence_cloud = ( ".atlassian.net/wiki/spaces/" in wiki_url or ".jira.com/wiki/spaces/" in wiki_url ) if is_confluence_cloud: wiki_base, space, page_id = _extract_confluence_keys_from_cloud_url(wiki_url) else: wiki_base, space, page_id = _extract_confluence_keys_from_datacenter_url( wiki_url ) return wiki_base, space, page_id, is_confluence_cloud def reconstruct_confluence_url( wiki_base: str, space: str, page_id: str, is_cloud: bool ) -> str: if is_cloud: url = f"{wiki_base}/spaces/{space}" if page_id: url += f"/pages/{page_id}" else: url = f"{wiki_base}/display/{space}" if page_id: url += f"/pages/{page_id}" return url def upgrade() -> None: connector = table( "connector", column("id", sa.Integer), column("source", sa.String()), column("input_type", sa.String()), column("connector_specific_config", postgresql.JSONB), ) # Fetch all Confluence connectors connection = op.get_bind() confluence_connectors = connection.execute( sa.select(connector).where( sa.and_( connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL" ) ) ).fetchall() for row in confluence_connectors: config = row.connector_specific_config wiki_page_url = config["wiki_page_url"] wiki_base, space, page_id, is_cloud = extract_confluence_keys_from_url( wiki_page_url ) new_config = { "wiki_base": wiki_base, "space": space, "page_id": page_id, "is_cloud": is_cloud, } for key, value in config.items(): if key not in ["wiki_page_url"]: new_config[key] = value op.execute( connector.update() .where(connector.c.id == row.id) .values(connector_specific_config=new_config) ) def downgrade() -> None: connector = table( "connector", column("id", sa.Integer), column("source", sa.String()), column("input_type", sa.String()), column("connector_specific_config", postgresql.JSONB), ) confluence_connectors = ( op.get_bind() .execute( sa.select(connector).where( connector.c.source == "CONFLUENCE", connector.c.input_type == "POLL" ) ) .fetchall() ) for row in confluence_connectors: config = row.connector_specific_config if all(key in config for key in ["wiki_base", "space", "is_cloud"]): wiki_page_url = reconstruct_confluence_url( config["wiki_base"], config["space"], config.get("page_id", ""), config["is_cloud"], ) new_config = {"wiki_page_url": wiki_page_url} new_config.update( { k: v for k, v in config.items() if k not in ["wiki_base", "space", "page_id", "is_cloud"] } ) op.execute( connector.update() .where(connector.c.id == row.id) .values(connector_specific_config=new_config) ) ================================================ FILE: backend/alembic/versions/a3b8d9e2f1c4_make_scim_external_id_nullable.py ================================================ """make scim_user_mapping.external_id nullable Revision ID: a3b8d9e2f1c4 Revises: 2664261bfaab Create Date: 2026-03-02 """ from alembic import op # revision identifiers, used by Alembic. revision = "a3b8d9e2f1c4" down_revision = "2664261bfaab" branch_labels = None depends_on = None def upgrade() -> None: op.alter_column( "scim_user_mapping", "external_id", nullable=True, ) def downgrade() -> None: # Delete any rows where external_id is NULL before re-applying NOT NULL op.execute("DELETE FROM scim_user_mapping WHERE external_id IS NULL") op.alter_column( "scim_user_mapping", "external_id", nullable=False, ) ================================================ FILE: backend/alembic/versions/a3bfd0d64902_add_chosen_assistants_to_user_table.py ================================================ """Add chosen_assistants to User table Revision ID: a3bfd0d64902 Revises: ec85f2b3c544 Create Date: 2024-05-26 17:22:24.834741 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "a3bfd0d64902" down_revision = "ec85f2b3c544" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "user", sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True), ) def downgrade() -> None: op.drop_column("user", "chosen_assistants") ================================================ FILE: backend/alembic/versions/a3c1a7904cd0_remove_userfile_related_deprecated_.py ================================================ """remove userfile related deprecated fields Revision ID: a3c1a7904cd0 Revises: 5c3dca366b35 Create Date: 2026-01-06 13:00:30.634396 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "a3c1a7904cd0" down_revision = "5c3dca366b35" branch_labels = None depends_on = None def upgrade() -> None: op.drop_column("user_file", "document_id") op.drop_column("user_file", "document_id_migrated") op.drop_column("connector_credential_pair", "is_user_file") def downgrade() -> None: op.add_column( "connector_credential_pair", sa.Column("is_user_file", sa.Boolean(), nullable=False, server_default="false"), ) op.add_column( "user_file", sa.Column("document_id", sa.String(), nullable=True), ) op.add_column( "user_file", sa.Column( "document_id_migrated", sa.Boolean(), nullable=False, server_default="true" ), ) ================================================ FILE: backend/alembic/versions/a3f8b2c1d4e5_add_preferred_response_id_to_chat_message.py ================================================ """add preferred_response_id and model_display_name to chat_message Revision ID: a3f8b2c1d4e5 Create Date: 2026-03-22 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "a3f8b2c1d4e5" down_revision = "25a5501dc766" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "chat_message", sa.Column( "preferred_response_id", sa.Integer(), sa.ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True, ), ) op.add_column( "chat_message", sa.Column("model_display_name", sa.String(), nullable=True), ) def downgrade() -> None: op.drop_column("chat_message", "model_display_name") op.drop_column("chat_message", "preferred_response_id") ================================================ FILE: backend/alembic/versions/a4f23d6b71c8_add_llm_provider_persona_restrictions.py ================================================ """add llm provider persona restrictions Revision ID: a4f23d6b71c8 Revises: 5e1c073d48a3 Create Date: 2025-10-21 00:00:00.000000 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "a4f23d6b71c8" down_revision = "5e1c073d48a3" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "llm_provider__persona", sa.Column("llm_provider_id", sa.Integer(), nullable=False), sa.Column("persona_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["llm_provider_id"], ["llm_provider.id"], ondelete="CASCADE" ), sa.ForeignKeyConstraint(["persona_id"], ["persona.id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("llm_provider_id", "persona_id"), ) op.create_index( "ix_llm_provider__persona_llm_provider_id", "llm_provider__persona", ["llm_provider_id"], ) op.create_index( "ix_llm_provider__persona_persona_id", "llm_provider__persona", ["persona_id"], ) op.create_index( "ix_llm_provider__persona_composite", "llm_provider__persona", ["persona_id", "llm_provider_id"], ) def downgrade() -> None: op.drop_index( "ix_llm_provider__persona_composite", table_name="llm_provider__persona", ) op.drop_index( "ix_llm_provider__persona_persona_id", table_name="llm_provider__persona", ) op.drop_index( "ix_llm_provider__persona_llm_provider_id", table_name="llm_provider__persona", ) op.drop_table("llm_provider__persona") ================================================ FILE: backend/alembic/versions/a570b80a5f20_usergroup_tables.py ================================================ """UserGroup tables Revision ID: a570b80a5f20 Revises: 904451035c9b Create Date: 2023-10-02 12:27:10.265725 """ from alembic import op import fastapi_users_db_sqlalchemy import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "a570b80a5f20" down_revision = "904451035c9b" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "user_group", sa.Column("id", sa.Integer(), nullable=False), sa.Column("name", sa.String(), nullable=False), sa.Column("is_up_to_date", sa.Boolean(), nullable=False), sa.Column("is_up_for_deletion", sa.Boolean(), nullable=False), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("name"), ) op.create_table( "user__user_group", sa.Column("user_group_id", sa.Integer(), nullable=False), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False, ), sa.ForeignKeyConstraint( ["user_group_id"], ["user_group.id"], ), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("user_group_id", "user_id"), ) op.create_table( "user_group__connector_credential_pair", sa.Column("user_group_id", sa.Integer(), nullable=False), sa.Column("cc_pair_id", sa.Integer(), nullable=False), sa.Column("is_current", sa.Boolean(), nullable=False), sa.ForeignKeyConstraint( ["cc_pair_id"], ["connector_credential_pair.id"], ), sa.ForeignKeyConstraint( ["user_group_id"], ["user_group.id"], ), sa.PrimaryKeyConstraint("user_group_id", "cc_pair_id", "is_current"), ) def downgrade() -> None: op.drop_table("user_group__connector_credential_pair") op.drop_table("user__user_group") op.drop_table("user_group") ================================================ FILE: backend/alembic/versions/a6df6b88ef81_remove_recent_assistants.py ================================================ """remove recent assistants Revision ID: a6df6b88ef81 Revises: 4d58345da04a Create Date: 2025-01-29 10:25:52.790407 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "a6df6b88ef81" down_revision = "4d58345da04a" branch_labels = None depends_on = None def upgrade() -> None: op.drop_column("user", "recent_assistants") def downgrade() -> None: op.add_column( "user", sa.Column( "recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False ), ) ================================================ FILE: backend/alembic/versions/a7688ab35c45_add_public_external_user_group_table.py ================================================ """Add public_external_user_group table Revision ID: a7688ab35c45 Revises: 5c448911b12f Create Date: 2025-05-06 20:55:12.747875 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "a7688ab35c45" down_revision = "5c448911b12f" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "public_external_user_group", sa.Column("external_user_group_id", sa.String(), nullable=False), sa.Column("cc_pair_id", sa.Integer(), nullable=False), sa.PrimaryKeyConstraint("external_user_group_id", "cc_pair_id"), sa.ForeignKeyConstraint( ["cc_pair_id"], ["connector_credential_pair.id"], ondelete="CASCADE" ), ) def downgrade() -> None: op.drop_table("public_external_user_group") ================================================ FILE: backend/alembic/versions/a852cbe15577_new_chat_history.py ================================================ """New Chat History Revision ID: a852cbe15577 Revises: 6436661d5b65 Create Date: 2025-11-08 15:16:37.781308 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "a852cbe15577" down_revision = "6436661d5b65" branch_labels = None depends_on = None def upgrade() -> None: # 1. Drop old research/agent tables (CASCADE handles dependencies) op.execute("DROP TABLE IF EXISTS research_agent_iteration_sub_step CASCADE") op.execute("DROP TABLE IF EXISTS research_agent_iteration CASCADE") op.execute("DROP TABLE IF EXISTS agent__sub_query__search_doc CASCADE") op.execute("DROP TABLE IF EXISTS agent__sub_query CASCADE") op.execute("DROP TABLE IF EXISTS agent__sub_question CASCADE") # 2. ChatMessage table changes # Rename columns and add FKs op.alter_column( "chat_message", "parent_message", new_column_name="parent_message_id" ) op.create_foreign_key( "fk_chat_message_parent_message_id", "chat_message", "chat_message", ["parent_message_id"], ["id"], ) op.alter_column( "chat_message", "latest_child_message", new_column_name="latest_child_message_id", ) op.create_foreign_key( "fk_chat_message_latest_child_message_id", "chat_message", "chat_message", ["latest_child_message_id"], ["id"], ) # Add new column op.add_column( "chat_message", sa.Column("reasoning_tokens", sa.Text(), nullable=True) ) # Drop old columns op.drop_column("chat_message", "rephrased_query") op.drop_column("chat_message", "alternate_assistant_id") op.drop_column("chat_message", "overridden_model") op.drop_column("chat_message", "is_agentic") op.drop_column("chat_message", "refined_answer_improvement") op.drop_column("chat_message", "research_type") op.drop_column("chat_message", "research_plan") op.drop_column("chat_message", "research_answer_purpose") # 3. ToolCall table changes # Drop the unique constraint first op.drop_constraint("uq_tool_call_message_id", "tool_call", type_="unique") # Delete orphaned tool_call rows (those without valid chat_message) op.execute( "DELETE FROM tool_call WHERE message_id NOT IN (SELECT id FROM chat_message)" ) # Add chat_session_id as nullable first, populate, then make NOT NULL op.add_column( "tool_call", sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=True), ) # Populate chat_session_id from the related chat_message op.execute( """ UPDATE tool_call SET chat_session_id = chat_message.chat_session_id FROM chat_message WHERE tool_call.message_id = chat_message.id """ ) # Now make it NOT NULL and add FK op.alter_column("tool_call", "chat_session_id", nullable=False) op.create_foreign_key( "fk_tool_call_chat_session_id", "tool_call", "chat_session", ["chat_session_id"], ["id"], ondelete="CASCADE", ) # Rename message_id and make nullable, recreate FK with CASCADE op.drop_constraint("tool_call_message_id_fkey", "tool_call", type_="foreignkey") op.alter_column( "tool_call", "message_id", new_column_name="parent_chat_message_id", nullable=True, ) op.create_foreign_key( "fk_tool_call_parent_chat_message_id", "tool_call", "chat_message", ["parent_chat_message_id"], ["id"], ondelete="CASCADE", ) # Add parent_tool_call_id with FK op.add_column( "tool_call", sa.Column("parent_tool_call_id", sa.Integer(), nullable=True) ) op.create_foreign_key( "fk_tool_call_parent_tool_call_id", "tool_call", "tool_call", ["parent_tool_call_id"], ["id"], ondelete="CASCADE", ) # Add other new columns op.add_column( "tool_call", sa.Column("turn_number", sa.Integer(), nullable=False, server_default="0"), ) op.add_column( "tool_call", sa.Column("tool_call_id", sa.String(), nullable=False, server_default=""), ) op.add_column("tool_call", sa.Column("reasoning_tokens", sa.Text(), nullable=True)) op.add_column( "tool_call", sa.Column("tool_call_tokens", sa.Integer(), nullable=False, server_default="0"), ) op.add_column( "tool_call", sa.Column("generated_images", postgresql.JSONB(), nullable=True), ) # Rename columns op.alter_column( "tool_call", "tool_arguments", new_column_name="tool_call_arguments" ) op.alter_column("tool_call", "tool_result", new_column_name="tool_call_response") # Change tool_call_response type from JSONB to Text op.execute( """ ALTER TABLE tool_call ALTER COLUMN tool_call_response TYPE TEXT USING tool_call_response::text """ ) # Drop old columns op.drop_column("tool_call", "tool_name") # 4. Create new association table op.create_table( "tool_call__search_doc", sa.Column("tool_call_id", sa.Integer(), nullable=False), sa.Column("search_doc_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint(["tool_call_id"], ["tool_call.id"], ondelete="CASCADE"), sa.ForeignKeyConstraint( ["search_doc_id"], ["search_doc.id"], ondelete="CASCADE" ), sa.PrimaryKeyConstraint("tool_call_id", "search_doc_id"), ) # 5. Persona table change op.add_column( "persona", sa.Column( "replace_base_system_prompt", sa.Boolean(), nullable=False, server_default="false", ), ) def downgrade() -> None: # Reverse persona changes op.drop_column("persona", "replace_base_system_prompt") # Drop new association table op.drop_table("tool_call__search_doc") # Reverse ToolCall changes op.add_column( "tool_call", sa.Column("tool_name", sa.String(), nullable=False, server_default=""), ) # Change tool_call_response back to JSONB op.execute( """ ALTER TABLE tool_call ALTER COLUMN tool_call_response TYPE JSONB USING tool_call_response::jsonb """ ) op.alter_column("tool_call", "tool_call_response", new_column_name="tool_result") op.alter_column( "tool_call", "tool_call_arguments", new_column_name="tool_arguments" ) op.drop_column("tool_call", "generated_images") op.drop_column("tool_call", "tool_call_tokens") op.drop_column("tool_call", "reasoning_tokens") op.drop_column("tool_call", "tool_call_id") op.drop_column("tool_call", "turn_number") op.drop_constraint( "fk_tool_call_parent_tool_call_id", "tool_call", type_="foreignkey" ) op.drop_column("tool_call", "parent_tool_call_id") op.drop_constraint( "fk_tool_call_parent_chat_message_id", "tool_call", type_="foreignkey" ) op.alter_column( "tool_call", "parent_chat_message_id", new_column_name="message_id", nullable=False, ) op.create_foreign_key( "tool_call_message_id_fkey", "tool_call", "chat_message", ["message_id"], ["id"], ) op.drop_constraint("fk_tool_call_chat_session_id", "tool_call", type_="foreignkey") op.drop_column("tool_call", "chat_session_id") op.create_unique_constraint("uq_tool_call_message_id", "tool_call", ["message_id"]) # Reverse ChatMessage changes # Note: research_answer_purpose and research_type were originally String columns, # not Enum types (see migrations 5ae8240accb3 and f8a9b2c3d4e5) op.add_column( "chat_message", sa.Column("research_answer_purpose", sa.String(), nullable=True), ) op.add_column( "chat_message", sa.Column("research_plan", postgresql.JSONB(), nullable=True) ) op.add_column( "chat_message", sa.Column("research_type", sa.String(), nullable=True), ) op.add_column( "chat_message", sa.Column("refined_answer_improvement", sa.Boolean(), nullable=True), ) op.add_column( "chat_message", sa.Column("is_agentic", sa.Boolean(), nullable=False, server_default="false"), ) op.add_column( "chat_message", sa.Column("overridden_model", sa.String(), nullable=True) ) op.add_column( "chat_message", sa.Column("alternate_assistant_id", sa.Integer(), nullable=True) ) # Recreate the FK constraint that was implicitly dropped when the column was dropped op.create_foreign_key( "fk_chat_message_persona", "chat_message", "persona", ["alternate_assistant_id"], ["id"], ) op.add_column( "chat_message", sa.Column("rephrased_query", sa.Text(), nullable=True) ) op.drop_column("chat_message", "reasoning_tokens") op.drop_constraint( "fk_chat_message_latest_child_message_id", "chat_message", type_="foreignkey" ) op.alter_column( "chat_message", "latest_child_message_id", new_column_name="latest_child_message", ) op.drop_constraint( "fk_chat_message_parent_message_id", "chat_message", type_="foreignkey" ) op.alter_column( "chat_message", "parent_message_id", new_column_name="parent_message" ) # Recreate agent sub question and sub query tables op.create_table( "agent__sub_question", sa.Column("id", sa.Integer(), primary_key=True), sa.Column("primary_question_id", sa.Integer(), nullable=False), sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False), sa.Column("sub_question", sa.Text(), nullable=False), sa.Column("level", sa.Integer(), nullable=False), sa.Column("level_question_num", sa.Integer(), nullable=False), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column("sub_answer", sa.Text(), nullable=False), sa.Column("sub_question_doc_results", postgresql.JSONB(), nullable=False), sa.ForeignKeyConstraint( ["primary_question_id"], ["chat_message.id"], ondelete="CASCADE" ), sa.ForeignKeyConstraint(["chat_session_id"], ["chat_session.id"]), sa.PrimaryKeyConstraint("id"), ) op.create_table( "agent__sub_query", sa.Column("id", sa.Integer(), primary_key=True), sa.Column("parent_question_id", sa.Integer(), nullable=False), sa.Column("chat_session_id", postgresql.UUID(as_uuid=True), nullable=False), sa.Column("sub_query", sa.Text(), nullable=False), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.ForeignKeyConstraint( ["parent_question_id"], ["agent__sub_question.id"], ondelete="CASCADE" ), sa.ForeignKeyConstraint(["chat_session_id"], ["chat_session.id"]), sa.PrimaryKeyConstraint("id"), ) op.create_table( "agent__sub_query__search_doc", sa.Column("sub_query_id", sa.Integer(), nullable=False), sa.Column("search_doc_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["sub_query_id"], ["agent__sub_query.id"], ondelete="CASCADE" ), sa.ForeignKeyConstraint(["search_doc_id"], ["search_doc.id"]), sa.PrimaryKeyConstraint("sub_query_id", "search_doc_id"), ) # Recreate research agent tables op.create_table( "research_agent_iteration", sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), sa.Column("primary_question_id", sa.Integer(), nullable=False), sa.Column("iteration_nr", sa.Integer(), nullable=False), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column("purpose", sa.String(), nullable=True), sa.Column("reasoning", sa.String(), nullable=True), sa.ForeignKeyConstraint( ["primary_question_id"], ["chat_message.id"], ondelete="CASCADE" ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint( "primary_question_id", "iteration_nr", name="_research_agent_iteration_unique_constraint", ), ) op.create_table( "research_agent_iteration_sub_step", sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), sa.Column("primary_question_id", sa.Integer(), nullable=False), sa.Column("iteration_nr", sa.Integer(), nullable=False), sa.Column("iteration_sub_step_nr", sa.Integer(), nullable=False), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column("sub_step_instructions", sa.String(), nullable=True), sa.Column("sub_step_tool_id", sa.Integer(), nullable=True), sa.Column("reasoning", sa.String(), nullable=True), sa.Column("sub_answer", sa.String(), nullable=True), sa.Column("cited_doc_results", postgresql.JSONB(), nullable=False), sa.Column("claims", postgresql.JSONB(), nullable=True), sa.Column("is_web_fetch", sa.Boolean(), nullable=True), sa.Column("queries", postgresql.JSONB(), nullable=True), sa.Column("generated_images", postgresql.JSONB(), nullable=True), sa.Column("additional_data", postgresql.JSONB(), nullable=True), sa.Column("file_ids", postgresql.JSONB(), nullable=True), sa.ForeignKeyConstraint( ["primary_question_id", "iteration_nr"], [ "research_agent_iteration.primary_question_id", "research_agent_iteration.iteration_nr", ], ondelete="CASCADE", ), sa.ForeignKeyConstraint(["sub_step_tool_id"], ["tool.id"], ondelete="SET NULL"), sa.PrimaryKeyConstraint("id"), ) ================================================ FILE: backend/alembic/versions/a8c2065484e6_add_auto_scroll_to_user_model.py ================================================ """add auto scroll to user model Revision ID: a8c2065484e6 Revises: abe7378b8217 Create Date: 2024-11-22 17:34:09.690295 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "a8c2065484e6" down_revision = "abe7378b8217" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "user", sa.Column("auto_scroll", sa.Boolean(), nullable=True, server_default=None), ) def downgrade() -> None: op.drop_column("user", "auto_scroll") ================================================ FILE: backend/alembic/versions/abbfec3a5ac5_merge_prompt_into_persona.py ================================================ """merge prompt into persona Revision ID: abbfec3a5ac5 Revises: 8818cf73fa1a Create Date: 2024-12-19 12:00:00.000000 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "abbfec3a5ac5" down_revision = "8818cf73fa1a" branch_labels = None depends_on = None MAX_PROMPT_LENGTH = 5_000_000 def upgrade() -> None: """NOTE: Prompts without any Personas will just be lost.""" # Step 1: Add new columns to persona table (only if they don't exist) # Check if columns exist before adding them connection = op.get_bind() inspector = sa.inspect(connection) existing_columns = [col["name"] for col in inspector.get_columns("persona")] if "system_prompt" not in existing_columns: op.add_column( "persona", sa.Column( "system_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=True ), ) if "task_prompt" not in existing_columns: op.add_column( "persona", sa.Column( "task_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=True ), ) if "datetime_aware" not in existing_columns: op.add_column( "persona", sa.Column( "datetime_aware", sa.Boolean(), nullable=False, server_default="true" ), ) # Step 2: Migrate data from prompt table to persona table (only if tables exist) existing_tables = inspector.get_table_names() if "prompt" in existing_tables and "persona__prompt" in existing_tables: # For personas that have associated prompts, copy the prompt data op.execute( """ UPDATE persona SET system_prompt = p.system_prompt, task_prompt = p.task_prompt, datetime_aware = p.datetime_aware FROM ( -- Get the first prompt for each persona (in case there are multiple) SELECT DISTINCT ON (pp.persona_id) pp.persona_id, pr.system_prompt, pr.task_prompt, pr.datetime_aware FROM persona__prompt pp JOIN prompt pr ON pp.prompt_id = pr.id ) p WHERE persona.id = p.persona_id """ ) # Step 3: Update chat_message references # Since chat messages referenced prompt_id, we need to update them to use persona_id # This is complex as we need to map from prompt_id to persona_id # Check if chat_message has prompt_id column chat_message_columns = [ col["name"] for col in inspector.get_columns("chat_message") ] if "prompt_id" in chat_message_columns: op.execute( """ ALTER TABLE chat_message DROP CONSTRAINT IF EXISTS chat_message__prompt_fk """ ) op.drop_column("chat_message", "prompt_id") # Step 4: Handle personas without prompts - set default values if needed (always run this) op.execute( """ UPDATE persona SET system_prompt = COALESCE(system_prompt, ''), task_prompt = COALESCE(task_prompt, '') WHERE system_prompt IS NULL OR task_prompt IS NULL """ ) # Step 5: Drop the persona__prompt association table (if it exists) if "persona__prompt" in existing_tables: op.drop_table("persona__prompt") # Step 6: Drop the prompt table (if it exists) if "prompt" in existing_tables: op.drop_table("prompt") # Step 7: Make system_prompt and task_prompt non-nullable after migration (only if they exist) op.alter_column( "persona", "system_prompt", existing_type=sa.String(length=MAX_PROMPT_LENGTH), nullable=False, server_default=None, ) op.alter_column( "persona", "task_prompt", existing_type=sa.String(length=MAX_PROMPT_LENGTH), nullable=False, server_default=None, ) def downgrade() -> None: # Step 1: Recreate the prompt table op.create_table( "prompt", sa.Column("id", sa.Integer(), nullable=False), sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=True), sa.Column("name", sa.String(), nullable=False), sa.Column("description", sa.String(), nullable=False), sa.Column("system_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=False), sa.Column("task_prompt", sa.String(length=MAX_PROMPT_LENGTH), nullable=False), sa.Column( "datetime_aware", sa.Boolean(), nullable=False, server_default="true" ), sa.Column( "default_prompt", sa.Boolean(), nullable=False, server_default="false" ), sa.Column("deleted", sa.Boolean(), nullable=False, server_default="false"), sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("id"), ) # Step 2: Recreate the persona__prompt association table op.create_table( "persona__prompt", sa.Column("persona_id", sa.Integer(), nullable=False), sa.Column("prompt_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["persona_id"], ["persona.id"], ), sa.ForeignKeyConstraint( ["prompt_id"], ["prompt.id"], ), sa.PrimaryKeyConstraint("persona_id", "prompt_id"), ) # Step 3: Migrate data back from persona to prompt table op.execute( """ INSERT INTO prompt ( name, description, system_prompt, task_prompt, datetime_aware, default_prompt, deleted, user_id ) SELECT CONCAT('Prompt for ', name), description, system_prompt, task_prompt, datetime_aware, is_default_persona, deleted, user_id FROM persona WHERE system_prompt IS NOT NULL AND system_prompt != '' RETURNING id, name """ ) # Step 4: Re-establish persona__prompt relationships op.execute( """ INSERT INTO persona__prompt (persona_id, prompt_id) SELECT p.id as persona_id, pr.id as prompt_id FROM persona p JOIN prompt pr ON pr.name = CONCAT('Prompt for ', p.name) WHERE p.system_prompt IS NOT NULL AND p.system_prompt != '' """ ) # Step 5: Add prompt_id column back to chat_message op.add_column("chat_message", sa.Column("prompt_id", sa.Integer(), nullable=True)) # Step 6: Re-establish foreign key constraint op.create_foreign_key( "chat_message__prompt_fk", "chat_message", "prompt", ["prompt_id"], ["id"] ) # Step 7: Remove columns from persona table op.drop_column("persona", "datetime_aware") op.drop_column("persona", "task_prompt") op.drop_column("persona", "system_prompt") ================================================ FILE: backend/alembic/versions/abe7378b8217_add_indexing_trigger_to_cc_pair.py ================================================ """add indexing trigger to cc_pair Revision ID: abe7378b8217 Revises: 6d562f86c78b Create Date: 2024-11-26 19:09:53.481171 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "abe7378b8217" down_revision = "93560ba1b118" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "connector_credential_pair", sa.Column( "indexing_trigger", sa.Enum("UPDATE", "REINDEX", name="indexingmode", native_enum=False), nullable=True, ), ) def downgrade() -> None: op.drop_column("connector_credential_pair", "indexing_trigger") ================================================ FILE: backend/alembic/versions/ac5eaac849f9_add_last_pruned_to_connector_table.py ================================================ """add last_pruned to the connector_credential_pair table Revision ID: ac5eaac849f9 Revises: 52a219fb5233 Create Date: 2024-09-10 15:04:26.437118 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "ac5eaac849f9" down_revision = "46b7a812670f" branch_labels = None depends_on = None def upgrade() -> None: # last pruned represents the last time the connector was pruned op.add_column( "connector_credential_pair", sa.Column("last_pruned", sa.DateTime(timezone=True), nullable=True), ) def downgrade() -> None: op.drop_column("connector_credential_pair", "last_pruned") ================================================ FILE: backend/alembic/versions/acaab4ef4507_remove_inactive_ccpair_status_on_.py ================================================ """remove inactive ccpair status on downgrade Revision ID: acaab4ef4507 Revises: b388730a2899 Create Date: 2025-02-16 18:21:41.330212 """ from alembic import op from onyx.db.models import ConnectorCredentialPair from onyx.db.enums import ConnectorCredentialPairStatus from sqlalchemy import update # revision identifiers, used by Alembic. revision = "acaab4ef4507" down_revision = "b388730a2899" branch_labels = None depends_on = None def upgrade() -> None: pass def downgrade() -> None: op.execute( update(ConnectorCredentialPair) .where(ConnectorCredentialPair.status == ConnectorCredentialPairStatus.INVALID) .values(status=ConnectorCredentialPairStatus.ACTIVE) ) ================================================ FILE: backend/alembic/versions/ae62505e3acc_add_saml_accounts.py ================================================ """Add SAML Accounts Revision ID: ae62505e3acc Revises: 7da543f5672f Create Date: 2023-09-26 16:19:30.933183 """ import fastapi_users_db_sqlalchemy from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "ae62505e3acc" down_revision = "7da543f5672f" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "saml", sa.Column("id", sa.Integer(), nullable=False), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False, ), sa.Column("encrypted_cookie", sa.Text(), nullable=False), sa.Column("expires_at", sa.DateTime(timezone=True), nullable=True), sa.Column( "updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("encrypted_cookie"), sa.UniqueConstraint("user_id"), ) def downgrade() -> None: op.drop_table("saml") ================================================ FILE: backend/alembic/versions/aeda5f2df4f6_add_pinned_assistants.py ================================================ """add pinned assistants Revision ID: aeda5f2df4f6 Revises: c5eae4a75a1b Create Date: 2025-01-09 16:04:10.770636 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "aeda5f2df4f6" down_revision = "c5eae4a75a1b" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "user", sa.Column("pinned_assistants", postgresql.JSONB(), nullable=True) ) op.execute('UPDATE "user" SET pinned_assistants = chosen_assistants') def downgrade() -> None: op.drop_column("user", "pinned_assistants") ================================================ FILE: backend/alembic/versions/b082fec533f0_make_last_attempt_status_nullable.py ================================================ """Make 'last_attempt_status' nullable Revision ID: b082fec533f0 Revises: df0c7ad8a076 Create Date: 2023-08-06 12:05:47.087325 """ from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "b082fec533f0" down_revision = "df0c7ad8a076" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.alter_column( "connector_credential_pair", "last_attempt_status", existing_type=postgresql.ENUM( "NOT_STARTED", "IN_PROGRESS", "SUCCESS", "FAILED", name="indexingstatus", ), nullable=True, ) def downgrade() -> None: # First, update any null values to a default value op.execute( "UPDATE connector_credential_pair SET last_attempt_status = 'NOT_STARTED' WHERE last_attempt_status IS NULL" ) # Then, make the column non-nullable op.alter_column( "connector_credential_pair", "last_attempt_status", existing_type=postgresql.ENUM( "NOT_STARTED", "IN_PROGRESS", "SUCCESS", "FAILED", name="indexingstatus", ), nullable=False, ) ================================================ FILE: backend/alembic/versions/b156fa702355_chat_reworked.py ================================================ """Chat Reworked Revision ID: b156fa702355 Revises: baf71f781b9e Create Date: 2023-12-12 00:57:41.823371 """ import fastapi_users_db_sqlalchemy from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql from sqlalchemy.dialects.postgresql import ENUM from onyx.configs.constants import DocumentSource # revision identifiers, used by Alembic. revision = "b156fa702355" down_revision = "baf71f781b9e" branch_labels: None = None depends_on: None = None searchtype_enum = ENUM( "KEYWORD", "SEMANTIC", "HYBRID", name="searchtype", create_type=True ) recencybiassetting_enum = ENUM( "FAVOR_RECENT", "BASE_DECAY", "NO_DECAY", "AUTO", name="recencybiassetting", create_type=True, ) def upgrade() -> None: bind = op.get_bind() searchtype_enum.create(bind) recencybiassetting_enum.create(bind) # This is irrecoverable, whatever op.execute("DELETE FROM chat_feedback") op.execute("DELETE FROM document_retrieval_feedback") op.create_table( "search_doc", sa.Column("id", sa.Integer(), nullable=False), sa.Column("document_id", sa.String(), nullable=False), sa.Column("chunk_ind", sa.Integer(), nullable=False), sa.Column("semantic_id", sa.String(), nullable=False), sa.Column("link", sa.String(), nullable=True), sa.Column("blurb", sa.String(), nullable=False), sa.Column("boost", sa.Integer(), nullable=False), sa.Column( "source_type", sa.Enum(DocumentSource, native=False), nullable=False, ), sa.Column("hidden", sa.Boolean(), nullable=False), sa.Column("score", sa.Float(), nullable=False), sa.Column("match_highlights", postgresql.ARRAY(sa.String()), nullable=False), sa.Column("updated_at", sa.DateTime(timezone=True), nullable=True), sa.Column("primary_owners", postgresql.ARRAY(sa.String()), nullable=True), sa.Column("secondary_owners", postgresql.ARRAY(sa.String()), nullable=True), sa.PrimaryKeyConstraint("id"), ) op.create_table( "prompt", sa.Column("id", sa.Integer(), nullable=False), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=True, ), sa.Column("name", sa.String(), nullable=False), sa.Column("description", sa.String(), nullable=False), sa.Column("system_prompt", sa.Text(), nullable=False), sa.Column("task_prompt", sa.Text(), nullable=False), sa.Column("include_citations", sa.Boolean(), nullable=False), sa.Column("datetime_aware", sa.Boolean(), nullable=False), sa.Column("default_prompt", sa.Boolean(), nullable=False), sa.Column("deleted", sa.Boolean(), nullable=False), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("id"), ) op.create_table( "persona__prompt", sa.Column("persona_id", sa.Integer(), nullable=False), sa.Column("prompt_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["persona_id"], ["persona.id"], ), sa.ForeignKeyConstraint( ["prompt_id"], ["prompt.id"], ), sa.PrimaryKeyConstraint("persona_id", "prompt_id"), ) # Changes to persona first so chat_sessions can have the right persona # The empty persona will be overwritten on server startup op.add_column( "persona", sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=True, ), ) op.add_column( "persona", sa.Column( "search_type", searchtype_enum, nullable=True, ), ) op.execute("UPDATE persona SET search_type = 'HYBRID'") op.alter_column("persona", "search_type", nullable=False) op.add_column( "persona", sa.Column("llm_relevance_filter", sa.Boolean(), nullable=True), ) op.execute("UPDATE persona SET llm_relevance_filter = TRUE") op.alter_column("persona", "llm_relevance_filter", nullable=False) op.add_column( "persona", sa.Column("llm_filter_extraction", sa.Boolean(), nullable=True), ) op.execute("UPDATE persona SET llm_filter_extraction = TRUE") op.alter_column("persona", "llm_filter_extraction", nullable=False) op.add_column( "persona", sa.Column( "recency_bias", recencybiassetting_enum, nullable=True, ), ) op.execute("UPDATE persona SET recency_bias = 'BASE_DECAY'") op.alter_column("persona", "recency_bias", nullable=False) op.alter_column("persona", "description", existing_type=sa.VARCHAR(), nullable=True) op.execute("UPDATE persona SET description = ''") op.alter_column("persona", "description", nullable=False) op.create_foreign_key("persona__user_fk", "persona", "user", ["user_id"], ["id"]) op.drop_column("persona", "datetime_aware") op.drop_column("persona", "tools") op.drop_column("persona", "hint_text") op.drop_column("persona", "apply_llm_relevance_filter") op.drop_column("persona", "retrieval_enabled") op.drop_column("persona", "system_text") # Need to create a persona row so fk can work result = bind.execute(sa.text("SELECT 1 FROM persona WHERE id = 0")) exists = result.fetchone() if not exists: op.execute( sa.text( """ INSERT INTO persona ( id, user_id, name, description, search_type, num_chunks, llm_relevance_filter, llm_filter_extraction, recency_bias, llm_model_version_override, default_persona, deleted ) VALUES ( 0, NULL, '', '', 'HYBRID', NULL, TRUE, TRUE, 'BASE_DECAY', NULL, TRUE, FALSE ) """ ) ) delete_statement = sa.text( """ DELETE FROM persona WHERE name = 'Danswer' AND default_persona = TRUE AND id != 0 """ ) bind.execute(delete_statement) op.add_column( "chat_feedback", sa.Column("chat_message_id", sa.Integer(), nullable=False), ) op.drop_constraint( "chat_feedback_chat_message_chat_session_id_chat_message_me_fkey", "chat_feedback", type_="foreignkey", ) op.drop_column("chat_feedback", "chat_message_edit_number") op.drop_column("chat_feedback", "chat_message_chat_session_id") op.drop_column("chat_feedback", "chat_message_message_number") op.add_column( "chat_message", sa.Column( "id", sa.Integer(), primary_key=True, autoincrement=True, nullable=False, unique=True, ), ) op.add_column( "chat_message", sa.Column("parent_message", sa.Integer(), nullable=True), ) op.add_column( "chat_message", sa.Column("latest_child_message", sa.Integer(), nullable=True), ) op.add_column( "chat_message", sa.Column("rephrased_query", sa.Text(), nullable=True) ) op.add_column("chat_message", sa.Column("prompt_id", sa.Integer(), nullable=True)) op.add_column( "chat_message", sa.Column("citations", postgresql.JSONB(astext_type=sa.Text()), nullable=True), ) op.add_column("chat_message", sa.Column("error", sa.Text(), nullable=True)) op.drop_constraint("fk_chat_message_persona_id", "chat_message", type_="foreignkey") op.create_foreign_key( "chat_message__prompt_fk", "chat_message", "prompt", ["prompt_id"], ["id"] ) op.drop_column("chat_message", "parent_edit_number") op.drop_column("chat_message", "persona_id") op.drop_column("chat_message", "reference_docs") op.drop_column("chat_message", "edit_number") op.drop_column("chat_message", "latest") op.drop_column("chat_message", "message_number") op.add_column("chat_session", sa.Column("one_shot", sa.Boolean(), nullable=True)) op.execute("UPDATE chat_session SET one_shot = TRUE") op.alter_column("chat_session", "one_shot", nullable=False) op.alter_column( "chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True, ) op.execute("UPDATE chat_session SET persona_id = 0") op.alter_column("chat_session", "persona_id", nullable=False) op.add_column( "document_retrieval_feedback", sa.Column("chat_message_id", sa.Integer(), nullable=False), ) op.drop_constraint( "document_retrieval_feedback_qa_event_id_fkey", "document_retrieval_feedback", type_="foreignkey", ) op.create_foreign_key( "document_retrieval_feedback__chat_message_fk", "document_retrieval_feedback", "chat_message", ["chat_message_id"], ["id"], ) op.drop_column("document_retrieval_feedback", "qa_event_id") # Relation table must be created after the other tables are correct op.create_table( "chat_message__search_doc", sa.Column("chat_message_id", sa.Integer(), nullable=False), sa.Column("search_doc_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["chat_message_id"], ["chat_message.id"], ), sa.ForeignKeyConstraint( ["search_doc_id"], ["search_doc.id"], ), sa.PrimaryKeyConstraint("chat_message_id", "search_doc_id"), ) # Needs to be created after chat_message id field is added op.create_foreign_key( "chat_feedback__chat_message_fk", "chat_feedback", "chat_message", ["chat_message_id"], ["id"], ) op.drop_table("query_event") def downgrade() -> None: # NOTE: you will lose all chat history. This is to satisfy the non-nullable constraints # below op.execute("DELETE FROM chat_feedback") op.execute("DELETE FROM chat_message__search_doc") op.execute("DELETE FROM document_retrieval_feedback") op.execute("DELETE FROM document_retrieval_feedback") op.execute("DELETE FROM chat_message") op.execute("DELETE FROM chat_session") op.drop_constraint( "chat_feedback__chat_message_fk", "chat_feedback", type_="foreignkey" ) op.drop_constraint( "document_retrieval_feedback__chat_message_fk", "document_retrieval_feedback", type_="foreignkey", ) op.drop_constraint("persona__user_fk", "persona", type_="foreignkey") op.drop_constraint("chat_message__prompt_fk", "chat_message", type_="foreignkey") op.drop_constraint( "chat_message__search_doc_chat_message_id_fkey", "chat_message__search_doc", type_="foreignkey", ) op.add_column( "persona", sa.Column("system_text", sa.TEXT(), autoincrement=False, nullable=True), ) op.add_column( "persona", sa.Column( "retrieval_enabled", sa.BOOLEAN(), autoincrement=False, nullable=True, ), ) op.execute("UPDATE persona SET retrieval_enabled = TRUE") op.alter_column("persona", "retrieval_enabled", nullable=False) op.add_column( "persona", sa.Column( "apply_llm_relevance_filter", sa.BOOLEAN(), autoincrement=False, nullable=True, ), ) op.add_column( "persona", sa.Column("hint_text", sa.TEXT(), autoincrement=False, nullable=True), ) op.add_column( "persona", sa.Column( "tools", postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=True, ), ) op.add_column( "persona", sa.Column("datetime_aware", sa.BOOLEAN(), autoincrement=False, nullable=True), ) op.execute("UPDATE persona SET datetime_aware = TRUE") op.alter_column("persona", "datetime_aware", nullable=False) op.alter_column("persona", "description", existing_type=sa.VARCHAR(), nullable=True) op.drop_column("persona", "recency_bias") op.drop_column("persona", "llm_filter_extraction") op.drop_column("persona", "llm_relevance_filter") op.drop_column("persona", "search_type") op.drop_column("persona", "user_id") op.add_column( "document_retrieval_feedback", sa.Column("qa_event_id", sa.INTEGER(), autoincrement=False, nullable=False), ) op.drop_column("document_retrieval_feedback", "chat_message_id") op.alter_column( "chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True ) op.drop_column("chat_session", "one_shot") op.add_column( "chat_message", sa.Column( "message_number", sa.INTEGER(), autoincrement=False, nullable=False, primary_key=True, ), ) op.add_column( "chat_message", sa.Column("latest", sa.BOOLEAN(), autoincrement=False, nullable=False), ) op.add_column( "chat_message", sa.Column( "edit_number", sa.INTEGER(), autoincrement=False, nullable=False, primary_key=True, ), ) op.add_column( "chat_message", sa.Column( "reference_docs", postgresql.JSONB(astext_type=sa.Text()), autoincrement=False, nullable=True, ), ) op.add_column( "chat_message", sa.Column("persona_id", sa.INTEGER(), autoincrement=False, nullable=True), ) op.add_column( "chat_message", sa.Column( "parent_edit_number", sa.INTEGER(), autoincrement=False, nullable=True, ), ) op.create_foreign_key( "fk_chat_message_persona_id", "chat_message", "persona", ["persona_id"], ["id"], ) op.drop_column("chat_message", "error") op.drop_column("chat_message", "citations") op.drop_column("chat_message", "prompt_id") op.drop_column("chat_message", "rephrased_query") op.drop_column("chat_message", "latest_child_message") op.drop_column("chat_message", "parent_message") op.drop_column("chat_message", "id") op.add_column( "chat_feedback", sa.Column( "chat_message_message_number", sa.INTEGER(), autoincrement=False, nullable=False, ), ) op.add_column( "chat_feedback", sa.Column( "chat_message_chat_session_id", sa.INTEGER(), autoincrement=False, nullable=False, primary_key=True, ), ) op.add_column( "chat_feedback", sa.Column( "chat_message_edit_number", sa.INTEGER(), autoincrement=False, nullable=False, ), ) op.drop_column("chat_feedback", "chat_message_id") op.create_table( "query_event", sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False), sa.Column("query", sa.VARCHAR(), autoincrement=False, nullable=False), sa.Column( "selected_search_flow", sa.VARCHAR(), autoincrement=False, nullable=True, ), sa.Column("llm_answer", sa.VARCHAR(), autoincrement=False, nullable=True), sa.Column("feedback", sa.VARCHAR(), autoincrement=False, nullable=True), sa.Column("user_id", sa.UUID(), autoincrement=False, nullable=True), sa.Column( "time_created", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=False, ), sa.Column( "retrieved_document_ids", postgresql.ARRAY(sa.VARCHAR()), autoincrement=False, nullable=True, ), sa.Column("chat_session_id", sa.INTEGER(), autoincrement=False, nullable=True), sa.ForeignKeyConstraint( ["chat_session_id"], ["chat_session.id"], name="fk_query_event_chat_session_id", ), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], name="query_event_user_id_fkey" ), sa.PrimaryKeyConstraint("id", name="query_event_pkey"), ) op.drop_table("chat_message__search_doc") op.drop_table("persona__prompt") op.drop_table("prompt") op.drop_table("search_doc") op.create_unique_constraint( "uq_chat_message_combination", "chat_message", ["chat_session_id", "message_number", "edit_number"], ) op.create_foreign_key( "chat_feedback_chat_message_chat_session_id_chat_message_me_fkey", "chat_feedback", "chat_message", [ "chat_message_chat_session_id", "chat_message_message_number", "chat_message_edit_number", ], ["chat_session_id", "message_number", "edit_number"], ) op.create_foreign_key( "document_retrieval_feedback_qa_event_id_fkey", "document_retrieval_feedback", "query_event", ["qa_event_id"], ["id"], ) op.execute("DROP TYPE IF EXISTS searchtype") op.execute("DROP TYPE IF EXISTS recencybiassetting") op.execute("DROP TYPE IF EXISTS documentsource") ================================================ FILE: backend/alembic/versions/b30353be4eec_add_mcp_auth_performer.py ================================================ """add_mcp_auth_performer Revision ID: b30353be4eec Revises: 2b75d0a8ffcb Create Date: 2025-09-13 14:58:08.413534 """ from alembic import op import sqlalchemy as sa from onyx.db.enums import MCPAuthenticationPerformer, MCPTransport # revision identifiers, used by Alembic. revision = "b30353be4eec" down_revision = "2b75d0a8ffcb" branch_labels = None depends_on = None def upgrade() -> None: """moving to a better way of handling auth performer and transport""" # Add nullable column first for backward compatibility op.add_column( "mcp_server", sa.Column( "auth_performer", sa.Enum(MCPAuthenticationPerformer, native_enum=False), nullable=True, ), ) op.add_column( "mcp_server", sa.Column( "transport", sa.Enum(MCPTransport, native_enum=False), nullable=True, ), ) # # Backfill values using existing data and inference rules bind = op.get_bind() # 1) OAUTH servers are always PER_USER bind.execute( sa.text( """ UPDATE mcp_server SET auth_performer = 'PER_USER' WHERE auth_type = 'OAUTH' """ ) ) # 2) If there is no admin connection config, mark as ADMIN (and not set yet) bind.execute( sa.text( """ UPDATE mcp_server SET auth_performer = 'ADMIN' WHERE admin_connection_config_id IS NULL AND auth_performer IS NULL """ ) ) # 3) If there exists any user-specific connection config (user_email != ''), mark as PER_USER bind.execute( sa.text( """ UPDATE mcp_server AS ms SET auth_performer = 'PER_USER' FROM mcp_connection_config AS mcc WHERE mcc.mcp_server_id = ms.id AND COALESCE(mcc.user_email, '') <> '' AND ms.auth_performer IS NULL """ ) ) # 4) Default any remaining nulls to ADMIN (covers API_TOKEN admin-managed and NONE) bind.execute( sa.text( """ UPDATE mcp_server SET auth_performer = 'ADMIN' WHERE auth_performer IS NULL """ ) ) # Finally, make the column non-nullable op.alter_column( "mcp_server", "auth_performer", existing_type=sa.Enum(MCPAuthenticationPerformer, native_enum=False), nullable=False, ) # Backfill transport for existing rows to STREAMABLE_HTTP, then make non-nullable bind.execute( sa.text( """ UPDATE mcp_server SET transport = 'STREAMABLE_HTTP' WHERE transport IS NULL """ ) ) op.alter_column( "mcp_server", "transport", existing_type=sa.Enum(MCPTransport, native_enum=False), nullable=False, ) def downgrade() -> None: """remove cols""" op.drop_column("mcp_server", "transport") op.drop_column("mcp_server", "auth_performer") ================================================ FILE: backend/alembic/versions/b329d00a9ea6_adding_assistant_specific_user_.py ================================================ """Adding assistant-specific user preferences Revision ID: b329d00a9ea6 Revises: f9b8c7d6e5a4 Create Date: 2025-08-26 23:14:44.592985 """ from alembic import op import fastapi_users_db_sqlalchemy import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "b329d00a9ea6" down_revision = "f9b8c7d6e5a4" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "assistant__user_specific_config", sa.Column("assistant_id", sa.Integer(), nullable=False), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False, ), sa.Column("disabled_tool_ids", postgresql.ARRAY(sa.Integer()), nullable=False), sa.ForeignKeyConstraint(["assistant_id"], ["persona.id"], ondelete="CASCADE"), sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("assistant_id", "user_id"), ) def downgrade() -> None: op.drop_table("assistant__user_specific_config") ================================================ FILE: backend/alembic/versions/b388730a2899_nullable_preferences.py ================================================ """nullable preferences Revision ID: b388730a2899 Revises: 1a03d2c2856b Create Date: 2025-02-17 18:49:22.643902 """ from alembic import op # revision identifiers, used by Alembic. revision = "b388730a2899" down_revision = "1a03d2c2856b" branch_labels = None depends_on = None def upgrade() -> None: op.alter_column("user", "temperature_override_enabled", nullable=True) op.alter_column("user", "auto_scroll", nullable=True) def downgrade() -> None: # Ensure no null values before making columns non-nullable op.execute( 'UPDATE "user" SET temperature_override_enabled = false WHERE temperature_override_enabled IS NULL' ) op.execute('UPDATE "user" SET auto_scroll = false WHERE auto_scroll IS NULL') op.alter_column("user", "temperature_override_enabled", nullable=False) op.alter_column("user", "auto_scroll", nullable=False) ================================================ FILE: backend/alembic/versions/b4b7e1028dfd_grant_basic_to_existing_groups.py ================================================ """grant_basic_to_existing_groups Grants the "basic" permission to all existing groups that don't already have it. Every group should have at least "basic" so that its members get basic access when effective_permissions is backfilled. Revision ID: b4b7e1028dfd Revises: b7bcc991d722 Create Date: 2026-03-30 16:15:17.093498 """ from collections.abc import Sequence from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "b4b7e1028dfd" down_revision = "b7bcc991d722" branch_labels: str | None = None depends_on: str | Sequence[str] | None = None user_group = sa.table( "user_group", sa.column("id", sa.Integer), sa.column("is_default", sa.Boolean), ) permission_grant = sa.table( "permission_grant", sa.column("group_id", sa.Integer), sa.column("permission", sa.String), sa.column("grant_source", sa.String), sa.column("is_deleted", sa.Boolean), ) def upgrade() -> None: conn = op.get_bind() already_has_basic = ( sa.select(sa.literal(1)) .select_from(permission_grant) .where( permission_grant.c.group_id == user_group.c.id, permission_grant.c.permission == "basic", ) .exists() ) groups_needing_basic = sa.select( user_group.c.id, sa.literal("basic").label("permission"), sa.literal("SYSTEM").label("grant_source"), sa.literal(False).label("is_deleted"), ).where( user_group.c.is_default == sa.false(), ~already_has_basic, ) conn.execute( permission_grant.insert().from_select( ["group_id", "permission", "grant_source", "is_deleted"], groups_needing_basic, ) ) def downgrade() -> None: conn = op.get_bind() non_default_group_ids = sa.select(user_group.c.id).where( user_group.c.is_default == sa.false() ) conn.execute( permission_grant.delete().where( permission_grant.c.permission == "basic", permission_grant.c.grant_source == "SYSTEM", permission_grant.c.group_id.in_(non_default_group_ids), ) ) ================================================ FILE: backend/alembic/versions/b4ef3ae0bf6e_add_user_oauth_token_to_slack_bot.py ================================================ """add_user_oauth_token_to_slack_bot Revision ID: b4ef3ae0bf6e Revises: 505c488f6662 Create Date: 2025-08-26 17:47:41.788462 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "b4ef3ae0bf6e" down_revision = "505c488f6662" branch_labels = None depends_on = None def upgrade() -> None: # Add user_token column to slack_bot table op.add_column("slack_bot", sa.Column("user_token", sa.LargeBinary(), nullable=True)) def downgrade() -> None: # Remove user_token column from slack_bot table op.drop_column("slack_bot", "user_token") ================================================ FILE: backend/alembic/versions/b51c6844d1df_seed_memory_tool.py ================================================ """seed_memory_tool and add enable_memory_tool to user Revision ID: b51c6844d1df Revises: 93c15d6a6fbb Create Date: 2026-02-11 00:00:00.000000 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "b51c6844d1df" down_revision = "93c15d6a6fbb" branch_labels = None depends_on = None MEMORY_TOOL = { "name": "MemoryTool", "display_name": "Add Memory", "description": "Save memories about the user for future conversations.", "in_code_tool_id": "MemoryTool", "enabled": True, } def upgrade() -> None: conn = op.get_bind() existing = conn.execute( sa.text( "SELECT in_code_tool_id FROM tool WHERE in_code_tool_id = :in_code_tool_id" ), {"in_code_tool_id": MEMORY_TOOL["in_code_tool_id"]}, ).fetchone() if existing: conn.execute( sa.text( """ UPDATE tool SET name = :name, display_name = :display_name, description = :description WHERE in_code_tool_id = :in_code_tool_id """ ), MEMORY_TOOL, ) else: conn.execute( sa.text( """ INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled) VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled) """ ), MEMORY_TOOL, ) op.add_column( "user", sa.Column( "enable_memory_tool", sa.Boolean(), nullable=False, server_default=sa.true(), ), ) def downgrade() -> None: op.drop_column("user", "enable_memory_tool") conn = op.get_bind() conn.execute( sa.text("DELETE FROM tool WHERE in_code_tool_id = :in_code_tool_id"), {"in_code_tool_id": MEMORY_TOOL["in_code_tool_id"]}, ) ================================================ FILE: backend/alembic/versions/b558f51620b4_pause_finished_user_file_connectors.py ================================================ """Pause finished user file connectors Revision ID: b558f51620b4 Revises: 90e3b9af7da4 Create Date: 2025-08-15 17:17:02.456704 """ from alembic import op # revision identifiers, used by Alembic. revision = "b558f51620b4" down_revision = "90e3b9af7da4" branch_labels = None depends_on = None def upgrade() -> None: # Set all user file connector credential pairs with ACTIVE status to PAUSED # This ensures user files don't continue to run indexing tasks after processing op.execute( """ UPDATE connector_credential_pair SET status = 'PAUSED' WHERE is_user_file = true AND status = 'ACTIVE' """ ) def downgrade() -> None: pass ================================================ FILE: backend/alembic/versions/b5c4d7e8f9a1_add_hierarchy_node_cc_pair_table.py ================================================ """add hierarchy_node_by_connector_credential_pair table Revision ID: b5c4d7e8f9a1 Revises: a3b8d9e2f1c4 Create Date: 2026-03-04 """ import sqlalchemy as sa from alembic import op revision = "b5c4d7e8f9a1" down_revision = "a3b8d9e2f1c4" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "hierarchy_node_by_connector_credential_pair", sa.Column("hierarchy_node_id", sa.Integer(), nullable=False), sa.Column("connector_id", sa.Integer(), nullable=False), sa.Column("credential_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["hierarchy_node_id"], ["hierarchy_node.id"], ondelete="CASCADE", ), sa.ForeignKeyConstraint( ["connector_id", "credential_id"], [ "connector_credential_pair.connector_id", "connector_credential_pair.credential_id", ], ondelete="CASCADE", ), sa.PrimaryKeyConstraint("hierarchy_node_id", "connector_id", "credential_id"), ) op.create_index( "ix_hierarchy_node_cc_pair_connector_credential", "hierarchy_node_by_connector_credential_pair", ["connector_id", "credential_id"], ) def downgrade() -> None: op.drop_index( "ix_hierarchy_node_cc_pair_connector_credential", table_name="hierarchy_node_by_connector_credential_pair", ) op.drop_table("hierarchy_node_by_connector_credential_pair") ================================================ FILE: backend/alembic/versions/b728689f45b1_rename_persona_is_visible_to_is_listed_.py ================================================ """rename persona is_visible to is_listed and featured to is_featured Revision ID: b728689f45b1 Revises: 689433b0d8de Create Date: 2026-03-23 12:36:26.607305 """ from alembic import op # revision identifiers, used by Alembic. revision = "b728689f45b1" down_revision = "689433b0d8de" branch_labels = None depends_on = None def upgrade() -> None: op.alter_column("persona", "is_visible", new_column_name="is_listed") op.alter_column("persona", "featured", new_column_name="is_featured") def downgrade() -> None: op.alter_column("persona", "is_listed", new_column_name="is_visible") op.alter_column("persona", "is_featured", new_column_name="featured") ================================================ FILE: backend/alembic/versions/b72ed7a5db0e_remove_description_from_starter_messages.py ================================================ """remove description from starter messages Revision ID: b72ed7a5db0e Revises: 33cb72ea4d80 Create Date: 2024-11-03 15:55:28.944408 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "b72ed7a5db0e" down_revision = "33cb72ea4d80" branch_labels = None depends_on = None def upgrade() -> None: op.execute( sa.text( """ UPDATE persona SET starter_messages = ( SELECT jsonb_agg(elem - 'description') FROM jsonb_array_elements(starter_messages) elem ) WHERE starter_messages IS NOT NULL AND jsonb_typeof(starter_messages) = 'array' """ ) ) def downgrade() -> None: op.execute( sa.text( """ UPDATE persona SET starter_messages = ( SELECT jsonb_agg(elem || '{"description": ""}') FROM jsonb_array_elements(starter_messages) elem ) WHERE starter_messages IS NOT NULL AND jsonb_typeof(starter_messages) = 'array' """ ) ) ================================================ FILE: backend/alembic/versions/b7a7eee5aa15_add_checkpointing_failure_handling.py ================================================ """Add checkpointing/failure handling Revision ID: b7a7eee5aa15 Revises: f39c5794c10a Create Date: 2025-01-24 15:17:36.763172 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "b7a7eee5aa15" down_revision = "f39c5794c10a" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "index_attempt", sa.Column("checkpoint_pointer", sa.String(), nullable=True), ) op.add_column( "index_attempt", sa.Column("poll_range_start", sa.DateTime(timezone=True), nullable=True), ) op.add_column( "index_attempt", sa.Column("poll_range_end", sa.DateTime(timezone=True), nullable=True), ) op.create_index( "ix_index_attempt_cc_pair_settings_poll", "index_attempt", [ "connector_credential_pair_id", "search_settings_id", "status", sa.text("time_updated DESC"), ], ) # Drop the old IndexAttemptError table op.drop_index("index_attempt_id", table_name="index_attempt_errors") op.drop_table("index_attempt_errors") # Create the new version of the table op.create_table( "index_attempt_errors", sa.Column("id", sa.Integer(), primary_key=True), sa.Column("index_attempt_id", sa.Integer(), nullable=False), sa.Column("connector_credential_pair_id", sa.Integer(), nullable=False), sa.Column("document_id", sa.String(), nullable=True), sa.Column("document_link", sa.String(), nullable=True), sa.Column("entity_id", sa.String(), nullable=True), sa.Column("failed_time_range_start", sa.DateTime(timezone=True), nullable=True), sa.Column("failed_time_range_end", sa.DateTime(timezone=True), nullable=True), sa.Column("failure_message", sa.Text(), nullable=False), sa.Column("is_resolved", sa.Boolean(), nullable=False, default=False), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.ForeignKeyConstraint( ["index_attempt_id"], ["index_attempt.id"], ), sa.ForeignKeyConstraint( ["connector_credential_pair_id"], ["connector_credential_pair.id"], ), ) def downgrade() -> None: op.execute("SET lock_timeout = '5s'") # try a few times to drop the table, this has been observed to fail due to other locks # blocking the drop NUM_TRIES = 10 for i in range(NUM_TRIES): try: op.drop_table("index_attempt_errors") break except Exception as e: if i == NUM_TRIES - 1: raise e print(f"Error dropping table: {e}. Retrying...") op.execute("SET lock_timeout = DEFAULT") # Recreate the old IndexAttemptError table op.create_table( "index_attempt_errors", sa.Column("id", sa.Integer(), primary_key=True), sa.Column("index_attempt_id", sa.Integer(), nullable=True), sa.Column("batch", sa.Integer(), nullable=True), sa.Column("doc_summaries", postgresql.JSONB(), nullable=False), sa.Column("error_msg", sa.Text(), nullable=True), sa.Column("traceback", sa.Text(), nullable=True), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()"), ), sa.ForeignKeyConstraint( ["index_attempt_id"], ["index_attempt.id"], ), ) op.create_index( "index_attempt_id", "index_attempt_errors", ["time_created"], ) op.drop_index("ix_index_attempt_cc_pair_settings_poll") op.drop_column("index_attempt", "checkpoint_pointer") op.drop_column("index_attempt", "poll_range_start") op.drop_column("index_attempt", "poll_range_end") ================================================ FILE: backend/alembic/versions/b7bcc991d722_assign_users_to_default_groups.py ================================================ """assign_users_to_default_groups Revision ID: b7bcc991d722 Revises: 03d085c5c38d Create Date: 2026-03-25 16:30:39.529301 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects.postgresql import insert as pg_insert # revision identifiers, used by Alembic. revision = "b7bcc991d722" down_revision = "03d085c5c38d" branch_labels = None depends_on = None # The no-auth placeholder user must NOT be assigned to default groups. # A database trigger (migrate_no_auth_data_to_user) will try to DELETE this # user when the first real user registers; group membership rows would cause # an FK violation on that DELETE. NO_AUTH_PLACEHOLDER_USER_UUID = "00000000-0000-0000-0000-000000000001" # Reflect table structures for use in DML user_group_table = sa.table( "user_group", sa.column("id", sa.Integer), sa.column("name", sa.String), sa.column("is_default", sa.Boolean), ) user_table = sa.table( "user", sa.column("id", sa.Uuid), sa.column("role", sa.String), sa.column("account_type", sa.String), sa.column("is_active", sa.Boolean), ) user__user_group_table = sa.table( "user__user_group", sa.column("user_group_id", sa.Integer), sa.column("user_id", sa.Uuid), ) def upgrade() -> None: conn = op.get_bind() # Look up default group IDs admin_row = conn.execute( sa.select(user_group_table.c.id).where( user_group_table.c.name == "Admin", user_group_table.c.is_default == True, # noqa: E712 ) ).fetchone() basic_row = conn.execute( sa.select(user_group_table.c.id).where( user_group_table.c.name == "Basic", user_group_table.c.is_default == True, # noqa: E712 ) ).fetchone() if admin_row is None: raise RuntimeError( "Default 'Admin' group not found. " "Ensure migration 977e834c1427 (seed_default_groups) ran successfully." ) if basic_row is None: raise RuntimeError( "Default 'Basic' group not found. " "Ensure migration 977e834c1427 (seed_default_groups) ran successfully." ) # Users with role=admin → Admin group # Include inactive users so reactivation doesn't require reconciliation. # Exclude non-human account types (mirrors assign_user_to_default_groups logic). admin_users = sa.select( sa.literal(admin_row[0]).label("user_group_id"), user_table.c.id.label("user_id"), ).where( user_table.c.role == "ADMIN", user_table.c.account_type.notin_(["BOT", "EXT_PERM_USER", "ANONYMOUS"]), user_table.c.id != NO_AUTH_PLACEHOLDER_USER_UUID, ) op.execute( pg_insert(user__user_group_table) .from_select(["user_group_id", "user_id"], admin_users) .on_conflict_do_nothing(index_elements=["user_group_id", "user_id"]) ) # STANDARD users (non-admin) and SERVICE_ACCOUNT users (role=basic) → Basic group # Include inactive users so reactivation doesn't require reconciliation. basic_users = sa.select( sa.literal(basic_row[0]).label("user_group_id"), user_table.c.id.label("user_id"), ).where( user_table.c.account_type.notin_(["BOT", "EXT_PERM_USER", "ANONYMOUS"]), user_table.c.id != NO_AUTH_PLACEHOLDER_USER_UUID, sa.or_( sa.and_( user_table.c.account_type == "STANDARD", user_table.c.role != "ADMIN", ), sa.and_( user_table.c.account_type == "SERVICE_ACCOUNT", user_table.c.role == "BASIC", ), ), ) op.execute( pg_insert(user__user_group_table) .from_select(["user_group_id", "user_id"], basic_users) .on_conflict_do_nothing(index_elements=["user_group_id", "user_id"]) ) def downgrade() -> None: # Group memberships are left in place — removing them risks # deleting memberships that existed before this migration. pass ================================================ FILE: backend/alembic/versions/b7c2b63c4a03_add_background_reindex_enabled_field.py ================================================ """add background_reindex_enabled field Revision ID: b7c2b63c4a03 Revises: f11b408e39d3 Create Date: 2024-03-26 12:34:56.789012 """ from alembic import op import sqlalchemy as sa from onyx.db.enums import EmbeddingPrecision # revision identifiers, used by Alembic. revision = "b7c2b63c4a03" down_revision = "f11b408e39d3" branch_labels = None depends_on = None def upgrade() -> None: # Add background_reindex_enabled column with default value of True op.add_column( "search_settings", sa.Column( "background_reindex_enabled", sa.Boolean(), nullable=False, server_default="true", ), ) # Add embedding_precision column with default value of FLOAT op.add_column( "search_settings", sa.Column( "embedding_precision", sa.Enum(EmbeddingPrecision, native_enum=False), nullable=False, server_default=EmbeddingPrecision.FLOAT.name, ), ) # Add reduced_dimension column with default value of None op.add_column( "search_settings", sa.Column("reduced_dimension", sa.Integer(), nullable=True), ) def downgrade() -> None: # Remove the background_reindex_enabled column op.drop_column("search_settings", "background_reindex_enabled") op.drop_column("search_settings", "embedding_precision") op.drop_column("search_settings", "reduced_dimension") ================================================ FILE: backend/alembic/versions/b7ec9b5b505f_adjust_prompt_length.py ================================================ """adjust prompt length Revision ID: b7ec9b5b505f Revises: abbfec3a5ac5 Create Date: 2025-09-10 18:51:15.629197 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "b7ec9b5b505f" down_revision = "abbfec3a5ac5" branch_labels = None depends_on = None MAX_PROMPT_LENGTH = 5_000_000 def upgrade() -> None: # NOTE: need to run this since the previous migration PREVIOUSLY set the length to 8000 op.alter_column( "persona", "system_prompt", existing_type=sa.String(length=8000), type_=sa.String(length=MAX_PROMPT_LENGTH), existing_nullable=False, ) op.alter_column( "persona", "task_prompt", existing_type=sa.String(length=8000), type_=sa.String(length=MAX_PROMPT_LENGTH), existing_nullable=False, ) def downgrade() -> None: # Downgrade not necessary pass ================================================ FILE: backend/alembic/versions/b85f02ec1308_fix_file_type_migration.py ================================================ """fix-file-type-migration Revision ID: b85f02ec1308 Revises: a3bfd0d64902 Create Date: 2024-05-31 18:09:26.658164 """ from alembic import op # revision identifiers, used by Alembic. revision = "b85f02ec1308" down_revision = "a3bfd0d64902" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.execute( """ UPDATE file_store SET file_origin = UPPER(file_origin) """ ) def downgrade() -> None: # Let's not break anything on purpose :) pass ================================================ FILE: backend/alembic/versions/b896bbd0d5a7_backfill_is_internet_data_to_false.py ================================================ """backfill is_internet data to False Revision ID: b896bbd0d5a7 Revises: 44f856ae2a4a Create Date: 2024-07-16 15:21:05.718571 """ from alembic import op # revision identifiers, used by Alembic. revision = "b896bbd0d5a7" down_revision = "44f856ae2a4a" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.execute("UPDATE search_doc SET is_internet = FALSE WHERE is_internet IS NULL") def downgrade() -> None: pass ================================================ FILE: backend/alembic/versions/b8c9d0e1f2a3_drop_milestone_table.py ================================================ """Drop milestone table Revision ID: b8c9d0e1f2a3 Revises: a2b3c4d5e6f7 Create Date: 2025-12-18 """ from alembic import op import sqlalchemy as sa import fastapi_users_db_sqlalchemy from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "b8c9d0e1f2a3" down_revision = "a2b3c4d5e6f7" branch_labels = None depends_on = None def upgrade() -> None: op.drop_table("milestone") def downgrade() -> None: op.create_table( "milestone", sa.Column("id", sa.UUID(), nullable=False), sa.Column("tenant_id", sa.String(), nullable=True), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=True, ), sa.Column("event_type", sa.String(), nullable=False), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column("event_tracker", postgresql.JSONB(), nullable=True), sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("event_type", name="uq_milestone_event_type"), ) ================================================ FILE: backend/alembic/versions/ba98eba0f66a_add_support_for_litellm_proxy_in_.py ================================================ """add support for litellm proxy in reranking Revision ID: ba98eba0f66a Revises: bceb1e139447 Create Date: 2024-09-06 10:36:04.507332 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "ba98eba0f66a" down_revision = "bceb1e139447" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "search_settings", sa.Column("rerank_api_url", sa.String(), nullable=True) ) def downgrade() -> None: op.drop_column("search_settings", "rerank_api_url") ================================================ FILE: backend/alembic/versions/baf71f781b9e_add_llm_model_version_override_to_.py ================================================ """Add llm_model_version_override to Persona Revision ID: baf71f781b9e Revises: 50b683a8295c Create Date: 2023-12-06 21:56:50.286158 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "baf71f781b9e" down_revision = "50b683a8295c" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "persona", sa.Column("llm_model_version_override", sa.String(), nullable=True), ) def downgrade() -> None: op.drop_column("persona", "llm_model_version_override") ================================================ FILE: backend/alembic/versions/bc9771dccadf_create_usage_reports_table.py ================================================ """create usage reports table Revision ID: bc9771dccadf Revises: 0568ccf46a6b Create Date: 2024-06-18 10:04:26.800282 """ from alembic import op import sqlalchemy as sa import fastapi_users_db_sqlalchemy # revision identifiers, used by Alembic. revision = "bc9771dccadf" down_revision = "0568ccf46a6b" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "usage_reports", sa.Column("id", sa.Integer(), nullable=False), sa.Column("report_name", sa.String(), nullable=False), sa.Column( "requestor_user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=True, ), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column("period_from", sa.DateTime(timezone=True), nullable=True), sa.Column("period_to", sa.DateTime(timezone=True), nullable=True), sa.ForeignKeyConstraint( ["report_name"], ["file_store.file_name"], ), sa.ForeignKeyConstraint( ["requestor_user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("id"), ) def downgrade() -> None: op.drop_table("usage_reports") ================================================ FILE: backend/alembic/versions/bceb1e139447_add_base_url_to_cloudembeddingprovider.py ================================================ """Add base_url to CloudEmbeddingProvider Revision ID: bceb1e139447 Revises: a3795dce87be Create Date: 2024-08-28 17:00:52.554580 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "bceb1e139447" down_revision = "a3795dce87be" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "embedding_provider", sa.Column("api_url", sa.String(), nullable=True) ) def downgrade() -> None: op.drop_column("embedding_provider", "api_url") ================================================ FILE: backend/alembic/versions/bd2921608c3a_non_nullable_default_persona.py ================================================ """non nullable default persona Revision ID: bd2921608c3a Revises: 797089dfb4d2 Create Date: 2024-09-20 10:28:37.992042 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "bd2921608c3a" down_revision = "797089dfb4d2" branch_labels = None depends_on = None def upgrade() -> None: # Set existing NULL values to False op.execute( "UPDATE persona SET is_default_persona = FALSE WHERE is_default_persona IS NULL" ) # Alter the column to be not nullable with a default value of False op.alter_column( "persona", "is_default_persona", existing_type=sa.Boolean(), nullable=False, server_default=sa.text("false"), ) def downgrade() -> None: # Revert the changes op.alter_column( "persona", "is_default_persona", existing_type=sa.Boolean(), nullable=True, server_default=None, ) ================================================ FILE: backend/alembic/versions/bd7c3bf8beba_migrate_agent_responses_to_research_.py ================================================ """migrate_agent_sub_questions_to_research_iterations Revision ID: bd7c3bf8beba Revises: f8a9b2c3d4e5 Create Date: 2025-08-18 11:33:27.098287 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "bd7c3bf8beba" down_revision = "f8a9b2c3d4e5" branch_labels = None depends_on = None def upgrade() -> None: # Get connection to execute raw SQL connection = op.get_bind() # First, insert data into research_agent_iteration table # This creates one iteration record per primary_question_id using the earliest time_created connection.execute( sa.text( """ INSERT INTO research_agent_iteration (primary_question_id, created_at, iteration_nr, purpose, reasoning) SELECT primary_question_id, MIN(time_created) as created_at, 1 as iteration_nr, 'Generating and researching subquestions' as purpose, '(No previous reasoning)' as reasoning FROM agent__sub_question JOIN chat_message on agent__sub_question.primary_question_id = chat_message.id WHERE primary_question_id IS NOT NULL AND chat_message.is_agentic = true GROUP BY primary_question_id ON CONFLICT DO NOTHING; """ ) ) # Then, insert data into research_agent_iteration_sub_step table # This migrates each sub-question as a sub-step connection.execute( sa.text( """ INSERT INTO research_agent_iteration_sub_step ( primary_question_id, iteration_nr, iteration_sub_step_nr, created_at, sub_step_instructions, sub_step_tool_id, sub_answer, cited_doc_results ) SELECT primary_question_id, 1 as iteration_nr, level_question_num as iteration_sub_step_nr, time_created as created_at, sub_question as sub_step_instructions, 1 as sub_step_tool_id, sub_answer, sub_question_doc_results as cited_doc_results FROM agent__sub_question JOIN chat_message on agent__sub_question.primary_question_id = chat_message.id WHERE chat_message.is_agentic = true AND primary_question_id IS NOT NULL ON CONFLICT DO NOTHING; """ ) ) # Update chat_message records: set legacy agentic type and answer purpose for existing agentic messages connection.execute( sa.text( """ UPDATE chat_message SET research_answer_purpose = 'ANSWER' WHERE is_agentic = true AND research_type IS NULL and message_type = 'ASSISTANT'; """ ) ) connection.execute( sa.text( """ UPDATE chat_message SET research_type = 'LEGACY_AGENTIC' WHERE is_agentic = true AND research_type IS NULL; """ ) ) def downgrade() -> None: # Get connection to execute raw SQL connection = op.get_bind() # Note: This downgrade removes all research agent iteration data # There's no way to perfectly restore the original agent__sub_question data # if it was deleted after this migration # Delete all research_agent_iteration_sub_step records that were migrated connection.execute( sa.text( """ DELETE FROM research_agent_iteration_sub_step USING chat_message WHERE research_agent_iteration_sub_step.primary_question_id = chat_message.id AND chat_message.research_type = 'LEGACY_AGENTIC'; """ ) ) # Delete all research_agent_iteration records that were migrated connection.execute( sa.text( """ DELETE FROM research_agent_iteration USING chat_message WHERE research_agent_iteration.primary_question_id = chat_message.id AND chat_message.research_type = 'LEGACY_AGENTIC'; """ ) ) # Revert chat_message updates: clear research fields for legacy agentic messages connection.execute( sa.text( """ UPDATE chat_message SET research_type = NULL, research_answer_purpose = NULL WHERE is_agentic = true AND research_type = 'LEGACY_AGENTIC' AND message_type = 'ASSISTANT'; """ ) ) ================================================ FILE: backend/alembic/versions/be2ab2aa50ee_fix_capitalization.py ================================================ """fix_capitalization Revision ID: be2ab2aa50ee Revises: 369644546676 Create Date: 2025-01-10 13:13:26.228960 """ from alembic import op # revision identifiers, used by Alembic. revision = "be2ab2aa50ee" down_revision = "369644546676" branch_labels = None depends_on = None def upgrade() -> None: op.execute( """ UPDATE document SET external_user_group_ids = ARRAY( SELECT LOWER(unnest(external_user_group_ids)) ), last_modified = NOW() WHERE external_user_group_ids IS NOT NULL AND external_user_group_ids::text[] <> ARRAY( SELECT LOWER(unnest(external_user_group_ids)) )::text[] """ ) def downgrade() -> None: # No way to cleanly persist the bad state through an upgrade/downgrade # cycle, so we just pass pass ================================================ FILE: backend/alembic/versions/be87a654d5af_persona_new_default_model_configuration_.py ================================================ """Persona new default model configuration id column Revision ID: be87a654d5af Revises: e7f8a9b0c1d2 Create Date: 2026-01-30 11:14:17.306275 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "be87a654d5af" down_revision = "e7f8a9b0c1d2" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "persona", sa.Column("default_model_configuration_id", sa.Integer(), nullable=True), ) op.create_foreign_key( "fk_persona_default_model_configuration_id", "persona", "model_configuration", ["default_model_configuration_id"], ["id"], ondelete="SET NULL", ) def downgrade() -> None: op.drop_constraint( "fk_persona_default_model_configuration_id", "persona", type_="foreignkey" ) op.drop_column("persona", "default_model_configuration_id") ================================================ FILE: backend/alembic/versions/bf7a81109301_delete_input_prompts.py ================================================ """delete_input_prompts Revision ID: bf7a81109301 Revises: f7a894b06d02 Create Date: 2024-12-09 12:00:49.884228 """ from alembic import op import sqlalchemy as sa import fastapi_users_db_sqlalchemy # revision identifiers, used by Alembic. revision = "bf7a81109301" down_revision = "f7a894b06d02" branch_labels = None depends_on = None def upgrade() -> None: op.drop_table("inputprompt__user") op.drop_table("inputprompt") def downgrade() -> None: op.create_table( "inputprompt", sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), sa.Column("prompt", sa.String(), nullable=False), sa.Column("content", sa.String(), nullable=False), sa.Column("active", sa.Boolean(), nullable=False), sa.Column("is_public", sa.Boolean(), nullable=False), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=True, ), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("id"), ) op.create_table( "inputprompt__user", sa.Column("input_prompt_id", sa.Integer(), nullable=False), sa.Column("user_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["input_prompt_id"], ["inputprompt.id"], ), sa.ForeignKeyConstraint( ["user_id"], ["inputprompt.id"], ), sa.PrimaryKeyConstraint("input_prompt_id", "user_id"), ) ================================================ FILE: backend/alembic/versions/c0aab6edb6dd_delete_workspace.py ================================================ """delete workspace Revision ID: c0aab6edb6dd Revises: 35e518e0ddf4 Create Date: 2024-12-17 14:37:07.660631 """ from alembic import op # revision identifiers, used by Alembic. revision = "c0aab6edb6dd" down_revision = "35e518e0ddf4" branch_labels = None depends_on = None def upgrade() -> None: op.execute( """ UPDATE connector SET connector_specific_config = connector_specific_config - 'workspace' WHERE source = 'SLACK' """ ) def downgrade() -> None: import json from sqlalchemy import text from slack_sdk import WebClient conn = op.get_bind() # Fetch all Slack credentials creds_result = conn.execute( text("SELECT id, credential_json FROM credential WHERE source = 'SLACK'") ) all_slack_creds = creds_result.fetchall() if not all_slack_creds: return for cred_row in all_slack_creds: credential_id, credential_json = cred_row credential_json = ( credential_json.tobytes().decode("utf-8") if isinstance(credential_json, memoryview) else credential_json.decode("utf-8") ) credential_data = json.loads(credential_json) slack_bot_token = credential_data.get("slack_bot_token") if not slack_bot_token: print( f"No slack_bot_token found for credential {credential_id}. " "Your Slack connector will not function until you upgrade and provide a valid token." ) continue client = WebClient(token=slack_bot_token) try: auth_response = client.auth_test() workspace = auth_response["url"].split("//")[1].split(".")[0] # Update only the connectors linked to this credential # (and which are Slack connectors). op.execute( f""" UPDATE connector AS c SET connector_specific_config = jsonb_set( connector_specific_config, '{{workspace}}', to_jsonb('{workspace}'::text) ) FROM connector_credential_pair AS ccp WHERE ccp.connector_id = c.id AND c.source = 'SLACK' AND ccp.credential_id = {credential_id} """ ) except Exception: print( f"We were unable to get the workspace url for your Slack Connector with id {credential_id}." ) print("This connector will no longer work until you upgrade.") continue ================================================ FILE: backend/alembic/versions/c0c937d5c9e5_llm_provider_deprecate_fields.py ================================================ """llm provider deprecate fields Revision ID: c0c937d5c9e5 Revises: 8ffcc2bcfc11 Create Date: 2026-02-25 17:35:46.125102 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "c0c937d5c9e5" down_revision = "8ffcc2bcfc11" branch_labels = None depends_on = None def upgrade() -> None: # Make default_model_name nullable (was NOT NULL) op.alter_column( "llm_provider", "default_model_name", existing_type=sa.String(), nullable=True, ) # Drop unique constraint on is_default_provider (defaults now tracked via LLMModelFlow) op.drop_constraint( "llm_provider_is_default_provider_key", "llm_provider", type_="unique", ) # Remove server_default from is_default_vision_provider (was server_default=false()) op.alter_column( "llm_provider", "is_default_vision_provider", existing_type=sa.Boolean(), server_default=None, ) def downgrade() -> None: # Restore default_model_name to NOT NULL (set empty string for any NULLs first) op.execute( "UPDATE llm_provider SET default_model_name = '' WHERE default_model_name IS NULL" ) op.alter_column( "llm_provider", "default_model_name", existing_type=sa.String(), nullable=False, ) # Restore unique constraint on is_default_provider op.create_unique_constraint( "llm_provider_is_default_provider_key", "llm_provider", ["is_default_provider"], ) # Restore server_default for is_default_vision_provider op.alter_column( "llm_provider", "is_default_vision_provider", existing_type=sa.Boolean(), server_default=sa.false(), ) ================================================ FILE: backend/alembic/versions/c0fd6e4da83a_add_recent_assistants.py ================================================ """add recent assistants Revision ID: c0fd6e4da83a Revises: b72ed7a5db0e Create Date: 2024-11-03 17:28:54.916618 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "c0fd6e4da83a" down_revision = "b72ed7a5db0e" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "user", sa.Column( "recent_assistants", postgresql.JSONB(), server_default="[]", nullable=False ), ) def downgrade() -> None: op.drop_column("user", "recent_assistants") ================================================ FILE: backend/alembic/versions/c18cdf4b497e_add_standard_answer_tables.py ================================================ """Add standard_answer tables Revision ID: c18cdf4b497e Revises: 3a7802814195 Create Date: 2024-06-06 15:15:02.000648 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "c18cdf4b497e" down_revision = "3a7802814195" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "standard_answer", sa.Column("id", sa.Integer(), nullable=False), sa.Column("keyword", sa.String(), nullable=False), sa.Column("answer", sa.String(), nullable=False), sa.Column("active", sa.Boolean(), nullable=False), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("keyword"), ) op.create_table( "standard_answer_category", sa.Column("id", sa.Integer(), nullable=False), sa.Column("name", sa.String(), nullable=False), sa.PrimaryKeyConstraint("id"), sa.UniqueConstraint("name"), ) op.create_table( "standard_answer__standard_answer_category", sa.Column("standard_answer_id", sa.Integer(), nullable=False), sa.Column("standard_answer_category_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["standard_answer_category_id"], ["standard_answer_category.id"], ), sa.ForeignKeyConstraint( ["standard_answer_id"], ["standard_answer.id"], ), sa.PrimaryKeyConstraint("standard_answer_id", "standard_answer_category_id"), ) op.create_table( "slack_bot_config__standard_answer_category", sa.Column("slack_bot_config_id", sa.Integer(), nullable=False), sa.Column("standard_answer_category_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["slack_bot_config_id"], ["slack_bot_config.id"], ), sa.ForeignKeyConstraint( ["standard_answer_category_id"], ["standard_answer_category.id"], ), sa.PrimaryKeyConstraint("slack_bot_config_id", "standard_answer_category_id"), ) op.add_column( "chat_session", sa.Column("slack_thread_id", sa.String(), nullable=True) ) def downgrade() -> None: op.drop_column("chat_session", "slack_thread_id") op.drop_table("slack_bot_config__standard_answer_category") op.drop_table("standard_answer__standard_answer_category") op.drop_table("standard_answer_category") op.drop_table("standard_answer") ================================================ FILE: backend/alembic/versions/c1d2e3f4a5b6_add_deep_research_tool.py ================================================ """add_deep_research_tool Revision ID: c1d2e3f4a5b6 Revises: b8c9d0e1f2a3 Create Date: 2025-12-18 16:00:00.000000 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "c1d2e3f4a5b6" down_revision = "b8c9d0e1f2a3" branch_labels = None depends_on = None DEEP_RESEARCH_TOOL = { "name": "ResearchAgent", "display_name": "Research Agent", "description": "The Research Agent is a sub-agent that conducts research on a specific topic.", "in_code_tool_id": "ResearchAgent", } def upgrade() -> None: conn = op.get_bind() conn.execute( sa.text( """ INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled) VALUES (:name, :display_name, :description, :in_code_tool_id, false) """ ), DEEP_RESEARCH_TOOL, ) def downgrade() -> None: conn = op.get_bind() conn.execute( sa.text( """ DELETE FROM tool WHERE in_code_tool_id = :in_code_tool_id """ ), {"in_code_tool_id": DEEP_RESEARCH_TOOL["in_code_tool_id"]}, ) ================================================ FILE: backend/alembic/versions/c5b692fa265c_add_index_attempt_errors_table.py ================================================ """Add index_attempt_errors table Revision ID: c5b692fa265c Revises: 4a951134c801 Create Date: 2024-08-08 14:06:39.581972 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "c5b692fa265c" down_revision = "4a951134c801" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "index_attempt_errors", sa.Column("id", sa.Integer(), nullable=False), sa.Column("index_attempt_id", sa.Integer(), nullable=True), sa.Column("batch", sa.Integer(), nullable=True), sa.Column( "doc_summaries", postgresql.JSONB(astext_type=sa.Text()), nullable=False, ), sa.Column("error_msg", sa.Text(), nullable=True), sa.Column("traceback", sa.Text(), nullable=True), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.ForeignKeyConstraint( ["index_attempt_id"], ["index_attempt.id"], ), sa.PrimaryKeyConstraint("id"), ) op.create_index( "index_attempt_id", "index_attempt_errors", ["time_created"], unique=False, ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_index("index_attempt_id", table_name="index_attempt_errors") op.drop_table("index_attempt_errors") # ### end Alembic commands ### ================================================ FILE: backend/alembic/versions/c5eae4a75a1b_add_chat_message__standard_answer_table.py ================================================ """Add chat_message__standard_answer table Revision ID: c5eae4a75a1b Revises: 0f7ff6d75b57 Create Date: 2025-01-15 14:08:49.688998 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "c5eae4a75a1b" down_revision = "0f7ff6d75b57" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "chat_message__standard_answer", sa.Column("chat_message_id", sa.Integer(), nullable=False), sa.Column("standard_answer_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["chat_message_id"], ["chat_message.id"], ), sa.ForeignKeyConstraint( ["standard_answer_id"], ["standard_answer.id"], ), sa.PrimaryKeyConstraint("chat_message_id", "standard_answer_id"), ) def downgrade() -> None: op.drop_table("chat_message__standard_answer") ================================================ FILE: backend/alembic/versions/c7bf5721733e_add_has_been_indexed_to_.py ================================================ """Add has_been_indexed to DocumentByConnectorCredentialPair Revision ID: c7bf5721733e Revises: fec3db967bf7 Create Date: 2025-01-13 12:39:05.831693 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "c7bf5721733e" down_revision = "027381bce97c" branch_labels = None depends_on = None def upgrade() -> None: # assume all existing rows have been indexed, no better approach op.add_column( "document_by_connector_credential_pair", sa.Column("has_been_indexed", sa.Boolean(), nullable=True), ) op.execute( "UPDATE document_by_connector_credential_pair SET has_been_indexed = TRUE" ) op.alter_column( "document_by_connector_credential_pair", "has_been_indexed", nullable=False, ) # Add index to optimize get_document_counts_for_cc_pairs query pattern op.create_index( "idx_document_cc_pair_counts", "document_by_connector_credential_pair", ["connector_id", "credential_id", "has_been_indexed"], unique=False, ) def downgrade() -> None: # Remove the index first before removing the column op.drop_index( "idx_document_cc_pair_counts", table_name="document_by_connector_credential_pair", ) op.drop_column("document_by_connector_credential_pair", "has_been_indexed") ================================================ FILE: backend/alembic/versions/c7e9f4a3b2d1_add_python_tool.py ================================================ """add_python_tool Revision ID: c7e9f4a3b2d1 Revises: 3c9a65f1207f Create Date: 2025-11-08 00:00:00.000000 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "c7e9f4a3b2d1" down_revision = "3c9a65f1207f" branch_labels = None depends_on = None def upgrade() -> None: """Add PythonTool to built-in tools""" conn = op.get_bind() conn.execute( sa.text( """ INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled) VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled) """ ), { "name": "PythonTool", # in the UI, call it `Code Interpreter` since this is a well known term for this tool "display_name": "Code Interpreter", "description": ( "The Code Interpreter Action allows the assistant to execute " "Python code in a secure, isolated environment for data analysis, " "computation, visualization, and file processing." ), "in_code_tool_id": "PythonTool", "enabled": True, }, ) # needed to store files generated by the python tool op.add_column( "research_agent_iteration_sub_step", sa.Column( "file_ids", postgresql.JSONB(astext_type=sa.Text()), nullable=True, ), ) def downgrade() -> None: """Remove PythonTool from built-in tools""" conn = op.get_bind() conn.execute( sa.text( """ DELETE FROM tool WHERE in_code_tool_id = :in_code_tool_id """ ), { "in_code_tool_id": "PythonTool", }, ) op.drop_column("research_agent_iteration_sub_step", "file_ids") ================================================ FILE: backend/alembic/versions/c7f2e1b4a9d3_add_sharing_scope_to_build_session.py ================================================ """add sharing_scope to build_session Revision ID: c7f2e1b4a9d3 Revises: 19c0ccb01687 Create Date: 2026-02-17 12:00:00.000000 """ from alembic import op import sqlalchemy as sa revision = "c7f2e1b4a9d3" down_revision = "19c0ccb01687" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "build_session", sa.Column( "sharing_scope", sa.String(), nullable=False, server_default="private", ), ) def downgrade() -> None: op.drop_column("build_session", "sharing_scope") ================================================ FILE: backend/alembic/versions/c8a93a2af083_personalization_user_info.py ================================================ """personalization_user_info Revision ID: c8a93a2af083 Revises: 6f4f86aef280 Create Date: 2025-10-14 15:59:03.577343 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "c8a93a2af083" down_revision = "6f4f86aef280" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "user", sa.Column("personal_name", sa.String(), nullable=True), ) op.add_column( "user", sa.Column("personal_role", sa.String(), nullable=True), ) op.add_column( "user", sa.Column( "use_memories", sa.Boolean(), nullable=False, server_default=sa.true(), ), ) op.create_table( "memory", sa.Column("id", sa.Integer(), nullable=False), sa.Column("user_id", postgresql.UUID(as_uuid=True), nullable=False), sa.Column("memory_text", sa.Text(), nullable=False), sa.Column("conversation_id", postgresql.UUID(as_uuid=True), nullable=True), sa.Column("message_id", sa.Integer(), nullable=True), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False, ), sa.Column( "updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False, ), sa.ForeignKeyConstraint(["user_id"], ["user.id"], ondelete="CASCADE"), sa.PrimaryKeyConstraint("id"), ) op.create_index("ix_memory_user_id", "memory", ["user_id"]) def downgrade() -> None: op.drop_index("ix_memory_user_id", table_name="memory") op.drop_table("memory") op.drop_column("user", "use_memories") op.drop_column("user", "personal_role") op.drop_column("user", "personal_name") ================================================ FILE: backend/alembic/versions/c99d76fcd298_add_nullable_to_persona_id_in_chat_.py ================================================ """add nullable to persona id in Chat Session Revision ID: c99d76fcd298 Revises: 5c7fdadae813 Create Date: 2024-07-09 19:27:01.579697 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "c99d76fcd298" down_revision = "5c7fdadae813" branch_labels = None depends_on = None def upgrade() -> None: op.alter_column( "chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=True ) def downgrade() -> None: # Delete chat messages and feedback first since they reference chat sessions # Get chat messages from sessions with null persona_id chat_messages_query = """ SELECT id FROM chat_message WHERE chat_session_id IN ( SELECT id FROM chat_session WHERE persona_id IS NULL ) """ # Delete dependent records first op.execute( f""" DELETE FROM document_retrieval_feedback WHERE chat_message_id IN ( {chat_messages_query} ) """ ) op.execute( f""" DELETE FROM chat_message__search_doc WHERE chat_message_id IN ( {chat_messages_query} ) """ ) # Delete chat messages op.execute( """ DELETE FROM chat_message WHERE chat_session_id IN ( SELECT id FROM chat_session WHERE persona_id IS NULL ) """ ) # Now we can safely delete the chat sessions op.execute( """ DELETE FROM chat_session WHERE persona_id IS NULL """ ) op.alter_column( "chat_session", "persona_id", existing_type=sa.INTEGER(), nullable=False, ) ================================================ FILE: backend/alembic/versions/c9e2cd766c29_add_s3_file_store_table.py ================================================ """modify_file_store_for_external_storage Revision ID: c9e2cd766c29 Revises: 03bf8be6b53a Create Date: 2025-06-13 14:02:09.867679 """ from alembic import op import sqlalchemy as sa from sqlalchemy.orm import Session from sqlalchemy import text from typing import cast, Any from botocore.exceptions import ClientError from onyx.db._deprecated.pg_file_store import delete_lobj_by_id, read_lobj from onyx.file_store.file_store import get_s3_file_store from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR # revision identifiers, used by Alembic. revision = "c9e2cd766c29" down_revision = "03bf8be6b53a" branch_labels = None depends_on = None def upgrade() -> None: try: # Modify existing file_store table to support external storage op.rename_table("file_store", "file_record") # Make lobj_oid nullable (for external storage files) op.alter_column("file_record", "lobj_oid", nullable=True) # Add external storage columns with generic names op.add_column( "file_record", sa.Column("bucket_name", sa.String(), nullable=True) ) op.add_column( "file_record", sa.Column("object_key", sa.String(), nullable=True) ) # Add timestamps for tracking op.add_column( "file_record", sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False, ), ) op.add_column( "file_record", sa.Column( "updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False, ), ) op.alter_column("file_record", "file_name", new_column_name="file_id") except Exception as e: if "does not exist" in str(e) or 'relation "file_store" does not exist' in str( e ): print( f"Ran into error - {e}. Likely means we had a partial success in the past, continuing..." ) else: raise print( "External storage configured - migrating files from PostgreSQL to external storage..." ) # if we fail midway through this, we'll have a partial success. Running the migration # again should allow us to continue. _migrate_files_to_external_storage() print("File migration completed successfully!") # Remove lobj_oid column op.drop_column("file_record", "lobj_oid") def downgrade() -> None: """Revert schema changes and migrate files from external storage back to PostgreSQL large objects.""" print( "Reverting to PostgreSQL-backed file store – migrating files from external storage …" ) # 1. Ensure `lobj_oid` exists on the current `file_record` table (nullable for now). op.add_column("file_record", sa.Column("lobj_oid", sa.Integer(), nullable=True)) # 2. Move content from external storage back into PostgreSQL large objects (table is still # called `file_record` so application code continues to work during the copy). try: _migrate_files_to_postgres() except Exception: print("Error during downgrade migration, rolling back …") op.drop_column("file_record", "lobj_oid") raise # 3. After migration every row should now have `lobj_oid` populated – mark NOT NULL. op.alter_column("file_record", "lobj_oid", nullable=False) # 4. Remove columns that are only relevant to external storage. op.drop_column("file_record", "updated_at") op.drop_column("file_record", "created_at") op.drop_column("file_record", "object_key") op.drop_column("file_record", "bucket_name") # 5. Rename `file_id` back to `file_name` (still on `file_record`). op.alter_column("file_record", "file_id", new_column_name="file_name") # 6. Finally, rename the table back to its original name expected by the legacy codebase. op.rename_table("file_record", "file_store") print( "Downgrade migration completed – files are now stored inside PostgreSQL again." ) # ----------------------------------------------------------------------------- # Helper: migrate from external storage (S3/MinIO) back into PostgreSQL large objects def _migrate_files_to_postgres() -> None: """Move any files whose content lives in external S3-compatible storage back into PostgreSQL. The logic mirrors *inverse* of `_migrate_files_to_external_storage` used on upgrade. """ # Obtain DB session from Alembic context bind = op.get_bind() session = Session(bind=bind) # Fetch rows that have external storage pointers (bucket/object_key not NULL) result = session.execute( text( "SELECT file_id, bucket_name, object_key FROM file_record WHERE bucket_name IS NOT NULL AND object_key IS NOT NULL" ) ) files_to_migrate = [row[0] for row in result.fetchall()] total_files = len(files_to_migrate) if total_files == 0: print("No files found in external storage to migrate back to PostgreSQL.") return print(f"Found {total_files} files to migrate back to PostgreSQL large objects.") _set_tenant_contextvar(session) migrated_count = 0 # only create external store if we have files to migrate. This line # makes it so we need to have S3/MinIO configured to run this migration. external_store = get_s3_file_store() for i, file_id in enumerate(files_to_migrate, 1): print(f"Migrating file {i}/{total_files}: {file_id}") # Read file content from external storage (always binary) try: file_io = external_store.read_file( file_id=file_id, mode="b", use_tempfile=True ) file_io.seek(0) # Import lazily to avoid circular deps at Alembic runtime from onyx.db._deprecated.pg_file_store import ( create_populate_lobj, ) # noqa: E402 # Create new Postgres large object and populate it lobj_oid = create_populate_lobj(content=file_io, db_session=session) # Update DB row: set lobj_oid, clear bucket/object_key session.execute( text( "UPDATE file_record SET lobj_oid = :lobj_oid, bucket_name = NULL, object_key = NULL WHERE file_id = :file_id" ), {"lobj_oid": lobj_oid, "file_id": file_id}, ) except ClientError as e: if "NoSuchKey" in str(e): print( f"File {file_id} not found in external storage. Deleting from database." ) session.execute( text("DELETE FROM file_record WHERE file_id = :file_id"), {"file_id": file_id}, ) else: raise migrated_count += 1 print(f"✓ Successfully migrated file {i}/{total_files}: {file_id}") # Flush the SQLAlchemy session so statements are sent to the DB, but **do not** # commit the transaction. The surrounding Alembic migration will commit once # the *entire* downgrade succeeds. This keeps the whole downgrade atomic and # avoids leaving the database in a partially-migrated state if a later schema # operation fails. session.flush() print( f"Migration back to PostgreSQL completed: {migrated_count} files staged for commit." ) def _migrate_files_to_external_storage() -> None: """Migrate files from PostgreSQL large objects to external storage""" # Get database session bind = op.get_bind() session = Session(bind=bind) external_store = get_s3_file_store() # Find all files currently stored in PostgreSQL (lobj_oid is not null) result = session.execute( text( "SELECT file_id FROM file_record WHERE lobj_oid IS NOT NULL AND bucket_name IS NULL AND object_key IS NULL" ) ) files_to_migrate = [row[0] for row in result.fetchall()] total_files = len(files_to_migrate) if total_files == 0: print("No files found in PostgreSQL storage to migrate.") return # might need to move this above the if statement when creating a new multi-tenant # system. VERY extreme edge case. external_store.initialize() print(f"Found {total_files} files to migrate from PostgreSQL to external storage.") _set_tenant_contextvar(session) migrated_count = 0 for i, file_id in enumerate(files_to_migrate, 1): print(f"Migrating file {i}/{total_files}: {file_id}") # Read file record to get metadata file_record = session.execute( text("SELECT * FROM file_record WHERE file_id = :file_id"), {"file_id": file_id}, ).fetchone() if file_record is None: print(f"File {file_id} not found in PostgreSQL storage.") continue lobj_id = cast(int, file_record.lobj_oid) file_metadata = cast(Any, file_record.file_metadata) # Read file content from PostgreSQL try: file_content = read_lobj( lobj_id, db_session=session, mode="b", use_tempfile=True ) except Exception as e: if "large object" in str(e) and "does not exist" in str(e): print(f"File {file_id} not found in PostgreSQL storage.") continue else: raise # Handle file_metadata type conversion file_metadata = None if file_metadata is not None: if isinstance(file_metadata, dict): file_metadata = file_metadata else: # Convert other types to dict if possible, otherwise None try: file_metadata = dict(file_record.file_metadata) except (TypeError, ValueError): file_metadata = None # Save to external storage (this will handle the database record update and cleanup) # NOTE: this WILL .commit() the transaction. external_store.save_file( file_id=file_id, content=file_content, display_name=file_record.display_name, file_origin=file_record.file_origin, file_type=file_record.file_type, file_metadata=file_metadata, ) delete_lobj_by_id(lobj_id, db_session=session) migrated_count += 1 print(f"✓ Successfully migrated file {i}/{total_files}: {file_id}") # See note above – flush but do **not** commit so the outer Alembic transaction # controls atomicity. session.flush() print( f"Migration completed: {migrated_count} files staged for commit to external storage." ) def _set_tenant_contextvar(session: Session) -> None: """Set the tenant contextvar to the default schema""" current_tenant = session.execute(text("SELECT current_schema()")).scalar() print(f"Migrating files for tenant: {current_tenant}") CURRENT_TENANT_ID_CONTEXTVAR.set(current_tenant) ================================================ FILE: backend/alembic/versions/ca04500b9ee8_add_cascade_deletes_to_agent_tables.py ================================================ """add_cascade_deletes_to_agent_tables Revision ID: ca04500b9ee8 Revises: 238b84885828 Create Date: 2025-05-30 16:03:51.112263 """ from alembic import op # revision identifiers, used by Alembic. revision = "ca04500b9ee8" down_revision = "238b84885828" branch_labels = None depends_on = None def upgrade() -> None: # Drop existing foreign key constraints op.drop_constraint( "agent__sub_question_primary_question_id_fkey", "agent__sub_question", type_="foreignkey", ) op.drop_constraint( "agent__sub_query_parent_question_id_fkey", "agent__sub_query", type_="foreignkey", ) op.drop_constraint( "chat_message__standard_answer_chat_message_id_fkey", "chat_message__standard_answer", type_="foreignkey", ) op.drop_constraint( "agent__sub_query__search_doc_sub_query_id_fkey", "agent__sub_query__search_doc", type_="foreignkey", ) # Recreate foreign key constraints with CASCADE delete op.create_foreign_key( "agent__sub_question_primary_question_id_fkey", "agent__sub_question", "chat_message", ["primary_question_id"], ["id"], ondelete="CASCADE", ) op.create_foreign_key( "agent__sub_query_parent_question_id_fkey", "agent__sub_query", "agent__sub_question", ["parent_question_id"], ["id"], ondelete="CASCADE", ) op.create_foreign_key( "chat_message__standard_answer_chat_message_id_fkey", "chat_message__standard_answer", "chat_message", ["chat_message_id"], ["id"], ondelete="CASCADE", ) op.create_foreign_key( "agent__sub_query__search_doc_sub_query_id_fkey", "agent__sub_query__search_doc", "agent__sub_query", ["sub_query_id"], ["id"], ondelete="CASCADE", ) def downgrade() -> None: # Drop CASCADE foreign key constraints op.drop_constraint( "agent__sub_question_primary_question_id_fkey", "agent__sub_question", type_="foreignkey", ) op.drop_constraint( "agent__sub_query_parent_question_id_fkey", "agent__sub_query", type_="foreignkey", ) op.drop_constraint( "chat_message__standard_answer_chat_message_id_fkey", "chat_message__standard_answer", type_="foreignkey", ) op.drop_constraint( "agent__sub_query__search_doc_sub_query_id_fkey", "agent__sub_query__search_doc", type_="foreignkey", ) # Recreate foreign key constraints without CASCADE delete op.create_foreign_key( "agent__sub_question_primary_question_id_fkey", "agent__sub_question", "chat_message", ["primary_question_id"], ["id"], ) op.create_foreign_key( "agent__sub_query_parent_question_id_fkey", "agent__sub_query", "agent__sub_question", ["parent_question_id"], ["id"], ) op.create_foreign_key( "chat_message__standard_answer_chat_message_id_fkey", "chat_message__standard_answer", "chat_message", ["chat_message_id"], ["id"], ) op.create_foreign_key( "agent__sub_query__search_doc_sub_query_id_fkey", "agent__sub_query__search_doc", "agent__sub_query", ["sub_query_id"], ["id"], ) ================================================ FILE: backend/alembic/versions/cbc03e08d0f3_add_opensearch_migration_tables.py ================================================ """add_opensearch_migration_tables Revision ID: cbc03e08d0f3 Revises: be87a654d5af Create Date: 2026-01-31 17:00:45.176604 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "cbc03e08d0f3" down_revision = "be87a654d5af" branch_labels = None depends_on = None def upgrade() -> None: # 1. Create opensearch_document_migration_record table. op.create_table( "opensearch_document_migration_record", sa.Column("document_id", sa.String(), nullable=False), sa.Column("status", sa.String(), nullable=False, server_default="pending"), sa.Column("error_message", sa.Text(), nullable=True), sa.Column("attempts_count", sa.Integer(), nullable=False, server_default="0"), sa.Column("last_attempt_at", sa.DateTime(timezone=True), nullable=True), sa.Column( "created_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False, ), sa.PrimaryKeyConstraint("document_id"), sa.ForeignKeyConstraint( ["document_id"], ["document.id"], ondelete="CASCADE", ), ) # 2. Create indices. op.create_index( "ix_opensearch_document_migration_record_status", "opensearch_document_migration_record", ["status"], ) op.create_index( "ix_opensearch_document_migration_record_attempts_count", "opensearch_document_migration_record", ["attempts_count"], ) op.create_index( "ix_opensearch_document_migration_record_created_at", "opensearch_document_migration_record", ["created_at"], ) # 3. Create opensearch_tenant_migration_record table (singleton). op.create_table( "opensearch_tenant_migration_record", sa.Column("id", sa.Integer(), nullable=False), sa.Column( "document_migration_record_table_population_status", sa.String(), nullable=False, server_default="pending", ), sa.Column( "num_times_observed_no_additional_docs_to_populate_migration_table", sa.Integer(), nullable=False, server_default="0", ), sa.Column( "overall_document_migration_status", sa.String(), nullable=False, server_default="pending", ), sa.Column( "num_times_observed_no_additional_docs_to_migrate", sa.Integer(), nullable=False, server_default="0", ), sa.Column( "last_updated_at", sa.DateTime(timezone=True), server_default=sa.func.now(), nullable=False, ), sa.PrimaryKeyConstraint("id"), ) # 4. Create unique index on constant to enforce singleton pattern. op.execute( sa.text( """ CREATE UNIQUE INDEX idx_opensearch_tenant_migration_singleton ON opensearch_tenant_migration_record ((true)) """ ) ) def downgrade() -> None: # Drop opensearch_tenant_migration_record. op.drop_index( "idx_opensearch_tenant_migration_singleton", table_name="opensearch_tenant_migration_record", ) op.drop_table("opensearch_tenant_migration_record") # Drop opensearch_document_migration_record. op.drop_index( "ix_opensearch_document_migration_record_created_at", table_name="opensearch_document_migration_record", ) op.drop_index( "ix_opensearch_document_migration_record_attempts_count", table_name="opensearch_document_migration_record", ) op.drop_index( "ix_opensearch_document_migration_record_status", table_name="opensearch_document_migration_record", ) op.drop_table("opensearch_document_migration_record") ================================================ FILE: backend/alembic/versions/cec7ec36c505_kgentity_parent.py ================================================ """kgentity_parent Revision ID: cec7ec36c505 Revises: 495cb26ce93e Create Date: 2025-06-07 20:07:46.400770 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "cec7ec36c505" down_revision = "495cb26ce93e" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "kg_entity", sa.Column("parent_key", sa.String(), nullable=True, index=True), ) # NOTE: you will have to reindex the KG after this migration as the parent_key will be null def downgrade() -> None: op.drop_column("kg_entity", "parent_key") ================================================ FILE: backend/alembic/versions/cf90764725d8_larger_refresh_tokens.py ================================================ """larger refresh tokens Revision ID: cf90764725d8 Revises: 4794bc13e484 Create Date: 2025-04-04 10:56:39.769294 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "cf90764725d8" down_revision = "4794bc13e484" branch_labels = None depends_on = None def upgrade() -> None: op.alter_column("oauth_account", "refresh_token", type_=sa.Text()) def downgrade() -> None: op.alter_column("oauth_account", "refresh_token", type_=sa.String(length=1024)) ================================================ FILE: backend/alembic/versions/d09fc20a3c66_seed_builtin_tools.py ================================================ """seed_builtin_tools Revision ID: d09fc20a3c66 Revises: b7ec9b5b505f Create Date: 2025-09-09 19:32:16.824373 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "d09fc20a3c66" down_revision = "b7ec9b5b505f" branch_labels = None depends_on = None # Tool definitions - core tools that should always be seeded # Names/in_code_tool_id are the same as the class names in the tool_implementations package BUILT_IN_TOOLS = [ { "name": "SearchTool", "display_name": "Internal Search", "description": "The Search Action allows the Assistant to search through connected knowledge to help build an answer.", "in_code_tool_id": "SearchTool", }, { "name": "ImageGenerationTool", "display_name": "Image Generation", "description": ( "The Image Generation Action allows the assistant to use DALL-E 3 or GPT-IMAGE-1 to generate images. " "The action will be used when the user asks the assistant to generate an image." ), "in_code_tool_id": "ImageGenerationTool", }, { "name": "WebSearchTool", "display_name": "Web Search", "description": ( "The Web Search Action allows the assistant to perform internet searches for up-to-date information." ), "in_code_tool_id": "WebSearchTool", }, { "name": "KnowledgeGraphTool", "display_name": "Knowledge Graph Search", "description": ( "The Knowledge Graph Search Action allows the assistant to search the " "Knowledge Graph for information. This tool can (for now) only be active in the KG Beta Assistant, " "and it requires the Knowledge Graph to be enabled." ), "in_code_tool_id": "KnowledgeGraphTool", }, { "name": "OktaProfileTool", "display_name": "Okta Profile", "description": ( "The Okta Profile Action allows the assistant to fetch the current user's information from Okta. " "This may include the user's name, email, phone number, address, and other details such as their " "manager and direct reports." ), "in_code_tool_id": "OktaProfileTool", }, ] def upgrade() -> None: conn = op.get_bind() # Get existing tools to check what already exists existing_tools = conn.execute( sa.text("SELECT in_code_tool_id FROM tool WHERE in_code_tool_id IS NOT NULL") ).fetchall() existing_tool_ids = {row[0] for row in existing_tools} # Insert or update built-in tools for tool in BUILT_IN_TOOLS: in_code_id = tool["in_code_tool_id"] # Handle historical rename: InternetSearchTool -> WebSearchTool if ( in_code_id == "WebSearchTool" and "WebSearchTool" not in existing_tool_ids and "InternetSearchTool" in existing_tool_ids ): # Rename the existing InternetSearchTool row in place and update fields conn.execute( sa.text( """ UPDATE tool SET name = :name, display_name = :display_name, description = :description, in_code_tool_id = :in_code_tool_id WHERE in_code_tool_id = 'InternetSearchTool' """ ), tool, ) # Keep the local view of existing ids in sync to avoid duplicate insert existing_tool_ids.discard("InternetSearchTool") existing_tool_ids.add("WebSearchTool") continue if in_code_id in existing_tool_ids: # Update existing tool conn.execute( sa.text( """ UPDATE tool SET name = :name, display_name = :display_name, description = :description WHERE in_code_tool_id = :in_code_tool_id """ ), tool, ) else: # Insert new tool conn.execute( sa.text( """ INSERT INTO tool (name, display_name, description, in_code_tool_id) VALUES (:name, :display_name, :description, :in_code_tool_id) """ ), tool, ) def downgrade() -> None: # We don't remove the tools on downgrade since it's totally fine to just # have them around. If we upgrade again, it will be a no-op. pass ================================================ FILE: backend/alembic/versions/d1b637d7050a_sync_exa_api_key_to_content_provider.py ================================================ """sync_exa_api_key_to_content_provider Revision ID: d1b637d7050a Revises: d25168c2beee Create Date: 2026-01-09 15:54:15.646249 """ from alembic import op from sqlalchemy import text # revision identifiers, used by Alembic. revision = "d1b637d7050a" down_revision = "d25168c2beee" branch_labels = None depends_on = None def upgrade() -> None: # Exa uses a shared API key between search and content providers. # For existing Exa search providers with API keys, create the corresponding # content provider if it doesn't exist yet. connection = op.get_bind() # Check if Exa search provider exists with an API key result = connection.execute( text( """ SELECT api_key FROM internet_search_provider WHERE provider_type = 'exa' AND api_key IS NOT NULL LIMIT 1 """ ) ) row = result.fetchone() if row: api_key = row[0] # Create Exa content provider with the shared key connection.execute( text( """ INSERT INTO internet_content_provider (name, provider_type, api_key, is_active) VALUES ('Exa', 'exa', :api_key, false) ON CONFLICT (name) DO NOTHING """ ), {"api_key": api_key}, ) def downgrade() -> None: # Remove the Exa content provider that was created by this migration connection = op.get_bind() connection.execute( text( """ DELETE FROM internet_content_provider WHERE provider_type = 'exa' """ ) ) ================================================ FILE: backend/alembic/versions/d25168c2beee_tool_name_consistency.py ================================================ """tool_name_consistency Revision ID: d25168c2beee Revises: 8405ca81cc83 Create Date: 2026-01-11 17:54:40.135777 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "d25168c2beee" down_revision = "8405ca81cc83" branch_labels = None depends_on = None # Currently the seeded tools have the in_code_tool_id == name CURRENT_TOOL_NAME_MAPPING = [ "SearchTool", "WebSearchTool", "ImageGenerationTool", "PythonTool", "OpenURLTool", "KnowledgeGraphTool", "ResearchAgent", ] # Mapping of in_code_tool_id -> name # These are the expected names that we want in the database EXPECTED_TOOL_NAME_MAPPING = { "SearchTool": "internal_search", "WebSearchTool": "web_search", "ImageGenerationTool": "generate_image", "PythonTool": "python", "OpenURLTool": "open_url", "KnowledgeGraphTool": "run_kg_search", "ResearchAgent": "research_agent", } def upgrade() -> None: conn = op.get_bind() # Mapping of in_code_tool_id to the NAME constant from each tool class # These match the .name property of each tool implementation tool_name_mapping = EXPECTED_TOOL_NAME_MAPPING # Update the name column for each tool based on its in_code_tool_id for in_code_tool_id, expected_name in tool_name_mapping.items(): conn.execute( sa.text( """ UPDATE tool SET name = :expected_name WHERE in_code_tool_id = :in_code_tool_id """ ), { "expected_name": expected_name, "in_code_tool_id": in_code_tool_id, }, ) def downgrade() -> None: conn = op.get_bind() # Reverse the migration by setting name back to in_code_tool_id # This matches the original pattern where name was the class name for in_code_tool_id in CURRENT_TOOL_NAME_MAPPING: conn.execute( sa.text( """ UPDATE tool SET name = :current_name WHERE in_code_tool_id = :in_code_tool_id """ ), { "current_name": in_code_tool_id, "in_code_tool_id": in_code_tool_id, }, ) ================================================ FILE: backend/alembic/versions/d3fd499c829c_add_file_reader_tool.py ================================================ """add_file_reader_tool Revision ID: d3fd499c829c Revises: 114a638452db Create Date: 2026-02-07 19:28:22.452337 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "d3fd499c829c" down_revision = "114a638452db" branch_labels = None depends_on = None FILE_READER_TOOL = { "name": "read_file", "display_name": "File Reader", "description": ( "Read sections of user-uploaded files by character offset. " "Useful for inspecting large files that cannot fit entirely in context." ), "in_code_tool_id": "FileReaderTool", "enabled": True, } def upgrade() -> None: conn = op.get_bind() # Check if tool already exists existing = conn.execute( sa.text("SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id"), {"in_code_tool_id": FILE_READER_TOOL["in_code_tool_id"]}, ).fetchone() if existing: # Update existing tool conn.execute( sa.text( """ UPDATE tool SET name = :name, display_name = :display_name, description = :description WHERE in_code_tool_id = :in_code_tool_id """ ), FILE_READER_TOOL, ) tool_id = existing[0] else: # Insert new tool result = conn.execute( sa.text( """ INSERT INTO tool (name, display_name, description, in_code_tool_id, enabled) VALUES (:name, :display_name, :description, :in_code_tool_id, :enabled) RETURNING id """ ), FILE_READER_TOOL, ) tool_id = result.scalar_one() # Attach to the default persona (id=0) if not already attached conn.execute( sa.text( """ INSERT INTO persona__tool (persona_id, tool_id) VALUES (0, :tool_id) ON CONFLICT DO NOTHING """ ), {"tool_id": tool_id}, ) def downgrade() -> None: conn = op.get_bind() in_code_tool_id = FILE_READER_TOOL["in_code_tool_id"] # Remove persona associations first (FK constraint) conn.execute( sa.text( """ DELETE FROM persona__tool WHERE tool_id IN ( SELECT id FROM tool WHERE in_code_tool_id = :in_code_tool_id ) """ ), {"in_code_tool_id": in_code_tool_id}, ) conn.execute( sa.text("DELETE FROM tool WHERE in_code_tool_id = :in_code_tool_id"), {"in_code_tool_id": in_code_tool_id}, ) ================================================ FILE: backend/alembic/versions/d5645c915d0e_remove_deletion_attempt_table.py ================================================ """Remove deletion_attempt table Revision ID: d5645c915d0e Revises: 8e26726b7683 Create Date: 2023-09-14 15:04:14.444909 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "d5645c915d0e" down_revision = "8e26726b7683" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.drop_table("deletion_attempt") # Remove the DeletionStatus enum op.execute("DROP TYPE IF EXISTS deletionstatus;") def downgrade() -> None: op.create_table( "deletion_attempt", sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False), sa.Column("connector_id", sa.INTEGER(), autoincrement=False, nullable=False), sa.Column("credential_id", sa.INTEGER(), autoincrement=False, nullable=False), sa.Column( "status", postgresql.ENUM( "NOT_STARTED", "IN_PROGRESS", "SUCCESS", "FAILED", name="deletionstatus", ), autoincrement=False, nullable=False, ), sa.Column( "num_docs_deleted", sa.INTEGER(), autoincrement=False, nullable=False, ), sa.Column("error_msg", sa.VARCHAR(), autoincrement=False, nullable=True), sa.Column( "time_created", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=False, ), sa.Column( "time_updated", postgresql.TIMESTAMP(timezone=True), server_default=sa.text("now()"), autoincrement=False, nullable=False, ), sa.ForeignKeyConstraint( ["connector_id"], ["connector.id"], name="deletion_attempt_connector_id_fkey", ), sa.ForeignKeyConstraint( ["credential_id"], ["credential.id"], name="deletion_attempt_credential_id_fkey", ), sa.PrimaryKeyConstraint("id", name="deletion_attempt_pkey"), ) ================================================ FILE: backend/alembic/versions/d56ffa94ca32_add_file_content.py ================================================ """add_file_content Revision ID: d56ffa94ca32 Revises: 01f8e6d95a33 Create Date: 2026-02-06 15:29:34.192960 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "d56ffa94ca32" down_revision = "01f8e6d95a33" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "file_content", sa.Column( "file_id", sa.String(), sa.ForeignKey("file_record.file_id", ondelete="CASCADE"), primary_key=True, ), sa.Column("lobj_oid", sa.BigInteger(), nullable=False), sa.Column("file_size", sa.BigInteger(), nullable=False, server_default="0"), ) def downgrade() -> None: op.drop_table("file_content") ================================================ FILE: backend/alembic/versions/d5c86e2c6dc6_add_cascade_delete_to_search_query_user_.py ================================================ """add_cascade_delete_to_search_query_user_id Revision ID: d5c86e2c6dc6 Revises: 90b409d06e50 Create Date: 2026-02-04 16:05:04.749804 """ from alembic import op # revision identifiers, used by Alembic. revision = "d5c86e2c6dc6" down_revision = "90b409d06e50" branch_labels = None depends_on = None def upgrade() -> None: op.drop_constraint("search_query_user_id_fkey", "search_query", type_="foreignkey") op.create_foreign_key( "search_query_user_id_fkey", "search_query", "user", ["user_id"], ["id"], ondelete="CASCADE", ) def downgrade() -> None: op.drop_constraint("search_query_user_id_fkey", "search_query", type_="foreignkey") op.create_foreign_key( "search_query_user_id_fkey", "search_query", "user", ["user_id"], ["id"] ) ================================================ FILE: backend/alembic/versions/d61e513bef0a_add_total_docs_for_index_attempt.py ================================================ """Add Total Docs for Index Attempt Revision ID: d61e513bef0a Revises: 46625e4745d4 Create Date: 2023-10-27 23:02:43.369964 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "d61e513bef0a" down_revision = "46625e4745d4" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "index_attempt", sa.Column("new_docs_indexed", sa.Integer(), nullable=True), ) op.alter_column( "index_attempt", "num_docs_indexed", new_column_name="total_docs_indexed" ) def downgrade() -> None: op.alter_column( "index_attempt", "total_docs_indexed", new_column_name="num_docs_indexed" ) op.drop_column("index_attempt", "new_docs_indexed") ================================================ FILE: backend/alembic/versions/d7111c1238cd_remove_document_ids.py ================================================ """Remove Document IDs Revision ID: d7111c1238cd Revises: 465f78d9b7f9 Create Date: 2023-07-29 15:06:25.126169 """ import sqlalchemy as sa from alembic import op from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "d7111c1238cd" down_revision = "465f78d9b7f9" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.drop_column("index_attempt", "document_ids") def downgrade() -> None: op.add_column( "index_attempt", sa.Column( "document_ids", postgresql.ARRAY(sa.VARCHAR()), autoincrement=False, nullable=True, ), ) ================================================ FILE: backend/alembic/versions/d716b0791ddd_combined_slack_id_fields.py ================================================ """combined slack id fields Revision ID: d716b0791ddd Revises: 7aea705850d5 Create Date: 2024-07-10 17:57:45.630550 """ from alembic import op # revision identifiers, used by Alembic. revision = "d716b0791ddd" down_revision = "7aea705850d5" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.execute( """ UPDATE slack_bot_config SET channel_config = jsonb_set( channel_config, '{respond_member_group_list}', coalesce(channel_config->'respond_team_member_list', '[]'::jsonb) || coalesce(channel_config->'respond_slack_group_list', '[]'::jsonb) ) - 'respond_team_member_list' - 'respond_slack_group_list' """ ) def downgrade() -> None: op.execute( """ UPDATE slack_bot_config SET channel_config = jsonb_set( jsonb_set( channel_config - 'respond_member_group_list', '{respond_team_member_list}', '[]'::jsonb ), '{respond_slack_group_list}', '[]'::jsonb ) """ ) ================================================ FILE: backend/alembic/versions/d8cdfee5df80_add_skipped_to_userfilestatus.py ================================================ """add skipped to userfilestatus Revision ID: d8cdfee5df80 Revises: 1d78c0ca7853 Create Date: 2026-04-01 10:47:12.593950 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "d8cdfee5df80" down_revision = "1d78c0ca7853" branch_labels = None depends_on = None TABLE = "user_file" COLUMN = "status" CONSTRAINT_NAME = "ck_user_file_status" OLD_VALUES = ("PROCESSING", "INDEXING", "COMPLETED", "FAILED", "CANCELED", "DELETING") NEW_VALUES = ( "PROCESSING", "INDEXING", "COMPLETED", "SKIPPED", "FAILED", "CANCELED", "DELETING", ) def _drop_status_check_constraint() -> None: inspector = sa.inspect(op.get_bind()) for constraint in inspector.get_check_constraints(TABLE): if COLUMN in constraint.get("sqltext", ""): constraint_name = constraint["name"] if constraint_name is not None: op.drop_constraint(constraint_name, TABLE, type_="check") def upgrade() -> None: _drop_status_check_constraint() in_clause = ", ".join(f"'{v}'" for v in NEW_VALUES) op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})") def downgrade() -> None: op.execute(f"UPDATE {TABLE} SET {COLUMN} = 'COMPLETED' WHERE {COLUMN} = 'SKIPPED'") _drop_status_check_constraint() in_clause = ", ".join(f"'{v}'" for v in OLD_VALUES) op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})") ================================================ FILE: backend/alembic/versions/d929f0c1c6af_feedback_feature.py ================================================ """Feedback Feature Revision ID: d929f0c1c6af Revises: 8aabb57f3b49 Create Date: 2023-08-27 13:03:54.274987 """ import fastapi_users_db_sqlalchemy from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "d929f0c1c6af" down_revision = "8aabb57f3b49" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "query_event", sa.Column("id", sa.Integer(), nullable=False), sa.Column("query", sa.String(), nullable=False), sa.Column( "selected_search_flow", sa.Enum("KEYWORD", "SEMANTIC", name="searchtype", native_enum=False), nullable=True, ), sa.Column("llm_answer", sa.String(), nullable=True), sa.Column( "feedback", sa.Enum("LIKE", "DISLIKE", name="qafeedbacktype", native_enum=False), nullable=True, ), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=True, ), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("id"), ) op.create_table( "document_retrieval_feedback", sa.Column("id", sa.Integer(), nullable=False), sa.Column("qa_event_id", sa.Integer(), nullable=False), sa.Column("document_id", sa.String(), nullable=False), sa.Column("document_rank", sa.Integer(), nullable=False), sa.Column("clicked", sa.Boolean(), nullable=False), sa.Column( "feedback", sa.Enum( "ENDORSE", "REJECT", "HIDE", "UNHIDE", name="searchfeedbacktype", native_enum=False, ), nullable=True, ), sa.ForeignKeyConstraint( ["document_id"], ["document.id"], ), sa.ForeignKeyConstraint( ["qa_event_id"], ["query_event.id"], ), sa.PrimaryKeyConstraint("id"), ) op.add_column("document", sa.Column("boost", sa.Integer(), nullable=False)) op.add_column("document", sa.Column("hidden", sa.Boolean(), nullable=False)) op.add_column("document", sa.Column("semantic_id", sa.String(), nullable=False)) op.add_column("document", sa.Column("link", sa.String(), nullable=True)) def downgrade() -> None: op.drop_column("document", "link") op.drop_column("document", "semantic_id") op.drop_column("document", "hidden") op.drop_column("document", "boost") op.drop_table("document_retrieval_feedback") op.drop_table("query_event") ================================================ FILE: backend/alembic/versions/d961aca62eb3_update_status_length.py ================================================ """Update status length Revision ID: d961aca62eb3 Revises: cf90764725d8 Create Date: 2025-03-23 16:10:05.683965 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "d961aca62eb3" down_revision = "cf90764725d8" branch_labels = None depends_on = None def upgrade() -> None: # Drop the existing enum type constraint op.execute("ALTER TABLE connector_credential_pair ALTER COLUMN status TYPE varchar") # Create new enum type with all values op.execute( "ALTER TABLE connector_credential_pair ALTER COLUMN status TYPE VARCHAR(20) USING status::varchar(20)" ) # Update the enum type to include all possible values op.alter_column( "connector_credential_pair", "status", type_=sa.Enum( "SCHEDULED", "INITIAL_INDEXING", "ACTIVE", "PAUSED", "DELETING", "INVALID", name="connectorcredentialpairstatus", native_enum=False, ), existing_type=sa.String(20), nullable=False, ) op.add_column( "connector_credential_pair", sa.Column( "in_repeated_error_state", sa.Boolean, default=False, server_default="false" ), ) def downgrade() -> None: # no need to convert back to the old enum type, since we're not using it anymore op.drop_column("connector_credential_pair", "in_repeated_error_state") ================================================ FILE: backend/alembic/versions/d9ec13955951_remove__dim_suffix_from_model_name.py ================================================ """Remove _alt suffix from model_name Revision ID: d9ec13955951 Revises: da4c21c69164 Create Date: 2024-08-20 16:31:32.955686 """ from alembic import op # revision identifiers, used by Alembic. revision = "d9ec13955951" down_revision = "da4c21c69164" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.execute( """ UPDATE embedding_model SET model_name = regexp_replace(model_name, '__danswer_alt_index$', '') WHERE model_name LIKE '%__danswer_alt_index' """ ) def downgrade() -> None: # We can't reliably add the __danswer_alt_index suffix back, so we'll leave this empty pass ================================================ FILE: backend/alembic/versions/da42808081e3_migrate_jira_connectors_to_new_format.py ================================================ """migrate jira connectors to new format Revision ID: da42808081e3 Revises: f13db29f3101 Create Date: 2025-02-24 11:24:54.396040 """ from alembic import op import sqlalchemy as sa import json from onyx.configs.constants import DocumentSource from onyx.connectors.jira.utils import extract_jira_project # revision identifiers, used by Alembic. revision = "da42808081e3" down_revision = "f13db29f3101" branch_labels = None depends_on = None PRESERVED_CONFIG_KEYS = ["comment_email_blacklist", "batch_size", "labels_to_skip"] def upgrade() -> None: # Get all Jira connectors conn = op.get_bind() # First get all Jira connectors jira_connectors = conn.execute( sa.text( """ SELECT id, connector_specific_config FROM connector WHERE source = :source """ ), {"source": DocumentSource.JIRA.value.upper()}, ).fetchall() # Update each connector's config for connector_id, old_config in jira_connectors: if not old_config: continue # Extract project key from URL if it exists new_config: dict[str, str | None] = {} if project_url := old_config.get("jira_project_url"): # Parse the URL to get base and project try: jira_base, project_key = extract_jira_project(project_url) new_config = {"jira_base_url": jira_base, "project_key": project_key} except ValueError: # If URL parsing fails, just use the URL as the base new_config = { "jira_base_url": project_url.split("/projects/")[0], "project_key": None, } else: # For connectors without a project URL, we need admin intervention # Mark these for review print( f"WARNING: Jira connector {connector_id} has no project URL configured" ) continue for old_key in PRESERVED_CONFIG_KEYS: if old_key in old_config: new_config[old_key] = old_config[old_key] # Update the connector config conn.execute( sa.text( """ UPDATE connector SET connector_specific_config = :new_config WHERE id = :id """ ), {"id": connector_id, "new_config": json.dumps(new_config)}, ) def downgrade() -> None: # Get all Jira connectors conn = op.get_bind() # First get all Jira connectors jira_connectors = conn.execute( sa.text( """ SELECT id, connector_specific_config FROM connector WHERE source = :source """ ), {"source": DocumentSource.JIRA.value.upper()}, ).fetchall() # Update each connector's config back to the old format for connector_id, new_config in jira_connectors: if not new_config: continue old_config = {} base_url = new_config.get("jira_base_url") project_key = new_config.get("project_key") if base_url and project_key: old_config = {"jira_project_url": f"{base_url}/projects/{project_key}"} elif base_url: old_config = {"jira_project_url": base_url} else: continue for old_key in PRESERVED_CONFIG_KEYS: if old_key in new_config: old_config[old_key] = new_config[old_key] # Update the connector config conn.execute( sa.text( """ UPDATE connector SET connector_specific_config = :old_config WHERE id = :id """ ), {"id": connector_id, "old_config": json.dumps(old_config)}, ) ================================================ FILE: backend/alembic/versions/da4c21c69164_chosen_assistants_changed_to_jsonb.py ================================================ """chosen_assistants changed to jsonb Revision ID: da4c21c69164 Revises: c5b692fa265c Create Date: 2024-08-18 19:06:47.291491 """ import json from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "da4c21c69164" down_revision = "c5b692fa265c" branch_labels: None = None depends_on: None = None def upgrade() -> None: conn = op.get_bind() existing_ids_and_chosen_assistants = conn.execute( sa.text('select id, chosen_assistants from "user"') ) op.drop_column( "user", "chosen_assistants", ) op.add_column( "user", sa.Column( "chosen_assistants", postgresql.JSONB(astext_type=sa.Text()), nullable=True, ), ) for id, chosen_assistants in existing_ids_and_chosen_assistants: conn.execute( sa.text( 'update "user" set chosen_assistants = :chosen_assistants where id = :id' ), {"chosen_assistants": json.dumps(chosen_assistants), "id": id}, ) def downgrade() -> None: conn = op.get_bind() existing_ids_and_chosen_assistants = conn.execute( sa.text('select id, chosen_assistants from "user"') ) op.drop_column( "user", "chosen_assistants", ) op.add_column( "user", sa.Column("chosen_assistants", postgresql.ARRAY(sa.Integer()), nullable=True), ) for id, chosen_assistants in existing_ids_and_chosen_assistants: conn.execute( sa.text( 'update "user" set chosen_assistants = :chosen_assistants where id = :id' ), {"chosen_assistants": chosen_assistants, "id": id}, ) ================================================ FILE: backend/alembic/versions/dab04867cd88_add_composite_index_to_document_by_.py ================================================ """Add composite index to document_by_connector_credential_pair Revision ID: dab04867cd88 Revises: 54a74a0417fc Create Date: 2024-12-13 22:43:20.119990 """ from alembic import op # revision identifiers, used by Alembic. revision = "dab04867cd88" down_revision = "54a74a0417fc" branch_labels = None depends_on = None def upgrade() -> None: # Composite index on (connector_id, credential_id) op.create_index( "idx_document_cc_pair_connector_credential", "document_by_connector_credential_pair", ["connector_id", "credential_id"], unique=False, ) def downgrade() -> None: op.drop_index( "idx_document_cc_pair_connector_credential", table_name="document_by_connector_credential_pair", ) ================================================ FILE: backend/alembic/versions/dba7f71618f5_onyx_custom_tool_flow.py ================================================ """Onyx Custom Tool Flow Revision ID: dba7f71618f5 Revises: d5645c915d0e Create Date: 2023-09-18 15:18:37.370972 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "dba7f71618f5" down_revision = "d5645c915d0e" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "persona", sa.Column("retrieval_enabled", sa.Boolean(), nullable=True), ) op.execute("UPDATE persona SET retrieval_enabled = true") op.alter_column("persona", "retrieval_enabled", nullable=False) def downgrade() -> None: op.drop_column("persona", "retrieval_enabled") ================================================ FILE: backend/alembic/versions/dbaa756c2ccf_embedding_models.py ================================================ """Embedding Models Revision ID: dbaa756c2ccf Revises: 7f726bad5367 Create Date: 2024-01-25 17:12:31.813160 """ from alembic import op import sqlalchemy as sa from sqlalchemy import table, column, String, Integer, Boolean from onyx.configs.model_configs import ASYM_PASSAGE_PREFIX from onyx.configs.model_configs import ASYM_QUERY_PREFIX from onyx.configs.model_configs import DOC_EMBEDDING_DIM from onyx.configs.model_configs import DOCUMENT_ENCODER_MODEL from onyx.configs.model_configs import NORMALIZE_EMBEDDINGS from onyx.configs.model_configs import OLD_DEFAULT_DOCUMENT_ENCODER_MODEL from onyx.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM from onyx.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS from onyx.db.enums import EmbeddingPrecision from onyx.db.models import IndexModelStatus from onyx.db.search_settings import user_has_overridden_embedding_model from onyx.indexing.models import IndexingSetting from onyx.natural_language_processing.search_nlp_models import clean_model_name # revision identifiers, used by Alembic. revision = "dbaa756c2ccf" down_revision = "7f726bad5367" branch_labels: None = None depends_on: None = None def _get_old_default_embedding_model() -> IndexingSetting: is_overridden = user_has_overridden_embedding_model() return IndexingSetting( model_name=( DOCUMENT_ENCODER_MODEL if is_overridden else OLD_DEFAULT_DOCUMENT_ENCODER_MODEL ), model_dim=( DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM ), embedding_precision=(EmbeddingPrecision.FLOAT), normalize=( NORMALIZE_EMBEDDINGS if is_overridden else OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS ), query_prefix=(ASYM_QUERY_PREFIX if is_overridden else ""), passage_prefix=(ASYM_PASSAGE_PREFIX if is_overridden else ""), index_name="danswer_chunk", multipass_indexing=False, enable_contextual_rag=False, api_url=None, ) def _get_new_default_embedding_model() -> IndexingSetting: return IndexingSetting( model_name=DOCUMENT_ENCODER_MODEL, model_dim=DOC_EMBEDDING_DIM, embedding_precision=(EmbeddingPrecision.BFLOAT16), normalize=NORMALIZE_EMBEDDINGS, query_prefix=ASYM_QUERY_PREFIX, passage_prefix=ASYM_PASSAGE_PREFIX, index_name=f"danswer_chunk_{clean_model_name(DOCUMENT_ENCODER_MODEL)}", multipass_indexing=False, enable_contextual_rag=False, api_url=None, ) def upgrade() -> None: op.create_table( "embedding_model", sa.Column("id", sa.Integer(), nullable=False), sa.Column("model_name", sa.String(), nullable=False), sa.Column("model_dim", sa.Integer(), nullable=False), sa.Column("normalize", sa.Boolean(), nullable=False), sa.Column("query_prefix", sa.String(), nullable=False), sa.Column("passage_prefix", sa.String(), nullable=False), sa.Column("index_name", sa.String(), nullable=False), sa.Column( "status", sa.Enum(IndexModelStatus, native=False), nullable=False, ), sa.PrimaryKeyConstraint("id"), ) # since all index attempts must be associated with an embedding model, # need to put something in here to avoid nulls. On server startup, # this value will be overriden EmbeddingModel = table( "embedding_model", column("id", Integer), column("model_name", String), column("model_dim", Integer), column("normalize", Boolean), column("query_prefix", String), column("passage_prefix", String), column("index_name", String), column( "status", sa.Enum(IndexModelStatus, name="indexmodelstatus", native=False) ), ) # insert an embedding model row that corresponds to the embedding model # the user selected via env variables before this change. This is needed since # all index_attempts must be associated with an embedding model, so without this # we will run into violations of non-null contraints old_embedding_model = _get_old_default_embedding_model() op.bulk_insert( EmbeddingModel, [ { "model_name": old_embedding_model.model_name, "model_dim": old_embedding_model.model_dim, "normalize": old_embedding_model.normalize, "query_prefix": old_embedding_model.query_prefix, "passage_prefix": old_embedding_model.passage_prefix, "index_name": old_embedding_model.index_name, "status": IndexModelStatus.PRESENT, } ], ) # if the user has not overridden the default embedding model via env variables, # insert the new default model into the database to auto-upgrade them if not user_has_overridden_embedding_model(): new_embedding_model = _get_new_default_embedding_model() op.bulk_insert( EmbeddingModel, [ { "model_name": new_embedding_model.model_name, "model_dim": new_embedding_model.model_dim, "normalize": new_embedding_model.normalize, "query_prefix": new_embedding_model.query_prefix, "passage_prefix": new_embedding_model.passage_prefix, "index_name": new_embedding_model.index_name, "status": IndexModelStatus.FUTURE, } ], ) op.add_column( "index_attempt", sa.Column("embedding_model_id", sa.Integer(), nullable=True), ) op.execute( "UPDATE index_attempt SET embedding_model_id=1 WHERE embedding_model_id IS NULL" ) op.alter_column( "index_attempt", "embedding_model_id", existing_type=sa.Integer(), nullable=False, ) op.create_foreign_key( "index_attempt__embedding_model_fk", "index_attempt", "embedding_model", ["embedding_model_id"], ["id"], ) op.create_index( "ix_embedding_model_present_unique", "embedding_model", ["status"], unique=True, postgresql_where=sa.text("status = 'PRESENT'"), ) op.create_index( "ix_embedding_model_future_unique", "embedding_model", ["status"], unique=True, postgresql_where=sa.text("status = 'FUTURE'"), ) def downgrade() -> None: op.drop_constraint( "index_attempt__embedding_model_fk", "index_attempt", type_="foreignkey" ) op.drop_column("index_attempt", "embedding_model_id") op.drop_table("embedding_model") op.execute("DROP TYPE IF EXISTS indexmodelstatus;") ================================================ FILE: backend/alembic/versions/df0c7ad8a076_added_deletion_attempt_table.py ================================================ """Added deletion_attempt table Revision ID: df0c7ad8a076 Revises: d7111c1238cd Create Date: 2023-08-05 13:35:39.609619 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "df0c7ad8a076" down_revision = "d7111c1238cd" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.execute("DROP TABLE IF EXISTS document CASCADE") op.create_table( "document", sa.Column("id", sa.String(), nullable=False), sa.PrimaryKeyConstraint("id"), ) op.execute("DROP TABLE IF EXISTS chunk CASCADE") op.create_table( "chunk", sa.Column("id", sa.String(), nullable=False), sa.Column( "document_store_type", sa.Enum( "VECTOR", "KEYWORD", name="documentstoretype", native_enum=False, ), nullable=False, ), sa.Column("document_id", sa.String(), nullable=False), sa.ForeignKeyConstraint( ["document_id"], ["document.id"], ), sa.PrimaryKeyConstraint("id", "document_store_type"), ) op.execute("DROP TABLE IF EXISTS deletion_attempt CASCADE") op.create_table( "deletion_attempt", sa.Column("id", sa.Integer(), nullable=False), sa.Column("connector_id", sa.Integer(), nullable=False), sa.Column("credential_id", sa.Integer(), nullable=False), sa.Column( "status", sa.Enum( "NOT_STARTED", "IN_PROGRESS", "SUCCESS", "FAILED", name="deletionstatus", native_enum=False, ), nullable=False, ), sa.Column("num_docs_deleted", sa.Integer(), nullable=False), sa.Column("error_msg", sa.String(), nullable=True), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column( "time_updated", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.ForeignKeyConstraint( ["connector_id"], ["connector.id"], ), sa.ForeignKeyConstraint( ["credential_id"], ["credential.id"], ), sa.PrimaryKeyConstraint("id"), ) op.execute("DROP TABLE IF EXISTS document_by_connector_credential_pair CASCADE") op.create_table( "document_by_connector_credential_pair", sa.Column("id", sa.String(), nullable=False), sa.Column("connector_id", sa.Integer(), nullable=False), sa.Column("credential_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["connector_id"], ["connector.id"], ), sa.ForeignKeyConstraint( ["credential_id"], ["credential.id"], ), sa.ForeignKeyConstraint( ["id"], ["document.id"], ), sa.PrimaryKeyConstraint("id", "connector_id", "credential_id"), ) def downgrade() -> None: # upstream tables first op.drop_table("document_by_connector_credential_pair") op.drop_table("deletion_attempt") op.drop_table("chunk") # Alembic op.drop_table() has no "cascade" flag – issue raw SQL op.execute("DROP TABLE IF EXISTS document CASCADE") ================================================ FILE: backend/alembic/versions/df46c75b714e_add_default_vision_provider_to_llm_.py ================================================ """add_default_vision_provider_to_llm_provider Revision ID: df46c75b714e Revises: 3934b1bc7b62 Create Date: 2025-03-11 16:20:19.038945 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "df46c75b714e" down_revision = "3934b1bc7b62" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "llm_provider", sa.Column( "is_default_vision_provider", sa.Boolean(), nullable=True, server_default=sa.false(), ), ) op.add_column( "llm_provider", sa.Column("default_vision_model", sa.String(), nullable=True) ) def downgrade() -> None: op.drop_column("llm_provider", "default_vision_model") op.drop_column("llm_provider", "is_default_vision_provider") ================================================ FILE: backend/alembic/versions/dfbe9e93d3c7_extended_role_for_non_web.py ================================================ """extended_role_for_non_web Revision ID: dfbe9e93d3c7 Revises: 9cf5c00f72fe Create Date: 2024-11-16 07:54:18.727906 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "dfbe9e93d3c7" down_revision = "9cf5c00f72fe" branch_labels = None depends_on = None def upgrade() -> None: op.execute( """ UPDATE "user" SET role = 'EXT_PERM_USER' WHERE has_web_login = false """ ) op.drop_column("user", "has_web_login") def downgrade() -> None: op.add_column( "user", sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"), ) op.execute( """ UPDATE "user" SET has_web_login = false, role = 'BASIC' WHERE role IN ('SLACK_USER', 'EXT_PERM_USER') """ ) ================================================ FILE: backend/alembic/versions/e0a68a81d434_add_chat_feedback.py ================================================ """Add Chat Feedback Revision ID: e0a68a81d434 Revises: ae62505e3acc Create Date: 2023-10-04 20:22:33.380286 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "e0a68a81d434" down_revision = "ae62505e3acc" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "chat_feedback", sa.Column("id", sa.Integer(), nullable=False), sa.Column("chat_message_chat_session_id", sa.Integer(), nullable=False), sa.Column("chat_message_message_number", sa.Integer(), nullable=False), sa.Column("chat_message_edit_number", sa.Integer(), nullable=False), sa.Column("is_positive", sa.Boolean(), nullable=True), sa.Column("feedback_text", sa.Text(), nullable=True), sa.ForeignKeyConstraint( [ "chat_message_chat_session_id", "chat_message_message_number", "chat_message_edit_number", ], [ "chat_message.chat_session_id", "chat_message.message_number", "chat_message.edit_number", ], ), sa.PrimaryKeyConstraint("id"), ) def downgrade() -> None: op.drop_table("chat_feedback") ================================================ FILE: backend/alembic/versions/e1392f05e840_added_input_prompts.py ================================================ """Added input prompts Revision ID: e1392f05e840 Revises: 08a1eda20fe1 Create Date: 2024-07-13 19:09:22.556224 """ import fastapi_users_db_sqlalchemy from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "e1392f05e840" down_revision = "08a1eda20fe1" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "inputprompt", sa.Column("id", sa.Integer(), autoincrement=True, nullable=False), sa.Column("prompt", sa.String(), nullable=False), sa.Column("content", sa.String(), nullable=False), sa.Column("active", sa.Boolean(), nullable=False), sa.Column("is_public", sa.Boolean(), nullable=False), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=True, ), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("id"), ) op.create_table( "inputprompt__user", sa.Column("input_prompt_id", sa.Integer(), nullable=False), sa.Column("user_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["input_prompt_id"], ["inputprompt.id"], ), sa.ForeignKeyConstraint( ["user_id"], ["inputprompt.id"], ), sa.PrimaryKeyConstraint("input_prompt_id", "user_id"), ) def downgrade() -> None: op.drop_table("inputprompt__user") op.drop_table("inputprompt") ================================================ FILE: backend/alembic/versions/e209dc5a8156_added_prune_frequency.py ================================================ """added-prune-frequency Revision ID: e209dc5a8156 Revises: 48d14957fe80 Create Date: 2024-06-16 16:02:35.273231 """ from alembic import op import sqlalchemy as sa revision = "e209dc5a8156" down_revision = "48d14957fe80" branch_labels = None depends_on = None def upgrade() -> None: op.add_column("connector", sa.Column("prune_freq", sa.Integer(), nullable=True)) def downgrade() -> None: op.drop_column("connector", "prune_freq") ================================================ FILE: backend/alembic/versions/e4334d5b33ba_add_deployment_name_to_llmprovider.py ================================================ """add_deployment_name_to_llmprovider Revision ID: e4334d5b33ba Revises: ac5eaac849f9 Create Date: 2024-10-04 09:52:34.896867 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "e4334d5b33ba" down_revision = "ac5eaac849f9" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "llm_provider", sa.Column("deployment_name", sa.String(), nullable=True) ) def downgrade() -> None: op.drop_column("llm_provider", "deployment_name") ================================================ FILE: backend/alembic/versions/e50154680a5c_no_source_enum.py ================================================ """No Source Enum Revision ID: e50154680a5c Revises: fcd135795f21 Create Date: 2024-03-14 18:06:08.523106 """ from alembic import op import sqlalchemy as sa from onyx.configs.constants import DocumentSource # revision identifiers, used by Alembic. revision = "e50154680a5c" down_revision = "fcd135795f21" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.alter_column( "search_doc", "source_type", type_=sa.String(length=50), existing_type=sa.Enum(DocumentSource, native_enum=False), existing_nullable=False, ) op.execute("DROP TYPE IF EXISTS documentsource") def downgrade() -> None: op.alter_column( "search_doc", "source_type", type_=sa.Enum(DocumentSource, native_enum=False), existing_type=sa.String(length=50), existing_nullable=False, ) ================================================ FILE: backend/alembic/versions/e6a4bbc13fe4_add_index_for_retrieving_latest_index_.py ================================================ """Add index for retrieving latest index_attempt Revision ID: e6a4bbc13fe4 Revises: b082fec533f0 Create Date: 2023-08-10 12:37:23.335471 """ from alembic import op # revision identifiers, used by Alembic. revision = "e6a4bbc13fe4" down_revision = "b082fec533f0" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_index( op.f("ix_index_attempt_latest_for_connector_credential_pair"), "index_attempt", ["connector_id", "credential_id", "time_created"], unique=False, ) def downgrade() -> None: op.drop_index( op.f("ix_index_attempt_latest_for_connector_credential_pair"), table_name="index_attempt", ) ================================================ FILE: backend/alembic/versions/e7f8a9b0c1d2_create_anonymous_user.py ================================================ """create_anonymous_user This migration creates a permanent anonymous user in the database. When anonymous access is enabled, unauthenticated requests will use this user instead of returning user_id=NULL. Revision ID: e7f8a9b0c1d2 Revises: f7ca3e2f45d9 Create Date: 2026-01-15 14:00:00.000000 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "e7f8a9b0c1d2" down_revision = "f7ca3e2f45d9" branch_labels = None depends_on = None # Must match constants in onyx/configs/constants.py file ANONYMOUS_USER_UUID = "00000000-0000-0000-0000-000000000002" ANONYMOUS_USER_EMAIL = "anonymous@onyx.app" # Tables with user_id foreign key that may need migration TABLES_WITH_USER_ID = [ "chat_session", "credential", "document_set", "persona", "tool", "notification", "inputprompt", ] def _dedupe_null_notifications(connection: sa.Connection) -> None: # Multiple NULL-owned notifications can exist because the unique index treats # NULL user_id values as distinct. Before migrating them to the anonymous # user, collapse duplicates and remove rows that would conflict with an # already-existing anonymous notification. result = connection.execute( sa.text( """ WITH ranked_null_notifications AS ( SELECT id, ROW_NUMBER() OVER ( PARTITION BY notif_type, COALESCE(additional_data, '{}'::jsonb) ORDER BY first_shown DESC, last_shown DESC, id DESC ) AS row_num FROM notification WHERE user_id IS NULL ) DELETE FROM notification WHERE id IN ( SELECT id FROM ranked_null_notifications WHERE row_num > 1 ) """ ) ) if result.rowcount > 0: print(f"Deleted {result.rowcount} duplicate NULL-owned notifications") result = connection.execute( sa.text( """ DELETE FROM notification AS null_owned USING notification AS anonymous_owned WHERE null_owned.user_id IS NULL AND anonymous_owned.user_id = :user_id AND null_owned.notif_type = anonymous_owned.notif_type AND COALESCE(null_owned.additional_data, '{}'::jsonb) = COALESCE(anonymous_owned.additional_data, '{}'::jsonb) """ ), {"user_id": ANONYMOUS_USER_UUID}, ) if result.rowcount > 0: print( f"Deleted {result.rowcount} NULL-owned notifications that conflict with existing anonymous-owned notifications" ) def upgrade() -> None: """ Create the anonymous user for anonymous access feature. Also migrates any remaining user_id=NULL records to the anonymous user. """ connection = op.get_bind() # Create the anonymous user (using ON CONFLICT to be idempotent) connection.execute( sa.text( """ INSERT INTO "user" (id, email, hashed_password, is_active, is_superuser, is_verified, role) VALUES (:id, :email, :hashed_password, :is_active, :is_superuser, :is_verified, :role) ON CONFLICT (id) DO NOTHING """ ), { "id": ANONYMOUS_USER_UUID, "email": ANONYMOUS_USER_EMAIL, "hashed_password": "", # Empty password - user cannot log in directly "is_active": True, # Active so it can be used for anonymous access "is_superuser": False, "is_verified": True, # Verified since no email verification needed "role": "LIMITED", # Anonymous users have limited role to restrict access }, ) # Migrate any remaining user_id=NULL records to anonymous user for table in TABLES_WITH_USER_ID: # Dedup notifications outside the savepoint so deletions persist # even if the subsequent UPDATE rolls back if table == "notification": _dedupe_null_notifications(connection) with connection.begin_nested(): # Exclude public credential (id=0) which must remain user_id=NULL # Exclude builtin tools (in_code_tool_id IS NOT NULL) which must remain user_id=NULL # Exclude builtin personas (builtin_persona=True) which must remain user_id=NULL # Exclude system input prompts (is_public=True with user_id=NULL) which must remain user_id=NULL if table == "credential": condition = "user_id IS NULL AND id != 0" elif table == "tool": condition = "user_id IS NULL AND in_code_tool_id IS NULL" elif table == "persona": condition = "user_id IS NULL AND builtin_persona = false" elif table == "inputprompt": condition = "user_id IS NULL AND is_public = false" else: condition = "user_id IS NULL" result = connection.execute( sa.text( f""" UPDATE "{table}" SET user_id = :user_id WHERE {condition} """ ), {"user_id": ANONYMOUS_USER_UUID}, ) if result.rowcount > 0: print(f"Updated {result.rowcount} rows in {table} to anonymous user") def downgrade() -> None: """ Set anonymous user's records back to NULL and delete the anonymous user. Note: Duplicate NULL-owned notifications removed during upgrade are not restored. """ connection = op.get_bind() # Set records back to NULL for table in TABLES_WITH_USER_ID: with connection.begin_nested(): connection.execute( sa.text( f""" UPDATE "{table}" SET user_id = NULL WHERE user_id = :user_id """ ), {"user_id": ANONYMOUS_USER_UUID}, ) # Delete the anonymous user connection.execute( sa.text('DELETE FROM "user" WHERE id = :user_id'), {"user_id": ANONYMOUS_USER_UUID}, ) ================================================ FILE: backend/alembic/versions/e86866a9c78a_add_persona_to_chat_session.py ================================================ """Add persona to chat_session Revision ID: e86866a9c78a Revises: 80696cf850ae Create Date: 2023-11-26 02:51:47.657357 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "e86866a9c78a" down_revision = "80696cf850ae" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column("chat_session", sa.Column("persona_id", sa.Integer(), nullable=True)) op.create_foreign_key( "fk_chat_session_persona_id", "chat_session", "persona", ["persona_id"], ["id"] ) def downgrade() -> None: op.drop_constraint("fk_chat_session_persona_id", "chat_session", type_="foreignkey") op.drop_column("chat_session", "persona_id") ================================================ FILE: backend/alembic/versions/e8f0d2a38171_add_status_to_mcp_server_and_make_auth_.py ================================================ """add status to mcp server and make auth fields nullable Revision ID: e8f0d2a38171 Revises: ed9e44312505 Create Date: 2025-11-28 11:15:37.667340 """ from alembic import op import sqlalchemy as sa from onyx.db.enums import ( MCPTransport, MCPAuthenticationType, MCPAuthenticationPerformer, MCPServerStatus, ) # revision identifiers, used by Alembic. revision = "e8f0d2a38171" down_revision = "ed9e44312505" branch_labels = None depends_on = None def upgrade() -> None: # Make auth fields nullable op.alter_column( "mcp_server", "transport", existing_type=sa.Enum(MCPTransport, name="mcp_transport", native_enum=False), nullable=True, ) op.alter_column( "mcp_server", "auth_type", existing_type=sa.Enum( MCPAuthenticationType, name="mcp_authentication_type", native_enum=False ), nullable=True, ) op.alter_column( "mcp_server", "auth_performer", existing_type=sa.Enum( MCPAuthenticationPerformer, name="mcp_authentication_performer", native_enum=False, ), nullable=True, ) # Add status column with default op.add_column( "mcp_server", sa.Column( "status", sa.Enum(MCPServerStatus, name="mcp_server_status", native_enum=False), nullable=False, server_default="CREATED", ), ) # For existing records, mark status as CONNECTED bind = op.get_bind() bind.execute( sa.text( """ UPDATE mcp_server SET status = 'CONNECTED' WHERE status != 'CONNECTED' and admin_connection_config_id IS NOT NULL """ ) ) def downgrade() -> None: # Remove status column op.drop_column("mcp_server", "status") # Make auth fields non-nullable (set defaults first) op.execute( "UPDATE mcp_server SET transport = 'STREAMABLE_HTTP' WHERE transport IS NULL" ) op.execute("UPDATE mcp_server SET auth_type = 'NONE' WHERE auth_type IS NULL") op.execute( "UPDATE mcp_server SET auth_performer = 'ADMIN' WHERE auth_performer IS NULL" ) op.alter_column( "mcp_server", "transport", existing_type=sa.Enum(MCPTransport, name="mcp_transport", native_enum=False), nullable=False, ) op.alter_column( "mcp_server", "auth_type", existing_type=sa.Enum( MCPAuthenticationType, name="mcp_authentication_type", native_enum=False ), nullable=False, ) op.alter_column( "mcp_server", "auth_performer", existing_type=sa.Enum( MCPAuthenticationPerformer, name="mcp_authentication_performer", native_enum=False, ), nullable=False, ) ================================================ FILE: backend/alembic/versions/e91df4e935ef_private_personas_documentsets.py ================================================ """Private Personas DocumentSets Revision ID: e91df4e935ef Revises: 91fd3b470d1a Create Date: 2024-03-17 11:47:24.675881 """ import fastapi_users_db_sqlalchemy from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "e91df4e935ef" down_revision = "91fd3b470d1a" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "document_set__user", sa.Column("document_set_id", sa.Integer(), nullable=False), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False, ), sa.ForeignKeyConstraint( ["document_set_id"], ["document_set.id"], ), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("document_set_id", "user_id"), ) op.create_table( "persona__user", sa.Column("persona_id", sa.Integer(), nullable=False), sa.Column( "user_id", fastapi_users_db_sqlalchemy.generics.GUID(), nullable=False, ), sa.ForeignKeyConstraint( ["persona_id"], ["persona.id"], ), sa.ForeignKeyConstraint( ["user_id"], ["user.id"], ), sa.PrimaryKeyConstraint("persona_id", "user_id"), ) op.create_table( "document_set__user_group", sa.Column("document_set_id", sa.Integer(), nullable=False), sa.Column( "user_group_id", sa.Integer(), nullable=False, ), sa.ForeignKeyConstraint( ["document_set_id"], ["document_set.id"], ), sa.ForeignKeyConstraint( ["user_group_id"], ["user_group.id"], ), sa.PrimaryKeyConstraint("document_set_id", "user_group_id"), ) op.create_table( "persona__user_group", sa.Column("persona_id", sa.Integer(), nullable=False), sa.Column( "user_group_id", sa.Integer(), nullable=False, ), sa.ForeignKeyConstraint( ["persona_id"], ["persona.id"], ), sa.ForeignKeyConstraint( ["user_group_id"], ["user_group.id"], ), sa.PrimaryKeyConstraint("persona_id", "user_group_id"), ) op.add_column( "document_set", sa.Column("is_public", sa.Boolean(), nullable=True), ) # fill in is_public for existing rows op.execute("UPDATE document_set SET is_public = true WHERE is_public IS NULL") op.alter_column("document_set", "is_public", nullable=False) op.add_column( "persona", sa.Column("is_public", sa.Boolean(), nullable=True), ) # fill in is_public for existing rows op.execute("UPDATE persona SET is_public = true WHERE is_public IS NULL") op.alter_column("persona", "is_public", nullable=False) def downgrade() -> None: op.drop_column("persona", "is_public") op.drop_column("document_set", "is_public") op.drop_table("persona__user") op.drop_table("document_set__user") op.drop_table("persona__user_group") op.drop_table("document_set__user_group") ================================================ FILE: backend/alembic/versions/eaa3b5593925_add_default_slack_channel_config.py ================================================ """add default slack channel config Revision ID: eaa3b5593925 Revises: 98a5008d8711 Create Date: 2025-02-03 18:07:56.552526 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "eaa3b5593925" down_revision = "98a5008d8711" branch_labels = None depends_on = None def upgrade() -> None: # Add is_default column op.add_column( "slack_channel_config", sa.Column("is_default", sa.Boolean(), nullable=False, server_default="false"), ) op.create_index( "ix_slack_channel_config_slack_bot_id_default", "slack_channel_config", ["slack_bot_id", "is_default"], unique=True, postgresql_where=sa.text("is_default IS TRUE"), ) # Create default channel configs for existing slack bots without one conn = op.get_bind() slack_bots = conn.execute(sa.text("SELECT id FROM slack_bot")).fetchall() for slack_bot in slack_bots: slack_bot_id = slack_bot[0] existing_default = conn.execute( sa.text( "SELECT id FROM slack_channel_config WHERE slack_bot_id = :bot_id AND is_default = TRUE" ), {"bot_id": slack_bot_id}, ).fetchone() if not existing_default: conn.execute( sa.text( """ INSERT INTO slack_channel_config ( slack_bot_id, persona_id, channel_config, enable_auto_filters, is_default ) VALUES ( :bot_id, NULL, '{"channel_name": null, ' '"respond_member_group_list": [], ' '"answer_filters": [], ' '"follow_up_tags": [], ' '"respond_tag_only": true}', FALSE, TRUE ) """ ), {"bot_id": slack_bot_id}, ) def downgrade() -> None: # Delete default slack channel configs conn = op.get_bind() conn.execute(sa.text("DELETE FROM slack_channel_config WHERE is_default = TRUE")) # Remove index op.drop_index( "ix_slack_channel_config_slack_bot_id_default", table_name="slack_channel_config", ) # Remove is_default column op.drop_column("slack_channel_config", "is_default") ================================================ FILE: backend/alembic/versions/ec3ec2eabf7b_index_from_beginning.py ================================================ """Index From Beginning Revision ID: ec3ec2eabf7b Revises: dbaa756c2ccf Create Date: 2024-02-06 22:03:28.098158 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "ec3ec2eabf7b" down_revision = "dbaa756c2ccf" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "index_attempt", sa.Column("from_beginning", sa.Boolean(), nullable=True) ) op.execute("UPDATE index_attempt SET from_beginning = False") op.alter_column("index_attempt", "from_beginning", nullable=False) def downgrade() -> None: op.drop_column("index_attempt", "from_beginning") ================================================ FILE: backend/alembic/versions/ec85f2b3c544_remove_last_attempt_status_from_cc_pair.py ================================================ """Remove Last Attempt Status from CC Pair Revision ID: ec85f2b3c544 Revises: 3879338f8ba1 Create Date: 2024-05-23 21:39:46.126010 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "ec85f2b3c544" down_revision = "70f00c45c0f2" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.drop_column("connector_credential_pair", "last_attempt_status") def downgrade() -> None: op.add_column( "connector_credential_pair", sa.Column( "last_attempt_status", sa.VARCHAR(), autoincrement=False, nullable=True, ), ) ================================================ FILE: backend/alembic/versions/ecab2b3f1a3b_add_overrides_to_the_chat_session.py ================================================ """Add overrides to the chat session Revision ID: ecab2b3f1a3b Revises: 38eda64af7fe Create Date: 2024-04-01 19:08:21.359102 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "ecab2b3f1a3b" down_revision = "38eda64af7fe" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "chat_session", sa.Column( "llm_override", postgresql.JSONB(astext_type=sa.Text()), nullable=True, ), ) op.add_column( "chat_session", sa.Column( "prompt_override", postgresql.JSONB(astext_type=sa.Text()), nullable=True, ), ) def downgrade() -> None: op.drop_column("chat_session", "prompt_override") op.drop_column("chat_session", "llm_override") ================================================ FILE: backend/alembic/versions/ed9e44312505_add_icon_name_field.py ================================================ """Add icon_name field Revision ID: ed9e44312505 Revises: 5e6f7a8b9c0d Create Date: 2025-12-03 16:35:07.828393 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "ed9e44312505" down_revision = "5e6f7a8b9c0d" branch_labels = None depends_on = None def upgrade() -> None: # Add icon_name column op.add_column("persona", sa.Column("icon_name", sa.String(), nullable=True)) # Remove old icon columns op.drop_column("persona", "icon_shape") op.drop_column("persona", "icon_color") def downgrade() -> None: # Re-add old icon columns op.add_column("persona", sa.Column("icon_color", sa.String(), nullable=True)) op.add_column("persona", sa.Column("icon_shape", sa.Integer(), nullable=True)) # Remove icon_name column op.drop_column("persona", "icon_name") ================================================ FILE: backend/alembic/versions/ee3f4b47fad5_added_alternate_model_to_chat_message.py ================================================ """Added alternate model to chat message Revision ID: ee3f4b47fad5 Revises: 2d2304e27d8c Create Date: 2024-08-12 00:11:50.915845 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "ee3f4b47fad5" down_revision = "2d2304e27d8c" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "chat_message", sa.Column("overridden_model", sa.String(length=255), nullable=True), ) def downgrade() -> None: op.drop_column("chat_message", "overridden_model") ================================================ FILE: backend/alembic/versions/ef7da92f7213_add_files_to_chatmessage.py ================================================ """Add files to ChatMessage Revision ID: ef7da92f7213 Revises: 401c1ac29467 Create Date: 2024-04-28 16:59:33.199153 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "ef7da92f7213" down_revision = "401c1ac29467" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "chat_message", sa.Column("files", postgresql.JSONB(astext_type=sa.Text()), nullable=True), ) def downgrade() -> None: op.drop_column("chat_message", "files") ================================================ FILE: backend/alembic/versions/efb35676026c_standard_answer_match_regex_flag.py ================================================ """standard answer match_regex flag Revision ID: efb35676026c Revises: 0ebb1d516877 Create Date: 2024-09-11 13:55:46.101149 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "efb35676026c" down_revision = "0ebb1d516877" branch_labels = None depends_on = None def upgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.add_column( "standard_answer", sa.Column( "match_regex", sa.Boolean(), nullable=False, server_default=sa.false() ), ) # ### end Alembic commands ### def downgrade() -> None: # ### commands auto generated by Alembic - please adjust! ### op.drop_column("standard_answer", "match_regex") # ### end Alembic commands ### ================================================ FILE: backend/alembic/versions/f11b408e39d3_force_lowercase_all_users.py ================================================ """force lowercase all users Revision ID: f11b408e39d3 Revises: 3bd4c84fe72f Create Date: 2025-02-26 17:04:55.683500 """ # revision identifiers, used by Alembic. revision = "f11b408e39d3" down_revision = "3bd4c84fe72f" branch_labels = None depends_on = None def upgrade() -> None: # 1) Convert all existing user emails to lowercase from alembic import op op.execute( """ UPDATE "user" SET email = LOWER(email) """ ) # 2) Add a check constraint to ensure emails are always lowercase op.create_check_constraint("ensure_lowercase_email", "user", "email = LOWER(email)") def downgrade() -> None: # Drop the check constraint from alembic import op op.drop_constraint("ensure_lowercase_email", "user", type_="check") ================================================ FILE: backend/alembic/versions/f13db29f3101_add_composite_index_for_last_modified_.py ================================================ """Add composite index for last_modified and last_synced to document Revision ID: f13db29f3101 Revises: b388730a2899 Create Date: 2025-02-18 22:48:11.511389 """ from alembic import op # revision identifiers, used by Alembic. revision = "f13db29f3101" down_revision = "acaab4ef4507" branch_labels: str | None = None depends_on: str | None = None def upgrade() -> None: op.create_index( "ix_document_sync_status", "document", ["last_modified", "last_synced"], unique=False, ) def downgrade() -> None: op.drop_index("ix_document_sync_status", table_name="document") ================================================ FILE: backend/alembic/versions/f17bf3b0d9f1_embedding_provider_by_provider_type.py ================================================ """embedding provider by provider type Revision ID: f17bf3b0d9f1 Revises: 351faebd379d Create Date: 2024-08-21 13:13:31.120460 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "f17bf3b0d9f1" down_revision = "351faebd379d" branch_labels: None = None depends_on: None = None def upgrade() -> None: # Add provider_type column to embedding_provider op.add_column( "embedding_provider", sa.Column("provider_type", sa.String(50), nullable=True), ) # Update provider_type with existing name values op.execute("UPDATE embedding_provider SET provider_type = UPPER(name)") # Make provider_type not nullable op.alter_column("embedding_provider", "provider_type", nullable=False) # Drop the foreign key constraint in embedding_model table op.drop_constraint( "fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey" ) # Drop the existing primary key constraint op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary") # Create a new primary key constraint on provider_type op.create_primary_key( "embedding_provider_pkey", "embedding_provider", ["provider_type"] ) # Add provider_type column to embedding_model op.add_column( "embedding_model", sa.Column("provider_type", sa.String(50), nullable=True), ) # Update provider_type for existing embedding models op.execute( """ UPDATE embedding_model SET provider_type = ( SELECT provider_type FROM embedding_provider WHERE embedding_provider.id = embedding_model.cloud_provider_id ) """ ) # Drop the old id column from embedding_provider op.drop_column("embedding_provider", "id") # Drop the name column from embedding_provider op.drop_column("embedding_provider", "name") # Drop the default_model_id column from embedding_provider op.drop_column("embedding_provider", "default_model_id") # Drop the old cloud_provider_id column from embedding_model op.drop_column("embedding_model", "cloud_provider_id") # Create the new foreign key constraint op.create_foreign_key( "fk_embedding_model_cloud_provider", "embedding_model", "embedding_provider", ["provider_type"], ["provider_type"], ) def downgrade() -> None: # Drop the foreign key constraint in embedding_model table op.drop_constraint( "fk_embedding_model_cloud_provider", "embedding_model", type_="foreignkey" ) # Add back the cloud_provider_id column to embedding_model op.add_column( "embedding_model", sa.Column("cloud_provider_id", sa.Integer(), nullable=True) ) op.add_column("embedding_provider", sa.Column("id", sa.Integer(), nullable=True)) # Assign incrementing IDs to embedding providers op.execute( """ CREATE SEQUENCE IF NOT EXISTS embedding_provider_id_seq;""" ) op.execute( """ UPDATE embedding_provider SET id = nextval('embedding_provider_id_seq'); """ ) # Update cloud_provider_id based on provider_type op.execute( """ UPDATE embedding_model SET cloud_provider_id = CASE WHEN provider_type IS NULL THEN NULL ELSE ( SELECT id FROM embedding_provider WHERE embedding_provider.provider_type = embedding_model.provider_type ) END """ ) # Drop the provider_type column from embedding_model op.drop_column("embedding_model", "provider_type") # Add back the columns to embedding_provider op.add_column("embedding_provider", sa.Column("name", sa.String(50), nullable=True)) op.add_column( "embedding_provider", sa.Column("default_model_id", sa.Integer(), nullable=True) ) # Drop the existing primary key constraint on provider_type op.drop_constraint("embedding_provider_pkey", "embedding_provider", type_="primary") # Create the original primary key constraint on id op.create_primary_key("embedding_provider_pkey", "embedding_provider", ["id"]) # Update name with existing provider_type values op.execute( """ UPDATE embedding_provider SET name = CASE WHEN provider_type = 'OPENAI' THEN 'OpenAI' WHEN provider_type = 'COHERE' THEN 'Cohere' WHEN provider_type = 'GOOGLE' THEN 'Google' WHEN provider_type = 'VOYAGE' THEN 'Voyage' ELSE provider_type END """ ) # Drop the provider_type column from embedding_provider op.drop_column("embedding_provider", "provider_type") # Recreate the foreign key constraint in embedding_model table op.create_foreign_key( "fk_embedding_model_cloud_provider", "embedding_model", "embedding_provider", ["cloud_provider_id"], ["id"], ) # Recreate the foreign key constraint in embedding_model table op.create_foreign_key( "fk_embedding_provider_default_model", "embedding_provider", "embedding_model", ["default_model_id"], ["id"], ) ================================================ FILE: backend/alembic/versions/f1c6478c3fd8_add_pre_defined_feedback.py ================================================ """Add pre-defined feedback Revision ID: f1c6478c3fd8 Revises: 643a84a42a33 Create Date: 2024-05-09 18:11:49.210667 """ from alembic import op import sqlalchemy as sa revision = "f1c6478c3fd8" down_revision = "643a84a42a33" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "chat_feedback", sa.Column("predefined_feedback", sa.String(), nullable=True), ) def downgrade() -> None: op.drop_column("chat_feedback", "predefined_feedback") ================================================ FILE: backend/alembic/versions/f1ca58b2f2ec_add_passthrough_auth_to_tool.py ================================================ """add passthrough auth to tool Revision ID: f1ca58b2f2ec Revises: c7bf5721733e Create Date: 2024-03-19 """ from typing import Sequence, Union from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision: str = "f1ca58b2f2ec" down_revision: Union[str, None] = "c7bf5721733e" branch_labels: Union[str, Sequence[str], None] = None depends_on: Union[str, Sequence[str], None] = None def upgrade() -> None: # Add passthrough_auth column to tool table with default value of False op.add_column( "tool", sa.Column( "passthrough_auth", sa.Boolean(), nullable=False, server_default=sa.false() ), ) def downgrade() -> None: # Remove passthrough_auth column from tool table op.drop_column("tool", "passthrough_auth") ================================================ FILE: backend/alembic/versions/f220515df7b4_add_flow_mapping_table.py ================================================ """Add flow mapping table Revision ID: f220515df7b4 Revises: cbc03e08d0f3 Create Date: 2026-01-30 12:21:24.955922 """ from onyx.db.enums import LLMModelFlowType from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "f220515df7b4" down_revision = "9d1543a37106" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "llm_model_flow", sa.Column("id", sa.Integer(), nullable=False), sa.Column( "llm_model_flow_type", sa.Enum(LLMModelFlowType, name="llmmodelflowtype", native_enum=False), nullable=False, ), sa.Column( "is_default", sa.Boolean(), nullable=False, server_default=sa.text("false") ), sa.Column("model_configuration_id", sa.Integer(), nullable=False), sa.PrimaryKeyConstraint("id"), sa.ForeignKeyConstraint( ["model_configuration_id"], ["model_configuration.id"], ondelete="CASCADE" ), sa.UniqueConstraint( "llm_model_flow_type", "model_configuration_id", name="uq_model_config_per_llm_model_flow_type", ), ) # Partial unique index so that there is at most one default for each flow type op.create_index( "ix_one_default_per_llm_model_flow", "llm_model_flow", ["llm_model_flow_type"], unique=True, postgresql_where=sa.text("is_default IS TRUE"), ) def downgrade() -> None: # Drop the llm_model_flow table (index is dropped automatically with table) op.drop_table("llm_model_flow") ================================================ FILE: backend/alembic/versions/f32615f71aeb_add_custom_headers_to_tools.py ================================================ """add custom headers to tools Revision ID: f32615f71aeb Revises: bd2921608c3a Create Date: 2024-09-12 20:26:38.932377 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "f32615f71aeb" down_revision = "bd2921608c3a" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "tool", sa.Column("custom_headers", postgresql.JSONB(), nullable=True) ) def downgrade() -> None: op.drop_column("tool", "custom_headers") ================================================ FILE: backend/alembic/versions/f39c5794c10a_add_background_errors_table.py ================================================ """Add background errors table Revision ID: f39c5794c10a Revises: 2cdeff6d8c93 Create Date: 2025-02-12 17:11:14.527876 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "f39c5794c10a" down_revision = "2cdeff6d8c93" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "background_error", sa.Column("id", sa.Integer(), nullable=False), sa.Column("message", sa.String(), nullable=False), sa.Column( "time_created", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=False, ), sa.Column("cc_pair_id", sa.Integer(), nullable=True), sa.PrimaryKeyConstraint("id"), sa.ForeignKeyConstraint( ["cc_pair_id"], ["connector_credential_pair.id"], ondelete="CASCADE", ), ) def downgrade() -> None: op.drop_table("background_error") ================================================ FILE: backend/alembic/versions/f5437cc136c5_delete_non_search_assistants.py ================================================ """delete non-search assistants Revision ID: f5437cc136c5 Revises: eaa3b5593925 Create Date: 2025-02-04 16:17:15.677256 """ from alembic import op # revision identifiers, used by Alembic. revision = "f5437cc136c5" down_revision = "eaa3b5593925" branch_labels = None depends_on = None def upgrade() -> None: pass def downgrade() -> None: # Fix: split the statements into multiple op.execute() calls op.execute( """ WITH personas_without_search AS ( SELECT p.id FROM persona p LEFT JOIN persona__tool pt ON p.id = pt.persona_id LEFT JOIN tool t ON pt.tool_id = t.id GROUP BY p.id HAVING COUNT(CASE WHEN t.in_code_tool_id = 'run_search' THEN 1 END) = 0 ) UPDATE slack_channel_config SET persona_id = NULL WHERE is_default = TRUE AND persona_id IN (SELECT id FROM personas_without_search) """ ) op.execute( """ WITH personas_without_search AS ( SELECT p.id FROM persona p LEFT JOIN persona__tool pt ON p.id = pt.persona_id LEFT JOIN tool t ON pt.tool_id = t.id GROUP BY p.id HAVING COUNT(CASE WHEN t.in_code_tool_id = 'run_search' THEN 1 END) = 0 ) DELETE FROM slack_channel_config WHERE is_default = FALSE AND persona_id IN (SELECT id FROM personas_without_search) """ ) ================================================ FILE: backend/alembic/versions/f71470ba9274_add_prompt_length_limit.py ================================================ """add prompt length limit Revision ID: f71470ba9274 Revises: 6a804aeb4830 Create Date: 2025-04-01 15:07:14.977435 """ # revision identifiers, used by Alembic. revision = "f71470ba9274" down_revision = "6a804aeb4830" branch_labels = None depends_on = None def upgrade() -> None: # op.alter_column( # "prompt", # "system_prompt", # existing_type=sa.TEXT(), # type_=sa.String(length=8000), # existing_nullable=False, # ) # op.alter_column( # "prompt", # "task_prompt", # existing_type=sa.TEXT(), # type_=sa.String(length=8000), # existing_nullable=False, # ) pass def downgrade() -> None: # op.alter_column( # "prompt", # "system_prompt", # existing_type=sa.String(length=8000), # type_=sa.TEXT(), # existing_nullable=False, # ) # op.alter_column( # "prompt", # "task_prompt", # existing_type=sa.String(length=8000), # type_=sa.TEXT(), # existing_nullable=False, # ) pass ================================================ FILE: backend/alembic/versions/f7505c5b0284_updated_constraints_for_ccpairs.py ================================================ """updated constraints for ccpairs Revision ID: f7505c5b0284 Revises: f71470ba9274 Create Date: 2025-04-01 17:50:42.504818 """ from alembic import op # revision identifiers, used by Alembic. revision = "f7505c5b0284" down_revision = "f71470ba9274" branch_labels = None depends_on = None def upgrade() -> None: # 1) Drop the old foreign-key constraints op.drop_constraint( "document_by_connector_credential_pair_connector_id_fkey", "document_by_connector_credential_pair", type_="foreignkey", ) op.drop_constraint( "document_by_connector_credential_pair_credential_id_fkey", "document_by_connector_credential_pair", type_="foreignkey", ) # 2) Re-add them with ondelete='CASCADE' op.create_foreign_key( "document_by_connector_credential_pair_connector_id_fkey", source_table="document_by_connector_credential_pair", referent_table="connector", local_cols=["connector_id"], remote_cols=["id"], ondelete="CASCADE", ) op.create_foreign_key( "document_by_connector_credential_pair_credential_id_fkey", source_table="document_by_connector_credential_pair", referent_table="credential", local_cols=["credential_id"], remote_cols=["id"], ondelete="CASCADE", ) def downgrade() -> None: # Reverse the changes for rollback op.drop_constraint( "document_by_connector_credential_pair_connector_id_fkey", "document_by_connector_credential_pair", type_="foreignkey", ) op.drop_constraint( "document_by_connector_credential_pair_credential_id_fkey", "document_by_connector_credential_pair", type_="foreignkey", ) # Recreate without CASCADE op.create_foreign_key( "document_by_connector_credential_pair_connector_id_fkey", "document_by_connector_credential_pair", "connector", ["connector_id"], ["id"], ) op.create_foreign_key( "document_by_connector_credential_pair_credential_id_fkey", "document_by_connector_credential_pair", "credential", ["credential_id"], ["id"], ) ================================================ FILE: backend/alembic/versions/f7a894b06d02_non_nullbale_slack_bot_id_in_channel_.py ================================================ """non-nullbale slack bot id in channel config Revision ID: f7a894b06d02 Revises: 9f696734098f Create Date: 2024-12-06 12:55:42.845723 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "f7a894b06d02" down_revision = "9f696734098f" branch_labels = None depends_on = None def upgrade() -> None: # Delete all rows with null slack_bot_id op.execute("DELETE FROM slack_channel_config WHERE slack_bot_id IS NULL") # Make slack_bot_id non-nullable op.alter_column( "slack_channel_config", "slack_bot_id", existing_type=sa.Integer(), nullable=False, ) def downgrade() -> None: # Make slack_bot_id nullable again op.alter_column( "slack_channel_config", "slack_bot_id", existing_type=sa.Integer(), nullable=True, ) ================================================ FILE: backend/alembic/versions/f7ca3e2f45d9_migrate_no_auth_data_to_placeholder.py ================================================ """migrate_no_auth_data_to_placeholder This migration handles the transition from AUTH_TYPE=disabled to requiring authentication. It creates a placeholder user and assigns all data that was created without a user (user_id=NULL) to this placeholder. A database trigger is installed that automatically transfers all data from the placeholder user to the first real user who registers, then drops itself. Revision ID: f7ca3e2f45d9 Revises: 78ebc66946a0 Create Date: 2026-01-15 12:49:53.802741 """ import os from alembic import op import sqlalchemy as sa from shared_configs.configs import MULTI_TENANT # revision identifiers, used by Alembic. revision = "f7ca3e2f45d9" down_revision = "78ebc66946a0" branch_labels = None depends_on = None # Must match constants in onyx/configs/constants.py file NO_AUTH_PLACEHOLDER_USER_UUID = "00000000-0000-0000-0000-000000000001" NO_AUTH_PLACEHOLDER_USER_EMAIL = "no-auth-placeholder@onyx.app" # Trigger and function names TRIGGER_NAME = "trg_migrate_no_auth_data" FUNCTION_NAME = "migrate_no_auth_data_to_user" # Trigger function that migrates data from placeholder to first real user MIGRATE_NO_AUTH_TRIGGER_FUNCTION = f""" CREATE OR REPLACE FUNCTION {FUNCTION_NAME}() RETURNS TRIGGER AS $$ DECLARE placeholder_uuid UUID := '00000000-0000-0000-0000-000000000001'::uuid; anonymous_uuid UUID := '00000000-0000-0000-0000-000000000002'::uuid; placeholder_row RECORD; schema_name TEXT; BEGIN -- Skip if this is the placeholder user being inserted IF NEW.id = placeholder_uuid THEN RETURN NULL; END IF; -- Skip if this is the anonymous user being inserted (not a real user) IF NEW.id = anonymous_uuid THEN RETURN NULL; END IF; -- Skip if the new user is not active IF NEW.is_active = FALSE THEN RETURN NULL; END IF; -- Get current schema for self-cleanup schema_name := current_schema(); -- Try to lock the placeholder user row with FOR UPDATE SKIP LOCKED -- This ensures only one concurrent transaction can proceed with migration -- SKIP LOCKED means if another transaction has the lock, we skip (don't wait) SELECT id INTO placeholder_row FROM "user" WHERE id = placeholder_uuid FOR UPDATE SKIP LOCKED; IF NOT FOUND THEN -- Either placeholder doesn't exist or another transaction has it locked -- Either way, drop the trigger and return without making admin EXECUTE format('DROP TRIGGER IF EXISTS {TRIGGER_NAME} ON %I."user"', schema_name); EXECUTE format('DROP FUNCTION IF EXISTS %I.{FUNCTION_NAME}()', schema_name); RETURN NULL; END IF; -- We have exclusive lock on placeholder - proceed with migration -- The INSERT has already completed (AFTER INSERT), so NEW.id exists in the table -- Migrate chat_session UPDATE "chat_session" SET user_id = NEW.id WHERE user_id = placeholder_uuid; -- Migrate credential (exclude public credential id=0) UPDATE "credential" SET user_id = NEW.id WHERE user_id = placeholder_uuid AND id != 0; -- Migrate document_set UPDATE "document_set" SET user_id = NEW.id WHERE user_id = placeholder_uuid; -- Migrate persona (exclude builtin personas) UPDATE "persona" SET user_id = NEW.id WHERE user_id = placeholder_uuid AND builtin_persona = FALSE; -- Migrate tool (exclude builtin tools) UPDATE "tool" SET user_id = NEW.id WHERE user_id = placeholder_uuid AND in_code_tool_id IS NULL; -- Migrate notification UPDATE "notification" SET user_id = NEW.id WHERE user_id = placeholder_uuid; -- Migrate inputprompt (exclude system/public prompts) UPDATE "inputprompt" SET user_id = NEW.id WHERE user_id = placeholder_uuid AND is_public = FALSE; -- Make the new user an admin (they had admin access in no-auth mode) -- In AFTER INSERT trigger, we must UPDATE the row since it already exists UPDATE "user" SET role = 'ADMIN' WHERE id = NEW.id; -- Delete the placeholder user (we hold the lock so this is safe) DELETE FROM "user" WHERE id = placeholder_uuid; -- Drop the trigger and function (self-cleanup) EXECUTE format('DROP TRIGGER IF EXISTS {TRIGGER_NAME} ON %I."user"', schema_name); EXECUTE format('DROP FUNCTION IF EXISTS %I.{FUNCTION_NAME}()', schema_name); RETURN NULL; END; $$ LANGUAGE plpgsql; """ MIGRATE_NO_AUTH_TRIGGER = f""" CREATE TRIGGER {TRIGGER_NAME} AFTER INSERT ON "user" FOR EACH ROW EXECUTE FUNCTION {FUNCTION_NAME}(); """ def upgrade() -> None: """ Create a placeholder user and assign all NULL user_id records to it. Install a trigger that migrates data to the first real user and self-destructs. Only runs if AUTH_TYPE is currently disabled/none. Skipped in multi-tenant mode - each tenant starts fresh with no legacy data. """ # Skip in multi-tenant mode - this migration handles single-tenant # AUTH_TYPE=disabled -> auth transitions only if MULTI_TENANT: return # Only run if AUTH_TYPE is currently disabled/none # If they've already switched to auth-enabled, NULL data is stale anyway auth_type = (os.environ.get("AUTH_TYPE") or "").lower() if auth_type not in ("disabled", "none", ""): print(f"AUTH_TYPE is '{auth_type}', not disabled. Skipping migration.") return connection = op.get_bind() # Check if there are any NULL user_id records that need migration tables_to_check = [ "chat_session", "credential", "document_set", "persona", "tool", "notification", "inputprompt", ] has_null_records = False for table in tables_to_check: try: result = connection.execute( sa.text(f'SELECT 1 FROM "{table}" WHERE user_id IS NULL LIMIT 1') ) if result.fetchone(): has_null_records = True break except Exception: # Table might not exist pass if not has_null_records: return # Create the placeholder user connection.execute( sa.text( """ INSERT INTO "user" (id, email, hashed_password, is_active, is_superuser, is_verified, role) VALUES (:id, :email, :hashed_password, :is_active, :is_superuser, :is_verified, :role) """ ), { "id": NO_AUTH_PLACEHOLDER_USER_UUID, "email": NO_AUTH_PLACEHOLDER_USER_EMAIL, "hashed_password": "", # Empty password - user cannot log in "is_active": False, # Inactive - user cannot log in "is_superuser": False, "is_verified": False, "role": "BASIC", }, ) # Assign NULL user_id records to the placeholder user for table in tables_to_check: try: # Base condition for all tables condition = "user_id IS NULL" # Exclude public credential (id=0) which must remain user_id=NULL if table == "credential": condition += " AND id != 0" # Exclude builtin tools (in_code_tool_id IS NOT NULL) which must remain user_id=NULL elif table == "tool": condition += " AND in_code_tool_id IS NULL" # Exclude builtin personas which must remain user_id=NULL elif table == "persona": condition += " AND builtin_persona = FALSE" # Exclude system/public input prompts which must remain user_id=NULL elif table == "inputprompt": condition += " AND is_public = FALSE" result = connection.execute( sa.text( f""" UPDATE "{table}" SET user_id = :user_id WHERE {condition} """ ), {"user_id": NO_AUTH_PLACEHOLDER_USER_UUID}, ) if result.rowcount > 0: print(f"Updated {result.rowcount} rows in {table}") except Exception as e: print(f"Skipping {table}: {e}") # Install the trigger function and trigger for automatic migration on first user registration connection.execute(sa.text(MIGRATE_NO_AUTH_TRIGGER_FUNCTION)) connection.execute(sa.text(MIGRATE_NO_AUTH_TRIGGER)) print("Installed trigger for automatic data migration on first user registration") def downgrade() -> None: """ Drop trigger and function, set placeholder user's records back to NULL, and delete the placeholder user. """ # Skip in multi-tenant mode for consistency with upgrade if MULTI_TENANT: return connection = op.get_bind() # Drop trigger and function if they exist (they may have already self-destructed) connection.execute(sa.text(f'DROP TRIGGER IF EXISTS {TRIGGER_NAME} ON "user"')) connection.execute(sa.text(f"DROP FUNCTION IF EXISTS {FUNCTION_NAME}()")) tables_to_update = [ "chat_session", "credential", "document_set", "persona", "tool", "notification", "inputprompt", ] # Set records back to NULL for table in tables_to_update: try: connection.execute( sa.text( f""" UPDATE "{table}" SET user_id = NULL WHERE user_id = :user_id """ ), {"user_id": NO_AUTH_PLACEHOLDER_USER_UUID}, ) except Exception: pass # Delete the placeholder user connection.execute( sa.text('DELETE FROM "user" WHERE id = :user_id'), {"user_id": NO_AUTH_PLACEHOLDER_USER_UUID}, ) ================================================ FILE: backend/alembic/versions/f7e58d357687_add_has_web_column_to_user.py ================================================ """add has_web_login column to user Revision ID: f7e58d357687 Revises: ba98eba0f66a Create Date: 2024-09-07 20:20:54.522620 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "f7e58d357687" down_revision = "ba98eba0f66a" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "user", sa.Column("has_web_login", sa.Boolean(), nullable=False, server_default="true"), ) def downgrade() -> None: op.drop_column("user", "has_web_login") ================================================ FILE: backend/alembic/versions/f8a9b2c3d4e5_add_research_answer_purpose_to_chat_message.py ================================================ """add research_answer_purpose to chat_message Revision ID: f8a9b2c3d4e5 Revises: 5ae8240accb3 Create Date: 2025-01-27 12:00:00.000000 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "f8a9b2c3d4e5" down_revision = "5ae8240accb3" branch_labels = None depends_on = None def upgrade() -> None: # Add research_answer_purpose column to chat_message table op.add_column( "chat_message", sa.Column("research_answer_purpose", sa.String(), nullable=True), ) def downgrade() -> None: # Remove research_answer_purpose column from chat_message table op.drop_column("chat_message", "research_answer_purpose") ================================================ FILE: backend/alembic/versions/f9b8c7d6e5a4_update_parent_question_id_foreign_key_to_research_agent_iteration.py ================================================ """remove foreign key constraints from research_agent_iteration_sub_step Revision ID: f9b8c7d6e5a4 Revises: bd7c3bf8beba Create Date: 2025-01-27 12:00:00.000000 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "f9b8c7d6e5a4" down_revision = "bd7c3bf8beba" branch_labels = None depends_on = None def upgrade() -> None: # Drop the existing foreign key constraint for parent_question_id op.drop_constraint( "research_agent_iteration_sub_step_parent_question_id_fkey", "research_agent_iteration_sub_step", type_="foreignkey", ) # Drop the parent_question_id column entirely op.drop_column("research_agent_iteration_sub_step", "parent_question_id") # Drop the foreign key constraint for primary_question_id to chat_message.id # (keep the column as it's needed for the composite foreign key) op.drop_constraint( "research_agent_iteration_sub_step_primary_question_id_fkey", "research_agent_iteration_sub_step", type_="foreignkey", ) def downgrade() -> None: # Restore the foreign key constraint for primary_question_id to chat_message.id op.create_foreign_key( "research_agent_iteration_sub_step_primary_question_id_fkey", "research_agent_iteration_sub_step", "chat_message", ["primary_question_id"], ["id"], ondelete="CASCADE", ) # Add back the parent_question_id column op.add_column( "research_agent_iteration_sub_step", sa.Column( "parent_question_id", sa.Integer(), nullable=True, ), ) # Restore the foreign key constraint pointing to research_agent_iteration_sub_step.id op.create_foreign_key( "research_agent_iteration_sub_step_parent_question_id_fkey", "research_agent_iteration_sub_step", "research_agent_iteration_sub_step", ["parent_question_id"], ["id"], ondelete="CASCADE", ) ================================================ FILE: backend/alembic/versions/fad14119fb92_delete_tags_with_wrong_enum.py ================================================ """Delete Tags with wrong Enum Revision ID: fad14119fb92 Revises: 72bdc9929a46 Create Date: 2024-04-25 17:05:09.695703 """ from alembic import op # revision identifiers, used by Alembic. revision = "fad14119fb92" down_revision = "72bdc9929a46" branch_labels: None = None depends_on: None = None def upgrade() -> None: # Some documents may lose their tags but this is the only way as the enum # mapping may have changed since tag switched to string (it will be reindexed anyway) op.execute( """ DELETE FROM document__tag WHERE tag_id IN ( SELECT id FROM tag WHERE source ~ '^[0-9]+$' ) """ ) op.execute( """ DELETE FROM tag WHERE source ~ '^[0-9]+$' """ ) def downgrade() -> None: pass ================================================ FILE: backend/alembic/versions/fb80bdd256de_add_chat_background_to_user.py ================================================ """add chat_background to user Revision ID: fb80bdd256de Revises: 8b5ce697290e Create Date: 2026-01-16 16:15:59.222617 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "fb80bdd256de" down_revision = "8b5ce697290e" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "user", sa.Column( "chat_background", sa.String(), nullable=True, ), ) def downgrade() -> None: op.drop_column("user", "chat_background") ================================================ FILE: backend/alembic/versions/fcd135795f21_add_slack_bot_display_type.py ================================================ """Add slack bot display type Revision ID: fcd135795f21 Revises: 0a2b51deb0b8 Create Date: 2024-03-04 17:03:27.116284 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "fcd135795f21" down_revision = "0a2b51deb0b8" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "slack_bot_config", sa.Column( "response_type", sa.Enum( "QUOTES", "CITATIONS", name="slackbotresponsetype", native_enum=False, ), nullable=True, ), ) op.execute( "UPDATE slack_bot_config SET response_type = 'QUOTES' WHERE response_type IS NULL" ) op.alter_column("slack_bot_config", "response_type", nullable=False) def downgrade() -> None: op.drop_column("slack_bot_config", "response_type") ================================================ FILE: backend/alembic/versions/febe9eaa0644_add_document_set_persona_relationship_.py ================================================ """Add document_set / persona relationship table Revision ID: febe9eaa0644 Revises: 57b53544726e Create Date: 2023-09-24 13:06:24.018610 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "febe9eaa0644" down_revision = "57b53544726e" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.create_table( "persona__document_set", sa.Column("persona_id", sa.Integer(), nullable=False), sa.Column("document_set_id", sa.Integer(), nullable=False), sa.ForeignKeyConstraint( ["document_set_id"], ["document_set.id"], ), sa.ForeignKeyConstraint( ["persona_id"], ["persona.id"], ), sa.PrimaryKeyConstraint("persona_id", "document_set_id"), ) def downgrade() -> None: op.drop_table("persona__document_set") ================================================ FILE: backend/alembic/versions/fec3db967bf7_add_time_updated_to_usergroup_and_.py ================================================ """Add time_updated to UserGroup and DocumentSet Revision ID: fec3db967bf7 Revises: 97dbb53fa8c8 Create Date: 2025-01-12 15:49:02.289100 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "fec3db967bf7" down_revision = "97dbb53fa8c8" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "document_set", sa.Column( "time_last_modified_by_user", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now(), ), ) op.add_column( "user_group", sa.Column( "time_last_modified_by_user", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now(), ), ) def downgrade() -> None: op.drop_column("user_group", "time_last_modified_by_user") op.drop_column("document_set", "time_last_modified_by_user") ================================================ FILE: backend/alembic/versions/feead2911109_add_opensearch_tenant_migration_columns.py ================================================ """add_opensearch_tenant_migration_columns Revision ID: feead2911109 Revises: d56ffa94ca32 Create Date: 2026-02-10 17:46:34.029937 """ from alembic import op import sqlalchemy as sa # revision identifiers, used by Alembic. revision = "feead2911109" down_revision = "175ea04c7087" branch_labels = None depends_on = None def upgrade() -> None: op.add_column( "opensearch_tenant_migration_record", sa.Column("vespa_visit_continuation_token", sa.Text(), nullable=True), ) op.add_column( "opensearch_tenant_migration_record", sa.Column( "total_chunks_migrated", sa.Integer(), nullable=False, server_default="0", ), ) op.add_column( "opensearch_tenant_migration_record", sa.Column( "created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.func.now(), ), ) op.add_column( "opensearch_tenant_migration_record", sa.Column( "migration_completed_at", sa.DateTime(timezone=True), nullable=True, ), ) op.add_column( "opensearch_tenant_migration_record", sa.Column( "enable_opensearch_retrieval", sa.Boolean(), nullable=False, server_default="false", ), ) def downgrade() -> None: op.drop_column("opensearch_tenant_migration_record", "enable_opensearch_retrieval") op.drop_column("opensearch_tenant_migration_record", "migration_completed_at") op.drop_column("opensearch_tenant_migration_record", "created_at") op.drop_column("opensearch_tenant_migration_record", "total_chunks_migrated") op.drop_column( "opensearch_tenant_migration_record", "vespa_visit_continuation_token" ) ================================================ FILE: backend/alembic/versions/ffc707a226b4_basic_document_metadata.py ================================================ """Basic Document Metadata Revision ID: ffc707a226b4 Revises: 30c1d5744104 Create Date: 2023-10-18 16:52:25.967592 """ from alembic import op import sqlalchemy as sa from sqlalchemy.dialects import postgresql # revision identifiers, used by Alembic. revision = "ffc707a226b4" down_revision = "30c1d5744104" branch_labels: None = None depends_on: None = None def upgrade() -> None: op.add_column( "document", sa.Column("doc_updated_at", sa.DateTime(timezone=True), nullable=True), ) op.add_column( "document", sa.Column("primary_owners", postgresql.ARRAY(sa.String()), nullable=True), ) op.add_column( "document", sa.Column("secondary_owners", postgresql.ARRAY(sa.String()), nullable=True), ) def downgrade() -> None: op.drop_column("document", "secondary_owners") op.drop_column("document", "primary_owners") op.drop_column("document", "doc_updated_at") ================================================ FILE: backend/alembic.ini ================================================ # A generic, single database configuration. [DEFAULT] # path to migration scripts script_location = alembic # template used to generate migration file names; The default value is %%(rev)s_%%(slug)s # Uncomment the line below if you want the files to be prepended with date and time # file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s # sys.path path, will be prepended to sys.path if present. # defaults to the current working directory. prepend_sys_path = . # timezone to use when rendering the date within the migration file # as well as the filename. # If specified, requires the python-dateutil library that can be # installed by adding `alembic[tz]` to the pip requirements # string value is passed to dateutil.tz.gettz() # leave blank for localtime # timezone = # max length of characters to apply to the # "slug" field # truncate_slug_length = 40 # set to 'true' to run the environment during # the 'revision' command, regardless of autogenerate # revision_environment = false # set to 'true' to allow .pyc and .pyo files without # a source .py file to be detected as revisions in the # versions/ directory # sourceless = false # version location specification; This defaults # to alembic/versions. When using multiple version # directories, initial revisions must be specified with --version-path. # The path separator used here should be the separator specified by "version_path_separator" below. # version_locations = %(here)s/bar:%(here)s/bat:alembic/versions # version path separator; As mentioned above, this is the character used to split # version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. # If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. # Valid values for version_path_separator are: # # version_path_separator = : # version_path_separator = ; # version_path_separator = space version_path_separator = os # Use os.pathsep. Default configuration used for new projects. # set to 'true' to search source files recursively # in each "version_locations" directory # new in Alembic version 1.10 # recursive_version_locations = false # the output encoding used when revision files # are written from script.py.mako # output_encoding = utf-8 # sqlalchemy.url = driver://user:pass@localhost/dbname [post_write_hooks] # post_write_hooks defines scripts or Python functions that are run # on newly generated revision scripts. See the documentation for further # detail and examples # format using "black" - use the console_scripts runner, against the "black" entrypoint hooks = black black.type = console_scripts black.entrypoint = black black.options = -l 79 REVISION_SCRIPT_FILENAME # Logging configuration [loggers] keys = root,sqlalchemy,alembic [handlers] keys = console [formatters] keys = generic [logger_root] level = INFO handlers = console qualname = [logger_sqlalchemy] level = WARN handlers = qualname = sqlalchemy.engine [logger_alembic] level = INFO handlers = qualname = alembic [handler_console] class = StreamHandler args = (sys.stderr,) level = NOTSET formatter = generic [formatter_generic] format = %(levelname)-5.5s [%(name)s] %(message)s datefmt = %H:%M:%S [alembic] script_location = alembic version_locations = %(script_location)s/versions [schema_private] script_location = alembic_tenants version_locations = %(script_location)s/versions ================================================ FILE: backend/alembic_tenants/README.md ================================================ These files are for public table migrations when operating with multi tenancy. If you are not a Onyx developer, you can ignore this directory entirely. ================================================ FILE: backend/alembic_tenants/__init__.py ================================================ ================================================ FILE: backend/alembic_tenants/env.py ================================================ import asyncio from logging.config import fileConfig from typing import Literal from sqlalchemy import pool from sqlalchemy.engine import Connection from sqlalchemy.ext.asyncio import create_async_engine from sqlalchemy.schema import SchemaItem from alembic import context from onyx.db.engine.sql_engine import build_connection_string from onyx.db.models import PublicBase # this is the Alembic Config object, which provides # access to the values within the .ini file in use. config = context.config # Interpret the config file for Python logging. # This line sets up loggers basically. if config.config_file_name is not None and config.attributes.get( "configure_logger", True ): # disable_existing_loggers=False prevents breaking pytest's caplog fixture # See: https://pytest-alembic.readthedocs.io/en/latest/setup.html#caplog-issues fileConfig(config.config_file_name, disable_existing_loggers=False) # add your model's MetaData object here # for 'autogenerate' support # from myapp import mymodel # target_metadata = mymodel.Base.metadata target_metadata = [PublicBase.metadata] # other values from the config, defined by the needs of env.py, # can be acquired: # my_important_option = config.get_main_option("my_important_option") # ... etc. EXCLUDE_TABLES = {"kombu_queue", "kombu_message"} def include_object( object: SchemaItem, # noqa: ARG001 name: str | None, type_: Literal[ "schema", "table", "column", "index", "unique_constraint", "foreign_key_constraint", ], reflected: bool, # noqa: ARG001 compare_to: SchemaItem | None, # noqa: ARG001 ) -> bool: if type_ == "table" and name in EXCLUDE_TABLES: return False return True def run_migrations_offline() -> None: """Run migrations in 'offline' mode. This configures the context with just a URL and not an Engine, though an Engine is acceptable here as well. By skipping the Engine creation we don't even need a DBAPI to be available. Calls to context.execute() here emit the given string to the script output. """ url = build_connection_string() context.configure( url=url, target_metadata=target_metadata, # type: ignore literal_binds=True, dialect_opts={"paramstyle": "named"}, ) with context.begin_transaction(): context.run_migrations() def do_run_migrations(connection: Connection) -> None: context.configure( connection=connection, target_metadata=target_metadata, # type: ignore[arg-type] include_object=include_object, ) with context.begin_transaction(): context.run_migrations() async def run_async_migrations() -> None: """In this scenario we need to create an Engine and associate a connection with the context. """ connectable = create_async_engine( build_connection_string(), poolclass=pool.NullPool, ) async with connectable.connect() as connection: await connection.run_sync(do_run_migrations) await connectable.dispose() def run_migrations_online() -> None: """Run migrations in 'online' mode. Supports pytest-alembic by checking for a pre-configured connection in context.config.attributes["connection"]. If present, uses that connection/engine directly instead of creating a new async engine. """ # Check if pytest-alembic is providing a connection/engine connectable = context.config.attributes.get("connection", None) if connectable is not None: # pytest-alembic is providing an engine - use it directly with connectable.connect() as connection: do_run_migrations(connection) # Commit to ensure changes are visible to next migration connection.commit() else: # Normal operation - use async migrations asyncio.run(run_async_migrations()) if context.is_offline_mode(): run_migrations_offline() else: run_migrations_online() ================================================ FILE: backend/alembic_tenants/script.py.mako ================================================ """${message} Revision ID: ${up_revision} Revises: ${down_revision | comma,n} Create Date: ${create_date} """ from alembic import op import sqlalchemy as sa ${imports if imports else ""} # revision identifiers, used by Alembic. revision = ${repr(up_revision)} down_revision = ${repr(down_revision)} branch_labels = ${repr(branch_labels)} depends_on = ${repr(depends_on)} def upgrade() -> None: ${upgrades if upgrades else "pass"} def downgrade() -> None: ${downgrades if downgrades else "pass"} ================================================ FILE: backend/alembic_tenants/versions/14a83a331951_create_usertenantmapping_table.py ================================================ import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. revision = "14a83a331951" down_revision = None branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "user_tenant_mapping", sa.Column("email", sa.String(), nullable=False), sa.Column("tenant_id", sa.String(), nullable=False), sa.UniqueConstraint("email", "tenant_id", name="uq_user_tenant"), sa.UniqueConstraint("email", name="uq_email"), schema="public", ) def downgrade() -> None: op.drop_table("user_tenant_mapping", schema="public") ================================================ FILE: backend/alembic_tenants/versions/34e3630c7f32_lowercase_multi_tenant_user_auth.py ================================================ """lowercase multi-tenant user auth Revision ID: 34e3630c7f32 Revises: a4f6ee863c47 Create Date: 2025-02-26 15:03:01.211894 """ from alembic import op # revision identifiers, used by Alembic. revision = "34e3630c7f32" down_revision = "a4f6ee863c47" branch_labels = None depends_on = None def upgrade() -> None: # 1) Convert all existing rows to lowercase op.execute( """ UPDATE user_tenant_mapping SET email = LOWER(email) """ ) # 2) Add a check constraint so that emails cannot be written in uppercase op.create_check_constraint( "ensure_lowercase_email", "user_tenant_mapping", "email = LOWER(email)", schema="public", ) def downgrade() -> None: # Drop the check constraint op.drop_constraint( "ensure_lowercase_email", "user_tenant_mapping", schema="public", type_="check", ) ================================================ FILE: backend/alembic_tenants/versions/3b45e0018bf1_add_new_available_tenant_table.py ================================================ """add new available tenant table Revision ID: 3b45e0018bf1 Revises: ac842f85f932 Create Date: 2025-03-06 09:55:18.229910 """ import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. revision = "3b45e0018bf1" down_revision = "ac842f85f932" branch_labels = None depends_on = None def upgrade() -> None: # Create new_available_tenant table op.create_table( "available_tenant", sa.Column("tenant_id", sa.String(), nullable=False), sa.Column("alembic_version", sa.String(), nullable=False), sa.Column("date_created", sa.DateTime(), nullable=False), sa.PrimaryKeyConstraint("tenant_id"), ) def downgrade() -> None: # Drop new_available_tenant table op.drop_table("available_tenant") ================================================ FILE: backend/alembic_tenants/versions/3b9f09038764_add_read_only_kg_user.py ================================================ """add_db_readonly_user Revision ID: 3b9f09038764 Revises: 3b45e0018bf1 Create Date: 2025-05-11 11:05:11.436977 """ from sqlalchemy import text from alembic import op from onyx.configs.app_configs import DB_READONLY_PASSWORD from onyx.configs.app_configs import DB_READONLY_USER # revision identifiers, used by Alembic. revision = "3b9f09038764" down_revision = "3b45e0018bf1" branch_labels = None depends_on = None def upgrade() -> None: # Enable pg_trgm extension if not already enabled op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm") # Create the read-only db user if it does not already exist. if not (DB_READONLY_USER and DB_READONLY_PASSWORD): raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set") op.execute( text( f""" DO $$ BEGIN -- Check if the read-only user already exists IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN -- Create the read-only user with the specified password EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}'); -- First revoke all privileges to ensure a clean slate EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}'); -- Grant only the CONNECT privilege to allow the user to connect to the database -- but not perform any operations without additional specific grants EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}'); END IF; END $$; """ ) ) def downgrade() -> None: op.execute( text( f""" DO $$ BEGIN IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN -- First revoke all privileges from the database EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}'); -- Then revoke all privileges from the public schema EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}'); -- Then drop the user EXECUTE format('DROP USER %I', '{DB_READONLY_USER}'); END IF; END $$; """ ) ) op.execute(text("DROP EXTENSION IF EXISTS pg_trgm")) ================================================ FILE: backend/alembic_tenants/versions/a4f6ee863c47_mapping_for_anonymous_user_path.py ================================================ """mapping for anonymous user path Revision ID: a4f6ee863c47 Revises: 14a83a331951 Create Date: 2025-01-04 14:16:58.697451 """ import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. revision = "a4f6ee863c47" down_revision = "14a83a331951" branch_labels = None depends_on = None def upgrade() -> None: op.create_table( "tenant_anonymous_user_path", sa.Column("tenant_id", sa.String(), primary_key=True, nullable=False), sa.Column("anonymous_user_path", sa.String(), nullable=False), sa.PrimaryKeyConstraint("tenant_id"), sa.UniqueConstraint("anonymous_user_path"), ) def downgrade() -> None: op.drop_table("tenant_anonymous_user_path") ================================================ FILE: backend/alembic_tenants/versions/ac842f85f932_new_column_user_tenant_mapping.py ================================================ """new column user tenant mapping Revision ID: ac842f85f932 Revises: 34e3630c7f32 Create Date: 2025-03-03 13:30:14.802874 """ import sqlalchemy as sa from alembic import op # revision identifiers, used by Alembic. revision = "ac842f85f932" down_revision = "34e3630c7f32" branch_labels = None depends_on = None def upgrade() -> None: # Add active column with default value of True op.add_column( "user_tenant_mapping", sa.Column( "active", sa.Boolean(), nullable=False, server_default="true", ), schema="public", ) op.drop_constraint("uq_email", "user_tenant_mapping", schema="public") # Create a unique index for active=true records # This ensures a user can only be active in one tenant at a time op.execute( "CREATE UNIQUE INDEX uq_user_active_email_idx ON public.user_tenant_mapping (email) WHERE active = true" ) def downgrade() -> None: # Drop the unique index for active=true records op.execute("DROP INDEX IF EXISTS uq_user_active_email_idx") op.create_unique_constraint( "uq_email", "user_tenant_mapping", ["email"], schema="public" ) # Remove the active column op.drop_column("user_tenant_mapping", "active", schema="public") ================================================ FILE: backend/assets/.gitignore ================================================ * !.gitignore ================================================ FILE: backend/ee/LICENSE ================================================ The Onyx Enterprise License (the "Enterprise License") Copyright (c) 2023-present DanswerAI, Inc. With regard to the Onyx Software: This software and associated documentation files (the "Software") may only be used in production, if you (and any entity that you represent) have agreed to, and are in compliance with, the Onyx Subscription Terms of Service, available at https://www.onyx.app/legal/self-host (the "Enterprise Terms"), or other agreement governing the use of the Software, as agreed by you and DanswerAI, and otherwise have a valid Onyx Enterprise License for the correct number of user seats. Subject to the foregoing sentence, you are free to modify this Software and publish patches to the Software. You agree that DanswerAI and/or its licensors (as applicable) retain all right, title and interest in and to all such modifications and/or patches, and all such modifications and/or patches may only be used, copied, modified, displayed, distributed, or otherwise exploited with a valid Onyx Enterprise License for the correct number of user seats. Notwithstanding the foregoing, you may copy and modify the Software for development and testing purposes, without requiring a subscription. You agree that DanswerAI and/or its licensors (as applicable) retain all right, title and interest in and to all such modifications. You are not granted any other rights beyond what is expressly stated herein. Subject to the foregoing, it is forbidden to copy, merge, publish, distribute, sublicense, and/or sell the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. For all third party components incorporated into the Onyx Software, those components are licensed under the original license provided by the owner of the applicable component. ================================================ FILE: backend/ee/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/access/access.py ================================================ from sqlalchemy.orm import Session from ee.onyx.db.external_perm import fetch_external_groups_for_user from ee.onyx.db.external_perm import fetch_public_external_group_ids from ee.onyx.db.user_group import fetch_user_groups_for_documents from ee.onyx.db.user_group import fetch_user_groups_for_user from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config from onyx.access.access import ( _get_access_for_documents as get_access_for_documents_without_groups, ) from onyx.access.access import _get_acl_for_user as get_acl_for_user_without_groups from onyx.access.access import collect_user_file_access from onyx.access.models import DocumentAccess from onyx.access.utils import prefix_external_group from onyx.access.utils import prefix_user_group from onyx.db.document import get_document_sources from onyx.db.document import get_documents_by_ids from onyx.db.models import User from onyx.db.models import UserFile from onyx.db.user_file import fetch_user_files_with_access_relationships from onyx.utils.logger import setup_logger logger = setup_logger() def _get_access_for_document( document_id: str, db_session: Session, ) -> DocumentAccess: id_to_access = _get_access_for_documents([document_id], db_session) if len(id_to_access) == 0: return DocumentAccess.build( user_emails=[], user_groups=[], external_user_emails=[], external_user_group_ids=[], is_public=False, ) return next(iter(id_to_access.values())) def _get_access_for_documents( document_ids: list[str], db_session: Session, ) -> dict[str, DocumentAccess]: non_ee_access_dict = get_access_for_documents_without_groups( document_ids=document_ids, db_session=db_session, ) user_group_info: dict[str, list[str]] = { document_id: group_names for document_id, group_names in fetch_user_groups_for_documents( db_session=db_session, document_ids=document_ids, ) } documents = get_documents_by_ids( db_session=db_session, document_ids=document_ids, ) doc_id_map = {doc.id: doc for doc in documents} # Get all sources in one batch doc_id_to_source_map = get_document_sources( db_session=db_session, document_ids=document_ids, ) all_public_ext_u_group_ids = set(fetch_public_external_group_ids(db_session)) access_map = {} for document_id, non_ee_access in non_ee_access_dict.items(): document = doc_id_map[document_id] source = doc_id_to_source_map.get(document_id) if source is None: logger.error(f"Document {document_id} has no source") continue perm_sync_config = get_source_perm_sync_config(source) is_only_censored = ( perm_sync_config and perm_sync_config.censoring_config is not None and perm_sync_config.doc_sync_config is None ) ext_u_emails = ( set(document.external_user_emails) if document.external_user_emails else set() ) ext_u_groups = ( set(document.external_user_group_ids) if document.external_user_group_ids else set() ) # If the document is determined to be "public" externally (through a SYNC connector) # then it's given the same access level as if it were marked public within Onyx # If its censored, then it's public anywhere during the search and then permissions are # applied after the search is_public_anywhere = ( document.is_public or non_ee_access.is_public or is_only_censored or any(u_group in all_public_ext_u_group_ids for u_group in ext_u_groups) ) # To avoid collisions of group namings between connectors, they need to be prefixed access_map[document_id] = DocumentAccess.build( user_emails=list(non_ee_access.user_emails), user_groups=user_group_info.get(document_id, []), is_public=is_public_anywhere, external_user_emails=list(ext_u_emails), external_user_group_ids=list(ext_u_groups), ) return access_map def _collect_user_file_group_names(user_file: UserFile) -> set[str]: """Extract user-group names from the already-loaded Persona.groups relationships on a UserFile (skipping deleted personas).""" groups: set[str] = set() for persona in user_file.assistants: if persona.deleted: continue for group in persona.groups: groups.add(group.name) return groups def get_access_for_user_files_impl( user_file_ids: list[str], db_session: Session, ) -> dict[str, DocumentAccess]: """EE version: extends the MIT user file ACL with user group names from personas shared via user groups. Uses a single DB query (via fetch_user_files_with_access_relationships) that eagerly loads both the MIT-needed and EE-needed relationships. NOTE: is imported in onyx.access.access by `fetch_versioned_implementation` DO NOT REMOVE.""" user_files = fetch_user_files_with_access_relationships( user_file_ids, db_session, eager_load_groups=True ) return build_access_for_user_files_impl(user_files) def build_access_for_user_files_impl( user_files: list[UserFile], ) -> dict[str, DocumentAccess]: """EE version: works on pre-loaded UserFile objects. Expects Persona.groups to be eagerly loaded. NOTE: is imported in onyx.access.access by `fetch_versioned_implementation` DO NOT REMOVE.""" result: dict[str, DocumentAccess] = {} for user_file in user_files: if user_file.user is None: result[str(user_file.id)] = DocumentAccess.build( user_emails=[], user_groups=[], is_public=True, external_user_emails=[], external_user_group_ids=[], ) continue emails, is_public = collect_user_file_access(user_file) group_names = _collect_user_file_group_names(user_file) result[str(user_file.id)] = DocumentAccess.build( user_emails=list(emails), user_groups=list(group_names), is_public=is_public, external_user_emails=[], external_user_group_ids=[], ) return result def _get_acl_for_user(user: User, db_session: Session) -> set[str]: """Returns a list of ACL entries that the user has access to. This is meant to be used downstream to filter out documents that the user does not have access to. The user should have access to a document if at least one entry in the document's ACL matches one entry in the returned set. NOTE: is imported in onyx.access.access by `fetch_versioned_implementation` DO NOT REMOVE.""" is_anonymous = user.is_anonymous db_user_groups = ( [] if is_anonymous else fetch_user_groups_for_user(db_session, user.id) ) prefixed_user_groups = [ prefix_user_group(db_user_group.name) for db_user_group in db_user_groups ] db_external_groups = ( [] if is_anonymous else fetch_external_groups_for_user(db_session, user.id) ) prefixed_external_groups = [ prefix_external_group(db_external_group.external_user_group_id) for db_external_group in db_external_groups ] user_acl = set(prefixed_user_groups + prefixed_external_groups) user_acl.update(get_acl_for_user_without_groups(user, db_session)) return user_acl ================================================ FILE: backend/ee/onyx/access/hierarchy_access.py ================================================ from sqlalchemy.orm import Session from ee.onyx.db.external_perm import fetch_external_groups_for_user from onyx.db.models import User def _get_user_external_group_ids(db_session: Session, user: User) -> list[str]: if not user: return [] external_groups = fetch_external_groups_for_user(db_session, user.id) return [external_group.external_user_group_id for external_group in external_groups] ================================================ FILE: backend/ee/onyx/auth/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/auth/users.py ================================================ import os from datetime import datetime import jwt from fastapi import Depends from fastapi import HTTPException from fastapi import Request from fastapi import status from ee.onyx.configs.app_configs import SUPER_CLOUD_API_KEY from ee.onyx.configs.app_configs import SUPER_USERS from ee.onyx.server.seeding import get_seed_config from onyx.auth.users import current_admin_user from onyx.configs.app_configs import AUTH_TYPE from onyx.configs.app_configs import USER_AUTH_SECRET from onyx.db.models import User from onyx.utils.logger import setup_logger logger = setup_logger() def verify_auth_setting() -> None: # All the Auth flows are valid for EE version, but warn about deprecated 'disabled' raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower() if raw_auth_type == "disabled": logger.warning( "AUTH_TYPE='disabled' is no longer supported. Using 'basic' instead. Please update your configuration." ) logger.notice(f"Using Auth Type: {AUTH_TYPE.value}") def get_default_admin_user_emails_() -> list[str]: seed_config = get_seed_config() if seed_config and seed_config.admin_user_emails: return seed_config.admin_user_emails return [] async def current_cloud_superuser( request: Request, user: User = Depends(current_admin_user), ) -> User: api_key = request.headers.get("Authorization", "").replace("Bearer ", "") if api_key != SUPER_CLOUD_API_KEY: raise HTTPException(status_code=401, detail="Invalid API key") if user and user.email not in SUPER_USERS: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, detail="Access denied. User must be a cloud superuser to perform this action.", ) return user def generate_anonymous_user_jwt_token(tenant_id: str) -> str: payload = { "tenant_id": tenant_id, # Token does not expire "iat": datetime.utcnow(), # Issued at time } return jwt.encode(payload, USER_AUTH_SECRET, algorithm="HS256") def decode_anonymous_user_jwt_token(token: str) -> dict: return jwt.decode(token, USER_AUTH_SECRET, algorithms=["HS256"]) ================================================ FILE: backend/ee/onyx/background/celery/apps/heavy.py ================================================ from onyx.background.celery.apps import app_base from onyx.background.celery.apps.heavy import celery_app celery_app.autodiscover_tasks( app_base.filter_task_modules( [ "ee.onyx.background.celery.tasks.doc_permission_syncing", "ee.onyx.background.celery.tasks.external_group_syncing", "ee.onyx.background.celery.tasks.cleanup", "ee.onyx.background.celery.tasks.query_history", ] ) ) ================================================ FILE: backend/ee/onyx/background/celery/apps/light.py ================================================ from onyx.background.celery.apps import app_base from onyx.background.celery.apps.light import celery_app celery_app.autodiscover_tasks( app_base.filter_task_modules( [ "ee.onyx.background.celery.tasks.doc_permission_syncing", "ee.onyx.background.celery.tasks.external_group_syncing", ] ) ) ================================================ FILE: backend/ee/onyx/background/celery/apps/monitoring.py ================================================ from onyx.background.celery.apps import app_base from onyx.background.celery.apps.monitoring import celery_app celery_app.autodiscover_tasks( app_base.filter_task_modules( [ "ee.onyx.background.celery.tasks.tenant_provisioning", ] ) ) ================================================ FILE: backend/ee/onyx/background/celery/apps/primary.py ================================================ from onyx.background.celery.apps import app_base from onyx.background.celery.apps.primary import celery_app celery_app.autodiscover_tasks( app_base.filter_task_modules( [ "ee.onyx.background.celery.tasks.hooks", "ee.onyx.background.celery.tasks.doc_permission_syncing", "ee.onyx.background.celery.tasks.external_group_syncing", "ee.onyx.background.celery.tasks.cloud", "ee.onyx.background.celery.tasks.ttl_management", "ee.onyx.background.celery.tasks.usage_reporting", ] ) ) ================================================ FILE: backend/ee/onyx/background/celery/tasks/beat_schedule.py ================================================ from datetime import timedelta from typing import Any from ee.onyx.configs.app_configs import CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS from onyx.background.celery.tasks.beat_schedule import ( beat_cloud_tasks as base_beat_system_tasks, ) from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT from onyx.background.celery.tasks.beat_schedule import ( beat_task_templates as base_beat_task_templates, ) from onyx.background.celery.tasks.beat_schedule import generate_cloud_tasks from onyx.background.celery.tasks.beat_schedule import ( get_tasks_to_schedule as base_get_tasks_to_schedule, ) from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from shared_configs.configs import MULTI_TENANT ee_beat_system_tasks: list[dict] = [] ee_beat_task_templates: list[dict] = [ { "name": "autogenerate-usage-report", "task": OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK, "schedule": timedelta(days=30), "options": { "priority": OnyxCeleryPriority.MEDIUM, "expires": BEAT_EXPIRES_DEFAULT, }, }, { "name": "check-ttl-management", "task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK, "schedule": timedelta(hours=CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS), "options": { "priority": OnyxCeleryPriority.MEDIUM, "expires": BEAT_EXPIRES_DEFAULT, }, }, { "name": "export-query-history-cleanup-task", "task": OnyxCeleryTask.EXPORT_QUERY_HISTORY_CLEANUP_TASK, "schedule": timedelta(hours=1), "options": { "priority": OnyxCeleryPriority.MEDIUM, "expires": BEAT_EXPIRES_DEFAULT, "queue": OnyxCeleryQueues.CSV_GENERATION, }, }, ] ee_tasks_to_schedule: list[dict] = [] if not MULTI_TENANT: ee_tasks_to_schedule = [ { "name": "hook-execution-log-cleanup", "task": OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK, "schedule": timedelta(days=1), "options": { "priority": OnyxCeleryPriority.LOW, "expires": BEAT_EXPIRES_DEFAULT, }, }, { "name": "autogenerate-usage-report", "task": OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK, "schedule": timedelta(days=30), # TODO: change this to config flag "options": { "priority": OnyxCeleryPriority.MEDIUM, "expires": BEAT_EXPIRES_DEFAULT, }, }, { "name": "check-ttl-management", "task": OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK, "schedule": timedelta(hours=CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS), "options": { "priority": OnyxCeleryPriority.MEDIUM, "expires": BEAT_EXPIRES_DEFAULT, }, }, { "name": "export-query-history-cleanup-task", "task": OnyxCeleryTask.EXPORT_QUERY_HISTORY_CLEANUP_TASK, "schedule": timedelta(hours=1), "options": { "priority": OnyxCeleryPriority.MEDIUM, "expires": BEAT_EXPIRES_DEFAULT, "queue": OnyxCeleryQueues.CSV_GENERATION, }, }, ] def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]: beat_system_tasks = ee_beat_system_tasks + base_beat_system_tasks beat_task_templates = ee_beat_task_templates + base_beat_task_templates cloud_tasks = generate_cloud_tasks( beat_system_tasks, beat_task_templates, beat_multiplier ) return cloud_tasks def get_tasks_to_schedule() -> list[dict[str, Any]]: return ee_tasks_to_schedule + base_get_tasks_to_schedule() ================================================ FILE: backend/ee/onyx/background/celery/tasks/cleanup/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/background/celery/tasks/cleanup/tasks.py ================================================ from datetime import datetime from datetime import timedelta from celery import shared_task from ee.onyx.db.query_history import get_all_query_history_export_tasks from onyx.configs.app_configs import JOB_TIMEOUT from onyx.configs.constants import OnyxCeleryTask from onyx.db.engine.sql_engine import get_session_with_tenant from onyx.db.enums import TaskStatus from onyx.db.tasks import delete_task_with_id from onyx.utils.logger import setup_logger logger = setup_logger() @shared_task( name=OnyxCeleryTask.EXPORT_QUERY_HISTORY_CLEANUP_TASK, ignore_result=True, soft_time_limit=JOB_TIMEOUT, ) def export_query_history_cleanup_task(*, tenant_id: str) -> None: with get_session_with_tenant(tenant_id=tenant_id) as db_session: tasks = get_all_query_history_export_tasks(db_session=db_session) for task in tasks: if task.status == TaskStatus.SUCCESS: delete_task_with_id(db_session=db_session, task_id=task.task_id) elif task.status == TaskStatus.FAILURE: if task.start_time: deadline = task.start_time + timedelta(hours=24) now = datetime.now() if now < deadline: continue logger.error( f"Task with {task.task_id=} failed; it is being deleted now" ) delete_task_with_id(db_session=db_session, task_id=task.task_id) ================================================ FILE: backend/ee/onyx/background/celery/tasks/cloud/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/background/celery/tasks/cloud/tasks.py ================================================ import time from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded from redis.lock import Lock as RedisLock from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.tasks.beat_schedule import BEAT_EXPIRES_DEFAULT from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import ONYX_CLOUD_TENANT_ID from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisLocks from onyx.db.engine.tenant_utils import get_all_tenant_ids from onyx.redis.redis_pool import get_redis_client from onyx.redis.redis_pool import redis_lock_dump from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST @shared_task( name=OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR, ignore_result=True, trail=False, bind=True, ) def cloud_beat_task_generator( self: Task, task_name: str, queue: str = OnyxCeleryTask.DEFAULT, priority: int = OnyxCeleryPriority.MEDIUM, expires: int = BEAT_EXPIRES_DEFAULT, ) -> bool | None: """a lightweight task used to kick off individual beat tasks per tenant.""" time_start = time.monotonic() redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID) lock_beat: RedisLock = redis_client.lock( f"{OnyxRedisLocks.CLOUD_BEAT_TASK_GENERATOR_LOCK}:{task_name}", timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT, ) # these tasks should never overlap if not lock_beat.acquire(blocking=False): return None last_lock_time = time.monotonic() tenant_ids: list[str] = [] num_processed_tenants = 0 try: tenant_ids = get_all_tenant_ids() # NOTE: for now, we are running tasks for gated tenants, since we want to allow # connector deletion to run successfully. The new plan is to continously prune # the gated tenants set, so we won't have a build up of old, unused gated tenants. # Keeping this around in case we want to revert to the previous behavior. # gated_tenants = get_gated_tenants() for tenant_id in tenant_ids: # Same comment here as the above NOTE # if tenant_id in gated_tenants: # continue current_time = time.monotonic() if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4): lock_beat.reacquire() last_lock_time = current_time # needed in the cloud if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST: continue self.app.send_task( task_name, kwargs=dict( tenant_id=tenant_id, ), queue=queue, priority=priority, expires=expires, ignore_result=True, ) num_processed_tenants += 1 except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." ) except Exception: task_logger.exception("Unexpected exception during cloud_beat_task_generator") finally: if not lock_beat.owned(): task_logger.error( "cloud_beat_task_generator - Lock not owned on completion" ) redis_lock_dump(lock_beat, redis_client) else: lock_beat.release() time_elapsed = time.monotonic() - time_start task_logger.info( f"cloud_beat_task_generator finished: " f"task={task_name} " f"num_processed_tenants={num_processed_tenants} " f"num_tenants={len(tenant_ids)} " f"elapsed={time_elapsed:.2f}" ) return True ================================================ FILE: backend/ee/onyx/background/celery/tasks/doc_permission_syncing/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/background/celery/tasks/doc_permission_syncing/tasks.py ================================================ import time from datetime import datetime from datetime import timedelta from datetime import timezone from time import sleep from typing import Any from typing import cast from uuid import uuid4 from celery import Celery from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded from pydantic import ValidationError from redis import Redis from redis.exceptions import LockError from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from tenacity import retry from tenacity import retry_if_exception from tenacity import stop_after_delay from tenacity import wait_random_exponential from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs from ee.onyx.db.document import upsert_document_external_perms from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config from onyx.access.models import DocExternalAccess from onyx.access.models import ElementExternalAccess from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.celery_redis import celery_find_task from onyx.background.celery.celery_redis import celery_get_broker_client from onyx.background.celery.celery_redis import celery_get_queue_length from onyx.background.celery.celery_redis import celery_get_queued_task_ids from onyx.background.celery.celery_redis import celery_get_unacked_task_ids from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT from onyx.configs.app_configs import JOB_TIMEOUT from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX from onyx.configs.constants import DocumentSource from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisConstants from onyx.configs.constants import OnyxRedisLocks from onyx.configs.constants import OnyxRedisSignals from onyx.connectors.factory import validate_ccpair_for_user from onyx.db.connector import mark_cc_pair_as_permissions_synced from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.document import get_document_ids_for_connector_credential_pair from onyx.db.document import get_documents_for_connector_credential_pair_limited_columns from onyx.db.document import upsert_document_by_connector_credential_pair from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.engine.sql_engine import get_session_with_tenant from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import SyncStatus from onyx.db.enums import SyncType from onyx.db.hierarchy import ( update_hierarchy_node_permissions as db_update_hierarchy_node_permissions, ) from onyx.db.models import ConnectorCredentialPair from onyx.db.permission_sync_attempt import complete_doc_permission_sync_attempt from onyx.db.permission_sync_attempt import create_doc_permission_sync_attempt from onyx.db.permission_sync_attempt import mark_doc_permission_sync_attempt_failed from onyx.db.permission_sync_attempt import ( mark_doc_permission_sync_attempt_in_progress, ) from onyx.db.sync_record import insert_sync_record from onyx.db.sync_record import update_sync_record_status from onyx.db.users import batch_add_ext_perm_user_if_not_exists from onyx.db.utils import DocumentRow from onyx.db.utils import is_retryable_sqlalchemy_error from onyx.db.utils import SortOrder from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.redis.redis_connector import RedisConnector from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyncPayload from onyx.redis.redis_pool import get_redis_client from onyx.redis.redis_pool import get_redis_replica_client from onyx.redis.redis_pool import redis_lock_dump from onyx.server.runtime.onyx_runtime import OnyxRuntime from onyx.server.utils import make_short_id from onyx.utils.logger import doc_permission_sync_ctx from onyx.utils.logger import format_error_for_logging from onyx.utils.logger import LoggerContextVars from onyx.utils.logger import setup_logger from onyx.utils.telemetry import optional_telemetry from onyx.utils.telemetry import RecordType from shared_configs.configs import MULTI_TENANT logger = setup_logger() DOCUMENT_PERMISSIONS_UPDATE_MAX_RETRIES = 3 DOCUMENT_PERMISSIONS_UPDATE_STOP_AFTER = 10 * 60 DOCUMENT_PERMISSIONS_UPDATE_MAX_WAIT = 60 # 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT LIGHT_SOFT_TIME_LIMIT = 105 LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15 def _get_fence_validation_block_expiration() -> int: """ Compute the expiration time for the fence validation block signal. Base expiration is 300 seconds, multiplied by the beat multiplier only in MULTI_TENANT mode. """ base_expiration = 300 # seconds if not MULTI_TENANT: return base_expiration try: beat_multiplier = OnyxRuntime.get_beat_multiplier() except Exception: beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT return int(base_expiration * beat_multiplier) """Jobs / utils for kicking off doc permissions sync tasks.""" def _fail_doc_permission_sync_attempt(attempt_id: int, error_msg: str) -> None: """Helper to mark a doc permission sync attempt as failed with an error message.""" with get_session_with_current_tenant() as db_session: mark_doc_permission_sync_attempt_failed( attempt_id, db_session, error_message=error_msg ) def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> bool: """Returns boolean indicating if external doc permissions sync is due.""" if cc_pair.access_type != AccessType.SYNC: return False # skip doc permissions sync if not active if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE: return False sync_config = get_source_perm_sync_config(cc_pair.connector.source) if sync_config is None: logger.error(f"No sync config found for {cc_pair.connector.source}") return False if sync_config.doc_sync_config is None: logger.error(f"No doc sync config found for {cc_pair.connector.source}") return False # if indexing also does perm sync, don't start running doc_sync until at # least one indexing is done if ( sync_config.doc_sync_config.initial_index_should_sync and cc_pair.last_successful_index_time is None ): return False # If the last sync is None, it has never been run so we run the sync last_perm_sync = cc_pair.last_time_perm_sync if last_perm_sync is None: return True source_sync_period = sync_config.doc_sync_config.doc_sync_frequency source_sync_period *= int(OnyxRuntime.get_doc_permission_sync_multiplier()) # If the last sync is greater than the full fetch period, we run the sync next_sync = last_perm_sync + timedelta(seconds=source_sync_period) if datetime.now(timezone.utc) >= next_sync: return True return False @shared_task( name=OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC, ignore_result=True, soft_time_limit=JOB_TIMEOUT, bind=True, ) def check_for_doc_permissions_sync(self: Task, *, tenant_id: str) -> bool | None: # TODO(rkuo): merge into check function after lookup table for fences is added # we need to use celery's redis client to access its redis data # (which lives on a different db number) r = get_redis_client() r_replica = get_redis_replica_client() lock_beat: RedisLock = r.lock( OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK, timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT, ) # these tasks should never overlap if not lock_beat.acquire(blocking=False): return None try: # get all cc pairs that need to be synced cc_pair_ids_to_sync: list[int] = [] with get_session_with_current_tenant() as db_session: cc_pairs = get_all_auto_sync_cc_pairs(db_session) for cc_pair in cc_pairs: if _is_external_doc_permissions_sync_due(cc_pair): cc_pair_ids_to_sync.append(cc_pair.id) lock_beat.reacquire() for cc_pair_id in cc_pair_ids_to_sync: payload_id = try_creating_permissions_sync_task( self.app, cc_pair_id, r, tenant_id ) if not payload_id: continue task_logger.info( f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}" ) # we want to run this less frequently than the overall task lock_beat.reacquire() if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_PERMISSION_SYNC_FENCES): # clear any permission fences that don't have associated celery tasks in progress # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker), # or be currently executing try: r_celery = celery_get_broker_client(self.app) validate_permission_sync_fences( tenant_id, r, r_replica, r_celery, lock_beat ) except Exception: task_logger.exception( "Exception while validating permission sync fences" ) r.set( OnyxRedisSignals.BLOCK_VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=_get_fence_validation_block_expiration(), ) # use a lookup table to find active fences. We still have to verify the fence # exists since it is an optimization and not the source of truth. lock_beat.reacquire() keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES)) for key in keys: key_bytes = cast(bytes, key) if not r.exists(key_bytes): r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes) continue key_str = key_bytes.decode("utf-8") if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX): with get_session_with_current_tenant() as db_session: monitor_ccpair_permissions_taskset( tenant_id, key_bytes, r, db_session ) task_logger.info(f"check_for_doc_permissions_sync finished: tenant={tenant_id}") except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." ) except Exception as e: error_msg = format_error_for_logging(e) task_logger.warning( f"Unexpected check_for_doc_permissions_sync exception: tenant={tenant_id} {error_msg}" ) task_logger.exception( f"Unexpected check_for_doc_permissions_sync exception: tenant={tenant_id}" ) finally: if lock_beat.owned(): lock_beat.release() return True def try_creating_permissions_sync_task( app: Celery, cc_pair_id: int, r: Redis, tenant_id: str, ) -> str | None: """Returns a randomized payload id on success. Returns None if no syncing is required.""" LOCK_TIMEOUT = 30 payload_id: str | None = None redis_connector = RedisConnector(tenant_id, cc_pair_id) lock: RedisLock = r.lock( DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks", timeout=LOCK_TIMEOUT, ) acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2) if not acquired: return None try: if redis_connector.permissions.fenced: return None if redis_connector.delete.fenced: return None if redis_connector.prune.fenced: return None redis_connector.permissions.generator_clear() redis_connector.permissions.taskset_clear() custom_task_id = f"{redis_connector.permissions.generator_task_key}_{uuid4()}" # create before setting fence to avoid race condition where the monitoring # task updates the sync record before it is created try: with get_session_with_current_tenant() as db_session: insert_sync_record( db_session=db_session, entity_id=cc_pair_id, sync_type=SyncType.EXTERNAL_PERMISSIONS, ) except Exception: task_logger.exception("insert_sync_record exceptioned.") # set a basic fence to start redis_connector.permissions.set_active() payload = RedisConnectorPermissionSyncPayload( id=make_short_id(), submitted=datetime.now(timezone.utc), started=None, celery_task_id=None, ) redis_connector.permissions.set_fence(payload) result = app.send_task( OnyxCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK, kwargs=dict( cc_pair_id=cc_pair_id, tenant_id=tenant_id, ), queue=OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, task_id=custom_task_id, priority=OnyxCeleryPriority.MEDIUM, ) # fill in the celery task id payload.celery_task_id = result.id redis_connector.permissions.set_fence(payload) payload_id = payload.id except Exception as e: error_msg = format_error_for_logging(e) task_logger.warning( f"Unexpected try_creating_permissions_sync_task exception: cc_pair={cc_pair_id} {error_msg}" ) return None finally: if lock.owned(): lock.release() task_logger.info( f"try_creating_permissions_sync_task finished: cc_pair={cc_pair_id} payload_id={payload_id}" ) return payload_id @shared_task( name=OnyxCeleryTask.CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK, acks_late=False, soft_time_limit=JOB_TIMEOUT, track_started=True, trail=False, bind=True, ) def connector_permission_sync_generator_task( self: Task, cc_pair_id: int, tenant_id: str, ) -> None: """ Permission sync task that handles document permission syncing for a given connector credential pair This task assumes that the task has already been properly fenced """ payload_id: str | None = None LoggerContextVars.reset() doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get() doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id doc_permission_sync_ctx_dict["request_id"] = self.request.id doc_permission_sync_ctx.set(doc_permission_sync_ctx_dict) with get_session_with_current_tenant() as db_session: attempt_id = create_doc_permission_sync_attempt( connector_credential_pair_id=cc_pair_id, db_session=db_session, ) task_logger.info( f"Created doc permission sync attempt: {attempt_id} for cc_pair={cc_pair_id}" ) redis_connector = RedisConnector(tenant_id, cc_pair_id) r = get_redis_client() # this wait is needed to avoid a race condition where # the primary worker sends the task and it is immediately executed # before the primary worker can finalize the fence start = time.monotonic() while True: if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT: error_msg = ( f"connector_permission_sync_generator_task - timed out waiting for fence to be ready: " f"fence={redis_connector.permissions.fence_key}" ) _fail_doc_permission_sync_attempt(attempt_id, error_msg) raise ValueError(error_msg) if not redis_connector.permissions.fenced: # The fence must exist error_msg = f"connector_permission_sync_generator_task - fence not found: fence={redis_connector.permissions.fence_key}" _fail_doc_permission_sync_attempt(attempt_id, error_msg) raise ValueError(error_msg) payload = redis_connector.permissions.payload # The payload must exist if not payload: error_msg = ( "connector_permission_sync_generator_task: payload invalid or not found" ) _fail_doc_permission_sync_attempt(attempt_id, error_msg) raise ValueError(error_msg) if payload.celery_task_id is None: logger.info( f"connector_permission_sync_generator_task - Waiting for fence: fence={redis_connector.permissions.fence_key}" ) sleep(1) continue payload_id = payload.id logger.info( f"connector_permission_sync_generator_task - Fence found, continuing...: " f"fence={redis_connector.permissions.fence_key} " f"payload_id={payload.id}" ) break lock: RedisLock = r.lock( OnyxRedisLocks.CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX + f"_{redis_connector.cc_pair_id}", timeout=CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT, thread_local=False, ) acquired = lock.acquire(blocking=False) if not acquired: error_msg = ( f"Permission sync task already running, exiting...: cc_pair={cc_pair_id}" ) task_logger.warning(error_msg) _fail_doc_permission_sync_attempt(attempt_id, error_msg) return None try: with get_session_with_current_tenant() as db_session: cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, eager_load_connector=True, eager_load_credential=True, ) if cc_pair is None: raise ValueError( f"No connector credential pair found for id: {cc_pair_id}" ) try: created = validate_ccpair_for_user( cc_pair.connector.id, cc_pair.credential.id, cc_pair.access_type, db_session, enforce_creation=False, ) if not created: task_logger.warning( f"Unable to create connector credential pair for id: {cc_pair_id}" ) except Exception: task_logger.exception( f"validate_ccpair_permissions_sync exceptioned: cc_pair={cc_pair_id}" ) # TODO: add some notification to the admins here raise source_type = cc_pair.connector.source sync_config = get_source_perm_sync_config(source_type) if sync_config is None: error_msg = f"No sync config found for {source_type}" logger.error(error_msg) _fail_doc_permission_sync_attempt(attempt_id, error_msg) return None if sync_config.doc_sync_config is None: if sync_config.censoring_config: error_msg = f"Doc sync config is None but censoring config exists for {source_type}" _fail_doc_permission_sync_attempt(attempt_id, error_msg) return None raise ValueError( f"No doc sync func found for {source_type} with cc_pair={cc_pair_id}" ) logger.info(f"Syncing docs for {source_type} with cc_pair={cc_pair_id}") mark_doc_permission_sync_attempt_in_progress(attempt_id, db_session) payload = redis_connector.permissions.payload if not payload: raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}") new_payload = RedisConnectorPermissionSyncPayload( id=payload.id, submitted=payload.submitted, started=datetime.now(timezone.utc), celery_task_id=payload.celery_task_id, ) redis_connector.permissions.set_fence(new_payload) callback = PermissionSyncCallback( redis_connector, lock, r, timeout_seconds=JOB_TIMEOUT ) # pass in the capability to fetch all existing docs for the cc_pair # this is can be used to determine documents that are "missing" and thus # should no longer be accessible. The decision as to whether we should find # every document during the doc sync process is connector-specific. def fetch_all_existing_docs_fn( sort_order: SortOrder | None = None, ) -> list[DocumentRow]: result = get_documents_for_connector_credential_pair_limited_columns( db_session=db_session, connector_id=cc_pair.connector.id, credential_id=cc_pair.credential.id, sort_order=sort_order, ) return list(result) def fetch_all_existing_docs_ids_fn() -> list[str]: result = get_document_ids_for_connector_credential_pair( db_session=db_session, connector_id=cc_pair.connector.id, credential_id=cc_pair.credential.id, ) return result doc_sync_func = sync_config.doc_sync_config.doc_sync_func document_external_accesses = doc_sync_func( cc_pair, fetch_all_existing_docs_fn, fetch_all_existing_docs_ids_fn, callback, ) task_logger.info( f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}" ) tasks_generated = 0 docs_with_errors = 0 for doc_external_access in document_external_accesses: if callback.should_stop(): raise RuntimeError( f"Permission sync task timed out or stop signal detected: " f"cc_pair={cc_pair_id} " f"tasks_generated={tasks_generated}" ) result = redis_connector.permissions.update_db( lock=lock, new_permissions=[doc_external_access], source_string=source_type, connector_id=cc_pair.connector.id, credential_id=cc_pair.credential.id, task_logger=task_logger, ) tasks_generated += result.num_updated docs_with_errors += result.num_errors task_logger.info( f"RedisConnector.permissions.generate_tasks finished. " f"cc_pair={cc_pair_id} tasks_generated={tasks_generated} docs_with_errors={docs_with_errors}" ) complete_doc_permission_sync_attempt( db_session=db_session, attempt_id=attempt_id, total_docs_synced=tasks_generated, docs_with_permission_errors=docs_with_errors, ) task_logger.info( f"Completed doc permission sync attempt {attempt_id}: {tasks_generated} docs, {docs_with_errors} errors" ) redis_connector.permissions.generator_complete = tasks_generated except Exception as e: error_msg = format_error_for_logging(e) task_logger.warning( f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id} {error_msg}" ) task_logger.exception( f"Permission sync exceptioned: cc_pair={cc_pair_id} payload_id={payload_id}" ) with get_session_with_current_tenant() as db_session: mark_doc_permission_sync_attempt_failed( attempt_id, db_session, error_message=error_msg ) redis_connector.permissions.generator_clear() redis_connector.permissions.taskset_clear() redis_connector.permissions.set_fence(None) raise e finally: if lock.owned(): lock.release() task_logger.info( f"Permission sync finished: cc_pair={cc_pair_id} payload_id={payload.id}" ) # NOTE(rkuo): this should probably move to the db layer @retry( retry=retry_if_exception(is_retryable_sqlalchemy_error), wait=wait_random_exponential( multiplier=1, max=DOCUMENT_PERMISSIONS_UPDATE_MAX_WAIT ), stop=stop_after_delay(DOCUMENT_PERMISSIONS_UPDATE_STOP_AFTER), ) def element_update_permissions( tenant_id: str, permissions: ElementExternalAccess, source_type_str: str, connector_id: int, credential_id: int, ) -> bool: """Update permissions for a document or hierarchy node.""" start = time.monotonic() external_access = permissions.external_access # Determine element type and identifier for logging if isinstance(permissions, DocExternalAccess): element_id = permissions.doc_id element_type = "doc" else: element_id = permissions.raw_node_id element_type = "node" try: with get_session_with_tenant(tenant_id=tenant_id) as db_session: # Add the users to the DB if they don't exist batch_add_ext_perm_user_if_not_exists( db_session=db_session, emails=list(external_access.external_user_emails), continue_on_error=True, ) if isinstance(permissions, DocExternalAccess): # Document permission update created_new_doc = upsert_document_external_perms( db_session=db_session, doc_id=permissions.doc_id, external_access=external_access, source_type=DocumentSource(source_type_str), ) if created_new_doc: # If a new document was created, we associate it with the cc_pair upsert_document_by_connector_credential_pair( db_session=db_session, connector_id=connector_id, credential_id=credential_id, document_ids=[permissions.doc_id], ) else: # Hierarchy node permission update db_update_hierarchy_node_permissions( db_session=db_session, raw_node_id=permissions.raw_node_id, source=DocumentSource(permissions.source), is_public=external_access.is_public, external_user_emails=( list(external_access.external_user_emails) if external_access.external_user_emails else None ), external_user_group_ids=( list(external_access.external_user_group_ids) if external_access.external_user_group_ids else None ), ) elapsed = time.monotonic() - start task_logger.info( f"{element_type}={element_id} action=update_permissions elapsed={elapsed:.2f}" ) except Exception as e: task_logger.exception( f"element_update_permissions exceptioned: {element_type}={element_id}, {connector_id=} {credential_id=}" ) raise e finally: task_logger.info( f"element_update_permissions completed: {element_type}={element_id}, {connector_id=} {credential_id=}" ) return True def validate_permission_sync_fences( tenant_id: str, r: Redis, r_replica: Redis, r_celery: Redis, lock_beat: RedisLock, ) -> None: # building lookup table can be expensive, so we won't bother # validating until the queue is small PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN = 1024 queue_len = celery_get_queue_length( OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery ) if queue_len > PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN: return queued_upsert_tasks = celery_get_queued_task_ids( OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery ) reserved_generator_tasks = celery_get_unacked_task_ids( OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery ) # validate all existing permission sync jobs lock_beat.reacquire() keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES)) for key in keys: key_bytes = cast(bytes, key) key_str = key_bytes.decode("utf-8") if not key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX): continue validate_permission_sync_fence( tenant_id, key_bytes, queued_upsert_tasks, reserved_generator_tasks, r, r_celery, ) lock_beat.reacquire() return def validate_permission_sync_fence( tenant_id: str, key_bytes: bytes, queued_tasks: set[str], reserved_tasks: set[str], r: Redis, r_celery: Redis, ) -> None: """Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist. This can happen if the indexing worker hard crashes or is terminated. Being in this bad state means the fence will never clear without help, so this function gives the help. How this works: 1. This function renews the active signal with a 5 minute TTL under the following conditions 1.2. When the task is seen in the redis queue 1.3. When the task is seen in the reserved / prefetched list 2. Externally, the active signal is renewed when: 2.1. The fence is created 2.2. The indexing watchdog checks the spawned task. 3. The TTL allows us to get through the transitions on fence startup and when the task starts executing. More TTL clarification: it is seemingly impossible to exactly query Celery for whether a task is in the queue or currently executing. 1. An unknown task id is always returned as state PENDING. 2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task and the time it actually starts on the worker. queued_tasks: the celery queue of lightweight permission sync tasks reserved_tasks: prefetched tasks for sync task generator """ # if the fence doesn't exist, there's nothing to do fence_key = key_bytes.decode("utf-8") cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key) if cc_pair_id_str is None: task_logger.warning( f"validate_permission_sync_fence - could not parse id from {fence_key}" ) return cc_pair_id = int(cc_pair_id_str) # parse out metadata and initialize the helper class with it redis_connector = RedisConnector(tenant_id, int(cc_pair_id)) # check to see if the fence/payload exists if not redis_connector.permissions.fenced: return # in the cloud, the payload format may have changed ... # it's a little sloppy, but just reset the fence for now if that happens # TODO: add intentional cleanup/abort logic try: payload = redis_connector.permissions.payload except ValidationError: task_logger.exception( "validate_permission_sync_fence - " "Resetting fence because fence schema is out of date: " f"cc_pair={cc_pair_id} " f"fence={fence_key}" ) redis_connector.permissions.reset() return if not payload: return if not payload.celery_task_id: return # OK, there's actually something for us to validate # either the generator task must be in flight or its subtasks must be found = celery_find_task( payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery, ) if found: # the celery task exists in the redis queue redis_connector.permissions.set_active() return if payload.celery_task_id in reserved_tasks: # the celery task was prefetched and is reserved within a worker redis_connector.permissions.set_active() return # look up every task in the current taskset in the celery queue # every entry in the taskset should have an associated entry in the celery task queue # because we get the celery tasks first, the entries in our own permissions taskset # should be roughly a subset of the tasks in celery # this check isn't very exact, but should be sufficient over a period of time # A single successful check over some number of attempts is sufficient. # TODO: if the number of tasks in celery is much lower than than the taskset length # we might be able to shortcut the lookup since by definition some of the tasks # must not exist in celery. tasks_scanned = 0 tasks_not_in_celery = 0 # a non-zero number after completing our check is bad for member in r.sscan_iter(redis_connector.permissions.taskset_key): tasks_scanned += 1 member_bytes = cast(bytes, member) member_str = member_bytes.decode("utf-8") if member_str in queued_tasks: continue if member_str in reserved_tasks: continue tasks_not_in_celery += 1 task_logger.info( f"validate_permission_sync_fence task check: tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}" ) # we're active if there are still tasks to run and those tasks all exist in celery if tasks_scanned > 0 and tasks_not_in_celery == 0: redis_connector.permissions.set_active() return # we may want to enable this check if using the active task list somehow isn't good enough # if redis_connector_index.generator_locked(): # logger.info(f"{payload.celery_task_id} is currently executing.") # if we get here, we didn't find any direct indication that the associated celery tasks exist, # but they still might be there due to gaps in our ability to check states during transitions # Checking the active signal safeguards us against these transition periods # (which has a duration that allows us to bridge those gaps) if redis_connector.permissions.active(): return # celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up. task_logger.warning( "validate_permission_sync_fence - " "Resetting fence because no associated celery tasks were found: " f"cc_pair={cc_pair_id} " f"fence={fence_key} " f"payload_id={payload.id}" ) redis_connector.permissions.reset() return class PermissionSyncCallback(IndexingHeartbeatInterface): PARENT_CHECK_INTERVAL = 60 def __init__( self, redis_connector: RedisConnector, redis_lock: RedisLock, redis_client: Redis, timeout_seconds: int | None = None, ): super().__init__() self.redis_connector: RedisConnector = redis_connector self.redis_lock: RedisLock = redis_lock self.redis_client = redis_client self.started: datetime = datetime.now(timezone.utc) self.redis_lock.reacquire() self.last_tag: str = "PermissionSyncCallback.__init__" self.last_lock_reacquire: datetime = datetime.now(timezone.utc) self.last_lock_monotonic = time.monotonic() self.start_monotonic = time.monotonic() self.timeout_seconds = timeout_seconds def should_stop(self) -> bool: if self.redis_connector.stop.fenced: return True # Check if the task has exceeded its timeout # NOTE: Celery's soft_time_limit does not work with thread pools, # so we must enforce timeouts internally. if self.timeout_seconds is not None: elapsed = time.monotonic() - self.start_monotonic if elapsed > self.timeout_seconds: logger.warning( f"PermissionSyncCallback - task timeout exceeded: " f"elapsed={elapsed:.0f}s timeout={self.timeout_seconds}s " f"cc_pair={self.redis_connector.cc_pair_id}" ) return True return False def progress(self, tag: str, amount: int) -> None: # noqa: ARG002 try: self.redis_connector.permissions.set_active() current_time = time.monotonic() if current_time - self.last_lock_monotonic >= ( CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4 ): self.redis_lock.reacquire() self.last_lock_reacquire = datetime.now(timezone.utc) self.last_lock_monotonic = time.monotonic() self.last_tag = tag except LockError: logger.exception( f"PermissionSyncCallback - lock.reacquire exceptioned: " f"lock_timeout={self.redis_lock.timeout} " f"start={self.started} " f"last_tag={self.last_tag} " f"last_reacquired={self.last_lock_reacquire} " f"now={datetime.now(timezone.utc)}" ) redis_lock_dump(self.redis_lock, self.redis_client) raise """Monitoring CCPair permissions utils""" def monitor_ccpair_permissions_taskset( tenant_id: str, key_bytes: bytes, r: Redis, # noqa: ARG001 db_session: Session, ) -> None: fence_key = key_bytes.decode("utf-8") cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key) if cc_pair_id_str is None: task_logger.warning( f"monitor_ccpair_permissions_taskset: could not parse cc_pair_id from {fence_key}" ) return cc_pair_id = int(cc_pair_id_str) redis_connector = RedisConnector(tenant_id, cc_pair_id) if not redis_connector.permissions.fenced: return initial = redis_connector.permissions.generator_complete if initial is None: return try: payload = redis_connector.permissions.payload except ValidationError: task_logger.exception( "Permissions sync payload failed to validate. Schema may have been updated." ) return if not payload: return remaining = redis_connector.permissions.get_remaining() task_logger.info( f"Permissions sync progress: cc_pair={cc_pair_id} id={payload.id} remaining={remaining} initial={initial}" ) # Add telemetry for permission syncing progress optional_telemetry( record_type=RecordType.PERMISSION_SYNC_PROGRESS, data={ "cc_pair_id": cc_pair_id, "total_docs_synced": initial if initial is not None else 0, "remaining_docs_to_sync": remaining, }, tenant_id=tenant_id, ) if remaining > 0: return mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), payload.started) task_logger.info( f"Permissions sync finished: cc_pair={cc_pair_id} id={payload.id} num_synced={initial}" ) # Add telemetry for permission syncing complete optional_telemetry( record_type=RecordType.PERMISSION_SYNC_COMPLETE, data={"cc_pair_id": cc_pair_id}, tenant_id=tenant_id, ) update_sync_record_status( db_session=db_session, entity_id=cc_pair_id, sync_type=SyncType.EXTERNAL_PERMISSIONS, sync_status=SyncStatus.SUCCESS, num_docs_synced=initial, ) redis_connector.permissions.reset() ================================================ FILE: backend/ee/onyx/background/celery/tasks/external_group_syncing/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/background/celery/tasks/external_group_syncing/group_sync_utils.py ================================================ from sqlalchemy.orm import Session from ee.onyx.external_permissions.sync_params import ( source_group_sync_is_cc_pair_agnostic, ) from onyx.db.connector import mark_cc_pair_as_external_group_synced from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_source from onyx.db.models import ConnectorCredentialPair def _get_all_cc_pair_ids_to_mark_as_group_synced( db_session: Session, cc_pair: ConnectorCredentialPair ) -> list[int]: if not source_group_sync_is_cc_pair_agnostic(cc_pair.connector.source): return [cc_pair.id] cc_pairs = get_connector_credential_pairs_for_source( db_session, cc_pair.connector.source ) return [cc_pair.id for cc_pair in cc_pairs] def mark_all_relevant_cc_pairs_as_external_group_synced( db_session: Session, cc_pair: ConnectorCredentialPair ) -> None: """For some source types, one successful group sync run should count for all cc pairs of that type. This function handles that case.""" cc_pair_ids = _get_all_cc_pair_ids_to_mark_as_group_synced(db_session, cc_pair) for cc_pair_id in cc_pair_ids: mark_cc_pair_as_external_group_synced(db_session, cc_pair_id) ================================================ FILE: backend/ee/onyx/background/celery/tasks/external_group_syncing/tasks.py ================================================ import time from datetime import datetime from datetime import timedelta from datetime import timezone from typing import Any from typing import cast from uuid import uuid4 from celery import Celery from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded from pydantic import ValidationError from redis import Redis from redis.lock import Lock as RedisLock from ee.onyx.background.celery.tasks.external_group_syncing.group_sync_utils import ( mark_all_relevant_cc_pairs_as_external_group_synced, ) from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source from ee.onyx.db.external_perm import ExternalUserGroup from ee.onyx.db.external_perm import mark_old_external_groups_as_stale from ee.onyx.db.external_perm import remove_stale_external_groups from ee.onyx.db.external_perm import upsert_external_groups from ee.onyx.external_permissions.sync_params import ( get_all_cc_pair_agnostic_group_sync_sources, ) from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.celery_redis import celery_find_task from onyx.background.celery.celery_redis import celery_get_broker_client from onyx.background.celery.celery_redis import celery_get_unacked_task_ids from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT from onyx.background.error_logging import emit_background_error from onyx.configs.app_configs import JOB_TIMEOUT from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisConstants from onyx.configs.constants import OnyxRedisLocks from onyx.configs.constants import OnyxRedisSignals from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import SyncStatus from onyx.db.enums import SyncType from onyx.db.models import ConnectorCredentialPair from onyx.db.permission_sync_attempt import complete_external_group_sync_attempt from onyx.db.permission_sync_attempt import ( create_external_group_sync_attempt, ) from onyx.db.permission_sync_attempt import ( mark_external_group_sync_attempt_failed, ) from onyx.db.permission_sync_attempt import ( mark_external_group_sync_attempt_in_progress, ) from onyx.db.sync_record import insert_sync_record from onyx.db.sync_record import update_sync_record_status from onyx.redis.redis_connector import RedisConnector from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync from onyx.redis.redis_connector_ext_group_sync import ( RedisConnectorExternalGroupSyncPayload, ) from onyx.redis.redis_pool import get_redis_client from onyx.redis.redis_pool import get_redis_replica_client from onyx.server.runtime.onyx_runtime import OnyxRuntime from onyx.server.utils import make_short_id from onyx.utils.logger import format_error_for_logging from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT logger = setup_logger() _EXTERNAL_GROUP_BATCH_SIZE = 100 def _fail_external_group_sync_attempt(attempt_id: int, error_msg: str) -> None: """Helper to mark an external group sync attempt as failed with an error message.""" with get_session_with_current_tenant() as db_session: mark_external_group_sync_attempt_failed( attempt_id, db_session, error_message=error_msg ) def _get_fence_validation_block_expiration() -> int: """ Compute the expiration time for the fence validation block signal. Base expiration is 300 seconds, multiplied by the beat multiplier only in MULTI_TENANT mode. """ base_expiration = 300 # seconds if not MULTI_TENANT: return base_expiration try: beat_multiplier = OnyxRuntime.get_beat_multiplier() except Exception: beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT return int(base_expiration * beat_multiplier) def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool: """Returns boolean indicating if external group sync is due.""" if cc_pair.access_type != AccessType.SYNC: task_logger.error( f"Received non-sync CC Pair {cc_pair.id} for external group sync. Actual access type: {cc_pair.access_type}" ) return False if cc_pair.status == ConnectorCredentialPairStatus.DELETING: task_logger.debug( f"Skipping group sync for CC Pair {cc_pair.id} - CC Pair is being deleted" ) return False sync_config = get_source_perm_sync_config(cc_pair.connector.source) if sync_config is None: task_logger.debug( f"Skipping group sync for CC Pair {cc_pair.id} - no sync config found for {cc_pair.connector.source}" ) return False # If there is not group sync function for the connector, we don't run the sync # This is fine because all sources dont necessarily have a concept of groups if sync_config.group_sync_config is None: task_logger.debug( f"Skipping group sync for CC Pair {cc_pair.id} - no group sync config found for {cc_pair.connector.source}" ) return False # If the last sync is None, it has never been run so we run the sync last_ext_group_sync = cc_pair.last_time_external_group_sync if last_ext_group_sync is None: return True source_sync_period = sync_config.group_sync_config.group_sync_frequency # If the last sync is greater than the full fetch period, we run the sync next_sync = last_ext_group_sync + timedelta(seconds=source_sync_period) if datetime.now(timezone.utc) >= next_sync: return True return False @shared_task( name=OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC, ignore_result=True, soft_time_limit=JOB_TIMEOUT, bind=True, ) def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None: # we need to use celery's redis client to access its redis data # (which lives on a different db number) r = get_redis_client() r_replica = get_redis_replica_client() lock_beat: RedisLock = r.lock( OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK, timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT, ) # these tasks should never overlap if not lock_beat.acquire(blocking=False): task_logger.warning( f"Failed to acquire beat lock for external group sync: {tenant_id}" ) return None try: cc_pair_ids_to_sync: list[int] = [] with get_session_with_current_tenant() as db_session: cc_pairs = get_all_auto_sync_cc_pairs(db_session) # For some sources, we only want to sync one cc_pair per source type for source in get_all_cc_pair_agnostic_group_sync_sources(): # These are ordered by cc_pair id so the first one is the one we want cc_pairs_to_dedupe = get_cc_pairs_by_source( db_session, source, access_type=AccessType.SYNC, status=ConnectorCredentialPairStatus.ACTIVE, ) # dedupe cc_pairs to only keep the first one for cc_pair_to_remove in cc_pairs_to_dedupe[1:]: cc_pairs = [ cc_pair for cc_pair in cc_pairs if cc_pair.id != cc_pair_to_remove.id ] for cc_pair in cc_pairs: if _is_external_group_sync_due(cc_pair): cc_pair_ids_to_sync.append(cc_pair.id) lock_beat.reacquire() for cc_pair_id in cc_pair_ids_to_sync: payload_id = try_creating_external_group_sync_task( self.app, cc_pair_id, r, tenant_id ) if not payload_id: continue task_logger.info( f"External group sync queued: cc_pair={cc_pair_id} id={payload_id}" ) # we want to run this less frequently than the overall task lock_beat.reacquire() if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES): # clear fences that don't have associated celery tasks in progress # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker), # or be currently executing try: r_celery = celery_get_broker_client(self.app) validate_external_group_sync_fences( tenant_id, self.app, r, r_replica, r_celery, lock_beat ) except Exception: task_logger.exception( "Exception while validating external group sync fences" ) r.set( OnyxRedisSignals.BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=_get_fence_validation_block_expiration(), ) except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." ) except Exception as e: error_msg = format_error_for_logging(e) task_logger.warning( f"Unexpected check_for_external_group_sync exception: tenant={tenant_id} {error_msg}" ) task_logger.exception(f"Unexpected exception: tenant={tenant_id}") finally: if lock_beat.owned(): lock_beat.release() task_logger.info(f"check_for_external_group_sync finished: tenant={tenant_id}") return True def try_creating_external_group_sync_task( app: Celery, cc_pair_id: int, r: Redis, # noqa: ARG001 tenant_id: str, ) -> str | None: """Returns an int if syncing is needed. The int represents the number of sync tasks generated. Returns None if no syncing is required.""" payload_id: str | None = None redis_connector = RedisConnector(tenant_id, cc_pair_id) try: # Dont kick off a new sync if the previous one is still running if redis_connector.external_group_sync.fenced: logger.warning( f"Skipping external group sync for CC Pair {cc_pair_id} - already running." ) return None redis_connector.external_group_sync.generator_clear() redis_connector.external_group_sync.taskset_clear() # create before setting fence to avoid race condition where the monitoring # task updates the sync record before it is created try: with get_session_with_current_tenant() as db_session: insert_sync_record( db_session=db_session, entity_id=cc_pair_id, sync_type=SyncType.EXTERNAL_GROUP, ) except Exception: task_logger.exception("insert_sync_record exceptioned.") # Signal active before creating fence redis_connector.external_group_sync.set_active() payload = RedisConnectorExternalGroupSyncPayload( id=make_short_id(), submitted=datetime.now(timezone.utc), started=None, celery_task_id=None, ) redis_connector.external_group_sync.set_fence(payload) custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}" result = app.send_task( OnyxCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK, kwargs=dict( cc_pair_id=cc_pair_id, tenant_id=tenant_id, ), queue=OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, task_id=custom_task_id, priority=OnyxCeleryPriority.MEDIUM, ) payload.celery_task_id = result.id redis_connector.external_group_sync.set_fence(payload) payload_id = payload.id except Exception as e: error_msg = format_error_for_logging(e) task_logger.warning( f"Unexpected try_creating_external_group_sync_task exception: cc_pair={cc_pair_id} {error_msg}" ) task_logger.exception( f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}" ) return None task_logger.info( f"try_creating_external_group_sync_task finished: cc_pair={cc_pair_id} payload_id={payload_id}" ) return payload_id @shared_task( name=OnyxCeleryTask.CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK, acks_late=False, soft_time_limit=JOB_TIMEOUT, track_started=True, trail=False, bind=True, ) def connector_external_group_sync_generator_task( self: Task, # noqa: ARG001 cc_pair_id: int, tenant_id: str, ) -> None: """ External group sync task for a given connector credential pair This task assumes that the task has already been properly fenced """ redis_connector = RedisConnector(tenant_id, cc_pair_id) r = get_redis_client() # this wait is needed to avoid a race condition where # the primary worker sends the task and it is immediately executed # before the primary worker can finalize the fence start = time.monotonic() while True: if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT: msg = ( f"connector_external_group_sync_generator_task - timed out waiting for fence to be ready: " f"fence={redis_connector.external_group_sync.fence_key}" ) emit_background_error(msg, cc_pair_id=cc_pair_id) raise ValueError(msg) if not redis_connector.external_group_sync.fenced: # The fence must exist msg = ( f"connector_external_group_sync_generator_task - fence not found: " f"fence={redis_connector.external_group_sync.fence_key}" ) emit_background_error(msg, cc_pair_id=cc_pair_id) raise ValueError(msg) payload = redis_connector.external_group_sync.payload # The payload must exist if not payload: msg = "connector_external_group_sync_generator_task: payload invalid or not found" emit_background_error(msg, cc_pair_id=cc_pair_id) raise ValueError(msg) if payload.celery_task_id is None: logger.info( f"connector_external_group_sync_generator_task - Waiting for fence: " f"fence={redis_connector.external_group_sync.fence_key}" ) time.sleep(1) continue logger.info( f"connector_external_group_sync_generator_task - Fence found, continuing...: " f"fence={redis_connector.external_group_sync.fence_key} " f"payload_id={payload.id}" ) break lock: RedisLock = r.lock( OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX + f"_{redis_connector.cc_pair_id}", timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT, ) acquired = lock.acquire(blocking=False) if not acquired: msg = f"External group sync task already running, exiting...: cc_pair={cc_pair_id}" emit_background_error(msg, cc_pair_id=cc_pair_id) task_logger.error(msg) return None try: payload.started = datetime.now(timezone.utc) redis_connector.external_group_sync.set_fence(payload) _perform_external_group_sync( cc_pair_id=cc_pair_id, tenant_id=tenant_id, ) with get_session_with_current_tenant() as db_session: update_sync_record_status( db_session=db_session, entity_id=cc_pair_id, sync_type=SyncType.EXTERNAL_GROUP, sync_status=SyncStatus.SUCCESS, ) except Exception as e: error_msg = format_error_for_logging(e) task_logger.warning( f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id} {error_msg}" ) task_logger.exception( f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}" ) msg = f"External group sync exceptioned: cc_pair={cc_pair_id} payload_id={payload.id}" task_logger.exception(msg) emit_background_error(msg + f"\n\n{e}", cc_pair_id=cc_pair_id) with get_session_with_current_tenant() as db_session: update_sync_record_status( db_session=db_session, entity_id=cc_pair_id, sync_type=SyncType.EXTERNAL_GROUP, sync_status=SyncStatus.FAILED, ) redis_connector.external_group_sync.generator_clear() redis_connector.external_group_sync.taskset_clear() raise e finally: # we always want to clear the fence after the task is done or failed so it doesn't get stuck redis_connector.external_group_sync.set_fence(None) if lock.owned(): lock.release() task_logger.info( f"External group sync finished: cc_pair={cc_pair_id} payload_id={payload.id}" ) def _perform_external_group_sync( cc_pair_id: int, tenant_id: str, timeout_seconds: int = JOB_TIMEOUT, ) -> None: # Create attempt record at the start with get_session_with_current_tenant() as db_session: attempt_id = create_external_group_sync_attempt( connector_credential_pair_id=cc_pair_id, db_session=db_session, ) logger.info( f"Created external group sync attempt: {attempt_id} for cc_pair={cc_pair_id}" ) with get_session_with_current_tenant() as db_session: cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, eager_load_credential=True, ) if cc_pair is None: raise ValueError(f"No connector credential pair found for id: {cc_pair_id}") source_type = cc_pair.connector.source sync_config = get_source_perm_sync_config(source_type) if sync_config is None: msg = f"No sync config found for {source_type} for cc_pair: {cc_pair_id}" emit_background_error(msg, cc_pair_id=cc_pair_id) _fail_external_group_sync_attempt(attempt_id, msg) raise ValueError(msg) if sync_config.group_sync_config is None: msg = f"No group sync config found for {source_type} for cc_pair: {cc_pair_id}" emit_background_error(msg, cc_pair_id=cc_pair_id) _fail_external_group_sync_attempt(attempt_id, msg) raise ValueError(msg) ext_group_sync_func = sync_config.group_sync_config.group_sync_func logger.info( f"Marking old external groups as stale for {source_type} for cc_pair: {cc_pair_id}" ) mark_old_external_groups_as_stale(db_session, cc_pair_id) # Mark attempt as in progress mark_external_group_sync_attempt_in_progress(attempt_id, db_session) logger.info(f"Marked external group sync attempt {attempt_id} as in progress") logger.info( f"Syncing external groups for {source_type} for cc_pair: {cc_pair_id}" ) external_user_group_batch: list[ExternalUserGroup] = [] seen_users: set[str] = set() # Track unique users across all groups total_groups_processed = 0 total_group_memberships_synced = 0 start_time = time.monotonic() try: external_user_group_generator = ext_group_sync_func(tenant_id, cc_pair) for external_user_group in external_user_group_generator: # Check if the task has exceeded its timeout # NOTE: Celery's soft_time_limit does not work with thread pools, # so we must enforce timeouts internally. elapsed = time.monotonic() - start_time if elapsed > timeout_seconds: raise RuntimeError( f"External group sync task timed out: " f"cc_pair={cc_pair_id} " f"elapsed={elapsed:.0f}s " f"timeout={timeout_seconds}s " f"groups_processed={total_groups_processed}" ) external_user_group_batch.append(external_user_group) # Track progress total_groups_processed += 1 total_group_memberships_synced += len(external_user_group.user_emails) seen_users = seen_users.union(external_user_group.user_emails) if len(external_user_group_batch) >= _EXTERNAL_GROUP_BATCH_SIZE: logger.debug( f"New external user groups: {external_user_group_batch}" ) upsert_external_groups( db_session=db_session, cc_pair_id=cc_pair_id, external_groups=external_user_group_batch, source=cc_pair.connector.source, ) external_user_group_batch = [] if external_user_group_batch: logger.debug(f"New external user groups: {external_user_group_batch}") upsert_external_groups( db_session=db_session, cc_pair_id=cc_pair_id, external_groups=external_user_group_batch, source=cc_pair.connector.source, ) except Exception as e: format_error_for_logging(e) # Mark as failed (this also updates progress to show partial progress) mark_external_group_sync_attempt_failed( attempt_id, db_session, error_message=str(e) ) # TODO: add some notification to the admins here logger.exception( f"Error syncing external groups for {source_type} for cc_pair: {cc_pair_id} {e}" ) raise e logger.info( f"Removing stale external groups for {source_type} for cc_pair: {cc_pair_id}" ) remove_stale_external_groups(db_session, cc_pair_id) # Calculate total unique users processed total_users_processed = len(seen_users) # Complete the sync attempt with final progress complete_external_group_sync_attempt( db_session=db_session, attempt_id=attempt_id, total_users_processed=total_users_processed, total_groups_processed=total_groups_processed, total_group_memberships_synced=total_group_memberships_synced, errors_encountered=0, ) logger.info( f"Completed external group sync attempt {attempt_id}: " f"{total_groups_processed} groups, {total_users_processed} users, " f"{total_group_memberships_synced} memberships" ) mark_all_relevant_cc_pairs_as_external_group_synced(db_session, cc_pair) def validate_external_group_sync_fences( tenant_id: str, celery_app: Celery, # noqa: ARG001 r: Redis, # noqa: ARG001 r_replica: Redis, r_celery: Redis, lock_beat: RedisLock, ) -> None: reserved_tasks = celery_get_unacked_task_ids( OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery ) # validate all existing external group sync tasks lock_beat.reacquire() keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES)) for key in keys: key_bytes = cast(bytes, key) key_str = key_bytes.decode("utf-8") if not key_str.startswith(RedisConnectorExternalGroupSync.FENCE_PREFIX): continue validate_external_group_sync_fence( tenant_id, key_bytes, reserved_tasks, r_celery, ) lock_beat.reacquire() return def validate_external_group_sync_fence( tenant_id: str, key_bytes: bytes, reserved_tasks: set[str], r_celery: Redis, ) -> None: """Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist. This can happen if the indexing worker hard crashes or is terminated. Being in this bad state means the fence will never clear without help, so this function gives the help. How this works: 1. This function renews the active signal with a 5 minute TTL under the following conditions 1.2. When the task is seen in the redis queue 1.3. When the task is seen in the reserved / prefetched list 2. Externally, the active signal is renewed when: 2.1. The fence is created 2.2. The indexing watchdog checks the spawned task. 3. The TTL allows us to get through the transitions on fence startup and when the task starts executing. More TTL clarification: it is seemingly impossible to exactly query Celery for whether a task is in the queue or currently executing. 1. An unknown task id is always returned as state PENDING. 2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task and the time it actually starts on the worker. """ # if the fence doesn't exist, there's nothing to do fence_key = key_bytes.decode("utf-8") cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key) if cc_pair_id_str is None: msg = ( f"validate_external_group_sync_fence - could not parse id from {fence_key}" ) emit_background_error(msg) task_logger.error(msg) return cc_pair_id = int(cc_pair_id_str) # parse out metadata and initialize the helper class with it redis_connector = RedisConnector(tenant_id, int(cc_pair_id)) # check to see if the fence/payload exists if not redis_connector.external_group_sync.fenced: return try: payload = redis_connector.external_group_sync.payload except ValidationError: msg = ( "validate_external_group_sync_fence - " "Resetting fence because fence schema is out of date: " f"cc_pair={cc_pair_id} " f"fence={fence_key}" ) task_logger.exception(msg) emit_background_error(msg, cc_pair_id=cc_pair_id) redis_connector.external_group_sync.reset() return if not payload: return if not payload.celery_task_id: return # OK, there's actually something for us to validate found = celery_find_task( payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery ) if found: # the celery task exists in the redis queue # redis_connector_index.set_active() return if payload.celery_task_id in reserved_tasks: # the celery task was prefetched and is reserved within the indexing worker # redis_connector_index.set_active() return # we may want to enable this check if using the active task list somehow isn't good enough # if redis_connector_index.generator_locked(): # logger.info(f"{payload.celery_task_id} is currently executing.") # if we get here, we didn't find any direct indication that the associated celery tasks exist, # but they still might be there due to gaps in our ability to check states during transitions # Checking the active signal safeguards us against these transition periods # (which has a duration that allows us to bridge those gaps) # if redis_connector_index.active(): # return # celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up. emit_background_error( message=( "validate_external_group_sync_fence - " "Resetting fence because no associated celery tasks were found: " f"cc_pair={cc_pair_id} " f"fence={fence_key} " f"payload_id={payload.id}" ), cc_pair_id=cc_pair_id, ) redis_connector.external_group_sync.reset() return ================================================ FILE: backend/ee/onyx/background/celery/tasks/hooks/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/background/celery/tasks/hooks/tasks.py ================================================ from celery import shared_task from onyx.configs.app_configs import JOB_TIMEOUT from onyx.configs.constants import OnyxCeleryTask from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.hook import cleanup_old_execution_logs__no_commit from onyx.utils.logger import setup_logger logger = setup_logger() _HOOK_EXECUTION_LOG_RETENTION_DAYS: int = 30 @shared_task( name=OnyxCeleryTask.HOOK_EXECUTION_LOG_CLEANUP_TASK, ignore_result=True, soft_time_limit=JOB_TIMEOUT, trail=False, ) def hook_execution_log_cleanup_task(*, tenant_id: str) -> None: # noqa: ARG001 try: with get_session_with_current_tenant() as db_session: deleted: int = cleanup_old_execution_logs__no_commit( db_session=db_session, max_age_days=_HOOK_EXECUTION_LOG_RETENTION_DAYS, ) db_session.commit() if deleted: logger.info( f"Deleted {deleted} hook execution log(s) older than " f"{_HOOK_EXECUTION_LOG_RETENTION_DAYS} days." ) except Exception: logger.exception("Failed to clean up hook execution logs") raise ================================================ FILE: backend/ee/onyx/background/celery/tasks/query_history/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/background/celery/tasks/query_history/tasks.py ================================================ import csv import io from datetime import datetime from celery import shared_task from celery import Task from ee.onyx.server.query_history.api import fetch_and_process_chat_session_history from ee.onyx.server.query_history.api import ONYX_ANONYMIZED_EMAIL from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot from onyx.background.task_utils import construct_query_history_report_name from onyx.configs.app_configs import JOB_TIMEOUT from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE from onyx.configs.constants import FileOrigin from onyx.configs.constants import FileType from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import QueryHistoryType from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.tasks import delete_task_with_id from onyx.db.tasks import mark_task_as_finished_with_id from onyx.db.tasks import mark_task_as_started_with_id from onyx.file_store.file_store import get_default_file_store from onyx.utils.logger import setup_logger logger = setup_logger() @shared_task( name=OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK, ignore_result=True, soft_time_limit=JOB_TIMEOUT, bind=True, trail=False, ) def export_query_history_task( self: Task, *, start: datetime, end: datetime, start_time: datetime, # Need to include the tenant_id since the TenantAwareTask needs this tenant_id: str, # noqa: ARG001 ) -> None: if not self.request.id: raise RuntimeError("No task id defined for this task; cannot identify it") task_id = self.request.id stream = io.StringIO() writer = csv.DictWriter( stream, fieldnames=list(QuestionAnswerPairSnapshot.model_fields.keys()), ) writer.writeheader() with get_session_with_current_tenant() as db_session: try: mark_task_as_started_with_id( db_session=db_session, task_id=task_id, ) snapshot_generator = fetch_and_process_chat_session_history( db_session=db_session, start=start, end=end, ) for snapshot in snapshot_generator: if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED: snapshot.user_email = ONYX_ANONYMIZED_EMAIL writer.writerows( qa_pair.to_json() for qa_pair in QuestionAnswerPairSnapshot.from_chat_session_snapshot( snapshot ) ) except Exception: logger.exception(f"Failed to export query history with {task_id=}") mark_task_as_finished_with_id( db_session=db_session, task_id=task_id, success=False, ) raise report_name = construct_query_history_report_name(task_id) with get_session_with_current_tenant() as db_session: try: stream.seek(0) get_default_file_store().save_file( content=stream, display_name=report_name, file_origin=FileOrigin.QUERY_HISTORY_CSV, file_type=FileType.CSV, file_metadata={ "start": start.isoformat(), "end": end.isoformat(), "start_time": start_time.isoformat(), }, file_id=report_name, ) delete_task_with_id( db_session=db_session, task_id=task_id, ) except Exception: logger.exception( f"Failed to save query history export file; {report_name=}" ) mark_task_as_finished_with_id( db_session=db_session, task_id=task_id, success=False, ) raise ================================================ FILE: backend/ee/onyx/background/celery/tasks/tenant_provisioning/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/background/celery/tasks/tenant_provisioning/tasks.py ================================================ """ Periodic tasks for tenant pre-provisioning. """ import asyncio import datetime import uuid from celery import shared_task from celery import Task from redis.lock import Lock as RedisLock from ee.onyx.server.tenants.provisioning import setup_tenant from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists from ee.onyx.server.tenants.schema_management import get_current_alembic_version from ee.onyx.server.tenants.schema_management import run_alembic_migrations from onyx.background.celery.apps.app_base import task_logger from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS from onyx.configs.constants import ONYX_CLOUD_TENANT_ID from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisLocks from onyx.db.engine.sql_engine import get_session_with_shared_schema from onyx.db.models import AvailableTenant from onyx.redis.redis_pool import get_redis_client from shared_configs.configs import MULTI_TENANT from shared_configs.configs import TENANT_ID_PREFIX # Maximum tenants to provision in a single task run. # Each tenant takes ~80s (alembic migrations), so 5 tenants ≈ 7 minutes. _MAX_TENANTS_PER_RUN = 5 # Time limits sized for worst-case: provisioning up to _MAX_TENANTS_PER_RUN new tenants # (~90s each) plus migrating up to TARGET_AVAILABLE_TENANTS pool tenants (~90s each). _TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 20 # 20 minutes _TENANT_PROVISIONING_TIME_LIMIT = 60 * 25 # 25 minutes @shared_task( name=OnyxCeleryTask.CLOUD_CHECK_AVAILABLE_TENANTS, queue=OnyxCeleryQueues.MONITORING, ignore_result=True, soft_time_limit=_TENANT_PROVISIONING_SOFT_TIME_LIMIT, time_limit=_TENANT_PROVISIONING_TIME_LIMIT, trail=False, bind=True, ) def check_available_tenants(self: Task) -> None: # noqa: ARG001 """ Check if we have enough pre-provisioned tenants available. If not, trigger the pre-provisioning of new tenants. """ task_logger.info("STARTING CHECK_AVAILABLE_TENANTS") if not MULTI_TENANT: task_logger.info( "Multi-tenancy is not enabled, skipping tenant pre-provisioning" ) return r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID) lock_check: RedisLock = r.lock( OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK, timeout=_TENANT_PROVISIONING_TIME_LIMIT, ) # These tasks should never overlap if not lock_check.acquire(blocking=False): task_logger.info( "Skipping check_available_tenants task because it is already running" ) return try: # Get the current count of available tenants with get_session_with_shared_schema() as db_session: num_available_tenants = db_session.query(AvailableTenant).count() # Get the target number of available tenants num_minimum_available_tenants = TARGET_AVAILABLE_TENANTS # Calculate how many new tenants we need to provision if num_available_tenants < num_minimum_available_tenants: tenants_to_provision = num_minimum_available_tenants - num_available_tenants else: tenants_to_provision = 0 task_logger.info( f"Available tenants: {num_available_tenants}, " f"Target minimum available tenants: {num_minimum_available_tenants}, " f"To provision: {tenants_to_provision}" ) batch_size = min(tenants_to_provision, _MAX_TENANTS_PER_RUN) if batch_size < tenants_to_provision: task_logger.info( f"Capping batch to {batch_size} (need {tenants_to_provision}, will catch up next cycle)" ) provisioned = 0 for i in range(batch_size): task_logger.info(f"Provisioning tenant {i + 1}/{batch_size}") try: if pre_provision_tenant(): provisioned += 1 except Exception: task_logger.exception( f"Failed to provision tenant {i + 1}/{batch_size}, continuing with remaining tenants" ) task_logger.info(f"Provisioning complete: {provisioned}/{batch_size} succeeded") # Migrate any pool tenants that were provisioned before a new migration was deployed _migrate_stale_pool_tenants() except Exception: task_logger.exception("Error in check_available_tenants task") finally: try: lock_check.release() except Exception: task_logger.warning( "Could not release check lock (likely expired), continuing" ) def _migrate_stale_pool_tenants() -> None: """ Run alembic upgrade head on all pool tenants. Since alembic upgrade head is idempotent, tenants already at head are a fast no-op. This ensures pool tenants are always current so that signup doesn't hit schema mismatches (e.g. missing columns added after the tenant was pre-provisioned). """ with get_session_with_shared_schema() as db_session: pool_tenants = db_session.query(AvailableTenant).all() tenant_ids = [t.tenant_id for t in pool_tenants] if not tenant_ids: return task_logger.info( f"Checking {len(tenant_ids)} pool tenant(s) for pending migrations" ) for tenant_id in tenant_ids: try: run_alembic_migrations(tenant_id) new_version = get_current_alembic_version(tenant_id) with get_session_with_shared_schema() as db_session: tenant = ( db_session.query(AvailableTenant) .filter_by(tenant_id=tenant_id) .first() ) if tenant and tenant.alembic_version != new_version: task_logger.info( f"Migrated pool tenant {tenant_id}: {tenant.alembic_version} -> {new_version}" ) tenant.alembic_version = new_version db_session.commit() except Exception: task_logger.exception( f"Failed to migrate pool tenant {tenant_id}, skipping" ) def pre_provision_tenant() -> bool: """ Pre-provision a new tenant and store it in the NewAvailableTenant table. This function fully sets up the tenant with all necessary configurations, so it's ready to be assigned to a user immediately. Returns True if a tenant was successfully provisioned, False otherwise. """ # The MULTI_TENANT check is now done at the caller level (check_available_tenants) # rather than inside this function r = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID) lock_provision: RedisLock = r.lock( OnyxRedisLocks.CLOUD_PRE_PROVISION_TENANT_LOCK, timeout=_TENANT_PROVISIONING_TIME_LIMIT, ) # Allow multiple pre-provisioning tasks to run, but ensure they don't overlap if not lock_provision.acquire(blocking=False): task_logger.warning( "Skipping pre_provision_tenant — could not acquire provision lock" ) return False tenant_id: str | None = None try: # Generate a new tenant ID tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4()) task_logger.info(f"Pre-provisioning tenant: {tenant_id}") # Create the schema for the new tenant schema_created = create_schema_if_not_exists(tenant_id) if schema_created: task_logger.debug(f"Created schema for tenant: {tenant_id}") else: task_logger.debug(f"Schema already exists for tenant: {tenant_id}") # Set up the tenant with all necessary configurations task_logger.debug(f"Setting up tenant configuration: {tenant_id}") asyncio.run(setup_tenant(tenant_id)) task_logger.debug(f"Tenant configuration completed: {tenant_id}") # Get the current Alembic version alembic_version = get_current_alembic_version(tenant_id) task_logger.debug( f"Tenant {tenant_id} using Alembic version: {alembic_version}" ) # Store the pre-provisioned tenant in the database task_logger.debug(f"Storing pre-provisioned tenant in database: {tenant_id}") with get_session_with_shared_schema() as db_session: # Use a transaction to ensure atomicity db_session.begin() try: new_tenant = AvailableTenant( tenant_id=tenant_id, alembic_version=alembic_version, date_created=datetime.datetime.now(), ) db_session.add(new_tenant) db_session.commit() task_logger.info(f"Successfully pre-provisioned tenant: {tenant_id}") return True except Exception: db_session.rollback() task_logger.error( f"Failed to store pre-provisioned tenant: {tenant_id}", exc_info=True, ) raise except Exception: task_logger.error("Error in pre_provision_tenant task", exc_info=True) # If we have a tenant_id, attempt to rollback any partially completed provisioning if tenant_id: task_logger.info( f"Rolling back failed tenant provisioning for: {tenant_id}" ) try: from ee.onyx.server.tenants.provisioning import ( rollback_tenant_provisioning, ) asyncio.run(rollback_tenant_provisioning(tenant_id)) except Exception: task_logger.exception(f"Error during rollback for tenant: {tenant_id}") return False finally: try: lock_provision.release() except Exception: task_logger.warning( "Could not release provision lock (likely expired), continuing" ) ================================================ FILE: backend/ee/onyx/background/celery/tasks/ttl_management/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/background/celery/tasks/ttl_management/tasks.py ================================================ from uuid import UUID from celery import shared_task from celery import Task from ee.onyx.background.celery_utils import should_perform_chat_ttl_check from onyx.configs.app_configs import JOB_TIMEOUT from onyx.configs.constants import OnyxCeleryTask from onyx.db.chat import delete_chat_session from onyx.db.chat import get_chat_sessions_older_than from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.server.settings.store import load_settings from onyx.utils.logger import setup_logger logger = setup_logger() @shared_task( name=OnyxCeleryTask.PERFORM_TTL_MANAGEMENT_TASK, ignore_result=True, soft_time_limit=JOB_TIMEOUT, bind=True, trail=False, ) def perform_ttl_management_task( self: Task, retention_limit_days: int, *, tenant_id: str # noqa: ARG001 ) -> None: task_id = self.request.id if not task_id: raise RuntimeError("No task id defined for this task; cannot identify it") user_id: UUID | None = None session_id: UUID | None = None try: with get_session_with_current_tenant() as db_session: old_chat_sessions = get_chat_sessions_older_than( retention_limit_days, db_session ) for user_id, session_id in old_chat_sessions: # one session per delete so that we don't blow up if a deletion fails. with get_session_with_current_tenant() as db_session: delete_chat_session( user_id, session_id, db_session, include_deleted=True, hard_delete=True, ) except Exception: logger.exception( f"delete_chat_session exceptioned. user_id={user_id} session_id={session_id}" ) raise @shared_task( name=OnyxCeleryTask.CHECK_TTL_MANAGEMENT_TASK, ignore_result=True, soft_time_limit=JOB_TIMEOUT, ) def check_ttl_management_task(*, tenant_id: str) -> None: """Runs periodically to check if any ttl tasks should be run and adds them to the queue""" settings = load_settings() retention_limit_days = settings.maximum_chat_retention_days with get_session_with_current_tenant() as db_session: if should_perform_chat_ttl_check(retention_limit_days, db_session): perform_ttl_management_task.apply_async( kwargs=dict( retention_limit_days=retention_limit_days, tenant_id=tenant_id ), ) ================================================ FILE: backend/ee/onyx/background/celery/tasks/usage_reporting/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/background/celery/tasks/usage_reporting/tasks.py ================================================ from datetime import datetime from uuid import UUID from celery import shared_task from celery import Task from ee.onyx.server.reporting.usage_export_generation import create_new_usage_report from onyx.configs.app_configs import JOB_TIMEOUT from onyx.configs.constants import OnyxCeleryTask from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.utils.logger import setup_logger logger = setup_logger() @shared_task( name=OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK, ignore_result=True, soft_time_limit=JOB_TIMEOUT, bind=True, trail=False, ) def generate_usage_report_task( self: Task, # noqa: ARG001 *, tenant_id: str, # noqa: ARG001 user_id: str | None = None, period_from: str | None = None, period_to: str | None = None, ) -> None: """User-initiated usage report generation task""" # Parse period if provided period = None if period_from and period_to: period = ( datetime.fromisoformat(period_from), datetime.fromisoformat(period_to), ) # Generate the report with get_session_with_current_tenant() as db_session: create_new_usage_report( db_session=db_session, user_id=UUID(user_id) if user_id else None, period=period, ) ================================================ FILE: backend/ee/onyx/background/celery/tasks/vespa/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/background/celery/tasks/vespa/tasks.py ================================================ from typing import cast from redis import Redis from sqlalchemy.orm import Session from ee.onyx.db.user_group import delete_user_group from ee.onyx.db.user_group import fetch_user_group from ee.onyx.db.user_group import mark_user_group_as_synced from ee.onyx.db.user_group import prepare_user_group_for_deletion from onyx.background.celery.apps.app_base import task_logger from onyx.db.enums import SyncStatus from onyx.db.enums import SyncType from onyx.db.sync_record import update_sync_record_status from onyx.redis.redis_usergroup import RedisUserGroup from onyx.utils.logger import setup_logger logger = setup_logger() def monitor_usergroup_taskset( tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session ) -> None: """This function is likely to move in the worker refactor happening next.""" fence_key = key_bytes.decode("utf-8") usergroup_id_str = RedisUserGroup.get_id_from_fence_key(fence_key) if not usergroup_id_str: task_logger.warning(f"Could not parse usergroup id from {fence_key}") return try: usergroup_id = int(usergroup_id_str) except ValueError: task_logger.exception(f"usergroup_id ({usergroup_id_str}) is not an integer!") raise rug = RedisUserGroup(tenant_id, usergroup_id) if not rug.fenced: return initial_count = rug.payload if initial_count is None: return count = cast(int, r.scard(rug.taskset_key)) task_logger.info( f"User group sync progress: usergroup_id={usergroup_id} remaining={count} initial={initial_count}" ) if count > 0: update_sync_record_status( db_session=db_session, entity_id=usergroup_id, sync_type=SyncType.USER_GROUP, sync_status=SyncStatus.IN_PROGRESS, num_docs_synced=count, ) return user_group = fetch_user_group(db_session=db_session, user_group_id=usergroup_id) if user_group: usergroup_name = user_group.name try: if user_group.is_up_for_deletion: # this prepare should have been run when the deletion was scheduled, # but run it again to be sure we're ready to go mark_user_group_as_synced(db_session, user_group) prepare_user_group_for_deletion(db_session, usergroup_id) delete_user_group(db_session=db_session, user_group=user_group) update_sync_record_status( db_session=db_session, entity_id=usergroup_id, sync_type=SyncType.USER_GROUP, sync_status=SyncStatus.SUCCESS, num_docs_synced=initial_count, ) task_logger.info( f"Deleted usergroup: name={usergroup_name} id={usergroup_id}" ) else: mark_user_group_as_synced(db_session=db_session, user_group=user_group) update_sync_record_status( db_session=db_session, entity_id=usergroup_id, sync_type=SyncType.USER_GROUP, sync_status=SyncStatus.SUCCESS, num_docs_synced=initial_count, ) task_logger.info( f"Synced usergroup. name={usergroup_name} id={usergroup_id}" ) except Exception as e: update_sync_record_status( db_session=db_session, entity_id=usergroup_id, sync_type=SyncType.USER_GROUP, sync_status=SyncStatus.FAILED, num_docs_synced=initial_count, ) raise e rug.reset() ================================================ FILE: backend/ee/onyx/background/celery_utils.py ================================================ from sqlalchemy.orm import Session from ee.onyx.background.task_name_builders import name_chat_ttl_task from onyx.db.tasks import check_task_is_live_and_not_timed_out from onyx.db.tasks import get_latest_task from onyx.utils.logger import setup_logger logger = setup_logger() def should_perform_chat_ttl_check( retention_limit_days: float | None, db_session: Session ) -> bool: # TODO: make this a check for None and add behavior for 0 day TTL if not retention_limit_days: return False task_name = name_chat_ttl_task(retention_limit_days) latest_task = get_latest_task(task_name, db_session) if not latest_task: return True if check_task_is_live_and_not_timed_out(latest_task, db_session): logger.debug(f"{task_name} is already being performed. Skipping.") return False return True ================================================ FILE: backend/ee/onyx/background/task_name_builders.py ================================================ from datetime import datetime from onyx.configs.constants import OnyxCeleryTask QUERY_HISTORY_TASK_NAME_PREFIX = OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK def name_chat_ttl_task( retention_limit_days: float, tenant_id: str | None = None, # noqa: ARG001 ) -> str: return f"chat_ttl_{retention_limit_days}_days" def query_history_task_name(start: datetime, end: datetime) -> str: return f"{QUERY_HISTORY_TASK_NAME_PREFIX}_{start}_{end}" ================================================ FILE: backend/ee/onyx/configs/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/configs/app_configs.py ================================================ import json import os ##### # Auto Permission Sync ##### # should generally only be used for sources that support polling of permissions # e.g. can pull in only permission changes rather than having to go through all # documents every time DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY = int( os.environ.get("DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60 ) ##### # Confluence ##### # In seconds, default is 30 minutes CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY = int( os.environ.get("CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY") or 30 * 60 ) # In seconds, default is 30 minutes CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY = int( os.environ.get("CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY") or 30 * 60 ) # This is a boolean that determines if anonymous access is public # Default behavior is to not make the page public and instead add a group # that contains all the users that we found in Confluence CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC = ( os.environ.get("CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC", "").lower() == "true" ) ##### # JIRA ##### # In seconds, default is 30 minutes JIRA_PERMISSION_DOC_SYNC_FREQUENCY = int( os.environ.get("JIRA_PERMISSION_DOC_SYNC_FREQUENCY") or 30 * 60 ) # In seconds, default is 30 minutes JIRA_PERMISSION_GROUP_SYNC_FREQUENCY = int( os.environ.get("JIRA_PERMISSION_GROUP_SYNC_FREQUENCY") or 30 * 60 ) ##### # Google Drive ##### GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY = int( os.environ.get("GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60 ) ##### # GitHub ##### # In seconds, default is 5 minutes GITHUB_PERMISSION_DOC_SYNC_FREQUENCY = int( os.environ.get("GITHUB_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60 ) # In seconds, default is 5 minutes GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY = int( os.environ.get("GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60 ) ##### # Slack ##### SLACK_PERMISSION_DOC_SYNC_FREQUENCY = int( os.environ.get("SLACK_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60 ) NUM_PERMISSION_WORKERS = int(os.environ.get("NUM_PERMISSION_WORKERS") or 2) ##### # Teams ##### # In seconds, default is 5 minutes TEAMS_PERMISSION_DOC_SYNC_FREQUENCY = int( os.environ.get("TEAMS_PERMISSION_DOC_SYNC_FREQUENCY") or 5 * 60 ) ##### # SharePoint ##### # In seconds, default is 30 minutes SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY = int( os.environ.get("SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY") or 30 * 60 ) # In seconds, default is 5 minutes SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY = int( os.environ.get("SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY") or 5 * 60 ) #### # Celery Job Frequency #### CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS = float( os.environ.get("CHECK_TTL_MANAGEMENT_TASK_FREQUENCY_IN_HOURS") or 1 ) # float for easier testing STRIPE_SECRET_KEY = os.environ.get("STRIPE_SECRET_KEY") # JWT Public Key URL JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None) # Super Users SUPER_USERS = json.loads(os.environ.get("SUPER_USERS", "[]")) SUPER_CLOUD_API_KEY = os.environ.get("SUPER_CLOUD_API_KEY", "api_key") POSTHOG_API_KEY = os.environ.get("POSTHOG_API_KEY") POSTHOG_HOST = os.environ.get("POSTHOG_HOST") or "https://us.i.posthog.com" POSTHOG_DEBUG_LOGS_ENABLED = ( os.environ.get("POSTHOG_DEBUG_LOGS_ENABLED", "").lower() == "true" ) MARKETING_POSTHOG_API_KEY = os.environ.get("MARKETING_POSTHOG_API_KEY") HUBSPOT_TRACKING_URL = os.environ.get("HUBSPOT_TRACKING_URL") GATED_TENANTS_KEY = "gated_tenants" # License enforcement - when True, blocks API access for gated/expired licenses LICENSE_ENFORCEMENT_ENABLED = ( os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "true").lower() == "true" ) # Cloud data plane URL - self-hosted instances call this to reach cloud proxy endpoints # Used when MULTI_TENANT=false (self-hosted mode) CLOUD_DATA_PLANE_URL = os.environ.get( "CLOUD_DATA_PLANE_URL", "https://cloud.onyx.app/api" ) ================================================ FILE: backend/ee/onyx/configs/license_enforcement_config.py ================================================ """Constants for license enforcement. This file is the single source of truth for: 1. Paths that bypass license enforcement (always accessible) 2. Paths that require an EE license (EE-only features) Import these constants in both production code and tests to ensure consistency. """ # Paths that are ALWAYS accessible, even when license is expired/gated. # These enable users to: # /auth - Log in/out (users can't fix billing if locked out of auth) # /license - Fetch, upload, or check license status # /health - Health checks for load balancers/orchestrators # /me - Basic user info needed for UI rendering # /settings, /enterprise-settings - View app status and branding # /billing - Unified billing API # /proxy - Self-hosted proxy endpoints (have own license-based auth) # /tenants/billing-* - Legacy billing endpoints (backwards compatibility) # /manage/users, /users - User management (needed for seat limit resolution) # /notifications - Needed for UI to load properly LICENSE_ENFORCEMENT_ALLOWED_PREFIXES: frozenset[str] = frozenset( { "/auth", "/license", "/health", "/me", "/settings", "/enterprise-settings", # Billing endpoints (unified API for both MT and self-hosted) "/billing", "/admin/billing", # Proxy endpoints for self-hosted billing (no tenant context) "/proxy", # Legacy tenant billing endpoints (kept for backwards compatibility) "/tenants/billing-information", "/tenants/create-customer-portal-session", "/tenants/create-subscription-session", # User management - needed to remove users when seat limit exceeded "/manage/users", "/manage/admin/users", "/manage/admin/valid-domains", "/manage/admin/deactivate-user", "/manage/admin/delete-user", "/users", # Notifications - needed for UI to load properly "/notifications", } ) # EE-only paths that require a valid license. # Users without a license (community edition) cannot access these. # These are blocked even when user has never subscribed (no license). EE_ONLY_PATH_PREFIXES: frozenset[str] = frozenset( { # User groups and access control "/manage/admin/user-group", # Analytics and reporting "/analytics", # Query history (admin chat session endpoints) "/admin/chat-sessions", "/admin/chat-session-history", "/admin/query-history", # Usage reporting/export "/admin/usage-report", # Standard answers (canned responses) "/manage/admin/standard-answer", # Token rate limits "/admin/token-rate-limits", # Evals "/evals", # Hook extensions "/admin/hooks", } ) ================================================ FILE: backend/ee/onyx/connectors/perm_sync_valid.py ================================================ from onyx.connectors.confluence.connector import ConfluenceConnector from onyx.connectors.google_drive.connector import GoogleDriveConnector from onyx.connectors.interfaces import BaseConnector def validate_confluence_perm_sync(connector: ConfluenceConnector) -> None: """ Validate that the connector is configured correctly for permissions syncing. """ def validate_drive_perm_sync(connector: GoogleDriveConnector) -> None: """ Validate that the connector is configured correctly for permissions syncing. """ def validate_perm_sync(connector: BaseConnector) -> None: """ Override this if your connector needs to validate permissions syncing. Raise an exception if invalid, otherwise do nothing. Default is a no-op (always successful). """ if isinstance(connector, ConfluenceConnector): validate_confluence_perm_sync(connector) elif isinstance(connector, GoogleDriveConnector): validate_drive_perm_sync(connector) ================================================ FILE: backend/ee/onyx/db/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/db/analytics.py ================================================ import datetime from collections.abc import Sequence from uuid import UUID from sqlalchemy import and_ from sqlalchemy import case from sqlalchemy import cast from sqlalchemy import Date from sqlalchemy import func from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy.orm import Session from onyx.configs.constants import MessageType from onyx.db.models import ChatMessage from onyx.db.models import ChatMessageFeedback from onyx.db.models import ChatSession from onyx.db.models import Persona from onyx.db.models import User from onyx.db.models import UserRole def fetch_query_analytics( start: datetime.datetime, end: datetime.datetime, db_session: Session, ) -> Sequence[tuple[int, int, int, datetime.date]]: stmt = ( select( func.count(ChatMessage.id), func.sum(case((ChatMessageFeedback.is_positive, 1), else_=0)), func.sum( case( (ChatMessageFeedback.is_positive == False, 1), # noqa: E712 else_=0, # noqa: E712 ) ), cast(ChatMessage.time_sent, Date), ) .join( ChatMessageFeedback, ChatMessageFeedback.chat_message_id == ChatMessage.id, isouter=True, ) .where( ChatMessage.time_sent >= start, ) .where( ChatMessage.time_sent <= end, ) .where(ChatMessage.message_type == MessageType.ASSISTANT) .group_by(cast(ChatMessage.time_sent, Date)) .order_by(cast(ChatMessage.time_sent, Date)) ) return db_session.execute(stmt).all() # type: ignore def fetch_per_user_query_analytics( start: datetime.datetime, end: datetime.datetime, db_session: Session, ) -> Sequence[tuple[int, int, int, datetime.date, UUID]]: stmt = ( select( func.count(ChatMessage.id), func.sum(case((ChatMessageFeedback.is_positive, 1), else_=0)), func.sum( case( (ChatMessageFeedback.is_positive == False, 1), # noqa: E712 else_=0, # noqa: E712 ) ), cast(ChatMessage.time_sent, Date), ChatSession.user_id, ) .join(ChatSession, ChatSession.id == ChatMessage.chat_session_id) # Include chats that have no explicit feedback instead of dropping them .join( ChatMessageFeedback, ChatMessageFeedback.chat_message_id == ChatMessage.id, isouter=True, ) .where( ChatMessage.time_sent >= start, ) .where( ChatMessage.time_sent <= end, ) .where(ChatMessage.message_type == MessageType.ASSISTANT) .group_by(cast(ChatMessage.time_sent, Date), ChatSession.user_id) .order_by(cast(ChatMessage.time_sent, Date), ChatSession.user_id) ) return db_session.execute(stmt).all() # type: ignore def fetch_onyxbot_analytics( start: datetime.datetime, end: datetime.datetime, db_session: Session, ) -> Sequence[tuple[int, int, datetime.date]]: """Gets the: Date of each set of aggregated statistics Number of OnyxBot Queries (Chat Sessions) Number of instances of Negative feedback OR Needing additional help (only counting the last feedback) """ # Get every chat session in the time range which is a Onyxbot flow # along with the first Assistant message which is the response to the user question. # Generally there should not be more than one AI message per chat session of this type subquery_first_ai_response = ( db_session.query( ChatMessage.chat_session_id.label("chat_session_id"), func.min(ChatMessage.id).label("chat_message_id"), ) .join(ChatSession, ChatSession.id == ChatMessage.chat_session_id) .where( ChatSession.time_created >= start, ChatSession.time_created <= end, ChatSession.onyxbot_flow.is_(True), ) .where( ChatMessage.message_type == MessageType.ASSISTANT, ) .group_by(ChatMessage.chat_session_id) .subquery() ) # Get the chat message ids and most recent feedback for each of those chat messages, # not including the messages that have no feedback subquery_last_feedback = ( db_session.query( ChatMessageFeedback.chat_message_id.label("chat_message_id"), func.max(ChatMessageFeedback.id).label("max_feedback_id"), ) .group_by(ChatMessageFeedback.chat_message_id) .subquery() ) results = ( db_session.query( func.count(ChatSession.id).label("total_sessions"), # Need to explicitly specify this as False to handle the NULL case so the cases without # feedback aren't counted against Onyxbot func.sum( case( ( or_( ChatMessageFeedback.is_positive.is_(False), ChatMessageFeedback.required_followup.is_(True), ), 1, ), else_=0, ) ).label("negative_answer"), cast(ChatSession.time_created, Date).label("session_date"), ) .join( subquery_first_ai_response, ChatSession.id == subquery_first_ai_response.c.chat_session_id, ) # Combine the chat sessions with latest feedback to get the latest feedback for the first AI # message of the chat session where the chat session is Onyxbot type and within the time # range specified. Left/outer join used here to ensure that if no feedback, a null is used # for the feedback id .outerjoin( subquery_last_feedback, subquery_first_ai_response.c.chat_message_id == subquery_last_feedback.c.chat_message_id, ) # Join the actual feedback table to get the feedback info for the sums # Outer join because the "last feedback" may be null .outerjoin( ChatMessageFeedback, ChatMessageFeedback.id == subquery_last_feedback.c.max_feedback_id, ) .group_by(cast(ChatSession.time_created, Date)) .order_by(cast(ChatSession.time_created, Date)) .all() ) return [tuple(row) for row in results] def fetch_persona_message_analytics( db_session: Session, persona_id: int, start: datetime.datetime, end: datetime.datetime, ) -> list[tuple[int, datetime.date]]: """Gets the daily message counts for a specific persona within the given time range.""" query = ( select( func.count(ChatMessage.id), cast(ChatMessage.time_sent, Date), ) .join( ChatSession, ChatMessage.chat_session_id == ChatSession.id, ) .where( ChatSession.persona_id == persona_id, ChatMessage.time_sent >= start, ChatMessage.time_sent <= end, ChatMessage.message_type == MessageType.ASSISTANT, ) .group_by(cast(ChatMessage.time_sent, Date)) .order_by(cast(ChatMessage.time_sent, Date)) ) return [tuple(row) for row in db_session.execute(query).all()] def fetch_persona_unique_users( db_session: Session, persona_id: int, start: datetime.datetime, end: datetime.datetime, ) -> list[tuple[int, datetime.date]]: """Gets the daily unique user counts for a specific persona within the given time range.""" query = ( select( func.count(func.distinct(ChatSession.user_id)), cast(ChatMessage.time_sent, Date), ) .join( ChatSession, ChatMessage.chat_session_id == ChatSession.id, ) .where( ChatSession.persona_id == persona_id, ChatMessage.time_sent >= start, ChatMessage.time_sent <= end, ChatMessage.message_type == MessageType.ASSISTANT, ) .group_by(cast(ChatMessage.time_sent, Date)) .order_by(cast(ChatMessage.time_sent, Date)) ) return [tuple(row) for row in db_session.execute(query).all()] def fetch_assistant_message_analytics( db_session: Session, assistant_id: int, start: datetime.datetime, end: datetime.datetime, ) -> list[tuple[int, datetime.date]]: """ Gets the daily message counts for a specific assistant in the given time range. """ query = ( select( func.count(ChatMessage.id), cast(ChatMessage.time_sent, Date), ) .join( ChatSession, ChatMessage.chat_session_id == ChatSession.id, ) .where( ChatSession.persona_id == assistant_id, ChatMessage.time_sent >= start, ChatMessage.time_sent <= end, ChatMessage.message_type == MessageType.ASSISTANT, ) .group_by(cast(ChatMessage.time_sent, Date)) .order_by(cast(ChatMessage.time_sent, Date)) ) return [tuple(row) for row in db_session.execute(query).all()] def fetch_assistant_unique_users( db_session: Session, assistant_id: int, start: datetime.datetime, end: datetime.datetime, ) -> list[tuple[int, datetime.date]]: """ Gets the daily unique user counts for a specific assistant in the given time range. """ query = ( select( func.count(func.distinct(ChatSession.user_id)), cast(ChatMessage.time_sent, Date), ) .join( ChatSession, ChatMessage.chat_session_id == ChatSession.id, ) .where( ChatSession.persona_id == assistant_id, ChatMessage.time_sent >= start, ChatMessage.time_sent <= end, ChatMessage.message_type == MessageType.ASSISTANT, ) .group_by(cast(ChatMessage.time_sent, Date)) .order_by(cast(ChatMessage.time_sent, Date)) ) return [tuple(row) for row in db_session.execute(query).all()] def fetch_assistant_unique_users_total( db_session: Session, assistant_id: int, start: datetime.datetime, end: datetime.datetime, ) -> int: """ Gets the total number of distinct users who have sent or received messages from the specified assistant in the given time range. """ query = ( select(func.count(func.distinct(ChatSession.user_id))) .select_from(ChatMessage) .join( ChatSession, ChatMessage.chat_session_id == ChatSession.id, ) .where( ChatSession.persona_id == assistant_id, ChatMessage.time_sent >= start, ChatMessage.time_sent <= end, ChatMessage.message_type == MessageType.ASSISTANT, ) ) result = db_session.execute(query).scalar() return result if result else 0 # Users can view assistant stats if they created the persona, # or if they are an admin def user_can_view_assistant_stats( db_session: Session, user: User, assistant_id: int ) -> bool: if user.role == UserRole.ADMIN: return True # Check if the user created the persona stmt = select(Persona).where( and_(Persona.id == assistant_id, Persona.user_id == user.id) ) persona = db_session.execute(stmt).scalar_one_or_none() return persona is not None ================================================ FILE: backend/ee/onyx/db/connector.py ================================================ from sqlalchemy import distinct from sqlalchemy.orm import Session from onyx.configs.constants import DocumentSource from onyx.db.models import Connector from onyx.utils.logger import setup_logger logger = setup_logger() def fetch_sources_with_connectors(db_session: Session) -> list[DocumentSource]: sources = db_session.query(distinct(Connector.source)).all() # type: ignore document_sources = [source[0] for source in sources] return document_sources ================================================ FILE: backend/ee/onyx/db/connector_credential_pair.py ================================================ from sqlalchemy import delete from sqlalchemy.orm import Session from onyx.configs.constants import DocumentSource from onyx.db.connector_credential_pair import get_connector_credential_pair from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.models import Connector from onyx.db.models import ConnectorCredentialPair from onyx.db.models import UserGroup__ConnectorCredentialPair from onyx.utils.logger import setup_logger logger = setup_logger() def _delete_connector_credential_pair_user_groups_relationship__no_commit( db_session: Session, connector_id: int, credential_id: int ) -> None: cc_pair = get_connector_credential_pair( db_session=db_session, connector_id=connector_id, credential_id=credential_id, ) if cc_pair is None: raise ValueError( f"ConnectorCredentialPair with connector_id: {connector_id} and credential_id: {credential_id} not found" ) stmt = delete(UserGroup__ConnectorCredentialPair).where( UserGroup__ConnectorCredentialPair.cc_pair_id == cc_pair.id, ) db_session.execute(stmt) def get_cc_pairs_by_source( db_session: Session, source_type: DocumentSource, access_type: AccessType | None = None, status: ConnectorCredentialPairStatus | None = None, ) -> list[ConnectorCredentialPair]: """ Get all cc_pairs for a given source type with optional filtering by access_type and status result is sorted by cc_pair id """ query = ( db_session.query(ConnectorCredentialPair) .join(ConnectorCredentialPair.connector) .filter(Connector.source == source_type) .order_by(ConnectorCredentialPair.id) ) if access_type is not None: query = query.filter(ConnectorCredentialPair.access_type == access_type) if status is not None: query = query.filter(ConnectorCredentialPair.status == status) cc_pairs = query.all() return cc_pairs def get_all_auto_sync_cc_pairs( db_session: Session, ) -> list[ConnectorCredentialPair]: return ( db_session.query(ConnectorCredentialPair) .where( ConnectorCredentialPair.access_type == AccessType.SYNC, ) .all() ) ================================================ FILE: backend/ee/onyx/db/document.py ================================================ from datetime import datetime from datetime import timezone from sqlalchemy import select from sqlalchemy.orm import Session from onyx.access.models import ExternalAccess from onyx.access.utils import build_ext_group_name_for_onyx from onyx.configs.constants import DocumentSource from onyx.db.models import Document as DbDocument def upsert_document_external_perms__no_commit( db_session: Session, doc_id: str, external_access: ExternalAccess, source_type: DocumentSource, ) -> None: """ This sets the permissions for a document in postgres. NOTE: this will replace any existing external access, it will not do a union """ document = db_session.scalars( select(DbDocument).where(DbDocument.id == doc_id) ).first() prefixed_external_groups = [ build_ext_group_name_for_onyx( ext_group_name=group_id, source=source_type, ) for group_id in external_access.external_user_group_ids ] if not document: # If the document does not exist, still store the external access # So that if the document is added later, the external access is already stored document = DbDocument( id=doc_id, semantic_id="", external_user_emails=external_access.external_user_emails, external_user_group_ids=prefixed_external_groups, is_public=external_access.is_public, ) db_session.add(document) return document.external_user_emails = list(external_access.external_user_emails) document.external_user_group_ids = prefixed_external_groups document.is_public = external_access.is_public def upsert_document_external_perms( db_session: Session, doc_id: str, external_access: ExternalAccess, source_type: DocumentSource, ) -> bool: """ This sets the permissions for a document in postgres. Returns True if the a new document was created, False otherwise. NOTE: this will replace any existing external access, it will not do a union """ document = db_session.scalars( select(DbDocument).where(DbDocument.id == doc_id) ).first() prefixed_external_groups: set[str] = { build_ext_group_name_for_onyx( ext_group_name=group_id, source=source_type, ) for group_id in external_access.external_user_group_ids } if not document: # If the document does not exist, still store the external access # So that if the document is added later, the external access is already stored # The upsert function in the indexing pipeline does not overwrite the permissions fields document = DbDocument( id=doc_id, semantic_id="", external_user_emails=external_access.external_user_emails, external_user_group_ids=prefixed_external_groups, is_public=external_access.is_public, ) db_session.add(document) db_session.commit() return True # If the document exists, we need to check if the external access has changed if ( external_access.external_user_emails != set(document.external_user_emails or []) or prefixed_external_groups != set(document.external_user_group_ids or []) or external_access.is_public != document.is_public ): document.external_user_emails = list(external_access.external_user_emails) document.external_user_group_ids = list(prefixed_external_groups) document.is_public = external_access.is_public document.last_modified = datetime.now(timezone.utc) db_session.commit() return False ================================================ FILE: backend/ee/onyx/db/document_set.py ================================================ from uuid import UUID from sqlalchemy.orm import Session from onyx.db.models import ConnectorCredentialPair from onyx.db.models import DocumentSet from onyx.db.models import DocumentSet__ConnectorCredentialPair from onyx.db.models import DocumentSet__User from onyx.db.models import DocumentSet__UserGroup from onyx.db.models import User__UserGroup from onyx.db.models import UserGroup def make_doc_set_private( document_set_id: int, user_ids: list[UUID] | None, group_ids: list[int] | None, db_session: Session, ) -> None: db_session.query(DocumentSet__User).filter( DocumentSet__User.document_set_id == document_set_id ).delete(synchronize_session="fetch") db_session.query(DocumentSet__UserGroup).filter( DocumentSet__UserGroup.document_set_id == document_set_id ).delete(synchronize_session="fetch") if user_ids: for user_uuid in user_ids: db_session.add( DocumentSet__User(document_set_id=document_set_id, user_id=user_uuid) ) if group_ids: for group_id in group_ids: db_session.add( DocumentSet__UserGroup( document_set_id=document_set_id, user_group_id=group_id ) ) def delete_document_set_privacy__no_commit( document_set_id: int, db_session: Session ) -> None: db_session.query(DocumentSet__User).filter( DocumentSet__User.document_set_id == document_set_id ).delete(synchronize_session="fetch") db_session.query(DocumentSet__UserGroup).filter( DocumentSet__UserGroup.document_set_id == document_set_id ).delete(synchronize_session="fetch") def fetch_document_sets( user_id: UUID | None, db_session: Session, include_outdated: bool = True, # Parameter only for versioned implementation, unused # noqa: ARG001 ) -> list[tuple[DocumentSet, list[ConnectorCredentialPair]]]: assert user_id is not None # Public document sets public_document_sets = ( db_session.query(DocumentSet) .filter(DocumentSet.is_public == True) # noqa .all() ) # Document sets via shared user relationships shared_document_sets = ( db_session.query(DocumentSet) .join(DocumentSet__User, DocumentSet.id == DocumentSet__User.document_set_id) .filter(DocumentSet__User.user_id == user_id) .all() ) # Document sets via groups # First, find the user groups the user belongs to user_groups = ( db_session.query(UserGroup) .join(User__UserGroup, UserGroup.id == User__UserGroup.user_group_id) .filter(User__UserGroup.user_id == user_id) .all() ) group_document_sets = [] for group in user_groups: group_document_sets.extend( db_session.query(DocumentSet) .join( DocumentSet__UserGroup, DocumentSet.id == DocumentSet__UserGroup.document_set_id, ) .filter(DocumentSet__UserGroup.user_group_id == group.id) .all() ) # Combine and deduplicate document sets from all sources all_document_sets = list( set(public_document_sets + shared_document_sets + group_document_sets) ) document_set_with_cc_pairs: list[ tuple[DocumentSet, list[ConnectorCredentialPair]] ] = [] for document_set in all_document_sets: # Fetch the associated ConnectorCredentialPairs cc_pairs = ( db_session.query(ConnectorCredentialPair) .join( DocumentSet__ConnectorCredentialPair, ConnectorCredentialPair.id == DocumentSet__ConnectorCredentialPair.connector_credential_pair_id, ) .filter( DocumentSet__ConnectorCredentialPair.document_set_id == document_set.id, ) .all() ) document_set_with_cc_pairs.append((document_set, cc_pairs)) return document_set_with_cc_pairs ================================================ FILE: backend/ee/onyx/db/external_perm.py ================================================ from collections.abc import Sequence from uuid import UUID from pydantic import BaseModel from sqlalchemy import delete from sqlalchemy import select from sqlalchemy import update from sqlalchemy.orm import Session from onyx.access.utils import build_ext_group_name_for_onyx from onyx.configs.constants import DocumentSource from onyx.db.models import PublicExternalUserGroup from onyx.db.models import User from onyx.db.models import User__ExternalUserGroupId from onyx.db.users import batch_add_ext_perm_user_if_not_exists from onyx.db.users import get_user_by_email from onyx.utils.logger import setup_logger logger = setup_logger() class ExternalUserGroup(BaseModel): id: str user_emails: list[str] # `True` for cases like a Folder in Google Drive that give domain-wide # or "Anyone with link" access to all files in the folder. # if this is set, `user_emails` don't really matter. # When this is `True`, this `ExternalUserGroup` object doesn't really represent # an actual "group" in the source. gives_anyone_access: bool = False def delete_user__ext_group_for_user__no_commit( db_session: Session, user_id: UUID, ) -> None: db_session.execute( delete(User__ExternalUserGroupId).where( User__ExternalUserGroupId.user_id == user_id ) ) def delete_user__ext_group_for_cc_pair__no_commit( db_session: Session, cc_pair_id: int, ) -> None: db_session.execute( delete(User__ExternalUserGroupId).where( User__ExternalUserGroupId.cc_pair_id == cc_pair_id ) ) def delete_public_external_group_for_cc_pair__no_commit( db_session: Session, cc_pair_id: int, ) -> None: db_session.execute( delete(PublicExternalUserGroup).where( PublicExternalUserGroup.cc_pair_id == cc_pair_id ) ) def mark_old_external_groups_as_stale( db_session: Session, cc_pair_id: int, ) -> None: db_session.execute( update(User__ExternalUserGroupId) .where(User__ExternalUserGroupId.cc_pair_id == cc_pair_id) .values(stale=True) ) db_session.execute( update(PublicExternalUserGroup) .where(PublicExternalUserGroup.cc_pair_id == cc_pair_id) .values(stale=True) ) def upsert_external_groups( db_session: Session, cc_pair_id: int, external_groups: list[ExternalUserGroup], source: DocumentSource, ) -> None: """ Performs a true upsert operation for external user groups: - For existing groups (same user_id, external_user_group_id, cc_pair_id), updates the stale flag to False - For new groups, inserts them with stale=False - For public groups, uses upsert logic as well """ # If there are no groups to add, return early if not external_groups: return # collect all emails from all groups to batch add all users at once for efficiency all_group_member_emails = set() for external_group in external_groups: for user_email in external_group.user_emails: all_group_member_emails.add(user_email) # batch add users if they don't exist and get their ids all_group_members: list[User] = batch_add_ext_perm_user_if_not_exists( db_session=db_session, # NOTE: this function handles case sensitivity for emails emails=list(all_group_member_emails), ) # map emails to ids email_id_map = {user.email.lower(): user.id for user in all_group_members} # Process each external group for external_group in external_groups: external_group_id = build_ext_group_name_for_onyx( ext_group_name=external_group.id, source=source, ) # Handle user-group mappings for user_email in external_group.user_emails: user_id = email_id_map.get(user_email.lower()) if user_id is None: logger.warning( f"User in group {external_group.id} with email {user_email} not found" ) continue # Check if the user-group mapping already exists existing_user_group = db_session.scalar( select(User__ExternalUserGroupId).where( User__ExternalUserGroupId.user_id == user_id, User__ExternalUserGroupId.external_user_group_id == external_group_id, User__ExternalUserGroupId.cc_pair_id == cc_pair_id, ) ) if existing_user_group: # Update existing record existing_user_group.stale = False else: # Insert new record new_user_group = User__ExternalUserGroupId( user_id=user_id, external_user_group_id=external_group_id, cc_pair_id=cc_pair_id, stale=False, ) db_session.add(new_user_group) # Handle public group if needed if external_group.gives_anyone_access: # Check if the public group already exists existing_public_group = db_session.scalar( select(PublicExternalUserGroup).where( PublicExternalUserGroup.external_user_group_id == external_group_id, PublicExternalUserGroup.cc_pair_id == cc_pair_id, ) ) if existing_public_group: # Update existing record existing_public_group.stale = False else: # Insert new record new_public_group = PublicExternalUserGroup( external_user_group_id=external_group_id, cc_pair_id=cc_pair_id, stale=False, ) db_session.add(new_public_group) db_session.commit() def remove_stale_external_groups( db_session: Session, cc_pair_id: int, ) -> None: db_session.execute( delete(User__ExternalUserGroupId).where( User__ExternalUserGroupId.cc_pair_id == cc_pair_id, User__ExternalUserGroupId.stale.is_(True), ) ) db_session.execute( delete(PublicExternalUserGroup).where( PublicExternalUserGroup.cc_pair_id == cc_pair_id, PublicExternalUserGroup.stale.is_(True), ) ) db_session.commit() def fetch_external_groups_for_user( db_session: Session, user_id: UUID, ) -> Sequence[User__ExternalUserGroupId]: return db_session.scalars( select(User__ExternalUserGroupId).where( User__ExternalUserGroupId.user_id == user_id ) ).all() def fetch_external_groups_for_user_email_and_group_ids( db_session: Session, user_email: str, group_ids: list[str], ) -> list[User__ExternalUserGroupId]: user = get_user_by_email(db_session=db_session, email=user_email) if user is None: return [] user_id = user.id user_ext_groups = db_session.scalars( select(User__ExternalUserGroupId).where( User__ExternalUserGroupId.user_id == user_id, User__ExternalUserGroupId.external_user_group_id.in_(group_ids), ) ).all() return list(user_ext_groups) def fetch_public_external_group_ids( db_session: Session, ) -> list[str]: return list( db_session.scalars(select(PublicExternalUserGroup.external_user_group_id)).all() ) ================================================ FILE: backend/ee/onyx/db/hierarchy.py ================================================ """EE version of hierarchy node access control. This module provides permission-aware hierarchy node access for Enterprise Edition. It filters hierarchy nodes based on user email and external group membership. """ from sqlalchemy import any_ from sqlalchemy import cast from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy import String from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Session from sqlalchemy.sql.elements import ColumnElement from onyx.configs.constants import DocumentSource from onyx.db.models import HierarchyNode def _build_hierarchy_access_filter( user_email: str, external_group_ids: list[str], ) -> ColumnElement[bool]: """Build SQLAlchemy filter for hierarchy node access. A user can access a hierarchy node if any of the following are true: - The node is marked as public (is_public=True) - The user's email is in the node's external_user_emails list - Any of the user's external group IDs overlap with the node's external_user_group_ids """ access_filters: list[ColumnElement[bool]] = [HierarchyNode.is_public.is_(True)] if user_email: access_filters.append(any_(HierarchyNode.external_user_emails) == user_email) if external_group_ids: access_filters.append( HierarchyNode.external_user_group_ids.overlap( cast(postgresql.array(external_group_ids), postgresql.ARRAY(String)) ) ) return or_(*access_filters) def _get_accessible_hierarchy_nodes_for_source( db_session: Session, source: DocumentSource, user_email: str, external_group_ids: list[str], ) -> list[HierarchyNode]: """ EE version: Returns hierarchy nodes filtered by user permissions. A user can access a hierarchy node if any of the following are true: - The node is marked as public (is_public=True) - The user's email is in the node's external_user_emails list - Any of the user's external group IDs overlap with the node's external_user_group_ids Args: db_session: SQLAlchemy session source: Document source type user_email: User's email for permission checking external_group_ids: User's external group IDs for permission checking Returns: List of HierarchyNode objects the user has access to """ stmt = select(HierarchyNode).where(HierarchyNode.source == source) stmt = stmt.where(_build_hierarchy_access_filter(user_email, external_group_ids)) stmt = stmt.order_by(HierarchyNode.display_name) return list(db_session.execute(stmt).scalars().all()) ================================================ FILE: backend/ee/onyx/db/license.py ================================================ """Database and cache operations for the license table.""" from datetime import datetime from typing import NamedTuple from sqlalchemy import func from sqlalchemy import select from sqlalchemy.orm import Session from ee.onyx.server.license.models import LicenseMetadata from ee.onyx.server.license.models import LicensePayload from ee.onyx.server.license.models import LicenseSource from onyx.auth.schemas import UserRole from onyx.cache.factory import get_cache_backend from onyx.configs.constants import ANONYMOUS_USER_EMAIL from onyx.db.models import License from onyx.db.models import User from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() LICENSE_METADATA_KEY = "license:metadata" LICENSE_CACHE_TTL_SECONDS = 86400 # 24 hours class SeatAvailabilityResult(NamedTuple): """Result of a seat availability check.""" available: bool error_message: str | None = None # ----------------------------------------------------------------------------- # Database CRUD Operations # ----------------------------------------------------------------------------- def get_license(db_session: Session) -> License | None: """ Get the current license (singleton pattern - only one row). Args: db_session: Database session Returns: License object if exists, None otherwise """ return db_session.execute(select(License)).scalars().first() def upsert_license(db_session: Session, license_data: str) -> License: """ Insert or update the license (singleton pattern). Args: db_session: Database session license_data: Base64-encoded signed license blob Returns: The created or updated License object """ existing = get_license(db_session) if existing: existing.license_data = license_data db_session.commit() db_session.refresh(existing) logger.info("License updated") return existing new_license = License(license_data=license_data) db_session.add(new_license) db_session.commit() db_session.refresh(new_license) logger.info("License created") return new_license def delete_license(db_session: Session) -> bool: """ Delete the current license. Args: db_session: Database session Returns: True if deleted, False if no license existed """ existing = get_license(db_session) if existing: db_session.delete(existing) db_session.commit() logger.info("License deleted") return True return False # ----------------------------------------------------------------------------- # Seat Counting # ----------------------------------------------------------------------------- def get_used_seats(tenant_id: str | None = None) -> int: """ Get current seat usage directly from database. For multi-tenant: counts users in UserTenantMapping for this tenant. For self-hosted: counts all active users (excludes EXT_PERM_USER role and the anonymous system user). TODO: Exclude API key dummy users from seat counting. API keys create users with emails like `__DANSWER_API_KEY_*` that should not count toward seat limits. See: https://linear.app/onyx-app/issue/ENG-3518 """ if MULTI_TENANT: from ee.onyx.server.tenants.user_mapping import get_tenant_count return get_tenant_count(tenant_id or get_current_tenant_id()) else: from onyx.db.engine.sql_engine import get_session_with_current_tenant with get_session_with_current_tenant() as db_session: result = db_session.execute( select(func.count()) .select_from(User) .where( User.is_active == True, # type: ignore # noqa: E712 User.role != UserRole.EXT_PERM_USER, User.email != ANONYMOUS_USER_EMAIL, # type: ignore ) ) return result.scalar() or 0 # ----------------------------------------------------------------------------- # Redis Cache Operations # ----------------------------------------------------------------------------- def get_cached_license_metadata(tenant_id: str | None = None) -> LicenseMetadata | None: """ Get license metadata from cache. Args: tenant_id: Tenant ID (for multi-tenant deployments) Returns: LicenseMetadata if cached, None otherwise """ cache = get_cache_backend(tenant_id=tenant_id) cached = cache.get(LICENSE_METADATA_KEY) if not cached: return None try: cached_str = ( cached.decode("utf-8") if isinstance(cached, bytes) else str(cached) ) return LicenseMetadata.model_validate_json(cached_str) except Exception as e: logger.warning(f"Failed to parse cached license metadata: {e}") return None def invalidate_license_cache(tenant_id: str | None = None) -> None: """ Invalidate the license metadata cache (not the license itself). Deletes the cached LicenseMetadata. The actual license in the database is not affected. Delete is idempotent — if the key doesn't exist, this is a no-op. Args: tenant_id: Tenant ID (for multi-tenant deployments) """ cache = get_cache_backend(tenant_id=tenant_id) cache.delete(LICENSE_METADATA_KEY) logger.info("License cache invalidated") def update_license_cache( payload: LicensePayload, source: LicenseSource | None = None, grace_period_end: datetime | None = None, tenant_id: str | None = None, ) -> LicenseMetadata: """ Update the cache with license metadata. We cache all license statuses (ACTIVE, GRACE_PERIOD, GATED_ACCESS) because: 1. Frontend needs status to show appropriate UI/banners 2. Caching avoids repeated DB + crypto verification on every request 3. Status enforcement happens at the feature level, not here Args: payload: Verified license payload source: How the license was obtained grace_period_end: Optional grace period end time tenant_id: Tenant ID (for multi-tenant deployments) Returns: The cached LicenseMetadata """ from ee.onyx.utils.license import get_license_status tenant = tenant_id or get_current_tenant_id() cache = get_cache_backend(tenant_id=tenant_id) used_seats = get_used_seats(tenant) status = get_license_status(payload, grace_period_end) metadata = LicenseMetadata( tenant_id=payload.tenant_id, organization_name=payload.organization_name, seats=payload.seats, used_seats=used_seats, plan_type=payload.plan_type, issued_at=payload.issued_at, expires_at=payload.expires_at, grace_period_end=grace_period_end, status=status, source=source, stripe_subscription_id=payload.stripe_subscription_id, ) cache.set( LICENSE_METADATA_KEY, metadata.model_dump_json(), ex=LICENSE_CACHE_TTL_SECONDS, ) logger.info(f"License cache updated: {metadata.seats} seats, status={status.value}") return metadata def refresh_license_cache( db_session: Session, tenant_id: str | None = None, ) -> LicenseMetadata | None: """ Refresh the license cache from the database. Args: db_session: Database session tenant_id: Tenant ID (for multi-tenant deployments) Returns: LicenseMetadata if license exists, None otherwise """ from ee.onyx.utils.license import verify_license_signature license_record = get_license(db_session) if not license_record: invalidate_license_cache(tenant_id) return None try: payload = verify_license_signature(license_record.license_data) # Derive source from payload: manual licenses lack stripe_customer_id source: LicenseSource = ( LicenseSource.AUTO_FETCH if payload.stripe_customer_id else LicenseSource.MANUAL_UPLOAD ) return update_license_cache( payload, source=source, tenant_id=tenant_id, ) except ValueError as e: logger.error(f"Failed to verify license during cache refresh: {e}") invalidate_license_cache(tenant_id) return None def get_license_metadata( db_session: Session, tenant_id: str | None = None, ) -> LicenseMetadata | None: """ Get license metadata, using cache if available. Args: db_session: Database session tenant_id: Tenant ID (for multi-tenant deployments) Returns: LicenseMetadata if license exists, None otherwise """ # Try cache first cached = get_cached_license_metadata(tenant_id) if cached: return cached # Refresh from database return refresh_license_cache(db_session, tenant_id) def check_seat_availability( db_session: Session, seats_needed: int = 1, tenant_id: str | None = None, ) -> SeatAvailabilityResult: """ Check if there are enough seats available to add users. Args: db_session: Database session seats_needed: Number of seats needed (default 1) tenant_id: Tenant ID (for multi-tenant deployments) Returns: SeatAvailabilityResult with available=True if seats are available, or available=False with error_message if limit would be exceeded. Returns available=True if no license exists (self-hosted = unlimited). """ metadata = get_license_metadata(db_session, tenant_id) # No license = no enforcement (self-hosted without license) if metadata is None: return SeatAvailabilityResult(available=True) # Calculate current usage directly from DB (not cache) for accuracy current_used = get_used_seats(tenant_id) total_seats = metadata.seats # Use > (not >=) to allow filling to exactly 100% capacity would_exceed_limit = current_used + seats_needed > total_seats if would_exceed_limit: return SeatAvailabilityResult( available=False, error_message=f"Seat limit would be exceeded: {current_used} of {total_seats} seats used, " f"cannot add {seats_needed} more user(s).", ) return SeatAvailabilityResult(available=True) ================================================ FILE: backend/ee/onyx/db/persona.py ================================================ from uuid import UUID from sqlalchemy.orm import Session from onyx.configs.constants import NotificationType from onyx.db.models import Persona from onyx.db.models import Persona__User from onyx.db.models import Persona__UserGroup from onyx.db.notification import create_notification from onyx.db.persona import mark_persona_user_files_for_sync from onyx.server.features.persona.models import PersonaSharedNotificationData def update_persona_access( persona_id: int, creator_user_id: UUID | None, db_session: Session, is_public: bool | None = None, user_ids: list[UUID] | None = None, group_ids: list[int] | None = None, ) -> None: """Updates the access settings for a persona including public status, user shares, and group shares. NOTE: This function batches all updates. If we don't dedupe the inputs, the commit will exception. NOTE: Callers are responsible for committing.""" needs_sync = False if is_public is not None: needs_sync = True persona = db_session.query(Persona).filter(Persona.id == persona_id).first() if persona: persona.is_public = is_public # NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares", # and a non-empty list means "replace with these shares". if user_ids is not None: needs_sync = True db_session.query(Persona__User).filter( Persona__User.persona_id == persona_id ).delete(synchronize_session="fetch") user_ids_set = set(user_ids) for user_id in user_ids_set: db_session.add(Persona__User(persona_id=persona_id, user_id=user_id)) if user_id != creator_user_id: create_notification( user_id=user_id, notif_type=NotificationType.PERSONA_SHARED, title="A new agent was shared with you!", db_session=db_session, additional_data=PersonaSharedNotificationData( persona_id=persona_id, ).model_dump(), ) if group_ids is not None: needs_sync = True db_session.query(Persona__UserGroup).filter( Persona__UserGroup.persona_id == persona_id ).delete(synchronize_session="fetch") group_ids_set = set(group_ids) for group_id in group_ids_set: db_session.add( Persona__UserGroup(persona_id=persona_id, user_group_id=group_id) ) # When sharing changes, user file ACLs need to be updated in the vector DB if needs_sync: mark_persona_user_files_for_sync(persona_id, db_session) ================================================ FILE: backend/ee/onyx/db/query_history.py ================================================ from collections.abc import Sequence from datetime import datetime from sqlalchemy import asc from sqlalchemy import BinaryExpression from sqlalchemy import ColumnElement from sqlalchemy import desc from sqlalchemy import distinct from sqlalchemy.orm import contains_eager from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session from sqlalchemy.sql import case from sqlalchemy.sql import func from sqlalchemy.sql import select from sqlalchemy.sql.expression import literal from sqlalchemy.sql.expression import UnaryExpression from ee.onyx.background.task_name_builders import QUERY_HISTORY_TASK_NAME_PREFIX from onyx.configs.constants import QAFeedbackType from onyx.db.models import ChatMessage from onyx.db.models import ChatMessageFeedback from onyx.db.models import ChatSession from onyx.db.models import TaskQueueState from onyx.db.tasks import get_all_tasks_with_prefix def _build_filter_conditions( start_time: datetime | None, end_time: datetime | None, feedback_filter: QAFeedbackType | None, ) -> list[ColumnElement]: """ Helper function to build all filter conditions for chat sessions. Filters by start and end time, feedback type, and any sessions without messages. start_time: Date from which to filter end_time: Date to which to filter feedback_filter: Feedback type to filter by Returns: List of filter conditions """ conditions = [] if start_time is not None: conditions.append(ChatSession.time_created >= start_time) if end_time is not None: conditions.append(ChatSession.time_created <= end_time) if feedback_filter is not None: feedback_subq = ( select(ChatMessage.chat_session_id) .join(ChatMessageFeedback) .group_by(ChatMessage.chat_session_id) .having( case( ( case( {literal(feedback_filter == QAFeedbackType.LIKE): True}, else_=False, ), func.bool_and(ChatMessageFeedback.is_positive), ), ( case( {literal(feedback_filter == QAFeedbackType.DISLIKE): True}, else_=False, ), func.bool_and(func.not_(ChatMessageFeedback.is_positive)), ), else_=func.bool_or(ChatMessageFeedback.is_positive) & func.bool_or(func.not_(ChatMessageFeedback.is_positive)), ) ) ) conditions.append(ChatSession.id.in_(feedback_subq)) return conditions def get_total_filtered_chat_sessions_count( db_session: Session, start_time: datetime | None, end_time: datetime | None, feedback_filter: QAFeedbackType | None, ) -> int: conditions = _build_filter_conditions(start_time, end_time, feedback_filter) stmt = ( select(func.count(distinct(ChatSession.id))) .select_from(ChatSession) .filter(*conditions) ) return db_session.scalar(stmt) or 0 def get_page_of_chat_sessions( start_time: datetime | None, end_time: datetime | None, db_session: Session, page_num: int, page_size: int, feedback_filter: QAFeedbackType | None = None, ) -> Sequence[ChatSession]: conditions = _build_filter_conditions(start_time, end_time, feedback_filter) subquery = ( select(ChatSession.id) .filter(*conditions) .order_by(desc(ChatSession.time_created), ChatSession.id) .limit(page_size) .offset(page_num * page_size) .subquery() ) stmt = ( select(ChatSession) .join(subquery, ChatSession.id == subquery.c.id) .outerjoin(ChatMessage, ChatSession.id == ChatMessage.chat_session_id) .options( joinedload(ChatSession.user), joinedload(ChatSession.persona), contains_eager(ChatSession.messages).joinedload( ChatMessage.chat_message_feedbacks ), ) .order_by( desc(ChatSession.time_created), ChatSession.id, asc(ChatMessage.id), # Ensure chronological message order ) ) return db_session.scalars(stmt).unique().all() def fetch_chat_sessions_eagerly_by_time( start: datetime, end: datetime, db_session: Session, limit: int | None = 500, initial_time: datetime | None = None, ) -> list[ChatSession]: """Sorted by oldest to newest, then by message id""" asc_time_order: UnaryExpression = asc(ChatSession.time_created) message_order: UnaryExpression = asc(ChatMessage.id) filters: list[ColumnElement | BinaryExpression] = [ ChatSession.time_created.between(start, end) ] if initial_time: filters.append(ChatSession.time_created > initial_time) subquery = ( db_session.query(ChatSession.id, ChatSession.time_created) .filter(*filters) .order_by(asc_time_order) .limit(limit) .subquery() ) query = ( db_session.query(ChatSession) .join(subquery, ChatSession.id == subquery.c.id) .outerjoin(ChatMessage, ChatSession.id == ChatMessage.chat_session_id) .options( joinedload(ChatSession.user), joinedload(ChatSession.persona), contains_eager(ChatSession.messages).joinedload( ChatMessage.chat_message_feedbacks ), ) .order_by(asc_time_order, message_order) ) chat_sessions = query.all() return chat_sessions def get_all_query_history_export_tasks( db_session: Session, ) -> list[TaskQueueState]: return get_all_tasks_with_prefix(db_session, QUERY_HISTORY_TASK_NAME_PREFIX) ================================================ FILE: backend/ee/onyx/db/saml.py ================================================ import datetime from typing import cast from uuid import UUID from sqlalchemy import and_ from sqlalchemy import func from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS from onyx.db.models import SamlAccount def upsert_saml_account( user_id: UUID, cookie: str, db_session: Session, expiration_offset: int = SESSION_EXPIRE_TIME_SECONDS, ) -> datetime.datetime: expires_at = func.now() + datetime.timedelta(seconds=expiration_offset) existing_saml_acc = ( db_session.query(SamlAccount) .filter(SamlAccount.user_id == user_id) .one_or_none() ) if existing_saml_acc: existing_saml_acc.encrypted_cookie = cookie existing_saml_acc.expires_at = cast(datetime.datetime, expires_at) existing_saml_acc.updated_at = func.now() saml_acc = existing_saml_acc else: saml_acc = SamlAccount( user_id=user_id, encrypted_cookie=cookie, expires_at=expires_at, ) db_session.add(saml_acc) db_session.commit() return saml_acc.expires_at async def get_saml_account( cookie: str, async_db_session: AsyncSession ) -> SamlAccount | None: """NOTE: this is async, since it's used during auth (which is necessarily async due to FastAPI Users)""" stmt = ( select(SamlAccount) .options(selectinload(SamlAccount.user)) # Use selectinload for collections .where( and_( SamlAccount.encrypted_cookie == cookie, SamlAccount.expires_at > func.now(), ) ) ) result = await async_db_session.execute(stmt) return result.scalars().unique().one_or_none() async def expire_saml_account( saml_account: SamlAccount, async_db_session: AsyncSession ) -> None: saml_account.expires_at = func.now() await async_db_session.commit() ================================================ FILE: backend/ee/onyx/db/scim.py ================================================ """SCIM Data Access Layer. All database operations for SCIM provisioning — token management, user mappings, and group mappings. Extends the base DAL (see ``onyx.db.dal``). Usage from FastAPI:: def get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL: return ScimDAL(db_session) @router.post("/tokens") def create_token(dal: ScimDAL = Depends(get_scim_dal)) -> ...: token = dal.create_token(name=..., hashed_token=..., ...) dal.commit() return token Usage from background tasks:: with ScimDAL.from_tenant("tenant_abc") as dal: mapping = dal.create_user_mapping(external_id="idp-123", user_id=uid) dal.commit() """ from __future__ import annotations from uuid import UUID from sqlalchemy import delete as sa_delete from sqlalchemy import func from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import SQLColumnExpression from sqlalchemy.dialects.postgresql import insert as pg_insert from ee.onyx.server.scim.filtering import ScimFilter from ee.onyx.server.scim.filtering import ScimFilterOperator from ee.onyx.server.scim.models import ScimMappingFields from onyx.db.dal import DAL from onyx.db.enums import AccountType from onyx.db.enums import GrantSource from onyx.db.enums import Permission from onyx.db.models import PermissionGrant from onyx.db.models import ScimGroupMapping from onyx.db.models import ScimToken from onyx.db.models import ScimUserMapping from onyx.db.models import User from onyx.db.models import User__UserGroup from onyx.db.models import UserGroup from onyx.utils.logger import setup_logger logger = setup_logger() class ScimDAL(DAL): """Data Access Layer for SCIM provisioning operations. Methods mutate but do NOT commit — call ``dal.commit()`` explicitly when you want to persist changes. This follows the existing ``_no_commit`` convention and lets callers batch multiple operations into one transaction. """ # ------------------------------------------------------------------ # Token operations # ------------------------------------------------------------------ def create_token( self, name: str, hashed_token: str, token_display: str, created_by_id: UUID, ) -> ScimToken: """Create a new SCIM bearer token. Only one token is active at a time — this method automatically revokes all existing active tokens before creating the new one. """ # Revoke any currently active tokens active_tokens = list( self._session.scalars( select(ScimToken).where(ScimToken.is_active.is_(True)) ).all() ) for t in active_tokens: t.is_active = False token = ScimToken( name=name, hashed_token=hashed_token, token_display=token_display, created_by_id=created_by_id, ) self._session.add(token) self._session.flush() return token def get_active_token(self) -> ScimToken | None: """Return the single currently active token, or None.""" return self._session.scalar( select(ScimToken).where(ScimToken.is_active.is_(True)) ) def get_token_by_hash(self, hashed_token: str) -> ScimToken | None: """Look up a token by its SHA-256 hash.""" return self._session.scalar( select(ScimToken).where(ScimToken.hashed_token == hashed_token) ) def revoke_token(self, token_id: int) -> None: """Deactivate a token by ID. Raises: ValueError: If the token does not exist. """ token = self._session.get(ScimToken, token_id) if not token: raise ValueError(f"SCIM token with id {token_id} not found") token.is_active = False def update_token_last_used(self, token_id: int) -> None: """Update the last_used_at timestamp for a token.""" token = self._session.get(ScimToken, token_id) if token: token.last_used_at = func.now() # type: ignore[assignment] # ------------------------------------------------------------------ # User mapping operations # ------------------------------------------------------------------ def create_user_mapping( self, external_id: str | None, user_id: UUID, scim_username: str | None = None, fields: ScimMappingFields | None = None, ) -> ScimUserMapping: """Create a SCIM mapping for a user. ``external_id`` may be ``None`` when the IdP omits it (RFC 7643 allows this). The mapping still marks the user as SCIM-managed. """ f = fields or ScimMappingFields() mapping = ScimUserMapping( external_id=external_id, user_id=user_id, scim_username=scim_username, department=f.department, manager=f.manager, given_name=f.given_name, family_name=f.family_name, scim_emails_json=f.scim_emails_json, ) self._session.add(mapping) self._session.flush() return mapping def get_user_mapping_by_external_id( self, external_id: str ) -> ScimUserMapping | None: """Look up a user mapping by the IdP's external identifier.""" return self._session.scalar( select(ScimUserMapping).where(ScimUserMapping.external_id == external_id) ) def get_user_mapping_by_user_id(self, user_id: UUID) -> ScimUserMapping | None: """Look up a user mapping by the Onyx user ID.""" return self._session.scalar( select(ScimUserMapping).where(ScimUserMapping.user_id == user_id) ) def list_user_mappings( self, start_index: int = 1, count: int = 100, ) -> tuple[list[ScimUserMapping], int]: """List user mappings with SCIM-style pagination. Args: start_index: 1-based start index (SCIM convention). count: Maximum number of results to return. Returns: A tuple of (mappings, total_count). """ total = ( self._session.scalar(select(func.count()).select_from(ScimUserMapping)) or 0 ) offset = max(start_index - 1, 0) mappings = list( self._session.scalars( select(ScimUserMapping) .order_by(ScimUserMapping.id) .offset(offset) .limit(count) ).all() ) return mappings, total def update_user_mapping_external_id( self, mapping_id: int, external_id: str, ) -> ScimUserMapping: """Update the external ID on a user mapping. Raises: ValueError: If the mapping does not exist. """ mapping = self._session.get(ScimUserMapping, mapping_id) if not mapping: raise ValueError(f"SCIM user mapping with id {mapping_id} not found") mapping.external_id = external_id return mapping def delete_user_mapping(self, mapping_id: int) -> None: """Delete a user mapping by ID. No-op if already deleted.""" mapping = self._session.get(ScimUserMapping, mapping_id) if not mapping: logger.warning("SCIM user mapping %d not found during delete", mapping_id) return self._session.delete(mapping) # ------------------------------------------------------------------ # User query operations # ------------------------------------------------------------------ def get_user(self, user_id: UUID) -> User | None: """Fetch a user by ID.""" return self._session.scalar( select(User).where(User.id == user_id) # type: ignore[arg-type] ) def get_user_by_email(self, email: str) -> User | None: """Fetch a user by email (case-insensitive).""" return self._session.scalar( select(User).where(func.lower(User.email) == func.lower(email)) ) def add_user(self, user: User) -> None: """Add a new user to the session and flush to assign an ID.""" self._session.add(user) self._session.flush() def update_user( self, user: User, *, email: str | None = None, is_active: bool | None = None, personal_name: str | None = None, ) -> None: """Update user attributes. Only sets fields that are provided.""" if email is not None: user.email = email if is_active is not None: user.is_active = is_active if personal_name is not None: user.personal_name = personal_name def deactivate_user(self, user: User) -> None: """Mark a user as inactive.""" user.is_active = False def list_users( self, scim_filter: ScimFilter | None, start_index: int = 1, count: int = 100, ) -> tuple[list[tuple[User, ScimUserMapping | None]], int]: """Query users with optional SCIM filter and pagination. Returns: A tuple of (list of (user, mapping) pairs, total_count). Raises: ValueError: If the filter uses an unsupported attribute. """ # Inner-join with ScimUserMapping so only SCIM-managed users appear. # Pre-existing system accounts (anonymous, admin, etc.) are excluded # unless they were explicitly linked via SCIM provisioning. query = ( select(User) .join(ScimUserMapping, ScimUserMapping.user_id == User.id) .where( User.account_type.notin_([AccountType.BOT, AccountType.EXT_PERM_USER]) ) ) if scim_filter: attr = scim_filter.attribute.lower() if attr == "username": # arg-type: fastapi-users types User.email as str, not a column expression # assignment: union return type widens but query is still Select[tuple[User]] query = _apply_scim_string_op(query, User.email, scim_filter) # type: ignore[arg-type, assignment] elif attr == "active": query = query.where( User.is_active.is_(scim_filter.value.lower() == "true") # type: ignore[attr-defined] ) elif attr == "externalid": mapping = self.get_user_mapping_by_external_id(scim_filter.value) if not mapping: return [], 0 query = query.where(User.id == mapping.user_id) # type: ignore[arg-type] else: raise ValueError( f"Unsupported filter attribute: {scim_filter.attribute}" ) # Count total matching rows first, then paginate. SCIM uses 1-based # indexing (RFC 7644 §3.4.2), so we convert to a 0-based offset. total = ( self._session.scalar(select(func.count()).select_from(query.subquery())) or 0 ) offset = max(start_index - 1, 0) users = list( self._session.scalars( query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type] ) .unique() .all() ) # Batch-fetch SCIM mappings to avoid N+1 queries mapping_map = self._get_user_mappings_batch([u.id for u in users]) return [(u, mapping_map.get(u.id)) for u in users], total def sync_user_external_id( self, user_id: UUID, new_external_id: str | None, scim_username: str | None = None, fields: ScimMappingFields | None = None, ) -> None: """Sync the SCIM mapping for a user. If a mapping already exists, its fields are updated (including setting ``external_id`` to ``None`` when the IdP omits it). If no mapping exists and ``new_external_id`` is provided, a new mapping is created. A mapping is never deleted here — SCIM-managed users must retain their mapping to remain visible in ``GET /Users``. When *fields* is provided, all mapping fields are written unconditionally — including ``None`` values — so that a caller can clear a previously-set field (e.g. removing a department). """ mapping = self.get_user_mapping_by_user_id(user_id) if mapping: if mapping.external_id != new_external_id: mapping.external_id = new_external_id if scim_username is not None: mapping.scim_username = scim_username if fields is not None: mapping.department = fields.department mapping.manager = fields.manager mapping.given_name = fields.given_name mapping.family_name = fields.family_name mapping.scim_emails_json = fields.scim_emails_json elif new_external_id: self.create_user_mapping( external_id=new_external_id, user_id=user_id, scim_username=scim_username, fields=fields, ) def _get_user_mappings_batch( self, user_ids: list[UUID] ) -> dict[UUID, ScimUserMapping]: """Batch-fetch SCIM user mappings keyed by user ID.""" if not user_ids: return {} mappings = self._session.scalars( select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids)) ).all() return {m.user_id: m for m in mappings} def get_user_groups(self, user_id: UUID) -> list[tuple[int, str]]: """Get groups a user belongs to as ``(group_id, group_name)`` pairs. Excludes groups marked for deletion. """ rels = self._session.scalars( select(User__UserGroup).where(User__UserGroup.user_id == user_id) ).all() group_ids = [r.user_group_id for r in rels] if not group_ids: return [] groups = self._session.scalars( select(UserGroup).where( UserGroup.id.in_(group_ids), UserGroup.is_up_for_deletion.is_(False), ) ).all() return [(g.id, g.name) for g in groups] def get_users_groups_batch( self, user_ids: list[UUID] ) -> dict[UUID, list[tuple[int, str]]]: """Batch-fetch group memberships for multiple users. Returns a mapping of ``user_id → [(group_id, group_name), ...]``. Avoids N+1 queries when building user list responses. """ if not user_ids: return {} rels = self._session.scalars( select(User__UserGroup).where(User__UserGroup.user_id.in_(user_ids)) ).all() group_ids = list({r.user_group_id for r in rels}) if not group_ids: return {} groups = self._session.scalars( select(UserGroup).where( UserGroup.id.in_(group_ids), UserGroup.is_up_for_deletion.is_(False), ) ).all() groups_by_id = {g.id: g.name for g in groups} result: dict[UUID, list[tuple[int, str]]] = {} for r in rels: if r.user_id and r.user_group_id in groups_by_id: result.setdefault(r.user_id, []).append( (r.user_group_id, groups_by_id[r.user_group_id]) ) return result # ------------------------------------------------------------------ # Group mapping operations # ------------------------------------------------------------------ def create_group_mapping( self, external_id: str, user_group_id: int, ) -> ScimGroupMapping: """Create a mapping between a SCIM externalId and an Onyx user group.""" mapping = ScimGroupMapping(external_id=external_id, user_group_id=user_group_id) self._session.add(mapping) self._session.flush() return mapping def get_group_mapping_by_external_id( self, external_id: str ) -> ScimGroupMapping | None: """Look up a group mapping by the IdP's external identifier.""" return self._session.scalar( select(ScimGroupMapping).where(ScimGroupMapping.external_id == external_id) ) def get_group_mapping_by_group_id( self, user_group_id: int ) -> ScimGroupMapping | None: """Look up a group mapping by the Onyx user group ID.""" return self._session.scalar( select(ScimGroupMapping).where( ScimGroupMapping.user_group_id == user_group_id ) ) def list_group_mappings( self, start_index: int = 1, count: int = 100, ) -> tuple[list[ScimGroupMapping], int]: """List group mappings with SCIM-style pagination. Args: start_index: 1-based start index (SCIM convention). count: Maximum number of results to return. Returns: A tuple of (mappings, total_count). """ total = ( self._session.scalar(select(func.count()).select_from(ScimGroupMapping)) or 0 ) offset = max(start_index - 1, 0) mappings = list( self._session.scalars( select(ScimGroupMapping) .order_by(ScimGroupMapping.id) .offset(offset) .limit(count) ).all() ) return mappings, total def delete_group_mapping(self, mapping_id: int) -> None: """Delete a group mapping by ID. No-op if already deleted.""" mapping = self._session.get(ScimGroupMapping, mapping_id) if not mapping: logger.warning("SCIM group mapping %d not found during delete", mapping_id) return self._session.delete(mapping) # ------------------------------------------------------------------ # Group query operations # ------------------------------------------------------------------ def get_group(self, group_id: int) -> UserGroup | None: """Fetch a group by ID, returning None if deleted or missing.""" group = self._session.get(UserGroup, group_id) if group and group.is_up_for_deletion: return None return group def get_group_by_name(self, name: str) -> UserGroup | None: """Fetch a group by exact name.""" return self._session.scalar(select(UserGroup).where(UserGroup.name == name)) def add_group(self, group: UserGroup) -> None: """Add a new group to the session and flush to assign an ID.""" self._session.add(group) self._session.flush() def add_permission_grant_to_group( self, group_id: int, permission: Permission, grant_source: GrantSource, ) -> None: """Grant a permission to a group and flush.""" self._session.add( PermissionGrant( group_id=group_id, permission=permission, grant_source=grant_source, ) ) self._session.flush() def update_group( self, group: UserGroup, *, name: str | None = None, ) -> None: """Update group attributes and set the modification timestamp.""" if name is not None: group.name = name group.time_last_modified_by_user = func.now() def delete_group(self, group: UserGroup) -> None: """Delete a group from the session.""" self._session.delete(group) def list_groups( self, scim_filter: ScimFilter | None, start_index: int = 1, count: int = 100, ) -> tuple[list[tuple[UserGroup, str | None]], int]: """Query groups with optional SCIM filter and pagination. Returns: A tuple of (list of (group, external_id) pairs, total_count). Raises: ValueError: If the filter uses an unsupported attribute. """ query = select(UserGroup).where(UserGroup.is_up_for_deletion.is_(False)) if scim_filter: attr = scim_filter.attribute.lower() if attr == "displayname": # assignment: union return type widens but query is still Select[tuple[UserGroup]] query = _apply_scim_string_op(query, UserGroup.name, scim_filter) # type: ignore[assignment] elif attr == "externalid": mapping = self.get_group_mapping_by_external_id(scim_filter.value) if not mapping: return [], 0 query = query.where(UserGroup.id == mapping.user_group_id) else: raise ValueError( f"Unsupported filter attribute: {scim_filter.attribute}" ) total = ( self._session.scalar(select(func.count()).select_from(query.subquery())) or 0 ) offset = max(start_index - 1, 0) groups = list( self._session.scalars( query.order_by(UserGroup.id).offset(offset).limit(count) ).all() ) ext_id_map = self._get_group_external_ids([g.id for g in groups]) return [(g, ext_id_map.get(g.id)) for g in groups], total def get_group_members(self, group_id: int) -> list[tuple[UUID, str | None]]: """Get group members as (user_id, email) pairs.""" rels = self._session.scalars( select(User__UserGroup).where(User__UserGroup.user_group_id == group_id) ).all() user_ids = [r.user_id for r in rels if r.user_id] if not user_ids: return [] users = ( self._session.scalars( select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined] ) .unique() .all() ) users_by_id = {u.id: u for u in users} return [ ( r.user_id, users_by_id[r.user_id].email if r.user_id in users_by_id else None, ) for r in rels if r.user_id ] def validate_member_ids(self, uuids: list[UUID]) -> list[UUID]: """Return the subset of UUIDs that don't exist as users. Returns an empty list if all IDs are valid. """ if not uuids: return [] existing_users = ( self._session.scalars( select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined] ) .unique() .all() ) existing_ids = {u.id for u in existing_users} return [uid for uid in uuids if uid not in existing_ids] def upsert_group_members(self, group_id: int, user_ids: list[UUID]) -> None: """Add user-group relationships, ignoring duplicates.""" if not user_ids: return self._session.execute( pg_insert(User__UserGroup) .values([{"user_id": uid, "user_group_id": group_id} for uid in user_ids]) .on_conflict_do_nothing( index_elements=[ User__UserGroup.user_group_id, User__UserGroup.user_id, ] ) ) def replace_group_members(self, group_id: int, user_ids: list[UUID]) -> None: """Replace all members of a group.""" self._session.execute( sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group_id) ) self.upsert_group_members(group_id, user_ids) def remove_group_members(self, group_id: int, user_ids: list[UUID]) -> None: """Remove specific members from a group.""" if not user_ids: return self._session.execute( sa_delete(User__UserGroup).where( User__UserGroup.user_group_id == group_id, User__UserGroup.user_id.in_(user_ids), ) ) def delete_group_with_members(self, group: UserGroup) -> None: """Remove all member relationships and delete the group.""" self._session.execute( sa_delete(User__UserGroup).where(User__UserGroup.user_group_id == group.id) ) self._session.delete(group) def sync_group_external_id( self, group_id: int, new_external_id: str | None ) -> None: """Create, update, or delete the external ID mapping for a group.""" mapping = self.get_group_mapping_by_group_id(group_id) if new_external_id: if mapping: if mapping.external_id != new_external_id: mapping.external_id = new_external_id else: self.create_group_mapping( external_id=new_external_id, user_group_id=group_id ) elif mapping: self.delete_group_mapping(mapping.id) def _get_group_external_ids(self, group_ids: list[int]) -> dict[int, str]: """Batch-fetch external IDs for a list of group IDs.""" if not group_ids: return {} mappings = self._session.scalars( select(ScimGroupMapping).where( ScimGroupMapping.user_group_id.in_(group_ids) ) ).all() return {m.user_group_id: m.external_id for m in mappings} # --------------------------------------------------------------------------- # Module-level helpers (used by DAL methods above) # --------------------------------------------------------------------------- def _apply_scim_string_op( query: Select[tuple[User]] | Select[tuple[UserGroup]], column: SQLColumnExpression[str], scim_filter: ScimFilter, ) -> Select[tuple[User]] | Select[tuple[UserGroup]]: """Apply a SCIM string filter operator using SQLAlchemy column operators. Handles eq (case-insensitive exact), co (contains), and sw (starts with). SQLAlchemy's operators handle LIKE-pattern escaping internally. """ val = scim_filter.value if scim_filter.operator == ScimFilterOperator.EQUAL: return query.where(func.lower(column) == val.lower()) elif scim_filter.operator == ScimFilterOperator.CONTAINS: return query.where(column.icontains(val, autoescape=True)) elif scim_filter.operator == ScimFilterOperator.STARTS_WITH: return query.where(column.istartswith(val, autoescape=True)) else: raise ValueError(f"Unsupported string filter operator: {scim_filter.operator}") ================================================ FILE: backend/ee/onyx/db/search.py ================================================ import uuid from datetime import timedelta from uuid import UUID from sqlalchemy import select from sqlalchemy.orm import Session from onyx.db.engine.time_utils import get_db_current_time from onyx.db.models import SearchQuery def create_search_query( db_session: Session, user_id: UUID, query: str, query_expansions: list[str] | None = None, ) -> SearchQuery: """Create and persist a `SearchQuery` row. Notes: - `SearchQuery.id` is a UUID PK without a server-side default, so we generate it. - `created_at` is filled by the DB (server_default=now()). """ search_query = SearchQuery( id=uuid.uuid4(), user_id=user_id, query=query, query_expansions=query_expansions, ) db_session.add(search_query) db_session.commit() db_session.refresh(search_query) return search_query def fetch_search_queries_for_user( db_session: Session, user_id: UUID, filter_days: int | None = None, limit: int | None = None, ) -> list[SearchQuery]: """Fetch `SearchQuery` rows for a user. Args: user_id: User UUID. filter_days: Optional time filter. If provided, only rows created within the last `filter_days` days are returned. limit: Optional max number of rows to return. """ if filter_days is not None and filter_days <= 0: raise ValueError("filter_days must be > 0") stmt = select(SearchQuery).where(SearchQuery.user_id == user_id) if filter_days is not None and filter_days > 0: cutoff = get_db_current_time(db_session) - timedelta(days=filter_days) stmt = stmt.where(SearchQuery.created_at >= cutoff) stmt = stmt.order_by(SearchQuery.created_at.desc()) if limit is not None: stmt = stmt.limit(limit) return list(db_session.scalars(stmt).all()) ================================================ FILE: backend/ee/onyx/db/standard_answer.py ================================================ import re import string from collections.abc import Sequence from sqlalchemy import select from sqlalchemy.orm import Session from onyx.db.models import StandardAnswer from onyx.db.models import StandardAnswerCategory from onyx.utils.logger import setup_logger logger = setup_logger() def check_category_validity(category_name: str) -> bool: """If a category name is too long, it should not be used (it will cause an error in Postgres as the unique constraint can only apply to entries that are less than 2704 bytes). Additionally, extremely long categories are not really usable / useful.""" if len(category_name) > 255: logger.error( f"Category with name '{category_name}' is too long, cannot be used" ) return False return True def insert_standard_answer_category( category_name: str, db_session: Session ) -> StandardAnswerCategory: if not check_category_validity(category_name): raise ValueError(f"Invalid category name: {category_name}") standard_answer_category = StandardAnswerCategory(name=category_name) db_session.add(standard_answer_category) db_session.commit() return standard_answer_category def insert_standard_answer( keyword: str, answer: str, category_ids: list[int], match_regex: bool, match_any_keywords: bool, db_session: Session, ) -> StandardAnswer: existing_categories = fetch_standard_answer_categories_by_ids( standard_answer_category_ids=category_ids, db_session=db_session, ) if len(existing_categories) != len(category_ids): raise ValueError(f"Some or all categories with ids {category_ids} do not exist") standard_answer = StandardAnswer( keyword=keyword, answer=answer, categories=existing_categories, active=True, match_regex=match_regex, match_any_keywords=match_any_keywords, ) db_session.add(standard_answer) db_session.commit() return standard_answer def update_standard_answer( standard_answer_id: int, keyword: str, answer: str, category_ids: list[int], match_regex: bool, match_any_keywords: bool, db_session: Session, ) -> StandardAnswer: standard_answer = db_session.scalar( select(StandardAnswer).where(StandardAnswer.id == standard_answer_id) ) if standard_answer is None: raise ValueError(f"No standard answer with id {standard_answer_id}") existing_categories = fetch_standard_answer_categories_by_ids( standard_answer_category_ids=category_ids, db_session=db_session, ) if len(existing_categories) != len(category_ids): raise ValueError(f"Some or all categories with ids {category_ids} do not exist") standard_answer.keyword = keyword standard_answer.answer = answer standard_answer.categories = list(existing_categories) standard_answer.match_regex = match_regex standard_answer.match_any_keywords = match_any_keywords db_session.commit() return standard_answer def remove_standard_answer( standard_answer_id: int, db_session: Session, ) -> None: standard_answer = db_session.scalar( select(StandardAnswer).where(StandardAnswer.id == standard_answer_id) ) if standard_answer is None: raise ValueError(f"No standard answer with id {standard_answer_id}") standard_answer.active = False db_session.commit() def update_standard_answer_category( standard_answer_category_id: int, category_name: str, db_session: Session, ) -> StandardAnswerCategory: standard_answer_category = db_session.scalar( select(StandardAnswerCategory).where( StandardAnswerCategory.id == standard_answer_category_id ) ) if standard_answer_category is None: raise ValueError( f"No standard answer category with id {standard_answer_category_id}" ) if not check_category_validity(category_name): raise ValueError(f"Invalid category name: {category_name}") standard_answer_category.name = category_name db_session.commit() return standard_answer_category def fetch_standard_answer_category( standard_answer_category_id: int, db_session: Session, ) -> StandardAnswerCategory | None: return db_session.scalar( select(StandardAnswerCategory).where( StandardAnswerCategory.id == standard_answer_category_id ) ) def fetch_standard_answer_categories_by_ids( standard_answer_category_ids: list[int], db_session: Session, ) -> Sequence[StandardAnswerCategory]: return db_session.scalars( select(StandardAnswerCategory).where( StandardAnswerCategory.id.in_(standard_answer_category_ids) ) ).all() def fetch_standard_answer_categories( db_session: Session, ) -> Sequence[StandardAnswerCategory]: return db_session.scalars(select(StandardAnswerCategory)).all() def fetch_standard_answer( standard_answer_id: int, db_session: Session, ) -> StandardAnswer | None: return db_session.scalar( select(StandardAnswer).where(StandardAnswer.id == standard_answer_id) ) def fetch_standard_answers(db_session: Session) -> Sequence[StandardAnswer]: return db_session.scalars( select(StandardAnswer).where(StandardAnswer.active.is_(True)) ).all() def create_initial_default_standard_answer_category(db_session: Session) -> None: default_category_id = 0 default_category_name = "General" default_category = fetch_standard_answer_category( standard_answer_category_id=default_category_id, db_session=db_session, ) if default_category is not None: if default_category.name != default_category_name: raise ValueError( "DB is not in a valid initial state. Default standard answer category does not have expected name." ) return standard_answer_category = StandardAnswerCategory( id=default_category_id, name=default_category_name, ) db_session.add(standard_answer_category) db_session.commit() def fetch_standard_answer_categories_by_names( standard_answer_category_names: list[str], db_session: Session, ) -> Sequence[StandardAnswerCategory]: return db_session.scalars( select(StandardAnswerCategory).where( StandardAnswerCategory.name.in_(standard_answer_category_names) ) ).all() def find_matching_standard_answers( id_in: list[int], query: str, db_session: Session, ) -> list[tuple[StandardAnswer, str]]: """ Returns a list of tuples, where each tuple is a StandardAnswer definition matching the query and a string representing the match (either the regex match group or the set of keywords). If `answer_instance.match_regex` is true, the definition is considered "matched" if the query matches the `answer_instance.keyword` using `re.search`. Otherwise, the definition is considered "matched" if the space-delimited tokens in `keyword` exists in `query`, depending on the state of `match_any_keywords` """ stmt = ( select(StandardAnswer) .where(StandardAnswer.active.is_(True)) .where(StandardAnswer.id.in_(id_in)) ) possible_standard_answers: Sequence[StandardAnswer] = db_session.scalars(stmt).all() matching_standard_answers: list[tuple[StandardAnswer, str]] = [] for standard_answer in possible_standard_answers: if standard_answer.match_regex: maybe_matches = re.search(standard_answer.keyword, query, re.IGNORECASE) if maybe_matches is not None: match_group = maybe_matches.group(0) matching_standard_answers.append((standard_answer, match_group)) else: # Remove punctuation and split the keyword into individual words keyword_words = set( "".join( char for char in standard_answer.keyword.lower() if char not in string.punctuation ).split() ) # Remove punctuation and split the query into individual words query_words = "".join( char for char in query.lower() if char not in string.punctuation ).split() # Check if all of the keyword words are in the query words if standard_answer.match_any_keywords: for word in query_words: if word in keyword_words: matching_standard_answers.append((standard_answer, word)) break else: if all(word in query_words for word in keyword_words): matching_standard_answers.append( ( standard_answer, re.sub(r"\s+?", ", ", standard_answer.keyword), ) ) return matching_standard_answers ================================================ FILE: backend/ee/onyx/db/token_limit.py ================================================ from collections.abc import Sequence from sqlalchemy import exists from sqlalchemy import Row from sqlalchemy import Select from sqlalchemy import select from sqlalchemy.orm import aliased from sqlalchemy.orm import Session from onyx.configs.constants import TokenRateLimitScope from onyx.db.models import TokenRateLimit from onyx.db.models import TokenRateLimit__UserGroup from onyx.db.models import User from onyx.db.models import User__UserGroup from onyx.db.models import UserGroup from onyx.db.models import UserRole from onyx.server.token_rate_limits.models import TokenRateLimitArgs def _add_user_filters(stmt: Select, user: User, get_editable: bool = True) -> Select: if user.role == UserRole.ADMIN: return stmt # If anonymous user, only show global/public token_rate_limits if user.is_anonymous: where_clause = TokenRateLimit.scope == TokenRateLimitScope.GLOBAL return stmt.where(where_clause) stmt = stmt.distinct() TRLimit_UG = aliased(TokenRateLimit__UserGroup) User__UG = aliased(User__UserGroup) """ Here we select token_rate_limits by relation: User -> User__UserGroup -> TokenRateLimit__UserGroup -> TokenRateLimit """ stmt = stmt.outerjoin(TRLimit_UG).outerjoin( User__UG, User__UG.user_group_id == TRLimit_UG.user_group_id, ) """ Filter token_rate_limits by: - if the user is in the user_group that owns the token_rate_limit - if the user is not a global_curator, they must also have a curator relationship to the user_group - if editing is being done, we also filter out token_rate_limits that are owned by groups that the user isn't a curator for - if we are not editing, we show all token_rate_limits in the groups the user curates """ where_clause = User__UG.user_id == user.id if user.role == UserRole.CURATOR and get_editable: where_clause &= User__UG.is_curator == True # noqa: E712 if get_editable: user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id) if user.role == UserRole.CURATOR: user_groups = user_groups.where( User__UserGroup.is_curator == True # noqa: E712 ) where_clause &= ( ~exists() .where(TRLimit_UG.rate_limit_id == TokenRateLimit.id) .where(~TRLimit_UG.user_group_id.in_(user_groups)) .correlate(TokenRateLimit) ) return stmt.where(where_clause) def fetch_all_user_group_token_rate_limits_by_group( db_session: Session, ) -> Sequence[Row[tuple[TokenRateLimit, str]]]: query = ( select(TokenRateLimit, UserGroup.name) .join( TokenRateLimit__UserGroup, TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id, ) .join(UserGroup, UserGroup.id == TokenRateLimit__UserGroup.user_group_id) ) return db_session.execute(query).all() def insert_user_group_token_rate_limit( db_session: Session, token_rate_limit_settings: TokenRateLimitArgs, group_id: int, ) -> TokenRateLimit: token_limit = TokenRateLimit( enabled=token_rate_limit_settings.enabled, token_budget=token_rate_limit_settings.token_budget, period_hours=token_rate_limit_settings.period_hours, scope=TokenRateLimitScope.USER_GROUP, ) db_session.add(token_limit) db_session.flush() rate_limit = TokenRateLimit__UserGroup( rate_limit_id=token_limit.id, user_group_id=group_id ) db_session.add(rate_limit) db_session.commit() return token_limit def fetch_user_group_token_rate_limits_for_user( db_session: Session, group_id: int, user: User, enabled_only: bool = False, ordered: bool = True, get_editable: bool = True, ) -> Sequence[TokenRateLimit]: stmt = ( select(TokenRateLimit) .join( TokenRateLimit__UserGroup, TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id, ) .where(TokenRateLimit__UserGroup.user_group_id == group_id) ) stmt = _add_user_filters(stmt, user, get_editable) if enabled_only: stmt = stmt.where(TokenRateLimit.enabled.is_(True)) if ordered: stmt = stmt.order_by(TokenRateLimit.created_at.desc()) return db_session.scalars(stmt).all() ================================================ FILE: backend/ee/onyx/db/usage_export.py ================================================ import uuid from collections.abc import Generator from datetime import datetime from typing import IO from typing import Optional from fastapi_users_db_sqlalchemy import UUID_ID from sqlalchemy import cast from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Session from ee.onyx.db.query_history import fetch_chat_sessions_eagerly_by_time from ee.onyx.server.reporting.usage_export_models import ChatMessageSkeleton from ee.onyx.server.reporting.usage_export_models import FlowType from ee.onyx.server.reporting.usage_export_models import UsageReportMetadata from onyx.configs.constants import MessageType from onyx.db.models import UsageReport from onyx.db.models import User from onyx.file_store.file_store import get_default_file_store # Gets skeletons of all messages in the given range def get_empty_chat_messages_entries__paginated( db_session: Session, period: tuple[datetime, datetime], limit: int | None = 500, initial_time: datetime | None = None, ) -> tuple[Optional[datetime], list[ChatMessageSkeleton]]: """Returns a tuple where: first element is the most recent timestamp out of the sessions iterated - this timestamp can be used to paginate forward in time second element is a list of messages belonging to all the sessions iterated Only messages of type USER are returned """ chat_sessions = fetch_chat_sessions_eagerly_by_time( start=period[0], end=period[1], db_session=db_session, limit=limit, initial_time=initial_time, ) message_skeletons: list[ChatMessageSkeleton] = [] for chat_session in chat_sessions: flow_type = FlowType.SLACK if chat_session.onyxbot_flow else FlowType.CHAT for message in chat_session.messages: # Only count user messages if message.message_type != MessageType.USER: continue # Get user email user_email = chat_session.user.email if chat_session.user else None # Get assistant name (from session persona, or alternate if specified) assistant_name = None if chat_session.persona: assistant_name = chat_session.persona.name message_skeletons.append( ChatMessageSkeleton( message_id=message.id, chat_session_id=chat_session.id, user_id=str(chat_session.user_id) if chat_session.user_id else None, flow_type=flow_type, time_sent=message.time_sent, assistant_name=assistant_name, user_email=user_email, number_of_tokens=message.token_count, ) ) if len(chat_sessions) == 0: return None, [] return chat_sessions[-1].time_created, message_skeletons def get_all_empty_chat_message_entries( db_session: Session, period: tuple[datetime, datetime], ) -> Generator[list[ChatMessageSkeleton], None, None]: """period is the range of time over which to fetch messages.""" initial_time: Optional[datetime] = period[0] while True: # iterate from oldest to newest time_created, message_skeletons = get_empty_chat_messages_entries__paginated( db_session, period, initial_time=initial_time, ) if not message_skeletons: return yield message_skeletons # Update initial_time for the next iteration initial_time = time_created def get_all_usage_reports(db_session: Session) -> list[UsageReportMetadata]: # Get the user emails usage_reports = db_session.query(UsageReport).all() user_ids = {r.requestor_user_id for r in usage_reports if r.requestor_user_id} user_emails = { user.id: user.email for user in db_session.query(User) .filter(cast(User.id, UUID).in_(user_ids)) .all() } return [ UsageReportMetadata( report_name=r.report_name, requestor=( user_emails.get(r.requestor_user_id) if r.requestor_user_id else None ), time_created=r.time_created, period_from=r.period_from, period_to=r.period_to, ) for r in usage_reports ] def get_usage_report_data( report_display_name: str, ) -> IO: """ Get the usage report data from the file store. Args: db_session: The database session. report_display_name: The display name of the usage report. Also assumes that the file is stored with this as the ID in the file store. Returns: The usage report data. """ file_store = get_default_file_store() # usage report may be very large, so don't load it all into memory return file_store.read_file( file_id=report_display_name, mode="b", use_tempfile=True ) def write_usage_report( db_session: Session, report_name: str, user_id: uuid.UUID | UUID_ID | None, period: tuple[datetime, datetime] | None, ) -> UsageReport: new_report = UsageReport( report_name=report_name, requestor_user_id=user_id, period_from=period[0] if period else None, period_to=period[1] if period else None, ) db_session.add(new_report) db_session.commit() return new_report ================================================ FILE: backend/ee/onyx/db/user_group.py ================================================ from collections.abc import Sequence from operator import and_ from uuid import UUID from fastapi import HTTPException from sqlalchemy import delete from sqlalchemy import func from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import update from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from ee.onyx.server.user_group.models import SetCuratorRequest from ee.onyx.server.user_group.models import UserGroupCreate from ee.onyx.server.user_group.models import UserGroupUpdate from onyx.configs.app_configs import DISABLE_VECTOR_DB from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import GrantSource from onyx.db.enums import Permission from onyx.db.models import ConnectorCredentialPair from onyx.db.models import Credential from onyx.db.models import Credential__UserGroup from onyx.db.models import Document from onyx.db.models import DocumentByConnectorCredentialPair from onyx.db.models import DocumentSet from onyx.db.models import DocumentSet__UserGroup from onyx.db.models import FederatedConnector__DocumentSet from onyx.db.models import LLMProvider__UserGroup from onyx.db.models import PermissionGrant from onyx.db.models import Persona from onyx.db.models import Persona__UserGroup from onyx.db.models import TokenRateLimit__UserGroup from onyx.db.models import User from onyx.db.models import User__UserGroup from onyx.db.models import UserGroup from onyx.db.models import UserGroup__ConnectorCredentialPair from onyx.db.models import UserRole from onyx.db.permissions import recompute_user_permissions__no_commit from onyx.db.users import fetch_user_by_id from onyx.utils.logger import setup_logger logger = setup_logger() def _cleanup_user__user_group_relationships__no_commit( db_session: Session, user_group_id: int, user_ids: list[UUID] | None = None, ) -> None: """NOTE: does not commit the transaction.""" where_clause = User__UserGroup.user_group_id == user_group_id if user_ids: where_clause &= User__UserGroup.user_id.in_(user_ids) user__user_group_relationships = db_session.scalars( select(User__UserGroup).where(where_clause) ).all() for user__user_group_relationship in user__user_group_relationships: db_session.delete(user__user_group_relationship) def _cleanup_credential__user_group_relationships__no_commit( db_session: Session, user_group_id: int, ) -> None: """NOTE: does not commit the transaction.""" db_session.query(Credential__UserGroup).filter( Credential__UserGroup.user_group_id == user_group_id ).delete(synchronize_session=False) def _cleanup_llm_provider__user_group_relationships__no_commit( db_session: Session, user_group_id: int ) -> None: """NOTE: does not commit the transaction.""" db_session.query(LLMProvider__UserGroup).filter( LLMProvider__UserGroup.user_group_id == user_group_id ).delete(synchronize_session=False) def _cleanup_persona__user_group_relationships__no_commit( db_session: Session, user_group_id: int ) -> None: """NOTE: does not commit the transaction.""" db_session.query(Persona__UserGroup).filter( Persona__UserGroup.user_group_id == user_group_id ).delete(synchronize_session=False) def _cleanup_token_rate_limit__user_group_relationships__no_commit( db_session: Session, user_group_id: int ) -> None: """NOTE: does not commit the transaction.""" token_rate_limit__user_group_relationships = db_session.scalars( select(TokenRateLimit__UserGroup).where( TokenRateLimit__UserGroup.user_group_id == user_group_id ) ).all() for ( token_rate_limit__user_group_relationship ) in token_rate_limit__user_group_relationships: db_session.delete(token_rate_limit__user_group_relationship) def _cleanup_user_group__cc_pair_relationships__no_commit( db_session: Session, user_group_id: int, outdated_only: bool ) -> None: """NOTE: does not commit the transaction.""" stmt = select(UserGroup__ConnectorCredentialPair).where( UserGroup__ConnectorCredentialPair.user_group_id == user_group_id ) if outdated_only: stmt = stmt.where( UserGroup__ConnectorCredentialPair.is_current == False # noqa: E712 ) user_group__cc_pair_relationships = db_session.scalars(stmt) for user_group__cc_pair_relationship in user_group__cc_pair_relationships: db_session.delete(user_group__cc_pair_relationship) def _cleanup_document_set__user_group_relationships__no_commit( db_session: Session, user_group_id: int ) -> None: """NOTE: does not commit the transaction.""" db_session.execute( delete(DocumentSet__UserGroup).where( DocumentSet__UserGroup.user_group_id == user_group_id ) ) def validate_object_creation_for_user( db_session: Session, user: User, target_group_ids: list[int] | None = None, object_is_public: bool | None = None, object_is_perm_sync: bool | None = None, object_is_owned_by_user: bool = False, object_is_new: bool = False, ) -> None: """ All users can create/edit permission synced objects if they don't specify a group All admin actions are allowed. Curators and global curators can create public objects. Prevents other non-admins from creating/editing: - public objects - objects with no groups - objects that belong to a group they don't curate """ if object_is_perm_sync and not target_group_ids: return # Admins are allowed if user.role == UserRole.ADMIN: return # Allow curators and global curators to create public objects # w/o associated groups IF the object is new/owned by them if ( object_is_public and user.role in [UserRole.CURATOR, UserRole.GLOBAL_CURATOR] and (object_is_new or object_is_owned_by_user) ): return if object_is_public and user.role == UserRole.BASIC: detail = "User does not have permission to create public objects" logger.error(detail) raise HTTPException( status_code=400, detail=detail, ) if not target_group_ids: detail = "Curators must specify 1+ groups" logger.error(detail) raise HTTPException( status_code=400, detail=detail, ) user_curated_groups = fetch_user_groups_for_user( db_session=db_session, user_id=user.id, # Global curators can curate all groups they are member of only_curator_groups=user.role != UserRole.GLOBAL_CURATOR, ) user_curated_group_ids = set([group.id for group in user_curated_groups]) target_group_ids_set = set(target_group_ids) if not target_group_ids_set.issubset(user_curated_group_ids): detail = "Curators cannot control groups they don't curate" logger.error(detail) raise HTTPException( status_code=400, detail=detail, ) def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | None: stmt = select(UserGroup).where(UserGroup.id == user_group_id) return db_session.scalar(stmt) def _add_user_group_snapshot_eager_loads( stmt: Select, ) -> Select: """Add eager loading options needed by UserGroup.from_model snapshot creation.""" return stmt.options( selectinload(UserGroup.users), selectinload(UserGroup.user_group_relationships), selectinload(UserGroup.cc_pair_relationships) .selectinload(UserGroup__ConnectorCredentialPair.cc_pair) .options( selectinload(ConnectorCredentialPair.connector), selectinload(ConnectorCredentialPair.credential).selectinload( Credential.user ), ), selectinload(UserGroup.document_sets).options( selectinload(DocumentSet.connector_credential_pairs).selectinload( ConnectorCredentialPair.connector ), selectinload(DocumentSet.users), selectinload(DocumentSet.groups), selectinload(DocumentSet.federated_connectors).selectinload( FederatedConnector__DocumentSet.federated_connector ), ), selectinload(UserGroup.personas).options( selectinload(Persona.tools), selectinload(Persona.hierarchy_nodes), selectinload(Persona.attached_documents).selectinload( Document.parent_hierarchy_node ), selectinload(Persona.labels), selectinload(Persona.document_sets).options( selectinload(DocumentSet.connector_credential_pairs).selectinload( ConnectorCredentialPair.connector ), selectinload(DocumentSet.users), selectinload(DocumentSet.groups), selectinload(DocumentSet.federated_connectors).selectinload( FederatedConnector__DocumentSet.federated_connector ), ), selectinload(Persona.user), selectinload(Persona.user_files), selectinload(Persona.users), selectinload(Persona.groups), ), ) def fetch_user_groups( db_session: Session, only_up_to_date: bool = True, eager_load_for_snapshot: bool = False, include_default: bool = True, ) -> Sequence[UserGroup]: """ Fetches user groups from the database. This function retrieves a sequence of `UserGroup` objects from the database. If `only_up_to_date` is set to `True`, it filters the user groups to return only those that are marked as up-to-date (`is_up_to_date` is `True`). Args: db_session (Session): The SQLAlchemy session used to query the database. only_up_to_date (bool, optional): Flag to determine whether to filter the results to include only up to date user groups. Defaults to `True`. eager_load_for_snapshot: If True, adds eager loading for all relationships needed by UserGroup.from_model snapshot creation. include_default: If False, excludes system default groups (is_default=True). Returns: Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria. """ stmt = select(UserGroup) if only_up_to_date: stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712 if not include_default: stmt = stmt.where(UserGroup.is_default == False) # noqa: E712 if eager_load_for_snapshot: stmt = _add_user_group_snapshot_eager_loads(stmt) return db_session.scalars(stmt).unique().all() def fetch_user_groups_for_user( db_session: Session, user_id: UUID, only_curator_groups: bool = False, eager_load_for_snapshot: bool = False, include_default: bool = True, ) -> Sequence[UserGroup]: stmt = ( select(UserGroup) .join(User__UserGroup, User__UserGroup.user_group_id == UserGroup.id) .join(User, User.id == User__UserGroup.user_id) # type: ignore .where(User.id == user_id) # type: ignore ) if only_curator_groups: stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712 if not include_default: stmt = stmt.where(UserGroup.is_default == False) # noqa: E712 if eager_load_for_snapshot: stmt = _add_user_group_snapshot_eager_loads(stmt) return db_session.scalars(stmt).unique().all() def construct_document_id_select_by_usergroup( user_group_id: int, ) -> Select: """This returns a statement that should be executed using .yield_per() to minimize overhead. The primary consumers of this function are background processing task generators.""" stmt = ( select(Document.id) .join( DocumentByConnectorCredentialPair, Document.id == DocumentByConnectorCredentialPair.id, ) .join( ConnectorCredentialPair, and_( DocumentByConnectorCredentialPair.connector_id == ConnectorCredentialPair.connector_id, DocumentByConnectorCredentialPair.credential_id == ConnectorCredentialPair.credential_id, ), ) .join( UserGroup__ConnectorCredentialPair, UserGroup__ConnectorCredentialPair.cc_pair_id == ConnectorCredentialPair.id, ) .join( UserGroup, UserGroup__ConnectorCredentialPair.user_group_id == UserGroup.id, ) .where(UserGroup.id == user_group_id) .order_by(Document.id) ) stmt = stmt.distinct() return stmt def fetch_documents_for_user_group_paginated( db_session: Session, user_group_id: int, last_document_id: str | None = None, limit: int = 100, ) -> tuple[Sequence[Document], str | None]: stmt = ( select(Document) .join( DocumentByConnectorCredentialPair, Document.id == DocumentByConnectorCredentialPair.id, ) .join( ConnectorCredentialPair, and_( DocumentByConnectorCredentialPair.connector_id == ConnectorCredentialPair.connector_id, DocumentByConnectorCredentialPair.credential_id == ConnectorCredentialPair.credential_id, ), ) .join( UserGroup__ConnectorCredentialPair, UserGroup__ConnectorCredentialPair.cc_pair_id == ConnectorCredentialPair.id, ) .join( UserGroup, UserGroup__ConnectorCredentialPair.user_group_id == UserGroup.id, ) .where(UserGroup.id == user_group_id) .order_by(Document.id) .limit(limit) ) if last_document_id is not None: stmt = stmt.where(Document.id > last_document_id) stmt = stmt.distinct() documents = db_session.scalars(stmt).all() return documents, documents[-1].id if documents else None def fetch_user_groups_for_documents( db_session: Session, document_ids: list[str], ) -> Sequence[tuple[str, list[str]]]: """ Fetches all user groups that have access to the given documents. NOTE: this doesn't include groups if the cc_pair is access type SYNC """ stmt = ( select(Document.id, func.array_agg(UserGroup.name)) .join( UserGroup__ConnectorCredentialPair, UserGroup.id == UserGroup__ConnectorCredentialPair.user_group_id, ) .join( ConnectorCredentialPair, and_( ConnectorCredentialPair.id == UserGroup__ConnectorCredentialPair.cc_pair_id, ConnectorCredentialPair.access_type != AccessType.SYNC, ), ) .join( DocumentByConnectorCredentialPair, and_( DocumentByConnectorCredentialPair.connector_id == ConnectorCredentialPair.connector_id, DocumentByConnectorCredentialPair.credential_id == ConnectorCredentialPair.credential_id, ), ) .join(Document, Document.id == DocumentByConnectorCredentialPair.id) .where(Document.id.in_(document_ids)) .where(UserGroup__ConnectorCredentialPair.is_current == True) # noqa: E712 # don't include CC pairs that are being deleted # NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them .where(ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING) .group_by(Document.id) ) return db_session.execute(stmt).all() # type: ignore def _check_user_group_is_modifiable(user_group: UserGroup) -> None: if not user_group.is_up_to_date: raise ValueError( "Specified user group is currently syncing. Wait until the current sync has finished before editing." ) def _add_user__user_group_relationships__no_commit( db_session: Session, user_group_id: int, user_ids: list[UUID] ) -> None: """NOTE: does not commit the transaction. This function is idempotent - it will skip users who are already in the group to avoid duplicate key violations during concurrent operations or re-syncs. Uses ON CONFLICT DO NOTHING to keep inserts atomic under concurrency. """ if not user_ids: return insert_stmt = ( insert(User__UserGroup) .values( [ {"user_id": user_id, "user_group_id": user_group_id} for user_id in user_ids ] ) .on_conflict_do_nothing( index_elements=[User__UserGroup.user_group_id, User__UserGroup.user_id] ) ) db_session.execute(insert_stmt) def _add_user_group__cc_pair_relationships__no_commit( db_session: Session, user_group_id: int, cc_pair_ids: list[int] ) -> list[UserGroup__ConnectorCredentialPair]: """NOTE: does not commit the transaction.""" relationships = [ UserGroup__ConnectorCredentialPair( user_group_id=user_group_id, cc_pair_id=cc_pair_id ) for cc_pair_id in cc_pair_ids ] db_session.add_all(relationships) return relationships def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserGroup: db_user_group = UserGroup( name=user_group.name, time_last_modified_by_user=func.now(), is_up_to_date=DISABLE_VECTOR_DB, ) db_session.add(db_user_group) db_session.flush() # give the group an ID # Every group gets the "basic" permission by default db_session.add( PermissionGrant( group_id=db_user_group.id, permission=Permission.BASIC_ACCESS, grant_source=GrantSource.SYSTEM, ) ) db_session.flush() _add_user__user_group_relationships__no_commit( db_session=db_session, user_group_id=db_user_group.id, user_ids=user_group.user_ids, ) _add_user_group__cc_pair_relationships__no_commit( db_session=db_session, user_group_id=db_user_group.id, cc_pair_ids=user_group.cc_pair_ids, ) recompute_user_permissions__no_commit(user_group.user_ids, db_session) db_session.commit() return db_user_group def _mark_user_group__cc_pair_relationships_outdated__no_commit( db_session: Session, user_group_id: int ) -> None: """NOTE: does not commit the transaction.""" user_group__cc_pair_relationships = db_session.scalars( select(UserGroup__ConnectorCredentialPair).where( UserGroup__ConnectorCredentialPair.user_group_id == user_group_id ) ) for user_group__cc_pair_relationship in user_group__cc_pair_relationships: user_group__cc_pair_relationship.is_current = False def _validate_curator_status__no_commit( db_session: Session, users: list[User], ) -> None: for user in users: # Check if the user is a curator in any of their groups curator_relationships = ( db_session.query(User__UserGroup) .filter( User__UserGroup.user_id == user.id, User__UserGroup.is_curator == True, # noqa: E712 ) .all() ) # if the user is a curator in any of their groups, set their role to CURATOR # otherwise, set their role to BASIC only if they were previously a CURATOR if curator_relationships: user.role = UserRole.CURATOR elif user.role == UserRole.CURATOR: user.role = UserRole.BASIC db_session.add(user) def remove_curator_status__no_commit(db_session: Session, user: User) -> None: stmt = ( update(User__UserGroup) .where(User__UserGroup.user_id == user.id) .values(is_curator=False) ) db_session.execute(stmt) _validate_curator_status__no_commit(db_session, [user]) def _validate_curator_relationship_update_requester( db_session: Session, user_group_id: int, user_making_change: User, ) -> None: """ This function validates that the user making the change has the necessary permissions to update the curator relationship for the target user in the given user group. """ # Admins can update curator relationships for any group if user_making_change.role == UserRole.ADMIN: return # check if the user making the change is a curator in the group they are changing the curator relationship for user_making_change_curator_groups = fetch_user_groups_for_user( db_session=db_session, user_id=user_making_change.id, # only check if the user making the change is a curator if they are a curator # otherwise, they are a global_curator and can update the curator relationship # for any group they are a member of only_curator_groups=user_making_change.role == UserRole.CURATOR, ) requestor_curator_group_ids = [ group.id for group in user_making_change_curator_groups ] if user_group_id not in requestor_curator_group_ids: raise ValueError( f"user making change {user_making_change.email} is not a curator," f" admin, or global_curator for group '{user_group_id}'" ) def _validate_curator_relationship_update_request( db_session: Session, user_group_id: int, target_user: User, ) -> None: """ This function validates that the curator_relationship_update request itself is valid. """ if target_user.role == UserRole.ADMIN: raise ValueError( f"User '{target_user.email}' is an admin and therefore has all permissions " "of a curator. If you'd like this user to only have curator permissions, " "you must update their role to BASIC then assign them to be CURATOR in the " "appropriate groups." ) elif target_user.role == UserRole.GLOBAL_CURATOR: raise ValueError( f"User '{target_user.email}' is a global_curator and therefore has all " "permissions of a curator for all groups. If you'd like this user to only " "have curator permissions for a specific group, you must update their role " "to BASIC then assign them to be CURATOR in the appropriate groups." ) elif target_user.role not in [UserRole.CURATOR, UserRole.BASIC]: raise ValueError( f"This endpoint can only be used to update the curator relationship for " "users with the CURATOR or BASIC role. \n" f"Target user: {target_user.email} \n" f"Target user role: {target_user.role} \n" ) # check if the target user is in the group they are changing the curator relationship for requested_user_groups = fetch_user_groups_for_user( db_session=db_session, user_id=target_user.id, only_curator_groups=False, ) group_ids = [group.id for group in requested_user_groups] if user_group_id not in group_ids: raise ValueError( f"target user {target_user.email} is not in group '{user_group_id}'" ) def update_user_curator_relationship( db_session: Session, user_group_id: int, set_curator_request: SetCuratorRequest, user_making_change: User, ) -> None: target_user = fetch_user_by_id(db_session, set_curator_request.user_id) if not target_user: raise ValueError(f"User with id '{set_curator_request.user_id}' not found") _validate_curator_relationship_update_request( db_session=db_session, user_group_id=user_group_id, target_user=target_user, ) _validate_curator_relationship_update_requester( db_session=db_session, user_group_id=user_group_id, user_making_change=user_making_change, ) logger.info( f"user_making_change={user_making_change.email if user_making_change else 'None'} is " f"updating the curator relationship for user={target_user.email} " f"in group={user_group_id} to is_curator={set_curator_request.is_curator}" ) relationship_to_update = ( db_session.query(User__UserGroup) .filter( User__UserGroup.user_group_id == user_group_id, User__UserGroup.user_id == set_curator_request.user_id, ) .first() ) if relationship_to_update: relationship_to_update.is_curator = set_curator_request.is_curator else: relationship_to_update = User__UserGroup( user_group_id=user_group_id, user_id=set_curator_request.user_id, is_curator=True, ) db_session.add(relationship_to_update) _validate_curator_status__no_commit(db_session, [target_user]) db_session.commit() def add_users_to_user_group( db_session: Session, user: User, user_group_id: int, user_ids: list[UUID], ) -> UserGroup: db_user_group = fetch_user_group(db_session=db_session, user_group_id=user_group_id) if db_user_group is None: raise ValueError(f"UserGroup with id '{user_group_id}' not found") missing_users = [ user_id for user_id in user_ids if fetch_user_by_id(db_session, user_id) is None ] if missing_users: raise ValueError( f"User(s) not found: {', '.join(str(user_id) for user_id in missing_users)}" ) _check_user_group_is_modifiable(db_user_group) current_user_ids = [user.id for user in db_user_group.users] current_user_ids_set = set(current_user_ids) new_user_ids = [ user_id for user_id in user_ids if user_id not in current_user_ids_set ] if not new_user_ids: return db_user_group user_group_update = UserGroupUpdate( user_ids=current_user_ids + new_user_ids, cc_pair_ids=[cc_pair.id for cc_pair in db_user_group.cc_pairs], ) return update_user_group( db_session=db_session, user=user, user_group_id=user_group_id, user_group_update=user_group_update, ) def update_user_group( db_session: Session, user: User, # noqa: ARG001 user_group_id: int, user_group_update: UserGroupUpdate, ) -> UserGroup: """If successful, this can set db_user_group.is_up_to_date = False. That will be processed by check_for_vespa_user_groups_sync_task and trigger a long running background sync to Vespa. """ stmt = select(UserGroup).where(UserGroup.id == user_group_id) db_user_group = db_session.scalar(stmt) if db_user_group is None: raise ValueError(f"UserGroup with id '{user_group_id}' not found") _check_user_group_is_modifiable(db_user_group) current_user_ids = set([user.id for user in db_user_group.users]) updated_user_ids = set(user_group_update.user_ids) added_user_ids = list(updated_user_ids - current_user_ids) removed_user_ids = list(current_user_ids - updated_user_ids) if added_user_ids: missing_users = [ user_id for user_id in added_user_ids if fetch_user_by_id(db_session, user_id) is None ] if missing_users: raise ValueError( f"User(s) not found: {', '.join(str(user_id) for user_id in missing_users)}" ) # LEAVING THIS HERE FOR NOW FOR GIVING DIFFERENT ROLES # ACCESS TO DIFFERENT PERMISSIONS # if (removed_user_ids or added_user_ids) and ( # not user or user.role != UserRole.ADMIN # ): # raise ValueError("Only admins can add or remove users from user groups") if removed_user_ids: _cleanup_user__user_group_relationships__no_commit( db_session=db_session, user_group_id=user_group_id, user_ids=removed_user_ids, ) if added_user_ids: _add_user__user_group_relationships__no_commit( db_session=db_session, user_group_id=user_group_id, user_ids=added_user_ids, ) cc_pairs_updated = set([cc_pair.id for cc_pair in db_user_group.cc_pairs]) != set( user_group_update.cc_pair_ids ) if cc_pairs_updated: _mark_user_group__cc_pair_relationships_outdated__no_commit( db_session=db_session, user_group_id=user_group_id ) _add_user_group__cc_pair_relationships__no_commit( db_session=db_session, user_group_id=db_user_group.id, cc_pair_ids=user_group_update.cc_pair_ids, ) if cc_pairs_updated and not DISABLE_VECTOR_DB: db_user_group.is_up_to_date = False removed_users = db_session.scalars( select(User).where(User.id.in_(removed_user_ids)) # type: ignore ).unique() # Filter out admin and global curator users before validating curator status users_to_validate = [ user for user in removed_users if user.role not in [UserRole.ADMIN, UserRole.GLOBAL_CURATOR] ] if users_to_validate: _validate_curator_status__no_commit(db_session, users_to_validate) # update "time_updated" to now db_user_group.time_last_modified_by_user = func.now() recompute_user_permissions__no_commit( list(set(added_user_ids) | set(removed_user_ids)), db_session ) db_session.commit() return db_user_group def rename_user_group( db_session: Session, user_group_id: int, new_name: str, ) -> UserGroup: stmt = select(UserGroup).where(UserGroup.id == user_group_id) db_user_group = db_session.scalar(stmt) if db_user_group is None: raise ValueError(f"UserGroup with id '{user_group_id}' not found") _check_user_group_is_modifiable(db_user_group) db_user_group.name = new_name db_user_group.time_last_modified_by_user = func.now() # CC pair documents in Vespa contain the group name, so we need to # trigger a sync to update them with the new name. _mark_user_group__cc_pair_relationships_outdated__no_commit( db_session=db_session, user_group_id=user_group_id ) if not DISABLE_VECTOR_DB: db_user_group.is_up_to_date = False db_session.commit() return db_user_group def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) -> None: stmt = select(UserGroup).where(UserGroup.id == user_group_id) db_user_group = db_session.scalar(stmt) if db_user_group is None: raise ValueError(f"UserGroup with id '{user_group_id}' not found") _check_user_group_is_modifiable(db_user_group) # Collect affected user IDs before cleanup deletes the relationships affected_user_ids: list[UUID] = [ uid for uid in db_session.execute( select(User__UserGroup.user_id).where( User__UserGroup.user_group_id == user_group_id ) ) .scalars() .all() if uid is not None ] _mark_user_group__cc_pair_relationships_outdated__no_commit( db_session=db_session, user_group_id=user_group_id ) _cleanup_credential__user_group_relationships__no_commit( db_session=db_session, user_group_id=user_group_id ) _cleanup_user__user_group_relationships__no_commit( db_session=db_session, user_group_id=user_group_id ) _cleanup_token_rate_limit__user_group_relationships__no_commit( db_session=db_session, user_group_id=user_group_id ) _cleanup_document_set__user_group_relationships__no_commit( db_session=db_session, user_group_id=user_group_id ) _cleanup_persona__user_group_relationships__no_commit( db_session=db_session, user_group_id=user_group_id ) _cleanup_user_group__cc_pair_relationships__no_commit( db_session=db_session, user_group_id=user_group_id, outdated_only=False, ) _cleanup_llm_provider__user_group_relationships__no_commit( db_session=db_session, user_group_id=user_group_id ) # Recompute permissions for affected users now that their # membership in this group has been removed recompute_user_permissions__no_commit(affected_user_ids, db_session) db_user_group.is_up_to_date = False db_user_group.is_up_for_deletion = True db_session.commit() def delete_user_group(db_session: Session, user_group: UserGroup) -> None: """ This assumes that all the fk cleanup has already been done. """ db_session.delete(user_group) db_session.commit() def mark_user_group_as_synced(db_session: Session, user_group: UserGroup) -> None: # cleanup outdated relationships _cleanup_user_group__cc_pair_relationships__no_commit( db_session=db_session, user_group_id=user_group.id, outdated_only=True ) user_group.is_up_to_date = True db_session.commit() def delete_user_group_cc_pair_relationship__no_commit( cc_pair_id: int, db_session: Session ) -> None: """Deletes all rows from UserGroup__ConnectorCredentialPair where the connector_credential_pair_id matches the given cc_pair_id. Should be used very carefully (only for connectors that are being deleted).""" cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, ) if not cc_pair: raise ValueError(f"Connector Credential Pair '{cc_pair_id}' does not exist") if cc_pair.status != ConnectorCredentialPairStatus.DELETING: raise ValueError( f"Connector Credential Pair '{cc_pair_id}' is not in the DELETING state. status={cc_pair.status}" ) delete_stmt = delete(UserGroup__ConnectorCredentialPair).where( UserGroup__ConnectorCredentialPair.cc_pair_id == cc_pair_id, ) db_session.execute(delete_stmt) ================================================ FILE: backend/ee/onyx/document_index/vespa/app_config/cloud-services.xml.jinja ================================================ {{ document_elements }} 2 3 750 350 300 2 ================================================ FILE: backend/ee/onyx/external_permissions/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/external_permissions/confluence/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/external_permissions/confluence/constants.py ================================================ # This is a group that we use to store all the users that we found in Confluence # Instead of setting a page to public, we just add this group so that the page # is only accessible to users who have confluence accounts. ALL_CONF_EMAILS_GROUP_NAME = "All_Confluence_Users_Found_By_Onyx" VIEWSPACE_PERMISSION_TYPE = "VIEWSPACE" REQUEST_PAGINATION_LIMIT = 5000 ================================================ FILE: backend/ee/onyx/external_permissions/confluence/doc_sync.py ================================================ """ Rules defined here: https://confluence.atlassian.com/conf85/check-who-can-view-a-page-1283360557.html """ from collections.abc import Generator from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction from ee.onyx.external_permissions.utils import generic_doc_sync from onyx.access.models import ElementExternalAccess from onyx.configs.constants import DocumentSource from onyx.connectors.confluence.connector import ConfluenceConnector from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider from onyx.db.models import ConnectorCredentialPair from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() CONFLUENCE_DOC_SYNC_LABEL = "confluence_doc_sync" def confluence_doc_sync( cc_pair: ConnectorCredentialPair, fetch_all_existing_docs_fn: FetchAllDocumentsFunction, # noqa: ARG001 fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction, callback: IndexingHeartbeatInterface | None, ) -> Generator[ElementExternalAccess, None, None]: """ Fetches document permissions from Confluence and yields DocExternalAccess objects. Compares fetched documents against existing documents in the DB for the connector. If a document exists in the DB but not in the Confluence fetch, it's marked as restricted. """ confluence_connector = ConfluenceConnector( **cc_pair.connector.connector_specific_config ) provider = OnyxDBCredentialsProvider( get_current_tenant_id(), "confluence", cc_pair.credential_id ) confluence_connector.set_credentials_provider(provider) yield from generic_doc_sync( cc_pair=cc_pair, fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn, callback=callback, doc_source=DocumentSource.CONFLUENCE, slim_connector=confluence_connector, label=CONFLUENCE_DOC_SYNC_LABEL, ) ================================================ FILE: backend/ee/onyx/external_permissions/confluence/group_sync.py ================================================ from collections.abc import Generator from ee.onyx.db.external_perm import ExternalUserGroup from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME from onyx.background.error_logging import emit_background_error from onyx.configs.app_configs import CONFLUENCE_USE_ONYX_USERS_FOR_GROUP_SYNC from onyx.connectors.confluence.onyx_confluence import ( get_user_email_from_username__server, ) from onyx.connectors.confluence.onyx_confluence import OnyxConfluence from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import ConnectorCredentialPair from onyx.db.users import get_all_users from onyx.utils.logger import setup_logger logger = setup_logger() def _build_group_member_email_map( confluence_client: OnyxConfluence, cc_pair_id: int ) -> dict[str, set[str]]: group_member_emails: dict[str, set[str]] = {} for user in confluence_client.paginated_cql_user_retrieval(): logger.info(f"Processing groups for user: {user}") email = user.email if not email: # This field is only present in Confluence Server user_name = user.username # If it is present, try to get the email using a Server-specific method if user_name: email = get_user_email_from_username__server( confluence_client=confluence_client, user_name=user_name, ) else: logger.error(f"user result missing username field: {user}") if not email: # If we still don't have an email, skip this user msg = f"user result missing email field: {user}" if user.type == "app": logger.warning(msg) else: emit_background_error(msg, cc_pair_id=cc_pair_id) logger.error(msg) continue all_users_groups: set[str] = set() for group in confluence_client.paginated_groups_by_user_retrieval(user.user_id): # group name uniqueness is enforced by Confluence, so we can use it as a group ID group_id = group["name"] group_member_emails.setdefault(group_id, set()).add(email) all_users_groups.add(group_id) if not all_users_groups: msg = f"No groups found for user with email: {email}" emit_background_error(msg, cc_pair_id=cc_pair_id) logger.error(msg) else: logger.debug(f"Found groups {all_users_groups} for user with email {email}") if not group_member_emails: msg = "No groups found for any users." emit_background_error(msg, cc_pair_id=cc_pair_id) logger.error(msg) return group_member_emails def _build_group_member_email_map_from_onyx_users( confluence_client: OnyxConfluence, ) -> dict[str, set[str]]: """Hacky, but it's the only way to do this as long as the Confluence APIs are broken. This is fixed in Confluence Data Center 10.1.0, so first choice is to tell users to upgrade to 10.1.0. https://jira.atlassian.com/browse/CONFSERVER-95999 """ with get_session_with_current_tenant() as db_session: # don't include external since they are handled by the "through confluence" # user fetching mechanism user_emails = [ user.email for user in get_all_users(db_session, include_external=False) ] def _infer_username_from_email(email: str) -> str: return email.split("@")[0] group_member_emails: dict[str, set[str]] = {} for email in user_emails: logger.info(f"Processing groups for user with email: {email}") try: user_name = _infer_username_from_email(email) response = confluence_client.get_user_details_by_username(user_name) user_key = response.get("userKey") if not user_key: logger.error(f"User key not found for user with email {email}") continue all_users_groups: set[str] = set() for group in confluence_client.paginated_groups_by_user_retrieval(user_key): # group name uniqueness is enforced by Confluence, so we can use it as a group ID group_id = group["name"] group_member_emails.setdefault(group_id, set()).add(email) all_users_groups.add(group_id) if not all_users_groups: msg = f"No groups found for user with email: {email}" logger.error(msg) else: logger.info( f"Found groups {all_users_groups} for user with email {email}" ) except Exception: logger.exception(f"Error getting user details for user with email {email}") return group_member_emails def _build_final_group_to_member_email_map( confluence_client: OnyxConfluence, cc_pair_id: int, # if set, will infer confluence usernames from onyx users in addition to using the # confluence users API. This is a hacky workaround for the fact that the Confluence # users API is broken before Confluence Data Center 10.1.0. use_onyx_users: bool = CONFLUENCE_USE_ONYX_USERS_FOR_GROUP_SYNC, ) -> dict[str, set[str]]: group_to_member_email_map = _build_group_member_email_map( confluence_client=confluence_client, cc_pair_id=cc_pair_id, ) group_to_member_email_map_from_onyx_users = ( ( _build_group_member_email_map_from_onyx_users( confluence_client=confluence_client, ) ) if use_onyx_users else {} ) all_group_ids = set(group_to_member_email_map.keys()) | set( group_to_member_email_map_from_onyx_users.keys() ) final_group_to_member_email_map = {} for group_id in all_group_ids: group_member_emails = group_to_member_email_map.get( group_id, set() ) | group_to_member_email_map_from_onyx_users.get(group_id, set()) final_group_to_member_email_map[group_id] = group_member_emails return final_group_to_member_email_map def confluence_group_sync( tenant_id: str, cc_pair: ConnectorCredentialPair, ) -> Generator[ExternalUserGroup, None, None]: provider = OnyxDBCredentialsProvider(tenant_id, "confluence", cc_pair.credential_id) is_cloud = cc_pair.connector.connector_specific_config.get("is_cloud", False) wiki_base: str = cc_pair.connector.connector_specific_config["wiki_base"] url = wiki_base.rstrip("/") probe_kwargs = { "max_backoff_retries": 6, "max_backoff_seconds": 10, } final_kwargs = { "max_backoff_retries": 10, "max_backoff_seconds": 60, } confluence_client = OnyxConfluence(is_cloud, url, provider) confluence_client._probe_connection(**probe_kwargs) confluence_client._initialize_connection(**final_kwargs) group_to_member_email_map = _build_final_group_to_member_email_map( confluence_client, cc_pair.id ) all_found_emails = set() for group_id, group_member_emails in group_to_member_email_map.items(): yield ( ExternalUserGroup( id=group_id, user_emails=list(group_member_emails), ) ) all_found_emails.update(group_member_emails) # This is so that when we find a public confleunce server page, we can # give access to all users only in if they have an email in Confluence if cc_pair.connector.connector_specific_config.get("is_cloud", False): all_found_group = ExternalUserGroup( id=ALL_CONF_EMAILS_GROUP_NAME, user_emails=list(all_found_emails), ) yield all_found_group ================================================ FILE: backend/ee/onyx/external_permissions/confluence/page_access.py ================================================ from typing import Any from onyx.access.models import ExternalAccess from onyx.access.utils import build_ext_group_name_for_onyx from onyx.configs.constants import DocumentSource from onyx.connectors.confluence.onyx_confluence import ( get_user_email_from_username__server, ) from onyx.connectors.confluence.onyx_confluence import OnyxConfluence from onyx.utils.logger import setup_logger logger = setup_logger() def _extract_read_access_restrictions( confluence_client: OnyxConfluence, restrictions: dict[str, Any] ) -> tuple[set[str], set[str], bool]: """ Converts a page's restrictions dict into an ExternalAccess object. If there are no restrictions, then return None """ read_access = restrictions.get("read", {}) read_access_restrictions = read_access.get("restrictions", {}) # Extract the users with read access read_access_user = read_access_restrictions.get("user", {}) read_access_user_jsons = read_access_user.get("results", []) # any items found means that there is a restriction found_any_restriction = bool(read_access_user_jsons) read_access_user_emails = [] for user in read_access_user_jsons: # If the user has an email, then add it to the list if user.get("email"): read_access_user_emails.append(user["email"]) # If the user has a username and not an email, then get the email from Confluence elif user.get("username"): email = get_user_email_from_username__server( confluence_client=confluence_client, user_name=user["username"] ) if email: read_access_user_emails.append(email) else: logger.warning( f"Email for user {user['username']} not found in Confluence" ) else: if user.get("email") is not None: logger.warning(f"Cant find email for user {user.get('displayName')}") logger.warning( "This user needs to make their email accessible in Confluence Settings" ) logger.warning(f"no user email or username for {user}") # Extract the groups with read access read_access_group = read_access_restrictions.get("group", {}) read_access_group_jsons = read_access_group.get("results", []) # any items found means that there is a restriction found_any_restriction |= bool(read_access_group_jsons) read_access_group_names = [ group["name"] for group in read_access_group_jsons if group.get("name") ] return ( set(read_access_user_emails), set(read_access_group_names), found_any_restriction, ) def get_page_restrictions( confluence_client: OnyxConfluence, page_id: str, page_restrictions: dict[str, Any], ancestors: list[dict[str, Any]], add_prefix: bool = False, ) -> ExternalAccess | None: """ This function gets the restrictions for a page. In Confluence, a child can have at MOST the same level accessibility as its immediate parent. If no restrictions are found anywhere, then return None, indicating that the page should inherit the space's restrictions. add_prefix: When True, prefix group IDs with source type (for indexing path). When False (default), leave unprefixed (for permission sync path). """ found_user_emails: set[str] = set() found_group_names: set[str] = set() # NOTE: need the found_any_restriction, since we can find restrictions # but not be able to extract any user emails or group names # in this case, we should just give no access found_user_emails, found_group_names, found_any_page_level_restriction = ( _extract_read_access_restrictions( confluence_client=confluence_client, restrictions=page_restrictions, ) ) def _maybe_prefix_groups(group_names: set[str]) -> set[str]: if add_prefix: return { build_ext_group_name_for_onyx(g, DocumentSource.CONFLUENCE) for g in group_names } return group_names # if there are individual page-level restrictions, then this is the accurate # restriction for the page. You cannot both have page-level restrictions AND # inherit restrictions from the parent. if found_any_page_level_restriction: return ExternalAccess( external_user_emails=found_user_emails, external_user_group_ids=_maybe_prefix_groups(found_group_names), is_public=False, ) # ancestors seem to be in order from root to immediate parent # https://community.atlassian.com/forums/Confluence-questions/Order-of-ancestors-in-REST-API-response-Confluence-Server-amp/qaq-p/2385981 # we want the restrictions from the immediate parent to take precedence, so we should # reverse the list for ancestor in reversed(ancestors): ( ancestor_user_emails, ancestor_group_names, found_any_restrictions_in_ancestor, ) = _extract_read_access_restrictions( confluence_client=confluence_client, restrictions=ancestor.get("restrictions", {}), ) if found_any_restrictions_in_ancestor: # if inheriting restrictions from the parent, then the first one we run into # should be applied (the reason why we'd traverse more than one ancestor is if # the ancestor also is in "inherit" mode.) logger.debug( f"Found user restrictions {ancestor_user_emails} and group restrictions {ancestor_group_names}" f"for document {page_id} based on ancestor {ancestor}" ) return ExternalAccess( external_user_emails=ancestor_user_emails, external_user_group_ids=_maybe_prefix_groups(ancestor_group_names), is_public=False, ) # we didn't find any restrictions, so the page inherits the space's restrictions return None ================================================ FILE: backend/ee/onyx/external_permissions/confluence/space_access.py ================================================ from ee.onyx.configs.app_configs import CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC from ee.onyx.external_permissions.confluence.constants import ALL_CONF_EMAILS_GROUP_NAME from ee.onyx.external_permissions.confluence.constants import REQUEST_PAGINATION_LIMIT from ee.onyx.external_permissions.confluence.constants import VIEWSPACE_PERMISSION_TYPE from onyx.access.models import ExternalAccess from onyx.access.utils import build_ext_group_name_for_onyx from onyx.configs.constants import DocumentSource from onyx.connectors.confluence.onyx_confluence import ( get_user_email_from_username__server, ) from onyx.connectors.confluence.onyx_confluence import OnyxConfluence from onyx.utils.logger import setup_logger logger = setup_logger() def _get_server_space_permissions( confluence_client: OnyxConfluence, space_key: str ) -> ExternalAccess: space_permissions = confluence_client.get_all_space_permissions_server( space_key=space_key ) viewspace_permissions = [] for permission_category in space_permissions: if permission_category.get("type") == VIEWSPACE_PERMISSION_TYPE: viewspace_permissions.extend( permission_category.get("spacePermissions", []) ) is_public = False user_names = set() group_names = set() for permission in viewspace_permissions: if user_name := permission.get("userName"): user_names.add(user_name) if group_name := permission.get("groupName"): group_names.add(group_name) # It seems that if anonymous access is turned on for the site and space, # then the space is publicly accessible. # For confluence server, we make a group that contains all users # that exist in confluence and then just add that group to the space permissions # if anonymous access is turned on for the site and space or we set is_public = True # if they set the env variable CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC to True so # that we can support confluence server deployments that want anonymous access # to be public (we cant test this because its paywalled) if user_name is None and group_name is None: # Defaults to False if CONFLUENCE_ANONYMOUS_ACCESS_IS_PUBLIC: is_public = True else: group_names.add(ALL_CONF_EMAILS_GROUP_NAME) user_emails = set() for user_name in user_names: user_email = get_user_email_from_username__server(confluence_client, user_name) if user_email: user_emails.add(user_email) else: logger.warning(f"Email for user {user_name} not found in Confluence") if not user_emails and not group_names: logger.warning( "No user emails or group names found in Confluence space permissions" f"\nSpace key: {space_key}" f"\nSpace permissions: {space_permissions}" ) return ExternalAccess( external_user_emails=user_emails, external_user_group_ids=group_names, is_public=is_public, ) def _get_cloud_space_permissions( confluence_client: OnyxConfluence, space_key: str ) -> ExternalAccess: space_permissions_result = confluence_client.get_space( space_key=space_key, expand="permissions" ) space_permissions = space_permissions_result.get("permissions", []) user_emails = set() group_names = set() is_externally_public = False for permission in space_permissions: subs = permission.get("subjects") if subs: # If there are subjects, then there are explicit users or groups with access if email := subs.get("user", {}).get("results", [{}])[0].get("email"): user_emails.add(email) if group_name := subs.get("group", {}).get("results", [{}])[0].get("name"): group_names.add(group_name) else: # If there are no subjects, then the permission is for everyone if permission.get("operation", {}).get( "operation" ) == "read" and permission.get("anonymousAccess", False): # If the permission specifies read access for anonymous users, then # the space is publicly accessible is_externally_public = True return ExternalAccess( external_user_emails=user_emails, external_user_group_ids=group_names, is_public=is_externally_public, ) def get_space_permission( confluence_client: OnyxConfluence, space_key: str, is_cloud: bool, add_prefix: bool = False, ) -> ExternalAccess: if is_cloud: space_permissions = _get_cloud_space_permissions(confluence_client, space_key) else: space_permissions = _get_server_space_permissions(confluence_client, space_key) if ( not space_permissions.is_public and not space_permissions.external_user_emails and not space_permissions.external_user_group_ids ): logger.warning( f"No permissions found for space '{space_key}'. This is very unlikely " "to be correct and is more likely caused by an access token with " "insufficient permissions. Make sure that the access token has Admin " f"permissions for space '{space_key}'" ) # Prefix group IDs with source type if requested (for indexing path) if add_prefix and space_permissions.external_user_group_ids: prefixed_groups = { build_ext_group_name_for_onyx(g, DocumentSource.CONFLUENCE) for g in space_permissions.external_user_group_ids } return ExternalAccess( external_user_emails=space_permissions.external_user_emails, external_user_group_ids=prefixed_groups, is_public=space_permissions.is_public, ) return space_permissions def get_all_space_permissions( confluence_client: OnyxConfluence, is_cloud: bool, add_prefix: bool = False, ) -> dict[str, ExternalAccess]: """ Get access permissions for all spaces in Confluence. add_prefix: When True, prefix group IDs with source type (for indexing path). When False (default), leave unprefixed (for permission sync path). """ logger.debug("Getting space permissions") # Gets all the spaces in the Confluence instance all_space_keys = [ key for space in confluence_client.retrieve_confluence_spaces( limit=REQUEST_PAGINATION_LIMIT, ) if (key := space.get("key")) ] # Gets the permissions for each space logger.debug(f"Got {len(all_space_keys)} spaces from confluence") space_permissions_by_space_key: dict[str, ExternalAccess] = {} for space_key in all_space_keys: space_permissions = get_space_permission( confluence_client, space_key, is_cloud, add_prefix ) # Stores the permissions for each space space_permissions_by_space_key[space_key] = space_permissions return space_permissions_by_space_key ================================================ FILE: backend/ee/onyx/external_permissions/github/doc_sync.py ================================================ import json from collections.abc import Generator from github import Github from github.Repository import Repository from ee.onyx.external_permissions.github.utils import fetch_repository_team_slugs from ee.onyx.external_permissions.github.utils import form_collaborators_group_id from ee.onyx.external_permissions.github.utils import form_organization_group_id from ee.onyx.external_permissions.github.utils import ( form_outside_collaborators_group_id, ) from ee.onyx.external_permissions.github.utils import get_external_access_permission from ee.onyx.external_permissions.github.utils import get_repository_visibility from ee.onyx.external_permissions.github.utils import GitHubVisibility from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction from onyx.access.models import DocExternalAccess from onyx.access.utils import build_ext_group_name_for_onyx from onyx.configs.constants import DocumentSource from onyx.connectors.github.connector import DocMetadata from onyx.connectors.github.connector import GithubConnector from onyx.db.models import ConnectorCredentialPair from onyx.db.utils import DocumentRow from onyx.db.utils import SortOrder from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger logger = setup_logger() GITHUB_DOC_SYNC_LABEL = "github_doc_sync" def github_doc_sync( cc_pair: ConnectorCredentialPair, fetch_all_existing_docs_fn: FetchAllDocumentsFunction, fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction, # noqa: ARG001 callback: IndexingHeartbeatInterface | None = None, ) -> Generator[DocExternalAccess, None, None]: """ Sync GitHub documents with external access permissions. This function checks each repository for visibility/team changes and updates document permissions accordingly without using checkpoints. """ logger.info(f"Starting GitHub document sync for CC pair ID: {cc_pair.id}") # Initialize GitHub connector with credentials github_connector: GithubConnector = GithubConnector( **cc_pair.connector.connector_specific_config ) credential_json = ( cc_pair.credential.credential_json.get_value(apply_mask=False) if cc_pair.credential.credential_json else {} ) github_connector.load_credentials(credential_json) logger.info("GitHub connector credentials loaded successfully") if not github_connector.github_client: logger.error("GitHub client initialization failed") raise ValueError("github_client is required") # Get all repositories from GitHub API logger.info("Fetching all repositories from GitHub API") try: repos = github_connector.fetch_configured_repos() logger.info(f"Found {len(repos)} repositories to check") except Exception as e: logger.error(f"Failed to fetch repositories: {e}") raise repo_to_doc_list_map: dict[str, list[DocumentRow]] = {} # sort order is ascending because we want to get the oldest documents first existing_docs: list[DocumentRow] = fetch_all_existing_docs_fn( sort_order=SortOrder.ASC ) logger.info(f"Found {len(existing_docs)} documents to check") for doc in existing_docs: try: doc_metadata = DocMetadata.model_validate_json(json.dumps(doc.doc_metadata)) if doc_metadata.repo not in repo_to_doc_list_map: repo_to_doc_list_map[doc_metadata.repo] = [] repo_to_doc_list_map[doc_metadata.repo].append(doc) except Exception as e: logger.error(f"Failed to parse doc metadata: {e} for doc {doc.id}") continue logger.info(f"Found {len(repo_to_doc_list_map)} documents to check") # Process each repository individually for repo in repos: try: logger.info(f"Processing repository: {repo.id} (name: {repo.name})") repo_doc_list: list[DocumentRow] = repo_to_doc_list_map.get( repo.full_name, [] ) if not repo_doc_list: logger.warning( f"No documents found for repository {repo.id} ({repo.name})" ) continue current_external_group_ids = repo_doc_list[0].external_user_group_ids or [] # Check if repository has any permission changes has_changes = _check_repository_for_changes( repo=repo, github_client=github_connector.github_client, current_external_group_ids=current_external_group_ids, ) if has_changes: logger.info( f"Repository {repo.id} ({repo.name}) has changes, updating documents" ) # Get new external access permissions for this repository new_external_access = get_external_access_permission( repo, github_connector.github_client ) logger.info( f"Found {len(repo_doc_list)} documents for repository {repo.full_name}" ) # Yield updated external access for each document for doc in repo_doc_list: if callback: callback.progress(GITHUB_DOC_SYNC_LABEL, 1) yield DocExternalAccess( doc_id=doc.id, external_access=new_external_access, ) else: logger.info( f"Repository {repo.id} ({repo.name}) has no changes, skipping" ) except Exception as e: logger.error(f"Error processing repository {repo.id} ({repo.name}): {e}") logger.info(f"GitHub document sync completed for CC pair ID: {cc_pair.id}") def _check_repository_for_changes( repo: Repository, github_client: Github, current_external_group_ids: list[str], ) -> bool: """ Check if repository has any permission changes (visibility or team updates). """ logger.info(f"Checking repository {repo.id} ({repo.name}) for changes") # Check for repository visibility changes using the sample document data if _is_repo_visibility_changed_from_groups( repo=repo, current_external_group_ids=current_external_group_ids, ): logger.info(f"Repository {repo.id} ({repo.name}) has visibility changes") return True # Check for team membership changes if repository is private if get_repository_visibility( repo ) == GitHubVisibility.PRIVATE and _teams_updated_from_groups( repo=repo, github_client=github_client, current_external_group_ids=current_external_group_ids, ): logger.info(f"Repository {repo.id} ({repo.name}) has team changes") return True logger.info(f"Repository {repo.id} ({repo.name}) has no changes") return False def _is_repo_visibility_changed_from_groups( repo: Repository, current_external_group_ids: list[str], ) -> bool: """ Check if repository visibility has changed by analyzing existing external group IDs. Args: repo: GitHub repository object current_external_group_ids: List of external group IDs from existing document Returns: True if visibility has changed """ current_repo_visibility = get_repository_visibility(repo) logger.info(f"Current repository visibility: {current_repo_visibility.value}") # Build expected group IDs for current visibility collaborators_group_id = build_ext_group_name_for_onyx( source=DocumentSource.GITHUB, ext_group_name=form_collaborators_group_id(repo.id), ) org_group_id = None if repo.organization: org_group_id = build_ext_group_name_for_onyx( source=DocumentSource.GITHUB, ext_group_name=form_organization_group_id(repo.organization.id), ) # Determine existing visibility from group IDs has_collaborators_group = collaborators_group_id in current_external_group_ids has_org_group = org_group_id and org_group_id in current_external_group_ids if has_collaborators_group: existing_repo_visibility = GitHubVisibility.PRIVATE elif has_org_group: existing_repo_visibility = GitHubVisibility.INTERNAL else: existing_repo_visibility = GitHubVisibility.PUBLIC logger.info(f"Inferred existing visibility: {existing_repo_visibility.value}") visibility_changed = existing_repo_visibility != current_repo_visibility if visibility_changed: logger.info( f"Visibility changed for repo {repo.id} ({repo.name}): " f"{existing_repo_visibility.value} -> {current_repo_visibility.value}" ) return visibility_changed def _teams_updated_from_groups( repo: Repository, github_client: Github, current_external_group_ids: list[str], ) -> bool: """ Check if repository team memberships have changed using existing group IDs. """ # Fetch current team slugs for the repository current_teams = fetch_repository_team_slugs(repo=repo, github_client=github_client) logger.info( f"Current teams for repository {repo.id} (name: {repo.name}): {current_teams}" ) # Build group IDs to exclude from team comparison (non-team groups) collaborators_group_id = build_ext_group_name_for_onyx( source=DocumentSource.GITHUB, ext_group_name=form_collaborators_group_id(repo.id), ) outside_collaborators_group_id = build_ext_group_name_for_onyx( source=DocumentSource.GITHUB, ext_group_name=form_outside_collaborators_group_id(repo.id), ) non_team_group_ids = {collaborators_group_id, outside_collaborators_group_id} # Extract existing team IDs from current external group IDs existing_team_ids = set() for group_id in current_external_group_ids: # Skip all non-team groups, keep only team groups if group_id not in non_team_group_ids: existing_team_ids.add(group_id) # Note: existing_team_ids from DB are already prefixed (e.g., "github__team-slug") # but current_teams from API are raw team slugs, so we need to add the prefix current_team_ids = set() for team_slug in current_teams: team_group_id = build_ext_group_name_for_onyx( source=DocumentSource.GITHUB, ext_group_name=team_slug, ) current_team_ids.add(team_group_id) logger.info( f"Existing team IDs: {existing_team_ids}, Current team IDs: {current_team_ids}" ) # Compare actual team IDs to detect changes teams_changed = current_team_ids != existing_team_ids if teams_changed: logger.info( f"Team changes detected for repo {repo.id} (name: {repo.name}): " f"existing={existing_team_ids}, current={current_team_ids}" ) return teams_changed ================================================ FILE: backend/ee/onyx/external_permissions/github/group_sync.py ================================================ from collections.abc import Generator from github import Repository from ee.onyx.db.external_perm import ExternalUserGroup from ee.onyx.external_permissions.github.utils import get_external_user_group from onyx.connectors.github.connector import GithubConnector from onyx.db.models import ConnectorCredentialPair from onyx.utils.logger import setup_logger logger = setup_logger() def github_group_sync( tenant_id: str, # noqa: ARG001 cc_pair: ConnectorCredentialPair, ) -> Generator[ExternalUserGroup, None, None]: github_connector: GithubConnector = GithubConnector( **cc_pair.connector.connector_specific_config ) credential_json = ( cc_pair.credential.credential_json.get_value(apply_mask=False) if cc_pair.credential.credential_json else {} ) github_connector.load_credentials(credential_json) if not github_connector.github_client: raise ValueError("github_client is required") logger.info("Starting GitHub group sync...") repos: list[Repository.Repository] = [] if github_connector.repositories: if "," in github_connector.repositories: # Multiple repositories specified repos = github_connector.get_github_repos(github_connector.github_client) else: # Single repository (backward compatibility) repos = [github_connector.get_github_repo(github_connector.github_client)] else: # All repositories repos = github_connector.get_all_repos(github_connector.github_client) for repo in repos: try: for external_group in get_external_user_group( repo, github_connector.github_client ): logger.info(f"External group: {external_group}") yield external_group except Exception as e: logger.error(f"Error processing repository {repo.id} ({repo.name}): {e}") ================================================ FILE: backend/ee/onyx/external_permissions/github/utils.py ================================================ from collections.abc import Callable from enum import Enum from typing import List from typing import Optional from typing import Tuple from typing import TypeVar from github import Github from github import RateLimitExceededException from github.GithubException import GithubException from github.NamedUser import NamedUser from github.Organization import Organization from github.PaginatedList import PaginatedList from github.Repository import Repository from github.Team import Team from pydantic import BaseModel from ee.onyx.db.external_perm import ExternalUserGroup from onyx.access.models import ExternalAccess from onyx.access.utils import build_ext_group_name_for_onyx from onyx.configs.constants import DocumentSource from onyx.connectors.github.rate_limit_utils import sleep_after_rate_limit_exception from onyx.utils.logger import setup_logger logger = setup_logger() class GitHubVisibility(Enum): """GitHub repository visibility options.""" PUBLIC = "public" PRIVATE = "private" INTERNAL = "internal" MAX_RETRY_COUNT = 3 T = TypeVar("T") # Higher-order function to wrap GitHub operations with retry and exception handling def _run_with_retry( operation: Callable[[], T], description: str, github_client: Github, retry_count: int = 0, ) -> Optional[T]: """Execute a GitHub operation with retry on rate limit and exception handling.""" logger.debug(f"Starting operation '{description}', attempt {retry_count + 1}") try: result = operation() logger.debug(f"Operation '{description}' completed successfully") return result except RateLimitExceededException: if retry_count < MAX_RETRY_COUNT: sleep_after_rate_limit_exception(github_client) logger.warning( f"Rate limit exceeded while {description}. Retrying... (attempt {retry_count + 1}/{MAX_RETRY_COUNT})" ) return _run_with_retry( operation, description, github_client, retry_count + 1 ) else: error_msg = f"Max retries exceeded for {description}" logger.exception(error_msg) raise RuntimeError(error_msg) except GithubException as e: logger.warning(f"GitHub API error during {description}: {e}") return None except Exception as e: logger.exception(f"Unexpected error during {description}: {e}") return None class UserInfo(BaseModel): """Represents a GitHub user with their basic information.""" login: str name: Optional[str] = None email: Optional[str] = None class TeamInfo(BaseModel): """Represents a GitHub team with its members.""" name: str slug: str members: List[UserInfo] def _fetch_organization_members( github_client: Github, org_name: str, retry_count: int = 0, # noqa: ARG001 ) -> List[UserInfo]: """Fetch all organization members including owners and regular members.""" org_members: List[UserInfo] = [] logger.info(f"Fetching organization members for {org_name}") org = _run_with_retry( lambda: github_client.get_organization(org_name), f"get organization {org_name}", github_client, ) if not org: logger.error(f"Failed to fetch organization {org_name}") raise RuntimeError(f"Failed to fetch organization {org_name}") member_objs: PaginatedList[NamedUser] | list[NamedUser] = ( _run_with_retry( lambda: org.get_members(filter_="all"), f"get members for organization {org_name}", github_client, ) or [] ) for member in member_objs: user_info = UserInfo(login=member.login, name=member.name, email=member.email) org_members.append(user_info) logger.info(f"Fetched {len(org_members)} members for organization {org_name}") return org_members def _fetch_repository_teams_detailed( repo: Repository, github_client: Github, retry_count: int = 0, # noqa: ARG001 ) -> List[TeamInfo]: """Fetch teams with access to the repository and their members.""" teams_data: List[TeamInfo] = [] logger.info(f"Fetching teams for repository {repo.full_name}") team_objs: PaginatedList[Team] | list[Team] = ( _run_with_retry( lambda: repo.get_teams(), f"get teams for repository {repo.full_name}", github_client, ) or [] ) for team in team_objs: logger.info( f"Processing team {team.name} (slug: {team.slug}) for repository {repo.full_name}" ) members: PaginatedList[NamedUser] | list[NamedUser] = ( _run_with_retry( lambda: team.get_members(), f"get members for team {team.name}", github_client, ) or [] ) team_members = [] for m in members: user_info = UserInfo(login=m.login, name=m.name, email=m.email) team_members.append(user_info) team_info = TeamInfo(name=team.name, slug=team.slug, members=team_members) teams_data.append(team_info) logger.info(f"Team {team.name} has {len(team_members)} members") logger.info(f"Fetched {len(teams_data)} teams for repository {repo.full_name}") return teams_data def fetch_repository_team_slugs( repo: Repository, github_client: Github, retry_count: int = 0, # noqa: ARG001 ) -> List[str]: """Fetch team slugs with access to the repository.""" logger.info(f"Fetching team slugs for repository {repo.full_name}") teams_data: List[str] = [] team_objs: PaginatedList[Team] | list[Team] = ( _run_with_retry( lambda: repo.get_teams(), f"get teams for repository {repo.full_name}", github_client, ) or [] ) for team in team_objs: teams_data.append(team.slug) logger.info(f"Fetched {len(teams_data)} team slugs for repository {repo.full_name}") return teams_data def _get_collaborators_and_outside_collaborators( github_client: Github, repo: Repository, ) -> Tuple[List[UserInfo], List[UserInfo]]: """Fetch and categorize collaborators into regular and outside collaborators.""" collaborators: List[UserInfo] = [] outside_collaborators: List[UserInfo] = [] logger.info(f"Fetching collaborators for repository {repo.full_name}") repo_collaborators: PaginatedList[NamedUser] | list[NamedUser] = ( _run_with_retry( lambda: repo.get_collaborators(), f"get collaborators for repository {repo.full_name}", github_client, ) or [] ) for collaborator in repo_collaborators: is_outside = False # Check if collaborator is outside the organization if repo.organization: org: Organization | None = _run_with_retry( lambda: github_client.get_organization(repo.organization.login), f"get organization {repo.organization.login}", github_client, ) if org is not None: org_obj = org membership = _run_with_retry( lambda: org_obj.has_in_members(collaborator), f"check membership for {collaborator.login} in org {org_obj.login}", github_client, ) is_outside = membership is not None and not membership info = UserInfo( login=collaborator.login, name=collaborator.name, email=collaborator.email ) if repo.organization and is_outside: outside_collaborators.append(info) else: collaborators.append(info) logger.info( f"Categorized {len(collaborators)} regular and {len(outside_collaborators)} outside collaborators for {repo.full_name}" ) return collaborators, outside_collaborators def form_collaborators_group_id(repository_id: int) -> str: """Generate group ID for repository collaborators.""" if not repository_id: logger.exception("Repository ID is required to generate collaborators group ID") raise ValueError("Repository ID must be set to generate group ID.") group_id = f"{repository_id}_collaborators" return group_id def form_organization_group_id(organization_id: int) -> str: """Generate group ID for organization using organization ID.""" if not organization_id: logger.exception( "Organization ID is required to generate organization group ID" ) raise ValueError("Organization ID must be set to generate group ID.") group_id = f"{organization_id}_organization" return group_id def form_outside_collaborators_group_id(repository_id: int) -> str: """Generate group ID for outside collaborators.""" if not repository_id: logger.exception( "Repository ID is required to generate outside collaborators group ID" ) raise ValueError("Repository ID must be set to generate group ID.") group_id = f"{repository_id}_outside_collaborators" return group_id def get_repository_visibility(repo: Repository) -> GitHubVisibility: """ Get the visibility of a repository. Returns GitHubVisibility enum member. """ if hasattr(repo, "visibility"): visibility = repo.visibility logger.info( f"Repository {repo.full_name} visibility from attribute: {visibility}" ) try: return GitHubVisibility(visibility) except ValueError: logger.warning( f"Unknown visibility '{visibility}' for repo {repo.full_name}, defaulting to private" ) return GitHubVisibility.PRIVATE logger.info(f"Repository {repo.full_name} is private") return GitHubVisibility.PRIVATE def get_external_access_permission( repo: Repository, github_client: Github, add_prefix: bool = False ) -> ExternalAccess: """ Get the external access permission for a repository. Uses group-based permissions for efficiency and scalability. add_prefix: When this method is called during the initial permission sync via the connector, the group ID isn't prefixed with the source while inserting the document record. So in that case, set add_prefix to True, allowing the method itself to handle prefixing. However, when the same method is invoked from doc_sync, our system already adds the prefix to the group ID while processing the ExternalAccess object. """ # We maintain collaborators, and outside collaborators as two separate groups # instead of adding individual user emails to ExternalAccess.external_user_emails for two reasons: # 1. Changes in repo collaborators (additions/removals) would require updating all documents. # 2. Repo permissions can change without updating the repo's updated_at timestamp, # forcing full permission syncs for all documents every time, which is inefficient. repo_visibility = get_repository_visibility(repo) logger.info( f"Generating ExternalAccess for {repo.full_name}: visibility={repo_visibility.value}" ) if repo_visibility == GitHubVisibility.PUBLIC: logger.info( f"Repository {repo.full_name} is public - allowing access to all users" ) return ExternalAccess( external_user_emails=set(), external_user_group_ids=set(), is_public=True, ) elif repo_visibility == GitHubVisibility.PRIVATE: logger.info( f"Repository {repo.full_name} is private - setting up restricted access" ) collaborators_group_id = form_collaborators_group_id(repo.id) outside_collaborators_group_id = form_outside_collaborators_group_id(repo.id) if add_prefix: collaborators_group_id = build_ext_group_name_for_onyx( source=DocumentSource.GITHUB, ext_group_name=collaborators_group_id, ) outside_collaborators_group_id = build_ext_group_name_for_onyx( source=DocumentSource.GITHUB, ext_group_name=outside_collaborators_group_id, ) group_ids = {collaborators_group_id, outside_collaborators_group_id} team_slugs = fetch_repository_team_slugs(repo, github_client) if add_prefix: team_slugs = [ build_ext_group_name_for_onyx( source=DocumentSource.GITHUB, ext_group_name=slug, ) for slug in team_slugs ] group_ids.update(team_slugs) logger.info(f"ExternalAccess groups for {repo.full_name}: {group_ids}") return ExternalAccess( external_user_emails=set(), external_user_group_ids=group_ids, is_public=False, ) else: # Internal repositories - accessible to organization members logger.info( f"Repository {repo.full_name} is internal - accessible to org members" ) org_group_id = form_organization_group_id(repo.organization.id) if add_prefix: org_group_id = build_ext_group_name_for_onyx( source=DocumentSource.GITHUB, ext_group_name=org_group_id, ) group_ids = {org_group_id} logger.info(f"ExternalAccess groups for {repo.full_name}: {group_ids}") return ExternalAccess( external_user_emails=set(), external_user_group_ids=group_ids, is_public=False, ) def get_external_user_group( repo: Repository, github_client: Github ) -> list[ExternalUserGroup]: """ Get the external user group for a repository. Creates ExternalUserGroup objects with actual user emails for each permission group. """ repo_visibility = get_repository_visibility(repo) logger.info( f"Generating ExternalUserGroups for {repo.full_name}: visibility={repo_visibility.value}" ) if repo_visibility == GitHubVisibility.PRIVATE: logger.info(f"Processing private repository {repo.full_name}") collaborators, outside_collaborators = ( _get_collaborators_and_outside_collaborators(github_client, repo) ) teams = _fetch_repository_teams_detailed(repo, github_client) external_user_groups = [] user_emails = set() for collab in collaborators: if collab.email: user_emails.add(collab.email) else: logger.error(f"Collaborator {collab.login} has no email") if user_emails: collaborators_group = ExternalUserGroup( id=form_collaborators_group_id(repo.id), user_emails=list(user_emails), ) external_user_groups.append(collaborators_group) logger.info(f"Created collaborators group with {len(user_emails)} emails") # Create group for outside collaborators user_emails = set() for collab in outside_collaborators: if collab.email: user_emails.add(collab.email) else: logger.error(f"Outside collaborator {collab.login} has no email") if user_emails: outside_collaborators_group = ExternalUserGroup( id=form_outside_collaborators_group_id(repo.id), user_emails=list(user_emails), ) external_user_groups.append(outside_collaborators_group) logger.info( f"Created outside collaborators group with {len(user_emails)} emails" ) # Create groups for teams for team in teams: user_emails = set() for member in team.members: if member.email: user_emails.add(member.email) else: logger.error(f"Team member {member.login} has no email") if user_emails: team_group = ExternalUserGroup( id=team.slug, user_emails=list(user_emails), ) external_user_groups.append(team_group) logger.info( f"Created team group {team.name} with {len(user_emails)} emails" ) logger.info( f"Created {len(external_user_groups)} ExternalUserGroups for private repository {repo.full_name}" ) return external_user_groups if repo_visibility == GitHubVisibility.INTERNAL: logger.info(f"Processing internal repository {repo.full_name}") org_group_id = form_organization_group_id(repo.organization.id) org_members = _fetch_organization_members( github_client, repo.organization.login ) user_emails = set() for member in org_members: if member.email: user_emails.add(member.email) else: logger.error(f"Org member {member.login} has no email") org_group = ExternalUserGroup( id=org_group_id, user_emails=list(user_emails), ) logger.info( f"Created organization group with {len(user_emails)} emails for internal repository {repo.full_name}" ) return [org_group] logger.info(f"Repository {repo.full_name} is public - no user groups needed") return [] ================================================ FILE: backend/ee/onyx/external_permissions/gmail/doc_sync.py ================================================ from collections.abc import Generator from datetime import datetime from datetime import timezone from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction from onyx.access.models import DocExternalAccess from onyx.access.models import ElementExternalAccess from onyx.access.models import NodeExternalAccess from onyx.configs.constants import DocumentSource from onyx.connectors.gmail.connector import GmailConnector from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.models import HierarchyNode from onyx.db.models import ConnectorCredentialPair from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger logger = setup_logger() def _get_slim_doc_generator( cc_pair: ConnectorCredentialPair, gmail_connector: GmailConnector, callback: IndexingHeartbeatInterface | None = None, ) -> GenerateSlimDocumentOutput: current_time = datetime.now(timezone.utc) start_time = ( cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp() if cc_pair.last_time_perm_sync else 0.0 ) return gmail_connector.retrieve_all_slim_docs_perm_sync( start=start_time, end=current_time.timestamp(), callback=callback, ) def gmail_doc_sync( cc_pair: ConnectorCredentialPair, fetch_all_existing_docs_fn: FetchAllDocumentsFunction, # noqa: ARG001 fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction, # noqa: ARG001 callback: IndexingHeartbeatInterface | None, ) -> Generator[ElementExternalAccess, None, None]: """ Adds the external permissions to the documents and hierarchy nodes in postgres. If the document doesn't already exist in postgres, we create it in postgres so that when it gets created later, the permissions are already populated. """ gmail_connector = GmailConnector(**cc_pair.connector.connector_specific_config) credential_json = ( cc_pair.credential.credential_json.get_value(apply_mask=False) if cc_pair.credential.credential_json else {} ) gmail_connector.load_credentials(credential_json) slim_doc_generator = _get_slim_doc_generator( cc_pair, gmail_connector, callback=callback ) for slim_doc_batch in slim_doc_generator: for slim_doc in slim_doc_batch: if callback: if callback.should_stop(): raise RuntimeError("gmail_doc_sync: Stop signal detected") callback.progress("gmail_doc_sync", 1) if isinstance(slim_doc, HierarchyNode): # Yield hierarchy node permissions to be processed in outer layer if slim_doc.external_access: yield NodeExternalAccess( external_access=slim_doc.external_access, raw_node_id=slim_doc.raw_node_id, source=DocumentSource.GMAIL.value, ) continue if slim_doc.external_access is None: logger.warning(f"No permissions found for document {slim_doc.id}") continue yield DocExternalAccess( doc_id=slim_doc.id, external_access=slim_doc.external_access, ) ================================================ FILE: backend/ee/onyx/external_permissions/google_drive/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/external_permissions/google_drive/doc_sync.py ================================================ from collections.abc import Generator from datetime import datetime from datetime import timezone from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission from ee.onyx.external_permissions.google_drive.models import PermissionType from ee.onyx.external_permissions.google_drive.permission_retrieval import ( get_permissions_by_ids, ) from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction from onyx.access.models import DocExternalAccess from onyx.access.models import ElementExternalAccess from onyx.access.models import ExternalAccess from onyx.access.models import NodeExternalAccess from onyx.access.utils import build_ext_group_name_for_onyx from onyx.configs.constants import DocumentSource from onyx.connectors.google_drive.connector import GoogleDriveConnector from onyx.connectors.google_drive.models import GoogleDriveFileType from onyx.connectors.google_utils.resources import GoogleDriveService from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.models import HierarchyNode from onyx.db.models import ConnectorCredentialPair from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger logger = setup_logger() def _get_slim_doc_generator( cc_pair: ConnectorCredentialPair, google_drive_connector: GoogleDriveConnector, callback: IndexingHeartbeatInterface | None = None, ) -> GenerateSlimDocumentOutput: current_time = datetime.now(timezone.utc) start_time = ( cc_pair.last_time_perm_sync.replace(tzinfo=timezone.utc).timestamp() if cc_pair.last_time_perm_sync else 0.0 ) return google_drive_connector.retrieve_all_slim_docs_perm_sync( start=start_time, end=current_time.timestamp(), callback=callback, ) def _merge_permissions_lists( permission_lists: list[list[GoogleDrivePermission]], ) -> list[GoogleDrivePermission]: """ Merge a list of permission lists into a single list of permissions. """ seen_permission_ids: set[str] = set() merged_permissions: list[GoogleDrivePermission] = [] for permission_list in permission_lists: for permission in permission_list: if permission.id not in seen_permission_ids: merged_permissions.append(permission) seen_permission_ids.add(permission.id) return merged_permissions def get_external_access_for_raw_gdrive_file( file: GoogleDriveFileType, company_domain: str, retriever_drive_service: GoogleDriveService | None, admin_drive_service: GoogleDriveService, fallback_user_email: str, add_prefix: bool = False, ) -> ExternalAccess: """ Get the external access for a raw Google Drive file. Assumes the file we retrieved has EITHER `permissions` or `permission_ids` add_prefix: When this method is called during the initial indexing via the connector, set add_prefix to True so group IDs are prefixed with the source type. When invoked from doc_sync (permission sync), use the default (False) since upsert_document_external_perms handles prefixing. fallback_user_email: When we cannot retrieve any permission info for a file (e.g. externally-owned files where the API returns no permissions and permissions.list returns 403), fall back to granting access to this user. This is typically the impersonated org user whose drive contained the file. """ doc_id = file.get("id") if not doc_id: raise ValueError("No doc_id found in file") permissions = file.get("permissions") permission_ids = file.get("permissionIds") drive_id = file.get("driveId") permissions_list: list[GoogleDrivePermission] = [] if permissions: permissions_list = [ GoogleDrivePermission.from_drive_permission(p) for p in permissions ] elif permission_ids: def _get_permissions( drive_service: GoogleDriveService, ) -> list[GoogleDrivePermission]: return get_permissions_by_ids( drive_service=drive_service, doc_id=doc_id, permission_ids=permission_ids, ) permissions_list = _get_permissions( retriever_drive_service or admin_drive_service ) if len(permissions_list) != len(permission_ids) and retriever_drive_service: logger.warning( f"Failed to get all permissions for file {doc_id} with retriever service, trying admin service" ) backup_permissions_list = _get_permissions(admin_drive_service) permissions_list = _merge_permissions_lists( [permissions_list, backup_permissions_list] ) # For externally-owned files, the Drive API may return no permissions # and permissions.list may return 403. In this case, fall back to # granting access to the user who found the file in their drive. # Note, even if other users also have access to this file, # they will not be granted access in Onyx. # We check permissions_list (the final result after all fetch attempts) # rather than the raw fields, because permission_ids may be present # but the actual fetch can still return empty due to a 403. if not permissions_list: logger.info( f"No permission info available for file {doc_id} " f"(likely owned by a user outside of your organization). " f"Falling back to granting access to retriever user: {fallback_user_email}" ) return ExternalAccess( external_user_emails={fallback_user_email}, external_user_group_ids=set(), is_public=False, ) folder_ids_to_inherit_permissions_from: set[str] = set() user_emails: set[str] = set() group_emails: set[str] = set() public = False for permission in permissions_list: # if the permission is inherited, do not add it directly to the file # instead, add the folder ID as a group that has access to the file # we will then handle mapping that folder to the list of Onyx users # in the group sync job # NOTE: this doesn't handle the case where a folder initially has no # permissioning, but then later that folder is shared with a user or group. # We could fetch all ancestors of the file to get the list of folders that # might affect the permissions of the file, but this will get replaced with # an audit-log based approach in the future so not doing it now. if permission.inherited_from: folder_ids_to_inherit_permissions_from.add(permission.inherited_from) if permission.type == PermissionType.USER: if permission.email_address: user_emails.add(permission.email_address) else: logger.error( f"Permission is type `user` but no email address is provided for document {doc_id}\n {permission}" ) elif permission.type == PermissionType.GROUP: # groups are represented as email addresses within Drive if permission.email_address: group_emails.add(permission.email_address) else: logger.error( f"Permission is type `group` but no email address is provided for document {doc_id}\n {permission}" ) elif permission.type == PermissionType.DOMAIN and company_domain: if permission.domain == company_domain: public = True else: logger.warning( f"Permission is type domain but does not match company domain:\n {permission}" ) elif permission.type == PermissionType.ANYONE: public = True group_ids = ( group_emails | folder_ids_to_inherit_permissions_from | ({drive_id} if drive_id is not None else set()) ) # Prefix group IDs with source type if requested (for indexing path) if add_prefix: group_ids = { build_ext_group_name_for_onyx(group_id, DocumentSource.GOOGLE_DRIVE) for group_id in group_ids } return ExternalAccess( external_user_emails=user_emails, external_user_group_ids=group_ids, is_public=public, ) def get_external_access_for_folder( folder: GoogleDriveFileType, google_domain: str, drive_service: GoogleDriveService, add_prefix: bool = False, ) -> ExternalAccess: """ Extract ExternalAccess from a folder's permissions. This fetches permissions using the Drive API (via permissionIds) and extracts user emails, group emails, and public access status. Args: folder: The folder metadata from Google Drive API (must include permissionIds field) google_domain: The company's Google Workspace domain (e.g., "company.com") drive_service: Google Drive service for fetching permission details add_prefix: When True, prefix group IDs with source type (for indexing path). When False (default), leave unprefixed (for permission sync path). Returns: ExternalAccess with extracted permission info """ folder_id = folder.get("id") if not folder_id: logger.warning("Folder missing ID, returning empty permissions") return ExternalAccess( external_user_emails=set(), external_user_group_ids=set(), is_public=False, ) # Get permission IDs from folder metadata permission_ids = folder.get("permissionIds") or [] if not permission_ids: logger.debug(f"No permissionIds found for folder {folder_id}") return ExternalAccess( external_user_emails=set(), external_user_group_ids=set(), is_public=False, ) # Fetch full permission objects using the permission IDs permissions_list = get_permissions_by_ids( drive_service=drive_service, doc_id=folder_id, permission_ids=permission_ids, ) user_emails: set[str] = set() group_emails: set[str] = set() is_public = False for permission in permissions_list: if permission.type == PermissionType.USER: if permission.email_address: user_emails.add(permission.email_address) else: logger.warning(f"User permission without email for folder {folder_id}") elif permission.type == PermissionType.GROUP: # Groups are represented as email addresses in Google Drive if permission.email_address: group_emails.add(permission.email_address) else: logger.warning(f"Group permission without email for folder {folder_id}") elif permission.type == PermissionType.DOMAIN: # Domain permission - check if it matches company domain if permission.domain == google_domain: # Only public if discoverable (allowFileDiscovery is not False) # If allowFileDiscovery is False, it's "link only" access is_public = permission.allow_file_discovery is not False else: logger.debug( f"Domain permission for {permission.domain} does not match " f"company domain {google_domain} for folder {folder_id}" ) elif permission.type == PermissionType.ANYONE: # Only public if discoverable (allowFileDiscovery is not False) # If allowFileDiscovery is False, it's "link only" access is_public = permission.allow_file_discovery is not False # Prefix group IDs with source type if requested (for indexing path) group_ids: set[str] = group_emails if add_prefix: group_ids = { build_ext_group_name_for_onyx(group_id, DocumentSource.GOOGLE_DRIVE) for group_id in group_emails } return ExternalAccess( external_user_emails=user_emails, external_user_group_ids=group_ids, is_public=is_public, ) def gdrive_doc_sync( cc_pair: ConnectorCredentialPair, fetch_all_existing_docs_fn: FetchAllDocumentsFunction, # noqa: ARG001 fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction, # noqa: ARG001 callback: IndexingHeartbeatInterface | None, ) -> Generator[ElementExternalAccess, None, None]: """ Adds the external permissions to the documents and hierarchy nodes in postgres. If the document doesn't already exist in postgres, we create it in postgres so that when it gets created later, the permissions are already populated. """ google_drive_connector = GoogleDriveConnector( **cc_pair.connector.connector_specific_config ) credential_json = ( cc_pair.credential.credential_json.get_value(apply_mask=False) if cc_pair.credential.credential_json else {} ) google_drive_connector.load_credentials(credential_json) slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector) total_processed = 0 for slim_doc_batch in slim_doc_generator: logger.info(f"Drive perm sync: Processing {len(slim_doc_batch)} documents") for slim_doc in slim_doc_batch: if callback: if callback.should_stop(): raise RuntimeError("gdrive_doc_sync: Stop signal detected") callback.progress("gdrive_doc_sync", 1) if isinstance(slim_doc, HierarchyNode): # Yield hierarchy node permissions to be processed in outer layer if slim_doc.external_access: yield NodeExternalAccess( external_access=slim_doc.external_access, raw_node_id=slim_doc.raw_node_id, source=DocumentSource.GOOGLE_DRIVE.value, ) continue if slim_doc.external_access is None: raise ValueError( f"Drive perm sync: No external access for document {slim_doc.id}" ) yield DocExternalAccess( external_access=slim_doc.external_access, doc_id=slim_doc.id, ) total_processed += len(slim_doc_batch) logger.info(f"Drive perm sync: Processed {total_processed} total documents") ================================================ FILE: backend/ee/onyx/external_permissions/google_drive/folder_retrieval.py ================================================ from collections.abc import Iterator from googleapiclient.discovery import Resource # type: ignore from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission from ee.onyx.external_permissions.google_drive.permission_retrieval import ( get_permissions_by_ids, ) from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE from onyx.connectors.google_drive.file_retrieval import generate_time_range_filter from onyx.connectors.google_drive.models import GoogleDriveFileType from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.utils.logger import setup_logger logger = setup_logger() # Only include fields we need - folder ID and permissions # IMPORTANT: must fetch permissionIds, since sometimes the drive API # seems to miss permissions when requesting them directly FOLDER_PERMISSION_FIELDS = "nextPageToken, files(id, name, permissionIds, permissions(id, emailAddress, type, domain, permissionDetails))" def get_folder_permissions_by_ids( service: Resource, folder_id: str, permission_ids: list[str], ) -> list[GoogleDrivePermission]: """ Retrieves permissions for a specific folder filtered by permission IDs. Args: service: The Google Drive service instance folder_id: The ID of the folder to fetch permissions for permission_ids: A list of permission IDs to filter by Returns: A list of permissions matching the provided permission IDs """ return get_permissions_by_ids( drive_service=service, doc_id=folder_id, permission_ids=permission_ids, ) def get_modified_folders( service: Resource, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: """ Retrieves all folders that were modified within the specified time range. Only includes folder ID and permission information, not any contained files. Args: service: The Google Drive service instance start: The start time as seconds since Unix epoch (inclusive) end: The end time as seconds since Unix epoch (inclusive) Returns: An iterator yielding folder information including ID and permissions """ # Build query for folders query = f"mimeType = '{DRIVE_FOLDER_TYPE}'" query += " and trashed = false" query += generate_time_range_filter(start, end) # Retrieve and yield folders for folder in execute_paginated_retrieval( retrieval_function=service.files().list, list_key="files", continue_on_404_or_403=True, corpora="allDrives", supportsAllDrives=True, includeItemsFromAllDrives=True, includePermissionsForView="published", fields=FOLDER_PERMISSION_FIELDS, q=query, ): yield folder ================================================ FILE: backend/ee/onyx/external_permissions/google_drive/group_sync.py ================================================ from collections.abc import Generator from googleapiclient.errors import HttpError # type: ignore from pydantic import BaseModel from ee.onyx.db.external_perm import ExternalUserGroup from ee.onyx.external_permissions.google_drive.folder_retrieval import ( get_folder_permissions_by_ids, ) from ee.onyx.external_permissions.google_drive.folder_retrieval import ( get_modified_folders, ) from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission from ee.onyx.external_permissions.google_drive.models import PermissionType from onyx.connectors.google_drive.connector import GoogleDriveConnector from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval from onyx.connectors.google_utils.resources import AdminService from onyx.connectors.google_utils.resources import get_admin_service from onyx.connectors.google_utils.resources import get_drive_service from onyx.db.models import ConnectorCredentialPair from onyx.utils.logger import setup_logger logger = setup_logger() """ Folder Permission Sync. Each folder is treated as a group. Each file has all ancestor folders as groups. """ class FolderInfo(BaseModel): id: str permissions: list[GoogleDrivePermission] def _get_all_folders( google_drive_connector: GoogleDriveConnector, skip_folders_without_permissions: bool ) -> list[FolderInfo]: """Have to get all folders since the group syncing system assumes all groups are returned every time. TODO: tweak things so we can fetch deltas. """ MAX_FAILED_PERCENTAGE = 0.5 all_folders: list[FolderInfo] = [] seen_folder_ids: set[str] = set() def _get_all_folders_for_user( google_drive_connector: GoogleDriveConnector, skip_folders_without_permissions: bool, user_email: str, ) -> None: """Helper to get folders for a specific user + update shared seen_folder_ids""" drive_service = get_drive_service( google_drive_connector.creds, user_email, ) for folder in get_modified_folders( service=drive_service, ): folder_id = folder["id"] if folder_id in seen_folder_ids: logger.debug(f"Folder {folder_id} has already been seen. Skipping.") continue seen_folder_ids.add(folder_id) # Check if the folder has permission IDs but no permissions permission_ids = folder.get("permissionIds", []) raw_permissions = folder.get("permissions", []) if not raw_permissions and permission_ids: # Fetch permissions using the IDs permissions = get_folder_permissions_by_ids( drive_service, folder_id, permission_ids ) else: permissions = [ GoogleDrivePermission.from_drive_permission(permission) for permission in raw_permissions ] # Don't include inherited permissions, those will be captured # by the folder/shared drive itself permissions = [ permission for permission in permissions if permission.inherited_from is None ] if not permissions and skip_folders_without_permissions: logger.debug(f"Folder {folder_id} has no permissions. Skipping.") continue all_folders.append( FolderInfo( id=folder_id, permissions=permissions, ) ) failed_count = 0 user_emails = google_drive_connector._get_all_user_emails() for user_email in user_emails: try: _get_all_folders_for_user( google_drive_connector, skip_folders_without_permissions, user_email ) except Exception: logger.exception(f"Error getting folders for user {user_email}") failed_count += 1 if failed_count > MAX_FAILED_PERCENTAGE * len(user_emails): raise RuntimeError("Too many failed folder fetches during group sync") return all_folders def _drive_folder_to_onyx_group( folder: FolderInfo, group_email_to_member_emails_map: dict[str, list[str]], ) -> ExternalUserGroup: """ Converts a folder into an Onyx group. """ anyone_can_access = False folder_member_emails: set[str] = set() for permission in folder.permissions: if permission.type == PermissionType.USER: if permission.email_address is None: logger.warning( f"User email is None for folder {folder.id} permission {permission}" ) continue folder_member_emails.add(permission.email_address) elif permission.type == PermissionType.GROUP: if permission.email_address not in group_email_to_member_emails_map: logger.warning( f"Group email {permission.email_address} for folder {folder.id} not found in group_email_to_member_emails_map" ) continue folder_member_emails.update( group_email_to_member_emails_map[permission.email_address] ) elif permission.type == PermissionType.ANYONE: anyone_can_access = True return ExternalUserGroup( id=folder.id, user_emails=list(folder_member_emails), gives_anyone_access=anyone_can_access, ) """Individual Shared Drive / My Drive Permission Sync""" def _get_drive_members( google_drive_connector: GoogleDriveConnector, admin_service: AdminService, ) -> dict[str, tuple[set[str], set[str]]]: """ This builds a map of drive ids to their members (group and user emails). E.g. { "drive_id_1": ({"group_email_1"}, {"user_email_1", "user_email_2"}), "drive_id_2": ({"group_email_3"}, {"user_email_3"}), } """ # fetches shared drives only drive_ids = google_drive_connector.get_all_drive_ids() drive_id_to_members_map: dict[str, tuple[set[str], set[str]]] = {} drive_service = get_drive_service( google_drive_connector.creds, google_drive_connector.primary_admin_email, ) admin_user_info = ( admin_service.users() .get(userKey=google_drive_connector.primary_admin_email) .execute() ) is_admin = admin_user_info.get("isAdmin", False) or admin_user_info.get( "isDelegatedAdmin", False ) for drive_id in drive_ids: group_emails: set[str] = set() user_emails: set[str] = set() try: for permission in execute_paginated_retrieval( drive_service.permissions().list, list_key="permissions", fileId=drive_id, fields="permissions(emailAddress, type),nextPageToken", supportsAllDrives=True, # can only set `useDomainAdminAccess` to true if the user # is an admin useDomainAdminAccess=is_admin, ): # NOTE: don't need to check for PermissionType.ANYONE since # you can't share a drive with the internet if permission["type"] == PermissionType.GROUP: group_emails.add(permission["emailAddress"]) elif permission["type"] == PermissionType.USER: user_emails.add(permission["emailAddress"]) except HttpError as e: if e.status_code == 404: logger.warning( f"Error getting permissions for drive id {drive_id}. " f"User '{google_drive_connector.primary_admin_email}' likely " f"does not have access to this drive. Exception: {e}" ) else: raise e drive_id_to_members_map[drive_id] = (group_emails, user_emails) return drive_id_to_members_map def _drive_member_map_to_onyx_groups( drive_id_to_members_map: dict[str, tuple[set[str], set[str]]], group_email_to_member_emails_map: dict[str, list[str]], ) -> Generator[ExternalUserGroup, None, None]: """The `user_emails` for the Shared Drive should be all individuals in the Shared Drive + the union of all flattened group emails.""" for drive_id, (group_emails, user_emails) in drive_id_to_members_map.items(): drive_member_emails: set[str] = user_emails for group_email in group_emails: if group_email not in group_email_to_member_emails_map: logger.warning( f"Group email {group_email} for drive {drive_id} not found in group_email_to_member_emails_map" ) continue drive_member_emails.update(group_email_to_member_emails_map[group_email]) yield ExternalUserGroup( id=drive_id, user_emails=list(drive_member_emails), ) def _get_all_google_groups( admin_service: AdminService, google_domain: str, ) -> set[str]: """ This gets all the group emails. """ group_emails: set[str] = set() for group in execute_paginated_retrieval( admin_service.groups().list, list_key="groups", domain=google_domain, fields="groups(email),nextPageToken", ): group_emails.add(group["email"]) return group_emails def _google_group_to_onyx_group( admin_service: AdminService, group_email: str, ) -> ExternalUserGroup: """ This maps google group emails to their member emails. """ group_member_emails: set[str] = set() for member in execute_paginated_retrieval( admin_service.members().list, list_key="members", groupKey=group_email, fields="members(email),nextPageToken", ): group_member_emails.add(member["email"]) return ExternalUserGroup( id=group_email, user_emails=list(group_member_emails), ) def _map_group_email_to_member_emails( admin_service: AdminService, group_emails: set[str], ) -> dict[str, set[str]]: """ This maps group emails to their member emails. """ group_to_member_map: dict[str, set[str]] = {} for group_email in group_emails: group_member_emails: set[str] = set() for member in execute_paginated_retrieval( admin_service.members().list, list_key="members", groupKey=group_email, fields="members(email),nextPageToken", ): group_member_emails.add(member["email"]) group_to_member_map[group_email] = group_member_emails return group_to_member_map def _build_onyx_groups( drive_id_to_members_map: dict[str, tuple[set[str], set[str]]], group_email_to_member_emails_map: dict[str, set[str]], folder_info: list[FolderInfo], ) -> list[ExternalUserGroup]: onyx_groups: list[ExternalUserGroup] = [] # Convert all drive member definitions to onyx groups # This is because having drive level access means you have # irrevocable access to all the files in the drive. for drive_id, (group_emails, user_emails) in drive_id_to_members_map.items(): drive_member_emails: set[str] = user_emails for group_email in group_emails: if group_email not in group_email_to_member_emails_map: logger.warning( f"Group email {group_email} for drive {drive_id} not found in group_email_to_member_emails_map" ) continue drive_member_emails.update(group_email_to_member_emails_map[group_email]) onyx_groups.append( ExternalUserGroup( id=drive_id, user_emails=list(drive_member_emails), ) ) # Convert all folder permissions to onyx groups for folder in folder_info: anyone_can_access = False folder_member_emails: set[str] = set() for permission in folder.permissions: if permission.type == PermissionType.USER: if permission.email_address is None: logger.warning( f"User email is None for folder {folder.id} permission {permission}" ) continue folder_member_emails.add(permission.email_address) elif permission.type == PermissionType.GROUP: if permission.email_address not in group_email_to_member_emails_map: logger.warning( f"Group email {permission.email_address} for folder {folder.id} " "not found in group_email_to_member_emails_map" ) continue folder_member_emails.update( group_email_to_member_emails_map[permission.email_address] ) elif permission.type == PermissionType.ANYONE: anyone_can_access = True onyx_groups.append( ExternalUserGroup( id=folder.id, user_emails=list(folder_member_emails), gives_anyone_access=anyone_can_access, ) ) # Convert all group member definitions to onyx groups for group_email, member_emails in group_email_to_member_emails_map.items(): onyx_groups.append( ExternalUserGroup( id=group_email, user_emails=list(member_emails), ) ) return onyx_groups def gdrive_group_sync( tenant_id: str, # noqa: ARG001 cc_pair: ConnectorCredentialPair, ) -> Generator[ExternalUserGroup, None, None]: # Initialize connector and build credential/service objects google_drive_connector = GoogleDriveConnector( **cc_pair.connector.connector_specific_config ) credential_json = ( cc_pair.credential.credential_json.get_value(apply_mask=False) if cc_pair.credential.credential_json else {} ) google_drive_connector.load_credentials(credential_json) admin_service = get_admin_service( google_drive_connector.creds, google_drive_connector.primary_admin_email ) # Get all drive members drive_id_to_members_map = _get_drive_members(google_drive_connector, admin_service) # Get all group emails all_group_emails = _get_all_google_groups( admin_service, google_drive_connector.google_domain ) # Each google group is an Onyx group, yield those group_email_to_member_emails_map: dict[str, list[str]] = {} for group_email in all_group_emails: onyx_group = _google_group_to_onyx_group(admin_service, group_email) group_email_to_member_emails_map[group_email] = onyx_group.user_emails yield onyx_group # Each drive is a group, yield those for onyx_group in _drive_member_map_to_onyx_groups( drive_id_to_members_map, group_email_to_member_emails_map ): yield onyx_group # Get all folder permissions folder_info = _get_all_folders( google_drive_connector=google_drive_connector, skip_folders_without_permissions=True, ) for folder in folder_info: yield _drive_folder_to_onyx_group(folder, group_email_to_member_emails_map) ================================================ FILE: backend/ee/onyx/external_permissions/google_drive/models.py ================================================ from enum import Enum from typing import Any from pydantic import BaseModel class PermissionType(str, Enum): USER = "user" GROUP = "group" DOMAIN = "domain" ANYONE = "anyone" class GoogleDrivePermissionDetails(BaseModel): # this is "file", "member", etc. # different from the `type` field within `GoogleDrivePermission` # Sometimes can be not, although not sure why... permission_type: str | None # this is "reader", "writer", "owner", etc. role: str # this is the id of the parent permission inherited_from: str | None class GoogleDrivePermission(BaseModel): id: str # groups are also represented as email addresses within Drive # will be None for domain/global permissions email_address: str | None type: PermissionType domain: str | None # only applies to domain permissions permission_details: GoogleDrivePermissionDetails | None # Whether this permission makes the file discoverable in search # False means "anyone with the link" (not searchable/discoverable) # Only applicable for domain/anyone permission types allow_file_discovery: bool | None @classmethod def from_drive_permission( cls, drive_permission: dict[str, Any] ) -> "GoogleDrivePermission": # we seem to only get details for permissions that are inherited # we can get multiple details if a permission is inherited from multiple permission_details_list = drive_permission.get("permissionDetails", []) permission_details: dict[str, Any] | None = ( permission_details_list[0] if permission_details_list else None ) return cls( id=drive_permission["id"], email_address=drive_permission.get("emailAddress"), type=PermissionType(drive_permission["type"]), domain=drive_permission.get("domain"), allow_file_discovery=drive_permission.get("allowFileDiscovery"), permission_details=( GoogleDrivePermissionDetails( permission_type=permission_details.get("type"), role=permission_details.get("role", ""), inherited_from=permission_details.get("inheritedFrom"), ) if permission_details else None ), ) @property def inherited_from(self) -> str | None: if self.permission_details: return self.permission_details.inherited_from return None ================================================ FILE: backend/ee/onyx/external_permissions/google_drive/permission_retrieval.py ================================================ from retry import retry from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval from onyx.connectors.google_utils.resources import GoogleDriveService from onyx.utils.logger import setup_logger logger = setup_logger() @retry(tries=3, delay=2, backoff=2) def get_permissions_by_ids( drive_service: GoogleDriveService, doc_id: str, permission_ids: list[str], ) -> list[GoogleDrivePermission]: """ Fetches permissions for a document based on a list of permission IDs. Args: drive_service: The Google Drive service instance doc_id: The ID of the document to fetch permissions for permission_ids: A list of permission IDs to filter by Returns: A list of GoogleDrivePermission objects matching the provided permission IDs """ if not permission_ids: return [] # Create a set for faster lookup permission_id_set = set(permission_ids) # Fetch all permissions for the document fetched_permissions = execute_paginated_retrieval( retrieval_function=drive_service.permissions().list, list_key="permissions", fileId=doc_id, fields="permissions(id, emailAddress, type, domain, allowFileDiscovery, permissionDetails),nextPageToken", supportsAllDrives=True, continue_on_404_or_403=True, ) # Filter permissions by ID and convert to GoogleDrivePermission objects filtered_permissions = [] for permission in fetched_permissions: permission_id = permission.get("id") if permission_id in permission_id_set: google_drive_permission = GoogleDrivePermission.from_drive_permission( permission ) filtered_permissions.append(google_drive_permission) # Log if we couldn't find all requested permission IDs if len(filtered_permissions) < len(permission_ids): missing_ids = permission_id_set - {p.id for p in filtered_permissions if p.id} logger.warning( f"Could not find all requested permission IDs for document {doc_id}. Missing IDs: {missing_ids}" ) return filtered_permissions ================================================ FILE: backend/ee/onyx/external_permissions/jira/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/external_permissions/jira/doc_sync.py ================================================ from collections.abc import Generator from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction from ee.onyx.external_permissions.utils import generic_doc_sync from onyx.access.models import ElementExternalAccess from onyx.configs.constants import DocumentSource from onyx.connectors.jira.connector import JiraConnector from onyx.db.models import ConnectorCredentialPair from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger logger = setup_logger() JIRA_DOC_SYNC_TAG = "jira_doc_sync" def jira_doc_sync( cc_pair: ConnectorCredentialPair, fetch_all_existing_docs_fn: FetchAllDocumentsFunction, # noqa: ARG001 fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction, callback: IndexingHeartbeatInterface | None = None, ) -> Generator[ElementExternalAccess, None, None]: jira_connector = JiraConnector( **cc_pair.connector.connector_specific_config, ) credential_json = ( cc_pair.credential.credential_json.get_value(apply_mask=False) if cc_pair.credential.credential_json else {} ) jira_connector.load_credentials(credential_json) yield from generic_doc_sync( cc_pair=cc_pair, fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn, callback=callback, doc_source=DocumentSource.JIRA, slim_connector=jira_connector, label=JIRA_DOC_SYNC_TAG, ) ================================================ FILE: backend/ee/onyx/external_permissions/jira/group_sync.py ================================================ from collections.abc import Generator from typing import Any from jira import JIRA from jira.exceptions import JIRAError from ee.onyx.db.external_perm import ExternalUserGroup from onyx.connectors.jira.utils import build_jira_client from onyx.db.models import ConnectorCredentialPair from onyx.utils.logger import setup_logger logger = setup_logger() _ATLASSIAN_ACCOUNT_TYPE = "atlassian" _GROUP_MEMBER_PAGE_SIZE = 50 # The GET /group/member endpoint was introduced in Jira 6.0. # Jira versions older than 6.0 do not have group management REST APIs at all. _MIN_JIRA_VERSION_FOR_GROUP_MEMBER = "6.0" def _fetch_group_member_page( jira_client: JIRA, group_name: str, start_at: int, ) -> dict[str, Any]: """Fetch a single page from the non-deprecated GET /group/member endpoint. The old GET /group endpoint (used by jira_client.group_members()) is deprecated and decommissioned in Jira Server 10.3+. This uses the replacement endpoint directly via the library's internal _get_json helper, following the same pattern as enhanced_search_ids / bulk_fetch_issues in connector.py. There is an open PR to the library to switch to this endpoint since last year: https://github.com/pycontribs/jira/pull/2356 so once it is merged and released, we can switch to using the library function. """ try: return jira_client._get_json( "group/member", params={ "groupname": group_name, "includeInactiveUsers": "false", "startAt": start_at, "maxResults": _GROUP_MEMBER_PAGE_SIZE, }, ) except JIRAError as e: if e.status_code == 404: raise RuntimeError( f"GET /group/member returned 404 for group '{group_name}'. " f"This endpoint requires Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}+. " f"If you are running a self-hosted Jira instance, please upgrade " f"to at least Jira {_MIN_JIRA_VERSION_FOR_GROUP_MEMBER}." ) from e raise def _get_group_member_emails( jira_client: JIRA, group_name: str, ) -> set[str]: """Get all member emails for a single Jira group. Uses the non-deprecated GET /group/member endpoint which returns full user objects including accountType, so we can filter out app/customer accounts without making separate user() calls. """ emails: set[str] = set() start_at = 0 while True: try: page = _fetch_group_member_page(jira_client, group_name, start_at) except Exception as e: logger.error(f"Error fetching members for group {group_name}: {e}") raise members: list[dict[str, Any]] = page.get("values", []) for member in members: account_type = member.get("accountType") # On Jira DC < 9.0, accountType is absent; include those users. # On Cloud / DC 9.0+, filter to real user accounts only. if account_type is not None and account_type != _ATLASSIAN_ACCOUNT_TYPE: continue email = member.get("emailAddress") if email: emails.add(email) else: logger.warning( f"Atlassian user {member.get('accountId', 'unknown')} in group {group_name} has no visible email address" ) if page.get("isLast", True) or not members: break start_at += len(members) return emails def jira_group_sync( tenant_id: str, # noqa: ARG001 cc_pair: ConnectorCredentialPair, ) -> Generator[ExternalUserGroup, None, None]: """Sync Jira groups and their members, yielding one group at a time. Streams group-by-group rather than accumulating all groups in memory. """ jira_base_url = cc_pair.connector.connector_specific_config.get("jira_base_url", "") scoped_token = cc_pair.connector.connector_specific_config.get( "scoped_token", False ) if not jira_base_url: raise ValueError("No jira_base_url found in connector config") credential_json = ( cc_pair.credential.credential_json.get_value(apply_mask=False) if cc_pair.credential.credential_json else {} ) jira_client = build_jira_client( credentials=credential_json, jira_base=jira_base_url, scoped_token=scoped_token, ) group_names = jira_client.groups() if not group_names: raise ValueError(f"No groups found for cc_pair_id={cc_pair.id}") logger.info(f"Found {len(group_names)} groups in Jira") for group_name in group_names: if not group_name: continue member_emails = _get_group_member_emails( jira_client=jira_client, group_name=group_name, ) if not member_emails: logger.debug(f"No members found for group {group_name}") continue logger.debug(f"Found {len(member_emails)} members for group {group_name}") yield ExternalUserGroup( id=group_name, user_emails=list(member_emails), ) ================================================ FILE: backend/ee/onyx/external_permissions/jira/models.py ================================================ from typing import Any from pydantic import BaseModel from pydantic import ConfigDict from pydantic.alias_generators import to_camel Holder = dict[str, Any] class Permission(BaseModel): id: int permission: str holder: Holder | None class User(BaseModel): account_id: str email_address: str display_name: str active: bool model_config = ConfigDict( alias_generator=to_camel, ) ================================================ FILE: backend/ee/onyx/external_permissions/jira/page_access.py ================================================ from collections import defaultdict from jira import JIRA from jira.resources import PermissionScheme from pydantic import ValidationError from ee.onyx.external_permissions.jira.models import Holder from ee.onyx.external_permissions.jira.models import Permission from ee.onyx.external_permissions.jira.models import User from onyx.access.models import ExternalAccess from onyx.access.utils import build_ext_group_name_for_onyx from onyx.configs.constants import DocumentSource from onyx.utils.logger import setup_logger HolderMap = dict[str, list[Holder]] logger = setup_logger() def _get_role_id(holder: Holder) -> str | None: return holder.get("value") or holder.get("parameter") def _build_holder_map(permissions: list[dict]) -> dict[str, list[Holder]]: """ A "Holder" in JIRA is a person / entity who "holds" the corresponding permission. It can have different types. They can be one of (but not limited to): - user (an explicitly whitelisted user) - projectRole (for project level "roles") - reporter (the reporter of an issue) A "Holder" usually has following structure: - `{ "type": "user", "value": "$USER_ID", "user": { .. }, .. }` - `{ "type": "projectRole", "value": "$PROJECT_ID", .. }` When we fetch the PermissionSchema from JIRA, we retrieve a list of "Holder"s. The list of "Holder"s can have multiple "Holder"s of the same type in the list (e.g., you can have two `"type": "user"`s in there, each corresponding to a different user). This function constructs a map of "Holder" types to a list of the "Holder"s which contained that type. Returns: A dict from the "Holder" type to the actual "Holder" instance. Example: ``` { "user": [ { "type": "user", "value": "10000", "user": { .. }, .. }, { "type": "user", "value": "10001", "user": { .. }, .. }, ], "projectRole": [ { "type": "projectRole", "value": "10010", .. }, { "type": "projectRole", "value": "10011", .. }, ], "applicationRole": [ { "type": "applicationRole" }, ], .. } ``` """ holder_map: defaultdict[str, list[Holder]] = defaultdict(list) for raw_perm in permissions: if not hasattr(raw_perm, "raw"): logger.warning(f"Expected a 'raw' field, but none was found: {raw_perm=}") continue permission = Permission(**raw_perm.raw) # We only care about ability to browse through projects + issues (not other permissions such as read/write). if permission.permission != "BROWSE_PROJECTS": continue # In order to associate this permission to some Atlassian entity, we need the "Holder". # If this doesn't exist, then we cannot associate this permission to anyone; just skip. if not permission.holder: logger.warning( f"Expected to find a permission holder, but none was found: {permission=}" ) continue type = permission.holder.get("type") if not type: logger.warning( f"Expected to find the type of permission holder, but none was found: {permission=}" ) continue holder_map[type].append(permission.holder) return holder_map def _get_user_emails(user_holders: list[Holder]) -> list[str]: emails = [] for user_holder in user_holders: if "user" not in user_holder: continue raw_user_dict = user_holder["user"] try: user_model = User.model_validate(raw_user_dict) except ValidationError: logger.error( "Expected to be able to serialize the raw-user-dict into an instance of `User`, but validation failed;" f"{raw_user_dict=}" ) continue emails.append(user_model.email_address) return emails def _get_user_emails_and_groups_from_project_roles( jira_client: JIRA, jira_project: str, project_role_holders: list[Holder], ) -> tuple[list[str], list[str]]: """ Get user emails and group names from project roles. Returns a tuple of (emails, group_names). """ # Get role IDs - Cloud uses "value", Data Center uses "parameter" role_ids = [] for holder in project_role_holders: role_id = _get_role_id(holder) if role_id: role_ids.append(role_id) else: logger.warning(f"No value or parameter in projectRole holder: {holder}") roles = [ jira_client.project_role(project=jira_project, id=role_id) for role_id in role_ids ] emails = [] groups = [] for role in roles: if not hasattr(role, "actors"): logger.warning(f"Project role {role} has no actors attribute") continue for actor in role.actors: # Handle group actors if hasattr(actor, "actorGroup"): group_name = getattr(actor.actorGroup, "name", None) or getattr( actor.actorGroup, "displayName", None ) if group_name: groups.append(group_name) continue # Handle user actors if hasattr(actor, "actorUser"): account_id = getattr(actor.actorUser, "accountId", None) if not account_id: logger.error(f"No accountId in actorUser: {actor.actorUser}") continue user = jira_client.user(id=account_id) if not hasattr(user, "accountType") or user.accountType != "atlassian": logger.info( f"Skipping user {account_id} because it is not an atlassian user" ) continue if not hasattr(user, "emailAddress"): msg = f"User's email address was not able to be retrieved; {actor.actorUser.accountId=}" if hasattr(user, "displayName"): msg += f" {actor.displayName=}" logger.warning(msg) continue emails.append(user.emailAddress) continue logger.debug(f"Skipping actor type: {actor}") return emails, groups def _build_external_access_from_holder_map( jira_client: JIRA, jira_project: str, holder_map: HolderMap ) -> ExternalAccess: """ Build ExternalAccess from the holder map. Holder types handled: - "anyone": Public project, anyone can access - "applicationRole": All users with a Jira license can access (treated as public) - "user": Specific users with access - "projectRole": Project roles containing users and/or groups - "group": Groups directly assigned in the permission scheme """ # Public access - anyone can view if "anyone" in holder_map: return ExternalAccess( external_user_emails=set(), external_user_group_ids=set(), is_public=True ) # applicationRole means all users with a Jira license can access - treat as public if "applicationRole" in holder_map: return ExternalAccess( external_user_emails=set(), external_user_group_ids=set(), is_public=True ) # Get emails from explicit user holders user_emails = ( _get_user_emails(user_holders=holder_map["user"]) if "user" in holder_map else [] ) # Get emails and groups from project roles project_role_user_emails: list[str] = [] project_role_groups: list[str] = [] if "projectRole" in holder_map: project_role_user_emails, project_role_groups = ( _get_user_emails_and_groups_from_project_roles( jira_client=jira_client, jira_project=jira_project, project_role_holders=holder_map["projectRole"], ) ) # Get groups directly assigned in permission scheme (common in Data Center) # Format: {'type': 'group', 'parameter': 'group-name', 'expand': 'group'} direct_groups: list[str] = [] if "group" in holder_map: for group_holder in holder_map["group"]: group_name = _get_role_id(group_holder) if group_name: direct_groups.append(group_name) else: logger.error(f"No parameter/value in group holder: {group_holder}") external_user_emails = set(user_emails + project_role_user_emails) external_user_group_ids = set(project_role_groups + direct_groups) return ExternalAccess( external_user_emails=external_user_emails, external_user_group_ids=external_user_group_ids, is_public=False, ) def get_project_permissions( jira_client: JIRA, jira_project: str, add_prefix: bool = False, ) -> ExternalAccess | None: """ Get project permissions from Jira. add_prefix: When True, prefix group IDs with source type (for indexing path). When False (default), leave unprefixed (for permission sync path). """ project_permissions: PermissionScheme = jira_client.project_permissionscheme( project=jira_project ) if not hasattr(project_permissions, "permissions"): logger.error(f"Project {jira_project} has no permissions attribute") return None if not isinstance(project_permissions.permissions, list): logger.error(f"Project {jira_project} permissions is not a list") return None holder_map = _build_holder_map(permissions=project_permissions.permissions) external_access = _build_external_access_from_holder_map( jira_client=jira_client, jira_project=jira_project, holder_map=holder_map ) # Prefix group IDs with source type if requested (for indexing path) if add_prefix and external_access and external_access.external_user_group_ids: prefixed_groups = { build_ext_group_name_for_onyx(g, DocumentSource.JIRA) for g in external_access.external_user_group_ids } return ExternalAccess( external_user_emails=external_access.external_user_emails, external_user_group_ids=prefixed_groups, is_public=external_access.is_public, ) return external_access ================================================ FILE: backend/ee/onyx/external_permissions/perm_sync_types.py ================================================ from collections.abc import Callable from collections.abc import Generator from typing import Optional from typing import Protocol from ee.onyx.db.external_perm import ExternalUserGroup # noqa from onyx.access.models import DocExternalAccess # noqa from onyx.access.models import ElementExternalAccess # noqa from onyx.access.models import NodeExternalAccess # noqa from onyx.context.search.models import InferenceChunk from onyx.db.models import ConnectorCredentialPair # noqa from onyx.db.utils import DocumentRow from onyx.db.utils import SortOrder from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface # noqa class FetchAllDocumentsFunction(Protocol): """Protocol for a function that fetches documents for a connector credential pair. This protocol defines the interface for functions that retrieve documents from the database, typically used in permission synchronization workflows. """ def __call__( self, sort_order: SortOrder | None, ) -> list[DocumentRow]: """ Fetches documents for a connector credential pair. """ ... class FetchAllDocumentsIdsFunction(Protocol): """Protocol for a function that fetches document IDs for a connector credential pair. This protocol defines the interface for functions that retrieve document IDs from the database, typically used in permission synchronization workflows. """ def __call__( self, ) -> list[str]: """ Fetches document IDs for a connector credential pair. """ ... # Defining the input/output types for the sync functions DocSyncFuncType = Callable[ [ ConnectorCredentialPair, FetchAllDocumentsFunction, FetchAllDocumentsIdsFunction, Optional[IndexingHeartbeatInterface], ], Generator[ElementExternalAccess, None, None], ] GroupSyncFuncType = Callable[ [ str, # tenant_id ConnectorCredentialPair, # cc_pair ], Generator[ExternalUserGroup, None, None], ] # list of chunks to be censored and the user email. returns censored chunks CensoringFuncType = Callable[[list[InferenceChunk], str], list[InferenceChunk]] ================================================ FILE: backend/ee/onyx/external_permissions/post_query_censoring.py ================================================ from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs from ee.onyx.external_permissions.sync_params import get_all_censoring_enabled_sources from ee.onyx.external_permissions.sync_params import get_source_perm_sync_config from onyx.configs.constants import DocumentSource from onyx.context.search.pipeline import InferenceChunk from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import User from onyx.utils.logger import setup_logger logger = setup_logger() def _get_all_censoring_enabled_sources() -> set[DocumentSource]: """ Returns the set of sources that have censoring enabled. This is based on if the access_type is set to sync and the connector source has a censoring config. NOTE: This means if there is a source has a single cc_pair that is sync, all chunks for that source will be censored, even if the connector that indexed that chunk is not sync. This was done to avoid getting the cc_pair for every single chunk. """ all_censoring_enabled_sources = get_all_censoring_enabled_sources() with get_session_with_current_tenant() as db_session: enabled_sync_connectors = get_all_auto_sync_cc_pairs(db_session) return { cc_pair.connector.source for cc_pair in enabled_sync_connectors if cc_pair.connector.source in all_censoring_enabled_sources } # NOTE: This is only called if ee is enabled. def _post_query_chunk_censoring( chunks: list[InferenceChunk], user: User, ) -> list[InferenceChunk]: """ This function checks all chunks to see if they need to be sent to a censoring function. If they do, it sends them to the censoring function and returns the censored chunks. If they don't, it returns the original chunks. """ sources_to_censor = _get_all_censoring_enabled_sources() # Anonymous users can only access public (non-permission-synced) content if user.is_anonymous: return [chunk for chunk in chunks if chunk.source_type not in sources_to_censor] final_chunk_dict: dict[str, InferenceChunk] = {} chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {} for chunk in chunks: # Separate out chunks that require permission post-processing by source if chunk.source_type in sources_to_censor: chunks_to_process.setdefault(chunk.source_type, []).append(chunk) else: final_chunk_dict[chunk.unique_id] = chunk # For each source, filter out the chunks using the permission # check function for that source # TODO: Use a threadpool/multiprocessing to process the sources in parallel for source, chunks_for_source in chunks_to_process.items(): sync_config = get_source_perm_sync_config(source) if sync_config is None or sync_config.censoring_config is None: raise ValueError(f"No sync config found for {source}") censor_chunks_for_source = sync_config.censoring_config.chunk_censoring_func try: censored_chunks = censor_chunks_for_source(chunks_for_source, user.email) except Exception as e: logger.exception( f"Failed to censor chunks for source {source} so throwing out all chunks for this source and continuing: {e}" ) continue for censored_chunk in censored_chunks: final_chunk_dict[censored_chunk.unique_id] = censored_chunk # IMPORTANT: make sure to retain the same ordering as the original `chunks` passed in final_chunk_list: list[InferenceChunk] = [] for chunk in chunks: # only if the chunk is in the final censored chunks, add it to the final list # if it is missing, that means it was intentionally left out if chunk.unique_id in final_chunk_dict: final_chunk_list.append(final_chunk_dict[chunk.unique_id]) return final_chunk_list ================================================ FILE: backend/ee/onyx/external_permissions/salesforce/postprocessing.py ================================================ import time from ee.onyx.db.external_perm import fetch_external_groups_for_user_email_and_group_ids from ee.onyx.external_permissions.salesforce.utils import ( get_any_salesforce_client_for_doc_id, ) from ee.onyx.external_permissions.salesforce.utils import get_objects_access_for_user_id from ee.onyx.external_permissions.salesforce.utils import ( get_salesforce_user_id_from_email, ) from onyx.configs.app_configs import BLURB_SIZE from onyx.context.search.models import InferenceChunk from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.utils.logger import setup_logger logger = setup_logger() # Types ChunkKey = tuple[str, int] # (doc_id, chunk_id) ContentRange = tuple[int, int | None] # (start_index, end_index) None means to the end # NOTE: Used for testing timing def _get_dummy_object_access_map( object_ids: set[str], user_email: str, # noqa: ARG001 chunks: list[InferenceChunk], # noqa: ARG001 ) -> dict[str, bool]: time.sleep(0.15) # return {object_id: True for object_id in object_ids} import random return {object_id: random.choice([True, False]) for object_id in object_ids} def _get_objects_access_for_user_email_from_salesforce( object_ids: set[str], user_email: str, chunks: list[InferenceChunk], ) -> dict[str, bool] | None: """ This function wraps the salesforce call as we may want to change how this is done in the future. (E.g. replace it with the above function) """ # This is cached in the function so the first query takes an extra 0.1-0.3 seconds # but subsequent queries for this source are essentially instant first_doc_id = chunks[0].document_id with get_session_with_current_tenant() as db_session: salesforce_client = get_any_salesforce_client_for_doc_id( db_session, first_doc_id ) # This is cached in the function so the first query takes an extra 0.1-0.3 seconds # but subsequent queries by the same user are essentially instant start_time = time.monotonic() user_id = get_salesforce_user_id_from_email(salesforce_client, user_email) end_time = time.monotonic() logger.info( f"Time taken to get Salesforce user ID: {end_time - start_time} seconds" ) if user_id is None: logger.warning(f"User '{user_email}' not found in Salesforce") return None # This is the only query that is not cached in the function # so it takes 0.1-0.2 seconds total object_id_to_access = get_objects_access_for_user_id( salesforce_client, user_id, list(object_ids) ) logger.debug(f"Object ID to access: {object_id_to_access}") return object_id_to_access def _extract_salesforce_object_id_from_url(url: str) -> str: return url.split("/")[-1] def _get_object_ranges_for_chunk( chunk: InferenceChunk, ) -> dict[str, list[ContentRange]]: """ Given a chunk, return a dictionary of salesforce object ids and the content ranges for that object id in the current chunk """ if chunk.source_links is None: return {} object_ranges: dict[str, list[ContentRange]] = {} end_index = None descending_source_links = sorted( chunk.source_links.items(), key=lambda x: x[0], reverse=True ) for start_index, url in descending_source_links: object_id = _extract_salesforce_object_id_from_url(url) if object_id not in object_ranges: object_ranges[object_id] = [] object_ranges[object_id].append((start_index, end_index)) end_index = start_index return object_ranges def _create_empty_censored_chunk(uncensored_chunk: InferenceChunk) -> InferenceChunk: """ Create a copy of the unfiltered chunk where potentially sensitive content is removed to be added later if the user has access to each of the sub-objects """ empty_censored_chunk = InferenceChunk( **uncensored_chunk.model_dump(), ) empty_censored_chunk.content = "" empty_censored_chunk.blurb = "" empty_censored_chunk.source_links = {} return empty_censored_chunk def _update_censored_chunk( censored_chunk: InferenceChunk, uncensored_chunk: InferenceChunk, content_range: ContentRange, ) -> InferenceChunk: """ Update the filtered chunk with the content and source links from the unfiltered chunk using the content ranges """ start_index, end_index = content_range # Update the content of the filtered chunk permitted_content = uncensored_chunk.content[start_index:end_index] permitted_section_start_index = len(censored_chunk.content) censored_chunk.content = permitted_content + censored_chunk.content # Update the source links of the filtered chunk if uncensored_chunk.source_links is not None: if censored_chunk.source_links is None: censored_chunk.source_links = {} link_content = uncensored_chunk.source_links[start_index] censored_chunk.source_links[permitted_section_start_index] = link_content # Update the blurb of the filtered chunk censored_chunk.blurb = censored_chunk.content[:BLURB_SIZE] return censored_chunk # TODO: Generalize this to other sources def censor_salesforce_chunks( chunks: list[InferenceChunk], user_email: str, # This is so we can provide a mock access map for testing access_map: dict[str, bool] | None = None, ) -> list[InferenceChunk]: # object_id -> list[((doc_id, chunk_id), (start_index, end_index))] object_to_content_map: dict[str, list[tuple[ChunkKey, ContentRange]]] = {} # (doc_id, chunk_id) -> chunk uncensored_chunks: dict[ChunkKey, InferenceChunk] = {} # keep track of all object ids that we have seen to make it easier to get # the access for these object ids object_ids: set[str] = set() for chunk in chunks: chunk_key = (chunk.document_id, chunk.chunk_id) # create a dictionary to quickly look up the unfiltered chunk uncensored_chunks[chunk_key] = chunk # for each chunk, get a dictionary of object ids and the content ranges # for that object id in the current chunk object_ranges_for_chunk = _get_object_ranges_for_chunk(chunk) for object_id, ranges in object_ranges_for_chunk.items(): object_ids.add(object_id) for start_index, end_index in ranges: object_to_content_map.setdefault(object_id, []).append( (chunk_key, (start_index, end_index)) ) # This is so we can provide a mock access map for testing if access_map is None: access_map = _get_objects_access_for_user_email_from_salesforce( object_ids=object_ids, user_email=user_email, chunks=chunks, ) if access_map is None: # If the user is not found in Salesforce, access_map will be None # so we should just return an empty list because no chunks will be # censored return [] censored_chunks: dict[ChunkKey, InferenceChunk] = {} for object_id, content_list in object_to_content_map.items(): # if the user does not have access to the object, or the object is not in the # access_map, do not include its content in the filtered chunks if not access_map.get(object_id, False): continue # if we got this far, the user has access to the object so we can create or update # the filtered chunk(s) for this object # NOTE: we only create a censored chunk if the user has access to some # part of the chunk for chunk_key, content_range in content_list: if chunk_key not in censored_chunks: censored_chunks[chunk_key] = _create_empty_censored_chunk( uncensored_chunks[chunk_key] ) uncensored_chunk = uncensored_chunks[chunk_key] censored_chunk = _update_censored_chunk( censored_chunk=censored_chunks[chunk_key], uncensored_chunk=uncensored_chunk, content_range=content_range, ) censored_chunks[chunk_key] = censored_chunk return list(censored_chunks.values()) # NOTE: This is not used anywhere. def _get_objects_access_for_user_email( object_ids: set[str], user_email: str ) -> dict[str, bool]: with get_session_with_current_tenant() as db_session: external_groups = fetch_external_groups_for_user_email_and_group_ids( db_session=db_session, user_email=user_email, # Maybe make a function that adds a salesforce prefix to the group ids group_ids=list(object_ids), ) external_group_ids = {group.external_user_group_id for group in external_groups} return {group_id: group_id in external_group_ids for group_id in object_ids} ================================================ FILE: backend/ee/onyx/external_permissions/salesforce/utils.py ================================================ from simple_salesforce import Salesforce from sqlalchemy.orm import Session from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.document import get_cc_pairs_for_document from onyx.utils.logger import setup_logger logger = setup_logger() _ANY_SALESFORCE_CLIENT: Salesforce | None = None def get_any_salesforce_client_for_doc_id( db_session: Session, doc_id: str ) -> Salesforce: """ We create a salesforce client for the first cc_pair for the first doc_id where salesforce censoring is enabled. After that we just cache and reuse the same client for all queries. We do this to reduce the number of postgres queries we make at query time. This may be problematic if they are using multiple cc_pairs for salesforce. E.g. there are 2 different credential sets for 2 different salesforce cc_pairs but only one has the permissions to access the permissions needed for the query. """ # NOTE: this global seems very very bad global _ANY_SALESFORCE_CLIENT if _ANY_SALESFORCE_CLIENT is None: cc_pairs = get_cc_pairs_for_document(db_session, doc_id) first_cc_pair = cc_pairs[0] credential_json = ( first_cc_pair.credential.credential_json.get_value(apply_mask=False) if first_cc_pair.credential.credential_json else {} ) _ANY_SALESFORCE_CLIENT = Salesforce( username=credential_json["sf_username"], password=credential_json["sf_password"], security_token=credential_json["sf_security_token"], ) return _ANY_SALESFORCE_CLIENT def _query_salesforce_user_id(sf_client: Salesforce, user_email: str) -> str | None: query = f"SELECT Id FROM User WHERE Username = '{user_email}' AND IsActive = true" result = sf_client.query(query) if len(result["records"]) > 0: return result["records"][0]["Id"] # try emails query = f"SELECT Id FROM User WHERE Email = '{user_email}' AND IsActive = true" result = sf_client.query(query) if len(result["records"]) > 0: return result["records"][0]["Id"] return None # This contains only the user_ids that we have found in Salesforce. # If we don't know their user_id, we don't store anything in the cache. _CACHED_SF_EMAIL_TO_ID_MAP: dict[str, str] = {} def get_salesforce_user_id_from_email( sf_client: Salesforce, user_email: str, ) -> str | None: """ We cache this so we don't have to query Salesforce for every query and salesforce user IDs never change. Memory usage is fine because we just store 2 small strings per user. If the email is not in the cache, we check the local salesforce database for the info. If the user is not found in the local salesforce database, we query Salesforce. Whatever we get back from Salesforce is added to the database. If no user_id is found, we add a NULL_ID_STRING to the database for that email so we don't query Salesforce again (which is slow) but we still check the local salesforce database every query until a user id is found. This is acceptable because the query time is quite fast. If a user_id is created in Salesforce, it will be added to the local salesforce database next time the connector is run. Then that value will be found in this function and cached. NOTE: First time this runs, it may be slow if it hasn't already been updated in the local salesforce database. (Around 0.1-0.3 seconds) If it's cached or stored in the local salesforce database, it's fast (<0.001 seconds). """ # NOTE: this global seems bad global _CACHED_SF_EMAIL_TO_ID_MAP if user_email in _CACHED_SF_EMAIL_TO_ID_MAP: if _CACHED_SF_EMAIL_TO_ID_MAP[user_email] is not None: return _CACHED_SF_EMAIL_TO_ID_MAP[user_email] # some caching via sqlite existed here before ... check history if interested # ...query Salesforce and store the result in the database user_id = _query_salesforce_user_id(sf_client, user_email) if user_id is None: return None # If the found user_id is real, cache it _CACHED_SF_EMAIL_TO_ID_MAP[user_email] = user_id return user_id _MAX_RECORD_IDS_PER_QUERY = 200 def get_objects_access_for_user_id( salesforce_client: Salesforce, user_id: str, record_ids: list[str], ) -> dict[str, bool]: """ Salesforce has a limit of 200 record ids per query. So we just truncate the list of record ids to 200. We only ever retrieve 50 chunks at a time so this should be fine (unlikely that we retrieve all 50 chunks contain 4 unique objects). If we decide this isn't acceptable we can use multiple queries but they should be in parallel so query time doesn't get too long. """ truncated_record_ids = record_ids[:_MAX_RECORD_IDS_PER_QUERY] record_ids_str = "'" + "','".join(truncated_record_ids) + "'" access_query = f""" SELECT RecordId, HasReadAccess FROM UserRecordAccess WHERE RecordId IN ({record_ids_str}) AND UserId = '{user_id}' """ result = salesforce_client.query_all(access_query) return {record["RecordId"]: record["HasReadAccess"] for record in result["records"]} _CC_PAIR_ID_SALESFORCE_CLIENT_MAP: dict[int, Salesforce] = {} _DOC_ID_TO_CC_PAIR_ID_MAP: dict[str, int] = {} # NOTE: This is not used anywhere. def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Salesforce: """ Uses a document id to get the cc_pair that indexed that document and uses the credentials for that cc_pair to create a Salesforce client. Problems: - There may be multiple cc_pairs for a document, and we don't know which one to use. - right now we just use the first one - Building a new Salesforce client for each document is slow. - Memory usage could be an issue as we build these dictionaries. """ if doc_id not in _DOC_ID_TO_CC_PAIR_ID_MAP: cc_pairs = get_cc_pairs_for_document(db_session, doc_id) first_cc_pair = cc_pairs[0] _DOC_ID_TO_CC_PAIR_ID_MAP[doc_id] = first_cc_pair.id cc_pair_id = _DOC_ID_TO_CC_PAIR_ID_MAP[doc_id] if cc_pair_id not in _CC_PAIR_ID_SALESFORCE_CLIENT_MAP: cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, ) if cc_pair is None: raise ValueError(f"CC pair {cc_pair_id} not found") credential_json = ( cc_pair.credential.credential_json.get_value(apply_mask=False) if cc_pair.credential.credential_json else {} ) _CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id] = Salesforce( username=credential_json["sf_username"], password=credential_json["sf_password"], security_token=credential_json["sf_security_token"], ) return _CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id] ================================================ FILE: backend/ee/onyx/external_permissions/sharepoint/doc_sync.py ================================================ from collections.abc import Generator from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction from ee.onyx.external_permissions.utils import generic_doc_sync from onyx.access.models import ElementExternalAccess from onyx.configs.constants import DocumentSource from onyx.connectors.sharepoint.connector import SharepointConnector from onyx.db.models import ConnectorCredentialPair from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger logger = setup_logger() SHAREPOINT_DOC_SYNC_TAG = "sharepoint_doc_sync" def sharepoint_doc_sync( cc_pair: ConnectorCredentialPair, fetch_all_existing_docs_fn: FetchAllDocumentsFunction, # noqa: ARG001 fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction, callback: IndexingHeartbeatInterface | None = None, ) -> Generator[ElementExternalAccess, None, None]: sharepoint_connector = SharepointConnector( **cc_pair.connector.connector_specific_config, ) credential_json = ( cc_pair.credential.credential_json.get_value(apply_mask=False) if cc_pair.credential.credential_json else {} ) sharepoint_connector.load_credentials(credential_json) yield from generic_doc_sync( cc_pair=cc_pair, fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn, callback=callback, doc_source=DocumentSource.SHAREPOINT, slim_connector=sharepoint_connector, label=SHAREPOINT_DOC_SYNC_TAG, ) ================================================ FILE: backend/ee/onyx/external_permissions/sharepoint/group_sync.py ================================================ from collections.abc import Generator from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped] from ee.onyx.db.external_perm import ExternalUserGroup from ee.onyx.external_permissions.sharepoint.permission_utils import ( get_sharepoint_external_groups, ) from onyx.configs.app_configs import SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION from onyx.connectors.sharepoint.connector import acquire_token_for_rest from onyx.connectors.sharepoint.connector import SharepointConnector from onyx.db.models import ConnectorCredentialPair from onyx.utils.logger import setup_logger logger = setup_logger() def sharepoint_group_sync( tenant_id: str, # noqa: ARG001 cc_pair: ConnectorCredentialPair, ) -> Generator[ExternalUserGroup, None, None]: """Sync SharePoint groups and their members""" # Get site URLs from connector config connector_config = cc_pair.connector.connector_specific_config # Create SharePoint connector instance and load credentials connector = SharepointConnector(**connector_config) credential_json = ( cc_pair.credential.credential_json.get_value(apply_mask=False) if cc_pair.credential.credential_json else {} ) connector.load_credentials(credential_json) if not connector.msal_app: raise RuntimeError("MSAL app not initialized in connector") if not connector.sp_tenant_domain: raise RuntimeError("Tenant domain not initialized in connector") # Get site descriptors from connector (either configured sites or all sites) site_descriptors = connector.site_descriptors or connector.fetch_sites() if not site_descriptors: raise RuntimeError("No SharePoint sites found for group sync") logger.info(f"Processing {len(site_descriptors)} sites for group sync") enumerate_all = connector_config.get( "exhaustive_ad_enumeration", SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION ) msal_app = connector.msal_app sp_tenant_domain = connector.sp_tenant_domain sp_domain_suffix = connector.sharepoint_domain_suffix for site_descriptor in site_descriptors: logger.debug(f"Processing site: {site_descriptor.url}") ctx = ClientContext(site_descriptor.url).with_access_token( lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix) ) external_groups = get_sharepoint_external_groups( ctx, connector.graph_client, graph_api_base=connector.graph_api_base, get_access_token=connector._get_graph_access_token, enumerate_all_ad_groups=enumerate_all, ) # Yield each group for group in external_groups: logger.debug( f"Found group: {group.id} with {len(group.user_emails)} members" ) yield group ================================================ FILE: backend/ee/onyx/external_permissions/sharepoint/permission_utils.py ================================================ import re import time from collections import deque from collections.abc import Callable from collections.abc import Generator from typing import Any from urllib.parse import urlparse import requests as _requests from office365.graph_client import GraphClient # type: ignore[import-untyped] from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped] from office365.runtime.client_request import ClientRequestException # type: ignore from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped] from office365.sharepoint.permissions.securable_object import RoleAssignmentCollection # type: ignore[import-untyped] from pydantic import BaseModel from ee.onyx.db.external_perm import ExternalUserGroup from onyx.access.models import ExternalAccess from onyx.access.utils import build_ext_group_name_for_onyx from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS from onyx.configs.constants import DocumentSource from onyx.connectors.sharepoint.connector import GRAPH_API_MAX_RETRIES from onyx.connectors.sharepoint.connector import GRAPH_API_RETRYABLE_STATUSES from onyx.connectors.sharepoint.connector import SHARED_DOCUMENTS_MAP_REVERSE from onyx.connectors.sharepoint.connector import sleep_and_retry from onyx.utils.logger import setup_logger logger = setup_logger() # These values represent different types of SharePoint principals used in permission assignments USER_PRINCIPAL_TYPE = 1 # Individual user accounts ANONYMOUS_USER_PRINCIPAL_TYPE = 3 # Anonymous/unauthenticated users (public access) AZURE_AD_GROUP_PRINCIPAL_TYPE = 4 # Azure Active Directory security groups SHAREPOINT_GROUP_PRINCIPAL_TYPE = 8 # SharePoint site groups (local to the site) MICROSOFT_DOMAIN = ".onmicrosoft" # Limited Access role type, limited access is a travel through permission not a actual permission LIMITED_ACCESS_ROLE_TYPES = [1, 9] LIMITED_ACCESS_ROLE_NAMES = ["Limited Access", "Web-Only Limited Access"] AD_GROUP_ENUMERATION_THRESHOLD = 100_000 def _graph_api_get( url: str, get_access_token: Callable[[], str], params: dict[str, str] | None = None, ) -> dict[str, Any]: """Authenticated Graph API GET with retry on transient errors.""" for attempt in range(GRAPH_API_MAX_RETRIES + 1): access_token = get_access_token() headers = {"Authorization": f"Bearer {access_token}"} try: resp = _requests.get( url, headers=headers, params=params, timeout=REQUEST_TIMEOUT_SECONDS ) if ( resp.status_code in GRAPH_API_RETRYABLE_STATUSES and attempt < GRAPH_API_MAX_RETRIES ): wait = min(int(resp.headers.get("Retry-After", str(2**attempt))), 60) logger.warning( f"Graph API {resp.status_code} on attempt {attempt + 1}, retrying in {wait}s: {url}" ) time.sleep(wait) continue resp.raise_for_status() return resp.json() except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError): if attempt < GRAPH_API_MAX_RETRIES: wait = min(2**attempt, 60) logger.warning( f"Graph API connection error on attempt {attempt + 1}, retrying in {wait}s: {url}" ) time.sleep(wait) continue raise raise RuntimeError( f"Graph API request failed after {GRAPH_API_MAX_RETRIES + 1} attempts: {url}" ) def _iter_graph_collection( initial_url: str, get_access_token: Callable[[], str], params: dict[str, str] | None = None, ) -> Generator[dict[str, Any], None, None]: """Paginate through a Graph API collection, yielding items one at a time.""" url: str | None = initial_url while url: data = _graph_api_get(url, get_access_token, params) params = None yield from data.get("value", []) url = data.get("@odata.nextLink") def _normalize_email(email: str) -> str: if MICROSOFT_DOMAIN in email: return email.replace(MICROSOFT_DOMAIN, "") return email class SharepointGroup(BaseModel): model_config = {"frozen": True} name: str login_name: str principal_type: int class GroupsResult(BaseModel): groups_to_emails: dict[str, set[str]] found_public_group: bool def _get_azuread_group_guid_by_name( graph_client: GraphClient, group_name: str ) -> str | None: try: # Search for groups by display name groups = sleep_and_retry( graph_client.groups.filter(f"displayName eq '{group_name}'").get(), "get_azuread_group_guid_by_name", ) if groups and len(groups) > 0: return groups[0].id return None except Exception as e: logger.error(f"Failed to get Azure AD group GUID for name {group_name}: {e}") return None def _extract_guid_from_claims_token(claims_token: str) -> str | None: try: # Pattern to match GUID in claims token # Claims tokens often have format: c:0o.c|provider|GUID_suffix guid_pattern = r"([0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12})" match = re.search(guid_pattern, claims_token, re.IGNORECASE) if match: return match.group(1) return None except Exception as e: logger.error(f"Failed to extract GUID from claims token {claims_token}: {e}") return None def _get_group_guid_from_identifier( graph_client: GraphClient, identifier: str ) -> str | None: try: # Check if it's already a GUID guid_pattern = r"^[0-9a-f]{8}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{4}-[0-9a-f]{12}$" if re.match(guid_pattern, identifier, re.IGNORECASE): return identifier # Check if it's a SharePoint claims token if identifier.startswith("c:0") and "|" in identifier: guid = _extract_guid_from_claims_token(identifier) if guid: logger.info(f"Extracted GUID {guid} from claims token {identifier}") return guid # Try to search by display name as fallback return _get_azuread_group_guid_by_name(graph_client, identifier) except Exception as e: logger.error(f"Failed to get group GUID from identifier {identifier}: {e}") return None def _get_security_group_owners(graph_client: GraphClient, group_id: str) -> list[str]: try: # Get group owners using Graph API group = graph_client.groups[group_id] owners = sleep_and_retry( group.owners.get_all(page_loaded=lambda _: None), "get_security_group_owners", ) owner_emails: list[str] = [] logger.info(f"Owners: {owners}") for owner in owners: owner_data = owner.to_json() # Extract email from the JSON data mail: str | None = owner_data.get("mail") user_principal_name: str | None = owner_data.get("userPrincipalName") # Check if owner is a user and has an email if mail: if MICROSOFT_DOMAIN in mail: mail = mail.replace(MICROSOFT_DOMAIN, "") owner_emails.append(mail) elif user_principal_name: if MICROSOFT_DOMAIN in user_principal_name: user_principal_name = user_principal_name.replace( MICROSOFT_DOMAIN, "" ) owner_emails.append(user_principal_name) logger.info( f"Retrieved {len(owner_emails)} owners from security group {group_id}" ) return owner_emails except Exception as e: logger.error(f"Failed to get security group owners for group {group_id}: {e}") return [] def _get_sharepoint_list_item_id(drive_item: DriveItem) -> str | None: try: # First try to get the list item directly from the drive item if hasattr(drive_item, "listItem"): list_item = drive_item.listItem if list_item: # Load the list item properties to get the ID sleep_and_retry(list_item.get(), "get_sharepoint_list_item_id") if hasattr(list_item, "id") and list_item.id: return str(list_item.id) # The SharePoint list item ID is typically available in the sharepointIds property sharepoint_ids = getattr(drive_item, "sharepoint_ids", None) if sharepoint_ids and hasattr(sharepoint_ids, "listItemId"): return sharepoint_ids.listItemId # Alternative: try to get it from the properties properties = getattr(drive_item, "properties", None) if properties: # Sometimes the SharePoint list item ID is in the properties for prop_name, prop_value in properties.items(): if "listitemid" in prop_name.lower(): return str(prop_value) return None except Exception as e: logger.error( f"Error getting SharePoint list item ID for item {drive_item.id}: {e}" ) raise e def _is_public_item( drive_item: DriveItem, treat_sharing_link_as_public: bool = False, ) -> bool: if not treat_sharing_link_as_public: return False try: permissions = sleep_and_retry( drive_item.permissions.get_all(page_loaded=lambda _: None), "is_public_item" ) for permission in permissions: if permission.link and permission.link.scope in ( "anonymous", "organization", ): return True return False except Exception as e: logger.error(f"Failed to check if item {drive_item.id} is public: {e}") return False def _is_public_login_name(login_name: str) -> bool: # Patterns that indicate public access # This list is derived from the below link # https://learn.microsoft.com/en-us/answers/questions/2085339/guid-in-the-loginname-of-site-user-everyone-except public_login_patterns: list[str] = [ "c:0-.f|rolemanager|spo-grid-all-users/", "c:0(.s|true", ] for pattern in public_login_patterns: if pattern in login_name: logger.info(f"Login name {login_name} is public") return True return False # AD groups allows same display name for multiple groups, so we need to add the GUID to the name def _get_group_name_with_suffix( login_name: str, group_name: str, graph_client: GraphClient ) -> str: ad_group_suffix = _get_group_guid_from_identifier(graph_client, login_name) return f"{group_name}_{ad_group_suffix}" def _get_sharepoint_groups( client_context: ClientContext, group_name: str, graph_client: GraphClient ) -> tuple[set[SharepointGroup], set[str]]: groups: set[SharepointGroup] = set() user_emails: set[str] = set() def process_users(users: list[Any]) -> None: nonlocal groups, user_emails for user in users: logger.debug(f"User: {user.to_json()}") if user.principal_type == USER_PRINCIPAL_TYPE and hasattr( user, "user_principal_name" ): if user.user_principal_name: email = user.user_principal_name if MICROSOFT_DOMAIN in email: email = email.replace(MICROSOFT_DOMAIN, "") user_emails.add(email) else: logger.warning( f"User don't have a user principal name: {user.login_name}" ) elif user.principal_type in [ AZURE_AD_GROUP_PRINCIPAL_TYPE, SHAREPOINT_GROUP_PRINCIPAL_TYPE, ]: name = user.title if user.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE: name = _get_group_name_with_suffix( user.login_name, name, graph_client ) groups.add( SharepointGroup( login_name=user.login_name, principal_type=user.principal_type, name=name, ) ) group = client_context.web.site_groups.get_by_name(group_name) sleep_and_retry( group.users.get_all(page_loaded=process_users), "get_sharepoint_groups" ) return groups, user_emails def _get_azuread_groups( graph_client: GraphClient, group_name: str ) -> tuple[set[SharepointGroup], set[str]]: group_id = _get_group_guid_from_identifier(graph_client, group_name) if not group_id: logger.error(f"Failed to get Azure AD group GUID for name {group_name}") return set(), set() group = graph_client.groups[group_id] groups: set[SharepointGroup] = set() user_emails: set[str] = set() def process_members(members: list[Any]) -> None: nonlocal groups, user_emails for member in members: member_data = member.to_json() logger.debug(f"Member: {member_data}") # Check for user-specific attributes user_principal_name = member_data.get("userPrincipalName") mail = member_data.get("mail") display_name = member_data.get("displayName") or member_data.get( "display_name" ) # Check object attributes directly (if available) is_user = False is_group = False # Users typically have userPrincipalName or mail if user_principal_name or (mail and "@" in str(mail)): is_user = True # Groups typically have displayName but no userPrincipalName elif display_name and not user_principal_name: # Additional check: try to access group-specific properties if ( hasattr(member, "groupTypes") or member_data.get("groupTypes") is not None ): is_group = True # Or check if it has an 'id' field typical for groups elif member_data.get("id") and not user_principal_name: is_group = True # Check the object type name (fallback) if not is_user and not is_group: obj_type = type(member).__name__.lower() if "user" in obj_type: is_user = True elif "group" in obj_type: is_group = True # Process based on identification if is_user: if user_principal_name: email = user_principal_name if MICROSOFT_DOMAIN in email: email = email.replace(MICROSOFT_DOMAIN, "") user_emails.add(email) elif mail: email = mail if MICROSOFT_DOMAIN in email: email = email.replace(MICROSOFT_DOMAIN, "") user_emails.add(email) logger.info(f"Added user: {user_principal_name or mail}") elif is_group: if not display_name: logger.error(f"No display name for group: {member_data.get('id')}") continue name = _get_group_name_with_suffix( member_data.get("id", ""), display_name, graph_client ) groups.add( SharepointGroup( login_name=member_data.get("id", ""), # Use ID for groups principal_type=AZURE_AD_GROUP_PRINCIPAL_TYPE, name=name, ) ) logger.info(f"Added group: {name}") else: # Log unidentified members for debugging logger.warning(f"Could not identify member type for: {member_data}") sleep_and_retry( group.members.get_all(page_loaded=process_members), "get_azuread_groups" ) owner_emails = _get_security_group_owners(graph_client, group_id) user_emails.update(owner_emails) return groups, user_emails def _get_groups_and_members_recursively( client_context: ClientContext, graph_client: GraphClient, groups: set[SharepointGroup], is_group_sync: bool = False, ) -> GroupsResult: """ Get all groups and their members recursively. """ group_queue: deque[SharepointGroup] = deque(groups) visited_groups: set[str] = set() visited_group_name_to_emails: dict[str, set[str]] = {} found_public_group = False while group_queue: group = group_queue.popleft() if group.login_name in visited_groups: continue visited_groups.add(group.login_name) visited_group_name_to_emails[group.name] = set() logger.info( f"Processing group: {group.name} principal type: {group.principal_type}" ) if group.principal_type == SHAREPOINT_GROUP_PRINCIPAL_TYPE: group_info, user_emails = _get_sharepoint_groups( client_context, group.login_name, graph_client ) visited_group_name_to_emails[group.name].update(user_emails) if group_info: group_queue.extend(group_info) if group.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE: try: # if the site is public, we have default groups assigned to it, so we return early if _is_public_login_name(group.login_name): found_public_group = True if not is_group_sync: return GroupsResult( groups_to_emails={}, found_public_group=True ) else: # we don't want to sync public groups, so we skip them continue group_info, user_emails = _get_azuread_groups( graph_client, group.login_name ) visited_group_name_to_emails[group.name].update(user_emails) if group_info: group_queue.extend(group_info) except ClientRequestException as e: # If the group is not found, we skip it. There is a chance that group is still referenced # in sharepoint but it is removed from Azure AD. There is no actual documentation on this, but based on # our testing we have seen this happen. if e.response is not None and e.response.status_code == 404: logger.warning(f"Group {group.login_name} not found") continue raise e return GroupsResult( groups_to_emails=visited_group_name_to_emails, found_public_group=found_public_group, ) def get_external_access_from_sharepoint( client_context: ClientContext, graph_client: GraphClient, drive_name: str | None, drive_item: DriveItem | None, site_page: dict[str, Any] | None, add_prefix: bool = False, treat_sharing_link_as_public: bool = False, ) -> ExternalAccess: """ Get external access information from SharePoint. """ groups: set[SharepointGroup] = set() user_emails: set[str] = set() group_ids: set[str] = set() # Add all members to a processing set first def add_user_and_group_to_sets( role_assignments: RoleAssignmentCollection, ) -> None: nonlocal user_emails, groups for assignment in role_assignments: logger.debug(f"Assignment: {assignment.to_json()}") if assignment.role_definition_bindings: is_limited_access = True for role_definition_binding in assignment.role_definition_bindings: if ( role_definition_binding.role_type_kind not in LIMITED_ACCESS_ROLE_TYPES or role_definition_binding.name not in LIMITED_ACCESS_ROLE_NAMES ): is_limited_access = False break # Skip if the role is only Limited Access, because this is not a actual permission its a travel through permission if is_limited_access: logger.info( "Skipping assignment because it has only Limited Access role" ) continue if assignment.member: member = assignment.member if member.principal_type == USER_PRINCIPAL_TYPE and hasattr( member, "user_principal_name" ): email = member.user_principal_name if MICROSOFT_DOMAIN in email: email = email.replace(MICROSOFT_DOMAIN, "") user_emails.add(email) elif member.principal_type in [ AZURE_AD_GROUP_PRINCIPAL_TYPE, SHAREPOINT_GROUP_PRINCIPAL_TYPE, ]: name = member.title if member.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE: name = _get_group_name_with_suffix( member.login_name, name, graph_client ) groups.add( SharepointGroup( login_name=member.login_name, principal_type=member.principal_type, name=name, ) ) if drive_item and drive_name: is_public = _is_public_item(drive_item, treat_sharing_link_as_public) if is_public: logger.info(f"Item {drive_item.id} is public") return ExternalAccess( external_user_emails=set(), external_user_group_ids=set(), is_public=True, ) item_id = _get_sharepoint_list_item_id(drive_item) if not item_id: raise RuntimeError( f"Failed to get SharePoint list item ID for item {drive_item.id}" ) if drive_name in SHARED_DOCUMENTS_MAP_REVERSE: drive_name = SHARED_DOCUMENTS_MAP_REVERSE[drive_name] item = client_context.web.lists.get_by_title(drive_name).items.get_by_id( item_id ) sleep_and_retry( item.role_assignments.expand(["Member", "RoleDefinitionBindings"]).get_all( page_loaded=add_user_and_group_to_sets, ), "get_external_access_from_sharepoint", ) elif site_page: site_url = site_page.get("webUrl") # Keep percent-encoding intact so the path matches the encoding # used by the Office365 library's SPResPath.create_relative(), # which compares against urlparse(context.base_url).path. # Decoding (e.g. %27 → ') causes a mismatch that duplicates # the site prefix in the constructed URL. server_relative_url = urlparse(site_url).path file_obj = client_context.web.get_file_by_server_relative_url( server_relative_url ) item = file_obj.listItemAllFields sleep_and_retry( item.role_assignments.expand(["Member", "RoleDefinitionBindings"]).get_all( page_loaded=add_user_and_group_to_sets, ), "get_external_access_from_sharepoint", ) else: raise RuntimeError("No drive item or site page provided") groups_and_members: GroupsResult = _get_groups_and_members_recursively( client_context, graph_client, groups ) # If the site is public, w have default groups assigned to it, so we return early if groups_and_members.found_public_group: return ExternalAccess( external_user_emails=set(), external_user_group_ids=set(), is_public=True, ) for group_name, _ in groups_and_members.groups_to_emails.items(): if add_prefix: group_name = build_ext_group_name_for_onyx( group_name, DocumentSource.SHAREPOINT ) group_ids.add(group_name.lower()) logger.info(f"User emails: {len(user_emails)}") logger.info(f"Group IDs: {len(group_ids)}") return ExternalAccess( external_user_emails=user_emails, external_user_group_ids=group_ids, is_public=False, ) def _enumerate_ad_groups_paginated( get_access_token: Callable[[], str], already_resolved: set[str], graph_api_base: str, ) -> Generator[ExternalUserGroup, None, None]: """Paginate through all Azure AD groups and yield ExternalUserGroup for each. Skips groups whose suffixed name is already in *already_resolved*. Stops early if the number of groups exceeds AD_GROUP_ENUMERATION_THRESHOLD. """ groups_url = f"{graph_api_base}/groups" groups_params: dict[str, str] = {"$select": "id,displayName", "$top": "999"} total_groups = 0 for group_json in _iter_graph_collection( groups_url, get_access_token, groups_params ): group_id: str = group_json.get("id", "") display_name: str = group_json.get("displayName", "") if not group_id or not display_name: continue total_groups += 1 if total_groups > AD_GROUP_ENUMERATION_THRESHOLD: logger.warning( f"Azure AD group enumeration exceeded {AD_GROUP_ENUMERATION_THRESHOLD} " "groups — stopping to avoid excessive memory/API usage. " "Remaining groups will be resolved from role assignments only." ) return name = f"{display_name}_{group_id}" if name in already_resolved: continue member_emails: list[str] = [] members_url = f"{graph_api_base}/groups/{group_id}/members" members_params: dict[str, str] = { "$select": "userPrincipalName,mail", "$top": "999", } for member_json in _iter_graph_collection( members_url, get_access_token, members_params ): email = member_json.get("userPrincipalName") or member_json.get("mail") if email: member_emails.append(_normalize_email(email)) yield ExternalUserGroup(id=name, user_emails=member_emails) logger.info(f"Enumerated {total_groups} Azure AD groups via paginated Graph API") def get_sharepoint_external_groups( client_context: ClientContext, graph_client: GraphClient, graph_api_base: str, get_access_token: Callable[[], str] | None = None, enumerate_all_ad_groups: bool = False, ) -> list[ExternalUserGroup]: groups: set[SharepointGroup] = set() def add_group_to_sets(role_assignments: RoleAssignmentCollection) -> None: nonlocal groups for assignment in role_assignments: if assignment.role_definition_bindings: is_limited_access = True for role_definition_binding in assignment.role_definition_bindings: if ( role_definition_binding.role_type_kind not in LIMITED_ACCESS_ROLE_TYPES or role_definition_binding.name not in LIMITED_ACCESS_ROLE_NAMES ): is_limited_access = False break # Skip if the role assignment is only Limited Access, because this is not a actual permission its # a travel through permission if is_limited_access: logger.info( "Skipping assignment because it has only Limited Access role" ) continue if assignment.member: member = assignment.member if member.principal_type in [ AZURE_AD_GROUP_PRINCIPAL_TYPE, SHAREPOINT_GROUP_PRINCIPAL_TYPE, ]: name = member.title if member.principal_type == AZURE_AD_GROUP_PRINCIPAL_TYPE: name = _get_group_name_with_suffix( member.login_name, name, graph_client ) groups.add( SharepointGroup( login_name=member.login_name, principal_type=member.principal_type, name=name, ) ) sleep_and_retry( client_context.web.role_assignments.expand( ["Member", "RoleDefinitionBindings"] ).get_all(page_loaded=add_group_to_sets), "get_sharepoint_external_groups", ) groups_and_members: GroupsResult = _get_groups_and_members_recursively( client_context, graph_client, groups, is_group_sync=True ) external_user_groups: list[ExternalUserGroup] = [ ExternalUserGroup(id=group_name, user_emails=list(emails)) for group_name, emails in groups_and_members.groups_to_emails.items() ] if not enumerate_all_ad_groups or get_access_token is None: logger.info( "Skipping exhaustive Azure AD group enumeration. Only groups found in site role assignments are included." ) return external_user_groups already_resolved = set(groups_and_members.groups_to_emails.keys()) for group in _enumerate_ad_groups_paginated( get_access_token, already_resolved, graph_api_base ): external_user_groups.append(group) return external_user_groups ================================================ FILE: backend/ee/onyx/external_permissions/slack/channel_access.py ================================================ from slack_sdk import WebClient from onyx.access.models import ExternalAccess from onyx.connectors.models import BasicExpertInfo from onyx.connectors.slack.connector import ChannelType from onyx.connectors.slack.utils import expert_info_from_slack_id from onyx.connectors.slack.utils import make_paginated_slack_api_call def get_channel_access( client: WebClient, channel: ChannelType, user_cache: dict[str, BasicExpertInfo | None], ) -> ExternalAccess: """ Get channel access permissions for a Slack channel. Args: client: Slack WebClient instance channel: Slack channel object containing channel info user_cache: Cache of user IDs to BasicExpertInfo objects. May be updated in place. Returns: ExternalAccess object for the channel. """ channel_is_public = not channel["is_private"] if channel_is_public: return ExternalAccess( external_user_emails=set(), external_user_group_ids=set(), is_public=True, ) channel_id = channel["id"] # Get all member IDs for the channel member_ids = [] for result in make_paginated_slack_api_call( client.conversations_members, channel=channel_id, ): member_ids.extend(result.get("members", [])) member_emails = set() for member_id in member_ids: # Try to get user info from cache or fetch it user_info = expert_info_from_slack_id( user_id=member_id, client=client, user_cache=user_cache, ) # If we have user info and an email, add it to the set if user_info and user_info.email: member_emails.add(user_info.email) return ExternalAccess( external_user_emails=member_emails, # NOTE: groups are not used, since adding a group to a channel just adds all # users that are in the group. external_user_group_ids=set(), is_public=False, ) ================================================ FILE: backend/ee/onyx/external_permissions/slack/doc_sync.py ================================================ from collections.abc import Generator from slack_sdk import WebClient from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map from onyx.access.models import DocExternalAccess from onyx.access.models import ExternalAccess from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import HierarchyNode from onyx.connectors.slack.connector import get_channels from onyx.connectors.slack.connector import make_paginated_slack_api_call from onyx.connectors.slack.connector import SlackConnector from onyx.db.models import ConnectorCredentialPair from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.redis.redis_pool import get_redis_client from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() def _fetch_workspace_permissions( user_id_to_email_map: dict[str, str], ) -> ExternalAccess: user_emails = set() for email in user_id_to_email_map.values(): user_emails.add(email) return ExternalAccess( external_user_emails=user_emails, # No group<->document mapping for slack external_user_group_ids=set(), # No way to determine if slack is invite only without enterprise license is_public=False, ) def _fetch_channel_permissions( slack_client: WebClient, workspace_permissions: ExternalAccess, user_id_to_email_map: dict[str, str], ) -> dict[str, ExternalAccess]: channel_permissions = {} public_channels = get_channels( client=slack_client, get_public=True, get_private=False, ) public_channel_ids = [ channel["id"] for channel in public_channels if "id" in channel ] for channel_id in public_channel_ids: channel_permissions[channel_id] = workspace_permissions private_channels = get_channels( client=slack_client, get_public=False, get_private=True, ) private_channel_ids = [ channel["id"] for channel in private_channels if "id" in channel ] for channel_id in private_channel_ids: # Collect all member ids for the channel pagination calls member_ids = [] for result in make_paginated_slack_api_call( slack_client.conversations_members, channel=channel_id, ): member_ids.extend(result.get("members", [])) # Collect all member emails for the channel member_emails = set() for member_id in member_ids: member_email = user_id_to_email_map.get(member_id) if not member_email: # If the user is an external user, they wont get returned from the # conversations_members call so we need to make a separate call to users_info # and add them to the user_id_to_email_map member_info = slack_client.users_info(user=member_id) member_email = member_info["user"]["profile"].get("email") if not member_email: # If no email is found, we skip the user continue user_id_to_email_map[member_id] = member_email member_emails.add(member_email) channel_permissions[channel_id] = ExternalAccess( external_user_emails=member_emails, # No group<->document mapping for slack external_user_group_ids=set(), # No way to determine if slack is invite only without enterprise license is_public=False, ) return channel_permissions def _get_slack_document_access( slack_connector: SlackConnector, channel_permissions: dict[str, ExternalAccess], # noqa: ARG001 callback: IndexingHeartbeatInterface | None, indexing_start: SecondsSinceUnixEpoch | None = None, ) -> Generator[DocExternalAccess, None, None]: slim_doc_generator = slack_connector.retrieve_all_slim_docs_perm_sync( callback=callback, start=indexing_start, ) for doc_metadata_batch in slim_doc_generator: for doc_metadata in doc_metadata_batch: if isinstance(doc_metadata, HierarchyNode): # TODO: handle hierarchynodes during sync continue if doc_metadata.external_access is None: raise ValueError( f"No external access for document {doc_metadata.id}. " "Please check to make sure that your Slack bot token has the " "`channels:read` scope" ) yield DocExternalAccess( external_access=doc_metadata.external_access, doc_id=doc_metadata.id, ) if callback: if callback.should_stop(): raise RuntimeError("_get_slack_document_access: Stop signal detected") callback.progress("_get_slack_document_access", 1) def slack_doc_sync( cc_pair: ConnectorCredentialPair, fetch_all_existing_docs_fn: FetchAllDocumentsFunction, # noqa: ARG001 fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction, # noqa: ARG001 callback: IndexingHeartbeatInterface | None, ) -> Generator[DocExternalAccess, None, None]: """ Adds the external permissions to the documents in postgres if the document doesn't already exists in postgres, we create it in postgres so that when it gets created later, the permissions are already populated """ # Use credentials provider instead of directly loading credentials tenant_id = get_current_tenant_id() provider = OnyxDBCredentialsProvider(tenant_id, "slack", cc_pair.credential.id) r = get_redis_client(tenant_id=tenant_id) credential_json = ( cc_pair.credential.credential_json.get_value(apply_mask=False) if cc_pair.credential.credential_json else {} ) slack_client = SlackConnector.make_slack_web_client( provider.get_provider_key(), credential_json["slack_bot_token"], SlackConnector.MAX_RETRIES, r, ) user_id_to_email_map = fetch_user_id_to_email_map(slack_client) if not user_id_to_email_map: raise ValueError( "No user id to email map found. Please check to make sure that your Slack bot token has the `users:read.email` scope" ) workspace_permissions = _fetch_workspace_permissions( user_id_to_email_map=user_id_to_email_map, ) channel_permissions = _fetch_channel_permissions( slack_client=slack_client, workspace_permissions=workspace_permissions, user_id_to_email_map=user_id_to_email_map, ) slack_connector = SlackConnector(**cc_pair.connector.connector_specific_config) slack_connector.set_credentials_provider(provider) indexing_start_ts: SecondsSinceUnixEpoch | None = ( cc_pair.connector.indexing_start.timestamp() if cc_pair.connector.indexing_start is not None else None ) yield from _get_slack_document_access( slack_connector=slack_connector, channel_permissions=channel_permissions, callback=callback, indexing_start=indexing_start_ts, ) ================================================ FILE: backend/ee/onyx/external_permissions/slack/group_sync.py ================================================ """ THIS IS NOT USEFUL OR USED FOR PERMISSION SYNCING WHEN USERGROUPS ARE ADDED TO A CHANNEL, IT JUST RESOLVES ALL THE USERS TO THAT CHANNEL SO WHEN CHECKING IF A USER CAN ACCESS A DOCUMENT, WE ONLY NEED TO CHECK THEIR EMAIL THERE IS NO USERGROUP <-> DOCUMENT PERMISSION MAPPING """ from slack_sdk import WebClient from ee.onyx.db.external_perm import ExternalUserGroup from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider from onyx.connectors.slack.connector import SlackConnector from onyx.connectors.slack.utils import make_paginated_slack_api_call from onyx.db.models import ConnectorCredentialPair from onyx.redis.redis_pool import get_redis_client from onyx.utils.logger import setup_logger logger = setup_logger() def _get_slack_group_ids( slack_client: WebClient, ) -> list[str]: group_ids = [] for result in make_paginated_slack_api_call(slack_client.usergroups_list): for group in result.get("usergroups", []): group_ids.append(group.get("id")) return group_ids def _get_slack_group_members_email( slack_client: WebClient, group_name: str, user_id_to_email_map: dict[str, str], ) -> list[str]: group_member_emails = [] for result in make_paginated_slack_api_call( slack_client.usergroups_users_list, usergroup=group_name ): for member_id in result.get("users", []): member_email = user_id_to_email_map.get(member_id) if not member_email: # If the user is an external user, they wont get returned from the # conversations_members call so we need to make a separate call to users_info member_info = slack_client.users_info(user=member_id) member_email = member_info["user"]["profile"].get("email") if not member_email: # If no email is found, we skip the user continue user_id_to_email_map[member_id] = member_email group_member_emails.append(member_email) return group_member_emails def slack_group_sync( tenant_id: str, cc_pair: ConnectorCredentialPair, ) -> list[ExternalUserGroup]: """NOTE: not used atm. All channel access is done at the individual user level. Leaving in for now in case we need it later.""" provider = OnyxDBCredentialsProvider(tenant_id, "slack", cc_pair.credential.id) r = get_redis_client(tenant_id=tenant_id) credential_json = ( cc_pair.credential.credential_json.get_value(apply_mask=False) if cc_pair.credential.credential_json else {} ) slack_client = SlackConnector.make_slack_web_client( provider.get_provider_key(), credential_json["slack_bot_token"], SlackConnector.MAX_RETRIES, r, ) user_id_to_email_map = fetch_user_id_to_email_map(slack_client) onyx_groups: list[ExternalUserGroup] = [] for group_name in _get_slack_group_ids(slack_client): group_member_emails = _get_slack_group_members_email( slack_client=slack_client, group_name=group_name, user_id_to_email_map=user_id_to_email_map, ) if not group_member_emails: continue onyx_groups.append( ExternalUserGroup(id=group_name, user_emails=group_member_emails) ) return onyx_groups ================================================ FILE: backend/ee/onyx/external_permissions/slack/utils.py ================================================ from slack_sdk import WebClient from onyx.connectors.slack.utils import make_paginated_slack_api_call def fetch_user_id_to_email_map( slack_client: WebClient, ) -> dict[str, str]: user_id_to_email_map = {} for user_info in make_paginated_slack_api_call( slack_client.users_list, ): for user in user_info.get("members", []): if user.get("profile", {}).get("email"): user_id_to_email_map[user.get("id")] = user.get("profile", {}).get( "email" ) return user_id_to_email_map ================================================ FILE: backend/ee/onyx/external_permissions/sync_params.py ================================================ from collections.abc import Generator from typing import Optional from typing import TYPE_CHECKING from pydantic import BaseModel from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY from ee.onyx.configs.app_configs import CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY from ee.onyx.configs.app_configs import DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY from ee.onyx.configs.app_configs import GITHUB_PERMISSION_DOC_SYNC_FREQUENCY from ee.onyx.configs.app_configs import GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY from ee.onyx.configs.app_configs import GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY from ee.onyx.configs.app_configs import JIRA_PERMISSION_DOC_SYNC_FREQUENCY from ee.onyx.configs.app_configs import JIRA_PERMISSION_GROUP_SYNC_FREQUENCY from ee.onyx.configs.app_configs import SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY from ee.onyx.configs.app_configs import SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY from ee.onyx.configs.app_configs import SLACK_PERMISSION_DOC_SYNC_FREQUENCY from ee.onyx.configs.app_configs import TEAMS_PERMISSION_DOC_SYNC_FREQUENCY from ee.onyx.external_permissions.confluence.doc_sync import confluence_doc_sync from ee.onyx.external_permissions.confluence.group_sync import confluence_group_sync from ee.onyx.external_permissions.github.doc_sync import github_doc_sync from ee.onyx.external_permissions.github.group_sync import github_group_sync from ee.onyx.external_permissions.gmail.doc_sync import gmail_doc_sync from ee.onyx.external_permissions.google_drive.doc_sync import gdrive_doc_sync from ee.onyx.external_permissions.google_drive.group_sync import gdrive_group_sync from ee.onyx.external_permissions.jira.doc_sync import jira_doc_sync from ee.onyx.external_permissions.jira.group_sync import jira_group_sync from ee.onyx.external_permissions.perm_sync_types import CensoringFuncType from ee.onyx.external_permissions.perm_sync_types import DocSyncFuncType from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction from ee.onyx.external_permissions.perm_sync_types import GroupSyncFuncType from ee.onyx.external_permissions.salesforce.postprocessing import ( censor_salesforce_chunks, ) from ee.onyx.external_permissions.sharepoint.doc_sync import sharepoint_doc_sync from ee.onyx.external_permissions.sharepoint.group_sync import sharepoint_group_sync from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync from ee.onyx.external_permissions.teams.doc_sync import teams_doc_sync from onyx.configs.constants import DocumentSource if TYPE_CHECKING: from onyx.access.models import DocExternalAccess # noqa from onyx.db.models import ConnectorCredentialPair # noqa from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface # noqa class DocSyncConfig(BaseModel): doc_sync_frequency: int doc_sync_func: DocSyncFuncType initial_index_should_sync: bool class GroupSyncConfig(BaseModel): group_sync_frequency: int group_sync_func: GroupSyncFuncType group_sync_is_cc_pair_agnostic: bool class CensoringConfig(BaseModel): chunk_censoring_func: CensoringFuncType class SyncConfig(BaseModel): # None means we don't perform a doc_sync doc_sync_config: DocSyncConfig | None = None # None means we don't perform a group_sync group_sync_config: GroupSyncConfig | None = None # None means we don't perform a chunk_censoring censoring_config: CensoringConfig | None = None # Mock doc sync function for testing (no-op) def mock_doc_sync( cc_pair: "ConnectorCredentialPair", # noqa: ARG001 fetch_all_docs_fn: FetchAllDocumentsFunction, # noqa: ARG001 fetch_all_docs_ids_fn: FetchAllDocumentsIdsFunction, # noqa: ARG001 callback: Optional["IndexingHeartbeatInterface"], # noqa: ARG001 ) -> Generator["DocExternalAccess", None, None]: """Mock doc sync function for testing - returns empty list since permissions are fetched during indexing""" yield from [] _SOURCE_TO_SYNC_CONFIG: dict[DocumentSource, SyncConfig] = { DocumentSource.GOOGLE_DRIVE: SyncConfig( doc_sync_config=DocSyncConfig( doc_sync_frequency=DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY, doc_sync_func=gdrive_doc_sync, initial_index_should_sync=True, ), group_sync_config=GroupSyncConfig( group_sync_frequency=GOOGLE_DRIVE_PERMISSION_GROUP_SYNC_FREQUENCY, group_sync_func=gdrive_group_sync, group_sync_is_cc_pair_agnostic=False, ), ), DocumentSource.CONFLUENCE: SyncConfig( doc_sync_config=DocSyncConfig( doc_sync_frequency=CONFLUENCE_PERMISSION_DOC_SYNC_FREQUENCY, doc_sync_func=confluence_doc_sync, initial_index_should_sync=False, ), group_sync_config=GroupSyncConfig( group_sync_frequency=CONFLUENCE_PERMISSION_GROUP_SYNC_FREQUENCY, group_sync_func=confluence_group_sync, group_sync_is_cc_pair_agnostic=True, ), ), DocumentSource.JIRA: SyncConfig( doc_sync_config=DocSyncConfig( doc_sync_frequency=JIRA_PERMISSION_DOC_SYNC_FREQUENCY, doc_sync_func=jira_doc_sync, initial_index_should_sync=True, ), group_sync_config=GroupSyncConfig( group_sync_frequency=JIRA_PERMISSION_GROUP_SYNC_FREQUENCY, group_sync_func=jira_group_sync, group_sync_is_cc_pair_agnostic=True, ), ), # Groups are not needed for Slack. # All channel access is done at the individual user level. DocumentSource.SLACK: SyncConfig( doc_sync_config=DocSyncConfig( doc_sync_frequency=SLACK_PERMISSION_DOC_SYNC_FREQUENCY, doc_sync_func=slack_doc_sync, initial_index_should_sync=True, ), ), DocumentSource.GMAIL: SyncConfig( doc_sync_config=DocSyncConfig( doc_sync_frequency=DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY, doc_sync_func=gmail_doc_sync, initial_index_should_sync=False, ), ), DocumentSource.GITHUB: SyncConfig( doc_sync_config=DocSyncConfig( doc_sync_frequency=GITHUB_PERMISSION_DOC_SYNC_FREQUENCY, doc_sync_func=github_doc_sync, initial_index_should_sync=True, ), group_sync_config=GroupSyncConfig( group_sync_frequency=GITHUB_PERMISSION_GROUP_SYNC_FREQUENCY, group_sync_func=github_group_sync, group_sync_is_cc_pair_agnostic=False, ), ), DocumentSource.SALESFORCE: SyncConfig( censoring_config=CensoringConfig( chunk_censoring_func=censor_salesforce_chunks, ), ), DocumentSource.MOCK_CONNECTOR: SyncConfig( doc_sync_config=DocSyncConfig( doc_sync_frequency=DEFAULT_PERMISSION_DOC_SYNC_FREQUENCY, doc_sync_func=mock_doc_sync, initial_index_should_sync=True, ), ), # Groups are not needed for Teams. # All channel access is done at the individual user level. DocumentSource.TEAMS: SyncConfig( doc_sync_config=DocSyncConfig( doc_sync_frequency=TEAMS_PERMISSION_DOC_SYNC_FREQUENCY, doc_sync_func=teams_doc_sync, initial_index_should_sync=True, ), ), DocumentSource.SHAREPOINT: SyncConfig( doc_sync_config=DocSyncConfig( doc_sync_frequency=SHAREPOINT_PERMISSION_DOC_SYNC_FREQUENCY, doc_sync_func=sharepoint_doc_sync, initial_index_should_sync=True, ), group_sync_config=GroupSyncConfig( group_sync_frequency=SHAREPOINT_PERMISSION_GROUP_SYNC_FREQUENCY, group_sync_func=sharepoint_group_sync, group_sync_is_cc_pair_agnostic=False, ), ), } def source_requires_doc_sync(source: DocumentSource) -> bool: """Checks if the given DocumentSource requires doc syncing.""" if source not in _SOURCE_TO_SYNC_CONFIG: return False return _SOURCE_TO_SYNC_CONFIG[source].doc_sync_config is not None def source_requires_external_group_sync(source: DocumentSource) -> bool: """Checks if the given DocumentSource requires external group syncing.""" if source not in _SOURCE_TO_SYNC_CONFIG: return False return _SOURCE_TO_SYNC_CONFIG[source].group_sync_config is not None def get_source_perm_sync_config(source: DocumentSource) -> SyncConfig | None: """Returns the frequency of the external group sync for the given DocumentSource.""" return _SOURCE_TO_SYNC_CONFIG.get(source) def source_group_sync_is_cc_pair_agnostic(source: DocumentSource) -> bool: """Checks if the given DocumentSource requires external group syncing.""" if source not in _SOURCE_TO_SYNC_CONFIG: return False group_sync_config = _SOURCE_TO_SYNC_CONFIG[source].group_sync_config if group_sync_config is None: return False return group_sync_config.group_sync_is_cc_pair_agnostic def get_all_cc_pair_agnostic_group_sync_sources() -> set[DocumentSource]: """Returns the set of sources that have external group syncing that is cc_pair agnostic.""" return { source for source, sync_config in _SOURCE_TO_SYNC_CONFIG.items() if sync_config.group_sync_config is not None and sync_config.group_sync_config.group_sync_is_cc_pair_agnostic } def check_if_valid_sync_source(source_type: DocumentSource) -> bool: return source_type in _SOURCE_TO_SYNC_CONFIG def get_all_censoring_enabled_sources() -> set[DocumentSource]: """Returns the set of sources that have censoring enabled.""" return { source for source, sync_config in _SOURCE_TO_SYNC_CONFIG.items() if sync_config.censoring_config is not None } def source_should_fetch_permissions_during_indexing(source: DocumentSource) -> bool: """Returns True if the given DocumentSource requires permissions to be fetched during indexing.""" if source not in _SOURCE_TO_SYNC_CONFIG: return False doc_sync_config = _SOURCE_TO_SYNC_CONFIG[source].doc_sync_config if doc_sync_config is None: return False return doc_sync_config.initial_index_should_sync ================================================ FILE: backend/ee/onyx/external_permissions/teams/doc_sync.py ================================================ from collections.abc import Generator from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsFunction from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction from ee.onyx.external_permissions.utils import generic_doc_sync from onyx.access.models import ElementExternalAccess from onyx.configs.constants import DocumentSource from onyx.connectors.teams.connector import TeamsConnector from onyx.db.models import ConnectorCredentialPair from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger logger = setup_logger() TEAMS_DOC_SYNC_LABEL = "teams_doc_sync" def teams_doc_sync( cc_pair: ConnectorCredentialPair, fetch_all_existing_docs_fn: FetchAllDocumentsFunction, # noqa: ARG001 fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction, callback: IndexingHeartbeatInterface | None, ) -> Generator[ElementExternalAccess, None, None]: teams_connector = TeamsConnector( **cc_pair.connector.connector_specific_config, ) credential_json = ( cc_pair.credential.credential_json.get_value(apply_mask=False) if cc_pair.credential.credential_json else {} ) teams_connector.load_credentials(credential_json) yield from generic_doc_sync( cc_pair=cc_pair, fetch_all_existing_docs_ids_fn=fetch_all_existing_docs_ids_fn, callback=callback, doc_source=DocumentSource.TEAMS, slim_connector=teams_connector, label=TEAMS_DOC_SYNC_LABEL, ) ================================================ FILE: backend/ee/onyx/external_permissions/utils.py ================================================ from collections.abc import Generator from ee.onyx.external_permissions.perm_sync_types import FetchAllDocumentsIdsFunction from onyx.access.models import DocExternalAccess from onyx.access.models import ElementExternalAccess from onyx.access.models import ExternalAccess from onyx.access.models import NodeExternalAccess from onyx.configs.constants import DocumentSource from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnectorWithPermSync from onyx.connectors.models import HierarchyNode from onyx.db.models import ConnectorCredentialPair from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger logger = setup_logger() def generic_doc_sync( cc_pair: ConnectorCredentialPair, fetch_all_existing_docs_ids_fn: FetchAllDocumentsIdsFunction, callback: IndexingHeartbeatInterface | None, doc_source: DocumentSource, slim_connector: SlimConnectorWithPermSync, label: str, ) -> Generator[ElementExternalAccess, None, None]: """ A convenience function for performing a generic document synchronization. Notes: A generic doc sync includes: - fetching existing docs - fetching *all* new (slim) docs - yielding external-access permissions for existing docs which do not exist in the newly fetched slim-docs set (with their `external_access` set to "private") - yielding external-access permissions for newly fetched docs and hierarchy nodes Returns: A `Generator` which yields existing and newly fetched external-access permissions. """ logger.info(f"Starting {doc_source} doc sync for CC Pair ID: {cc_pair.id}") indexing_start: SecondsSinceUnixEpoch | None = ( cc_pair.connector.indexing_start.timestamp() if cc_pair.connector.indexing_start is not None else None ) newly_fetched_doc_ids: set[str] = set() logger.info(f"Fetching all slim documents from {doc_source}") for doc_batch in slim_connector.retrieve_all_slim_docs_perm_sync( start=indexing_start, callback=callback, ): logger.info(f"Got {len(doc_batch)} slim documents from {doc_source}") if callback: if callback.should_stop(): raise RuntimeError(f"{label}: Stop signal detected") callback.progress(label, 1) for doc in doc_batch: if isinstance(doc, HierarchyNode): # Yield hierarchy node permissions to be processed in outer layer if doc.external_access: yield NodeExternalAccess( external_access=doc.external_access, raw_node_id=doc.raw_node_id, source=doc_source.value, ) continue if not doc.external_access: raise RuntimeError( f"No external access found for document ID; {cc_pair.id=} {doc_source=} {doc.id=}" ) newly_fetched_doc_ids.add(doc.id) yield DocExternalAccess( doc_id=doc.id, external_access=doc.external_access, ) logger.info(f"Querying existing document IDs for CC Pair ID: {cc_pair.id=}") existing_doc_ids: list[str] = fetch_all_existing_docs_ids_fn() missing_doc_ids = set(existing_doc_ids) - newly_fetched_doc_ids if not missing_doc_ids: return logger.warning( f"Found {len(missing_doc_ids)=} documents that are in the DB but not present in fetch. Making them inaccessible." ) for missing_id in missing_doc_ids: logger.warning(f"Removing access for {missing_id=}") yield DocExternalAccess( doc_id=missing_id, external_access=ExternalAccess.empty(), ) logger.info(f"Finished {doc_source} doc sync") ================================================ FILE: backend/ee/onyx/feature_flags/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/feature_flags/factory.py ================================================ from ee.onyx.feature_flags.posthog_provider import PostHogFeatureFlagProvider from onyx.feature_flags.interface import FeatureFlagProvider def get_posthog_feature_flag_provider() -> FeatureFlagProvider: """ Get the PostHog feature flag provider instance. This is the EE implementation that gets loaded by the versioned implementation loader. Returns: PostHogFeatureFlagProvider: The PostHog-based feature flag provider """ return PostHogFeatureFlagProvider() ================================================ FILE: backend/ee/onyx/feature_flags/posthog_provider.py ================================================ from typing import Any from uuid import UUID from ee.onyx.utils.posthog_client import posthog from onyx.feature_flags.interface import FeatureFlagProvider from onyx.utils.logger import setup_logger logger = setup_logger() class PostHogFeatureFlagProvider(FeatureFlagProvider): """ PostHog-based feature flag provider. Uses PostHog's feature flag API to determine if features are enabled for specific users. Only active in multi-tenant mode. """ def feature_enabled( self, flag_key: str, user_id: UUID, user_properties: dict[str, Any] | None = None, ) -> bool: """ Check if a feature flag is enabled for a user via PostHog. Args: flag_key: The identifier for the feature flag to check user_id: The unique identifier for the user user_properties: Optional dictionary of user properties/attributes that may influence flag evaluation Returns: True if the feature is enabled for the user, False otherwise. """ if not posthog: return False try: posthog.set( distinct_id=user_id, properties=user_properties, ) is_enabled = posthog.feature_enabled( flag_key, str(user_id), person_properties=user_properties, ) return bool(is_enabled) if is_enabled is not None else False except Exception as e: logger.error( f"Error checking feature flag {flag_key} for user {user_id}: {e}" ) return False ================================================ FILE: backend/ee/onyx/hooks/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/hooks/executor.py ================================================ """Hook executor — calls a customer's external HTTP endpoint for a given hook point. Usage (Celery tasks and FastAPI handlers): result = execute_hook( db_session=db_session, hook_point=HookPoint.QUERY_PROCESSING, payload={"query": "...", "user_email": "...", "chat_session_id": "..."}, response_type=QueryProcessingResponse, ) if isinstance(result, HookSkipped): # no active hook configured — continue with original behavior ... elif isinstance(result, HookSoftFailed): # hook failed but fail strategy is SOFT — continue with original behavior ... else: # result is a validated Pydantic model instance (response_type) ... is_reachable update policy -------------------------- ``is_reachable`` on the Hook row is updated selectively — only when the outcome carries meaningful signal about physical reachability: NetworkError (DNS, connection refused) → False (cannot reach the server) HTTP 401 / 403 → False (api_key revoked or invalid) TimeoutException → None (server may be slow, skip write) Other HTTP errors (4xx / 5xx) → None (server responded, skip write) Unknown exception → None (no signal, skip write) Non-JSON / non-dict response → None (server responded, skip write) Success (2xx, valid dict) → True (confirmed reachable) None means "leave the current value unchanged" — no DB round-trip is made. DB session design ----------------- The executor uses three sessions: 1. Caller's session (db_session) — used only for the hook lookup read. All needed fields are extracted from the Hook object before the HTTP call, so the caller's session is not held open during the external HTTP request. 2. Log session — a separate short-lived session opened after the HTTP call completes to write the HookExecutionLog row on failure. Success runs are not recorded. Committed independently of everything else. 3. Reachable session — a second short-lived session to update is_reachable on the Hook. Kept separate from the log session so a concurrent hook deletion (which causes update_hook__no_commit to raise OnyxError(NOT_FOUND)) cannot prevent the execution log from being written. This update is best-effort. """ import json import time from typing import Any from typing import TypeVar import httpx from pydantic import BaseModel from pydantic import ValidationError from sqlalchemy.orm import Session from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.enums import HookFailStrategy from onyx.db.enums import HookPoint from onyx.db.hook import create_hook_execution_log__no_commit from onyx.db.hook import get_non_deleted_hook_by_hook_point from onyx.db.hook import update_hook__no_commit from onyx.db.models import Hook from onyx.error_handling.error_codes import OnyxErrorCode from onyx.error_handling.exceptions import OnyxError from onyx.hooks.executor import HookSkipped from onyx.hooks.executor import HookSoftFailed from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT logger = setup_logger() T = TypeVar("T", bound=BaseModel) # --------------------------------------------------------------------------- # Private helpers # --------------------------------------------------------------------------- class _HttpOutcome(BaseModel): """Structured result of an HTTP hook call, returned by _process_response.""" is_success: bool updated_is_reachable: ( bool | None ) # True/False = write to DB, None = unchanged (skip write) status_code: int | None error_message: str | None response_payload: dict[str, Any] | None def _lookup_hook( db_session: Session, hook_point: HookPoint, ) -> Hook | HookSkipped: """Return the active Hook or HookSkipped if hooks are unavailable/unconfigured. No HTTP call is made and no DB writes are performed for any HookSkipped path. There is nothing to log and no reachability information to update. """ if MULTI_TENANT: return HookSkipped() hook = get_non_deleted_hook_by_hook_point( db_session=db_session, hook_point=hook_point ) if hook is None or not hook.is_active: return HookSkipped() if not hook.endpoint_url: return HookSkipped() return hook def _process_response( *, response: httpx.Response | None, exc: Exception | None, timeout: float, ) -> _HttpOutcome: """Process the result of an HTTP call and return a structured outcome. Called after the client.post() try/except. If post() raised, exc is set and response is None. Otherwise response is set and exc is None. Handles raise_for_status(), JSON decoding, and the dict shape check. """ if exc is not None: if isinstance(exc, httpx.NetworkError): msg = f"Hook network error (endpoint unreachable): {exc}" logger.warning(msg, exc_info=exc) return _HttpOutcome( is_success=False, updated_is_reachable=False, status_code=None, error_message=msg, response_payload=None, ) if isinstance(exc, httpx.TimeoutException): msg = f"Hook timed out after {timeout}s: {exc}" logger.warning(msg, exc_info=exc) return _HttpOutcome( is_success=False, updated_is_reachable=None, # timeout doesn't indicate unreachability status_code=None, error_message=msg, response_payload=None, ) msg = f"Hook call failed: {exc}" logger.exception(msg, exc_info=exc) return _HttpOutcome( is_success=False, updated_is_reachable=None, # unknown error — don't make assumptions status_code=None, error_message=msg, response_payload=None, ) if response is None: raise ValueError( "exactly one of response or exc must be non-None; both are None" ) status_code = response.status_code try: response.raise_for_status() except httpx.HTTPStatusError as e: msg = f"Hook returned HTTP {e.response.status_code}: {e.response.text}" logger.warning(msg, exc_info=e) # 401/403 means the api_key has been revoked or is invalid — mark unreachable # so the operator knows to update it. All other HTTP errors keep is_reachable # as-is (server is up, the request just failed for application reasons). auth_failed = e.response.status_code in (401, 403) return _HttpOutcome( is_success=False, updated_is_reachable=False if auth_failed else None, status_code=status_code, error_message=msg, response_payload=None, ) try: response_payload = response.json() except (json.JSONDecodeError, httpx.DecodingError) as e: msg = f"Hook returned non-JSON response: {e}" logger.warning(msg, exc_info=e) return _HttpOutcome( is_success=False, updated_is_reachable=None, # server responded — reachability unchanged status_code=status_code, error_message=msg, response_payload=None, ) if not isinstance(response_payload, dict): msg = f"Hook returned non-dict JSON (got {type(response_payload).__name__})" logger.warning(msg) return _HttpOutcome( is_success=False, updated_is_reachable=None, # server responded — reachability unchanged status_code=status_code, error_message=msg, response_payload=None, ) return _HttpOutcome( is_success=True, updated_is_reachable=True, status_code=status_code, error_message=None, response_payload=response_payload, ) def _persist_result( *, hook_id: int, outcome: _HttpOutcome, duration_ms: int, ) -> None: """Write the execution log on failure and optionally update is_reachable, each in its own session so a failure in one does not affect the other.""" # Only write the execution log on failure — success runs are not recorded. # Must not be skipped if the is_reachable update fails (e.g. hook concurrently # deleted between the initial lookup and here). if not outcome.is_success: try: with get_session_with_current_tenant() as log_session: create_hook_execution_log__no_commit( db_session=log_session, hook_id=hook_id, is_success=False, error_message=outcome.error_message, status_code=outcome.status_code, duration_ms=duration_ms, ) log_session.commit() except Exception: logger.exception( f"Failed to persist hook execution log for hook_id={hook_id}" ) # Update is_reachable separately — best-effort, non-critical. # None means the value is unchanged (set by the caller to skip the no-op write). # update_hook__no_commit can raise OnyxError(NOT_FOUND) if the hook was # concurrently deleted, so keep this isolated from the log write above. if outcome.updated_is_reachable is not None: try: with get_session_with_current_tenant() as reachable_session: update_hook__no_commit( db_session=reachable_session, hook_id=hook_id, is_reachable=outcome.updated_is_reachable, ) reachable_session.commit() except Exception: logger.warning(f"Failed to update is_reachable for hook_id={hook_id}") # --------------------------------------------------------------------------- # Public API # --------------------------------------------------------------------------- def _execute_hook_inner( hook: Hook, payload: dict[str, Any], response_type: type[T], ) -> T | HookSoftFailed: """Make the HTTP call, validate the response, and return a typed model. Raises OnyxError on HARD failure. Returns HookSoftFailed on SOFT failure. """ timeout = hook.timeout_seconds hook_id = hook.id fail_strategy = hook.fail_strategy endpoint_url = hook.endpoint_url current_is_reachable: bool | None = hook.is_reachable if not endpoint_url: raise ValueError( f"hook_id={hook_id} is active but has no endpoint_url — " "active hooks without an endpoint_url must be rejected by _lookup_hook" ) start = time.monotonic() response: httpx.Response | None = None exc: Exception | None = None try: api_key: str | None = ( hook.api_key.get_value(apply_mask=False) if hook.api_key else None ) headers: dict[str, str] = {"Content-Type": "application/json"} if api_key: headers["Authorization"] = f"Bearer {api_key}" with httpx.Client( timeout=timeout, follow_redirects=False ) as client: # SSRF guard: never follow redirects response = client.post(endpoint_url, json=payload, headers=headers) except Exception as e: exc = e duration_ms = int((time.monotonic() - start) * 1000) outcome = _process_response(response=response, exc=exc, timeout=timeout) # Validate the response payload against response_type. # A validation failure downgrades the outcome to a failure so it is logged, # is_reachable is left unchanged (server responded — just a bad payload), # and fail_strategy is respected below. validated_model: T | None = None if outcome.is_success and outcome.response_payload is not None: try: validated_model = response_type.model_validate(outcome.response_payload) except ValidationError as e: msg = ( f"Hook response failed validation against {response_type.__name__}: {e}" ) outcome = _HttpOutcome( is_success=False, updated_is_reachable=None, # server responded — reachability unchanged status_code=outcome.status_code, error_message=msg, response_payload=None, ) # Skip the is_reachable write when the value would not change — avoids a # no-op DB round-trip on every call when the hook is already in the expected state. if outcome.updated_is_reachable == current_is_reachable: outcome = outcome.model_copy(update={"updated_is_reachable": None}) _persist_result(hook_id=hook_id, outcome=outcome, duration_ms=duration_ms) if not outcome.is_success: if fail_strategy == HookFailStrategy.HARD: raise OnyxError( OnyxErrorCode.HOOK_EXECUTION_FAILED, outcome.error_message or "Hook execution failed.", ) logger.warning( f"Hook execution failed (soft fail) for hook_id={hook_id}: {outcome.error_message}" ) return HookSoftFailed() if validated_model is None: raise OnyxError( OnyxErrorCode.INTERNAL_ERROR, f"validated_model is None for successful hook call (hook_id={hook_id})", ) return validated_model def _execute_hook_impl( *, db_session: Session, hook_point: HookPoint, payload: dict[str, Any], response_type: type[T], ) -> T | HookSkipped | HookSoftFailed: """EE implementation — loaded by CE's execute_hook via fetch_versioned_implementation. Returns HookSkipped if no active hook is configured, HookSoftFailed if the hook failed with SOFT fail strategy, or a validated response model on success. Raises OnyxError on HARD failure or if the hook is misconfigured. """ hook = _lookup_hook(db_session, hook_point) if isinstance(hook, HookSkipped): return hook fail_strategy = hook.fail_strategy hook_id = hook.id try: return _execute_hook_inner(hook, payload, response_type) except Exception: if fail_strategy == HookFailStrategy.SOFT: logger.exception( f"Unexpected error in hook execution (soft fail) for hook_id={hook_id}" ) return HookSoftFailed() raise ================================================ FILE: backend/ee/onyx/main.py ================================================ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from fastapi import FastAPI from httpx_oauth.clients.google import GoogleOAuth2 from ee.onyx.server.analytics.api import router as analytics_router from ee.onyx.server.auth_check import check_ee_router_auth from ee.onyx.server.billing.api import router as billing_router from ee.onyx.server.documents.cc_pair import router as ee_document_cc_pair_router from ee.onyx.server.enterprise_settings.api import ( admin_router as enterprise_settings_admin_router, ) from ee.onyx.server.enterprise_settings.api import ( basic_router as enterprise_settings_router, ) from ee.onyx.server.evals.api import router as evals_router from ee.onyx.server.features.hooks.api import router as hook_router from ee.onyx.server.license.api import router as license_router from ee.onyx.server.manage.standard_answer import router as standard_answer_router from ee.onyx.server.middleware.license_enforcement import ( add_license_enforcement_middleware, ) from ee.onyx.server.middleware.tenant_tracking import ( add_api_server_tenant_id_middleware, ) from ee.onyx.server.oauth.api import router as ee_oauth_router from ee.onyx.server.query_and_chat.query_backend import ( basic_router as ee_query_router, ) from ee.onyx.server.query_and_chat.search_backend import router as search_router from ee.onyx.server.query_history.api import router as query_history_router from ee.onyx.server.reporting.usage_export_api import router as usage_export_router from ee.onyx.server.scim.api import register_scim_exception_handlers from ee.onyx.server.scim.api import scim_router from ee.onyx.server.seeding import seed_db from ee.onyx.server.tenants.api import router as tenants_router from ee.onyx.server.token_rate_limits.api import ( router as token_rate_limit_settings_router, ) from ee.onyx.server.user_group.api import router as user_group_router from ee.onyx.utils.encryption import test_encryption from onyx.auth.users import auth_backend from onyx.auth.users import create_onyx_oauth_router from onyx.auth.users import fastapi_users from onyx.configs.app_configs import AUTH_TYPE from onyx.configs.app_configs import OAUTH_CLIENT_ID from onyx.configs.app_configs import OAUTH_CLIENT_SECRET from onyx.configs.app_configs import USER_AUTH_SECRET from onyx.configs.app_configs import WEB_DOMAIN from onyx.configs.constants import AuthType from onyx.main import get_application as get_application_base from onyx.main import include_auth_router_with_prefix from onyx.main import include_router_with_global_prefix_prepended from onyx.main import lifespan as lifespan_base from onyx.main import use_route_function_names_as_operation_ids from onyx.server.query_and_chat.query_backend import ( basic_router as query_router, ) from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import global_version from shared_configs.configs import MULTI_TENANT logger = setup_logger() @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: """Small wrapper around the lifespan of the MIT application. Basically just calls the base lifespan, and then adds EE-only steps after.""" async with lifespan_base(app): # seed the Onyx environment with LLMs, Assistants, etc. based on an optional # environment variable. Used to automate deployment for multiple environments. seed_db() yield def get_application() -> FastAPI: # Anything that happens at import time is not guaranteed to be running ee-version # Anything after the server startup will be running ee version global_version.set_ee() test_encryption() application = get_application_base(lifespan_override=lifespan) if MULTI_TENANT: add_api_server_tenant_id_middleware(application, logger) else: # License enforcement middleware for self-hosted deployments only # Checks LICENSE_ENFORCEMENT_ENABLED at runtime (can be toggled without restart) # MT deployments use control plane gating via is_tenant_gated() instead add_license_enforcement_middleware(application, logger) if AUTH_TYPE == AuthType.CLOUD: # For Google OAuth, refresh tokens are requested by: # 1. Adding the right scopes # 2. Properly configuring OAuth in Google Cloud Console to allow offline access oauth_client = GoogleOAuth2( OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, # Use standard scopes that include profile and email scopes=["openid", "email", "profile"], ) include_auth_router_with_prefix( application, create_onyx_oauth_router( oauth_client, auth_backend, USER_AUTH_SECRET, associate_by_email=True, is_verified_by_default=True, # Points the user back to the login page redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback", ), prefix="/auth/oauth", ) # Need basic auth router for `logout` endpoint include_auth_router_with_prefix( application, fastapi_users.get_logout_router(auth_backend), prefix="/auth", ) # RBAC / group access control include_router_with_global_prefix_prepended(application, user_group_router) # Analytics endpoints include_router_with_global_prefix_prepended(application, analytics_router) include_router_with_global_prefix_prepended(application, query_history_router) # EE only backend APIs include_router_with_global_prefix_prepended(application, query_router) include_router_with_global_prefix_prepended(application, ee_query_router) include_router_with_global_prefix_prepended(application, search_router) include_router_with_global_prefix_prepended(application, standard_answer_router) include_router_with_global_prefix_prepended(application, ee_oauth_router) include_router_with_global_prefix_prepended(application, ee_document_cc_pair_router) include_router_with_global_prefix_prepended(application, evals_router) include_router_with_global_prefix_prepended(application, hook_router) # Enterprise-only global settings include_router_with_global_prefix_prepended( application, enterprise_settings_admin_router ) # Token rate limit settings include_router_with_global_prefix_prepended( application, token_rate_limit_settings_router ) include_router_with_global_prefix_prepended(application, enterprise_settings_router) include_router_with_global_prefix_prepended(application, usage_export_router) # License management include_router_with_global_prefix_prepended(application, license_router) # Unified billing API - always registered in EE. # Each endpoint is protected by the `current_admin_user` dependency (admin auth). include_router_with_global_prefix_prepended(application, billing_router) if MULTI_TENANT: # Tenant management include_router_with_global_prefix_prepended(application, tenants_router) # SCIM 2.0 — protocol endpoints (unauthenticated by Onyx session auth; # they use their own SCIM bearer token auth). # Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly. application.include_router(scim_router) register_scim_exception_handlers(application) # Ensure all routes have auth enabled or are explicitly marked as public check_ee_router_auth(application) # for debugging discovered routes # for route in application.router.routes: # print(f"Path: {route.path}, Methods: {route.methods}") use_route_function_names_as_operation_ids(application) return application ================================================ FILE: backend/ee/onyx/onyxbot/slack/handlers/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/onyxbot/slack/handlers/handle_standard_answers.py ================================================ from slack_sdk import WebClient from slack_sdk.models.blocks import ActionsBlock from slack_sdk.models.blocks import Block from slack_sdk.models.blocks import ButtonElement from slack_sdk.models.blocks import SectionBlock from sqlalchemy.orm import Session from ee.onyx.db.standard_answer import fetch_standard_answer_categories_by_names from ee.onyx.db.standard_answer import find_matching_standard_answers from onyx.configs.constants import MessageType from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI from onyx.db.chat import create_chat_session from onyx.db.chat import create_new_chat_message from onyx.db.chat import get_chat_messages_by_sessions from onyx.db.chat import get_chat_sessions_by_slack_thread_id from onyx.db.chat import get_or_create_root_message from onyx.db.models import SlackChannelConfig from onyx.db.models import StandardAnswer as StandardAnswerModel from onyx.onyxbot.slack.blocks import get_restate_blocks from onyx.onyxbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID from onyx.onyxbot.slack.handlers.utils import send_team_member_message from onyx.onyxbot.slack.models import SlackMessageInfo from onyx.onyxbot.slack.utils import respond_in_thread_or_channel from onyx.onyxbot.slack.utils import update_emote_react from onyx.server.manage.models import StandardAnswer as PydanticStandardAnswer from onyx.utils.logger import OnyxLoggingAdapter from onyx.utils.logger import setup_logger logger = setup_logger() def build_standard_answer_blocks( answer_message: str, ) -> list[Block]: generate_button_block = ButtonElement( action_id=GENERATE_ANSWER_BUTTON_ACTION_ID, text="Generate Full Answer", ) answer_block = SectionBlock(text=answer_message) return [ answer_block, ActionsBlock( elements=[generate_button_block], ), ] def oneoff_standard_answers( message: str, slack_bot_categories: list[str], db_session: Session, ) -> list[PydanticStandardAnswer]: """ Respond to the user message if it matches any configured standard answers. Returns a list of matching StandardAnswers if found, otherwise None. """ configured_standard_answers = { standard_answer for category in fetch_standard_answer_categories_by_names( slack_bot_categories, db_session=db_session ) for standard_answer in category.standard_answers } matching_standard_answers = find_matching_standard_answers( query=message, id_in=[answer.id for answer in configured_standard_answers], db_session=db_session, ) server_standard_answers = [ PydanticStandardAnswer.from_model(standard_answer_model) for (standard_answer_model, _) in matching_standard_answers ] return server_standard_answers def _handle_standard_answers( message_info: SlackMessageInfo, receiver_ids: list[str] | None, slack_channel_config: SlackChannelConfig, logger: OnyxLoggingAdapter, client: WebClient, db_session: Session, ) -> bool: """ Potentially respond to the user message depending on whether the user's message matches any of the configured standard answers and also whether those answers have already been provided in the current thread. Returns True if standard answers are found to match the user's message and therefore, we still need to respond to the users. """ slack_thread_id = message_info.thread_to_respond configured_standard_answer_categories = ( slack_channel_config.standard_answer_categories ) configured_standard_answers = set( [ standard_answer for standard_answer_category in configured_standard_answer_categories for standard_answer in standard_answer_category.standard_answers ] ) query_msg = message_info.thread_messages[-1] if slack_thread_id is None: used_standard_answer_ids = set([]) else: chat_sessions = get_chat_sessions_by_slack_thread_id( slack_thread_id=slack_thread_id, user_id=None, db_session=db_session, ) chat_messages = get_chat_messages_by_sessions( chat_session_ids=[chat_session.id for chat_session in chat_sessions], user_id=None, db_session=db_session, skip_permission_check=True, ) used_standard_answer_ids = set( [ standard_answer.id for chat_message in chat_messages for standard_answer in chat_message.standard_answers ] ) usable_standard_answers = configured_standard_answers.difference( used_standard_answer_ids ) matching_standard_answers: list[tuple[StandardAnswerModel, str]] = [] if usable_standard_answers: matching_standard_answers = find_matching_standard_answers( query=query_msg.message, id_in=[standard_answer.id for standard_answer in usable_standard_answers], db_session=db_session, ) if matching_standard_answers: chat_session = create_chat_session( db_session=db_session, description="", user_id=None, persona_id=( slack_channel_config.persona.id if slack_channel_config.persona else 0 ), onyxbot_flow=True, slack_thread_id=slack_thread_id, ) root_message = get_or_create_root_message( chat_session_id=chat_session.id, db_session=db_session ) new_user_message = create_new_chat_message( chat_session_id=chat_session.id, parent_message=root_message, message=query_msg.message, token_count=0, message_type=MessageType.USER, db_session=db_session, commit=True, ) formatted_answers = [] for standard_answer, match_str in matching_standard_answers: since_you_mentioned_pretext = ( f'Since your question contains "_{match_str}_"' ) block_quotified_answer = ">" + standard_answer.answer.replace("\n", "\n> ") formatted_answer = f"{since_you_mentioned_pretext}, I thought this might be useful: \n\n{block_quotified_answer}" formatted_answers.append(formatted_answer) answer_message = "\n\n".join(formatted_answers) chat_message = create_new_chat_message( chat_session_id=chat_session.id, parent_message=new_user_message, message=answer_message, token_count=0, message_type=MessageType.ASSISTANT, error=None, db_session=db_session, commit=False, ) # attach the standard answers to the chat message chat_message.standard_answers = [ standard_answer for standard_answer, _ in matching_standard_answers ] db_session.commit() update_emote_react( emoji=ONYX_BOT_REACT_EMOJI, channel=message_info.channel_to_respond, message_ts=message_info.msg_to_respond, remove=True, client=client, ) restate_question_blocks = get_restate_blocks( msg=query_msg.message, is_slash_command=message_info.is_slash_command, ) answer_blocks = build_standard_answer_blocks( answer_message=answer_message, ) all_blocks = restate_question_blocks + answer_blocks try: respond_in_thread_or_channel( client=client, channel=message_info.channel_to_respond, receiver_ids=receiver_ids, text="Hello! Onyx has some results for you!", blocks=all_blocks, thread_ts=message_info.msg_to_respond, unfurl=False, ) if receiver_ids and slack_thread_id: send_team_member_message( client=client, channel=message_info.channel_to_respond, thread_ts=slack_thread_id, receiver_ids=receiver_ids, ) return True except Exception as e: logger.exception(f"Unable to send standard answer message: {e}") return False else: return False ================================================ FILE: backend/ee/onyx/prompts/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/prompts/query_expansion.py ================================================ # Single message is likely most reliable and generally better for this task # No final reminders at the end since the user query is expected to be short # If it is not short, it should go into the chat flow so we do not need to account for this. KEYWORD_EXPANSION_PROMPT = """ Generate a set of keyword-only queries to help find relevant documents for the provided query. \ These queries will be passed to a bm25-based keyword search engine. \ Provide a single query per line (where each query consists of one or more keywords). \ The queries must be purely keywords and not contain any filler natural language. \ The each query should have as few keywords as necessary to represent the user's search intent. \ If there are no useful expansions, simply return the original query with no additional keyword queries. \ CRITICAL: Do not include any additional formatting, comments, or anything aside from the keyword queries. The user query is: {user_query} """.strip() QUERY_TYPE_PROMPT = """ Determine if the provided query is better suited for a keyword search or a semantic search. Respond with "keyword" or "semantic" literally and nothing else. Do not provide any additional text or reasoning to your response. CRITICAL: It must only be 1 single word - EITHER "keyword" or "semantic". The user query is: {user_query} """.strip() ================================================ FILE: backend/ee/onyx/prompts/search_flow_classification.py ================================================ # ruff: noqa: E501, W605 start SEARCH_CLASS = "search" CHAT_CLASS = "chat" # Will note that with many larger LLMs the latency on running this prompt via third party APIs is as high as 2 seconds which is too slow for many # use cases. SEARCH_CHAT_PROMPT = f""" Determine if the following query is better suited for a search UI or a chat UI. Respond with "{SEARCH_CLASS}" or "{CHAT_CLASS}" literally and nothing else. \ Do not provide any additional text or reasoning to your response. CRITICAL, IT MUST ONLY BE 1 SINGLE WORD - EITHER "{SEARCH_CLASS}" or "{CHAT_CLASS}". # Classification Guidelines: ## {SEARCH_CLASS} - If the query consists entirely of keywords or query doesn't require any answer from the AI - If the query is a short statement that seems like a search query rather than a question - If the query feels nonsensical or is a short phrase that possibly describes a document or information that could be found in a internal document ### Examples of {SEARCH_CLASS} queries: - Find me the document that goes over the onboarding process for a new hire - Pull requests since last week - Sales Runbook AMEA Region - Procurement process - Retrieve the PRD for project X ## {CHAT_CLASS} - If the query is asking a question that requires an answer rather than a document - If the query is asking for a solution, suggestion, or general help - If the query is seeking information that is on the web and likely not in a company internal document - If the query should be answered without any context from additional documents or searches ### Examples of {CHAT_CLASS} queries: - What led us to win the deal with company X? (seeking answer) - Google Drive not sync-ing files to my computer (seeking solution) - Review my email: (general help) - Write me a script to... (general help) - Cheap flights Europe to Tokyo (information likely found on the web, not internal) # User Query: {{user_query}} REMEMBER TO ONLY RESPOND WITH "{SEARCH_CLASS}" OR "{CHAT_CLASS}" AND NOTHING ELSE. """.strip() # ruff: noqa: E501, W605 end ================================================ FILE: backend/ee/onyx/search/process_search_query.py ================================================ from collections.abc import Generator from sqlalchemy.orm import Session from ee.onyx.db.search import create_search_query from ee.onyx.secondary_llm_flows.query_expansion import expand_keywords from ee.onyx.server.query_and_chat.models import SearchDocWithContent from ee.onyx.server.query_and_chat.models import SearchFullResponse from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest from ee.onyx.server.query_and_chat.streaming_models import LLMSelectedDocsPacket from ee.onyx.server.query_and_chat.streaming_models import SearchDocsPacket from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket from ee.onyx.server.query_and_chat.streaming_models import SearchQueriesPacket from onyx.context.search.models import BaseFilters from onyx.context.search.models import ChunkSearchRequest from onyx.context.search.models import InferenceChunk from onyx.context.search.pipeline import merge_individual_chunks from onyx.context.search.pipeline import search_pipeline from onyx.db.models import User from onyx.db.search_settings import get_current_search_settings from onyx.document_index.factory import get_default_document_index from onyx.document_index.interfaces import DocumentIndex from onyx.llm.factory import get_default_llm from onyx.secondary_llm_flows.document_filter import select_sections_for_expansion from onyx.tools.tool_implementations.search.search_utils import ( weighted_reciprocal_rank_fusion, ) from onyx.utils.logger import setup_logger from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel logger = setup_logger() # This is just a heuristic that also happens to work well for the UI/UX # Users would not find it useful to see a huge list of suggested docs # but more than 1 is also likely good as many questions may target more than 1 doc. TARGET_NUM_SECTIONS_FOR_LLM_SELECTION = 3 def _run_single_search( query: str, filters: BaseFilters | None, document_index: DocumentIndex, user: User, db_session: Session, num_hits: int | None = None, hybrid_alpha: float | None = None, ) -> list[InferenceChunk]: """Execute a single search query and return chunks.""" chunk_search_request = ChunkSearchRequest( query=query, user_selected_filters=filters, limit=num_hits, hybrid_alpha=hybrid_alpha, ) return search_pipeline( chunk_search_request=chunk_search_request, document_index=document_index, user=user, persona_search_info=None, db_session=db_session, ) def stream_search_query( request: SendSearchQueryRequest, user: User, db_session: Session, ) -> Generator[ SearchQueriesPacket | SearchDocsPacket | LLMSelectedDocsPacket | SearchErrorPacket, None, None, ]: """ Core search function that yields streaming packets. Used by both streaming and non-streaming endpoints. """ # Get document index. search_settings = get_current_search_settings(db_session) # This flow is for search so we do not get all indices. document_index = get_default_document_index(search_settings, None, db_session) # Determine queries to execute original_query = request.search_query keyword_expansions: list[str] = [] if request.run_query_expansion: try: llm = get_default_llm() keyword_expansions = expand_keywords( user_query=original_query, llm=llm, ) if keyword_expansions: logger.debug( f"Query expansion generated {len(keyword_expansions)} keyword queries" ) except Exception as e: logger.warning(f"Query expansion failed: {e}; using original query only.") keyword_expansions = [] # Build list of all executed queries for tracking all_executed_queries = [original_query] + keyword_expansions if not user.is_anonymous: create_search_query( db_session=db_session, user_id=user.id, query=request.search_query, query_expansions=keyword_expansions if keyword_expansions else None, ) # Execute search(es) if not keyword_expansions: # Single query (original only) - no threading needed chunks = _run_single_search( query=original_query, filters=request.filters, document_index=document_index, user=user, db_session=db_session, num_hits=request.num_hits, hybrid_alpha=request.hybrid_alpha, ) else: # Multiple queries - run in parallel and merge with RRF # First query is the original (semantic), rest are keyword expansions search_functions = [ ( _run_single_search, ( query, request.filters, document_index, user, db_session, request.num_hits, request.hybrid_alpha, ), ) for query in all_executed_queries ] # Run all searches in parallel all_search_results: list[list[InferenceChunk]] = ( run_functions_tuples_in_parallel( search_functions, allow_failures=True, ) ) # Separate original query results from keyword expansion results # Note that in rare cases, the original query may have failed and so we may be # just overweighting one set of keyword results, should be not a big deal though. original_result = all_search_results[0] if all_search_results else [] keyword_results = all_search_results[1:] if len(all_search_results) > 1 else [] # Build valid results and weights # Original query (semantic): weight 2.0 # Keyword expansions: weight 1.0 each valid_results: list[list[InferenceChunk]] = [] weights: list[float] = [] if original_result: valid_results.append(original_result) weights.append(2.0) for keyword_result in keyword_results: if keyword_result: valid_results.append(keyword_result) weights.append(1.0) if not valid_results: logger.warning("All parallel searches returned empty results") chunks = [] else: chunks = weighted_reciprocal_rank_fusion( ranked_results=valid_results, weights=weights, id_extractor=lambda chunk: f"{chunk.document_id}_{chunk.chunk_id}", ) # Merge chunks into sections sections = merge_individual_chunks(chunks) # Truncate to the requested number of hits sections = sections[: request.num_hits] # Apply LLM document selection if requested # num_docs_fed_to_llm_selection specifies how many sections to feed to the LLM for selection # The LLM will always try to select TARGET_NUM_SECTIONS_FOR_LLM_SELECTION sections from those fed to it # llm_selected_doc_ids will be: # - None if LLM selection was not requested or failed # - Empty list if LLM selection ran but selected nothing # - List of doc IDs if LLM selection succeeded run_llm_selection = ( request.num_docs_fed_to_llm_selection is not None and request.num_docs_fed_to_llm_selection >= 1 ) llm_selected_doc_ids: list[str] | None = None llm_selection_failed = False if run_llm_selection and sections: try: llm = get_default_llm() sections_to_evaluate = sections[: request.num_docs_fed_to_llm_selection] selected_sections, _ = select_sections_for_expansion( sections=sections_to_evaluate, user_query=original_query, llm=llm, max_sections=TARGET_NUM_SECTIONS_FOR_LLM_SELECTION, try_to_fill_to_max=True, ) # Extract unique document IDs from selected sections (may be empty) llm_selected_doc_ids = list( dict.fromkeys( section.center_chunk.document_id for section in selected_sections ) ) logger.debug( f"LLM document selection evaluated {len(sections_to_evaluate)} sections, " f"selected {len(selected_sections)} sections with doc IDs: {llm_selected_doc_ids}" ) except Exception as e: # Allowing a blanket exception here as this step is not critical and the rest of the results are still valid logger.warning(f"LLM document selection failed: {e}") llm_selection_failed = True elif run_llm_selection and not sections: # LLM selection requested but no sections to evaluate llm_selected_doc_ids = [] # Convert to SearchDocWithContent list, optionally including content search_docs = SearchDocWithContent.from_inference_sections( sections, include_content=request.include_content, is_internet=False, ) # Yield queries packet yield SearchQueriesPacket(all_executed_queries=all_executed_queries) # Yield docs packet yield SearchDocsPacket(search_docs=search_docs) # Yield LLM selected docs packet if LLM selection was requested # - llm_selected_doc_ids is None if selection failed # - llm_selected_doc_ids is empty list if no docs were selected # - llm_selected_doc_ids is list of IDs if docs were selected if run_llm_selection: yield LLMSelectedDocsPacket( llm_selected_doc_ids=None if llm_selection_failed else llm_selected_doc_ids ) def gather_search_stream( packets: Generator[ SearchQueriesPacket | SearchDocsPacket | LLMSelectedDocsPacket | SearchErrorPacket, None, None, ], ) -> SearchFullResponse: """ Aggregate all streaming packets into SearchFullResponse. """ all_executed_queries: list[str] = [] search_docs: list[SearchDocWithContent] = [] llm_selected_doc_ids: list[str] | None = None error: str | None = None for packet in packets: if isinstance(packet, SearchQueriesPacket): all_executed_queries = packet.all_executed_queries elif isinstance(packet, SearchDocsPacket): search_docs = packet.search_docs elif isinstance(packet, LLMSelectedDocsPacket): llm_selected_doc_ids = packet.llm_selected_doc_ids elif isinstance(packet, SearchErrorPacket): error = packet.error return SearchFullResponse( all_executed_queries=all_executed_queries, search_docs=search_docs, doc_selection_reasoning=None, llm_selected_doc_ids=llm_selected_doc_ids, error=error, ) ================================================ FILE: backend/ee/onyx/secondary_llm_flows/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/secondary_llm_flows/query_expansion.py ================================================ import re from ee.onyx.prompts.query_expansion import KEYWORD_EXPANSION_PROMPT from onyx.llm.interfaces import LLM from onyx.llm.models import LanguageModelInput from onyx.llm.models import ReasoningEffort from onyx.llm.models import UserMessage from onyx.llm.utils import llm_response_to_string from onyx.utils.logger import setup_logger logger = setup_logger() # Pattern to remove common LLM artifacts: brackets, quotes, list markers, etc. CLEANUP_PATTERN = re.compile(r'[\[\]"\'`]') def _clean_keyword_line(line: str) -> str: """Clean a keyword line by removing common LLM artifacts. Removes brackets, quotes, and other characters that LLMs may accidentally include in their output. """ # Remove common artifacts cleaned = CLEANUP_PATTERN.sub("", line) # Remove leading list markers like "1.", "2.", "-", "*" cleaned = re.sub(r"^\s*(?:\d+[\.\)]\s*|[-*]\s*)", "", cleaned) return cleaned.strip() def expand_keywords( user_query: str, llm: LLM, ) -> list[str]: """Expand a user query into multiple keyword-only queries for BM25 search. Uses an LLM to generate keyword-based search queries that capture different aspects of the user's search intent. Returns only the expanded queries, not the original query. Args: user_query: The original search query from the user llm: Language model to use for keyword expansion Returns: List of expanded keyword queries (excluding the original query). Returns empty list if expansion fails or produces no useful expansions. """ messages: LanguageModelInput = [ UserMessage(content=KEYWORD_EXPANSION_PROMPT.format(user_query=user_query)) ] try: response = llm.invoke( prompt=messages, reasoning_effort=ReasoningEffort.OFF, # Limit output - we only expect a few short keyword queries max_tokens=150, ) content = llm_response_to_string(response).strip() if not content: logger.warning("Keyword expansion returned empty response.") return [] # Parse response - each line is a separate keyword query # Clean each line to remove LLM artifacts and drop empty lines parsed_queries = [] for line in content.strip().split("\n"): cleaned = _clean_keyword_line(line) if cleaned: parsed_queries.append(cleaned) if not parsed_queries: logger.warning("Keyword expansion parsing returned no queries.") return [] # Filter out duplicates and queries that match the original expanded_queries: list[str] = [] seen_lower: set[str] = {user_query.lower()} for query in parsed_queries: query_lower = query.lower() if query_lower not in seen_lower: seen_lower.add(query_lower) expanded_queries.append(query) logger.debug(f"Keyword expansion generated {len(expanded_queries)} queries") return expanded_queries except Exception as e: logger.warning(f"Keyword expansion failed: {e}") return [] ================================================ FILE: backend/ee/onyx/secondary_llm_flows/search_flow_classification.py ================================================ from ee.onyx.prompts.search_flow_classification import CHAT_CLASS from ee.onyx.prompts.search_flow_classification import SEARCH_CHAT_PROMPT from ee.onyx.prompts.search_flow_classification import SEARCH_CLASS from onyx.llm.interfaces import LLM from onyx.llm.models import LanguageModelInput from onyx.llm.models import ReasoningEffort from onyx.llm.models import UserMessage from onyx.llm.utils import llm_response_to_string from onyx.utils.logger import setup_logger from onyx.utils.timing import log_function_time logger = setup_logger() @log_function_time(print_only=True) def classify_is_search_flow( query: str, llm: LLM, ) -> bool: messages: LanguageModelInput = [ UserMessage(content=SEARCH_CHAT_PROMPT.format(user_query=query)) ] response = llm.invoke( prompt=messages, reasoning_effort=ReasoningEffort.OFF, # Nothing can happen in the UI until this call finishes so we need to be aggressive with the timeout timeout_override=2, # Well more than necessary but just to ensure completion and in case it succeeds with classifying but # ends up rambling max_tokens=20, ) content = llm_response_to_string(response).strip().lower() if not content: logger.warning( "Search flow classification returned empty response; defaulting to chat flow." ) return False # Prefer chat if both appear. if CHAT_CLASS in content: return False if SEARCH_CLASS in content: return True logger.warning( "Search flow classification returned unexpected response; defaulting to chat flow. Response=%r", content, ) return False ================================================ FILE: backend/ee/onyx/server/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/server/analytics/api.py ================================================ import datetime from collections import defaultdict from typing import List from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from pydantic import BaseModel from sqlalchemy.orm import Session from ee.onyx.db.analytics import fetch_assistant_message_analytics from ee.onyx.db.analytics import fetch_assistant_unique_users from ee.onyx.db.analytics import fetch_assistant_unique_users_total from ee.onyx.db.analytics import fetch_onyxbot_analytics from ee.onyx.db.analytics import fetch_per_user_query_analytics from ee.onyx.db.analytics import fetch_persona_message_analytics from ee.onyx.db.analytics import fetch_persona_unique_users from ee.onyx.db.analytics import fetch_query_analytics from ee.onyx.db.analytics import user_can_view_assistant_stats from onyx.auth.users import current_admin_user from onyx.auth.users import current_user from onyx.configs.constants import PUBLIC_API_TAGS from onyx.db.engine.sql_engine import get_session from onyx.db.models import User router = APIRouter(prefix="/analytics", tags=PUBLIC_API_TAGS) _DEFAULT_LOOKBACK_DAYS = 30 class QueryAnalyticsResponse(BaseModel): total_queries: int total_likes: int total_dislikes: int date: datetime.date @router.get("/admin/query") def get_query_analytics( start: datetime.datetime | None = None, end: datetime.datetime | None = None, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> list[QueryAnalyticsResponse]: daily_query_usage_info = fetch_query_analytics( start=start or ( datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS) ), # default is 30d lookback end=end or datetime.datetime.utcnow(), db_session=db_session, ) return [ QueryAnalyticsResponse( total_queries=total_queries, total_likes=total_likes, total_dislikes=total_dislikes, date=date, ) for total_queries, total_likes, total_dislikes, date in daily_query_usage_info ] class UserAnalyticsResponse(BaseModel): total_active_users: int date: datetime.date @router.get("/admin/user") def get_user_analytics( start: datetime.datetime | None = None, end: datetime.datetime | None = None, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> list[UserAnalyticsResponse]: daily_query_usage_info_per_user = fetch_per_user_query_analytics( start=start or ( datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS) ), # default is 30d lookback end=end or datetime.datetime.utcnow(), db_session=db_session, ) user_analytics: dict[datetime.date, int] = defaultdict(int) for __, ___, ____, date, _____ in daily_query_usage_info_per_user: user_analytics[date] += 1 return [ UserAnalyticsResponse( total_active_users=cnt, date=date, ) for date, cnt in user_analytics.items() ] class OnyxbotAnalyticsResponse(BaseModel): total_queries: int auto_resolved: int date: datetime.date @router.get("/admin/onyxbot") def get_onyxbot_analytics( start: datetime.datetime | None = None, end: datetime.datetime | None = None, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> list[OnyxbotAnalyticsResponse]: daily_onyxbot_info = fetch_onyxbot_analytics( start=start or ( datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS) ), # default is 30d lookback end=end or datetime.datetime.utcnow(), db_session=db_session, ) resolution_results = [ OnyxbotAnalyticsResponse( total_queries=total_queries, # If it hits negatives, something has gone wrong... auto_resolved=max(0, total_queries - total_negatives), date=date, ) for total_queries, total_negatives, date in daily_onyxbot_info ] return resolution_results class PersonaMessageAnalyticsResponse(BaseModel): total_messages: int date: datetime.date persona_id: int @router.get("/admin/persona/messages") def get_persona_messages( persona_id: int, start: datetime.datetime | None = None, end: datetime.datetime | None = None, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> list[PersonaMessageAnalyticsResponse]: """Fetch daily message counts for a single persona within the given time range.""" start = start or ( datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS) ) end = end or datetime.datetime.utcnow() persona_message_counts = [] for count, date in fetch_persona_message_analytics( db_session=db_session, persona_id=persona_id, start=start, end=end, ): persona_message_counts.append( PersonaMessageAnalyticsResponse( total_messages=count, date=date, persona_id=persona_id, ) ) return persona_message_counts class PersonaUniqueUsersResponse(BaseModel): unique_users: int date: datetime.date persona_id: int @router.get("/admin/persona/unique-users") def get_persona_unique_users( persona_id: int, start: datetime.datetime, end: datetime.datetime, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> list[PersonaUniqueUsersResponse]: """Get unique users per day for a single persona.""" unique_user_counts = [] daily_counts = fetch_persona_unique_users( db_session=db_session, persona_id=persona_id, start=start, end=end, ) for count, date in daily_counts: unique_user_counts.append( PersonaUniqueUsersResponse( unique_users=count, date=date, persona_id=persona_id, ) ) return unique_user_counts class AssistantDailyUsageResponse(BaseModel): date: datetime.date total_messages: int total_unique_users: int class AssistantStatsResponse(BaseModel): daily_stats: List[AssistantDailyUsageResponse] total_messages: int total_unique_users: int @router.get("/assistant/{assistant_id}/stats") def get_assistant_stats( assistant_id: int, start: datetime.datetime | None = None, end: datetime.datetime | None = None, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> AssistantStatsResponse: """ Returns daily message and unique user counts for a user's assistant, along with the overall total messages and total distinct users. """ start = start or ( datetime.datetime.utcnow() - datetime.timedelta(days=_DEFAULT_LOOKBACK_DAYS) ) end = end or datetime.datetime.utcnow() if not user_can_view_assistant_stats(db_session, user, assistant_id): raise HTTPException( status_code=403, detail="Not allowed to access this assistant's stats." ) # Pull daily usage from the DB calls messages_data = fetch_assistant_message_analytics( db_session, assistant_id, start, end ) unique_users_data = fetch_assistant_unique_users( db_session, assistant_id, start, end ) # Map each day => (messages, unique_users). daily_messages_map = {date: count for count, date in messages_data} daily_unique_users_map = {date: count for count, date in unique_users_data} all_dates = set(daily_messages_map.keys()) | set(daily_unique_users_map.keys()) # Merge both sets of metrics by date daily_results: list[AssistantDailyUsageResponse] = [] for date in sorted(all_dates): daily_results.append( AssistantDailyUsageResponse( date=date, total_messages=daily_messages_map.get(date, 0), total_unique_users=daily_unique_users_map.get(date, 0), ) ) # Now pull a single total distinct user count across the entire time range total_msgs = sum(d.total_messages for d in daily_results) total_users = fetch_assistant_unique_users_total( db_session, assistant_id, start, end ) return AssistantStatsResponse( daily_stats=daily_results, total_messages=total_msgs, total_unique_users=total_users, ) ================================================ FILE: backend/ee/onyx/server/auth_check.py ================================================ from fastapi import FastAPI from onyx.server.auth_check import check_router_auth from onyx.server.auth_check import PUBLIC_ENDPOINT_SPECS EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [ # SCIM 2.0 service discovery — unauthenticated so IdPs can probe # before bearer token configuration is complete ("/scim/v2/ServiceProviderConfig", {"GET"}), ("/scim/v2/ResourceTypes", {"GET"}), ("/scim/v2/Schemas", {"GET"}), # needs to be accessible prior to user login ("/enterprise-settings", {"GET"}), ("/enterprise-settings/logo", {"GET"}), ("/enterprise-settings/logotype", {"GET"}), ("/enterprise-settings/custom-analytics-script", {"GET"}), # Stripe publishable key is safe to expose publicly ("/tenants/stripe-publishable-key", {"GET"}), ("/admin/billing/stripe-publishable-key", {"GET"}), # Proxy endpoints use license-based auth, not user auth ("/proxy/create-checkout-session", {"POST"}), ("/proxy/claim-license", {"POST"}), ("/proxy/create-customer-portal-session", {"POST"}), ("/proxy/billing-information", {"GET"}), ("/proxy/license/{tenant_id}", {"GET"}), ("/proxy/seats/update", {"POST"}), ] def check_ee_router_auth( application: FastAPI, public_endpoint_specs: list[tuple[str, set[str]]] = EE_PUBLIC_ENDPOINT_SPECS, ) -> None: # similar to the open source version of this function, but checking for the EE-only # endpoints as well check_router_auth(application, public_endpoint_specs) ================================================ FILE: backend/ee/onyx/server/billing/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/server/billing/api.py ================================================ """Unified Billing API endpoints. These endpoints provide Stripe billing functionality for both cloud and self-hosted deployments. The service layer routes requests appropriately: - Self-hosted: Routes through cloud data plane proxy Flow: Backend /admin/billing/* → Cloud DP /proxy/* → Control plane - Cloud (MULTI_TENANT): Routes directly to control plane Flow: Backend /admin/billing/* → Control plane License claiming is handled separately by /license/claim endpoint (self-hosted only). Migration Note (ENG-3533): This /admin/billing/* API replaces the older /tenants/* billing endpoints: - /tenants/billing-information -> /admin/billing/billing-information - /tenants/create-customer-portal-session -> /admin/billing/create-customer-portal-session - /tenants/create-subscription-session -> /admin/billing/create-checkout-session - /tenants/stripe-publishable-key -> /admin/billing/stripe-publishable-key See: https://linear.app/onyx-app/issue/ENG-3533/migrate-tenantsbilling-adminbilling """ import asyncio import httpx from fastapi import APIRouter from fastapi import Depends from pydantic import BaseModel from sqlalchemy.orm import Session from ee.onyx.auth.users import current_admin_user from ee.onyx.db.license import get_license from ee.onyx.db.license import get_used_seats from ee.onyx.server.billing.models import BillingInformationResponse from ee.onyx.server.billing.models import CreateCheckoutSessionRequest from ee.onyx.server.billing.models import CreateCheckoutSessionResponse from ee.onyx.server.billing.models import CreateCustomerPortalSessionRequest from ee.onyx.server.billing.models import CreateCustomerPortalSessionResponse from ee.onyx.server.billing.models import SeatUpdateRequest from ee.onyx.server.billing.models import SeatUpdateResponse from ee.onyx.server.billing.models import StripePublishableKeyResponse from ee.onyx.server.billing.models import SubscriptionStatusResponse from ee.onyx.server.billing.service import ( create_checkout_session as create_checkout_service, ) from ee.onyx.server.billing.service import ( create_customer_portal_session as create_portal_service, ) from ee.onyx.server.billing.service import ( get_billing_information as get_billing_service, ) from ee.onyx.server.billing.service import update_seat_count as update_seat_service from onyx.auth.users import User from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL from onyx.configs.app_configs import WEB_DOMAIN from onyx.db.engine.sql_engine import get_session from onyx.error_handling.error_codes import OnyxErrorCode from onyx.error_handling.exceptions import OnyxError from onyx.redis.redis_pool import get_shared_redis_client from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() router = APIRouter(prefix="/admin/billing") # Cache for Stripe publishable key to avoid hitting S3 on every request _stripe_publishable_key_cache: str | None = None _stripe_key_lock = asyncio.Lock() # Redis key for billing circuit breaker (self-hosted only) # When set, billing requests to Stripe are disabled until user manually retries BILLING_CIRCUIT_BREAKER_KEY = "billing_circuit_open" # Circuit breaker auto-expires after 1 hour (user can manually retry sooner) BILLING_CIRCUIT_BREAKER_TTL_SECONDS = 3600 def _is_billing_circuit_open() -> bool: """Check if the billing circuit breaker is open (self-hosted only).""" if MULTI_TENANT: return False try: redis_client = get_shared_redis_client() is_open = bool(redis_client.exists(BILLING_CIRCUIT_BREAKER_KEY)) logger.debug( f"Circuit breaker check: key={BILLING_CIRCUIT_BREAKER_KEY}, is_open={is_open}" ) return is_open except Exception as e: logger.error(f"Failed to check circuit breaker: {e}") return False def _open_billing_circuit() -> None: """Open the billing circuit breaker after a failure (self-hosted only).""" if MULTI_TENANT: return try: redis_client = get_shared_redis_client() redis_client.set( BILLING_CIRCUIT_BREAKER_KEY, "1", ex=BILLING_CIRCUIT_BREAKER_TTL_SECONDS, ) # Verify it was set exists = redis_client.exists(BILLING_CIRCUIT_BREAKER_KEY) logger.warning( f"Billing circuit breaker opened (TTL={BILLING_CIRCUIT_BREAKER_TTL_SECONDS}s, " f"verified={exists}). Stripe billing requests are disabled until manually reset." ) except Exception as e: logger.error(f"Failed to open circuit breaker: {e}") def _close_billing_circuit() -> None: """Close the billing circuit breaker (re-enable Stripe requests).""" if MULTI_TENANT: return try: redis_client = get_shared_redis_client() redis_client.delete(BILLING_CIRCUIT_BREAKER_KEY) logger.info( "Billing circuit breaker closed. Stripe billing requests re-enabled." ) except Exception as e: logger.error(f"Failed to close circuit breaker: {e}") def _get_license_data(db_session: Session) -> str | None: """Get license data from database if exists (self-hosted only).""" if MULTI_TENANT: return None license_record = get_license(db_session) return license_record.license_data if license_record else None def _get_tenant_id() -> str | None: """Get tenant ID for cloud deployments.""" if MULTI_TENANT: return get_current_tenant_id() return None @router.post("/create-checkout-session") async def create_checkout_session( request: CreateCheckoutSessionRequest | None = None, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> CreateCheckoutSessionResponse: """Create a Stripe checkout session for new subscription or renewal. For new customers, no license/tenant is required. For renewals, existing license (self-hosted) or tenant_id (cloud) is used. After checkout completion: - Self-hosted: Use /license/claim to retrieve the license - Cloud: Subscription is automatically activated """ license_data = _get_license_data(db_session) tenant_id = _get_tenant_id() billing_period = request.billing_period if request else "monthly" seats = request.seats if request else None email = request.email if request else None # Validate that requested seats is not less than current used seats if seats is not None: used_seats = get_used_seats(tenant_id) if seats < used_seats: raise OnyxError( OnyxErrorCode.VALIDATION_ERROR, f"Cannot subscribe with fewer seats than current usage. " f"You have {used_seats} active users/integrations but requested {seats} seats.", ) # Build redirect URL for after checkout completion redirect_url = f"{WEB_DOMAIN}/admin/billing?checkout=success" return await create_checkout_service( billing_period=billing_period, seats=seats, email=email, license_data=license_data, redirect_url=redirect_url, tenant_id=tenant_id, ) @router.post("/create-customer-portal-session") async def create_customer_portal_session( request: CreateCustomerPortalSessionRequest | None = None, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> CreateCustomerPortalSessionResponse: """Create a Stripe customer portal session for managing subscription. Requires existing license (self-hosted) or active tenant (cloud). """ license_data = _get_license_data(db_session) tenant_id = _get_tenant_id() # Self-hosted requires license if not MULTI_TENANT and not license_data: raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No license found") return_url = request.return_url if request else f"{WEB_DOMAIN}/admin/billing" return await create_portal_service( license_data=license_data, return_url=return_url, tenant_id=tenant_id, ) @router.get("/billing-information") async def get_billing_information( _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> BillingInformationResponse | SubscriptionStatusResponse: """Get billing information for the current subscription. Returns subscription status and details from Stripe. For self-hosted: If the circuit breaker is open (previous failure), returns a 503 error without making the request. """ license_data = _get_license_data(db_session) tenant_id = _get_tenant_id() # Self-hosted without license = no subscription if not MULTI_TENANT and not license_data: return SubscriptionStatusResponse(subscribed=False) # Check circuit breaker (self-hosted only) if _is_billing_circuit_open(): raise OnyxError( OnyxErrorCode.SERVICE_UNAVAILABLE, "Stripe connection temporarily disabled. Click 'Connect to Stripe' to retry.", ) try: return await get_billing_service( license_data=license_data, tenant_id=tenant_id, ) except OnyxError as e: # Open circuit breaker on connection failures (self-hosted only) if e.status_code in ( OnyxErrorCode.BAD_GATEWAY.status_code, OnyxErrorCode.SERVICE_UNAVAILABLE.status_code, OnyxErrorCode.GATEWAY_TIMEOUT.status_code, ): _open_billing_circuit() raise @router.post("/seats/update") async def update_seats( request: SeatUpdateRequest, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> SeatUpdateResponse: """Update the seat count for the current subscription. Handles Stripe proration and license regeneration via control plane. For self-hosted, the frontend should call /license/claim after a short delay to fetch the regenerated license. """ license_data = _get_license_data(db_session) tenant_id = _get_tenant_id() # Self-hosted requires license if not MULTI_TENANT and not license_data: raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "No license found") # Validate that new seat count is not less than current used seats used_seats = get_used_seats(tenant_id) if request.new_seat_count < used_seats: raise OnyxError( OnyxErrorCode.VALIDATION_ERROR, f"Cannot reduce seats below current usage. " f"You have {used_seats} active users/integrations but requested {request.new_seat_count} seats.", ) # Note: Don't store license here - the control plane may still be processing # the subscription update. The frontend should call /license/claim after a # short delay to get the freshly generated license. return await update_seat_service( new_seat_count=request.new_seat_count, license_data=license_data, tenant_id=tenant_id, ) @router.get("/stripe-publishable-key") async def get_stripe_publishable_key() -> StripePublishableKeyResponse: """Fetch the Stripe publishable key. Priority: env var override (for testing) > S3 bucket (production). This endpoint is public (no auth required) since publishable keys are safe to expose. The key is cached in memory to avoid hitting S3 on every request. """ global _stripe_publishable_key_cache # Fast path: return cached value without lock if _stripe_publishable_key_cache: return StripePublishableKeyResponse( publishable_key=_stripe_publishable_key_cache ) # Use lock to prevent concurrent S3 requests async with _stripe_key_lock: # Double-check after acquiring lock (another request may have populated cache) if _stripe_publishable_key_cache: return StripePublishableKeyResponse( publishable_key=_stripe_publishable_key_cache ) # Check for env var override first (for local testing with pk_test_* keys) if STRIPE_PUBLISHABLE_KEY_OVERRIDE: key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip() if not key.startswith("pk_"): raise OnyxError( OnyxErrorCode.INTERNAL_ERROR, "Invalid Stripe publishable key format", ) _stripe_publishable_key_cache = key return StripePublishableKeyResponse(publishable_key=key) # Fall back to S3 bucket if not STRIPE_PUBLISHABLE_KEY_URL: raise OnyxError( OnyxErrorCode.INTERNAL_ERROR, "Stripe publishable key is not configured", ) try: async with httpx.AsyncClient() as client: response = await client.get(STRIPE_PUBLISHABLE_KEY_URL) response.raise_for_status() key = response.text.strip() # Validate key format if not key.startswith("pk_"): raise OnyxError( OnyxErrorCode.INTERNAL_ERROR, "Invalid Stripe publishable key format", ) _stripe_publishable_key_cache = key return StripePublishableKeyResponse(publishable_key=key) except httpx.HTTPError: raise OnyxError( OnyxErrorCode.INTERNAL_ERROR, "Failed to fetch Stripe publishable key", ) class ResetConnectionResponse(BaseModel): success: bool message: str @router.post("/reset-connection") async def reset_stripe_connection( _: User = Depends(current_admin_user), ) -> ResetConnectionResponse: """Reset the Stripe connection circuit breaker. Called when user clicks "Connect to Stripe" to retry after a previous failure. This clears the circuit breaker flag, allowing billing requests to proceed again. Self-hosted only - cloud deployments don't use the circuit breaker. """ if MULTI_TENANT: return ResetConnectionResponse( success=True, message="Circuit breaker not applicable for cloud deployments", ) _close_billing_circuit() return ResetConnectionResponse( success=True, message="Stripe connection reset. Billing requests re-enabled.", ) ================================================ FILE: backend/ee/onyx/server/billing/models.py ================================================ """Pydantic models for the billing API.""" from datetime import datetime from typing import Literal from pydantic import BaseModel class CreateCheckoutSessionRequest(BaseModel): """Request to create a Stripe checkout session.""" billing_period: Literal["monthly", "annual"] = "monthly" seats: int | None = None email: str | None = None class CreateCheckoutSessionResponse(BaseModel): """Response containing the Stripe checkout session URL.""" stripe_checkout_url: str class CreateCustomerPortalSessionRequest(BaseModel): """Request to create a Stripe customer portal session.""" return_url: str | None = None class CreateCustomerPortalSessionResponse(BaseModel): """Response containing the Stripe customer portal URL.""" stripe_customer_portal_url: str class BillingInformationResponse(BaseModel): """Billing information for the current subscription.""" tenant_id: str status: str | None = None plan_type: str | None = None seats: int | None = None billing_period: str | None = None current_period_start: datetime | None = None current_period_end: datetime | None = None cancel_at_period_end: bool = False canceled_at: datetime | None = None trial_start: datetime | None = None trial_end: datetime | None = None payment_method_enabled: bool = False class SubscriptionStatusResponse(BaseModel): """Response when no subscription exists.""" subscribed: bool = False class SeatUpdateRequest(BaseModel): """Request to update seat count.""" new_seat_count: int class SeatUpdateResponse(BaseModel): """Response from seat update operation.""" success: bool current_seats: int used_seats: int message: str | None = None license: str | None = None # Regenerated license (self-hosted stores this) class StripePublishableKeyResponse(BaseModel): """Response containing the Stripe publishable key.""" publishable_key: str ================================================ FILE: backend/ee/onyx/server/billing/service.py ================================================ """Service layer for billing operations. This module provides functions for billing operations that route differently based on deployment type: - Self-hosted (not MULTI_TENANT): Routes through cloud data plane proxy Flow: Self-hosted backend → Cloud DP /proxy/* → Control plane - Cloud (MULTI_TENANT): Routes directly to control plane Flow: Cloud backend → Control plane """ from typing import Literal import httpx from ee.onyx.configs.app_configs import CLOUD_DATA_PLANE_URL from ee.onyx.server.billing.models import BillingInformationResponse from ee.onyx.server.billing.models import CreateCheckoutSessionResponse from ee.onyx.server.billing.models import CreateCustomerPortalSessionResponse from ee.onyx.server.billing.models import SeatUpdateResponse from ee.onyx.server.billing.models import SubscriptionStatusResponse from ee.onyx.server.tenants.access import generate_data_plane_token from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL from onyx.error_handling.error_codes import OnyxErrorCode from onyx.error_handling.exceptions import OnyxError from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT logger = setup_logger() # HTTP request timeout for billing service calls _REQUEST_TIMEOUT = 30.0 def _get_proxy_headers(license_data: str | None) -> dict[str, str]: """Build headers for proxy requests (self-hosted). Self-hosted instances authenticate with their license. """ headers = {"Content-Type": "application/json"} if license_data: headers["Authorization"] = f"Bearer {license_data}" return headers def _get_direct_headers() -> dict[str, str]: """Build headers for direct control plane requests (cloud). Cloud instances authenticate with JWT. """ token = generate_data_plane_token() return { "Content-Type": "application/json", "Authorization": f"Bearer {token}", } def _get_base_url() -> str: """Get the base URL based on deployment type.""" if MULTI_TENANT: return CONTROL_PLANE_API_BASE_URL return f"{CLOUD_DATA_PLANE_URL}/proxy" def _get_headers(license_data: str | None) -> dict[str, str]: """Get appropriate headers based on deployment type.""" if MULTI_TENANT: return _get_direct_headers() return _get_proxy_headers(license_data) async def _make_billing_request( method: Literal["GET", "POST"], path: str, license_data: str | None = None, body: dict | None = None, params: dict | None = None, error_message: str = "Billing service request failed", ) -> dict: """Make an HTTP request to the billing service. Consolidates the common HTTP request pattern used by all billing operations. Args: method: HTTP method (GET or POST) path: URL path (appended to base URL) license_data: License for authentication (self-hosted) body: Request body for POST requests params: Query parameters for GET requests error_message: Default error message if request fails Returns: Response JSON as dict Raises: OnyxError: If request fails """ base_url = _get_base_url() url = f"{base_url}{path}" headers = _get_headers(license_data) try: async with httpx.AsyncClient( timeout=_REQUEST_TIMEOUT, follow_redirects=True ) as client: if method == "GET": response = await client.get(url, headers=headers, params=params) else: response = await client.post(url, headers=headers, json=body) response.raise_for_status() return response.json() except httpx.HTTPStatusError as e: detail = error_message try: error_data = e.response.json() detail = error_data.get("detail", detail) except Exception: pass logger.error(f"{error_message}: {e.response.status_code} - {detail}") raise OnyxError( OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=e.response.status_code, ) except httpx.RequestError: logger.exception("Failed to connect to billing service") raise OnyxError( OnyxErrorCode.BAD_GATEWAY, "Failed to connect to billing service" ) async def create_checkout_session( billing_period: str = "monthly", seats: int | None = None, email: str | None = None, license_data: str | None = None, redirect_url: str | None = None, tenant_id: str | None = None, ) -> CreateCheckoutSessionResponse: """Create a Stripe checkout session. Args: billing_period: "monthly" or "annual" seats: Number of seats to purchase (optional, uses default if not provided) email: Customer email for new subscriptions license_data: Existing license for renewals (self-hosted) redirect_url: URL to redirect after successful checkout tenant_id: Tenant ID (cloud only, for renewals) Returns: CreateCheckoutSessionResponse with checkout URL """ body: dict = {"billing_period": billing_period} if seats is not None: body["seats"] = seats if email: body["email"] = email if redirect_url: body["redirect_url"] = redirect_url if tenant_id and MULTI_TENANT: body["tenant_id"] = tenant_id data = await _make_billing_request( method="POST", path="/create-checkout-session", license_data=license_data, body=body, error_message="Failed to create checkout session", ) return CreateCheckoutSessionResponse(stripe_checkout_url=data["url"]) async def create_customer_portal_session( license_data: str | None = None, return_url: str | None = None, tenant_id: str | None = None, ) -> CreateCustomerPortalSessionResponse: """Create a Stripe customer portal session. Args: license_data: License blob for authentication (self-hosted) return_url: URL to return to after portal session tenant_id: Tenant ID (cloud only) Returns: CreateCustomerPortalSessionResponse with portal URL """ body: dict = {} if return_url: body["return_url"] = return_url if tenant_id and MULTI_TENANT: body["tenant_id"] = tenant_id data = await _make_billing_request( method="POST", path="/create-customer-portal-session", license_data=license_data, body=body, error_message="Failed to create customer portal session", ) return CreateCustomerPortalSessionResponse(stripe_customer_portal_url=data["url"]) async def get_billing_information( license_data: str | None = None, tenant_id: str | None = None, ) -> BillingInformationResponse | SubscriptionStatusResponse: """Fetch billing information. Args: license_data: License blob for authentication (self-hosted) tenant_id: Tenant ID (cloud only) Returns: BillingInformationResponse or SubscriptionStatusResponse if no subscription """ params = {} if tenant_id and MULTI_TENANT: params["tenant_id"] = tenant_id data = await _make_billing_request( method="GET", path="/billing-information", license_data=license_data, params=params or None, error_message="Failed to fetch billing information", ) # Check if no subscription if isinstance(data, dict) and data.get("subscribed") is False: return SubscriptionStatusResponse(subscribed=False) return BillingInformationResponse(**data) async def update_seat_count( new_seat_count: int, license_data: str | None = None, tenant_id: str | None = None, ) -> SeatUpdateResponse: """Update the seat count for the current subscription. Args: new_seat_count: New number of seats license_data: License blob for authentication (self-hosted) tenant_id: Tenant ID (cloud only) Returns: SeatUpdateResponse with updated seat information """ body: dict = {"new_seat_count": new_seat_count} if tenant_id and MULTI_TENANT: body["tenant_id"] = tenant_id data = await _make_billing_request( method="POST", path="/seats/update", license_data=license_data, body=body, error_message="Failed to update seat count", ) return SeatUpdateResponse( success=data.get("success", False), current_seats=data.get("current_seats", 0), used_seats=data.get("used_seats", 0), message=data.get("message"), license=data.get("license"), ) ================================================ FILE: backend/ee/onyx/server/documents/cc_pair.py ================================================ from datetime import datetime from http import HTTPStatus from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from sqlalchemy.orm import Session from ee.onyx.background.celery.tasks.doc_permission_syncing.tasks import ( try_creating_permissions_sync_task, ) from ee.onyx.background.celery.tasks.external_group_syncing.tasks import ( try_creating_external_group_sync_task, ) from onyx.auth.users import current_curator_or_admin_user from onyx.background.celery.versioned_apps.client import app as client_app from onyx.db.connector_credential_pair import ( get_connector_credential_pair_from_id_for_user, ) from onyx.db.engine.sql_engine import get_session from onyx.db.models import User from onyx.redis.redis_connector import RedisConnector from onyx.redis.redis_pool import get_redis_client from onyx.server.models import StatusResponse from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() router = APIRouter(prefix="/manage") @router.get("/admin/cc-pair/{cc_pair_id}/sync-permissions") def get_cc_pair_latest_sync( cc_pair_id: int, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> datetime | None: cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id=cc_pair_id, db_session=db_session, user=user, get_editable=False, ) if not cc_pair: raise HTTPException( status_code=400, detail="cc_pair not found for current user's permissions", ) return cc_pair.last_time_perm_sync @router.post("/admin/cc-pair/{cc_pair_id}/sync-permissions") def sync_cc_pair( cc_pair_id: int, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse[None]: """Triggers permissions sync on a particular cc_pair immediately""" tenant_id = get_current_tenant_id() cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id=cc_pair_id, db_session=db_session, user=user, get_editable=False, ) if not cc_pair: raise HTTPException( status_code=400, detail="Connection not found for current user's permissions", ) r = get_redis_client() redis_connector = RedisConnector(tenant_id, cc_pair_id) if redis_connector.permissions.fenced: raise HTTPException( status_code=HTTPStatus.CONFLICT, detail="Permissions sync task already in progress.", ) logger.info( f"Permissions sync cc_pair={cc_pair_id} " f"connector_id={cc_pair.connector_id} " f"credential_id={cc_pair.credential_id} " f"{cc_pair.connector.name} connector." ) payload_id = try_creating_permissions_sync_task( client_app, cc_pair_id, r, tenant_id ) if not payload_id: raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Permissions sync task creation failed.", ) logger.info(f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}") return StatusResponse( success=True, message="Successfully created the permissions sync task.", ) @router.get("/admin/cc-pair/{cc_pair_id}/sync-groups") def get_cc_pair_latest_group_sync( cc_pair_id: int, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> datetime | None: cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id=cc_pair_id, db_session=db_session, user=user, get_editable=False, ) if not cc_pair: raise HTTPException( status_code=400, detail="cc_pair not found for current user's permissions", ) return cc_pair.last_time_external_group_sync @router.post("/admin/cc-pair/{cc_pair_id}/sync-groups") def sync_cc_pair_groups( cc_pair_id: int, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse[None]: """Triggers group sync on a particular cc_pair immediately""" tenant_id = get_current_tenant_id() cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id=cc_pair_id, db_session=db_session, user=user, get_editable=False, ) if not cc_pair: raise HTTPException( status_code=400, detail="Connection not found for current user's permissions", ) r = get_redis_client() redis_connector = RedisConnector(tenant_id, cc_pair_id) if redis_connector.external_group_sync.fenced: raise HTTPException( status_code=HTTPStatus.CONFLICT, detail="External group sync task already in progress.", ) logger.info( f"External group sync cc_pair={cc_pair_id} " f"connector_id={cc_pair.connector_id} " f"credential_id={cc_pair.credential_id} " f"{cc_pair.connector.name} connector." ) payload_id = try_creating_external_group_sync_task( client_app, cc_pair_id, r, tenant_id ) if not payload_id: raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="External group sync task creation failed.", ) logger.info(f"External group sync queued: cc_pair={cc_pair_id} id={payload_id}") return StatusResponse( success=True, message="Successfully created the external group sync task.", ) ================================================ FILE: backend/ee/onyx/server/enterprise_settings/api.py ================================================ from datetime import datetime from datetime import timezone from typing import Any import httpx from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import Response from fastapi import status from fastapi import UploadFile from pydantic import BaseModel from pydantic import Field from sqlalchemy.orm import Session from ee.onyx.db.scim import ScimDAL from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload from ee.onyx.server.enterprise_settings.models import EnterpriseSettings from ee.onyx.server.enterprise_settings.store import get_logo_filename from ee.onyx.server.enterprise_settings.store import get_logotype_filename from ee.onyx.server.enterprise_settings.store import load_analytics_script from ee.onyx.server.enterprise_settings.store import load_settings from ee.onyx.server.enterprise_settings.store import store_analytics_script from ee.onyx.server.enterprise_settings.store import store_settings from ee.onyx.server.enterprise_settings.store import upload_logo from ee.onyx.server.scim.auth import generate_scim_token from ee.onyx.server.scim.models import ScimTokenCreate from ee.onyx.server.scim.models import ScimTokenCreatedResponse from ee.onyx.server.scim.models import ScimTokenResponse from onyx.auth.users import current_admin_user from onyx.auth.users import current_user_with_expired_token from onyx.auth.users import get_user_manager from onyx.auth.users import UserManager from onyx.db.engine.sql_engine import get_session from onyx.db.models import User from onyx.file_store.file_store import get_default_file_store from onyx.server.utils import BasicAuthenticationError from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.contextvars import get_current_tenant_id admin_router = APIRouter(prefix="/admin/enterprise-settings") basic_router = APIRouter(prefix="/enterprise-settings") logger = setup_logger() class RefreshTokenData(BaseModel): access_token: str refresh_token: str session: dict = Field(..., description="Contains session information") userinfo: dict = Field(..., description="Contains user information") def __init__(self, **data: Any) -> None: super().__init__(**data) if "exp" not in self.session: raise ValueError("'exp' must be set in the session dictionary") if "userId" not in self.userinfo or "email" not in self.userinfo: raise ValueError( "'userId' and 'email' must be set in the userinfo dictionary" ) @basic_router.post("/refresh-token") async def refresh_access_token( refresh_token: RefreshTokenData, user: User = Depends(current_user_with_expired_token), user_manager: UserManager = Depends(get_user_manager), ) -> None: try: logger.debug(f"Received response from Meechum auth URL for user {user.id}") # Extract new tokens new_access_token = refresh_token.access_token new_refresh_token = refresh_token.refresh_token new_expiry = datetime.fromtimestamp( refresh_token.session["exp"] / 1000, tz=timezone.utc ) expires_at_timestamp = int(new_expiry.timestamp()) logger.debug(f"Access token has been refreshed for user {user.id}") await user_manager.oauth_callback( oauth_name="custom", access_token=new_access_token, account_id=refresh_token.userinfo["userId"], account_email=refresh_token.userinfo["email"], expires_at=expires_at_timestamp, refresh_token=new_refresh_token, associate_by_email=True, ) logger.info(f"Successfully refreshed tokens for user {user.id}") except httpx.HTTPStatusError as e: if e.response.status_code == 401: logger.warning(f"Full authentication required for user {user.id}") raise HTTPException( status_code=status.HTTP_401_UNAUTHORIZED, detail="Full authentication required", ) logger.error( f"HTTP error occurred while refreshing token for user {user.id}: {str(e)}" ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Failed to refresh token", ) except Exception as e: logger.error( f"Unexpected error occurred while refreshing token for user {user.id}: {str(e)}" ) raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="An unexpected error occurred", ) @admin_router.put("") def admin_ee_put_settings( settings: EnterpriseSettings, _: User = Depends(current_admin_user) ) -> None: store_settings(settings) @basic_router.get("") def ee_fetch_settings() -> EnterpriseSettings: if MULTI_TENANT: tenant_id = get_current_tenant_id() if not tenant_id or tenant_id == POSTGRES_DEFAULT_SCHEMA: raise BasicAuthenticationError(detail="User must authenticate") return load_settings() @admin_router.put("/logo") def put_logo( file: UploadFile, is_logotype: bool = False, _: User = Depends(current_admin_user), ) -> None: upload_logo(file=file, is_logotype=is_logotype) def fetch_logo_helper(db_session: Session) -> Response: # noqa: ARG001 try: file_store = get_default_file_store() onyx_file = file_store.get_file_with_mime_type(get_logo_filename()) if not onyx_file: raise ValueError("get_onyx_file returned None!") except Exception: logger.exception("Faield to fetch logo file") raise HTTPException( status_code=404, detail="No logo file found", ) else: return Response( content=onyx_file.data, media_type=onyx_file.mime_type, headers={"Cache-Control": "no-cache"}, ) def fetch_logotype_helper(db_session: Session) -> Response: # noqa: ARG001 try: file_store = get_default_file_store() onyx_file = file_store.get_file_with_mime_type(get_logotype_filename()) if not onyx_file: raise ValueError("get_onyx_file returned None!") except Exception: raise HTTPException( status_code=404, detail="No logotype file found", ) else: return Response(content=onyx_file.data, media_type=onyx_file.mime_type) @basic_router.get("/logotype") def fetch_logotype(db_session: Session = Depends(get_session)) -> Response: return fetch_logotype_helper(db_session) @basic_router.get("/logo") def fetch_logo( is_logotype: bool = False, db_session: Session = Depends(get_session) ) -> Response: if is_logotype: return fetch_logotype_helper(db_session) return fetch_logo_helper(db_session) @admin_router.put("/custom-analytics-script") def upload_custom_analytics_script( script_upload: AnalyticsScriptUpload, _: User = Depends(current_admin_user) ) -> None: try: store_analytics_script(script_upload) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @basic_router.get("/custom-analytics-script") def fetch_custom_analytics_script() -> str | None: return load_analytics_script() # --------------------------------------------------------------------------- # SCIM token management # --------------------------------------------------------------------------- def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL: return ScimDAL(db_session) @admin_router.get("/scim/token") def get_active_scim_token( _: User = Depends(current_admin_user), dal: ScimDAL = Depends(_get_scim_dal), ) -> ScimTokenResponse: """Return the currently active SCIM token's metadata, or 404 if none.""" token = dal.get_active_token() if not token: raise HTTPException(status_code=404, detail="No active SCIM token") # Derive the IdP domain from the first synced user as a heuristic. idp_domain: str | None = None mappings, _total = dal.list_user_mappings(start_index=1, count=1) if mappings: user = dal.get_user(mappings[0].user_id) if user and "@" in user.email: idp_domain = user.email.rsplit("@", 1)[1] return ScimTokenResponse( id=token.id, name=token.name, token_display=token.token_display, is_active=token.is_active, created_at=token.created_at, last_used_at=token.last_used_at, idp_domain=idp_domain, ) @admin_router.post("/scim/token", status_code=201) def create_scim_token( body: ScimTokenCreate, user: User = Depends(current_admin_user), dal: ScimDAL = Depends(_get_scim_dal), ) -> ScimTokenCreatedResponse: """Create a new SCIM bearer token. Only one token is active at a time — creating a new token automatically revokes all previous tokens. The raw token value is returned exactly once in the response; it cannot be retrieved again. """ raw_token, hashed_token, token_display = generate_scim_token() token = dal.create_token( name=body.name, hashed_token=hashed_token, token_display=token_display, created_by_id=user.id, ) dal.commit() return ScimTokenCreatedResponse( id=token.id, name=token.name, token_display=token.token_display, is_active=token.is_active, created_at=token.created_at, last_used_at=token.last_used_at, raw_token=raw_token, ) ================================================ FILE: backend/ee/onyx/server/enterprise_settings/models.py ================================================ from enum import Enum from typing import Any from typing import List from pydantic import BaseModel from pydantic import Field class NavigationItem(BaseModel): link: str title: str # Right now must be one of the FA icons icon: str | None = None # NOTE: SVG must not have a width / height specified # This is the actual SVG as a string. Done this way to reduce # complexity / having to store additional "logos" in Postgres svg_logo: str | None = None @classmethod def model_validate(cls, *args: Any, **kwargs: Any) -> "NavigationItem": instance = super().model_validate(*args, **kwargs) if bool(instance.icon) == bool(instance.svg_logo): raise ValueError("Exactly one of fa_icon or svg_logo must be specified") return instance class LogoDisplayStyle(str, Enum): LOGO_AND_NAME = "logo_and_name" LOGO_ONLY = "logo_only" NAME_ONLY = "name_only" class EnterpriseSettings(BaseModel): """General settings that only apply to the Enterprise Edition of Onyx NOTE: don't put anything sensitive in here, as this is accessible without auth.""" application_name: str | None = None use_custom_logo: bool = False use_custom_logotype: bool = False logo_display_style: LogoDisplayStyle | None = None # custom navigation custom_nav_items: List[NavigationItem] = Field(default_factory=list) # custom Chat components two_lines_for_chat_header: bool | None = None custom_lower_disclaimer_content: str | None = None custom_header_content: str | None = None custom_popup_header: str | None = None custom_popup_content: str | None = None enable_consent_screen: bool | None = None consent_screen_prompt: str | None = None show_first_visit_notice: bool | None = None custom_greeting_message: str | None = None def check_validity(self) -> None: return class AnalyticsScriptUpload(BaseModel): script: str secret_key: str ================================================ FILE: backend/ee/onyx/server/enterprise_settings/store.py ================================================ import os from io import BytesIO from typing import Any from typing import cast from typing import IO from fastapi import HTTPException from fastapi import UploadFile from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload from ee.onyx.server.enterprise_settings.models import EnterpriseSettings from onyx.configs.constants import FileOrigin from onyx.configs.constants import KV_CUSTOM_ANALYTICS_SCRIPT_KEY from onyx.configs.constants import KV_ENTERPRISE_SETTINGS_KEY from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME from onyx.file_store.file_store import get_default_file_store from onyx.key_value_store.factory import get_kv_store from onyx.key_value_store.interface import KvKeyNotFoundError from onyx.utils.logger import setup_logger logger = setup_logger() _LOGO_FILENAME = "__logo__" _LOGOTYPE_FILENAME = "__logotype__" def load_settings() -> EnterpriseSettings: """Loads settings data directly from DB. This should be used primarily for checking what is actually in the DB, aka for editing and saving back settings. Runtime settings actually used by the application should be checked with load_runtime_settings as defaults may be applied at runtime. """ dynamic_config_store = get_kv_store() try: settings = EnterpriseSettings( **cast(dict, dynamic_config_store.load(KV_ENTERPRISE_SETTINGS_KEY)) ) except KvKeyNotFoundError: settings = EnterpriseSettings() dynamic_config_store.store(KV_ENTERPRISE_SETTINGS_KEY, settings.model_dump()) return settings def store_settings(settings: EnterpriseSettings) -> None: """Stores settings directly to the kv store / db.""" get_kv_store().store(KV_ENTERPRISE_SETTINGS_KEY, settings.model_dump()) def load_runtime_settings() -> EnterpriseSettings: """Loads settings from DB and applies any defaults or transformations for use at runtime. Should not be stored back to the DB. """ enterprise_settings = load_settings() if not enterprise_settings.application_name: enterprise_settings.application_name = ONYX_DEFAULT_APPLICATION_NAME return enterprise_settings _CUSTOM_ANALYTICS_SECRET_KEY = os.environ.get("CUSTOM_ANALYTICS_SECRET_KEY") def load_analytics_script() -> str | None: dynamic_config_store = get_kv_store() try: return cast(str, dynamic_config_store.load(KV_CUSTOM_ANALYTICS_SCRIPT_KEY)) except KvKeyNotFoundError: return None def store_analytics_script(analytics_script_upload: AnalyticsScriptUpload) -> None: if ( not _CUSTOM_ANALYTICS_SECRET_KEY or analytics_script_upload.secret_key != _CUSTOM_ANALYTICS_SECRET_KEY ): raise ValueError("Invalid secret key") get_kv_store().store(KV_CUSTOM_ANALYTICS_SCRIPT_KEY, analytics_script_upload.script) def is_valid_file_type(filename: str) -> bool: valid_extensions = (".png", ".jpg", ".jpeg") return filename.endswith(valid_extensions) def guess_file_type(filename: str) -> str: if filename.lower().endswith(".png"): return "image/png" elif filename.lower().endswith(".jpg") or filename.lower().endswith(".jpeg"): return "image/jpeg" return "application/octet-stream" def upload_logo(file: UploadFile | str, is_logotype: bool = False) -> bool: content: IO[Any] if isinstance(file, str): logger.notice(f"Uploading logo from local path {file}") if not os.path.isfile(file) or not is_valid_file_type(file): logger.error( "Invalid file type- only .png, .jpg, and .jpeg files are allowed" ) return False with open(file, "rb") as file_handle: file_content = file_handle.read() content = BytesIO(file_content) display_name = file file_type = guess_file_type(file) else: logger.notice("Uploading logo from uploaded file") if not file.filename or not is_valid_file_type(file.filename): raise HTTPException( status_code=400, detail="Invalid file type- only .png, .jpg, and .jpeg files are allowed", ) content = file.file display_name = file.filename file_type = file.content_type or "image/jpeg" file_store = get_default_file_store() file_store.save_file( content=content, display_name=display_name, file_origin=FileOrigin.OTHER, file_type=file_type, file_id=_LOGOTYPE_FILENAME if is_logotype else _LOGO_FILENAME, ) return True def get_logo_filename() -> str: return _LOGO_FILENAME def get_logotype_filename() -> str: return _LOGOTYPE_FILENAME ================================================ FILE: backend/ee/onyx/server/evals/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/server/evals/api.py ================================================ from fastapi import APIRouter from fastapi import Depends from ee.onyx.auth.users import current_cloud_superuser from onyx.background.celery.apps.client import celery_app as client_app from onyx.configs.constants import OnyxCeleryTask from onyx.db.models import User from onyx.evals.models import EvalConfigurationOptions from onyx.server.evals.models import EvalRunAck from onyx.utils.logger import setup_logger logger = setup_logger() router = APIRouter(prefix="/evals") @router.post("/eval_run", response_model=EvalRunAck) def eval_run( request: EvalConfigurationOptions, user: User = Depends(current_cloud_superuser), # noqa: ARG001 ) -> EvalRunAck: """ Run an evaluation with the given message and optional dataset. This endpoint requires a valid API key for authentication. """ client_app.send_task( OnyxCeleryTask.EVAL_RUN_TASK, kwargs={ "configuration_dict": request.model_dump(), }, ) return EvalRunAck(success=True) ================================================ FILE: backend/ee/onyx/server/features/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/server/features/hooks/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/server/features/hooks/api.py ================================================ import httpx from fastapi import APIRouter from fastapi import Depends from fastapi import Query from sqlalchemy.orm import Session from onyx.auth.users import current_admin_user from onyx.auth.users import User from onyx.db.constants import UNSET from onyx.db.constants import UnsetType from onyx.db.engine.sql_engine import get_session from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.hook import create_hook__no_commit from onyx.db.hook import delete_hook__no_commit from onyx.db.hook import get_hook_by_id from onyx.db.hook import get_hook_execution_logs from onyx.db.hook import get_hooks from onyx.db.hook import update_hook__no_commit from onyx.db.models import Hook from onyx.error_handling.error_codes import OnyxErrorCode from onyx.error_handling.exceptions import OnyxError from onyx.hooks.api_dependencies import require_hook_enabled from onyx.hooks.models import HookCreateRequest from onyx.hooks.models import HookExecutionRecord from onyx.hooks.models import HookPointMetaResponse from onyx.hooks.models import HookResponse from onyx.hooks.models import HookUpdateRequest from onyx.hooks.models import HookValidateResponse from onyx.hooks.models import HookValidateStatus from onyx.hooks.registry import get_all_specs from onyx.hooks.registry import get_hook_point_spec from onyx.utils.logger import setup_logger from onyx.utils.url import SSRFException from onyx.utils.url import validate_outbound_http_url logger = setup_logger() # --------------------------------------------------------------------------- # SSRF protection # --------------------------------------------------------------------------- def _check_ssrf_safety(endpoint_url: str) -> None: """Raise OnyxError if endpoint_url could be used for SSRF. Delegates to validate_outbound_http_url with https_only=True. Uses BAD_GATEWAY so the frontend maps the error to the Endpoint URL field. """ try: validate_outbound_http_url(endpoint_url, https_only=True) except (SSRFException, ValueError) as e: raise OnyxError(OnyxErrorCode.BAD_GATEWAY, str(e)) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _hook_to_response(hook: Hook, creator_email: str | None = None) -> HookResponse: return HookResponse( id=hook.id, name=hook.name, hook_point=hook.hook_point, endpoint_url=hook.endpoint_url, api_key_masked=( hook.api_key.get_value(apply_mask=True) if hook.api_key else None ), fail_strategy=hook.fail_strategy, timeout_seconds=hook.timeout_seconds, is_active=hook.is_active, is_reachable=hook.is_reachable, creator_email=( creator_email if creator_email is not None else (hook.creator.email if hook.creator else None) ), created_at=hook.created_at, updated_at=hook.updated_at, ) def _get_hook_or_404( db_session: Session, hook_id: int, include_creator: bool = False, ) -> Hook: hook = get_hook_by_id( db_session=db_session, hook_id=hook_id, include_creator=include_creator, ) if hook is None: raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook {hook_id} not found.") return hook def _raise_for_validation_failure(validation: HookValidateResponse) -> None: """Raise an appropriate OnyxError for a non-passed validation result.""" if validation.status == HookValidateStatus.auth_failed: raise OnyxError(OnyxErrorCode.CREDENTIAL_INVALID, validation.error_message) if validation.status == HookValidateStatus.timeout: raise OnyxError( OnyxErrorCode.GATEWAY_TIMEOUT, f"Endpoint validation failed: {validation.error_message}", ) raise OnyxError( OnyxErrorCode.BAD_GATEWAY, f"Endpoint validation failed: {validation.error_message}", ) def _validate_endpoint( endpoint_url: str, api_key: str | None, timeout_seconds: float, ) -> HookValidateResponse: """Check whether endpoint_url is reachable by sending an empty POST request. We use POST since hook endpoints expect POST requests. The server will typically respond with 4xx (missing/invalid body) — that is fine. Any HTTP response means the server is up and routable. A 401/403 response returns auth_failed (not reachable — indicates the api_key is invalid). Timeout handling: - Any httpx.TimeoutException (ConnectTimeout, ReadTimeout, WriteTimeout, PoolTimeout) → timeout (operator should consider increasing timeout_seconds). - All other exceptions → cannot_connect. """ _check_ssrf_safety(endpoint_url) headers: dict[str, str] = {} if api_key: headers["Authorization"] = f"Bearer {api_key}" try: with httpx.Client(timeout=timeout_seconds, follow_redirects=False) as client: response = client.post(endpoint_url, headers=headers) if response.status_code in (401, 403): return HookValidateResponse( status=HookValidateStatus.auth_failed, error_message=f"Authentication failed (HTTP {response.status_code})", ) return HookValidateResponse(status=HookValidateStatus.passed) except httpx.TimeoutException as exc: # Any timeout (connect, read, or write) means the configured timeout_seconds # is too low for this endpoint. Report as timeout so the UI directs the user # to increase the timeout setting. logger.warning( "Hook endpoint validation: timeout for %s", endpoint_url, exc_info=exc, ) return HookValidateResponse( status=HookValidateStatus.timeout, error_message="Endpoint timed out — consider increasing timeout_seconds.", ) except Exception as exc: logger.warning( "Hook endpoint validation: connection error for %s", endpoint_url, exc_info=exc, ) return HookValidateResponse( status=HookValidateStatus.cannot_connect, error_message=str(exc) ) # --------------------------------------------------------------------------- # Routers # --------------------------------------------------------------------------- router = APIRouter(prefix="/admin/hooks") # --------------------------------------------------------------------------- # Hook endpoints # --------------------------------------------------------------------------- @router.get("/specs") def get_hook_point_specs( _: User = Depends(current_admin_user), _hook_enabled: None = Depends(require_hook_enabled), ) -> list[HookPointMetaResponse]: return [ HookPointMetaResponse( hook_point=spec.hook_point, display_name=spec.display_name, description=spec.description, docs_url=spec.docs_url, input_schema=spec.input_schema, output_schema=spec.output_schema, default_timeout_seconds=spec.default_timeout_seconds, default_fail_strategy=spec.default_fail_strategy, fail_hard_description=spec.fail_hard_description, ) for spec in get_all_specs() ] @router.get("") def list_hooks( _: User = Depends(current_admin_user), _hook_enabled: None = Depends(require_hook_enabled), db_session: Session = Depends(get_session), ) -> list[HookResponse]: hooks = get_hooks(db_session=db_session, include_creator=True) return [_hook_to_response(h) for h in hooks] @router.post("") def create_hook( req: HookCreateRequest, user: User = Depends(current_admin_user), _hook_enabled: None = Depends(require_hook_enabled), db_session: Session = Depends(get_session), ) -> HookResponse: """Create a new hook. The endpoint is validated before persisting — creation fails if the endpoint cannot be reached or the api_key is invalid. Hooks are created active. """ spec = get_hook_point_spec(req.hook_point) api_key = req.api_key.get_secret_value() if req.api_key else None validation = _validate_endpoint( endpoint_url=req.endpoint_url, api_key=api_key, timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds, ) if validation.status != HookValidateStatus.passed: _raise_for_validation_failure(validation) hook = create_hook__no_commit( db_session=db_session, name=req.name, hook_point=req.hook_point, endpoint_url=req.endpoint_url, api_key=api_key, fail_strategy=req.fail_strategy or spec.default_fail_strategy, timeout_seconds=req.timeout_seconds or spec.default_timeout_seconds, is_active=True, is_reachable=True, creator_id=user.id, ) db_session.commit() return _hook_to_response(hook, creator_email=user.email) @router.get("/{hook_id}") def get_hook( hook_id: int, _: User = Depends(current_admin_user), _hook_enabled: None = Depends(require_hook_enabled), db_session: Session = Depends(get_session), ) -> HookResponse: hook = _get_hook_or_404(db_session, hook_id, include_creator=True) return _hook_to_response(hook) @router.patch("/{hook_id}") def update_hook( hook_id: int, req: HookUpdateRequest, _: User = Depends(current_admin_user), _hook_enabled: None = Depends(require_hook_enabled), db_session: Session = Depends(get_session), ) -> HookResponse: """Update hook fields. If endpoint_url, api_key, or timeout_seconds changes, the endpoint is re-validated using the effective values. For active hooks the update is rejected on validation failure, keeping live traffic unaffected. For inactive hooks the update goes through regardless and is_reachable is updated to reflect the result. Note: if an active hook's endpoint is currently down, even a timeout_seconds-only increase will be rejected. The recovery flow is: deactivate → update → reactivate. """ # api_key: UNSET = no change, None = clear, value = update api_key: str | None | UnsetType if "api_key" not in req.model_fields_set: api_key = UNSET elif req.api_key is None: api_key = None else: api_key = req.api_key.get_secret_value() endpoint_url_changing = "endpoint_url" in req.model_fields_set api_key_changing = not isinstance(api_key, UnsetType) timeout_changing = "timeout_seconds" in req.model_fields_set validated_is_reachable: bool | None = None if endpoint_url_changing or api_key_changing or timeout_changing: existing = _get_hook_or_404(db_session, hook_id) effective_url: str = ( req.endpoint_url if endpoint_url_changing else existing.endpoint_url # type: ignore[assignment] # endpoint_url is required on create and cannot be cleared on update ) effective_api_key: str | None = ( (api_key if not isinstance(api_key, UnsetType) else None) if api_key_changing else ( existing.api_key.get_value(apply_mask=False) if existing.api_key else None ) ) effective_timeout: float = ( req.timeout_seconds if timeout_changing else existing.timeout_seconds # type: ignore[assignment] # req.timeout_seconds is non-None when timeout_changing (validated by HookUpdateRequest) ) validation = _validate_endpoint( endpoint_url=effective_url, api_key=effective_api_key, timeout_seconds=effective_timeout, ) if existing.is_active and validation.status != HookValidateStatus.passed: _raise_for_validation_failure(validation) validated_is_reachable = validation.status == HookValidateStatus.passed hook = update_hook__no_commit( db_session=db_session, hook_id=hook_id, name=req.name, endpoint_url=(req.endpoint_url if endpoint_url_changing else UNSET), api_key=api_key, fail_strategy=req.fail_strategy, timeout_seconds=req.timeout_seconds, is_reachable=validated_is_reachable, include_creator=True, ) db_session.commit() return _hook_to_response(hook) @router.delete("/{hook_id}") def delete_hook( hook_id: int, _: User = Depends(current_admin_user), _hook_enabled: None = Depends(require_hook_enabled), db_session: Session = Depends(get_session), ) -> None: delete_hook__no_commit(db_session=db_session, hook_id=hook_id) db_session.commit() @router.post("/{hook_id}/activate") def activate_hook( hook_id: int, _: User = Depends(current_admin_user), _hook_enabled: None = Depends(require_hook_enabled), db_session: Session = Depends(get_session), ) -> HookResponse: hook = _get_hook_or_404(db_session, hook_id) if not hook.endpoint_url: raise OnyxError( OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured." ) api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None validation = _validate_endpoint( endpoint_url=hook.endpoint_url, api_key=api_key, timeout_seconds=hook.timeout_seconds, ) if validation.status != HookValidateStatus.passed: # Persist is_reachable=False in a separate session so the request # session has no commits on the failure path and the transaction # boundary stays clean. if hook.is_reachable is not False: with get_session_with_current_tenant() as side_session: update_hook__no_commit( db_session=side_session, hook_id=hook_id, is_reachable=False ) side_session.commit() _raise_for_validation_failure(validation) hook = update_hook__no_commit( db_session=db_session, hook_id=hook_id, is_active=True, is_reachable=True, include_creator=True, ) db_session.commit() return _hook_to_response(hook) @router.post("/{hook_id}/validate") def validate_hook( hook_id: int, _: User = Depends(current_admin_user), _hook_enabled: None = Depends(require_hook_enabled), db_session: Session = Depends(get_session), ) -> HookValidateResponse: hook = _get_hook_or_404(db_session, hook_id) if not hook.endpoint_url: raise OnyxError( OnyxErrorCode.INVALID_INPUT, "Hook has no endpoint URL configured." ) api_key = hook.api_key.get_value(apply_mask=False) if hook.api_key else None validation = _validate_endpoint( endpoint_url=hook.endpoint_url, api_key=api_key, timeout_seconds=hook.timeout_seconds, ) validation_passed = validation.status == HookValidateStatus.passed if hook.is_reachable != validation_passed: update_hook__no_commit( db_session=db_session, hook_id=hook_id, is_reachable=validation_passed ) db_session.commit() return validation @router.post("/{hook_id}/deactivate") def deactivate_hook( hook_id: int, _: User = Depends(current_admin_user), _hook_enabled: None = Depends(require_hook_enabled), db_session: Session = Depends(get_session), ) -> HookResponse: hook = update_hook__no_commit( db_session=db_session, hook_id=hook_id, is_active=False, include_creator=True, ) db_session.commit() return _hook_to_response(hook) # --------------------------------------------------------------------------- # Execution log endpoints # --------------------------------------------------------------------------- @router.get("/{hook_id}/execution-logs") def list_hook_execution_logs( hook_id: int, limit: int = Query(default=10, ge=1, le=100), _: User = Depends(current_admin_user), _hook_enabled: None = Depends(require_hook_enabled), db_session: Session = Depends(get_session), ) -> list[HookExecutionRecord]: _get_hook_or_404(db_session, hook_id) logs = get_hook_execution_logs(db_session=db_session, hook_id=hook_id, limit=limit) return [ HookExecutionRecord( error_message=log.error_message, status_code=log.status_code, duration_ms=log.duration_ms, created_at=log.created_at, ) for log in logs ] ================================================ FILE: backend/ee/onyx/server/license/api.py ================================================ """License API endpoints for self-hosted deployments. These endpoints allow self-hosted Onyx instances to: 1. Claim a license after Stripe checkout (via cloud data plane proxy) 2. Upload a license file manually (for air-gapped deployments) 3. View license status and seat usage 4. Refresh/delete the local license NOTE: Cloud (MULTI_TENANT) deployments do NOT use these endpoints. Cloud licensing is managed via the control plane and gated_tenants Redis key. """ import requests from fastapi import APIRouter from fastapi import Depends from fastapi import File from fastapi import UploadFile from sqlalchemy.orm import Session from ee.onyx.auth.users import current_admin_user from ee.onyx.configs.app_configs import CLOUD_DATA_PLANE_URL from ee.onyx.db.license import delete_license as db_delete_license from ee.onyx.db.license import get_license from ee.onyx.db.license import get_license_metadata from ee.onyx.db.license import invalidate_license_cache from ee.onyx.db.license import refresh_license_cache from ee.onyx.db.license import update_license_cache from ee.onyx.db.license import upsert_license from ee.onyx.server.license.models import LicenseResponse from ee.onyx.server.license.models import LicenseSource from ee.onyx.server.license.models import LicenseStatusResponse from ee.onyx.server.license.models import LicenseUploadResponse from ee.onyx.server.license.models import SeatUsageResponse from ee.onyx.utils.license import verify_license_signature from onyx.auth.users import User from onyx.db.engine.sql_engine import get_session from onyx.error_handling.error_codes import OnyxErrorCode from onyx.error_handling.exceptions import OnyxError from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT logger = setup_logger() router = APIRouter(prefix="/license") # PEM-style delimiters used in license file format _PEM_BEGIN = "-----BEGIN ONYX LICENSE-----" _PEM_END = "-----END ONYX LICENSE-----" def _strip_pem_delimiters(content: str) -> str: """Strip PEM-style delimiters from license content if present.""" content = content.strip() if content.startswith(_PEM_BEGIN) and content.endswith(_PEM_END): # Remove first and last lines (the delimiters) lines = content.split("\n") return "\n".join(lines[1:-1]).strip() return content @router.get("") async def get_license_status( _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> LicenseStatusResponse: """Get current license status and seat usage.""" metadata = get_license_metadata(db_session) if not metadata: return LicenseStatusResponse(has_license=False) return LicenseStatusResponse( has_license=True, seats=metadata.seats, used_seats=metadata.used_seats, plan_type=metadata.plan_type, issued_at=metadata.issued_at, expires_at=metadata.expires_at, grace_period_end=metadata.grace_period_end, status=metadata.status, source=metadata.source, ) @router.get("/seats") async def get_seat_usage( _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> SeatUsageResponse: """Get detailed seat usage information.""" metadata = get_license_metadata(db_session) if not metadata: return SeatUsageResponse( total_seats=0, used_seats=0, available_seats=0, ) return SeatUsageResponse( total_seats=metadata.seats, used_seats=metadata.used_seats, available_seats=max(0, metadata.seats - metadata.used_seats), ) @router.post("/claim") async def claim_license( session_id: str | None = None, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> LicenseResponse: """ Claim a license from the control plane (self-hosted only). Two modes: 1. With session_id: After Stripe checkout, exchange session_id for license 2. Without session_id: Re-claim using existing license for auth Use without session_id after: - Updating seats via the billing API - Returning from the Stripe customer portal - Any operation that regenerates the license on control plane Claim a license from the control plane (self-hosted only). Two modes: 1. With session_id: After Stripe checkout, exchange session_id for license 2. Without session_id: Re-claim using existing license for auth """ if MULTI_TENANT: raise OnyxError( OnyxErrorCode.VALIDATION_ERROR, "License claiming is only available for self-hosted deployments", ) try: if session_id: # Claim license after checkout using session_id url = f"{CLOUD_DATA_PLANE_URL}/proxy/claim-license" response = requests.post( url, json={"session_id": session_id}, headers={"Content-Type": "application/json"}, timeout=30, ) else: # Re-claim using existing license for auth metadata = get_license_metadata(db_session) if not metadata or not metadata.tenant_id: raise OnyxError( OnyxErrorCode.VALIDATION_ERROR, "No license found. Provide session_id after checkout.", ) license_row = get_license(db_session) if not license_row or not license_row.license_data: raise OnyxError( OnyxErrorCode.VALIDATION_ERROR, "No license found in database", ) url = f"{CLOUD_DATA_PLANE_URL}/proxy/license/{metadata.tenant_id}" response = requests.get( url, headers={ "Authorization": f"Bearer {license_row.license_data}", "Content-Type": "application/json", }, timeout=30, ) response.raise_for_status() data = response.json() license_data = data.get("license") if not license_data: raise OnyxError(OnyxErrorCode.NOT_FOUND, "No license in response") # Verify signature before persisting payload = verify_license_signature(license_data) # Store in DB upsert_license(db_session, license_data) try: update_license_cache(payload, source=LicenseSource.AUTO_FETCH) except Exception as cache_error: logger.warning(f"Failed to update license cache: {cache_error}") logger.info( f"License claimed: seats={payload.seats}, expires={payload.expires_at.date()}" ) return LicenseResponse(success=True, license=payload) except requests.HTTPError as e: status_code = e.response.status_code if e.response is not None else 502 detail = "Failed to claim license" try: error_data = e.response.json() if e.response is not None else {} detail = error_data.get("detail", detail) except Exception: pass raise OnyxError( OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=status_code ) except ValueError as e: raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e)) except requests.RequestException: raise OnyxError( OnyxErrorCode.BAD_GATEWAY, "Failed to connect to license server" ) @router.post("/upload") async def upload_license( license_file: UploadFile = File(...), _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> LicenseUploadResponse: """ Upload a license file manually (self-hosted only). Used for air-gapped deployments where the cloud data plane is not accessible. The license file must be cryptographically signed by Onyx. """ if MULTI_TENANT: raise OnyxError( OnyxErrorCode.VALIDATION_ERROR, "License upload is only available for self-hosted deployments", ) try: content = await license_file.read() license_data = content.decode("utf-8").strip() # Strip PEM-style delimiters if present (used in .lic file format) license_data = _strip_pem_delimiters(license_data) # Remove any stray whitespace/newlines from user input license_data = license_data.strip() except UnicodeDecodeError: raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Invalid license file format") # Verify cryptographic signature - this is the only validation needed # The license's tenant_id identifies the customer in control plane, not locally try: payload = verify_license_signature(license_data) except ValueError as e: raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, str(e)) # Persist to DB and update cache upsert_license(db_session, license_data) try: update_license_cache(payload, source=LicenseSource.MANUAL_UPLOAD) except Exception as cache_error: logger.warning(f"Failed to update license cache: {cache_error}") return LicenseUploadResponse( success=True, message=f"License uploaded successfully. {payload.seats} seats, expires {payload.expires_at.date()}", ) @router.post("/refresh") async def refresh_license_cache_endpoint( _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> LicenseStatusResponse: """ Force refresh the license cache from the local database. Useful after manual database changes or to verify license validity. Does NOT fetch from control plane - use /claim for that. """ metadata = refresh_license_cache(db_session) if not metadata: return LicenseStatusResponse(has_license=False) return LicenseStatusResponse( has_license=True, seats=metadata.seats, used_seats=metadata.used_seats, plan_type=metadata.plan_type, issued_at=metadata.issued_at, expires_at=metadata.expires_at, grace_period_end=metadata.grace_period_end, status=metadata.status, source=metadata.source, ) @router.delete("") async def delete_license( _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> dict[str, bool]: """ Delete the current license. Admin only - removes license from database and invalidates cache. """ if MULTI_TENANT: raise OnyxError( OnyxErrorCode.VALIDATION_ERROR, "License deletion is only available for self-hosted deployments", ) try: invalidate_license_cache() except Exception as cache_error: logger.warning(f"Failed to invalidate license cache: {cache_error}") deleted = db_delete_license(db_session) return {"deleted": deleted} ================================================ FILE: backend/ee/onyx/server/license/models.py ================================================ from datetime import datetime from enum import Enum from pydantic import BaseModel from onyx.server.settings.models import ApplicationStatus class PlanType(str, Enum): MONTHLY = "monthly" ANNUAL = "annual" class LicenseSource(str, Enum): AUTO_FETCH = "auto_fetch" MANUAL_UPLOAD = "manual_upload" class LicensePayload(BaseModel): """The payload portion of a signed license.""" version: str tenant_id: str organization_name: str | None = None issued_at: datetime expires_at: datetime seats: int plan_type: PlanType billing_cycle: str | None = None grace_period_days: int = 30 stripe_subscription_id: str | None = None stripe_customer_id: str | None = None class LicenseData(BaseModel): """Full signed license structure.""" payload: LicensePayload signature: str class LicenseMetadata(BaseModel): """Cached license metadata stored in Redis.""" tenant_id: str organization_name: str | None = None seats: int used_seats: int plan_type: PlanType issued_at: datetime expires_at: datetime grace_period_end: datetime | None = None status: ApplicationStatus source: LicenseSource | None = None stripe_subscription_id: str | None = None class LicenseStatusResponse(BaseModel): """Response for license status API.""" has_license: bool seats: int = 0 used_seats: int = 0 plan_type: PlanType | None = None issued_at: datetime | None = None expires_at: datetime | None = None grace_period_end: datetime | None = None status: ApplicationStatus | None = None source: LicenseSource | None = None class LicenseResponse(BaseModel): """Response after license fetch/upload.""" success: bool message: str | None = None license: LicensePayload | None = None class LicenseUploadResponse(BaseModel): """Response after license upload.""" success: bool message: str | None = None class SeatUsageResponse(BaseModel): """Response for seat usage API.""" total_seats: int used_seats: int available_seats: int ================================================ FILE: backend/ee/onyx/server/manage/standard_answer.py ================================================ from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from sqlalchemy.orm import Session from ee.onyx.db.standard_answer import fetch_standard_answer from ee.onyx.db.standard_answer import fetch_standard_answer_categories from ee.onyx.db.standard_answer import fetch_standard_answer_category from ee.onyx.db.standard_answer import fetch_standard_answers from ee.onyx.db.standard_answer import insert_standard_answer from ee.onyx.db.standard_answer import insert_standard_answer_category from ee.onyx.db.standard_answer import remove_standard_answer from ee.onyx.db.standard_answer import update_standard_answer from ee.onyx.db.standard_answer import update_standard_answer_category from onyx.auth.users import current_admin_user from onyx.db.engine.sql_engine import get_session from onyx.db.models import User from onyx.server.manage.models import StandardAnswer from onyx.server.manage.models import StandardAnswerCategory from onyx.server.manage.models import StandardAnswerCategoryCreationRequest from onyx.server.manage.models import StandardAnswerCreationRequest router = APIRouter(prefix="/manage") @router.post("/admin/standard-answer") def create_standard_answer( standard_answer_creation_request: StandardAnswerCreationRequest, db_session: Session = Depends(get_session), _: User = Depends(current_admin_user), ) -> StandardAnswer: standard_answer_model = insert_standard_answer( keyword=standard_answer_creation_request.keyword, answer=standard_answer_creation_request.answer, category_ids=standard_answer_creation_request.categories, match_regex=standard_answer_creation_request.match_regex, match_any_keywords=standard_answer_creation_request.match_any_keywords, db_session=db_session, ) return StandardAnswer.from_model(standard_answer_model) @router.get("/admin/standard-answer") def list_standard_answers( db_session: Session = Depends(get_session), _: User = Depends(current_admin_user), ) -> list[StandardAnswer]: standard_answer_models = fetch_standard_answers(db_session=db_session) return [ StandardAnswer.from_model(standard_answer_model) for standard_answer_model in standard_answer_models ] @router.patch("/admin/standard-answer/{standard_answer_id}") def patch_standard_answer( standard_answer_id: int, standard_answer_creation_request: StandardAnswerCreationRequest, db_session: Session = Depends(get_session), _: User = Depends(current_admin_user), ) -> StandardAnswer: existing_standard_answer = fetch_standard_answer( standard_answer_id=standard_answer_id, db_session=db_session, ) if existing_standard_answer is None: raise HTTPException(status_code=404, detail="Standard answer not found") standard_answer_model = update_standard_answer( standard_answer_id=standard_answer_id, keyword=standard_answer_creation_request.keyword, answer=standard_answer_creation_request.answer, category_ids=standard_answer_creation_request.categories, match_regex=standard_answer_creation_request.match_regex, match_any_keywords=standard_answer_creation_request.match_any_keywords, db_session=db_session, ) return StandardAnswer.from_model(standard_answer_model) @router.delete("/admin/standard-answer/{standard_answer_id}") def delete_standard_answer( standard_answer_id: int, db_session: Session = Depends(get_session), _: User = Depends(current_admin_user), ) -> None: return remove_standard_answer( standard_answer_id=standard_answer_id, db_session=db_session, ) @router.post("/admin/standard-answer/category") def create_standard_answer_category( standard_answer_category_creation_request: StandardAnswerCategoryCreationRequest, db_session: Session = Depends(get_session), _: User = Depends(current_admin_user), ) -> StandardAnswerCategory: standard_answer_category_model = insert_standard_answer_category( category_name=standard_answer_category_creation_request.name, db_session=db_session, ) return StandardAnswerCategory.from_model(standard_answer_category_model) @router.get("/admin/standard-answer/category") def list_standard_answer_categories( db_session: Session = Depends(get_session), _: User = Depends(current_admin_user), ) -> list[StandardAnswerCategory]: standard_answer_category_models = fetch_standard_answer_categories( db_session=db_session ) return [ StandardAnswerCategory.from_model(standard_answer_category_model) for standard_answer_category_model in standard_answer_category_models ] @router.patch("/admin/standard-answer/category/{standard_answer_category_id}") def patch_standard_answer_category( standard_answer_category_id: int, standard_answer_category_creation_request: StandardAnswerCategoryCreationRequest, db_session: Session = Depends(get_session), _: User = Depends(current_admin_user), ) -> StandardAnswerCategory: existing_standard_answer_category = fetch_standard_answer_category( standard_answer_category_id=standard_answer_category_id, db_session=db_session, ) if existing_standard_answer_category is None: raise HTTPException( status_code=404, detail="Standard answer category not found" ) standard_answer_category_model = update_standard_answer_category( standard_answer_category_id=standard_answer_category_id, category_name=standard_answer_category_creation_request.name, db_session=db_session, ) return StandardAnswerCategory.from_model(standard_answer_category_model) ================================================ FILE: backend/ee/onyx/server/middleware/license_enforcement.py ================================================ """Middleware to enforce license status for SELF-HOSTED deployments only. NOTE: This middleware is NOT used for multi-tenant (cloud) deployments. Multi-tenant gating is handled separately by the control plane via the /tenants/product-gating endpoint and is_tenant_gated() checks. IMPORTANT: Mutual Exclusivity with ENTERPRISE_EDITION_ENABLED ============================================================ This middleware is controlled by LICENSE_ENFORCEMENT_ENABLED env var. It works alongside the legacy ENTERPRISE_EDITION_ENABLED system: - LICENSE_ENFORCEMENT_ENABLED=false (default): Middleware is disabled. EE features are controlled solely by ENTERPRISE_EDITION_ENABLED. This preserves legacy behavior. - LICENSE_ENFORCEMENT_ENABLED=true: Middleware actively enforces license status. EE features require a valid license, regardless of ENTERPRISE_EDITION_ENABLED. Eventually, ENTERPRISE_EDITION_ENABLED will be removed and license enforcement will be the only mechanism for gating EE features. License Enforcement States (when enabled) ========================================= For self-hosted deployments: 1. No license (never subscribed): - Allow community features (basic connectors, search, chat) - Block EE-only features (analytics, user groups, etc.) 2. GATED_ACCESS (fully expired): - Block all routes except billing/auth/license - User must renew subscription to continue 3. Valid license (ACTIVE, GRACE_PERIOD, PAYMENT_REMINDER): - Full access to all EE features - Seat limits enforced - GRACE_PERIOD/PAYMENT_REMINDER are for notifications only, not blocking """ import logging from collections.abc import Awaitable from collections.abc import Callable from fastapi import FastAPI from fastapi import Request from fastapi import Response from fastapi.responses import JSONResponse from sqlalchemy.exc import SQLAlchemyError from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED from ee.onyx.configs.license_enforcement_config import EE_ONLY_PATH_PREFIXES from ee.onyx.configs.license_enforcement_config import ( LICENSE_ENFORCEMENT_ALLOWED_PREFIXES, ) from ee.onyx.db.license import get_cached_license_metadata from ee.onyx.db.license import refresh_license_cache from onyx.cache.interface import CACHE_TRANSIENT_ERRORS from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.server.settings.models import ApplicationStatus from shared_configs.contextvars import get_current_tenant_id def _is_path_allowed(path: str) -> bool: """Check if path is in allowlist (prefix match).""" return any( path.startswith(prefix) for prefix in LICENSE_ENFORCEMENT_ALLOWED_PREFIXES ) def _is_ee_only_path(path: str) -> bool: """Check if path requires EE license (prefix match).""" return any(path.startswith(prefix) for prefix in EE_ONLY_PATH_PREFIXES) def add_license_enforcement_middleware( app: FastAPI, logger: logging.LoggerAdapter ) -> None: logger.info("License enforcement middleware registered") @app.middleware("http") async def enforce_license( request: Request, call_next: Callable[[Request], Awaitable[Response]] ) -> Response: """Block requests when license is expired/gated.""" if not LICENSE_ENFORCEMENT_ENABLED: return await call_next(request) path = request.url.path if path.startswith("/api"): path = path[4:] if _is_path_allowed(path): return await call_next(request) is_gated = False tenant_id = get_current_tenant_id() try: metadata = get_cached_license_metadata(tenant_id) # If no cached metadata, check database (cache may have been cleared) if not metadata: logger.debug( "[license_enforcement] No cached license, checking database..." ) try: with get_session_with_current_tenant() as db_session: metadata = refresh_license_cache(db_session, tenant_id) if metadata: logger.info( "[license_enforcement] Loaded license from database" ) except SQLAlchemyError as db_error: logger.warning( f"[license_enforcement] Failed to check database for license: {db_error}" ) if metadata: # User HAS a license (current or expired) if metadata.status == ApplicationStatus.GATED_ACCESS: # License fully expired - gate the user # Note: GRACE_PERIOD and PAYMENT_REMINDER are for notifications only, # they don't block access is_gated = True else: # License is active - check seat limit # used_seats in cache is kept accurate via invalidation # when users are added/removed if metadata.used_seats > metadata.seats: logger.info( f"[license_enforcement] Blocking request: " f"seat limit exceeded ({metadata.used_seats}/{metadata.seats})" ) return JSONResponse( status_code=402, content={ "detail": { "error": "seat_limit_exceeded", "message": f"Seat limit exceeded: {metadata.used_seats} of {metadata.seats} seats used.", "used_seats": metadata.used_seats, "seats": metadata.seats, } }, ) else: # No license in cache OR database = never subscribed # Allow community features, but block EE-only features if _is_ee_only_path(path): logger.info( f"[license_enforcement] Blocking EE-only path (no license): {path}" ) return JSONResponse( status_code=402, content={ "detail": { "error": "enterprise_license_required", "message": "This feature requires an Enterprise license. " "Please upgrade to access this functionality.", } }, ) logger.debug( "[license_enforcement] No license, allowing community features" ) is_gated = False except CACHE_TRANSIENT_ERRORS as e: logger.warning(f"Failed to check license metadata: {e}") # Fail open - don't block users due to cache connectivity issues is_gated = False if is_gated: logger.info( f"[license_enforcement] Blocking request (license expired): {path}" ) return JSONResponse( status_code=402, content={ "detail": { "error": "license_expired", "message": "Your subscription has expired. Please update your billing.", } }, ) return await call_next(request) ================================================ FILE: backend/ee/onyx/server/middleware/tenant_tracking.py ================================================ import logging from collections.abc import Awaitable from collections.abc import Callable from fastapi import FastAPI from fastapi import HTTPException from fastapi import Request from fastapi import Response from ee.onyx.auth.users import decode_anonymous_user_jwt_token from onyx.auth.utils import extract_tenant_from_auth_header from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME from onyx.configs.constants import TENANT_ID_COOKIE_NAME from onyx.db.engine.sql_engine import is_valid_schema_name from onyx.redis.redis_pool import retrieve_auth_token_data_from_redis from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR def add_api_server_tenant_id_middleware( app: FastAPI, logger: logging.LoggerAdapter ) -> None: @app.middleware("http") async def set_tenant_id( request: Request, call_next: Callable[[Request], Awaitable[Response]] ) -> Response: """Extracts the tenant id from multiple locations and sets the context var. This is very specific to the api server and probably not something you'd want to use elsewhere. """ try: if MULTI_TENANT: tenant_id = await _get_tenant_id_from_request(request, logger) else: tenant_id = POSTGRES_DEFAULT_SCHEMA CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) return await call_next(request) except Exception as e: logger.exception(f"Error in tenant ID middleware: {str(e)}") raise async def _get_tenant_id_from_request( request: Request, logger: logging.LoggerAdapter ) -> str: """ Attempt to extract tenant_id from: 1) The API key or PAT (Personal Access Token) header 2) The Redis-based token (stored in Cookie: fastapiusersauth) 3) The anonymous user cookie Fallback: POSTGRES_DEFAULT_SCHEMA """ # Check for API key or PAT in Authorization header tenant_id = extract_tenant_from_auth_header(request) if tenant_id is not None: return tenant_id try: # Look up token data in Redis token_data = await retrieve_auth_token_data_from_redis(request) if token_data: tenant_id_from_payload = token_data.get( "tenant_id", POSTGRES_DEFAULT_SCHEMA ) tenant_id = ( str(tenant_id_from_payload) if tenant_id_from_payload is not None else None ) if tenant_id and not is_valid_schema_name(tenant_id): raise HTTPException(status_code=400, detail="Invalid tenant ID format") # Check for anonymous user cookie anonymous_user_cookie = request.cookies.get(ANONYMOUS_USER_COOKIE_NAME) if anonymous_user_cookie: try: anonymous_user_data = decode_anonymous_user_jwt_token( anonymous_user_cookie ) tenant_id = anonymous_user_data.get( "tenant_id", POSTGRES_DEFAULT_SCHEMA ) if not tenant_id or not is_valid_schema_name(tenant_id): raise HTTPException( status_code=400, detail="Invalid tenant ID format" ) return tenant_id except Exception as e: logger.error(f"Error decoding anonymous user cookie: {str(e)}") # Continue and attempt to authenticate logger.debug( "Token data not found or expired in Redis, defaulting to POSTGRES_DEFAULT_SCHEMA" ) # Return POSTGRES_DEFAULT_SCHEMA, so non-authenticated requests are sent to the default schema # The CURRENT_TENANT_ID_CONTEXTVAR is initialized with POSTGRES_DEFAULT_SCHEMA, # so we maintain consistency by returning it here when no valid tenant is found. return POSTGRES_DEFAULT_SCHEMA except Exception as e: logger.error(f"Unexpected error in _get_tenant_id_from_request: {str(e)}") raise HTTPException(status_code=500, detail="Internal server error") finally: if tenant_id: return tenant_id # As a final step, check for explicit tenant_id cookie tenant_id_cookie = request.cookies.get(TENANT_ID_COOKIE_NAME) if tenant_id_cookie and is_valid_schema_name(tenant_id_cookie): return tenant_id_cookie # If we've reached this point, return the default schema return POSTGRES_DEFAULT_SCHEMA ================================================ FILE: backend/ee/onyx/server/oauth/api.py ================================================ import base64 import uuid from fastapi import Depends from fastapi import HTTPException from fastapi.responses import JSONResponse from ee.onyx.server.oauth.api_router import router from ee.onyx.server.oauth.confluence_cloud import ConfluenceCloudOAuth from ee.onyx.server.oauth.google_drive import GoogleDriveOAuth from ee.onyx.server.oauth.slack import SlackOAuth from onyx.auth.users import current_admin_user from onyx.configs.app_configs import DEV_MODE from onyx.configs.constants import DocumentSource from onyx.db.models import User from onyx.redis.redis_pool import get_redis_client from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() @router.post("/prepare-authorization-request") def prepare_authorization_request( connector: DocumentSource, redirect_on_success: str | None, user: User = Depends(current_admin_user), tenant_id: str | None = Depends(get_current_tenant_id), ) -> JSONResponse: """Used by the frontend to generate the url for the user's browser during auth request. Example: https://www.oauth.com/oauth2-servers/authorization/the-authorization-request/ """ # create random oauth state param for security and to retrieve user data later oauth_uuid = uuid.uuid4() oauth_uuid_str = str(oauth_uuid) # urlsafe b64 encode the uuid for the oauth url oauth_state = ( base64.urlsafe_b64encode(oauth_uuid.bytes).rstrip(b"=").decode("utf-8") ) session: str | None = None if connector == DocumentSource.SLACK: if not DEV_MODE: oauth_url = SlackOAuth.generate_oauth_url(oauth_state) else: oauth_url = SlackOAuth.generate_dev_oauth_url(oauth_state) session = SlackOAuth.session_dump_json( email=user.email, redirect_on_success=redirect_on_success ) elif connector == DocumentSource.CONFLUENCE: if not DEV_MODE: oauth_url = ConfluenceCloudOAuth.generate_oauth_url(oauth_state) else: oauth_url = ConfluenceCloudOAuth.generate_dev_oauth_url(oauth_state) session = ConfluenceCloudOAuth.session_dump_json( email=user.email, redirect_on_success=redirect_on_success ) elif connector == DocumentSource.GOOGLE_DRIVE: if not DEV_MODE: oauth_url = GoogleDriveOAuth.generate_oauth_url(oauth_state) else: oauth_url = GoogleDriveOAuth.generate_dev_oauth_url(oauth_state) session = GoogleDriveOAuth.session_dump_json( email=user.email, redirect_on_success=redirect_on_success ) else: oauth_url = None if not oauth_url: raise HTTPException( status_code=404, detail=f"The document source type {connector} does not have OAuth implemented", ) if not session: raise HTTPException( status_code=500, detail=f"The document source type {connector} failed to generate an OAuth session.", ) r = get_redis_client(tenant_id=tenant_id) # store important session state to retrieve when the user is redirected back # 10 min is the max we want an oauth flow to be valid r.set(f"da_oauth:{oauth_uuid_str}", session, ex=600) return JSONResponse(content={"url": oauth_url}) ================================================ FILE: backend/ee/onyx/server/oauth/api_router.py ================================================ from fastapi import APIRouter router: APIRouter = APIRouter(prefix="/oauth") ================================================ FILE: backend/ee/onyx/server/oauth/confluence_cloud.py ================================================ import base64 import uuid from datetime import datetime from datetime import timedelta from datetime import timezone from typing import Any from typing import cast import requests from fastapi import Depends from fastapi import HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel from pydantic import ValidationError from sqlalchemy.orm import Session from ee.onyx.server.oauth.api_router import router from onyx.auth.users import current_admin_user from onyx.configs.app_configs import DEV_MODE from onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID from onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET from onyx.configs.app_configs import WEB_DOMAIN from onyx.configs.constants import DocumentSource from onyx.connectors.confluence.utils import CONFLUENCE_OAUTH_TOKEN_URL from onyx.db.credentials import create_credential from onyx.db.credentials import fetch_credential_by_id_for_user from onyx.db.credentials import update_credential_json from onyx.db.engine.sql_engine import get_session from onyx.db.models import User from onyx.redis.redis_pool import get_redis_client from onyx.server.documents.models import CredentialBase from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() class ConfluenceCloudOAuth: # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/ class OAuthSession(BaseModel): """Stored in redis to be looked up on callback""" email: str redirect_on_success: str | None # Where to send the user if OAuth flow succeeds class TokenResponse(BaseModel): access_token: str expires_in: int token_type: str refresh_token: str scope: str class AccessibleResources(BaseModel): id: str name: str url: str scopes: list[str] avatarUrl: str CLIENT_ID = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID CLIENT_SECRET = OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET TOKEN_URL = CONFLUENCE_OAUTH_TOKEN_URL ACCESSIBLE_RESOURCE_URL = ( "https://api.atlassian.com/oauth/token/accessible-resources" ) # All read scopes per https://developer.atlassian.com/cloud/confluence/scopes-for-oauth-2-3LO-and-forge-apps/ CONFLUENCE_OAUTH_SCOPE = ( # classic scope "read:confluence-space.summary%20" "read:confluence-props%20" "read:confluence-content.all%20" "read:confluence-content.summary%20" "read:confluence-content.permission%20" "read:confluence-user%20" "read:confluence-groups%20" "read:space:confluence%20" "readonly:content.attachment:confluence%20" "search:confluence%20" # granular scope "read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api "read:content-details:confluence%20" # for permission sync "offline_access" ) REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/callback" DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}" # eventually for Confluence Data Center # oauth_url = ( # f"http://localhost:8090/rest/oauth/v2/authorize?client_id={CONFLUENCE_OAUTH_CLIENT_ID}" # f"&scope={CONFLUENCE_OAUTH_SCOPE_2}" # f"&redirect_uri={redirectme_uri}" # ) @classmethod def generate_oauth_url(cls, state: str) -> str: return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state) @classmethod def generate_dev_oauth_url(cls, state: str) -> str: """dev mode workaround for localhost testing - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https """ return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state) @classmethod def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str: # https://developer.atlassian.com/cloud/jira/platform/oauth-2-3lo-apps/#1--direct-the-user-to-the-authorization-url-to-get-an-authorization-code url = ( "https://auth.atlassian.com/authorize" f"?audience=api.atlassian.com" f"&client_id={cls.CLIENT_ID}" f"&scope={cls.CONFLUENCE_OAUTH_SCOPE}" f"&redirect_uri={redirect_uri}" f"&state={state}" "&response_type=code" "&prompt=consent" ) return url @classmethod def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str: """Temporary state to store in redis. to be looked up on auth response. Returns a json string. """ session = ConfluenceCloudOAuth.OAuthSession( email=email, redirect_on_success=redirect_on_success ) return session.model_dump_json() @classmethod def parse_session(cls, session_json: str) -> OAuthSession: session = ConfluenceCloudOAuth.OAuthSession.model_validate_json(session_json) return session @classmethod def generate_finalize_url(cls, credential_id: int) -> str: return f"{WEB_DOMAIN}/admin/connectors/confluence/oauth/finalize?credential={credential_id}" @router.post("/connector/confluence/callback") def confluence_oauth_callback( code: str, state: str, user: User = Depends(current_admin_user), db_session: Session = Depends(get_session), tenant_id: str | None = Depends(get_current_tenant_id), ) -> JSONResponse: """Handles the backend logic for the frontend page that the user is redirected to after visiting the oauth authorization url.""" if not ConfluenceCloudOAuth.CLIENT_ID or not ConfluenceCloudOAuth.CLIENT_SECRET: raise HTTPException( status_code=500, detail="Confluence Cloud client ID or client secret is not configured.", ) r = get_redis_client(tenant_id=tenant_id) # recover the state padded_state = state + "=" * ( -len(state) % 4 ) # Add padding back (Base64 decoding requires padding) uuid_bytes = base64.urlsafe_b64decode( padded_state ) # Decode the Base64 string back to bytes # Convert bytes back to a UUID oauth_uuid = uuid.UUID(bytes=uuid_bytes) oauth_uuid_str = str(oauth_uuid) r_key = f"da_oauth:{oauth_uuid_str}" session_json_bytes = cast(bytes, r.get(r_key)) if not session_json_bytes: raise HTTPException( status_code=400, detail=f"Confluence Cloud OAuth failed - OAuth state key not found: key={r_key}", ) session_json = session_json_bytes.decode("utf-8") try: session = ConfluenceCloudOAuth.parse_session(session_json) if not DEV_MODE: redirect_uri = ConfluenceCloudOAuth.REDIRECT_URI else: redirect_uri = ConfluenceCloudOAuth.DEV_REDIRECT_URI # Exchange the authorization code for an access token response = requests.post( ConfluenceCloudOAuth.TOKEN_URL, headers={"Content-Type": "application/x-www-form-urlencoded"}, data={ "client_id": ConfluenceCloudOAuth.CLIENT_ID, "client_secret": ConfluenceCloudOAuth.CLIENT_SECRET, "code": code, "redirect_uri": redirect_uri, "grant_type": "authorization_code", }, ) token_response: ConfluenceCloudOAuth.TokenResponse | None = None try: token_response = ConfluenceCloudOAuth.TokenResponse.model_validate_json( response.text ) except Exception: raise RuntimeError( "Confluence Cloud OAuth failed during code/token exchange." ) now = datetime.now(timezone.utc) expires_at = now + timedelta(seconds=token_response.expires_in) credential_info = CredentialBase( credential_json={ "confluence_access_token": token_response.access_token, "confluence_refresh_token": token_response.refresh_token, "created_at": now.isoformat(), "expires_at": expires_at.isoformat(), "expires_in": token_response.expires_in, "scope": token_response.scope, }, admin_public=True, source=DocumentSource.CONFLUENCE, name="Confluence Cloud OAuth", ) credential = create_credential(credential_info, user, db_session) except Exception as e: return JSONResponse( status_code=500, content={ "success": False, "message": f"An error occurred during Confluence Cloud OAuth: {str(e)}", }, ) finally: r.delete(r_key) # return the result return JSONResponse( content={ "success": True, "message": "Confluence Cloud OAuth completed successfully.", "finalize_url": ConfluenceCloudOAuth.generate_finalize_url(credential.id), "redirect_on_success": session.redirect_on_success, } ) @router.get("/connector/confluence/accessible-resources") def confluence_oauth_accessible_resources( credential_id: int, user: User = Depends(current_admin_user), db_session: Session = Depends(get_session), tenant_id: str | None = Depends(get_current_tenant_id), # noqa: ARG001 ) -> JSONResponse: """Atlassian's API is weird and does not supply us with enough info to be in a usable state after authorizing. All API's require a cloud id. We have to list the accessible resources/sites and let the user choose which site to use.""" credential = fetch_credential_by_id_for_user(credential_id, user, db_session) if not credential: raise HTTPException(400, f"Credential {credential_id} not found.") credential_dict = ( credential.credential_json.get_value(apply_mask=False) if credential.credential_json else {} ) access_token = credential_dict["confluence_access_token"] try: # Exchange the authorization code for an access token response = requests.get( ConfluenceCloudOAuth.ACCESSIBLE_RESOURCE_URL, headers={ "Authorization": f"Bearer {access_token}", "Accept": "application/json", }, ) response.raise_for_status() accessible_resources_data = response.json() # Validate the list of AccessibleResources try: accessible_resources = [ ConfluenceCloudOAuth.AccessibleResources(**resource) for resource in accessible_resources_data ] except ValidationError as e: raise RuntimeError(f"Failed to parse accessible resources: {e}") except Exception as e: return JSONResponse( status_code=500, content={ "success": False, "message": f"An error occurred retrieving Confluence Cloud accessible resources: {str(e)}", }, ) # return the result return JSONResponse( content={ "success": True, "message": "Confluence Cloud get accessible resources completed successfully.", "accessible_resources": [ resource.model_dump() for resource in accessible_resources ], } ) @router.post("/connector/confluence/finalize") def confluence_oauth_finalize( credential_id: int, cloud_id: str, cloud_name: str, cloud_url: str, user: User = Depends(current_admin_user), db_session: Session = Depends(get_session), tenant_id: str | None = Depends(get_current_tenant_id), # noqa: ARG001 ) -> JSONResponse: """Saves the info for the selected cloud site to the credential. This is the final step in the confluence oauth flow where after the traditional OAuth process, the user has to select a site to associate with the credentials. After this, the credential is usable.""" credential = fetch_credential_by_id_for_user(credential_id, user, db_session) if not credential: raise HTTPException( status_code=400, detail=f"Confluence Cloud OAuth failed - credential {credential_id} not found.", ) existing_credential_json = ( credential.credential_json.get_value(apply_mask=False) if credential.credential_json else {} ) new_credential_json: dict[str, Any] = dict(existing_credential_json) new_credential_json["cloud_id"] = cloud_id new_credential_json["cloud_name"] = cloud_name new_credential_json["wiki_base"] = cloud_url try: update_credential_json(credential_id, new_credential_json, user, db_session) except Exception as e: return JSONResponse( status_code=500, content={ "success": False, "message": f"An error occurred during Confluence Cloud OAuth: {str(e)}", }, ) # return the result return JSONResponse( content={ "success": True, "message": "Confluence Cloud OAuth finalized successfully.", "redirect_url": f"{WEB_DOMAIN}/admin/connectors/confluence", } ) ================================================ FILE: backend/ee/onyx/server/oauth/google_drive.py ================================================ import base64 import json import uuid from typing import Any from typing import cast import requests from fastapi import Depends from fastapi import HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel from sqlalchemy.orm import Session from ee.onyx.server.oauth.api_router import router from onyx.auth.users import current_admin_user from onyx.configs.app_configs import DEV_MODE from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET from onyx.configs.app_configs import WEB_DOMAIN from onyx.configs.constants import DocumentSource from onyx.connectors.google_utils.google_auth import get_google_oauth_creds from onyx.connectors.google_utils.google_auth import sanitize_oauth_credentials from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_AUTHENTICATION_METHOD, ) from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_DICT_TOKEN_KEY, ) from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_PRIMARY_ADMIN_KEY, ) from onyx.connectors.google_utils.shared_constants import ( GoogleOAuthAuthenticationMethod, ) from onyx.db.credentials import create_credential from onyx.db.engine.sql_engine import get_session from onyx.db.models import User from onyx.redis.redis_pool import get_redis_client from onyx.server.documents.models import CredentialBase from shared_configs.contextvars import get_current_tenant_id class GoogleDriveOAuth: # https://developers.google.com/identity/protocols/oauth2 # https://developers.google.com/identity/protocols/oauth2/web-server class OAuthSession(BaseModel): """Stored in redis to be looked up on callback""" email: str redirect_on_success: str | None # Where to send the user if OAuth flow succeeds CLIENT_ID = OAUTH_GOOGLE_DRIVE_CLIENT_ID CLIENT_SECRET = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET TOKEN_URL = "https://oauth2.googleapis.com/token" # SCOPE is per https://docs.danswer.dev/connectors/google-drive # TODO: Merge with or use google_utils.GOOGLE_SCOPES SCOPE = ( "https://www.googleapis.com/auth/drive.readonly%20" "https://www.googleapis.com/auth/drive.metadata.readonly%20" "https://www.googleapis.com/auth/admin.directory.user.readonly%20" "https://www.googleapis.com/auth/admin.directory.group.readonly" ) REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/google-drive/oauth/callback" DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}" @classmethod def generate_oauth_url(cls, state: str) -> str: return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state) @classmethod def generate_dev_oauth_url(cls, state: str) -> str: """dev mode workaround for localhost testing - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https """ return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state) @classmethod def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str: # without prompt=consent, a refresh token is only issued the first time the user approves url = ( f"https://accounts.google.com/o/oauth2/v2/auth" f"?client_id={cls.CLIENT_ID}" f"&redirect_uri={redirect_uri}" "&response_type=code" f"&scope={cls.SCOPE}" "&access_type=offline" f"&state={state}" "&prompt=consent" ) return url @classmethod def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str: """Temporary state to store in redis. to be looked up on auth response. Returns a json string. """ session = GoogleDriveOAuth.OAuthSession( email=email, redirect_on_success=redirect_on_success ) return session.model_dump_json() @classmethod def parse_session(cls, session_json: str) -> OAuthSession: session = GoogleDriveOAuth.OAuthSession.model_validate_json(session_json) return session @router.post("/connector/google-drive/callback") def handle_google_drive_oauth_callback( code: str, state: str, user: User = Depends(current_admin_user), db_session: Session = Depends(get_session), tenant_id: str | None = Depends(get_current_tenant_id), ) -> JSONResponse: if not GoogleDriveOAuth.CLIENT_ID or not GoogleDriveOAuth.CLIENT_SECRET: raise HTTPException( status_code=500, detail="Google Drive client ID or client secret is not configured.", ) r = get_redis_client(tenant_id=tenant_id) # recover the state padded_state = state + "=" * ( -len(state) % 4 ) # Add padding back (Base64 decoding requires padding) uuid_bytes = base64.urlsafe_b64decode( padded_state ) # Decode the Base64 string back to bytes # Convert bytes back to a UUID oauth_uuid = uuid.UUID(bytes=uuid_bytes) oauth_uuid_str = str(oauth_uuid) r_key = f"da_oauth:{oauth_uuid_str}" session_json_bytes = cast(bytes, r.get(r_key)) if not session_json_bytes: raise HTTPException( status_code=400, detail=f"Google Drive OAuth failed - OAuth state key not found: key={r_key}", ) session_json = session_json_bytes.decode("utf-8") try: session = GoogleDriveOAuth.parse_session(session_json) if not DEV_MODE: redirect_uri = GoogleDriveOAuth.REDIRECT_URI else: redirect_uri = GoogleDriveOAuth.DEV_REDIRECT_URI # Exchange the authorization code for an access token response = requests.post( GoogleDriveOAuth.TOKEN_URL, headers={"Content-Type": "application/x-www-form-urlencoded"}, data={ "client_id": GoogleDriveOAuth.CLIENT_ID, "client_secret": GoogleDriveOAuth.CLIENT_SECRET, "code": code, "redirect_uri": redirect_uri, "grant_type": "authorization_code", }, ) response.raise_for_status() authorization_response: dict[str, Any] = response.json() # the connector wants us to store the json in its authorized_user_info format # returned from OAuthCredentials.get_authorized_user_info(). # So refresh immediately via get_google_oauth_creds with the params filled in # from fields in authorization_response to get the json we need authorized_user_info = {} authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET authorized_user_info["refresh_token"] = authorization_response["refresh_token"] token_json_str = json.dumps(authorized_user_info) oauth_creds = get_google_oauth_creds( token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE ) if not oauth_creds: raise RuntimeError("get_google_oauth_creds returned None.") # save off the credentials oauth_creds_sanitized_json_str = sanitize_oauth_credentials(oauth_creds) credential_dict: dict[str, str] = {} credential_dict[DB_CREDENTIALS_DICT_TOKEN_KEY] = oauth_creds_sanitized_json_str credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = session.email credential_dict[DB_CREDENTIALS_AUTHENTICATION_METHOD] = ( GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value ) credential_info = CredentialBase( credential_json=credential_dict, admin_public=True, source=DocumentSource.GOOGLE_DRIVE, name="OAuth (interactive)", ) create_credential(credential_info, user, db_session) except Exception as e: return JSONResponse( status_code=500, content={ "success": False, "message": f"An error occurred during Google Drive OAuth: {str(e)}", }, ) finally: r.delete(r_key) # return the result return JSONResponse( content={ "success": True, "message": "Google Drive OAuth completed successfully.", "finalize_url": None, "redirect_on_success": session.redirect_on_success, } ) ================================================ FILE: backend/ee/onyx/server/oauth/slack.py ================================================ import base64 import uuid from typing import cast import requests from fastapi import Depends from fastapi import HTTPException from fastapi.responses import JSONResponse from pydantic import BaseModel from sqlalchemy.orm import Session from ee.onyx.server.oauth.api_router import router from onyx.auth.users import current_admin_user from onyx.configs.app_configs import DEV_MODE from onyx.configs.app_configs import OAUTH_SLACK_CLIENT_ID from onyx.configs.app_configs import OAUTH_SLACK_CLIENT_SECRET from onyx.configs.app_configs import WEB_DOMAIN from onyx.configs.constants import DocumentSource from onyx.db.credentials import create_credential from onyx.db.engine.sql_engine import get_session from onyx.db.models import User from onyx.redis.redis_pool import get_redis_client from onyx.server.documents.models import CredentialBase from shared_configs.contextvars import get_current_tenant_id class SlackOAuth: # https://knock.app/blog/how-to-authenticate-users-in-slack-using-oauth # Example: https://api.slack.com/authentication/oauth-v2#exchanging class OAuthSession(BaseModel): """Stored in redis to be looked up on callback""" email: str redirect_on_success: str | None # Where to send the user if OAuth flow succeeds CLIENT_ID = OAUTH_SLACK_CLIENT_ID CLIENT_SECRET = OAUTH_SLACK_CLIENT_SECRET TOKEN_URL = "https://slack.com/api/oauth.v2.access" # SCOPE is per https://docs.danswer.dev/connectors/slack BOT_SCOPE = ( "channels:history," "channels:read," "groups:history," "groups:read," "channels:join," "im:history," "users:read," "users:read.email," "usergroups:read" ) REDIRECT_URI = f"{WEB_DOMAIN}/admin/connectors/slack/oauth/callback" DEV_REDIRECT_URI = f"https://redirectmeto.com/{REDIRECT_URI}" @classmethod def generate_oauth_url(cls, state: str) -> str: return cls._generate_oauth_url_helper(cls.REDIRECT_URI, state) @classmethod def generate_dev_oauth_url(cls, state: str) -> str: """dev mode workaround for localhost testing - https://www.nango.dev/blog/oauth-redirects-on-localhost-with-https """ return cls._generate_oauth_url_helper(cls.DEV_REDIRECT_URI, state) @classmethod def _generate_oauth_url_helper(cls, redirect_uri: str, state: str) -> str: url = ( f"https://slack.com/oauth/v2/authorize" f"?client_id={cls.CLIENT_ID}" f"&redirect_uri={redirect_uri}" f"&scope={cls.BOT_SCOPE}" f"&state={state}" ) return url @classmethod def session_dump_json(cls, email: str, redirect_on_success: str | None) -> str: """Temporary state to store in redis. to be looked up on auth response. Returns a json string. """ session = SlackOAuth.OAuthSession( email=email, redirect_on_success=redirect_on_success ) return session.model_dump_json() @classmethod def parse_session(cls, session_json: str) -> OAuthSession: session = SlackOAuth.OAuthSession.model_validate_json(session_json) return session @router.post("/connector/slack/callback") def handle_slack_oauth_callback( code: str, state: str, user: User = Depends(current_admin_user), db_session: Session = Depends(get_session), tenant_id: str | None = Depends(get_current_tenant_id), ) -> JSONResponse: if not SlackOAuth.CLIENT_ID or not SlackOAuth.CLIENT_SECRET: raise HTTPException( status_code=500, detail="Slack client ID or client secret is not configured.", ) r = get_redis_client(tenant_id=tenant_id) # recover the state padded_state = state + "=" * ( -len(state) % 4 ) # Add padding back (Base64 decoding requires padding) uuid_bytes = base64.urlsafe_b64decode( padded_state ) # Decode the Base64 string back to bytes # Convert bytes back to a UUID oauth_uuid = uuid.UUID(bytes=uuid_bytes) oauth_uuid_str = str(oauth_uuid) r_key = f"da_oauth:{oauth_uuid_str}" session_json_bytes = cast(bytes, r.get(r_key)) if not session_json_bytes: raise HTTPException( status_code=400, detail=f"Slack OAuth failed - OAuth state key not found: key={r_key}", ) session_json = session_json_bytes.decode("utf-8") try: session = SlackOAuth.parse_session(session_json) if not DEV_MODE: redirect_uri = SlackOAuth.REDIRECT_URI else: redirect_uri = SlackOAuth.DEV_REDIRECT_URI # Exchange the authorization code for an access token response = requests.post( SlackOAuth.TOKEN_URL, headers={"Content-Type": "application/x-www-form-urlencoded"}, data={ "client_id": SlackOAuth.CLIENT_ID, "client_secret": SlackOAuth.CLIENT_SECRET, "code": code, "redirect_uri": redirect_uri, }, ) response_data = response.json() if not response_data.get("ok"): raise HTTPException( status_code=400, detail=f"Slack OAuth failed: {response_data.get('error')}", ) # Extract token and team information access_token: str = response_data.get("access_token") team_id: str = response_data.get("team", {}).get("id") authed_user_id: str = response_data.get("authed_user", {}).get("id") credential_info = CredentialBase( credential_json={"slack_bot_token": access_token}, admin_public=True, source=DocumentSource.SLACK, name="Slack OAuth", ) create_credential(credential_info, user, db_session) except Exception as e: return JSONResponse( status_code=500, content={ "success": False, "message": f"An error occurred during Slack OAuth: {str(e)}", }, ) finally: r.delete(r_key) # return the result return JSONResponse( content={ "success": True, "message": "Slack OAuth completed successfully.", "finalize_url": None, "redirect_on_success": session.redirect_on_success, "team_id": team_id, "authed_user_id": authed_user_id, } ) ================================================ FILE: backend/ee/onyx/server/query_and_chat/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/server/query_and_chat/models.py ================================================ from collections.abc import Sequence from datetime import datetime from pydantic import BaseModel from pydantic import Field from onyx.context.search.models import BaseFilters from onyx.context.search.models import InferenceSection from onyx.context.search.models import SearchDoc from onyx.server.manage.models import StandardAnswer class StandardAnswerRequest(BaseModel): message: str slack_bot_categories: list[str] class StandardAnswerResponse(BaseModel): standard_answers: list[StandardAnswer] = Field(default_factory=list) class SearchFlowClassificationRequest(BaseModel): user_query: str class SearchFlowClassificationResponse(BaseModel): is_search_flow: bool # NOTE: This model is used for the core flow of the Onyx application, any # changes to it should be reviewed and approved by an experienced team member. # It is very important to 1. avoid bloat and 2. that this remains backwards # compatible across versions. class SendSearchQueryRequest(BaseModel): search_query: str filters: BaseFilters | None = None num_docs_fed_to_llm_selection: int | None = None run_query_expansion: bool = False num_hits: int = 30 hybrid_alpha: float | None = None include_content: bool = False stream: bool = False class SearchDocWithContent(SearchDoc): # Allows None because this is determined by a flag but the object used in code # of the search path uses this type content: str | None @classmethod def from_inference_sections( cls, sections: Sequence[InferenceSection], include_content: bool = False, is_internet: bool = False, ) -> list["SearchDocWithContent"]: """Convert InferenceSections to SearchDocWithContent objects. Args: sections: Sequence of InferenceSection objects include_content: If True, populate content field with combined_content is_internet: Whether these are internet search results Returns: List of SearchDocWithContent with optional content """ if not sections: return [] return [ cls( document_id=(chunk := section.center_chunk).document_id, chunk_ind=chunk.chunk_id, semantic_identifier=chunk.semantic_identifier or "Unknown", link=chunk.source_links[0] if chunk.source_links else None, blurb=chunk.blurb, source_type=chunk.source_type, boost=chunk.boost, hidden=chunk.hidden, metadata=chunk.metadata, score=chunk.score, match_highlights=chunk.match_highlights, updated_at=chunk.updated_at, primary_owners=chunk.primary_owners, secondary_owners=chunk.secondary_owners, is_internet=is_internet, content=section.combined_content if include_content else None, ) for section in sections ] class SearchFullResponse(BaseModel): all_executed_queries: list[str] search_docs: list[SearchDocWithContent] # Reasoning tokens output by the LLM for the document selection doc_selection_reasoning: str | None = None # This a list of document ids that are in the search_docs list llm_selected_doc_ids: list[str] | None = None # Error message if the search failed partway through error: str | None = None class SearchQueryResponse(BaseModel): query: str query_expansions: list[str] | None created_at: datetime class SearchHistoryResponse(BaseModel): search_queries: list[SearchQueryResponse] ================================================ FILE: backend/ee/onyx/server/query_and_chat/query_backend.py ================================================ from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from sqlalchemy.orm import Session from ee.onyx.onyxbot.slack.handlers.handle_standard_answers import ( oneoff_standard_answers, ) from ee.onyx.server.query_and_chat.models import StandardAnswerRequest from ee.onyx.server.query_and_chat.models import StandardAnswerResponse from onyx.auth.users import current_user from onyx.db.engine.sql_engine import get_session from onyx.db.models import User from onyx.utils.logger import setup_logger logger = setup_logger() basic_router = APIRouter(prefix="/query") @basic_router.get("/standard-answer") def get_standard_answer( request: StandardAnswerRequest, db_session: Session = Depends(get_session), _: User = Depends(current_user), ) -> StandardAnswerResponse: try: standard_answers = oneoff_standard_answers( message=request.message, slack_bot_categories=request.slack_bot_categories, db_session=db_session, ) return StandardAnswerResponse(standard_answers=standard_answers) except Exception as e: logger.error(f"Error in get_standard_answer: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail="An internal server error occurred") ================================================ FILE: backend/ee/onyx/server/query_and_chat/search_backend.py ================================================ from collections.abc import Generator from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from ee.onyx.db.search import fetch_search_queries_for_user from ee.onyx.search.process_search_query import gather_search_stream from ee.onyx.search.process_search_query import stream_search_query from ee.onyx.secondary_llm_flows.search_flow_classification import ( classify_is_search_flow, ) from ee.onyx.server.query_and_chat.models import SearchFlowClassificationRequest from ee.onyx.server.query_and_chat.models import SearchFlowClassificationResponse from ee.onyx.server.query_and_chat.models import SearchFullResponse from ee.onyx.server.query_and_chat.models import SearchHistoryResponse from ee.onyx.server.query_and_chat.models import SearchQueryResponse from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest from ee.onyx.server.query_and_chat.streaming_models import SearchErrorPacket from onyx.auth.users import current_user from onyx.configs.app_configs import ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH from onyx.db.engine.sql_engine import get_session from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import User from onyx.llm.factory import get_default_llm from onyx.server.usage_limits import check_llm_cost_limit_for_provider from onyx.server.utils import get_json_line from onyx.server.utils_vector_db import require_vector_db from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() router = APIRouter(prefix="/search") @router.post("/search-flow-classification") def search_flow_classification( request: SearchFlowClassificationRequest, _: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> SearchFlowClassificationResponse: query = request.user_query # This is a heuristic that if the user is typing a lot of text, it's unlikely they're looking for some specific document # Most likely something needs to be done with the text included so we'll just classify it as a chat flow if len(query) > 200: return SearchFlowClassificationResponse(is_search_flow=False) llm = get_default_llm() check_llm_cost_limit_for_provider( db_session=db_session, tenant_id=get_current_tenant_id(), llm_provider_api_key=llm.config.api_key, ) try: is_search_flow = classify_is_search_flow(query=query, llm=llm) except Exception as e: logger.exception( "Search flow classification failed; defaulting to chat flow", exc_info=e, ) is_search_flow = False return SearchFlowClassificationResponse(is_search_flow=is_search_flow) # NOTE: This endpoint is used for the core flow of the Onyx application, any # changes to it should be reviewed and approved by an experienced team member. # It is very important to 1. avoid bloat and 2. that this remains backwards # compatible across versions. @router.post( "/send-search-message", response_model=None, dependencies=[Depends(require_vector_db)], ) def handle_send_search_message( request: SendSearchQueryRequest, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> StreamingResponse | SearchFullResponse: """ Executes a search query with optional streaming. If hybrid_alpha is unset and ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH is True, executes pure keyword search. Returns: StreamingResponse with SSE if stream=True, otherwise SearchFullResponse. """ logger.debug(f"Received search query: {request.search_query}") if request.hybrid_alpha is None and ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH: request.hybrid_alpha = 0.0 # Non-streaming path if not request.stream: try: packets = stream_search_query(request, user, db_session) return gather_search_stream(packets) except NotImplementedError as e: return SearchFullResponse( all_executed_queries=[], search_docs=[], error=str(e), ) # Streaming path def stream_generator() -> Generator[str, None, None]: try: with get_session_with_current_tenant() as streaming_db_session: for packet in stream_search_query(request, user, streaming_db_session): yield get_json_line(packet.model_dump()) except NotImplementedError as e: yield get_json_line(SearchErrorPacket(error=str(e)).model_dump()) except HTTPException: raise except Exception as e: logger.exception("Error in search streaming") yield get_json_line(SearchErrorPacket(error=str(e)).model_dump()) return StreamingResponse(stream_generator(), media_type="text/event-stream") @router.get("/search-history") def get_search_history( limit: int = 100, filter_days: int | None = None, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> SearchHistoryResponse: """ Fetch past search queries for the authenticated user. Args: limit: Maximum number of queries to return (default 100) filter_days: Only return queries from the last N days (optional) Returns: SearchHistoryResponse with list of search queries, ordered by most recent first. """ # Validate limit if limit <= 0: raise HTTPException( status_code=400, detail="limit must be greater than 0", ) if limit > 1000: raise HTTPException( status_code=400, detail="limit must be at most 1000", ) # Validate filter_days if filter_days is not None and filter_days <= 0: raise HTTPException( status_code=400, detail="filter_days must be greater than 0", ) search_queries = fetch_search_queries_for_user( db_session=db_session, user_id=user.id, filter_days=filter_days, limit=limit, ) return SearchHistoryResponse( search_queries=[ SearchQueryResponse( query=sq.query, query_expansions=sq.query_expansions, created_at=sq.created_at, ) for sq in search_queries ] ) ================================================ FILE: backend/ee/onyx/server/query_and_chat/streaming_models.py ================================================ from typing import Literal from pydantic import BaseModel from pydantic import ConfigDict from ee.onyx.server.query_and_chat.models import SearchDocWithContent class SearchQueriesPacket(BaseModel): model_config = ConfigDict(frozen=True) type: Literal["search_queries"] = "search_queries" all_executed_queries: list[str] class SearchDocsPacket(BaseModel): model_config = ConfigDict(frozen=True) type: Literal["search_docs"] = "search_docs" search_docs: list[SearchDocWithContent] class SearchErrorPacket(BaseModel): model_config = ConfigDict(frozen=True) type: Literal["search_error"] = "search_error" error: str class LLMSelectedDocsPacket(BaseModel): model_config = ConfigDict(frozen=True) type: Literal["llm_selected_docs"] = "llm_selected_docs" # None if LLM selection failed, empty list if no docs selected, list of IDs otherwise llm_selected_doc_ids: list[str] | None ================================================ FILE: backend/ee/onyx/server/query_and_chat/token_limit.py ================================================ from collections import defaultdict from collections.abc import Sequence from datetime import datetime from itertools import groupby from typing import Dict from typing import List from typing import Tuple from uuid import UUID from fastapi import HTTPException from sqlalchemy import func from sqlalchemy import select from sqlalchemy.orm import Session from onyx.db.api_key import is_api_key_email_address from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import ChatMessage from onyx.db.models import ChatSession from onyx.db.models import TokenRateLimit from onyx.db.models import TokenRateLimit__UserGroup from onyx.db.models import User from onyx.db.models import User__UserGroup from onyx.db.models import UserGroup from onyx.db.token_limit import fetch_all_user_token_rate_limits from onyx.server.query_and_chat.token_limit import _get_cutoff_time from onyx.server.query_and_chat.token_limit import _is_rate_limited from onyx.server.query_and_chat.token_limit import _user_is_rate_limited_by_global from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel def _check_token_rate_limits(user: User) -> None: # Anonymous users are only rate limited by global settings if user.is_anonymous: _user_is_rate_limited_by_global() elif is_api_key_email_address(user.email): # API keys are only rate limited by global settings _user_is_rate_limited_by_global() else: run_functions_tuples_in_parallel( [ (_user_is_rate_limited, (user.id,)), (_user_is_rate_limited_by_group, (user.id,)), (_user_is_rate_limited_by_global, ()), ] ) """ User rate limits """ def _user_is_rate_limited(user_id: UUID) -> None: with get_session_with_current_tenant() as db_session: user_rate_limits = fetch_all_user_token_rate_limits( db_session=db_session, enabled_only=True, ordered=False ) if user_rate_limits: user_cutoff_time = _get_cutoff_time(user_rate_limits) user_usage = _fetch_user_usage(user_id, user_cutoff_time, db_session) if _is_rate_limited(user_rate_limits, user_usage): raise HTTPException( status_code=429, detail="Token budget exceeded for user. Try again later.", ) def _fetch_user_usage( user_id: UUID, cutoff_time: datetime, db_session: Session ) -> Sequence[tuple[datetime, int]]: """ Fetch user usage within the cutoff time, grouped by minute """ result = db_session.execute( select( func.date_trunc("minute", ChatMessage.time_sent), func.sum(ChatMessage.token_count), ) .join(ChatSession, ChatMessage.chat_session_id == ChatSession.id) .where(ChatSession.user_id == user_id, ChatMessage.time_sent >= cutoff_time) .group_by(func.date_trunc("minute", ChatMessage.time_sent)) ).all() return [(row[0], row[1]) for row in result] """ User Group rate limits """ def _user_is_rate_limited_by_group(user_id: UUID) -> None: with get_session_with_current_tenant() as db_session: group_rate_limits = _fetch_all_user_group_rate_limits(user_id, db_session) if group_rate_limits: # Group cutoff time is the same for all groups. # This could be optimized to only fetch the maximum cutoff time for # a specific group, but seems unnecessary for now. group_cutoff_time = _get_cutoff_time( [e for sublist in group_rate_limits.values() for e in sublist] ) user_group_ids = list(group_rate_limits.keys()) group_usage = _fetch_user_group_usage( user_group_ids, group_cutoff_time, db_session ) has_at_least_one_untriggered_limit = False for user_group_id, rate_limits in group_rate_limits.items(): usage = group_usage.get(user_group_id, []) if not _is_rate_limited(rate_limits, usage): has_at_least_one_untriggered_limit = True break if not has_at_least_one_untriggered_limit: raise HTTPException( status_code=429, detail="Token budget exceeded for user's groups. Try again later.", ) def _fetch_all_user_group_rate_limits( user_id: UUID, db_session: Session ) -> Dict[int, List[TokenRateLimit]]: group_limits = ( select(TokenRateLimit, User__UserGroup.user_group_id) .join( TokenRateLimit__UserGroup, TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id, ) .join( UserGroup, UserGroup.id == TokenRateLimit__UserGroup.user_group_id, ) .join( User__UserGroup, User__UserGroup.user_group_id == UserGroup.id, ) .where( User__UserGroup.user_id == user_id, TokenRateLimit.enabled.is_(True), ) ) raw_rate_limits = db_session.execute(group_limits).all() group_rate_limits = defaultdict(list) for rate_limit, user_group_id in raw_rate_limits: group_rate_limits[user_group_id].append(rate_limit) return group_rate_limits def _fetch_user_group_usage( user_group_ids: list[int], cutoff_time: datetime, db_session: Session ) -> dict[int, list[Tuple[datetime, int]]]: """ Fetch user group usage within the cutoff time, grouped by minute """ user_group_usage = db_session.execute( select( func.sum(ChatMessage.token_count), func.date_trunc("minute", ChatMessage.time_sent), UserGroup.id, ) .join(ChatSession, ChatMessage.chat_session_id == ChatSession.id) .join(User__UserGroup, User__UserGroup.user_id == ChatSession.user_id) .join(UserGroup, UserGroup.id == User__UserGroup.user_group_id) .filter(UserGroup.id.in_(user_group_ids), ChatMessage.time_sent >= cutoff_time) .group_by(func.date_trunc("minute", ChatMessage.time_sent), UserGroup.id) ).all() return { user_group_id: [(usage, time_sent) for time_sent, usage, _ in group_usage] for user_group_id, group_usage in groupby( user_group_usage, key=lambda row: row[2] ) } ================================================ FILE: backend/ee/onyx/server/query_history/api.py ================================================ import uuid from collections.abc import Generator from datetime import datetime from datetime import timezone from http import HTTPStatus from uuid import UUID from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import Query from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from ee.onyx.background.task_name_builders import query_history_task_name from ee.onyx.db.query_history import get_all_query_history_export_tasks from ee.onyx.db.query_history import get_page_of_chat_sessions from ee.onyx.db.query_history import get_total_filtered_chat_sessions_count from ee.onyx.server.query_history.models import ChatSessionMinimal from ee.onyx.server.query_history.models import ChatSessionSnapshot from ee.onyx.server.query_history.models import MessageSnapshot from ee.onyx.server.query_history.models import QueryHistoryExport from onyx.auth.users import current_admin_user from onyx.auth.users import get_display_email from onyx.background.celery.versioned_apps.client import app as client_app from onyx.background.task_utils import construct_query_history_report_name from onyx.chat.chat_utils import create_chat_history_chain from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE from onyx.configs.constants import FileOrigin from onyx.configs.constants import FileType from onyx.configs.constants import MessageType from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import PUBLIC_API_TAGS from onyx.configs.constants import QAFeedbackType from onyx.configs.constants import QueryHistoryType from onyx.configs.constants import SessionType from onyx.db.chat import get_chat_session_by_id from onyx.db.chat import get_chat_sessions_by_user from onyx.db.engine.sql_engine import get_session from onyx.db.enums import TaskStatus from onyx.db.file_record import get_query_history_export_files from onyx.db.models import ChatSession from onyx.db.models import User from onyx.db.tasks import get_task_with_id from onyx.db.tasks import register_task from onyx.file_store.file_store import get_default_file_store from onyx.server.documents.models import PaginatedReturn from onyx.server.query_and_chat.models import ChatSessionDetails from onyx.server.query_and_chat.models import ChatSessionsResponse from onyx.utils.threadpool_concurrency import parallel_yield from shared_configs.contextvars import get_current_tenant_id router = APIRouter() ONYX_ANONYMIZED_EMAIL = "anonymous@anonymous.invalid" def ensure_query_history_is_enabled( disallowed: list[QueryHistoryType], ) -> None: if ONYX_QUERY_HISTORY_TYPE in disallowed: raise HTTPException( status_code=HTTPStatus.FORBIDDEN, detail="Query history has been disabled by the administrator.", ) def yield_snapshot_from_chat_session( chat_session: ChatSession, db_session: Session, ) -> Generator[ChatSessionSnapshot | None]: yield snapshot_from_chat_session(chat_session=chat_session, db_session=db_session) def fetch_and_process_chat_session_history( db_session: Session, start: datetime, end: datetime, limit: int | None = 500, # noqa: ARG001 ) -> Generator[ChatSessionSnapshot]: PAGE_SIZE = 100 page = 0 while True: paged_chat_sessions = get_page_of_chat_sessions( start_time=start, end_time=end, db_session=db_session, page_num=page, page_size=PAGE_SIZE, ) if not paged_chat_sessions: break paged_snapshots = parallel_yield( [ yield_snapshot_from_chat_session( db_session=db_session, chat_session=chat_session, ) for chat_session in paged_chat_sessions ] ) for snapshot in paged_snapshots: if snapshot: yield snapshot # If we've fetched *less* than a `PAGE_SIZE` worth # of data, we have reached the end of the # pagination sequence; break. if len(paged_chat_sessions) < PAGE_SIZE: break page += 1 def snapshot_from_chat_session( chat_session: ChatSession, db_session: Session, ) -> ChatSessionSnapshot | None: try: # Older chats may not have the right structure messages = create_chat_history_chain( chat_session_id=chat_session.id, db_session=db_session ) except RuntimeError: return None flow_type = SessionType.SLACK if chat_session.onyxbot_flow else SessionType.CHAT return ChatSessionSnapshot( id=chat_session.id, user_email=get_display_email( chat_session.user.email if chat_session.user else None ), name=chat_session.description, messages=[ MessageSnapshot.build(message) for message in messages if message.message_type != MessageType.SYSTEM ], assistant_id=chat_session.persona_id, assistant_name=chat_session.persona.name if chat_session.persona else None, time_created=chat_session.time_created, flow_type=flow_type, ) @router.get("/admin/chat-sessions") def admin_get_chat_sessions( user_id: UUID, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> ChatSessionsResponse: # we specifically don't allow this endpoint if "anonymized" since # this is a direct query on the user id ensure_query_history_is_enabled( [ QueryHistoryType.DISABLED, QueryHistoryType.ANONYMIZED, ] ) try: chat_sessions = get_chat_sessions_by_user( user_id=user_id, deleted=False, db_session=db_session, limit=0 ) except ValueError: raise ValueError("Chat session does not exist or has been deleted") return ChatSessionsResponse( sessions=[ ChatSessionDetails( id=chat.id, name=chat.description, persona_id=chat.persona_id, time_created=chat.time_created.isoformat(), time_updated=chat.time_updated.isoformat(), shared_status=chat.shared_status, current_alternate_model=chat.current_alternate_model, ) for chat in chat_sessions ] ) @router.get("/admin/chat-session-history") def get_chat_session_history( page_num: int = Query(0, ge=0), page_size: int = Query(10, ge=1), feedback_type: QAFeedbackType | None = None, start_time: datetime | None = None, end_time: datetime | None = None, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> PaginatedReturn[ChatSessionMinimal]: ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED]) page_of_chat_sessions = get_page_of_chat_sessions( page_num=page_num, page_size=page_size, db_session=db_session, start_time=start_time, end_time=end_time, feedback_filter=feedback_type, ) total_filtered_chat_sessions_count = get_total_filtered_chat_sessions_count( db_session=db_session, start_time=start_time, end_time=end_time, feedback_filter=feedback_type, ) minimal_chat_sessions: list[ChatSessionMinimal] = [] for chat_session in page_of_chat_sessions: minimal_chat_session = ChatSessionMinimal.from_chat_session(chat_session) if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED: minimal_chat_session.user_email = ONYX_ANONYMIZED_EMAIL minimal_chat_sessions.append(minimal_chat_session) return PaginatedReturn( items=minimal_chat_sessions, total_items=total_filtered_chat_sessions_count, ) @router.get("/admin/chat-session-history/{chat_session_id}") def get_chat_session_admin( chat_session_id: UUID, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> ChatSessionSnapshot: ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED]) try: chat_session = get_chat_session_by_id( chat_session_id=chat_session_id, user_id=None, # view chat regardless of user db_session=db_session, include_deleted=True, ) except ValueError: raise HTTPException( HTTPStatus.BAD_REQUEST, f"Chat session with id '{chat_session_id}' does not exist.", ) snapshot = snapshot_from_chat_session( chat_session=chat_session, db_session=db_session ) if snapshot is None: raise HTTPException( HTTPStatus.BAD_REQUEST, f"Could not create snapshot for chat session with id '{chat_session_id}'", ) if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED: snapshot.user_email = ONYX_ANONYMIZED_EMAIL return snapshot @router.get("/admin/query-history/list") def list_all_query_history_exports( _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> list[QueryHistoryExport]: ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED]) try: pending_tasks = [ QueryHistoryExport.from_task(task) for task in get_all_query_history_export_tasks(db_session=db_session) ] generated_files = [ QueryHistoryExport.from_file(file) for file in get_query_history_export_files(db_session=db_session) ] merged = pending_tasks + generated_files # We sort based off of the start-time of the task. # We also return it in reverse order since viewing generated reports in most-recent to least-recent is most common. merged.sort(key=lambda task: task.start_time, reverse=True) return merged except Exception as e: raise HTTPException( HTTPStatus.INTERNAL_SERVER_ERROR, f"Failed to get all tasks: {e}" ) @router.post("/admin/query-history/start-export", tags=PUBLIC_API_TAGS) def start_query_history_export( _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), start: datetime | None = None, end: datetime | None = None, ) -> dict[str, str]: ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED]) start = start or datetime.fromtimestamp(0, tz=timezone.utc) end = end or datetime.now(tz=timezone.utc) if start >= end: raise HTTPException( HTTPStatus.BAD_REQUEST, f"Start time must come before end time, but instead got the start time coming after; {start=} {end=}", ) task_id_uuid = uuid.uuid4() task_id = str(task_id_uuid) start_time = datetime.now(tz=timezone.utc) register_task( db_session=db_session, task_name=query_history_task_name(start=start, end=end), task_id=task_id, status=TaskStatus.PENDING, start_time=start_time, ) client_app.send_task( OnyxCeleryTask.EXPORT_QUERY_HISTORY_TASK, task_id=task_id, priority=OnyxCeleryPriority.MEDIUM, queue=OnyxCeleryQueues.CSV_GENERATION, kwargs={ "start": start, "end": end, "start_time": start_time, "tenant_id": get_current_tenant_id(), }, ) return {"request_id": task_id} @router.get("/admin/query-history/export-status", tags=PUBLIC_API_TAGS) def get_query_history_export_status( request_id: str, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> dict[str, str]: ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED]) task = get_task_with_id(db_session=db_session, task_id=request_id) if task: return {"status": task.status} # If task is None, then it's possible that the task has already finished processing. # Therefore, we should then check if the export file has already been stored inside of the file-store. # If that *also* doesn't exist, then we can return a 404. file_store = get_default_file_store() report_name = construct_query_history_report_name(request_id) has_file = file_store.has_file( file_id=report_name, file_origin=FileOrigin.QUERY_HISTORY_CSV, file_type=FileType.CSV, ) if not has_file: raise HTTPException( HTTPStatus.NOT_FOUND, f"No task with {request_id=} was found", ) return {"status": TaskStatus.SUCCESS} @router.get("/admin/query-history/download", tags=PUBLIC_API_TAGS) def download_query_history_csv( request_id: str, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> StreamingResponse: ensure_query_history_is_enabled(disallowed=[QueryHistoryType.DISABLED]) report_name = construct_query_history_report_name(request_id) file_store = get_default_file_store() has_file = file_store.has_file( file_id=report_name, file_origin=FileOrigin.QUERY_HISTORY_CSV, file_type=FileType.CSV, ) if has_file: try: csv_stream = file_store.read_file(report_name) except Exception as e: raise HTTPException( HTTPStatus.INTERNAL_SERVER_ERROR, f"Failed to read query history file: {str(e)}", ) csv_stream.seek(0) return StreamingResponse( iter(csv_stream), media_type=FileType.CSV, headers={"Content-Disposition": f"attachment;filename={report_name}"}, ) # If the file doesn't exist yet, it may still be processing. # Therefore, we check the task queue to determine its status, if there is any. task = get_task_with_id(db_session=db_session, task_id=request_id) if not task: raise HTTPException( HTTPStatus.NOT_FOUND, f"No task with {request_id=} was found", ) if task.status in [TaskStatus.STARTED, TaskStatus.PENDING]: raise HTTPException( HTTPStatus.ACCEPTED, f"Task with {request_id=} is still being worked on" ) elif task.status == TaskStatus.FAILURE: raise HTTPException( HTTPStatus.INTERNAL_SERVER_ERROR, f"Task with {request_id=} failed to be processed", ) else: # This is the final case in which `task.status == SUCCESS` raise RuntimeError( "The task was marked as success, the file was not found in the file store; this is an internal error..." ) ================================================ FILE: backend/ee/onyx/server/query_history/models.py ================================================ from datetime import datetime from uuid import UUID from pydantic import BaseModel from ee.onyx.background.task_name_builders import QUERY_HISTORY_TASK_NAME_PREFIX from onyx.auth.users import get_display_email from onyx.background.task_utils import extract_task_id_from_query_history_report_name from onyx.configs.constants import MessageType from onyx.configs.constants import QAFeedbackType from onyx.configs.constants import SessionType from onyx.db.enums import TaskStatus from onyx.db.models import ChatMessage from onyx.db.models import ChatSession from onyx.db.models import FileRecord from onyx.db.models import TaskQueueState class AbridgedSearchDoc(BaseModel): """A subset of the info present in `SearchDoc`""" document_id: str semantic_identifier: str link: str | None class MessageSnapshot(BaseModel): id: int message: str message_type: MessageType documents: list[AbridgedSearchDoc] feedback_type: QAFeedbackType | None feedback_text: str | None time_created: datetime @classmethod def build(cls, message: ChatMessage) -> "MessageSnapshot": latest_messages_feedback_obj = ( message.chat_message_feedbacks[-1] if len(message.chat_message_feedbacks) > 0 else None ) feedback_type = ( ( QAFeedbackType.LIKE if latest_messages_feedback_obj.is_positive else QAFeedbackType.DISLIKE ) if latest_messages_feedback_obj else None ) feedback_text = ( latest_messages_feedback_obj.feedback_text if latest_messages_feedback_obj else None ) return cls( id=message.id, message=message.message, message_type=message.message_type, documents=[ AbridgedSearchDoc( document_id=document.document_id, semantic_identifier=document.semantic_id, link=document.link, ) for document in message.search_docs ], feedback_type=feedback_type, feedback_text=feedback_text, time_created=message.time_sent, ) class ChatSessionMinimal(BaseModel): id: UUID user_email: str name: str | None first_user_message: str first_ai_message: str assistant_id: int | None assistant_name: str | None time_created: datetime feedback_type: QAFeedbackType | None flow_type: SessionType conversation_length: int @classmethod def from_chat_session(cls, chat_session: ChatSession) -> "ChatSessionMinimal": first_user_message = next( ( message.message for message in chat_session.messages if message.message_type == MessageType.USER ), "", ) first_ai_message = next( ( message.message for message in chat_session.messages if message.message_type == MessageType.ASSISTANT ), "", ) list_of_message_feedbacks = [ feedback.is_positive for message in chat_session.messages for feedback in message.chat_message_feedbacks ] session_feedback_type = None if list_of_message_feedbacks: if all(list_of_message_feedbacks): session_feedback_type = QAFeedbackType.LIKE elif not any(list_of_message_feedbacks): session_feedback_type = QAFeedbackType.DISLIKE else: session_feedback_type = QAFeedbackType.MIXED return cls( id=chat_session.id, user_email=get_display_email( chat_session.user.email if chat_session.user else None ), name=chat_session.description, first_user_message=first_user_message, first_ai_message=first_ai_message, assistant_id=chat_session.persona_id, assistant_name=( chat_session.persona.name if chat_session.persona else None ), time_created=chat_session.time_created, feedback_type=session_feedback_type, flow_type=( SessionType.SLACK if chat_session.onyxbot_flow else SessionType.CHAT ), conversation_length=len( [ message for message in chat_session.messages if message.message_type != MessageType.SYSTEM ] ), ) class ChatSessionSnapshot(BaseModel): id: UUID user_email: str name: str | None messages: list[MessageSnapshot] assistant_id: int | None assistant_name: str | None time_created: datetime flow_type: SessionType class QuestionAnswerPairSnapshot(BaseModel): chat_session_id: UUID # 1-indexed message number in the chat_session # e.g. the first message pair in the chat_session is 1, the second is 2, etc. message_pair_num: int user_message: str ai_response: str retrieved_documents: list[AbridgedSearchDoc] feedback_type: QAFeedbackType | None feedback_text: str | None persona_name: str | None user_email: str time_created: datetime flow_type: SessionType @classmethod def from_chat_session_snapshot( cls, chat_session_snapshot: ChatSessionSnapshot, ) -> list["QuestionAnswerPairSnapshot"]: message_pairs: list[tuple[MessageSnapshot, MessageSnapshot]] = [] for ind in range(1, len(chat_session_snapshot.messages), 2): message_pairs.append( ( chat_session_snapshot.messages[ind - 1], chat_session_snapshot.messages[ind], ) ) return [ cls( chat_session_id=chat_session_snapshot.id, message_pair_num=ind + 1, user_message=user_message.message, ai_response=ai_message.message, retrieved_documents=ai_message.documents, feedback_type=ai_message.feedback_type, feedback_text=ai_message.feedback_text, persona_name=chat_session_snapshot.assistant_name, user_email=get_display_email(chat_session_snapshot.user_email), time_created=user_message.time_created, flow_type=chat_session_snapshot.flow_type, ) for ind, (user_message, ai_message) in enumerate(message_pairs) ] def to_json(self) -> dict[str, str | None]: return { "chat_session_id": str(self.chat_session_id), "message_pair_num": str(self.message_pair_num), "user_message": self.user_message, "ai_response": self.ai_response, "retrieved_documents": "|".join( [ doc.link or doc.semantic_identifier for doc in self.retrieved_documents ] ), "feedback_type": self.feedback_type.value if self.feedback_type else "", "feedback_text": self.feedback_text or "", "persona_name": self.persona_name, "user_email": self.user_email, "time_created": str(self.time_created), "flow_type": self.flow_type, } class QueryHistoryExport(BaseModel): task_id: str status: TaskStatus start: datetime end: datetime start_time: datetime @classmethod def from_task( cls, task_queue_state: TaskQueueState, ) -> "QueryHistoryExport": start_end = task_queue_state.task_name.removeprefix( f"{QUERY_HISTORY_TASK_NAME_PREFIX}_" ) start, end = start_end.split("_") if not task_queue_state.start_time: raise RuntimeError("The start time of the task must always be present") return cls( task_id=task_queue_state.task_id, status=task_queue_state.status, start=datetime.fromisoformat(start), end=datetime.fromisoformat(end), start_time=task_queue_state.start_time, ) @classmethod def from_file( cls, file: FileRecord, ) -> "QueryHistoryExport": if not file.file_metadata or not isinstance(file.file_metadata, dict): raise RuntimeError( "The file metadata must be non-null, and must be of type `dict[str, str]`" ) metadata = QueryHistoryFileMetadata.model_validate(dict(file.file_metadata)) task_id = extract_task_id_from_query_history_report_name(file.file_id) return cls( task_id=task_id, status=TaskStatus.SUCCESS, start=metadata.start, end=metadata.end, start_time=metadata.start_time, ) class QueryHistoryFileMetadata(BaseModel): start: datetime end: datetime start_time: datetime ================================================ FILE: backend/ee/onyx/server/reporting/usage_export_api.py ================================================ from collections.abc import Generator from datetime import datetime from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import Response from fastapi.responses import StreamingResponse from pydantic import BaseModel from sqlalchemy.orm import Session from ee.onyx.db.usage_export import get_all_usage_reports from ee.onyx.db.usage_export import get_usage_report_data from ee.onyx.db.usage_export import UsageReportMetadata from onyx.auth.users import current_admin_user from onyx.background.celery.versioned_apps.client import app as client_app from onyx.configs.constants import OnyxCeleryTask from onyx.db.engine.sql_engine import get_session from onyx.db.models import User from onyx.file_store.constants import STANDARD_CHUNK_SIZE from shared_configs.contextvars import get_current_tenant_id router = APIRouter() class GenerateUsageReportParams(BaseModel): period_from: str | None = None period_to: str | None = None @router.post("/admin/usage-report", status_code=204) def generate_report( params: GenerateUsageReportParams, user: User = Depends(current_admin_user), ) -> None: # Validate period parameters if params.period_from and params.period_to: try: datetime.fromisoformat(params.period_from) datetime.fromisoformat(params.period_to) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) tenant_id = get_current_tenant_id() client_app.send_task( OnyxCeleryTask.GENERATE_USAGE_REPORT_TASK, kwargs={ "tenant_id": tenant_id, "user_id": str(user.id) if user else None, "period_from": params.period_from, "period_to": params.period_to, }, ) return None @router.get("/admin/usage-report/{report_name}") def read_usage_report( report_name: str, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), # noqa: ARG001 ) -> Response: try: file = get_usage_report_data(report_name) except (ValueError, RuntimeError) as e: raise HTTPException(status_code=404, detail=str(e)) def iterfile() -> Generator[bytes, None, None]: while True: chunk = file.read(STANDARD_CHUNK_SIZE) if not chunk: break yield chunk return StreamingResponse( content=iterfile(), media_type="application/zip", headers={"Content-Disposition": f"attachment; filename={report_name}"}, ) @router.get("/admin/usage-report") def fetch_usage_reports( _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> list[UsageReportMetadata]: try: return get_all_usage_reports(db_session) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) ================================================ FILE: backend/ee/onyx/server/reporting/usage_export_generation.py ================================================ import csv import tempfile import uuid import zipfile from datetime import datetime from datetime import timedelta from datetime import timezone from fastapi_users_db_sqlalchemy import UUID_ID from sqlalchemy import cast from sqlalchemy.dialects.postgresql import UUID from sqlalchemy.orm import Session from ee.onyx.db.usage_export import get_all_empty_chat_message_entries from ee.onyx.db.usage_export import write_usage_report from ee.onyx.server.reporting.usage_export_models import UsageReportMetadata from ee.onyx.server.reporting.usage_export_models import UserSkeleton from onyx.configs.constants import FileOrigin from onyx.db.models import User from onyx.db.users import get_all_users from onyx.file_store.constants import MAX_IN_MEMORY_SIZE from onyx.file_store.file_store import FileStore from onyx.file_store.file_store import get_default_file_store def generate_chat_messages_report( db_session: Session, file_store: FileStore, report_id: str, period: tuple[datetime, datetime] | None, ) -> str: file_name = f"{report_id}_chat_sessions" if period is None: period = ( datetime.fromtimestamp(0, tz=timezone.utc), datetime.now(tz=timezone.utc), ) else: # time-picker sends a time which is at the beginning of the day # so we need to add one day to the end time to make it inclusive period = ( period[0], period[1] + timedelta(days=1), ) with tempfile.SpooledTemporaryFile( max_size=MAX_IN_MEMORY_SIZE, mode="w+" ) as temp_file: csvwriter = csv.writer(temp_file, delimiter=",") csvwriter.writerow( [ "session_id", "user_id", "flow_type", "time_sent", "assistant_name", "user_email", "number_of_tokens", ] ) for chat_message_skeleton_batch in get_all_empty_chat_message_entries( db_session, period ): for chat_message_skeleton in chat_message_skeleton_batch: csvwriter.writerow( [ chat_message_skeleton.chat_session_id, chat_message_skeleton.user_id, chat_message_skeleton.flow_type, chat_message_skeleton.time_sent.isoformat(), chat_message_skeleton.assistant_name, chat_message_skeleton.user_email, chat_message_skeleton.number_of_tokens, ] ) # after writing seek to beginning of buffer temp_file.seek(0) file_id = file_store.save_file( content=temp_file, display_name=file_name, file_origin=FileOrigin.GENERATED_REPORT, file_type="text/csv", ) return file_id def generate_user_report( db_session: Session, file_store: FileStore, report_id: str, ) -> str: file_name = f"{report_id}_users" with tempfile.SpooledTemporaryFile( max_size=MAX_IN_MEMORY_SIZE, mode="w+" ) as temp_file: csvwriter = csv.writer(temp_file, delimiter=",") csvwriter.writerow(["user_id", "is_active"]) users = get_all_users(db_session) for user in users: user_skeleton = UserSkeleton( user_id=str(user.id), is_active=user.is_active, ) csvwriter.writerow([user_skeleton.user_id, user_skeleton.is_active]) temp_file.seek(0) file_id = file_store.save_file( content=temp_file, display_name=file_name, file_origin=FileOrigin.GENERATED_REPORT, file_type="text/csv", ) return file_id def create_new_usage_report( db_session: Session, user_id: UUID_ID | None, # None = auto-generated period: tuple[datetime, datetime] | None, ) -> UsageReportMetadata: report_id = str(uuid.uuid4()) file_store = get_default_file_store() messages_file_id = generate_chat_messages_report( db_session, file_store, report_id, period ) users_file_id = generate_user_report(db_session, file_store, report_id) with tempfile.SpooledTemporaryFile(max_size=MAX_IN_MEMORY_SIZE) as zip_buffer: with zipfile.ZipFile(zip_buffer, "a", zipfile.ZIP_DEFLATED) as zip_file: # write messages chat_messages_tmpfile = file_store.read_file( messages_file_id, mode="b", use_tempfile=True ) zip_file.writestr( "chat_messages.csv", chat_messages_tmpfile.read(), ) # write users users_tmpfile = file_store.read_file( users_file_id, mode="b", use_tempfile=True ) zip_file.writestr("users.csv", users_tmpfile.read()) zip_buffer.seek(0) # store zip blob to file_store report_name = f"{datetime.now(tz=timezone.utc).strftime('%Y-%m-%d')}_{report_id}_usage_report.zip" file_store.save_file( content=zip_buffer, display_name=report_name, file_origin=FileOrigin.GENERATED_REPORT, file_type="application/zip", file_id=report_name, ) # add report after zip file is written new_report = write_usage_report(db_session, report_name, user_id, period) # get user email requestor_user = ( db_session.query(User) .filter(cast(User.id, UUID) == new_report.requestor_user_id) .one_or_none() if new_report.requestor_user_id else None ) requestor_email = requestor_user.email if requestor_user else None return UsageReportMetadata( report_name=new_report.report_name, requestor=requestor_email, time_created=new_report.time_created, period_from=new_report.period_from, period_to=new_report.period_to, ) ================================================ FILE: backend/ee/onyx/server/reporting/usage_export_models.py ================================================ from datetime import datetime from enum import Enum from uuid import UUID from pydantic import BaseModel class FlowType(str, Enum): CHAT = "chat" SLACK = "slack" class ChatMessageSkeleton(BaseModel): message_id: int chat_session_id: UUID user_id: str | None flow_type: FlowType time_sent: datetime assistant_name: str | None user_email: str | None number_of_tokens: int class UserSkeleton(BaseModel): user_id: str is_active: bool class UsageReportMetadata(BaseModel): report_name: str requestor: str | None time_created: datetime period_from: datetime | None # None = All time period_to: datetime | None ================================================ FILE: backend/ee/onyx/server/scim/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/server/scim/api.py ================================================ """SCIM 2.0 API endpoints (RFC 7644). This module provides the FastAPI router for SCIM service discovery, User CRUD, and Group CRUD. Identity providers (Okta, Azure AD) call these endpoints to provision and manage users and groups. Service discovery endpoints are unauthenticated — IdPs may probe them before bearer token configuration is complete. All other endpoints require a valid SCIM bearer token. """ from __future__ import annotations from uuid import UUID from fastapi import APIRouter from fastapi import Depends from fastapi import FastAPI from fastapi import Query from fastapi import Request from fastapi import Response from fastapi.responses import JSONResponse from fastapi_users.password import PasswordHelper from sqlalchemy import func from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from ee.onyx.db.scim import ScimDAL from ee.onyx.server.scim.auth import ScimAuthError from ee.onyx.server.scim.auth import verify_scim_token from ee.onyx.server.scim.filtering import parse_scim_filter from ee.onyx.server.scim.models import SCIM_LIST_RESPONSE_SCHEMA from ee.onyx.server.scim.models import ScimError from ee.onyx.server.scim.models import ScimGroupMember from ee.onyx.server.scim.models import ScimGroupResource from ee.onyx.server.scim.models import ScimListResponse from ee.onyx.server.scim.models import ScimMappingFields from ee.onyx.server.scim.models import ScimName from ee.onyx.server.scim.models import ScimPatchRequest from ee.onyx.server.scim.models import ScimServiceProviderConfig from ee.onyx.server.scim.models import ScimUserResource from ee.onyx.server.scim.patch import apply_group_patch from ee.onyx.server.scim.patch import apply_user_patch from ee.onyx.server.scim.patch import ScimPatchError from ee.onyx.server.scim.providers.base import get_default_provider from ee.onyx.server.scim.providers.base import ScimProvider from ee.onyx.server.scim.providers.base import serialize_emails from ee.onyx.server.scim.schema_definitions import ENTERPRISE_USER_SCHEMA_DEF from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF from onyx.db.engine.sql_engine import get_session from onyx.db.enums import AccountType from onyx.db.enums import GrantSource from onyx.db.enums import Permission from onyx.db.models import ScimToken from onyx.db.models import ScimUserMapping from onyx.db.models import User from onyx.db.models import UserGroup from onyx.db.models import UserRole from onyx.db.permissions import recompute_permissions_for_group__no_commit from onyx.db.permissions import recompute_user_permissions__no_commit from onyx.db.users import assign_user_to_default_groups__no_commit from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop logger = setup_logger() # Group names reserved for system default groups (seeded by migration). _RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"}) class ScimJSONResponse(JSONResponse): """JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1).""" media_type = "application/scim+json" # NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes, # /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644). # IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be # changed to kebab-case. scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"]) _pw_helper = PasswordHelper() def register_scim_exception_handlers(app: FastAPI) -> None: """Register SCIM-specific exception handlers on the FastAPI app. Call this after ``app.include_router(scim_router)`` so that auth failures from ``verify_scim_token`` return RFC 7644 §3.12 error envelopes (with ``schemas`` and ``status`` fields) instead of FastAPI's default ``{"detail": "..."}`` format. """ @app.exception_handler(ScimAuthError) async def _handle_scim_auth_error( _request: Request, exc: ScimAuthError ) -> ScimJSONResponse: return _scim_error_response(exc.status_code, exc.detail) def _get_provider( _token: ScimToken = Depends(verify_scim_token), ) -> ScimProvider: """Resolve the SCIM provider for the current request. Currently returns OktaProvider for all requests. When multi-provider support is added (ENG-3652), this will resolve based on token metadata or tenant configuration — no endpoint changes required. """ return get_default_provider() # --------------------------------------------------------------------------- # Service Discovery Endpoints (unauthenticated) # --------------------------------------------------------------------------- @scim_router.get("/ServiceProviderConfig") def get_service_provider_config() -> ScimServiceProviderConfig: """Advertise supported SCIM features (RFC 7643 §5).""" return SERVICE_PROVIDER_CONFIG @scim_router.get("/ResourceTypes") def get_resource_types() -> ScimJSONResponse: """List available SCIM resource types (RFC 7643 §6). Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs like Entra ID expect a JSON object, not a bare array. """ resources = [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE] return ScimJSONResponse( content={ "schemas": [SCIM_LIST_RESPONSE_SCHEMA], "totalResults": len(resources), "Resources": [ r.model_dump(exclude_none=True, by_alias=True) for r in resources ], } ) @scim_router.get("/Schemas") def get_schemas() -> ScimJSONResponse: """Return SCIM schema definitions (RFC 7643 §7). Wrapped in a ListResponse envelope (RFC 7644 §3.4.2) because IdPs like Entra ID expect a JSON object, not a bare array. """ schemas = [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF, ENTERPRISE_USER_SCHEMA_DEF] return ScimJSONResponse( content={ "schemas": [SCIM_LIST_RESPONSE_SCHEMA], "totalResults": len(schemas), "Resources": [s.model_dump(exclude_none=True) for s in schemas], } ) # --------------------------------------------------------------------------- # Helpers # --------------------------------------------------------------------------- def _scim_error_response(status: int, detail: str) -> ScimJSONResponse: """Build a SCIM-compliant error response (RFC 7644 §3.12).""" logger.warning("SCIM error response: status=%s detail=%s", status, detail) body = ScimError(status=str(status), detail=detail) return ScimJSONResponse( status_code=status, content=body.model_dump(exclude_none=True), ) def _parse_excluded_attributes(raw: str | None) -> set[str]: """Parse the ``excludedAttributes`` query parameter (RFC 7644 §3.4.2.5). Returns a set of lowercased attribute names to omit from responses. """ if not raw: return set() return {attr.strip().lower() for attr in raw.split(",") if attr.strip()} def _apply_exclusions( resource: ScimUserResource | ScimGroupResource, excluded: set[str], ) -> dict: """Serialize a SCIM resource, omitting attributes the IdP excluded. RFC 7644 §3.4.2.5 lets the IdP pass ``?excludedAttributes=groups,emails`` to reduce response payload size. We strip those fields after serialization so the rest of the pipeline doesn't need to know about them. """ data = resource.model_dump(exclude_none=True, by_alias=True) for attr in excluded: # Match case-insensitively against the camelCase field names keys_to_remove = [k for k in data if k.lower() == attr] for k in keys_to_remove: del data[k] return data def _check_seat_availability(dal: ScimDAL) -> str | None: """Return an error message if seat limit is reached, else None.""" check_fn = fetch_ee_implementation_or_noop( "onyx.db.license", "check_seat_availability", None ) if check_fn is None: return None result = check_fn(dal.session, seats_needed=1) if not result.available: return result.error_message or "Seat limit reached" return None def _fetch_user_or_404(user_id: str, dal: ScimDAL) -> User | ScimJSONResponse: """Parse *user_id* as UUID, look up the user, or return a 404 error.""" try: uid = UUID(user_id) except ValueError: return _scim_error_response(404, f"User {user_id} not found") user = dal.get_user(uid) if not user: return _scim_error_response(404, f"User {user_id} not found") return user def _scim_name_to_str(name: ScimName | None) -> str | None: """Extract a display name string from a SCIM name object. Returns None if no name is provided, so the caller can decide whether to update the user's personal_name. """ if not name: return None # If the client explicitly provides ``formatted``, prefer it — the client # knows what display string it wants. Otherwise build from components. if name.formatted: return name.formatted parts = " ".join(part for part in [name.givenName, name.familyName] if part) return parts or None def _scim_resource_response( resource: ScimUserResource | ScimGroupResource | ScimListResponse, status_code: int = 200, ) -> ScimJSONResponse: """Serialize a SCIM resource as ``application/scim+json``.""" content = resource.model_dump(exclude_none=True, by_alias=True) return ScimJSONResponse( status_code=status_code, content=content, ) def _build_list_response( resources: list[ScimUserResource | ScimGroupResource], total: int, start_index: int, count: int, excluded: set[str] | None = None, ) -> ScimListResponse | ScimJSONResponse: """Build a SCIM list response, optionally applying attribute exclusions. RFC 7644 §3.4.2.5 — IdPs may request certain attributes be omitted via the ``excludedAttributes`` query parameter. """ if excluded: envelope = ScimListResponse( totalResults=total, startIndex=start_index, itemsPerPage=count, ) data = envelope.model_dump(exclude_none=True) data["Resources"] = [_apply_exclusions(r, excluded) for r in resources] return ScimJSONResponse(content=data) return _scim_resource_response( ScimListResponse( totalResults=total, startIndex=start_index, itemsPerPage=count, Resources=resources, ) ) def _extract_enterprise_fields( resource: ScimUserResource, ) -> tuple[str | None, str | None]: """Extract department and manager from enterprise extension.""" ext = resource.enterprise_extension if not ext: return None, None department = ext.department manager = ext.manager.value if ext.manager else None return department, manager def _mapping_to_fields( mapping: ScimUserMapping | None, ) -> ScimMappingFields | None: """Extract round-trip fields from a SCIM user mapping.""" if not mapping: return None return ScimMappingFields( department=mapping.department, manager=mapping.manager, given_name=mapping.given_name, family_name=mapping.family_name, scim_emails_json=mapping.scim_emails_json, ) def _fields_from_resource(resource: ScimUserResource) -> ScimMappingFields: """Build mapping fields from an incoming SCIM user resource.""" department, manager = _extract_enterprise_fields(resource) return ScimMappingFields( department=department, manager=manager, given_name=resource.name.givenName if resource.name else None, family_name=resource.name.familyName if resource.name else None, scim_emails_json=serialize_emails(resource.emails), ) # --------------------------------------------------------------------------- # User CRUD (RFC 7644 §3) # --------------------------------------------------------------------------- @scim_router.get("/Users", response_model=None) def list_users( filter: str | None = Query(None), excludedAttributes: str | None = None, startIndex: int = Query(1, ge=1), count: int = Query(100, ge=0, le=500), _token: ScimToken = Depends(verify_scim_token), provider: ScimProvider = Depends(_get_provider), db_session: Session = Depends(get_session), ) -> ScimListResponse | ScimJSONResponse: """List users with optional SCIM filter and pagination.""" dal = ScimDAL(db_session) dal.update_token_last_used(_token.id) dal.commit() try: scim_filter = parse_scim_filter(filter) except ValueError as e: return _scim_error_response(400, str(e)) try: users_with_mappings, total = dal.list_users(scim_filter, startIndex, count) except ValueError as e: return _scim_error_response(400, str(e)) user_groups_map = dal.get_users_groups_batch([u.id for u, _ in users_with_mappings]) resources: list[ScimUserResource | ScimGroupResource] = [ provider.build_user_resource( user, mapping.external_id if mapping else None, groups=user_groups_map.get(user.id, []), scim_username=mapping.scim_username if mapping else None, fields=_mapping_to_fields(mapping), ) for user, mapping in users_with_mappings ] return _build_list_response( resources, total, startIndex, count, excluded=_parse_excluded_attributes(excludedAttributes), ) @scim_router.get("/Users/{user_id}", response_model=None) def get_user( user_id: str, excludedAttributes: str | None = None, _token: ScimToken = Depends(verify_scim_token), provider: ScimProvider = Depends(_get_provider), db_session: Session = Depends(get_session), ) -> ScimUserResource | ScimJSONResponse: """Get a single user by ID.""" dal = ScimDAL(db_session) dal.update_token_last_used(_token.id) dal.commit() result = _fetch_user_or_404(user_id, dal) if isinstance(result, ScimJSONResponse): return result user = result mapping = dal.get_user_mapping_by_user_id(user.id) resource = provider.build_user_resource( user, mapping.external_id if mapping else None, groups=dal.get_user_groups(user.id), scim_username=mapping.scim_username if mapping else None, fields=_mapping_to_fields(mapping), ) # RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted excluded = _parse_excluded_attributes(excludedAttributes) if excluded: return ScimJSONResponse(content=_apply_exclusions(resource, excluded)) return _scim_resource_response(resource) @scim_router.post("/Users", status_code=201, response_model=None) def create_user( user_resource: ScimUserResource, _token: ScimToken = Depends(verify_scim_token), provider: ScimProvider = Depends(_get_provider), db_session: Session = Depends(get_session), ) -> ScimUserResource | ScimJSONResponse: """Create a new user from a SCIM provisioning request.""" dal = ScimDAL(db_session) dal.update_token_last_used(_token.id) email = user_resource.userName.strip() # Check for existing user — if they exist but aren't SCIM-managed yet, # link them to the IdP rather than rejecting with 409. external_id: str | None = user_resource.externalId scim_username: str = user_resource.userName.strip() fields: ScimMappingFields = _fields_from_resource(user_resource) existing_user = dal.get_user_by_email(email) if existing_user: existing_mapping = dal.get_user_mapping_by_user_id(existing_user.id) if existing_mapping: return _scim_error_response(409, f"User with email {email} already exists") # Adopt pre-existing user into SCIM management. # Reactivating a deactivated user consumes a seat, so enforce the # seat limit the same way replace_user does. if user_resource.active and not existing_user.is_active: seat_error = _check_seat_availability(dal) if seat_error: return _scim_error_response(403, seat_error) personal_name = _scim_name_to_str(user_resource.name) dal.update_user( existing_user, is_active=user_resource.active, **({"personal_name": personal_name} if personal_name else {}), ) try: dal.create_user_mapping( external_id=external_id, user_id=existing_user.id, scim_username=scim_username, fields=fields, ) dal.commit() except IntegrityError: dal.rollback() return _scim_error_response( 409, f"User with email {email} already has a SCIM mapping" ) return _scim_resource_response( provider.build_user_resource( existing_user, external_id, scim_username=scim_username, fields=fields, ), status_code=201, ) # Only enforce seat limit for net-new users — adopting a pre-existing # user doesn't consume a new seat. seat_error = _check_seat_availability(dal) if seat_error: return _scim_error_response(403, seat_error) # Create user with a random password (SCIM users authenticate via IdP) personal_name = _scim_name_to_str(user_resource.name) user = User( email=email, hashed_password=_pw_helper.hash(_pw_helper.generate()), role=UserRole.BASIC, account_type=AccountType.STANDARD, is_active=user_resource.active, is_verified=True, personal_name=personal_name, ) try: dal.add_user(user) except IntegrityError: dal.rollback() return _scim_error_response(409, f"User with email {email} already exists") # Always create a SCIM mapping so that the user is marked as # SCIM-managed. externalId may be None (RFC 7643 says it's optional). try: dal.create_user_mapping( external_id=external_id, user_id=user.id, scim_username=scim_username, fields=fields, ) except IntegrityError: dal.rollback() return _scim_error_response( 409, f"User with email {email} already has a SCIM mapping" ) # Assign user to default group BEFORE commit so everything is atomic. # If this fails, the entire user creation rolls back and IdP can retry. try: assign_user_to_default_groups__no_commit(db_session, user) except Exception: dal.rollback() logger.exception(f"Failed to assign SCIM user {email} to default groups") return _scim_error_response( 500, f"Failed to assign user {email} to default group" ) dal.commit() return _scim_resource_response( provider.build_user_resource( user, external_id, scim_username=scim_username, fields=fields, ), status_code=201, ) @scim_router.put("/Users/{user_id}", response_model=None) def replace_user( user_id: str, user_resource: ScimUserResource, _token: ScimToken = Depends(verify_scim_token), provider: ScimProvider = Depends(_get_provider), db_session: Session = Depends(get_session), ) -> ScimUserResource | ScimJSONResponse: """Replace a user entirely (RFC 7644 §3.5.1).""" dal = ScimDAL(db_session) dal.update_token_last_used(_token.id) result = _fetch_user_or_404(user_id, dal) if isinstance(result, ScimJSONResponse): return result user = result # Handle activation (need seat check) / deactivation is_reactivation = user_resource.active and not user.is_active if is_reactivation: seat_error = _check_seat_availability(dal) if seat_error: return _scim_error_response(403, seat_error) personal_name = _scim_name_to_str(user_resource.name) dal.update_user( user, email=user_resource.userName.strip(), is_active=user_resource.active, personal_name=personal_name, ) # Reconcile default-group membership on reactivation if is_reactivation: assign_user_to_default_groups__no_commit( db_session, user, is_admin=(user.role == UserRole.ADMIN) ) new_external_id = user_resource.externalId scim_username = user_resource.userName.strip() fields = _fields_from_resource(user_resource) dal.sync_user_external_id( user.id, new_external_id, scim_username=scim_username, fields=fields, ) dal.commit() return _scim_resource_response( provider.build_user_resource( user, new_external_id, groups=dal.get_user_groups(user.id), scim_username=scim_username, fields=fields, ) ) @scim_router.patch("/Users/{user_id}", response_model=None) def patch_user( user_id: str, patch_request: ScimPatchRequest, _token: ScimToken = Depends(verify_scim_token), provider: ScimProvider = Depends(_get_provider), db_session: Session = Depends(get_session), ) -> ScimUserResource | ScimJSONResponse: """Partially update a user (RFC 7644 §3.5.2). This is the primary endpoint for user deprovisioning — Okta sends ``PATCH {"active": false}`` rather than DELETE. """ dal = ScimDAL(db_session) dal.update_token_last_used(_token.id) result = _fetch_user_or_404(user_id, dal) if isinstance(result, ScimJSONResponse): return result user = result mapping = dal.get_user_mapping_by_user_id(user.id) external_id = mapping.external_id if mapping else None current_scim_username = mapping.scim_username if mapping else None current_fields = _mapping_to_fields(mapping) current = provider.build_user_resource( user, external_id, groups=dal.get_user_groups(user.id), scim_username=current_scim_username, fields=current_fields, ) try: patched, ent_data = apply_user_patch( patch_request.Operations, current, provider.ignored_patch_paths ) except ScimPatchError as e: return _scim_error_response(e.status, e.detail) # Apply changes back to the DB model is_reactivation = patched.active and not user.is_active if patched.active != user.is_active: if patched.active: seat_error = _check_seat_availability(dal) if seat_error: return _scim_error_response(403, seat_error) # Track the scim_username — if userName was patched, update it new_scim_username = patched.userName.strip() if patched.userName else None # If displayName was explicitly patched (different from the original), use # it as personal_name directly. Otherwise, derive from name components. personal_name: str | None if patched.displayName and patched.displayName != current.displayName: personal_name = patched.displayName else: personal_name = _scim_name_to_str(patched.name) dal.update_user( user, email=( patched.userName.strip() if patched.userName.strip().lower() != user.email.lower() else None ), is_active=patched.active if patched.active != user.is_active else None, personal_name=personal_name, ) # Reconcile default-group membership on reactivation if is_reactivation: assign_user_to_default_groups__no_commit( db_session, user, is_admin=(user.role == UserRole.ADMIN) ) # Build updated fields by merging PATCH enterprise data with current values cf = current_fields or ScimMappingFields() fields = ScimMappingFields( department=ent_data.get("department", cf.department), manager=ent_data.get("manager", cf.manager), given_name=patched.name.givenName if patched.name else cf.given_name, family_name=patched.name.familyName if patched.name else cf.family_name, scim_emails_json=( serialize_emails(patched.emails) if patched.emails is not None else cf.scim_emails_json ), ) dal.sync_user_external_id( user.id, patched.externalId, scim_username=new_scim_username, fields=fields, ) dal.commit() return _scim_resource_response( provider.build_user_resource( user, patched.externalId, groups=dal.get_user_groups(user.id), scim_username=new_scim_username, fields=fields, ) ) @scim_router.delete("/Users/{user_id}", status_code=204, response_model=None) def delete_user( user_id: str, _token: ScimToken = Depends(verify_scim_token), db_session: Session = Depends(get_session), ) -> Response | ScimJSONResponse: """Delete a user (RFC 7644 §3.6). Deactivates the user and removes the SCIM mapping. Note that Okta typically uses PATCH active=false instead of DELETE. A second DELETE returns 404 per RFC 7644 §3.6. """ dal = ScimDAL(db_session) dal.update_token_last_used(_token.id) result = _fetch_user_or_404(user_id, dal) if isinstance(result, ScimJSONResponse): return result user = result # If no SCIM mapping exists, the user was already deleted from # SCIM's perspective — return 404 per RFC 7644 §3.6. mapping = dal.get_user_mapping_by_user_id(user.id) if not mapping: return _scim_error_response(404, f"User {user_id} not found") dal.deactivate_user(user) dal.delete_user_mapping(mapping.id) dal.commit() return Response(status_code=204) # --------------------------------------------------------------------------- # Group helpers # --------------------------------------------------------------------------- def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | ScimJSONResponse: """Parse *group_id* as int, look up the group, or return a 404 error.""" try: gid = int(group_id) except ValueError: return _scim_error_response(404, f"Group {group_id} not found") group = dal.get_group(gid) if not group: return _scim_error_response(404, f"Group {group_id} not found") return group def _parse_member_uuids( members: list[ScimGroupMember], ) -> tuple[list[UUID], str | None]: """Parse member value strings to UUIDs. Returns (uuid_list, error_message). error_message is None on success. """ uuids: list[UUID] = [] for m in members: try: uuids.append(UUID(m.value)) except ValueError: return [], f"Invalid member ID: {m.value}" return uuids, None def _validate_and_parse_members( members: list[ScimGroupMember], dal: ScimDAL ) -> tuple[list[UUID], str | None]: """Parse and validate member UUIDs exist in the database. Returns (uuid_list, error_message). error_message is None on success. """ uuids, err = _parse_member_uuids(members) if err: return [], err if uuids: missing = dal.validate_member_ids(uuids) if missing: return [], f"Member(s) not found: {', '.join(str(u) for u in missing)}" return uuids, None # --------------------------------------------------------------------------- # Group CRUD (RFC 7644 §3) # --------------------------------------------------------------------------- @scim_router.get("/Groups", response_model=None) def list_groups( filter: str | None = Query(None), excludedAttributes: str | None = None, startIndex: int = Query(1, ge=1), count: int = Query(100, ge=0, le=500), _token: ScimToken = Depends(verify_scim_token), provider: ScimProvider = Depends(_get_provider), db_session: Session = Depends(get_session), ) -> ScimListResponse | ScimJSONResponse: """List groups with optional SCIM filter and pagination.""" dal = ScimDAL(db_session) dal.update_token_last_used(_token.id) dal.commit() try: scim_filter = parse_scim_filter(filter) except ValueError as e: return _scim_error_response(400, str(e)) try: groups_with_ext_ids, total = dal.list_groups(scim_filter, startIndex, count) except ValueError as e: return _scim_error_response(400, str(e)) resources: list[ScimUserResource | ScimGroupResource] = [ provider.build_group_resource(group, dal.get_group_members(group.id), ext_id) for group, ext_id in groups_with_ext_ids ] return _build_list_response( resources, total, startIndex, count, excluded=_parse_excluded_attributes(excludedAttributes), ) @scim_router.get("/Groups/{group_id}", response_model=None) def get_group( group_id: str, excludedAttributes: str | None = None, _token: ScimToken = Depends(verify_scim_token), provider: ScimProvider = Depends(_get_provider), db_session: Session = Depends(get_session), ) -> ScimGroupResource | ScimJSONResponse: """Get a single group by ID.""" dal = ScimDAL(db_session) dal.update_token_last_used(_token.id) dal.commit() result = _fetch_group_or_404(group_id, dal) if isinstance(result, ScimJSONResponse): return result group = result mapping = dal.get_group_mapping_by_group_id(group.id) members = dal.get_group_members(group.id) resource = provider.build_group_resource( group, members, mapping.external_id if mapping else None ) # RFC 7644 §3.4.2.5 — IdP may request certain attributes be omitted excluded = _parse_excluded_attributes(excludedAttributes) if excluded: return ScimJSONResponse(content=_apply_exclusions(resource, excluded)) return _scim_resource_response(resource) @scim_router.post("/Groups", status_code=201, response_model=None) def create_group( group_resource: ScimGroupResource, _token: ScimToken = Depends(verify_scim_token), provider: ScimProvider = Depends(_get_provider), db_session: Session = Depends(get_session), ) -> ScimGroupResource | ScimJSONResponse: """Create a new group from a SCIM provisioning request.""" dal = ScimDAL(db_session) dal.update_token_last_used(_token.id) if group_resource.displayName in _RESERVED_GROUP_NAMES: return _scim_error_response( 409, f"'{group_resource.displayName}' is a reserved group name." ) if dal.get_group_by_name(group_resource.displayName): return _scim_error_response( 409, f"Group with name '{group_resource.displayName}' already exists" ) member_uuids, err = _validate_and_parse_members(group_resource.members, dal) if err: return _scim_error_response(400, err) db_group = UserGroup( name=group_resource.displayName, is_up_to_date=True, time_last_modified_by_user=func.now(), ) try: dal.add_group(db_group) except IntegrityError: dal.rollback() return _scim_error_response( 409, f"Group with name '{group_resource.displayName}' already exists" ) # Every group gets the "basic" permission by default. dal.add_permission_grant_to_group( group_id=db_group.id, permission=Permission.BASIC_ACCESS, grant_source=GrantSource.SYSTEM, ) dal.upsert_group_members(db_group.id, member_uuids) # Recompute permissions for initial members. recompute_user_permissions__no_commit(member_uuids, db_session) external_id = group_resource.externalId if external_id: dal.create_group_mapping(external_id=external_id, user_group_id=db_group.id) dal.commit() members = dal.get_group_members(db_group.id) return _scim_resource_response( provider.build_group_resource(db_group, members, external_id), status_code=201, ) @scim_router.put("/Groups/{group_id}", response_model=None) def replace_group( group_id: str, group_resource: ScimGroupResource, _token: ScimToken = Depends(verify_scim_token), provider: ScimProvider = Depends(_get_provider), db_session: Session = Depends(get_session), ) -> ScimGroupResource | ScimJSONResponse: """Replace a group entirely (RFC 7644 §3.5.1).""" dal = ScimDAL(db_session) dal.update_token_last_used(_token.id) result = _fetch_group_or_404(group_id, dal) if isinstance(result, ScimJSONResponse): return result group = result if group.name in _RESERVED_GROUP_NAMES and group_resource.displayName != group.name: return _scim_error_response( 409, f"'{group.name}' is a reserved group name and cannot be renamed." ) if ( group_resource.displayName in _RESERVED_GROUP_NAMES and group_resource.displayName != group.name ): return _scim_error_response( 409, f"'{group_resource.displayName}' is a reserved group name." ) member_uuids, err = _validate_and_parse_members(group_resource.members, dal) if err: return _scim_error_response(400, err) # Capture old member IDs before replacing so we can recompute their # permissions after they are removed from the group. old_member_ids = {uid for uid, _ in dal.get_group_members(group.id)} dal.update_group(group, name=group_resource.displayName) dal.replace_group_members(group.id, member_uuids) dal.sync_group_external_id(group.id, group_resource.externalId) # Recompute permissions for current members (batch) and removed members. recompute_permissions_for_group__no_commit(group.id, db_session) removed_ids = list(old_member_ids - set(member_uuids)) recompute_user_permissions__no_commit(removed_ids, db_session) dal.commit() members = dal.get_group_members(group.id) return _scim_resource_response( provider.build_group_resource(group, members, group_resource.externalId) ) @scim_router.patch("/Groups/{group_id}", response_model=None) def patch_group( group_id: str, patch_request: ScimPatchRequest, _token: ScimToken = Depends(verify_scim_token), provider: ScimProvider = Depends(_get_provider), db_session: Session = Depends(get_session), ) -> ScimGroupResource | ScimJSONResponse: """Partially update a group (RFC 7644 §3.5.2). Handles member add/remove operations from Okta and Azure AD. """ dal = ScimDAL(db_session) dal.update_token_last_used(_token.id) result = _fetch_group_or_404(group_id, dal) if isinstance(result, ScimJSONResponse): return result group = result mapping = dal.get_group_mapping_by_group_id(group.id) external_id = mapping.external_id if mapping else None current_members = dal.get_group_members(group.id) current = provider.build_group_resource(group, current_members, external_id) try: patched, added_ids, removed_ids = apply_group_patch( patch_request.Operations, current, provider.ignored_patch_paths ) except ScimPatchError as e: return _scim_error_response(e.status, e.detail) new_name = patched.displayName if patched.displayName != group.name else None if group.name in _RESERVED_GROUP_NAMES and new_name: return _scim_error_response( 409, f"'{group.name}' is a reserved group name and cannot be renamed." ) if new_name and new_name in _RESERVED_GROUP_NAMES: return _scim_error_response(409, f"'{new_name}' is a reserved group name.") dal.update_group(group, name=new_name) affected_uuids: list[UUID] = [] if added_ids: add_uuids = [UUID(mid) for mid in added_ids if _is_valid_uuid(mid)] if add_uuids: missing = dal.validate_member_ids(add_uuids) if missing: return _scim_error_response( 400, f"Member(s) not found: {', '.join(str(u) for u in missing)}", ) dal.upsert_group_members(group.id, add_uuids) affected_uuids.extend(add_uuids) if removed_ids: remove_uuids = [UUID(mid) for mid in removed_ids if _is_valid_uuid(mid)] dal.remove_group_members(group.id, remove_uuids) affected_uuids.extend(remove_uuids) # Recompute permissions for all users whose group membership changed. recompute_user_permissions__no_commit(affected_uuids, db_session) dal.sync_group_external_id(group.id, patched.externalId) dal.commit() members = dal.get_group_members(group.id) return _scim_resource_response( provider.build_group_resource(group, members, patched.externalId) ) @scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None) def delete_group( group_id: str, _token: ScimToken = Depends(verify_scim_token), db_session: Session = Depends(get_session), ) -> Response | ScimJSONResponse: """Delete a group (RFC 7644 §3.6).""" dal = ScimDAL(db_session) dal.update_token_last_used(_token.id) result = _fetch_group_or_404(group_id, dal) if isinstance(result, ScimJSONResponse): return result group = result if group.name in _RESERVED_GROUP_NAMES: return _scim_error_response(409, f"'{group.name}' is a reserved group name.") # Capture member IDs before deletion so we can recompute their permissions. affected_user_ids = [uid for uid, _ in dal.get_group_members(group.id)] mapping = dal.get_group_mapping_by_group_id(group.id) if mapping: dal.delete_group_mapping(mapping.id) dal.delete_group_with_members(group) # Recompute permissions for users who lost this group membership. recompute_user_permissions__no_commit(affected_user_ids, db_session) dal.commit() return Response(status_code=204) def _is_valid_uuid(value: str) -> bool: """Check if a string is a valid UUID.""" try: UUID(value) return True except ValueError: return False ================================================ FILE: backend/ee/onyx/server/scim/auth.py ================================================ """SCIM bearer token authentication. SCIM endpoints are authenticated via bearer tokens that admins create in the Onyx UI. This module provides: - ``verify_scim_token``: FastAPI dependency that extracts, hashes, and validates the token from the Authorization header. - ``generate_scim_token``: Creates a new cryptographically random token and returns the raw value, its SHA-256 hash, and a display suffix. Token format: ``onyx_scim_`` where ```` is 48 bytes of URL-safe base64 from ``secrets.token_urlsafe``. The hash is stored in the ``scim_token`` table; the raw value is shown to the admin exactly once at creation time. """ import hashlib import secrets from fastapi import Depends from fastapi import Request from sqlalchemy.orm import Session from ee.onyx.db.scim import ScimDAL from onyx.auth.utils import get_hashed_bearer_token_from_request from onyx.db.engine.sql_engine import get_session from onyx.db.models import ScimToken class ScimAuthError(Exception): """Raised when SCIM bearer token authentication fails. Unlike HTTPException, this carries the status and detail so the SCIM exception handler can wrap them in an RFC 7644 §3.12 error envelope with ``schemas`` and ``status`` fields. """ def __init__(self, status_code: int, detail: str) -> None: self.status_code = status_code self.detail = detail super().__init__(detail) SCIM_TOKEN_PREFIX = "onyx_scim_" SCIM_TOKEN_LENGTH = 48 def _hash_scim_token(token: str) -> str: """SHA-256 hash a SCIM token. No salt needed — tokens are random.""" return hashlib.sha256(token.encode("utf-8")).hexdigest() def generate_scim_token() -> tuple[str, str, str]: """Generate a new SCIM bearer token. Returns: A tuple of ``(raw_token, hashed_token, token_display)`` where ``token_display`` is a masked version showing only the last 4 chars. """ raw_token = SCIM_TOKEN_PREFIX + secrets.token_urlsafe(SCIM_TOKEN_LENGTH) hashed_token = _hash_scim_token(raw_token) token_display = SCIM_TOKEN_PREFIX + "****" + raw_token[-4:] return raw_token, hashed_token, token_display def _get_hashed_scim_token_from_request(request: Request) -> str | None: """Extract and hash a SCIM token from the request Authorization header.""" return get_hashed_bearer_token_from_request( request, valid_prefixes=[SCIM_TOKEN_PREFIX], hash_fn=_hash_scim_token, ) def _get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL: return ScimDAL(db_session) def verify_scim_token( request: Request, dal: ScimDAL = Depends(_get_scim_dal), ) -> ScimToken: """FastAPI dependency that authenticates SCIM requests. Extracts the bearer token from the Authorization header, hashes it, looks it up in the database, and verifies it is active. Note: This dependency does NOT update ``last_used_at`` — the endpoint should do that via ``ScimDAL.update_token_last_used()`` so the timestamp write is part of the endpoint's transaction. Raises: HTTPException(401): If the token is missing, invalid, or inactive. """ hashed = _get_hashed_scim_token_from_request(request) if not hashed: raise ScimAuthError(401, "Missing or invalid SCIM bearer token") token = dal.get_token_by_hash(hashed) if not token: raise ScimAuthError(401, "Invalid SCIM bearer token") if not token.is_active: raise ScimAuthError(401, "SCIM token has been revoked") return token ================================================ FILE: backend/ee/onyx/server/scim/filtering.py ================================================ """SCIM filter expression parser (RFC 7644 §3.4.2.2). Identity providers (Okta, Azure AD, OneLogin, etc.) use filters to look up resources before deciding whether to create or update them. For example, when an admin assigns a user to the Onyx app, the IdP first checks whether that user already exists:: GET /scim/v2/Users?filter=userName eq "john@example.com" If zero results come back the IdP creates the user (``POST``); if a match is found it links to the existing record and uses ``PUT``/``PATCH`` going forward. The same pattern applies to groups (``displayName eq "Engineering"``). This module parses the subset of the SCIM filter grammar that identity providers actually send in practice: attribute SP operator SP value Supported operators: ``eq``, ``co`` (contains), ``sw`` (starts with). Compound filters (``and`` / ``or``) are not supported; if an IdP sends one the parser returns ``None`` and the caller falls back to an unfiltered list. """ from __future__ import annotations import re from dataclasses import dataclass from enum import Enum class ScimFilterOperator(str, Enum): """Supported SCIM filter operators.""" EQUAL = "eq" CONTAINS = "co" STARTS_WITH = "sw" @dataclass(frozen=True, slots=True) class ScimFilter: """Parsed SCIM filter expression.""" attribute: str operator: ScimFilterOperator value: str # Matches: attribute operator "value" (with or without quotes around value) # Groups: (attribute) (operator) ("quoted value" | unquoted_value) _FILTER_RE = re.compile( r"^(\S+)\s+(eq|co|sw)\s+" # attribute + operator r'(?:"([^"]*)"' # quoted value r"|'([^']*)')" # or single-quoted value r"$", re.IGNORECASE, ) def parse_scim_filter(filter_string: str | None) -> ScimFilter | None: """Parse a simple SCIM filter expression. Args: filter_string: Raw filter query parameter value, e.g. ``'userName eq "john@example.com"'`` Returns: A ``ScimFilter`` if the expression is valid and uses a supported operator, or ``None`` if the input is empty / missing. Raises: ValueError: If the filter string is present but malformed or uses an unsupported operator. """ if not filter_string or not filter_string.strip(): return None match = _FILTER_RE.match(filter_string.strip()) if not match: raise ValueError(f"Unsupported or malformed SCIM filter: {filter_string}") return _build_filter(match, filter_string) def _build_filter(match: re.Match[str], raw: str) -> ScimFilter: """Extract fields from a regex match and construct a ScimFilter.""" attribute = match.group(1) op_str = match.group(2).lower() # Value is in group 3 (double-quoted) or group 4 (single-quoted) value = match.group(3) if match.group(3) is not None else match.group(4) if value is None: raise ValueError(f"Unsupported or malformed SCIM filter: {raw}") operator = ScimFilterOperator(op_str) return ScimFilter(attribute=attribute, operator=operator, value=value) ================================================ FILE: backend/ee/onyx/server/scim/models.py ================================================ """Pydantic schemas for SCIM 2.0 provisioning (RFC 7643 / RFC 7644). SCIM protocol schemas follow the wire format defined in: - Core Schema: https://datatracker.ietf.org/doc/html/rfc7643 - Protocol: https://datatracker.ietf.org/doc/html/rfc7644 Admin API schemas are internal to Onyx and used for SCIM token management. """ from dataclasses import dataclass from datetime import datetime from enum import Enum from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field from pydantic import field_validator # --------------------------------------------------------------------------- # SCIM Schema URIs (RFC 7643 §8) # Every SCIM JSON payload includes a "schemas" array identifying its type. # IdPs like Okta/Azure AD use these URIs to determine how to parse responses. # --------------------------------------------------------------------------- SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User" SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group" SCIM_LIST_RESPONSE_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:ListResponse" SCIM_PATCH_OP_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp" SCIM_ERROR_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:Error" SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = ( "urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig" ) SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType" SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema" SCIM_ENTERPRISE_USER_SCHEMA = ( "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User" ) # --------------------------------------------------------------------------- # SCIM Protocol Schemas # --------------------------------------------------------------------------- class ScimName(BaseModel): """User name components (RFC 7643 §4.1.1).""" givenName: str | None = None familyName: str | None = None formatted: str | None = None class ScimEmail(BaseModel): """Email sub-attribute (RFC 7643 §4.1.2).""" value: str type: str | None = None primary: bool = False class ScimMeta(BaseModel): """Resource metadata (RFC 7643 §3.1).""" resourceType: str | None = None created: datetime | None = None lastModified: datetime | None = None location: str | None = None class ScimUserGroupRef(BaseModel): """Group reference within a User resource (RFC 7643 §4.1.2, read-only).""" value: str display: str | None = None class ScimManagerRef(BaseModel): """Manager sub-attribute for the enterprise extension (RFC 7643 §4.3).""" value: str | None = None class ScimEnterpriseExtension(BaseModel): """Enterprise User extension attributes (RFC 7643 §4.3).""" department: str | None = None manager: ScimManagerRef | None = None @dataclass class ScimMappingFields: """Stored SCIM mapping fields that need to round-trip through the IdP. Entra ID sends structured name components, email metadata, and enterprise extension attributes that must be returned verbatim in subsequent GET responses. These fields are persisted on ScimUserMapping and threaded through the DAL, provider, and endpoint layers. """ department: str | None = None manager: str | None = None given_name: str | None = None family_name: str | None = None scim_emails_json: str | None = None class ScimUserResource(BaseModel): """SCIM User resource representation (RFC 7643 §4.1). This is the JSON shape that IdPs send when creating/updating a user via SCIM, and the shape we return in GET responses. Field names use camelCase to match the SCIM wire format (not Python convention). """ model_config = ConfigDict(populate_by_name=True) schemas: list[str] = Field(default_factory=lambda: [SCIM_USER_SCHEMA]) id: str | None = None # Onyx's internal user ID, set on responses externalId: str | None = None # IdP's identifier for this user userName: str # Typically the user's email address name: ScimName | None = None displayName: str | None = None emails: list[ScimEmail] = Field(default_factory=list) active: bool = True groups: list[ScimUserGroupRef] = Field(default_factory=list) meta: ScimMeta | None = None enterprise_extension: ScimEnterpriseExtension | None = Field( default=None, alias="urn:ietf:params:scim:schemas:extension:enterprise:2.0:User", ) class ScimGroupMember(BaseModel): """Group member reference (RFC 7643 §4.2). Represents a user within a SCIM group. The IdP sends these when adding or removing users from groups. ``value`` is the Onyx user ID. """ value: str # User ID of the group member display: str | None = None class ScimGroupResource(BaseModel): """SCIM Group resource representation (RFC 7643 §4.2).""" schemas: list[str] = Field(default_factory=lambda: [SCIM_GROUP_SCHEMA]) id: str | None = None externalId: str | None = None displayName: str members: list[ScimGroupMember] = Field(default_factory=list) meta: ScimMeta | None = None class ScimListResponse(BaseModel): """Paginated list response (RFC 7644 §3.4.2).""" schemas: list[str] = Field(default_factory=lambda: [SCIM_LIST_RESPONSE_SCHEMA]) totalResults: int startIndex: int = 1 itemsPerPage: int = 100 Resources: list[ScimUserResource | ScimGroupResource] = Field(default_factory=list) class ScimPatchOperationType(str, Enum): """Supported PATCH operations (RFC 7644 §3.5.2).""" ADD = "add" REPLACE = "replace" REMOVE = "remove" class ScimPatchResourceValue(BaseModel): """Partial resource dict for path-less PATCH replace operations. When an IdP sends a PATCH without a ``path``, the ``value`` is a dict of resource attributes to set. IdPs may include read-only fields (``id``, ``schemas``, ``meta``) alongside actual changes — these are stripped by the provider's ``ignored_patch_paths`` before processing. ``extra="allow"`` lets unknown attributes pass through so the patch handler can decide what to do with them (ignore or reject). """ model_config = ConfigDict(extra="allow") active: bool | None = None userName: str | None = None displayName: str | None = None externalId: str | None = None name: ScimName | None = None members: list[ScimGroupMember] | None = None id: str | None = None schemas: list[str] | None = None meta: ScimMeta | None = None ScimPatchValue = str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None class ScimPatchOperation(BaseModel): """Single PATCH operation (RFC 7644 §3.5.2).""" op: ScimPatchOperationType path: str | None = None value: ScimPatchValue = None @field_validator("op", mode="before") @classmethod def normalize_operation(cls, v: object) -> object: """Normalize op to lowercase for case-insensitive matching. Some IdPs (e.g. Entra ID) send capitalized ops like ``"Replace"`` instead of ``"replace"``. This is safe for all providers since the enum values are lowercase. If a future provider requires other pre-processing quirks, move patch deserialization into the provider subclass instead of adding more special cases here. """ return v.lower() if isinstance(v, str) else v class ScimPatchRequest(BaseModel): """PATCH request body (RFC 7644 §3.5.2). IdPs use PATCH to make incremental changes — e.g. deactivating a user (replace active=false) or adding/removing group members — instead of replacing the entire resource with PUT. """ schemas: list[str] = Field(default_factory=lambda: [SCIM_PATCH_OP_SCHEMA]) Operations: list[ScimPatchOperation] class ScimError(BaseModel): """SCIM error response (RFC 7644 §3.12).""" schemas: list[str] = Field(default_factory=lambda: [SCIM_ERROR_SCHEMA]) status: str detail: str | None = None scimType: str | None = None # --------------------------------------------------------------------------- # Service Provider Configuration (RFC 7643 §5) # --------------------------------------------------------------------------- class ScimSupported(BaseModel): """Generic supported/not-supported flag used in ServiceProviderConfig.""" supported: bool class ScimFilterConfig(BaseModel): """Filter configuration within ServiceProviderConfig (RFC 7643 §5).""" supported: bool maxResults: int = 100 class ScimServiceProviderConfig(BaseModel): """SCIM ServiceProviderConfig resource (RFC 7643 §5). Served at GET /scim/v2/ServiceProviderConfig. IdPs fetch this during initial setup to discover which SCIM features our server supports (e.g. PATCH yes, bulk no, filtering yes). """ schemas: list[str] = Field( default_factory=lambda: [SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA] ) patch: ScimSupported = ScimSupported(supported=True) bulk: ScimSupported = ScimSupported(supported=False) filter: ScimFilterConfig = ScimFilterConfig(supported=True) changePassword: ScimSupported = ScimSupported(supported=False) sort: ScimSupported = ScimSupported(supported=False) etag: ScimSupported = ScimSupported(supported=False) authenticationSchemes: list[dict[str, str]] = Field( default_factory=lambda: [ { "type": "oauthbearertoken", "name": "OAuth Bearer Token", "description": "Authentication scheme using a SCIM bearer token", } ] ) class ScimSchemaAttribute(BaseModel): """Attribute definition within a SCIM Schema (RFC 7643 §7).""" name: str type: str multiValued: bool = False required: bool = False description: str = "" caseExact: bool = False mutability: str = "readWrite" returned: str = "default" uniqueness: str = "none" subAttributes: list["ScimSchemaAttribute"] = Field(default_factory=list) class ScimSchemaDefinition(BaseModel): """SCIM Schema definition (RFC 7643 §7). Served at GET /scim/v2/Schemas. Describes the attributes available on each resource type so IdPs know which fields they can provision. """ schemas: list[str] = Field(default_factory=lambda: [SCIM_SCHEMA_SCHEMA]) id: str name: str description: str attributes: list[ScimSchemaAttribute] = Field(default_factory=list) class ScimSchemaExtension(BaseModel): """Schema extension reference within ResourceType (RFC 7643 §6).""" model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True) schema_: str = Field(alias="schema") required: bool class ScimResourceType(BaseModel): """SCIM ResourceType resource (RFC 7643 §6). Served at GET /scim/v2/ResourceTypes. Tells the IdP which resource types are available (Users, Groups) and their respective endpoints. """ model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True) schemas: list[str] = Field(default_factory=lambda: [SCIM_RESOURCE_TYPE_SCHEMA]) id: str name: str endpoint: str description: str | None = None schema_: str = Field(alias="schema") schemaExtensions: list[ScimSchemaExtension] = Field(default_factory=list) # --------------------------------------------------------------------------- # Admin API Schemas (Onyx-internal, for SCIM token management) # These are NOT part of the SCIM protocol. They power the Onyx admin UI # where admins create/revoke the bearer tokens that IdPs use to authenticate. # --------------------------------------------------------------------------- class ScimTokenCreate(BaseModel): """Request to create a new SCIM bearer token.""" name: str class ScimTokenResponse(BaseModel): """SCIM token metadata returned in list/get responses.""" id: int name: str token_display: str is_active: bool created_at: datetime last_used_at: datetime | None = None idp_domain: str | None = None class ScimTokenCreatedResponse(ScimTokenResponse): """Response returned when a new SCIM token is created. Includes the raw token value which is only available at creation time. """ raw_token: str ================================================ FILE: backend/ee/onyx/server/scim/patch.py ================================================ """SCIM PATCH operation handler (RFC 7644 §3.5.2). Identity providers use PATCH to make incremental changes to SCIM resources instead of replacing the entire resource with PUT. Common operations include: - Deactivating a user: ``replace`` ``active`` with ``false`` - Adding group members: ``add`` to ``members`` - Removing group members: ``remove`` from ``members[value eq "..."]`` This module applies PATCH operations to Pydantic SCIM resource objects and returns the modified result. It does NOT touch the database — the caller is responsible for persisting changes. """ from __future__ import annotations import logging import re from dataclasses import dataclass from dataclasses import field from typing import Any from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA from ee.onyx.server.scim.models import ScimGroupMember from ee.onyx.server.scim.models import ScimGroupResource from ee.onyx.server.scim.models import ScimPatchOperation from ee.onyx.server.scim.models import ScimPatchOperationType from ee.onyx.server.scim.models import ScimPatchResourceValue from ee.onyx.server.scim.models import ScimPatchValue from ee.onyx.server.scim.models import ScimUserResource logger = logging.getLogger(__name__) # Lowercased enterprise extension URN for case-insensitive matching _ENTERPRISE_URN_LOWER = SCIM_ENTERPRISE_USER_SCHEMA.lower() # Pattern for email filter paths, e.g.: # emails[primary eq true].value (Okta) # emails[type eq "work"].value (Azure AD / Entra ID) _EMAIL_FILTER_RE = re.compile( r"^emails\[.+\]\.value$", re.IGNORECASE, ) # Pattern for member removal path: members[value eq "user-id"] _MEMBER_FILTER_RE = re.compile( r'^members\[value\s+eq\s+"([^"]+)"\]$', re.IGNORECASE, ) # --------------------------------------------------------------------------- # Dispatch tables for user PATCH paths # # Maps lowercased SCIM path → (camelCase key, target dict name). # "data" writes to the top-level resource dict, "name" writes to the # name sub-object dict. This replaces the elif chains for simple fields. # --------------------------------------------------------------------------- _USER_REPLACE_PATHS: dict[str, tuple[str, str]] = { "active": ("active", "data"), "username": ("userName", "data"), "externalid": ("externalId", "data"), "name.givenname": ("givenName", "name"), "name.familyname": ("familyName", "name"), "name.formatted": ("formatted", "name"), } _USER_REMOVE_PATHS: dict[str, tuple[str, str]] = { "externalid": ("externalId", "data"), "name.givenname": ("givenName", "name"), "name.familyname": ("familyName", "name"), "name.formatted": ("formatted", "name"), "displayname": ("displayName", "data"), } _GROUP_REPLACE_PATHS: dict[str, tuple[str, str]] = { "displayname": ("displayName", "data"), "externalid": ("externalId", "data"), } class ScimPatchError(Exception): """Raised when a PATCH operation cannot be applied.""" def __init__(self, detail: str, status: int = 400) -> None: self.detail = detail self.status = status super().__init__(detail) @dataclass class _UserPatchCtx: """Bundles the mutable state for user PATCH operations.""" data: dict[str, Any] name_data: dict[str, Any] ent_data: dict[str, str | None] = field(default_factory=dict) # --------------------------------------------------------------------------- # User PATCH # --------------------------------------------------------------------------- def apply_user_patch( operations: list[ScimPatchOperation], current: ScimUserResource, ignored_paths: frozenset[str] = frozenset(), ) -> tuple[ScimUserResource, dict[str, str | None]]: """Apply SCIM PATCH operations to a user resource. Args: operations: The PATCH operations to apply. current: The current user resource state. ignored_paths: SCIM attribute paths to silently skip (from provider). Returns: A tuple of (modified user resource, enterprise extension data dict). The enterprise dict has keys ``"department"`` and ``"manager"`` with values set only when a PATCH operation touched them. Raises: ScimPatchError: If an operation targets an unsupported path. """ data = current.model_dump() ctx = _UserPatchCtx(data=data, name_data=data.get("name") or {}) for op in operations: if op.op in (ScimPatchOperationType.REPLACE, ScimPatchOperationType.ADD): _apply_user_replace(op, ctx, ignored_paths) elif op.op == ScimPatchOperationType.REMOVE: _apply_user_remove(op, ctx, ignored_paths) else: raise ScimPatchError( f"Unsupported operation '{op.op.value}' on User resource" ) ctx.data["name"] = ctx.name_data return ScimUserResource.model_validate(ctx.data), ctx.ent_data def _apply_user_replace( op: ScimPatchOperation, ctx: _UserPatchCtx, ignored_paths: frozenset[str], ) -> None: """Apply a replace/add operation to user data.""" path = (op.path or "").lower() if not path: # No path — value is a resource dict of top-level attributes to set. if isinstance(op.value, ScimPatchResourceValue): for key, val in op.value.model_dump(exclude_unset=True).items(): _set_user_field(key.lower(), val, ctx, ignored_paths, strict=False) else: raise ScimPatchError("Replace without path requires a dict value") return _set_user_field(path, op.value, ctx, ignored_paths) def _apply_user_remove( op: ScimPatchOperation, ctx: _UserPatchCtx, ignored_paths: frozenset[str], ) -> None: """Apply a remove operation to user data — clears the target field.""" path = (op.path or "").lower() if not path: raise ScimPatchError("Remove operation requires a path") if path in ignored_paths: return entry = _USER_REMOVE_PATHS.get(path) if entry: key, target = entry target_dict = ctx.data if target == "data" else ctx.name_data target_dict[key] = None return raise ScimPatchError(f"Unsupported remove path '{path}' for User PATCH") def _set_user_field( path: str, value: ScimPatchValue, ctx: _UserPatchCtx, ignored_paths: frozenset[str], *, strict: bool = True, ) -> None: """Set a single field on user data by SCIM path. Args: strict: When ``False`` (path-less replace), unknown attributes are silently skipped. When ``True`` (explicit path), they raise. """ if path in ignored_paths: return # Simple field writes handled by the dispatch table entry = _USER_REPLACE_PATHS.get(path) if entry: key, target = entry target_dict = ctx.data if target == "data" else ctx.name_data target_dict[key] = value return # displayName sets both the top-level field and the name.formatted sub-field if path == "displayname": ctx.data["displayName"] = value ctx.name_data["formatted"] = value elif path == "name": if isinstance(value, dict): for k, v in value.items(): ctx.name_data[k] = v elif path == "emails": if isinstance(value, list): ctx.data["emails"] = value elif _EMAIL_FILTER_RE.match(path): _update_primary_email(ctx.data, value) elif path.startswith(_ENTERPRISE_URN_LOWER): _set_enterprise_field(path, value, ctx.ent_data) elif not strict: return else: raise ScimPatchError(f"Unsupported path '{path}' for User PATCH") def _update_primary_email(data: dict[str, Any], value: ScimPatchValue) -> None: """Update the primary email entry via an email filter path.""" emails: list[dict] = data.get("emails") or [] for email_entry in emails: if email_entry.get("primary"): email_entry["value"] = value break else: emails.append({"value": value, "type": "work", "primary": True}) data["emails"] = emails def _to_dict(value: ScimPatchValue) -> dict | None: """Coerce a SCIM patch value to a plain dict if possible. Pydantic may parse raw dicts as ``ScimPatchResourceValue`` (which uses ``extra="allow"``), so we also dump those back to a dict. """ if isinstance(value, dict): return value if isinstance(value, ScimPatchResourceValue): return value.model_dump(exclude_unset=True) return None def _set_enterprise_field( path: str, value: ScimPatchValue, ent_data: dict[str, str | None], ) -> None: """Handle enterprise extension URN paths or value dicts.""" # Full URN as key with dict value (path-less PATCH) # e.g. key="urn:...:user", value={"department": "Eng", "manager": {...}} if path == _ENTERPRISE_URN_LOWER: d = _to_dict(value) if d is not None: if "department" in d: ent_data["department"] = d["department"] if "manager" in d: mgr = d["manager"] if isinstance(mgr, dict): ent_data["manager"] = mgr.get("value") return # Dotted URN path, e.g. "urn:...:user:department" suffix = path[len(_ENTERPRISE_URN_LOWER) :].lstrip(":").lower() if suffix == "department": ent_data["department"] = str(value) if value is not None else None elif suffix == "manager": d = _to_dict(value) if d is not None: ent_data["manager"] = d.get("value") elif isinstance(value, str): ent_data["manager"] = value else: # Unknown enterprise attributes are silently ignored rather than # rejected — IdPs may send attributes we don't model yet. logger.warning("Ignoring unknown enterprise extension attribute '%s'", suffix) # --------------------------------------------------------------------------- # Group PATCH # --------------------------------------------------------------------------- def apply_group_patch( operations: list[ScimPatchOperation], current: ScimGroupResource, ignored_paths: frozenset[str] = frozenset(), ) -> tuple[ScimGroupResource, list[str], list[str]]: """Apply SCIM PATCH operations to a group resource. Args: operations: The PATCH operations to apply. current: The current group resource state. ignored_paths: SCIM attribute paths to silently skip (from provider). Returns: A tuple of (modified group, added member IDs, removed member IDs). The caller uses the member ID lists to update the database. Raises: ScimPatchError: If an operation targets an unsupported path. """ data = current.model_dump() current_members: list[dict] = list(data.get("members") or []) added_ids: list[str] = [] removed_ids: list[str] = [] for op in operations: if op.op == ScimPatchOperationType.REPLACE: _apply_group_replace( op, data, current_members, added_ids, removed_ids, ignored_paths ) elif op.op == ScimPatchOperationType.ADD: _apply_group_add(op, current_members, added_ids) elif op.op == ScimPatchOperationType.REMOVE: _apply_group_remove(op, current_members, removed_ids) else: raise ScimPatchError( f"Unsupported operation '{op.op.value}' on Group resource" ) data["members"] = current_members group = ScimGroupResource.model_validate(data) return group, added_ids, removed_ids def _apply_group_replace( op: ScimPatchOperation, data: dict, current_members: list[dict], added_ids: list[str], removed_ids: list[str], ignored_paths: frozenset[str], ) -> None: """Apply a replace operation to group data.""" path = (op.path or "").lower() if not path: if isinstance(op.value, ScimPatchResourceValue): dumped = op.value.model_dump(exclude_unset=True) for key, val in dumped.items(): if key.lower() == "members": _replace_members(val, current_members, added_ids, removed_ids) else: _set_group_field(key.lower(), val, data, ignored_paths) else: raise ScimPatchError("Replace without path requires a dict value") return if path == "members": _replace_members( _members_to_dicts(op.value), current_members, added_ids, removed_ids ) return _set_group_field(path, op.value, data, ignored_paths) def _members_to_dicts( value: str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None, ) -> list[dict]: """Convert a member list value to a list of dicts for internal processing.""" if not isinstance(value, list): raise ScimPatchError("Replace members requires a list value") return [m.model_dump(exclude_none=True) for m in value] def _replace_members( value: list[dict], current_members: list[dict], added_ids: list[str], removed_ids: list[str], ) -> None: """Replace the entire group member list.""" old_ids = {m["value"] for m in current_members} new_ids = {m.get("value", "") for m in value} removed_ids.extend(old_ids - new_ids) added_ids.extend(new_ids - old_ids) current_members[:] = value def _set_group_field( path: str, value: ScimPatchValue, data: dict, ignored_paths: frozenset[str], ) -> None: """Set a single field on group data by SCIM path.""" if path in ignored_paths: return entry = _GROUP_REPLACE_PATHS.get(path) if entry: key, _ = entry data[key] = value return raise ScimPatchError(f"Unsupported path '{path}' for Group PATCH") def _apply_group_add( op: ScimPatchOperation, members: list[dict], added_ids: list[str], ) -> None: """Add members to a group.""" path = (op.path or "").lower() if path and path != "members": raise ScimPatchError(f"Unsupported add path '{op.path}' for Group") if not isinstance(op.value, list): raise ScimPatchError("Add members requires a list value") member_dicts = [m.model_dump(exclude_none=True) for m in op.value] existing_ids = {m["value"] for m in members} for member_data in member_dicts: member_id = member_data.get("value", "") if member_id and member_id not in existing_ids: members.append(member_data) added_ids.append(member_id) existing_ids.add(member_id) def _apply_group_remove( op: ScimPatchOperation, members: list[dict], removed_ids: list[str], ) -> None: """Remove members from a group.""" if not op.path: raise ScimPatchError("Remove operation requires a path") match = _MEMBER_FILTER_RE.match(op.path) if not match: raise ScimPatchError( f"Unsupported remove path '{op.path}'. Expected: members[value eq \"user-id\"]" ) target_id = match.group(1) original_len = len(members) members[:] = [m for m in members if m.get("value") != target_id] if len(members) < original_len: removed_ids.append(target_id) ================================================ FILE: backend/ee/onyx/server/scim/providers/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/server/scim/providers/base.py ================================================ """Base SCIM provider abstraction.""" from __future__ import annotations import json import logging from abc import ABC from abc import abstractmethod from uuid import UUID from pydantic import ValidationError from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA from ee.onyx.server.scim.models import SCIM_USER_SCHEMA from ee.onyx.server.scim.models import ScimEmail from ee.onyx.server.scim.models import ScimEnterpriseExtension from ee.onyx.server.scim.models import ScimGroupMember from ee.onyx.server.scim.models import ScimGroupResource from ee.onyx.server.scim.models import ScimManagerRef from ee.onyx.server.scim.models import ScimMappingFields from ee.onyx.server.scim.models import ScimMeta from ee.onyx.server.scim.models import ScimName from ee.onyx.server.scim.models import ScimUserGroupRef from ee.onyx.server.scim.models import ScimUserResource from onyx.db.models import User from onyx.db.models import UserGroup logger = logging.getLogger(__name__) COMMON_IGNORED_PATCH_PATHS: frozenset[str] = frozenset( { "id", "schemas", "meta", } ) class ScimProvider(ABC): """Base class for provider-specific SCIM behavior. Subclass this to handle IdP-specific quirks. The base class provides RFC 7643-compliant response builders that populate all standard fields. """ @property @abstractmethod def name(self) -> str: """Short identifier for this provider (e.g. ``"okta"``).""" ... @property @abstractmethod def ignored_patch_paths(self) -> frozenset[str]: """SCIM attribute paths to silently skip in PATCH value-object dicts. IdPs may include read-only or meta fields alongside actual changes (e.g. Okta sends ``{"id": "...", "active": false}``). Paths listed here are silently dropped instead of raising an error. """ ... @property def user_schemas(self) -> list[str]: """Schema URIs to include in User resource responses. Override in subclasses to advertise additional schemas (e.g. the enterprise extension for Entra ID). """ return [SCIM_USER_SCHEMA] def build_user_resource( self, user: User, external_id: str | None = None, groups: list[tuple[int, str]] | None = None, scim_username: str | None = None, fields: ScimMappingFields | None = None, ) -> ScimUserResource: """Build a SCIM User response from an Onyx User. Args: user: The Onyx user model. external_id: The IdP's external identifier for this user. groups: List of ``(group_id, group_name)`` tuples for the ``groups`` read-only attribute. Pass ``None`` or ``[]`` for newly-created users. scim_username: The original-case userName from the IdP. Falls back to ``user.email`` (lowercase) when not available. fields: Stored mapping fields that the IdP expects round-tripped. """ f = fields or ScimMappingFields() group_refs = [ ScimUserGroupRef(value=str(gid), display=gname) for gid, gname in (groups or []) ] username = scim_username or user.email # Build enterprise extension when at least one value is present. # Dynamically add the enterprise URN to schemas per RFC 7643 §3.0. enterprise_ext: ScimEnterpriseExtension | None = None schemas = list(self.user_schemas) if f.department is not None or f.manager is not None: manager_ref = ( ScimManagerRef(value=f.manager) if f.manager is not None else None ) enterprise_ext = ScimEnterpriseExtension( department=f.department, manager=manager_ref, ) if SCIM_ENTERPRISE_USER_SCHEMA not in schemas: schemas.append(SCIM_ENTERPRISE_USER_SCHEMA) name = self.build_scim_name(user, f) emails = _deserialize_emails(f.scim_emails_json, username) resource = ScimUserResource( schemas=schemas, id=str(user.id), externalId=external_id, userName=username, name=name, displayName=user.personal_name, emails=emails, active=user.is_active, groups=group_refs, meta=ScimMeta(resourceType="User"), ) resource.enterprise_extension = enterprise_ext return resource def build_group_resource( self, group: UserGroup, members: list[tuple[UUID, str | None]], external_id: str | None = None, ) -> ScimGroupResource: """Build a SCIM Group response from an Onyx UserGroup.""" scim_members = [ ScimGroupMember(value=str(uid), display=email) for uid, email in members ] return ScimGroupResource( id=str(group.id), externalId=external_id, displayName=group.name, members=scim_members, meta=ScimMeta(resourceType="Group"), ) def build_scim_name( self, user: User, fields: ScimMappingFields, ) -> ScimName: """Build SCIM name components for the response. Round-trips stored ``given_name``/``family_name`` when available (so the IdP gets back what it sent). Falls back to splitting ``personal_name`` for users provisioned before we stored components. Always returns a ScimName — Okta's spec tests expect ``name`` (with ``givenName``/``familyName``) on every user resource. Providers may override for custom behavior. """ if fields.given_name is not None or fields.family_name is not None: return ScimName( givenName=fields.given_name or "", familyName=fields.family_name or "", formatted=user.personal_name or "", ) if not user.personal_name: # Derive a reasonable name from the email so that SCIM spec tests # see non-empty givenName / familyName for every user resource. local = user.email.split("@")[0] if user.email else "" return ScimName(givenName=local, familyName="", formatted=local) parts = user.personal_name.split(" ", 1) return ScimName( givenName=parts[0], familyName=parts[1] if len(parts) > 1 else "", formatted=user.personal_name, ) def _deserialize_emails(stored_json: str | None, username: str) -> list[ScimEmail]: """Deserialize stored email entries or build a default work email.""" if stored_json: try: entries = json.loads(stored_json) if isinstance(entries, list) and entries: return [ScimEmail(**e) for e in entries] except (json.JSONDecodeError, TypeError, ValidationError): logger.warning( "Corrupt scim_emails_json, falling back to default: %s", stored_json ) return [ScimEmail(value=username, type="work", primary=True)] def serialize_emails(emails: list[ScimEmail]) -> str | None: """Serialize SCIM email entries to JSON for storage.""" if not emails: return None return json.dumps([e.model_dump(exclude_none=True) for e in emails]) def get_default_provider() -> ScimProvider: """Return the default SCIM provider. Currently returns ``OktaProvider`` since Okta is the primary supported IdP. When provider detection is added (via token metadata or tenant config), this can be replaced with dynamic resolution. """ from ee.onyx.server.scim.providers.okta import OktaProvider return OktaProvider() ================================================ FILE: backend/ee/onyx/server/scim/providers/entra.py ================================================ """Entra ID (Azure AD) SCIM provider.""" from __future__ import annotations from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA from ee.onyx.server.scim.models import SCIM_USER_SCHEMA from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS from ee.onyx.server.scim.providers.base import ScimProvider _ENTRA_IGNORED_PATCH_PATHS = COMMON_IGNORED_PATCH_PATHS class EntraProvider(ScimProvider): """Entra ID (Azure AD) SCIM provider. Entra behavioral notes: - Sends capitalized PATCH ops (``"Add"``, ``"Replace"``, ``"Remove"``) — handled by ``ScimPatchOperation.normalize_op`` validator. - Sends the enterprise extension URN as a key in path-less PATCH value dicts — handled by ``_set_enterprise_field`` in ``patch.py`` to store department/manager values. - Expects the enterprise extension schema in ``schemas`` arrays and ``/Schemas`` + ``/ResourceTypes`` discovery endpoints. """ @property def name(self) -> str: return "entra" @property def ignored_patch_paths(self) -> frozenset[str]: return _ENTRA_IGNORED_PATCH_PATHS @property def user_schemas(self) -> list[str]: return [SCIM_USER_SCHEMA, SCIM_ENTERPRISE_USER_SCHEMA] ================================================ FILE: backend/ee/onyx/server/scim/providers/okta.py ================================================ """Okta SCIM provider.""" from __future__ import annotations from ee.onyx.server.scim.providers.base import COMMON_IGNORED_PATCH_PATHS from ee.onyx.server.scim.providers.base import ScimProvider class OktaProvider(ScimProvider): """Okta SCIM provider. Okta behavioral notes: - Uses ``PATCH {"active": false}`` for deprovisioning (not DELETE) - Sends path-less PATCH with value dicts containing extra fields (``id``, ``schemas``) - Expects ``displayName`` and ``groups`` in user responses - Only uses ``eq`` operator for ``userName`` filter """ @property def name(self) -> str: return "okta" @property def ignored_patch_paths(self) -> frozenset[str]: return COMMON_IGNORED_PATCH_PATHS ================================================ FILE: backend/ee/onyx/server/scim/schema_definitions.py ================================================ """Static SCIM service discovery responses (RFC 7643 §5, §6, §7). Pre-built at import time — these never change at runtime. Separated from api.py to keep the endpoint module focused on request handling. """ from ee.onyx.server.scim.models import SCIM_ENTERPRISE_USER_SCHEMA from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA from ee.onyx.server.scim.models import SCIM_USER_SCHEMA from ee.onyx.server.scim.models import ScimResourceType from ee.onyx.server.scim.models import ScimSchemaAttribute from ee.onyx.server.scim.models import ScimSchemaDefinition from ee.onyx.server.scim.models import ScimServiceProviderConfig SERVICE_PROVIDER_CONFIG = ScimServiceProviderConfig() USER_RESOURCE_TYPE = ScimResourceType.model_validate( { "id": "User", "name": "User", "endpoint": "/scim/v2/Users", "description": "SCIM User resource", "schema": SCIM_USER_SCHEMA, "schemaExtensions": [ {"schema": SCIM_ENTERPRISE_USER_SCHEMA, "required": False} ], } ) GROUP_RESOURCE_TYPE = ScimResourceType.model_validate( { "id": "Group", "name": "Group", "endpoint": "/scim/v2/Groups", "description": "SCIM Group resource", "schema": SCIM_GROUP_SCHEMA, } ) USER_SCHEMA_DEF = ScimSchemaDefinition( id=SCIM_USER_SCHEMA, name="User", description="SCIM core User schema", attributes=[ ScimSchemaAttribute( name="userName", type="string", required=True, uniqueness="server", description="Unique identifier for the user, typically an email address.", ), ScimSchemaAttribute( name="name", type="complex", description="The components of the user's name.", subAttributes=[ ScimSchemaAttribute( name="givenName", type="string", description="The user's first name.", ), ScimSchemaAttribute( name="familyName", type="string", description="The user's last name.", ), ScimSchemaAttribute( name="formatted", type="string", description="The full name, including all middle names and titles.", ), ], ), ScimSchemaAttribute( name="emails", type="complex", multiValued=True, description="Email addresses for the user.", subAttributes=[ ScimSchemaAttribute( name="value", type="string", description="Email address value.", ), ScimSchemaAttribute( name="type", type="string", description="Label for this email (e.g. 'work').", ), ScimSchemaAttribute( name="primary", type="boolean", description="Whether this is the primary email.", ), ], ), ScimSchemaAttribute( name="active", type="boolean", description="Whether the user account is active.", ), ScimSchemaAttribute( name="externalId", type="string", description="Identifier from the provisioning client (IdP).", caseExact=True, ), ], ) ENTERPRISE_USER_SCHEMA_DEF = ScimSchemaDefinition( id=SCIM_ENTERPRISE_USER_SCHEMA, name="EnterpriseUser", description="Enterprise User extension (RFC 7643 §4.3)", attributes=[ ScimSchemaAttribute( name="department", type="string", description="Department.", ), ScimSchemaAttribute( name="manager", type="complex", description="The user's manager.", subAttributes=[ ScimSchemaAttribute( name="value", type="string", description="Manager user ID.", ), ], ), ], ) GROUP_SCHEMA_DEF = ScimSchemaDefinition( id=SCIM_GROUP_SCHEMA, name="Group", description="SCIM core Group schema", attributes=[ ScimSchemaAttribute( name="displayName", type="string", required=True, description="Human-readable name for the group.", ), ScimSchemaAttribute( name="members", type="complex", multiValued=True, description="Members of the group.", subAttributes=[ ScimSchemaAttribute( name="value", type="string", description="User ID of the group member.", ), ScimSchemaAttribute( name="display", type="string", mutability="readOnly", description="Display name of the group member.", ), ], ), ScimSchemaAttribute( name="externalId", type="string", description="Identifier from the provisioning client (IdP).", caseExact=True, ), ], ) ================================================ FILE: backend/ee/onyx/server/seeding.py ================================================ import json import os from copy import deepcopy from typing import List from typing import Optional from pydantic import BaseModel from sqlalchemy.orm import Session from ee.onyx.db.standard_answer import ( create_initial_default_standard_answer_category, ) from ee.onyx.server.enterprise_settings.models import AnalyticsScriptUpload from ee.onyx.server.enterprise_settings.models import EnterpriseSettings from ee.onyx.server.enterprise_settings.models import NavigationItem from ee.onyx.server.enterprise_settings.store import store_analytics_script from ee.onyx.server.enterprise_settings.store import ( store_settings as store_ee_settings, ) from ee.onyx.server.enterprise_settings.store import upload_logo from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.llm import fetch_existing_llm_provider from onyx.db.llm import update_default_provider from onyx.db.llm import upsert_llm_provider from onyx.db.models import Tool from onyx.db.persona import upsert_persona from onyx.server.features.persona.models import PersonaUpsertRequest from onyx.server.manage.llm.models import LLMProviderUpsertRequest from onyx.server.manage.llm.models import LLMProviderView from onyx.server.settings.models import Settings from onyx.server.settings.store import store_settings as store_base_settings from onyx.utils.logger import setup_logger class CustomToolSeed(BaseModel): name: str description: str definition_path: str custom_headers: Optional[List[dict]] = None display_name: Optional[str] = None in_code_tool_id: Optional[str] = None user_id: Optional[str] = None logger = setup_logger() _SEED_CONFIG_ENV_VAR_NAME = "ENV_SEED_CONFIGURATION" class NavigationItemSeed(BaseModel): link: str title: str # NOTE: SVG at this path must not have a width / height specified svg_path: str class SeedConfiguration(BaseModel): llms: list[LLMProviderUpsertRequest] | None = None admin_user_emails: list[str] | None = None seeded_logo_path: str | None = None personas: list[PersonaUpsertRequest] | None = None settings: Settings | None = None enterprise_settings: EnterpriseSettings | None = None # allows for specifying custom navigation items that have your own custom SVG logos nav_item_overrides: list[NavigationItemSeed] | None = None # Use existing `CUSTOM_ANALYTICS_SECRET_KEY` for reference analytics_script_path: str | None = None custom_tools: List[CustomToolSeed] | None = None def _parse_env() -> SeedConfiguration | None: seed_config_str = os.getenv(_SEED_CONFIG_ENV_VAR_NAME) if not seed_config_str: return None seed_config = SeedConfiguration.model_validate_json(seed_config_str) return seed_config def _seed_custom_tools(db_session: Session, tools: List[CustomToolSeed]) -> None: if tools: logger.notice("Seeding Custom Tools") for tool in tools: try: logger.debug(f"Attempting to seed tool: {tool.name}") logger.debug(f"Reading definition from: {tool.definition_path}") with open(tool.definition_path, "r") as file: file_content = file.read() if not file_content.strip(): raise ValueError("File is empty") openapi_schema = json.loads(file_content) db_tool = Tool( name=tool.name, description=tool.description, openapi_schema=openapi_schema, custom_headers=tool.custom_headers, display_name=tool.display_name, in_code_tool_id=tool.in_code_tool_id, user_id=tool.user_id, ) db_session.add(db_tool) logger.debug(f"Successfully added tool: {tool.name}") except FileNotFoundError: logger.error( f"Definition file not found for tool {tool.name}: {tool.definition_path}" ) except json.JSONDecodeError as e: logger.error( f"Invalid JSON in definition file for tool {tool.name}: {str(e)}" ) except Exception as e: logger.error(f"Failed to seed tool {tool.name}: {str(e)}") db_session.commit() logger.notice(f"Successfully seeded {len(tools)} Custom Tools") def _seed_llms( db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest] ) -> None: if not llm_upsert_requests: return logger.notice("Seeding LLMs") for request in llm_upsert_requests: existing = fetch_existing_llm_provider(name=request.name, db_session=db_session) if existing: request.id = existing.id seeded_providers: list[LLMProviderView] = [] for llm_upsert_request in llm_upsert_requests: try: seeded_providers.append(upsert_llm_provider(llm_upsert_request, db_session)) except ValueError as e: logger.warning( "Failed to upsert LLM provider '%s' during seeding: %s", llm_upsert_request.name, e, ) default_provider = next( (p for p in seeded_providers if p.model_configurations), None ) if not default_provider: return visible_configs = [ mc for mc in default_provider.model_configurations if mc.is_visible ] default_config = ( visible_configs[0] if visible_configs else default_provider.model_configurations[0] ) update_default_provider( provider_id=default_provider.id, model_name=default_config.name, db_session=db_session, ) def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None: if personas: logger.notice("Seeding Personas") try: for persona in personas: upsert_persona( user=None, # Seeding is done as admin name=persona.name, description=persona.description, document_set_ids=persona.document_set_ids, llm_model_provider_override=persona.llm_model_provider_override, llm_model_version_override=persona.llm_model_version_override, starter_messages=persona.starter_messages, is_public=persona.is_public, db_session=db_session, tool_ids=persona.tool_ids, display_priority=persona.display_priority, system_prompt=persona.system_prompt, task_prompt=persona.task_prompt, datetime_aware=persona.datetime_aware, is_featured=persona.is_featured, commit=False, ) db_session.commit() except Exception: logger.exception("Failed to seed personas.") raise def _seed_settings(settings: Settings) -> None: logger.notice("Seeding Settings") try: store_base_settings(settings) logger.notice("Successfully seeded Settings") except ValueError as e: logger.error(f"Failed to seed Settings: {str(e)}") def _seed_enterprise_settings(seed_config: SeedConfiguration) -> None: if ( seed_config.enterprise_settings is not None or seed_config.nav_item_overrides is not None ): final_enterprise_settings = ( deepcopy(seed_config.enterprise_settings) if seed_config.enterprise_settings else EnterpriseSettings() ) final_nav_items = final_enterprise_settings.custom_nav_items if seed_config.nav_item_overrides is not None: final_nav_items = [] for item in seed_config.nav_item_overrides: with open(item.svg_path, "r") as file: svg_content = file.read().strip() final_nav_items.append( NavigationItem( link=item.link, title=item.title, svg_logo=svg_content, ) ) final_enterprise_settings.custom_nav_items = final_nav_items logger.notice("Seeding enterprise settings") store_ee_settings(final_enterprise_settings) def _seed_logo(logo_path: str | None) -> None: if logo_path: logger.notice("Uploading logo") upload_logo(file=logo_path) def _seed_analytics_script(seed_config: SeedConfiguration) -> None: custom_analytics_secret_key = os.environ.get("CUSTOM_ANALYTICS_SECRET_KEY") if seed_config.analytics_script_path and custom_analytics_secret_key: logger.notice("Seeding analytics script") try: with open(seed_config.analytics_script_path, "r") as file: script_content = file.read() analytics_script = AnalyticsScriptUpload( script=script_content, secret_key=custom_analytics_secret_key ) store_analytics_script(analytics_script) except FileNotFoundError: logger.error( f"Analytics script file not found: {seed_config.analytics_script_path}" ) except ValueError as e: logger.error(f"Failed to seed analytics script: {str(e)}") def get_seed_config() -> SeedConfiguration | None: return _parse_env() def seed_db() -> None: seed_config = _parse_env() if seed_config is None: logger.debug("No seeding configuration file passed") return with get_session_with_current_tenant() as db_session: if seed_config.llms is not None: _seed_llms(db_session, seed_config.llms) if seed_config.personas is not None: _seed_personas(db_session, seed_config.personas) if seed_config.settings is not None: _seed_settings(seed_config.settings) if seed_config.custom_tools is not None: _seed_custom_tools(db_session, seed_config.custom_tools) _seed_logo(seed_config.seeded_logo_path) _seed_enterprise_settings(seed_config) _seed_analytics_script(seed_config) logger.notice("Verifying default standard answer category exists.") create_initial_default_standard_answer_category(db_session) ================================================ FILE: backend/ee/onyx/server/settings/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/server/settings/api.py ================================================ """EE Settings API - provides license-aware settings override.""" from redis.exceptions import RedisError from sqlalchemy.exc import SQLAlchemyError from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED from ee.onyx.db.license import get_cached_license_metadata from ee.onyx.db.license import refresh_license_cache from onyx.cache.interface import CACHE_TRANSIENT_ERRORS from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.server.settings.models import ApplicationStatus from onyx.server.settings.models import Settings from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() # Only GATED_ACCESS actually blocks access - other statuses are for notifications _BLOCKING_STATUS = ApplicationStatus.GATED_ACCESS def check_ee_features_enabled() -> bool: """EE version: checks if EE features should be available. Returns True if: - LICENSE_ENFORCEMENT_ENABLED is False (legacy/rollout mode) - Cloud mode (MULTI_TENANT) - cloud handles its own gating - Self-hosted with a valid (non-expired) license Returns False if: - Self-hosted with no license (never subscribed) - Self-hosted with expired license """ if not LICENSE_ENFORCEMENT_ENABLED: # License enforcement disabled - allow EE features (legacy behavior) return True if MULTI_TENANT: # Cloud mode - EE features always available (gating handled by is_tenant_gated) return True # Self-hosted with enforcement - check for valid license tenant_id = get_current_tenant_id() try: metadata = get_cached_license_metadata(tenant_id) if not metadata: # Cache miss — warm from DB so cold-start doesn't block EE features try: with get_session_with_current_tenant() as db_session: metadata = refresh_license_cache(db_session, tenant_id) except SQLAlchemyError as db_error: logger.warning(f"Failed to load license from DB: {db_error}") if metadata and metadata.status != _BLOCKING_STATUS: # Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features) return True except RedisError as e: logger.warning(f"Failed to check license for EE features: {e}") # Fail closed - if Redis is down, other things will break anyway return False # No license or GATED_ACCESS - no EE features return False def apply_license_status_to_settings(settings: Settings) -> Settings: """EE version: checks license status for self-hosted deployments. For self-hosted, looks up license metadata and overrides application_status if the license indicates GATED_ACCESS (fully expired). Also sets ee_features_enabled based on license status to control visibility of EE features in the UI. For multi-tenant (cloud), the settings already have the correct status from the control plane, so no override is needed. If LICENSE_ENFORCEMENT_ENABLED is false, ee_features_enabled is set to True (since EE code was loaded via ENABLE_PAID_ENTERPRISE_EDITION_FEATURES). """ if not LICENSE_ENFORCEMENT_ENABLED: # License enforcement disabled - EE code is loaded via # ENABLE_PAID_ENTERPRISE_EDITION_FEATURES, so EE features are on settings.ee_features_enabled = True return settings if MULTI_TENANT: # Cloud mode - EE features always available (gating handled by is_tenant_gated) settings.ee_features_enabled = True return settings tenant_id = get_current_tenant_id() try: metadata = get_cached_license_metadata(tenant_id) if not metadata: # Cache miss (e.g. after TTL expiry). Fall back to DB so # the /settings request doesn't falsely return GATED_ACCESS # while the cache is cold. try: with get_session_with_current_tenant() as db_session: metadata = refresh_license_cache(db_session, tenant_id) except SQLAlchemyError as db_error: logger.warning( f"Failed to load license from DB for settings: {db_error}" ) if metadata: if metadata.status == _BLOCKING_STATUS: settings.application_status = metadata.status settings.ee_features_enabled = False elif metadata.used_seats > metadata.seats: # License is valid but seat limit exceeded settings.application_status = ApplicationStatus.SEAT_LIMIT_EXCEEDED settings.seat_count = metadata.seats settings.used_seats = metadata.used_seats settings.ee_features_enabled = True else: # Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features) settings.ee_features_enabled = True else: # No license found in cache or DB. if ENTERPRISE_EDITION_ENABLED: # Legacy EE flag is set → prior EE usage (e.g. permission # syncing) means indexed data may need protection. settings.application_status = _BLOCKING_STATUS settings.ee_features_enabled = False except CACHE_TRANSIENT_ERRORS as e: logger.warning(f"Failed to check license metadata for settings: {e}") # Fail closed - disable EE features if we can't verify license settings.ee_features_enabled = False return settings ================================================ FILE: backend/ee/onyx/server/tenant_usage_limits.py ================================================ """Tenant-specific usage limit overrides from the control plane (EE version).""" import time import requests from ee.onyx.server.tenants.access import generate_data_plane_token from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL from onyx.configs.app_configs import DEV_MODE from onyx.server.tenant_usage_limits import TenantUsageLimitOverrides from onyx.server.usage_limits import NO_LIMIT from onyx.utils.logger import setup_logger logger = setup_logger() # In-memory storage for tenant overrides (populated at startup) _tenant_usage_limit_overrides: dict[str, TenantUsageLimitOverrides] | None = None _last_fetch_time: float = 0.0 _FETCH_INTERVAL = 60 * 60 * 24 # 24 hours _ERROR_FETCH_INTERVAL = 30 * 60 # 30 minutes (if the last fetch failed) def fetch_usage_limit_overrides() -> dict[str, TenantUsageLimitOverrides] | None: """ Fetch tenant-specific usage limit overrides from the control plane. Returns: Dictionary mapping tenant_id to their specific limit overrides. Returns empty dict on any error (falls back to defaults). """ try: token = generate_data_plane_token() headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json", } url = f"{CONTROL_PLANE_API_BASE_URL}/usage-limit-overrides" response = requests.get(url, headers=headers, timeout=30) response.raise_for_status() tenant_overrides = response.json() # Parse each tenant's overrides result: dict[str, TenantUsageLimitOverrides] = {} for override_data in tenant_overrides: tenant_id = override_data["tenant_id"] try: result[tenant_id] = TenantUsageLimitOverrides(**override_data) except Exception as e: logger.warning( f"Failed to parse usage limit overrides for tenant {tenant_id}: {e}" ) return ( result or None ) # if empty dictionary, something went wrong and we shouldn't enforce limits except requests.exceptions.RequestException as e: logger.warning(f"Failed to fetch usage limit overrides from control plane: {e}") return None except Exception as e: logger.error(f"Error parsing usage limit overrides: {e}") return None def load_usage_limit_overrides() -> None: """ Load tenant usage limit overrides from the control plane. """ global _tenant_usage_limit_overrides global _last_fetch_time logger.info("Loading tenant usage limit overrides from control plane...") overrides = fetch_usage_limit_overrides() _last_fetch_time = time.time() # use the new result if it exists, otherwise use the old result # (prevents us from updating to a failed fetch result) _tenant_usage_limit_overrides = overrides or _tenant_usage_limit_overrides if overrides: logger.info(f"Loaded usage limit overrides for {len(overrides)} tenants") else: logger.info("No tenant-specific usage limit overrides found") def unlimited(tenant_id: str) -> TenantUsageLimitOverrides: return TenantUsageLimitOverrides( tenant_id=tenant_id, llm_cost_cents_trial=NO_LIMIT, llm_cost_cents_paid=NO_LIMIT, chunks_indexed_trial=NO_LIMIT, chunks_indexed_paid=NO_LIMIT, api_calls_trial=NO_LIMIT, api_calls_paid=NO_LIMIT, non_streaming_calls_trial=NO_LIMIT, non_streaming_calls_paid=NO_LIMIT, ) def get_tenant_usage_limit_overrides( tenant_id: str, ) -> TenantUsageLimitOverrides | None: """ Get the usage limit overrides for a specific tenant. Args: tenant_id: The tenant ID to look up Returns: TenantUsageLimitOverrides if the tenant has overrides, None otherwise. """ if DEV_MODE: # in dev mode, we return unlimited limits for all tenants return unlimited(tenant_id) global _tenant_usage_limit_overrides time_since = time.time() - _last_fetch_time if ( _tenant_usage_limit_overrides is None and time_since > _ERROR_FETCH_INTERVAL ) or (time_since > _FETCH_INTERVAL): logger.debug( f"Last fetch time: {_last_fetch_time}, time since last fetch: {time_since}" ) load_usage_limit_overrides() # If we have failed to fetch from the control plane or we're in dev mode, don't usage limit anyone. if _tenant_usage_limit_overrides is None or DEV_MODE: return unlimited(tenant_id) return _tenant_usage_limit_overrides.get(tenant_id) ================================================ FILE: backend/ee/onyx/server/tenants/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/server/tenants/access.py ================================================ from datetime import datetime from datetime import timedelta import jwt from fastapi import HTTPException from fastapi import Request from onyx.configs.app_configs import DATA_PLANE_SECRET from onyx.configs.app_configs import EXPECTED_API_KEY from onyx.configs.app_configs import JWT_ALGORITHM from onyx.utils.logger import setup_logger logger = setup_logger() def generate_data_plane_token() -> str: if DATA_PLANE_SECRET is None: raise ValueError("DATA_PLANE_SECRET is not set") payload = { "iss": "data_plane", "exp": datetime.utcnow() + timedelta(minutes=5), "iat": datetime.utcnow(), "scope": "api_access", } token = jwt.encode(payload, DATA_PLANE_SECRET, algorithm=JWT_ALGORITHM) return token async def control_plane_dep(request: Request) -> None: api_key = request.headers.get("X-API-KEY") if api_key != EXPECTED_API_KEY: logger.warning("Invalid API key") raise HTTPException(status_code=401, detail="Invalid API key") auth_header = request.headers.get("Authorization") if not auth_header or not auth_header.startswith("Bearer "): logger.warning("Invalid authorization header") raise HTTPException(status_code=401, detail="Invalid authorization header") token = auth_header.split(" ")[1] try: payload = jwt.decode(token, DATA_PLANE_SECRET, algorithms=[JWT_ALGORITHM]) if payload.get("scope") != "tenant:create": logger.warning("Insufficient permissions") raise HTTPException(status_code=403, detail="Insufficient permissions") except jwt.ExpiredSignatureError: logger.warning("Token has expired") raise HTTPException(status_code=401, detail="Token has expired") except jwt.InvalidTokenError: logger.warning("Invalid token") raise HTTPException(status_code=401, detail="Invalid token") ================================================ FILE: backend/ee/onyx/server/tenants/admin_api.py ================================================ from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import Response from fastapi_users import exceptions from ee.onyx.auth.users import current_cloud_superuser from ee.onyx.server.tenants.models import ImpersonateRequest from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email from onyx.auth.users import auth_backend from onyx.auth.users import get_redis_strategy from onyx.auth.users import User from onyx.db.engine.sql_engine import get_session_with_tenant from onyx.db.users import get_user_by_email from onyx.utils.logger import setup_logger logger = setup_logger() router = APIRouter(prefix="/tenants") @router.post("/impersonate") async def impersonate_user( impersonate_request: ImpersonateRequest, _: User = Depends(current_cloud_superuser), ) -> Response: """Allows a cloud superuser to impersonate another user by generating an impersonation JWT token""" try: tenant_id = get_tenant_id_for_email(impersonate_request.email) except exceptions.UserNotExists: detail = f"User has no tenant mapping: {impersonate_request.email=}" logger.warning(detail) raise HTTPException(status_code=422, detail=detail) with get_session_with_tenant(tenant_id=tenant_id) as tenant_session: user_to_impersonate = get_user_by_email( impersonate_request.email, tenant_session ) if user_to_impersonate is None: detail = ( f"User not found in tenant: {impersonate_request.email=} {tenant_id=}" ) logger.warning(detail) raise HTTPException(status_code=422, detail=detail) token = await get_redis_strategy().write_token(user_to_impersonate) response = await auth_backend.transport.get_login_response(token) response.set_cookie( key="fastapiusersauth", value=token, httponly=True, secure=True, samesite="lax", ) return response ================================================ FILE: backend/ee/onyx/server/tenants/anonymous_user_path.py ================================================ from sqlalchemy import select from sqlalchemy.orm import Session from onyx.db.models import TenantAnonymousUserPath def get_anonymous_user_path(tenant_id: str, db_session: Session) -> str | None: result = db_session.execute( select(TenantAnonymousUserPath).where( TenantAnonymousUserPath.tenant_id == tenant_id ) ) result_scalar = result.scalar_one_or_none() if result_scalar: return result_scalar.anonymous_user_path else: return None def modify_anonymous_user_path( tenant_id: str, anonymous_user_path: str, db_session: Session ) -> None: # Enforce lowercase path at DB operation level anonymous_user_path = anonymous_user_path.lower() existing_entry = ( db_session.query(TenantAnonymousUserPath).filter_by(tenant_id=tenant_id).first() ) if existing_entry: existing_entry.anonymous_user_path = anonymous_user_path else: new_entry = TenantAnonymousUserPath( tenant_id=tenant_id, anonymous_user_path=anonymous_user_path ) db_session.add(new_entry) db_session.commit() def get_tenant_id_for_anonymous_user_path( anonymous_user_path: str, db_session: Session ) -> str | None: result = db_session.execute( select(TenantAnonymousUserPath).where( TenantAnonymousUserPath.anonymous_user_path == anonymous_user_path ) ) result_scalar = result.scalar_one_or_none() if result_scalar: return result_scalar.tenant_id else: return None def validate_anonymous_user_path(path: str) -> None: if not path or "/" in path or not path.replace("-", "").isalnum(): raise ValueError("Invalid path. Use only letters, numbers, and hyphens.") ================================================ FILE: backend/ee/onyx/server/tenants/anonymous_users_api.py ================================================ from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import Response from sqlalchemy.exc import IntegrityError from ee.onyx.auth.users import generate_anonymous_user_jwt_token from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path from ee.onyx.server.tenants.anonymous_user_path import ( get_tenant_id_for_anonymous_user_path, ) from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path from ee.onyx.server.tenants.models import AnonymousUserPath from onyx.auth.users import anonymous_user_enabled from onyx.auth.users import current_admin_user from onyx.auth.users import User from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME from onyx.db.engine.sql_engine import get_session_with_shared_schema from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() router = APIRouter(prefix="/tenants") @router.get("/anonymous-user-path") async def get_anonymous_user_path_api( _: User = Depends(current_admin_user), ) -> AnonymousUserPath: tenant_id = get_current_tenant_id() if tenant_id is None: raise HTTPException(status_code=404, detail="Tenant not found") with get_session_with_shared_schema() as db_session: current_path = get_anonymous_user_path(tenant_id, db_session) return AnonymousUserPath(anonymous_user_path=current_path) @router.post("/anonymous-user-path") async def set_anonymous_user_path_api( anonymous_user_path: str, _: User = Depends(current_admin_user), ) -> None: tenant_id = get_current_tenant_id() try: validate_anonymous_user_path(anonymous_user_path) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) with get_session_with_shared_schema() as db_session: try: modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session) except IntegrityError: raise HTTPException( status_code=409, detail="The anonymous user path is already in use. Please choose a different path.", ) except Exception as e: logger.exception(f"Failed to modify anonymous user path: {str(e)}") raise HTTPException( status_code=500, detail="An unexpected error occurred while modifying the anonymous user path", ) @router.post("/anonymous-user") async def login_as_anonymous_user( anonymous_user_path: str, ) -> Response: with get_session_with_shared_schema() as db_session: tenant_id = get_tenant_id_for_anonymous_user_path( anonymous_user_path, db_session ) if not tenant_id: raise HTTPException(status_code=404, detail="Tenant not found") if not anonymous_user_enabled(tenant_id=tenant_id): raise HTTPException(status_code=403, detail="Anonymous user is not enabled") token = generate_anonymous_user_jwt_token(tenant_id) response = Response() response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME) response.set_cookie( key=ANONYMOUS_USER_COOKIE_NAME, value=token, httponly=True, secure=True, samesite="strict", ) return response ================================================ FILE: backend/ee/onyx/server/tenants/api.py ================================================ from fastapi import APIRouter from ee.onyx.server.tenants.admin_api import router as admin_router from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router from ee.onyx.server.tenants.billing_api import router as billing_router from ee.onyx.server.tenants.proxy import router as proxy_router from ee.onyx.server.tenants.team_membership_api import router as team_membership_router from ee.onyx.server.tenants.tenant_management_api import ( router as tenant_management_router, ) from ee.onyx.server.tenants.user_invitations_api import ( router as user_invitations_router, ) # Create a main router to include all sub-routers # Note: We don't add a prefix here as each router already has the /tenants prefix router = APIRouter() # Include all the individual routers router.include_router(admin_router) router.include_router(anonymous_users_router) router.include_router(billing_router) router.include_router(team_membership_router) router.include_router(tenant_management_router) router.include_router(user_invitations_router) router.include_router(proxy_router) ================================================ FILE: backend/ee/onyx/server/tenants/billing.py ================================================ from typing import cast from typing import Literal import requests import stripe from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY from ee.onyx.server.tenants.access import generate_data_plane_token from ee.onyx.server.tenants.models import BillingInformation from ee.onyx.server.tenants.models import SubscriptionStatusResponse from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL from onyx.utils.logger import setup_logger stripe.api_key = STRIPE_SECRET_KEY logger = setup_logger() def fetch_stripe_checkout_session( tenant_id: str, billing_period: Literal["monthly", "annual"] = "monthly", seats: int | None = None, ) -> str: token = generate_data_plane_token() headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json", } url = f"{CONTROL_PLANE_API_BASE_URL}/create-checkout-session" payload = { "tenant_id": tenant_id, "billing_period": billing_period, "seats": seats, } response = requests.post(url, headers=headers, json=payload) if not response.ok: try: data = response.json() error_msg = ( data.get("error") or f"Request failed with status {response.status_code}" ) except (ValueError, requests.exceptions.JSONDecodeError): error_msg = f"Request failed with status {response.status_code}: {response.text[:200]}" raise Exception(error_msg) data = response.json() if data.get("error"): raise Exception(data["error"]) return data["sessionId"] def fetch_tenant_stripe_information(tenant_id: str) -> dict: token = generate_data_plane_token() headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json", } url = f"{CONTROL_PLANE_API_BASE_URL}/tenant-stripe-information" params = {"tenant_id": tenant_id} response = requests.get(url, headers=headers, params=params) response.raise_for_status() return response.json() def fetch_billing_information( tenant_id: str, ) -> BillingInformation | SubscriptionStatusResponse: token = generate_data_plane_token() headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json", } url = f"{CONTROL_PLANE_API_BASE_URL}/billing-information" params = {"tenant_id": tenant_id} response = requests.get(url, headers=headers, params=params) response.raise_for_status() response_data = response.json() # Check if the response indicates no subscription if ( isinstance(response_data, dict) and "subscribed" in response_data and not response_data["subscribed"] ): return SubscriptionStatusResponse(**response_data) # Otherwise, parse as BillingInformation return BillingInformation(**response_data) def fetch_customer_portal_session(tenant_id: str, return_url: str | None = None) -> str: """ Fetch a Stripe customer portal session URL from the control plane. NOTE: This is currently only used for multi-tenant (cloud) deployments. Self-hosted proxy endpoints will be added in a future phase. """ token = generate_data_plane_token() headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json", } url = f"{CONTROL_PLANE_API_BASE_URL}/create-customer-portal-session" payload = {"tenant_id": tenant_id} if return_url: payload["return_url"] = return_url response = requests.post(url, headers=headers, json=payload) response.raise_for_status() return response.json()["url"] def register_tenant_users(tenant_id: str, number_of_users: int) -> stripe.Subscription: """ Update the number of seats for a tenant's subscription. Preserves the existing price (monthly, annual, or grandfathered). """ response = fetch_tenant_stripe_information(tenant_id) stripe_subscription_id = cast(str, response.get("stripe_subscription_id")) subscription = stripe.Subscription.retrieve(stripe_subscription_id) subscription_item = subscription["items"]["data"][0] # Use existing price to preserve the customer's current plan current_price_id = subscription_item.price.id updated_subscription = stripe.Subscription.modify( stripe_subscription_id, items=[ { "id": subscription_item.id, "price": current_price_id, "quantity": number_of_users, } ], metadata={"tenant_id": str(tenant_id)}, ) return updated_subscription ================================================ FILE: backend/ee/onyx/server/tenants/billing_api.py ================================================ """Billing API endpoints for cloud multi-tenant deployments. DEPRECATED: These /tenants/* billing endpoints are being replaced by /admin/billing/* which provides a unified API for both self-hosted and cloud deployments. TODO(ENG-3533): Migrate frontend to use /admin/billing/* endpoints and remove this file. https://linear.app/onyx-app/issue/ENG-3533/migrate-tenantsbilling-adminbilling Current endpoints to migrate: - GET /tenants/billing-information -> GET /admin/billing/information - POST /tenants/create-customer-portal-session -> POST /admin/billing/portal-session - POST /tenants/create-subscription-session -> POST /admin/billing/checkout-session - GET /tenants/stripe-publishable-key -> (keep as-is, shared endpoint) Note: /tenants/product-gating/* endpoints are control-plane-to-data-plane calls and are NOT part of this migration - they stay here. """ import asyncio import httpx from fastapi import APIRouter from fastapi import Depends from ee.onyx.auth.users import current_admin_user from ee.onyx.server.tenants.access import control_plane_dep from ee.onyx.server.tenants.billing import fetch_billing_information from ee.onyx.server.tenants.billing import fetch_customer_portal_session from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session from ee.onyx.server.tenants.models import BillingInformation from ee.onyx.server.tenants.models import CreateCheckoutSessionRequest from ee.onyx.server.tenants.models import CreateSubscriptionSessionRequest from ee.onyx.server.tenants.models import ProductGatingFullSyncRequest from ee.onyx.server.tenants.models import ProductGatingRequest from ee.onyx.server.tenants.models import ProductGatingResponse from ee.onyx.server.tenants.models import StripePublishableKeyResponse from ee.onyx.server.tenants.models import SubscriptionSessionResponse from ee.onyx.server.tenants.models import SubscriptionStatusResponse from ee.onyx.server.tenants.product_gating import overwrite_full_gated_set from ee.onyx.server.tenants.product_gating import store_product_gating from onyx.auth.users import User from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_OVERRIDE from onyx.configs.app_configs import STRIPE_PUBLISHABLE_KEY_URL from onyx.configs.app_configs import WEB_DOMAIN from onyx.error_handling.error_codes import OnyxErrorCode from onyx.error_handling.exceptions import OnyxError from onyx.utils.logger import setup_logger from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() router = APIRouter(prefix="/tenants") # Cache for Stripe publishable key to avoid hitting S3 on every request _stripe_publishable_key_cache: str | None = None _stripe_key_lock = asyncio.Lock() @router.post("/product-gating") def gate_product( product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep) ) -> ProductGatingResponse: """ Gating the product means that the product is not available to the tenant. They will be directed to the billing page. We gate the product when their subscription has ended. """ try: store_product_gating( product_gating_request.tenant_id, product_gating_request.application_status ) return ProductGatingResponse(updated=True, error=None) except Exception as e: logger.exception("Failed to gate product") return ProductGatingResponse(updated=False, error=str(e)) @router.post("/product-gating/full-sync") def gate_product_full_sync( product_gating_request: ProductGatingFullSyncRequest, _: None = Depends(control_plane_dep), ) -> ProductGatingResponse: """ Bulk operation to overwrite the entire gated tenant set. This replaces all currently gated tenants with the provided list. Gated tenants are not available to access the product and will be directed to the billing page when their subscription has ended. """ try: overwrite_full_gated_set(product_gating_request.gated_tenant_ids) return ProductGatingResponse(updated=True, error=None) except Exception as e: logger.exception("Failed to gate products during full sync") return ProductGatingResponse(updated=False, error=str(e)) @router.get("/billing-information") async def billing_information( _: User = Depends(current_admin_user), ) -> BillingInformation | SubscriptionStatusResponse: logger.info("Fetching billing information") tenant_id = get_current_tenant_id() return fetch_billing_information(tenant_id) @router.post("/create-customer-portal-session") async def create_customer_portal_session( _: User = Depends(current_admin_user), ) -> dict: """Create a Stripe customer portal session via the control plane.""" tenant_id = get_current_tenant_id() return_url = f"{WEB_DOMAIN}/admin/billing" try: portal_url = fetch_customer_portal_session(tenant_id, return_url) return {"stripe_customer_portal_url": portal_url} except OnyxError: raise except Exception: logger.exception("Failed to create customer portal session") raise OnyxError( OnyxErrorCode.INTERNAL_ERROR, "Failed to create customer portal session", ) @router.post("/create-checkout-session") async def create_checkout_session( request: CreateCheckoutSessionRequest | None = None, _: User = Depends(current_admin_user), ) -> dict: """Create a Stripe checkout session via the control plane.""" tenant_id = get_current_tenant_id() billing_period = request.billing_period if request else "monthly" seats = request.seats if request else None try: checkout_url = fetch_stripe_checkout_session(tenant_id, billing_period, seats) return {"stripe_checkout_url": checkout_url} except OnyxError: raise except Exception: logger.exception("Failed to create checkout session") raise OnyxError( OnyxErrorCode.INTERNAL_ERROR, "Failed to create checkout session", ) @router.post("/create-subscription-session") async def create_subscription_session( request: CreateSubscriptionSessionRequest | None = None, _: User = Depends(current_admin_user), ) -> SubscriptionSessionResponse: try: tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() if not tenant_id: raise OnyxError(OnyxErrorCode.VALIDATION_ERROR, "Tenant ID not found") billing_period = request.billing_period if request else "monthly" session_id = fetch_stripe_checkout_session(tenant_id, billing_period) return SubscriptionSessionResponse(sessionId=session_id) except OnyxError: raise except Exception: logger.exception("Failed to create subscription session") raise OnyxError( OnyxErrorCode.INTERNAL_ERROR, "Failed to create subscription session", ) @router.get("/stripe-publishable-key") async def get_stripe_publishable_key() -> StripePublishableKeyResponse: """ Fetch the Stripe publishable key. Priority: env var override (for testing) > S3 bucket (production). This endpoint is public (no auth required) since publishable keys are safe to expose. The key is cached in memory to avoid hitting S3 on every request. """ global _stripe_publishable_key_cache # Fast path: return cached value without lock if _stripe_publishable_key_cache: return StripePublishableKeyResponse( publishable_key=_stripe_publishable_key_cache ) # Use lock to prevent concurrent S3 requests async with _stripe_key_lock: # Double-check after acquiring lock (another request may have populated cache) if _stripe_publishable_key_cache: return StripePublishableKeyResponse( publishable_key=_stripe_publishable_key_cache ) # Check for env var override first (for local testing with pk_test_* keys) if STRIPE_PUBLISHABLE_KEY_OVERRIDE: key = STRIPE_PUBLISHABLE_KEY_OVERRIDE.strip() if not key.startswith("pk_"): raise OnyxError( OnyxErrorCode.INTERNAL_ERROR, "Invalid Stripe publishable key format", ) _stripe_publishable_key_cache = key return StripePublishableKeyResponse(publishable_key=key) # Fall back to S3 bucket if not STRIPE_PUBLISHABLE_KEY_URL: raise OnyxError( OnyxErrorCode.INTERNAL_ERROR, "Stripe publishable key is not configured", ) try: async with httpx.AsyncClient() as client: response = await client.get(STRIPE_PUBLISHABLE_KEY_URL) response.raise_for_status() key = response.text.strip() # Validate key format if not key.startswith("pk_"): raise OnyxError( OnyxErrorCode.INTERNAL_ERROR, "Invalid Stripe publishable key format", ) _stripe_publishable_key_cache = key return StripePublishableKeyResponse(publishable_key=key) except httpx.HTTPError: raise OnyxError( OnyxErrorCode.INTERNAL_ERROR, "Failed to fetch Stripe publishable key", ) ================================================ FILE: backend/ee/onyx/server/tenants/models.py ================================================ from datetime import datetime from typing import Literal from pydantic import BaseModel from onyx.server.settings.models import ApplicationStatus class CheckoutSessionCreationRequest(BaseModel): quantity: int class CreateTenantRequest(BaseModel): tenant_id: str initial_admin_email: str class ProductGatingRequest(BaseModel): tenant_id: str application_status: ApplicationStatus class ProductGatingFullSyncRequest(BaseModel): gated_tenant_ids: list[str] class SubscriptionStatusResponse(BaseModel): subscribed: bool class BillingInformation(BaseModel): stripe_subscription_id: str status: str current_period_start: datetime current_period_end: datetime number_of_seats: int cancel_at_period_end: bool canceled_at: datetime | None trial_start: datetime | None trial_end: datetime | None seats: int payment_method_enabled: bool class CreateCheckoutSessionRequest(BaseModel): billing_period: Literal["monthly", "annual"] = "monthly" seats: int | None = None email: str | None = None class CheckoutSessionCreationResponse(BaseModel): id: str class ImpersonateRequest(BaseModel): email: str class TenantCreationPayload(BaseModel): tenant_id: str email: str referral_source: str | None = None class TenantDeletionPayload(BaseModel): tenant_id: str email: str class AnonymousUserPath(BaseModel): anonymous_user_path: str | None class ProductGatingResponse(BaseModel): updated: bool error: str | None class SubscriptionSessionResponse(BaseModel): sessionId: str class CreateSubscriptionSessionRequest(BaseModel): """Request to create a subscription checkout session.""" billing_period: Literal["monthly", "annual"] = "monthly" class TenantByDomainResponse(BaseModel): tenant_id: str number_of_users: int creator_email: str class TenantByDomainRequest(BaseModel): email: str class RequestInviteRequest(BaseModel): tenant_id: str class RequestInviteResponse(BaseModel): success: bool message: str class PendingUserSnapshot(BaseModel): email: str class ApproveUserRequest(BaseModel): email: str class StripePublishableKeyResponse(BaseModel): publishable_key: str ================================================ FILE: backend/ee/onyx/server/tenants/product_gating.py ================================================ from typing import cast from ee.onyx.configs.app_configs import GATED_TENANTS_KEY from onyx.configs.constants import ONYX_CLOUD_TENANT_ID from onyx.redis.redis_pool import get_redis_client from onyx.redis.redis_pool import get_redis_replica_client from onyx.server.settings.models import ApplicationStatus from onyx.server.settings.store import load_settings from onyx.server.settings.store import store_settings from onyx.utils.logger import setup_logger from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() def update_tenant_gating(tenant_id: str, status: ApplicationStatus) -> None: redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID) # Maintain the GATED_ACCESS set if status == ApplicationStatus.GATED_ACCESS: redis_client.sadd(GATED_TENANTS_KEY, tenant_id) else: redis_client.srem(GATED_TENANTS_KEY, tenant_id) def store_product_gating(tenant_id: str, application_status: ApplicationStatus) -> None: try: token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) settings = load_settings() settings.application_status = application_status store_settings(settings) # Store gated tenant information in Redis update_tenant_gating(tenant_id, application_status) if token is not None: CURRENT_TENANT_ID_CONTEXTVAR.reset(token) except Exception: logger.exception("Failed to gate product") raise def overwrite_full_gated_set(tenant_ids: list[str]) -> None: redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID) pipeline = redis_client.pipeline() # using pipeline doesn't automatically add the tenant_id prefix full_gated_set_key = f"{ONYX_CLOUD_TENANT_ID}:{GATED_TENANTS_KEY}" # Clear the existing set pipeline.delete(full_gated_set_key) # Add all tenant IDs to the set and set their status for tenant_id in tenant_ids: pipeline.sadd(full_gated_set_key, tenant_id) # Execute all commands at once pipeline.execute() def get_gated_tenants() -> set[str]: redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID) gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY)) return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes} def is_tenant_gated(tenant_id: str) -> bool: """Fast O(1) check if tenant is in gated set (multi-tenant only).""" redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID) return bool(redis_client.sismember(GATED_TENANTS_KEY, tenant_id)) ================================================ FILE: backend/ee/onyx/server/tenants/provisioning.py ================================================ import asyncio import uuid import aiohttp # Async HTTP client import httpx import requests from fastapi import HTTPException from fastapi import Request from sqlalchemy import select from sqlalchemy.orm import Session from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL from ee.onyx.server.tenants.access import generate_data_plane_token from ee.onyx.server.tenants.models import TenantByDomainResponse from ee.onyx.server.tenants.models import TenantCreationPayload from ee.onyx.server.tenants.models import TenantDeletionPayload from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists from ee.onyx.server.tenants.schema_management import drop_schema from ee.onyx.server.tenants.schema_management import run_alembic_migrations from ee.onyx.server.tenants.user_mapping import add_users_to_tenant from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant from onyx.auth.users import exceptions from onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY from onyx.configs.app_configs import COHERE_DEFAULT_API_KEY from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL from onyx.configs.app_configs import DEV_MODE from onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY from onyx.configs.app_configs import OPENROUTER_DEFAULT_API_KEY from onyx.configs.app_configs import VERTEXAI_DEFAULT_CREDENTIALS from onyx.configs.app_configs import VERTEXAI_DEFAULT_LOCATION from onyx.db.engine.sql_engine import get_session_with_shared_schema from onyx.db.engine.sql_engine import get_session_with_tenant from onyx.db.image_generation import create_default_image_gen_config_from_api_key from onyx.db.llm import fetch_existing_llm_provider from onyx.db.llm import update_default_provider from onyx.db.llm import upsert_cloud_embedding_provider from onyx.db.llm import upsert_llm_provider from onyx.db.models import AvailableTenant from onyx.db.models import IndexModelStatus from onyx.db.models import SearchSettings from onyx.db.models import UserTenantMapping from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME from onyx.llm.well_known_providers.constants import VERTEX_CREDENTIALS_FILE_KWARG from onyx.llm.well_known_providers.constants import VERTEX_LOCATION_KWARG from onyx.llm.well_known_providers.constants import VERTEXAI_PROVIDER_NAME from onyx.llm.well_known_providers.llm_provider_options import ( get_recommendations, ) from onyx.llm.well_known_providers.llm_provider_options import ( model_configurations_for_provider, ) from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest from onyx.server.manage.llm.models import LLMProviderUpsertRequest from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest from onyx.setup import setup_onyx from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.configs import TENANT_ID_PREFIX from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.enums import EmbeddingProvider logger = setup_logger() async def get_or_provision_tenant( email: str, referral_source: str | None = None, request: Request | None = None, ) -> str: """ Get existing tenant ID for an email or create a new tenant if none exists. This function should only be called after we have verified we want this user's tenant to exist. It returns the tenant ID associated with the email, creating a new tenant if necessary. """ # Early return for non-multi-tenant mode if not MULTI_TENANT: return POSTGRES_DEFAULT_SCHEMA if referral_source and request: await submit_to_hubspot(email, referral_source, request) # First, check if the user already has a tenant tenant_id: str | None = None try: tenant_id = get_tenant_id_for_email(email) return tenant_id except exceptions.UserNotExists: # User doesn't exist, so we need to create a new tenant or assign an existing one pass try: # Try to get a pre-provisioned tenant tenant_id = await get_available_tenant() if tenant_id: # Run migrations to ensure the pre-provisioned tenant schema is current. # Pool tenants may have been created before a new migration was deployed. # Capture as a non-optional local so mypy can type the lambda correctly. _tenant_id: str = tenant_id loop = asyncio.get_running_loop() try: await loop.run_in_executor( None, lambda: run_alembic_migrations(_tenant_id) ) except Exception: # The tenant was already dequeued from the pool — roll it back so # it doesn't end up orphaned (schema exists, but not assigned to anyone). logger.exception( f"Migration failed for pre-provisioned tenant {_tenant_id}; rolling back" ) try: await rollback_tenant_provisioning(_tenant_id) except Exception: logger.exception(f"Failed to rollback orphaned tenant {_tenant_id}") raise # If we have a pre-provisioned tenant, assign it to the user await assign_tenant_to_user(tenant_id, email, referral_source) logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}") else: # If no pre-provisioned tenant is available, create a new one on-demand tenant_id = await create_tenant(email, referral_source) # Notify control plane if we have created / assigned a new tenant if not DEV_MODE: await notify_control_plane(tenant_id, email, referral_source) return tenant_id except Exception as e: # If we've encountered an error, log and raise an exception error_msg = "Failed to provision tenant" logger.error(error_msg, exc_info=e) raise HTTPException( status_code=500, detail="Failed to provision tenant. Please try again later.", ) async def create_tenant( email: str, referral_source: str | None = None, # noqa: ARG001 ) -> str: """ Create a new tenant on-demand when no pre-provisioned tenants are available. This is the fallback method when we can't use a pre-provisioned tenant. """ tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4()) logger.info(f"Creating new tenant {tenant_id} for user {email}") try: # Provision tenant on data plane await provision_tenant(tenant_id, email) except Exception as e: logger.exception(f"Tenant provisioning failed: {str(e)}") # Attempt to rollback the tenant provisioning try: await rollback_tenant_provisioning(tenant_id) except Exception: logger.exception(f"Failed to rollback tenant provisioning for {tenant_id}") raise HTTPException(status_code=500, detail="Failed to provision tenant.") return tenant_id async def provision_tenant(tenant_id: str, email: str) -> None: if not MULTI_TENANT: raise HTTPException(status_code=403, detail="Multi-tenancy is not enabled") if user_owns_a_tenant(email): raise HTTPException( status_code=409, detail="User already belongs to an organization" ) logger.debug(f"Provisioning tenant {tenant_id} for user {email}") try: # Create the schema for the tenant if not create_schema_if_not_exists(tenant_id): logger.debug(f"Created schema for tenant {tenant_id}") else: logger.debug(f"Schema already exists for tenant {tenant_id}") # Set up the tenant with all necessary configurations await setup_tenant(tenant_id) # Assign the tenant to the user await assign_tenant_to_user(tenant_id, email) except Exception as e: logger.exception(f"Failed to create tenant {tenant_id}") raise HTTPException( status_code=500, detail=f"Failed to create tenant: {str(e)}" ) async def notify_control_plane( tenant_id: str, email: str, referral_source: str | None = None ) -> None: logger.info("Fetching billing information") token = generate_data_plane_token() headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json", } payload = TenantCreationPayload( tenant_id=tenant_id, email=email, referral_source=referral_source ) async with aiohttp.ClientSession() as session: async with session.post( f"{CONTROL_PLANE_API_BASE_URL}/tenants/create", headers=headers, json=payload.model_dump(), ) as response: if response.status != 200: error_text = await response.text() logger.error(f"Control plane tenant creation failed: {error_text}") raise Exception( f"Failed to create tenant on control plane: {error_text}" ) async def rollback_tenant_provisioning(tenant_id: str) -> None: """ Logic to rollback tenant provisioning on data plane. Handles each step independently to ensure maximum cleanup even if some steps fail. """ logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}") # Track if any part of the rollback fails rollback_errors = [] # 1. Try to drop the tenant's schema try: drop_schema(tenant_id) logger.info(f"Successfully dropped schema for tenant {tenant_id}") except Exception as e: error_msg = f"Failed to drop schema for tenant {tenant_id}: {str(e)}" logger.error(error_msg) rollback_errors.append(error_msg) # 2. Try to remove tenant mapping try: with get_session_with_shared_schema() as db_session: db_session.begin() try: db_session.query(UserTenantMapping).filter( UserTenantMapping.tenant_id == tenant_id ).delete() db_session.commit() logger.info( f"Successfully removed user mappings for tenant {tenant_id}" ) except Exception as e: db_session.rollback() raise e except Exception as e: error_msg = f"Failed to remove user mappings for tenant {tenant_id}: {str(e)}" logger.error(error_msg) rollback_errors.append(error_msg) # 3. If this tenant was in the available tenants table, remove it try: with get_session_with_shared_schema() as db_session: db_session.begin() try: available_tenant = ( db_session.query(AvailableTenant) .filter(AvailableTenant.tenant_id == tenant_id) .first() ) if available_tenant: db_session.delete(available_tenant) db_session.commit() logger.info( f"Removed tenant {tenant_id} from available tenants table" ) except Exception as e: db_session.rollback() raise e except Exception as e: error_msg = f"Failed to remove tenant {tenant_id} from available tenants table: {str(e)}" logger.error(error_msg) rollback_errors.append(error_msg) # Log summary of rollback operation if rollback_errors: logger.error(f"Tenant rollback completed with {len(rollback_errors)} errors") else: logger.info(f"Tenant rollback completed successfully for tenant {tenant_id}") def _build_model_configuration_upsert_requests( provider_name: str, recommendations: LLMRecommendations, ) -> list[ModelConfigurationUpsertRequest]: model_configurations = model_configurations_for_provider( provider_name, recommendations ) return [ ModelConfigurationUpsertRequest( name=model_configuration.name, is_visible=model_configuration.is_visible, max_input_tokens=model_configuration.max_input_tokens, supports_image_input=model_configuration.supports_image_input, ) for model_configuration in model_configurations ] def configure_default_api_keys(db_session: Session) -> None: """Configure default LLM providers using recommended-models.json for model selection.""" # Load recommendations from JSON config recommendations = get_recommendations() has_set_default_provider = False def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None: nonlocal has_set_default_provider try: existing = fetch_existing_llm_provider( name=request.name, db_session=db_session ) if existing: request.id = existing.id provider = upsert_llm_provider(request, db_session) if not has_set_default_provider: update_default_provider(provider.id, default_model, db_session) has_set_default_provider = True except Exception as e: logger.error(f"Failed to configure {request.provider} provider: {e}") # Configure OpenAI provider if OPENAI_DEFAULT_API_KEY: default_model = recommendations.get_default_model(OPENAI_PROVIDER_NAME) if default_model is None: logger.error( f"No default model found for {OPENAI_PROVIDER_NAME} in recommendations" ) default_model_name = default_model.name if default_model else "gpt-5.2" openai_provider = LLMProviderUpsertRequest( name="OpenAI", provider=OPENAI_PROVIDER_NAME, api_key=OPENAI_DEFAULT_API_KEY, model_configurations=_build_model_configuration_upsert_requests( OPENAI_PROVIDER_NAME, recommendations ), api_key_changed=True, is_auto_mode=True, ) _upsert(openai_provider, default_model_name) # Create default image generation config using the OpenAI API key try: create_default_image_gen_config_from_api_key( db_session, OPENAI_DEFAULT_API_KEY ) except Exception as e: logger.error(f"Failed to create default image gen config: {e}") else: logger.info( "OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration" ) # Configure Anthropic provider if ANTHROPIC_DEFAULT_API_KEY: default_model = recommendations.get_default_model(ANTHROPIC_PROVIDER_NAME) if default_model is None: logger.error( f"No default model found for {ANTHROPIC_PROVIDER_NAME} in recommendations" ) default_model_name = ( default_model.name if default_model else "claude-sonnet-4-5" ) anthropic_provider = LLMProviderUpsertRequest( name="Anthropic", provider=ANTHROPIC_PROVIDER_NAME, api_key=ANTHROPIC_DEFAULT_API_KEY, model_configurations=_build_model_configuration_upsert_requests( ANTHROPIC_PROVIDER_NAME, recommendations ), api_key_changed=True, is_auto_mode=True, ) _upsert(anthropic_provider, default_model_name) else: logger.info( "ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration" ) # Configure Vertex AI provider if VERTEXAI_DEFAULT_CREDENTIALS: default_model = recommendations.get_default_model(VERTEXAI_PROVIDER_NAME) if default_model is None: logger.error( f"No default model found for {VERTEXAI_PROVIDER_NAME} in recommendations" ) default_model_name = default_model.name if default_model else "gemini-2.5-pro" # Vertex AI uses custom_config for credentials and location custom_config = { VERTEX_CREDENTIALS_FILE_KWARG: VERTEXAI_DEFAULT_CREDENTIALS, VERTEX_LOCATION_KWARG: VERTEXAI_DEFAULT_LOCATION, } vertexai_provider = LLMProviderUpsertRequest( name="Google Vertex AI", provider=VERTEXAI_PROVIDER_NAME, custom_config=custom_config, model_configurations=_build_model_configuration_upsert_requests( VERTEXAI_PROVIDER_NAME, recommendations ), api_key_changed=True, is_auto_mode=True, ) _upsert(vertexai_provider, default_model_name) else: logger.info( "VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration" ) # Configure OpenRouter provider if OPENROUTER_DEFAULT_API_KEY: default_model = recommendations.get_default_model(OPENROUTER_PROVIDER_NAME) if default_model is None: logger.error( f"No default model found for {OPENROUTER_PROVIDER_NAME} in recommendations" ) default_model_name = default_model.name if default_model else "z-ai/glm-4.7" # For OpenRouter, we use the visible models from recommendations as model_configurations # since OpenRouter models are dynamic (fetched from their API) visible_models = recommendations.get_visible_models(OPENROUTER_PROVIDER_NAME) model_configurations = [ ModelConfigurationUpsertRequest( name=model.name, is_visible=True, max_input_tokens=None, display_name=model.display_name, ) for model in visible_models ] openrouter_provider = LLMProviderUpsertRequest( name="OpenRouter", provider=OPENROUTER_PROVIDER_NAME, api_key=OPENROUTER_DEFAULT_API_KEY, model_configurations=model_configurations, api_key_changed=True, is_auto_mode=True, ) _upsert(openrouter_provider, default_model_name) else: logger.info( "OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration" ) # Configure Cohere embedding provider if COHERE_DEFAULT_API_KEY: cloud_embedding_provider = CloudEmbeddingProviderCreationRequest( provider_type=EmbeddingProvider.COHERE, api_key=COHERE_DEFAULT_API_KEY, ) try: logger.info("Attempting to upsert Cohere cloud embedding provider") upsert_cloud_embedding_provider(db_session, cloud_embedding_provider) logger.info("Successfully upserted Cohere cloud embedding provider") logger.info("Updating search settings with Cohere embedding model details") query = ( select(SearchSettings) .where(SearchSettings.status == IndexModelStatus.FUTURE) .order_by(SearchSettings.id.desc()) ) result = db_session.execute(query) current_search_settings = result.scalars().first() if current_search_settings: current_search_settings.model_name = ( "embed-english-v3.0" # Cohere's latest model as of now ) current_search_settings.model_dim = ( 1024 # Cohere's embed-english-v3.0 dimension ) current_search_settings.provider_type = EmbeddingProvider.COHERE current_search_settings.index_name = ( "danswer_chunk_cohere_embed_english_v3_0" ) current_search_settings.query_prefix = "" current_search_settings.passage_prefix = "" db_session.commit() else: raise RuntimeError( "No search settings specified, DB is not in a valid state" ) logger.info("Fetching updated search settings to verify changes") updated_query = ( select(SearchSettings) .where(SearchSettings.status == IndexModelStatus.PRESENT) .order_by(SearchSettings.id.desc()) ) updated_result = db_session.execute(updated_query) updated_result.scalars().first() except Exception: logger.exception("Failed to configure Cohere embedding provider") else: logger.info( "COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration" ) async def submit_to_hubspot( email: str, referral_source: str | None, request: Request ) -> None: if not HUBSPOT_TRACKING_URL: logger.info("HUBSPOT_TRACKING_URL not set, skipping HubSpot submission") return # HubSpot tracking cookie hubspot_cookie = request.cookies.get("hubspotutk") # IP address ip_address = request.client.host if request.client else None data = { "fields": [ {"name": "email", "value": email}, {"name": "referral_source", "value": referral_source or ""}, ], "context": { "hutk": hubspot_cookie, "ipAddress": ip_address, "pageUri": str(request.url), "pageName": "User Registration", }, } async with httpx.AsyncClient() as client: response = await client.post(HUBSPOT_TRACKING_URL, json=data) if response.status_code != 200: logger.error(f"Failed to submit to HubSpot: {response.text}") async def delete_user_from_control_plane(tenant_id: str, email: str) -> None: token = generate_data_plane_token() headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json", } payload = TenantDeletionPayload(tenant_id=tenant_id, email=email) async with aiohttp.ClientSession() as session: async with session.delete( f"{CONTROL_PLANE_API_BASE_URL}/tenants/delete", headers=headers, json=payload.model_dump(), ) as response: if response.status != 200: error_text = await response.text() logger.error(f"Control plane tenant creation failed: {error_text}") raise Exception( f"Failed to delete tenant on control plane: {error_text}" ) def get_tenant_by_domain_from_control_plane( domain: str, tenant_id: str, ) -> TenantByDomainResponse | None: """ Fetches tenant information from the control plane based on the email domain. Args: domain: The email domain to search for (e.g., "example.com") Returns: A dictionary containing tenant information if found, None otherwise """ token = generate_data_plane_token() headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json", } try: response = requests.get( f"{CONTROL_PLANE_API_BASE_URL}/tenant-by-domain", headers=headers, json={"domain": domain, "tenant_id": tenant_id}, ) if response.status_code != 200: logger.error(f"Control plane tenant lookup failed: {response.text}") return None response_data = response.json() if not response_data: return None return TenantByDomainResponse( tenant_id=response_data.get("tenant_id"), number_of_users=response_data.get("number_of_users"), creator_email=response_data.get("creator_email"), ) except Exception as e: logger.error(f"Error fetching tenant by domain: {str(e)}") return None async def get_available_tenant() -> str | None: """ Get an available pre-provisioned tenant from the NewAvailableTenant table. Returns the tenant_id if one is available, None otherwise. Uses row-level locking to prevent race conditions when multiple processes try to get an available tenant simultaneously. """ if not MULTI_TENANT: return None with get_session_with_shared_schema() as db_session: try: db_session.begin() # Get the oldest available tenant with FOR UPDATE lock to prevent race conditions available_tenant = ( db_session.query(AvailableTenant) .order_by(AvailableTenant.date_created) .with_for_update(skip_locked=True) # Skip locked rows to avoid blocking .first() ) if available_tenant: tenant_id = available_tenant.tenant_id # Remove the tenant from the available tenants table db_session.delete(available_tenant) db_session.commit() logger.info(f"Using pre-provisioned tenant {tenant_id}") return tenant_id else: db_session.rollback() return None except Exception: logger.exception("Error getting available tenant") db_session.rollback() return None async def setup_tenant(tenant_id: str) -> None: """ Set up a tenant with all necessary configurations. This is a centralized function that handles all tenant setup logic. """ token = None try: token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) # Run Alembic migrations in a way that isolates it from the current event loop # Create a new event loop for this synchronous operation loop = asyncio.get_event_loop() # Use run_in_executor which properly isolates the thread execution await loop.run_in_executor(None, lambda: run_alembic_migrations(tenant_id)) # Configure the tenant with default settings with get_session_with_tenant(tenant_id=tenant_id) as db_session: # Configure default API keys configure_default_api_keys(db_session) # Set up Onyx with appropriate settings current_search_settings = ( db_session.query(SearchSettings) .filter_by(status=IndexModelStatus.FUTURE) .first() ) cohere_enabled = ( current_search_settings is not None and current_search_settings.provider_type == EmbeddingProvider.COHERE ) setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled) except Exception as e: logger.exception(f"Failed to set up tenant {tenant_id}") raise e finally: if token is not None: CURRENT_TENANT_ID_CONTEXTVAR.reset(token) async def assign_tenant_to_user( tenant_id: str, email: str, referral_source: str | None = None, # noqa: ARG001 ) -> None: """ Assign a tenant to a user and perform necessary operations. Uses transaction handling to ensure atomicity and includes retry logic for control plane notifications. """ # First, add the user to the tenant in a transaction try: add_users_to_tenant([email], tenant_id) except Exception: logger.exception(f"Failed to assign tenant {tenant_id} to user {email}") raise Exception("Failed to assign tenant to user") ================================================ FILE: backend/ee/onyx/server/tenants/proxy.py ================================================ """Proxy endpoints for billing operations. These endpoints run on the CLOUD DATA PLANE (cloud.onyx.app) and serve as a proxy for self-hosted instances to reach the control plane. Flow: Self-hosted backend → Cloud DP /proxy/* (license auth) → Control plane (JWT auth) Self-hosted instances call these endpoints with their license in the Authorization header. The cloud data plane validates the license signature and forwards the request to the control plane using JWT authentication. Auth levels by endpoint: - /create-checkout-session: No auth (new customer) or expired license OK (renewal) - /claim-license: Session ID based (one-time after Stripe payment) - /create-customer-portal-session: Expired license OK (need portal to fix payment) - /billing-information: Valid license required - /license/{tenant_id}: Valid license required - /seats/update: Valid license required """ from typing import Literal import httpx from fastapi import APIRouter from fastapi import Depends from fastapi import Header from fastapi import HTTPException from pydantic import BaseModel from ee.onyx.configs.app_configs import LICENSE_ENFORCEMENT_ENABLED from ee.onyx.server.billing.models import SeatUpdateRequest from ee.onyx.server.billing.models import SeatUpdateResponse from ee.onyx.server.license.models import LicensePayload from ee.onyx.server.tenants.access import generate_data_plane_token from ee.onyx.utils.license import is_license_valid from ee.onyx.utils.license import verify_license_signature from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL from onyx.utils.logger import setup_logger logger = setup_logger() router = APIRouter(prefix="/proxy") def _check_license_enforcement_enabled() -> None: """Ensure LICENSE_ENFORCEMENT_ENABLED is true (proxy endpoints only work on cloud DP).""" if not LICENSE_ENFORCEMENT_ENABLED: raise HTTPException( status_code=501, detail="Proxy endpoints are only available on cloud data plane", ) def _extract_license_from_header( authorization: str | None, required: bool = True, ) -> str | None: """Extract license data from Authorization header. Self-hosted instances authenticate to these proxy endpoints by sending their license as a Bearer token: `Authorization: Bearer `. We use the Bearer scheme (RFC 6750) because: 1. It's the standard HTTP auth scheme for token-based authentication 2. The license blob is cryptographically signed (RSA), so it's self-validating 3. No other auth schemes (Basic, Digest, etc.) are supported for license auth The license data is the base64-encoded signed blob that contains tenant_id, seats, expiration, etc. We verify the signature to authenticate the caller. Args: authorization: The Authorization header value (e.g., "Bearer ") required: If True, raise 401 when header is missing/invalid Returns: License data string (base64-encoded), or None if not required and missing Raises: HTTPException: 401 if required and header is missing/invalid """ if not authorization or not authorization.startswith("Bearer "): if required: raise HTTPException( status_code=401, detail="Missing or invalid authorization header" ) return None return authorization.split(" ", 1)[1] def verify_license_auth( license_data: str, allow_expired: bool = False, ) -> LicensePayload: """Verify license signature and optionally check expiry. Args: license_data: Base64-encoded signed license blob allow_expired: If True, accept expired licenses (for renewal flows) Returns: LicensePayload if valid Raises: HTTPException: If license is invalid or expired (when not allowed) """ _check_license_enforcement_enabled() try: payload = verify_license_signature(license_data) except ValueError as e: raise HTTPException(status_code=401, detail=f"Invalid license: {e}") if not allow_expired and not is_license_valid(payload): raise HTTPException(status_code=401, detail="License has expired") return payload async def get_license_payload( authorization: str | None = Header(None, alias="Authorization"), ) -> LicensePayload: """Dependency: Require valid (non-expired) license. Used for endpoints that require an active subscription. """ license_data = _extract_license_from_header(authorization, required=True) # license_data is guaranteed non-None when required=True assert license_data is not None return verify_license_auth(license_data, allow_expired=False) async def get_license_payload_allow_expired( authorization: str | None = Header(None, alias="Authorization"), ) -> LicensePayload: """Dependency: Require license with valid signature, expired OK. Used for endpoints needed to fix payment issues (portal, renewal checkout). """ license_data = _extract_license_from_header(authorization, required=True) # license_data is guaranteed non-None when required=True assert license_data is not None return verify_license_auth(license_data, allow_expired=True) async def get_optional_license_payload( authorization: str | None = Header(None, alias="Authorization"), ) -> LicensePayload | None: """Dependency: Optional license auth (for checkout - new customers have none). Returns None if no license provided, otherwise validates and returns payload. Expired licenses are allowed for renewal flows. """ _check_license_enforcement_enabled() license_data = _extract_license_from_header(authorization, required=False) if license_data is None: return None return verify_license_auth(license_data, allow_expired=True) async def forward_to_control_plane( method: str, path: str, body: dict | None = None, params: dict | None = None, ) -> dict: """Forward a request to the control plane with proper authentication.""" token = generate_data_plane_token() headers = { "Authorization": f"Bearer {token}", "Content-Type": "application/json", } url = f"{CONTROL_PLANE_API_BASE_URL}{path}" try: async with httpx.AsyncClient(timeout=30.0, follow_redirects=True) as client: if method == "GET": response = await client.get(url, headers=headers, params=params) elif method == "POST": response = await client.post(url, headers=headers, json=body) else: raise ValueError(f"Unsupported HTTP method: {method}") response.raise_for_status() return response.json() except httpx.HTTPStatusError as e: status_code = e.response.status_code detail = "Control plane request failed" try: error_data = e.response.json() detail = error_data.get("detail", detail) except Exception: pass logger.error(f"Control plane returned {status_code}: {detail}") raise HTTPException(status_code=status_code, detail=detail) except httpx.RequestError: logger.exception("Failed to connect to control plane") raise HTTPException( status_code=502, detail="Failed to connect to control plane" ) # ----------------------------------------------------------------------------- # Endpoints # ----------------------------------------------------------------------------- class CreateCheckoutSessionRequest(BaseModel): billing_period: Literal["monthly", "annual"] = "monthly" seats: int | None = None email: str | None = None # Redirect URL after successful checkout - self-hosted passes their instance URL redirect_url: str | None = None # Cancel URL when user exits checkout - returns to upgrade page cancel_url: str | None = None class CreateCheckoutSessionResponse(BaseModel): url: str @router.post("/create-checkout-session") async def proxy_create_checkout_session( request_body: CreateCheckoutSessionRequest, license_payload: LicensePayload | None = Depends(get_optional_license_payload), ) -> CreateCheckoutSessionResponse: """Proxy checkout session creation to control plane. Auth: Optional license (new customers don't have one yet). If license provided, expired is OK (for renewals). """ # license_payload is None for new customers who don't have a license yet. # In that case, tenant_id is omitted from the request body and the control # plane will create a new tenant during checkout completion. tenant_id = license_payload.tenant_id if license_payload else None body: dict = { "billing_period": request_body.billing_period, } if tenant_id: body["tenant_id"] = tenant_id if request_body.seats is not None: body["seats"] = request_body.seats if request_body.email: body["email"] = request_body.email if request_body.redirect_url: body["redirect_url"] = request_body.redirect_url if request_body.cancel_url: body["cancel_url"] = request_body.cancel_url result = await forward_to_control_plane( "POST", "/create-checkout-session", body=body ) return CreateCheckoutSessionResponse(url=result["url"]) class ClaimLicenseRequest(BaseModel): session_id: str class ClaimLicenseResponse(BaseModel): tenant_id: str license: str message: str | None = None @router.post("/claim-license") async def proxy_claim_license( request_body: ClaimLicenseRequest, ) -> ClaimLicenseResponse: """Claim a license after successful Stripe checkout. Auth: Session ID based (one-time use after payment). The control plane verifies the session_id is valid and unclaimed. Returns the license to the caller. For self-hosted instances, they will store the license locally. The cloud DP doesn't need to store it. """ _check_license_enforcement_enabled() result = await forward_to_control_plane( "POST", "/claim-license", body={"session_id": request_body.session_id}, ) tenant_id = result.get("tenant_id") license_data = result.get("license") if not tenant_id or not license_data: logger.error(f"Control plane returned incomplete claim response: {result}") raise HTTPException( status_code=502, detail="Control plane returned incomplete license data", ) return ClaimLicenseResponse( tenant_id=tenant_id, license=license_data, message="License claimed successfully", ) class CreateCustomerPortalSessionRequest(BaseModel): return_url: str | None = None class CreateCustomerPortalSessionResponse(BaseModel): url: str @router.post("/create-customer-portal-session") async def proxy_create_customer_portal_session( request_body: CreateCustomerPortalSessionRequest | None = None, license_payload: LicensePayload = Depends(get_license_payload_allow_expired), ) -> CreateCustomerPortalSessionResponse: """Proxy customer portal session creation to control plane. Auth: License required, expired OK (need portal to fix payment issues). """ # tenant_id is a required field in LicensePayload (Pydantic validates this), # but we check explicitly for defense in depth if not license_payload.tenant_id: raise HTTPException(status_code=401, detail="License missing tenant_id") tenant_id = license_payload.tenant_id body: dict = {"tenant_id": tenant_id} if request_body and request_body.return_url: body["return_url"] = request_body.return_url result = await forward_to_control_plane( "POST", "/create-customer-portal-session", body=body ) return CreateCustomerPortalSessionResponse(url=result["url"]) class BillingInformationResponse(BaseModel): tenant_id: str status: str | None = None plan_type: str | None = None seats: int | None = None billing_period: str | None = None current_period_start: str | None = None current_period_end: str | None = None cancel_at_period_end: bool = False canceled_at: str | None = None trial_start: str | None = None trial_end: str | None = None payment_method_enabled: bool = False stripe_subscription_id: str | None = None @router.get("/billing-information") async def proxy_billing_information( license_payload: LicensePayload = Depends(get_license_payload), ) -> BillingInformationResponse: """Proxy billing information request to control plane. Auth: Valid (non-expired) license required. """ # tenant_id is a required field in LicensePayload (Pydantic validates this), # but we check explicitly for defense in depth if not license_payload.tenant_id: raise HTTPException(status_code=401, detail="License missing tenant_id") tenant_id = license_payload.tenant_id result = await forward_to_control_plane( "GET", "/billing-information", params={"tenant_id": tenant_id} ) # Add tenant_id from license if not in response (control plane may not include it) if "tenant_id" not in result: result["tenant_id"] = tenant_id return BillingInformationResponse(**result) class LicenseFetchResponse(BaseModel): license: str tenant_id: str @router.get("/license/{tenant_id}") async def proxy_license_fetch( tenant_id: str, license_payload: LicensePayload = Depends(get_license_payload), ) -> LicenseFetchResponse: """Proxy license fetch to control plane. Auth: Valid license required. The tenant_id in path must match the authenticated tenant. """ # tenant_id is a required field in LicensePayload (Pydantic validates this), # but we check explicitly for defense in depth if not license_payload.tenant_id: raise HTTPException(status_code=401, detail="License missing tenant_id") if tenant_id != license_payload.tenant_id: raise HTTPException( status_code=403, detail="Cannot fetch license for a different tenant", ) result = await forward_to_control_plane("GET", f"/license/{tenant_id}") license_data = result.get("license") if not license_data: logger.error(f"Control plane returned incomplete license response: {result}") raise HTTPException( status_code=502, detail="Control plane returned incomplete license data", ) # Return license to caller - self-hosted instance stores it via /api/license/claim return LicenseFetchResponse(license=license_data, tenant_id=tenant_id) @router.post("/seats/update") async def proxy_seat_update( request_body: SeatUpdateRequest, license_payload: LicensePayload = Depends(get_license_payload), ) -> SeatUpdateResponse: """Proxy seat update to control plane. Auth: Valid (non-expired) license required. Handles Stripe proration and license regeneration. Returns the regenerated license in the response for the caller to store. """ if not license_payload.tenant_id: raise HTTPException(status_code=401, detail="License missing tenant_id") tenant_id = license_payload.tenant_id result = await forward_to_control_plane( "POST", "/seats/update", body={ "tenant_id": tenant_id, "new_seat_count": request_body.new_seat_count, }, ) # Return license in response - self-hosted instance stores it via /api/license/claim return SeatUpdateResponse( success=result.get("success", False), current_seats=result.get("current_seats", 0), used_seats=result.get("used_seats", 0), message=result.get("message"), license=result.get("license"), ) ================================================ FILE: backend/ee/onyx/server/tenants/schema_management.py ================================================ import logging import os import re from types import SimpleNamespace from sqlalchemy import text from sqlalchemy.orm import Session from sqlalchemy.schema import CreateSchema from alembic import command from alembic.config import Config from onyx.db.engine.sql_engine import build_connection_string from onyx.db.engine.sql_engine import get_sqlalchemy_engine from shared_configs.configs import TENANT_ID_PREFIX logger = logging.getLogger(__name__) # Regex pattern for valid tenant IDs: # - UUID format: tenant_xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx # - AWS instance ID format: tenant_i-xxxxxxxxxxxxxxxxx # Also useful for not accidentally dropping `public` schema TENANT_ID_PATTERN = re.compile( rf"^{re.escape(TENANT_ID_PREFIX)}(" r"[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}" # UUID r"|i-[a-f0-9]+" # AWS instance ID r")$" ) def validate_tenant_id(tenant_id: str) -> bool: """Validate that tenant_id matches expected format. This is important for SQL injection prevention since schema names cannot be parameterized in SQL and must be formatted directly. """ return bool(TENANT_ID_PATTERN.match(tenant_id)) def run_alembic_migrations(schema_name: str) -> None: logger.info(f"Starting Alembic migrations for schema: {schema_name}") try: current_dir = os.path.dirname(os.path.abspath(__file__)) root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", "..")) alembic_ini_path = os.path.join(root_dir, "alembic.ini") # Configure Alembic alembic_cfg = Config(alembic_ini_path) alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string()) alembic_cfg.set_main_option( "script_location", os.path.join(root_dir, "alembic") ) # Ensure that logging isn't broken alembic_cfg.attributes["configure_logger"] = False # Mimic command-line options by adding 'cmd_opts' to the config alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore alembic_cfg.cmd_opts.x = [f"schemas={schema_name}"] # type: ignore # Run migrations programmatically command.upgrade(alembic_cfg, "head") # Run migrations programmatically logger.info( f"Alembic migrations completed successfully for schema: {schema_name}" ) except Exception as e: logger.exception(f"Alembic migration failed for schema {schema_name}: {str(e)}") raise def create_schema_if_not_exists(tenant_id: str) -> bool: with Session(get_sqlalchemy_engine()) as db_session: with db_session.begin(): result = db_session.execute( text( "SELECT schema_name FROM information_schema.schemata WHERE schema_name = :schema_name" ), {"schema_name": tenant_id}, ) schema_exists = result.scalar() is not None if not schema_exists: stmt = CreateSchema(tenant_id) db_session.execute(stmt) return True return False def drop_schema(tenant_id: str) -> None: """Drop a tenant's schema. Uses strict regex validation to reject unexpected formats early, preventing SQL injection since schema names cannot be parameterized. """ if not validate_tenant_id(tenant_id): raise ValueError(f"Invalid tenant_id format: {tenant_id}") with get_sqlalchemy_engine().connect() as connection: with connection.begin(): # Use string formatting with validated tenant_id (safe after validation) connection.execute(text(f'DROP SCHEMA IF EXISTS "{tenant_id}" CASCADE')) def get_current_alembic_version(tenant_id: str) -> str: """Get the current Alembic version for a tenant.""" from alembic.runtime.migration import MigrationContext from sqlalchemy import text engine = get_sqlalchemy_engine() # Set the search path to the tenant's schema with engine.connect() as connection: connection.execute(text(f'SET search_path TO "{tenant_id}"')) # Get the current version from the alembic_version table context = MigrationContext.configure(connection) current_rev = context.get_current_revision() return current_rev or "head" ================================================ FILE: backend/ee/onyx/server/tenants/team_membership_api.py ================================================ from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from sqlalchemy.orm import Session from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant from onyx.auth.users import current_admin_user from onyx.auth.users import User from onyx.db.auth import get_user_count from onyx.db.engine.sql_engine import get_session from onyx.db.users import delete_user_from_db from onyx.db.users import get_user_by_email from onyx.server.manage.models import UserByEmail from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() router = APIRouter(prefix="/tenants") @router.post("/leave-team") async def leave_organization( user_email: UserByEmail, current_user: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> None: tenant_id = get_current_tenant_id() if current_user.email != user_email.user_email: raise HTTPException( status_code=403, detail="You can only leave the organization as yourself" ) user_to_delete = get_user_by_email(user_email.user_email, db_session) if user_to_delete is None: raise HTTPException(status_code=404, detail="User not found") num_admin_users = await get_user_count(only_admin_users=True) should_delete_tenant = num_admin_users == 1 if should_delete_tenant: logger.info( "Last admin user is leaving the organization. Deleting tenant from control plane." ) try: await delete_user_from_control_plane(tenant_id, user_to_delete.email) logger.debug("User deleted from control plane") except Exception as e: logger.exception( f"Failed to delete user from control plane for tenant {tenant_id}: {e}" ) raise HTTPException( status_code=500, detail=f"Failed to remove user from control plane: {str(e)}", ) db_session.expunge(user_to_delete) delete_user_from_db(user_to_delete, db_session) if should_delete_tenant: remove_all_users_from_tenant(tenant_id) else: remove_users_from_tenant([user_to_delete.email], tenant_id) ================================================ FILE: backend/ee/onyx/server/tenants/tenant_management_api.py ================================================ from fastapi import APIRouter from fastapi import Depends from ee.onyx.server.tenants.models import TenantByDomainResponse from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane from onyx.auth.users import current_user from onyx.auth.users import User from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() router = APIRouter(prefix="/tenants") FORBIDDEN_COMMON_EMAIL_SUBSTRINGS = [ "gmail", "outlook", "yahoo", "hotmail", "icloud", "msn", "hotmail", "hotmail.co.uk", ] @router.get("/existing-team-by-domain") def get_existing_tenant_by_domain( user: User = Depends(current_user), ) -> TenantByDomainResponse | None: domain = user.email.split("@")[1] if any(substring in domain for substring in FORBIDDEN_COMMON_EMAIL_SUBSTRINGS): return None tenant_id = get_current_tenant_id() return get_tenant_by_domain_from_control_plane(domain, tenant_id) ================================================ FILE: backend/ee/onyx/server/tenants/user_invitations_api.py ================================================ from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from ee.onyx.server.tenants.models import ApproveUserRequest from ee.onyx.server.tenants.models import PendingUserSnapshot from ee.onyx.server.tenants.models import RequestInviteRequest from ee.onyx.server.tenants.user_mapping import accept_user_invite from ee.onyx.server.tenants.user_mapping import approve_user_invite from ee.onyx.server.tenants.user_mapping import deny_user_invite from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant from onyx.auth.invited_users import get_pending_users from onyx.auth.users import current_admin_user from onyx.auth.users import current_user from onyx.auth.users import User from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() router = APIRouter(prefix="/tenants") @router.post("/users/invite/request") async def request_invite( invite_request: RequestInviteRequest, user: User = Depends(current_admin_user), ) -> None: try: invite_self_to_tenant(user.email, invite_request.tenant_id) except Exception as e: logger.exception( f"Failed to invite self to tenant {invite_request.tenant_id}: {e}" ) raise HTTPException(status_code=500, detail=str(e)) @router.get("/users/pending") def list_pending_users( _: User = Depends(current_admin_user), ) -> list[PendingUserSnapshot]: pending_emails = get_pending_users() return [PendingUserSnapshot(email=email) for email in pending_emails] @router.post("/users/invite/approve") async def approve_user( approve_user_request: ApproveUserRequest, _: User = Depends(current_admin_user), ) -> None: tenant_id = get_current_tenant_id() approve_user_invite(approve_user_request.email, tenant_id) @router.post("/users/invite/accept") async def accept_invite( invite_request: RequestInviteRequest, user: User = Depends(current_user), ) -> None: """ Accept an invitation to join a tenant. """ try: accept_user_invite(user.email, invite_request.tenant_id) except Exception as e: logger.exception(f"Failed to accept invite: {str(e)}") raise HTTPException(status_code=500, detail="Failed to accept invitation") @router.post("/users/invite/deny") async def deny_invite( invite_request: RequestInviteRequest, user: User = Depends(current_user), ) -> None: """ Deny an invitation to join a tenant. """ try: deny_user_invite(user.email, invite_request.tenant_id) except Exception as e: logger.exception(f"Failed to deny invite: {str(e)}") raise HTTPException(status_code=500, detail="Failed to deny invitation") ================================================ FILE: backend/ee/onyx/server/tenants/user_mapping.py ================================================ from fastapi_users import exceptions from sqlalchemy import select from onyx.auth.invited_users import get_invited_users from onyx.auth.invited_users import get_pending_users from onyx.auth.invited_users import write_invited_users from onyx.auth.invited_users import write_pending_users from onyx.db.engine.sql_engine import get_session_with_shared_schema from onyx.db.engine.sql_engine import get_session_with_tenant from onyx.db.models import UserTenantMapping from onyx.server.manage.models import TenantSnapshot from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() def get_tenant_id_for_email(email: str) -> str: if not MULTI_TENANT: return POSTGRES_DEFAULT_SCHEMA # Implement logic to get tenant_id from the mapping table try: with get_session_with_shared_schema() as db_session: # First try to get an active tenant result = db_session.execute( select(UserTenantMapping).where( UserTenantMapping.email == email, UserTenantMapping.active == True, # noqa: E712 ) ) mapping = result.scalar_one_or_none() tenant_id = mapping.tenant_id if mapping else None # If no active tenant found, try to get the first inactive one if tenant_id is None: result = db_session.execute( select(UserTenantMapping).where( UserTenantMapping.email == email, UserTenantMapping.active == False, # noqa: E712 ) ) mapping = result.scalar_one_or_none() if mapping: # Mark this mapping as active mapping.active = True db_session.commit() tenant_id = mapping.tenant_id except Exception as e: logger.exception(f"Error getting tenant id for email {email}: {e}") raise exceptions.UserNotExists() if tenant_id is None: raise exceptions.UserNotExists() return tenant_id def user_owns_a_tenant(email: str) -> bool: with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session: result = ( db_session.query(UserTenantMapping) .filter(UserTenantMapping.email == email) .first() ) return result is not None def add_users_to_tenant(emails: list[str], tenant_id: str) -> None: """ Add users to a tenant with proper transaction handling. Checks if users already have a tenant mapping to avoid duplicates. If a user already has an active mapping to a different tenant, they receive an inactive mapping (invitation) to this tenant. They can accept the invitation later to switch tenants. """ unique_emails = set(emails) if not unique_emails: return with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session: try: # Start a transaction db_session.begin() # Batch query 1: Get all existing mappings for these emails to this tenant # Lock rows to prevent concurrent modifications existing_mappings = ( db_session.query(UserTenantMapping) .filter( UserTenantMapping.email.in_(unique_emails), UserTenantMapping.tenant_id == tenant_id, ) .with_for_update() .all() ) emails_with_mapping = {m.email for m in existing_mappings} # Batch query 2: Get all active mappings for these emails (any tenant) active_mappings = ( db_session.query(UserTenantMapping) .filter( UserTenantMapping.email.in_(unique_emails), UserTenantMapping.active == True, # noqa: E712 ) .all() ) emails_with_active_mapping = {m.email for m in active_mappings} # Add mappings for emails that don't already have one to this tenant for email in unique_emails: if email in emails_with_mapping: continue # Create mapping: inactive if user belongs to another tenant (invitation), # active otherwise db_session.add( UserTenantMapping( email=email, tenant_id=tenant_id, active=email not in emails_with_active_mapping, ) ) # Commit the transaction db_session.commit() logger.info(f"Successfully added users {emails} to tenant {tenant_id}") except Exception: logger.exception(f"Failed to add users to tenant {tenant_id}") db_session.rollback() raise def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None: with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session: try: mappings_to_delete = ( db_session.query(UserTenantMapping) .filter( UserTenantMapping.email.in_(emails), UserTenantMapping.tenant_id == tenant_id, ) .all() ) for mapping in mappings_to_delete: db_session.delete(mapping) db_session.commit() except Exception as e: logger.exception( f"Failed to remove users from tenant {tenant_id}: {str(e)}" ) db_session.rollback() def remove_all_users_from_tenant(tenant_id: str) -> None: with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session: db_session.query(UserTenantMapping).filter( UserTenantMapping.tenant_id == tenant_id ).delete() db_session.commit() def invite_self_to_tenant(email: str, tenant_id: str) -> None: token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) try: pending_users = get_pending_users() if email in pending_users: return write_pending_users(pending_users + [email]) finally: CURRENT_TENANT_ID_CONTEXTVAR.reset(token) def approve_user_invite(email: str, tenant_id: str) -> None: """ Approve a user invite to a tenant. This will delete all existing records for this email and create a new mapping entry for the user in this tenant. """ with get_session_with_shared_schema() as db_session: # Delete all existing records for this email db_session.query(UserTenantMapping).filter( UserTenantMapping.email == email ).delete() # Create a new mapping entry for the user in this tenant new_mapping = UserTenantMapping(email=email, tenant_id=tenant_id, active=True) db_session.add(new_mapping) db_session.commit() # Also remove the user from pending users list # Remove from pending users pending_users = get_pending_users() if email in pending_users: pending_users.remove(email) write_pending_users(pending_users) # Add to invited users invited_users = get_invited_users() if email not in invited_users: invited_users.append(email) write_invited_users(invited_users) def accept_user_invite(email: str, tenant_id: str) -> None: """ Accept an invitation to join a tenant. This activates the user's mapping to the tenant. """ with get_session_with_shared_schema() as db_session: try: # Lock the user's mappings first to prevent race conditions. # This ensures no concurrent request can modify this user's mappings. active_mapping = ( db_session.query(UserTenantMapping) .filter( UserTenantMapping.email == email, UserTenantMapping.active == True, # noqa: E712 ) .with_for_update() .first() ) # If an active mapping exists, delete it if active_mapping: db_session.delete(active_mapping) logger.info( f"Deleted existing active mapping for user {email} in tenant {tenant_id}" ) # Find the inactive mapping for this user and tenant mapping = ( db_session.query(UserTenantMapping) .filter( UserTenantMapping.email == email, UserTenantMapping.tenant_id == tenant_id, UserTenantMapping.active == False, # noqa: E712 ) .first() ) if mapping: # Set all other mappings for this user to inactive db_session.query(UserTenantMapping).filter( UserTenantMapping.email == email, UserTenantMapping.active == True, # noqa: E712 ).update({"active": False}) # Activate this mapping mapping.active = True db_session.commit() logger.info(f"User {email} accepted invitation to tenant {tenant_id}") else: logger.warning( f"No invitation found for user {email} in tenant {tenant_id}" ) except Exception as e: db_session.rollback() logger.exception( f"Failed to accept invitation for user {email} to tenant {tenant_id}: {str(e)}" ) raise # Remove from invited users list since they've accepted token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) try: invited_users = get_invited_users() if email in invited_users: invited_users.remove(email) write_invited_users(invited_users) logger.info(f"Removed {email} from invited users list after acceptance") finally: CURRENT_TENANT_ID_CONTEXTVAR.reset(token) def deny_user_invite(email: str, tenant_id: str) -> None: """ Deny an invitation to join a tenant. This removes the user's mapping to the tenant. """ with get_session_with_shared_schema() as db_session: # Delete the mapping for this user and tenant result = ( db_session.query(UserTenantMapping) .filter( UserTenantMapping.email == email, UserTenantMapping.tenant_id == tenant_id, UserTenantMapping.active == False, # noqa: E712 ) .delete() ) db_session.commit() if result: logger.info(f"User {email} denied invitation to tenant {tenant_id}") else: logger.warning( f"No invitation found for user {email} in tenant {tenant_id}" ) token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) try: pending_users = get_invited_users() if email in pending_users: pending_users.remove(email) write_invited_users(pending_users) finally: CURRENT_TENANT_ID_CONTEXTVAR.reset(token) def get_tenant_count(tenant_id: str) -> int: """ Get the number of active users for this tenant. A user counts toward the seat count if: 1. They have an active mapping to this tenant (UserTenantMapping.active == True) 2. AND the User is active (User.is_active == True) 3. AND the User is not the anonymous system user TODO: Exclude API key dummy users from seat counting. API keys create users with emails like `__DANSWER_API_KEY_*` that should not count toward seat limits. See: https://linear.app/onyx-app/issue/ENG-3518 """ from onyx.configs.constants import ANONYMOUS_USER_EMAIL from onyx.db.models import User # First get all emails with active mappings to this tenant with get_session_with_shared_schema() as db_session: active_mapping_emails = ( db_session.query(UserTenantMapping.email) .filter( UserTenantMapping.tenant_id == tenant_id, UserTenantMapping.active == True, # noqa: E712 UserTenantMapping.email != ANONYMOUS_USER_EMAIL, ) .all() ) emails = [email for (email,) in active_mapping_emails] if not emails: return 0 # Now count how many of those users are actually active in the tenant's User table with get_session_with_tenant(tenant_id=tenant_id) as db_session: user_count = ( db_session.query(User) .filter( User.email.in_(emails), # type: ignore User.is_active == True, # type: ignore # noqa: E712 ) .count() ) return user_count def get_tenant_invitation(email: str) -> TenantSnapshot | None: """ Get the first tenant invitation for this user """ with get_session_with_shared_schema() as db_session: # Get the first tenant invitation for this user invitation = ( db_session.query(UserTenantMapping) .filter( UserTenantMapping.email == email, UserTenantMapping.active == False, # noqa: E712 ) .first() ) if invitation: # Get the user count for this tenant user_count = ( db_session.query(UserTenantMapping) .filter( UserTenantMapping.tenant_id == invitation.tenant_id, UserTenantMapping.active == True, # noqa: E712 ) .count() ) return TenantSnapshot( tenant_id=invitation.tenant_id, number_of_users=user_count ) return None ================================================ FILE: backend/ee/onyx/server/token_rate_limits/api.py ================================================ from collections import defaultdict from fastapi import APIRouter from fastapi import Depends from sqlalchemy.orm import Session from ee.onyx.db.token_limit import fetch_all_user_group_token_rate_limits_by_group from ee.onyx.db.token_limit import fetch_user_group_token_rate_limits_for_user from ee.onyx.db.token_limit import insert_user_group_token_rate_limit from onyx.auth.users import current_admin_user from onyx.auth.users import current_curator_or_admin_user from onyx.configs.constants import PUBLIC_API_TAGS from onyx.db.engine.sql_engine import get_session from onyx.db.models import User from onyx.db.token_limit import fetch_all_user_token_rate_limits from onyx.db.token_limit import insert_user_token_rate_limit from onyx.server.query_and_chat.token_limit import any_rate_limit_exists from onyx.server.token_rate_limits.models import TokenRateLimitArgs from onyx.server.token_rate_limits.models import TokenRateLimitDisplay router = APIRouter(prefix="/admin/token-rate-limits", tags=PUBLIC_API_TAGS) """ Group Token Limit Settings """ @router.get("/user-groups") def get_all_group_token_limit_settings( _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> dict[str, list[TokenRateLimitDisplay]]: user_groups_to_token_rate_limits = fetch_all_user_group_token_rate_limits_by_group( db_session ) token_rate_limits_by_group = defaultdict(list) for token_rate_limit, group_name in user_groups_to_token_rate_limits: token_rate_limits_by_group[group_name].append( TokenRateLimitDisplay.from_db(token_rate_limit) ) return dict(token_rate_limits_by_group) @router.get("/user-group/{group_id}") def get_group_token_limit_settings( group_id: int, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> list[TokenRateLimitDisplay]: return [ TokenRateLimitDisplay.from_db(token_rate_limit) for token_rate_limit in fetch_user_group_token_rate_limits_for_user( db_session=db_session, group_id=group_id, user=user, ) ] @router.post("/user-group/{group_id}") def create_group_token_limit_settings( group_id: int, token_limit_settings: TokenRateLimitArgs, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> TokenRateLimitDisplay: rate_limit_display = TokenRateLimitDisplay.from_db( insert_user_group_token_rate_limit( db_session=db_session, token_rate_limit_settings=token_limit_settings, group_id=group_id, ) ) # clear cache in case this was the first rate limit created any_rate_limit_exists.cache_clear() return rate_limit_display """ User Token Limit Settings """ @router.get("/users") def get_user_token_limit_settings( _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> list[TokenRateLimitDisplay]: return [ TokenRateLimitDisplay.from_db(token_rate_limit) for token_rate_limit in fetch_all_user_token_rate_limits(db_session) ] @router.post("/users") def create_user_token_limit_settings( token_limit_settings: TokenRateLimitArgs, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> TokenRateLimitDisplay: rate_limit_display = TokenRateLimitDisplay.from_db( insert_user_token_rate_limit(db_session, token_limit_settings) ) # clear cache in case this was the first rate limit created any_rate_limit_exists.cache_clear() return rate_limit_display ================================================ FILE: backend/ee/onyx/server/usage_limits.py ================================================ """EE Usage limits - trial detection via billing information.""" from ee.onyx.server.tenants.billing import fetch_billing_information from ee.onyx.server.tenants.models import BillingInformation from ee.onyx.server.tenants.models import SubscriptionStatusResponse from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT logger = setup_logger() def is_tenant_on_trial(tenant_id: str) -> bool: """ Determine if a tenant is currently on a trial subscription. In multi-tenant mode, we fetch billing information from the control plane to determine if the tenant has an active trial. """ if not MULTI_TENANT: return False try: billing_info = fetch_billing_information(tenant_id) # If not subscribed at all, check if we have trial information if isinstance(billing_info, SubscriptionStatusResponse): # No subscription means they're likely on trial (new tenant) return True if isinstance(billing_info, BillingInformation): return billing_info.status == "trialing" return False except Exception as e: logger.warning(f"Failed to fetch billing info for trial check: {e}") # Default to trial limits on error (more restrictive = safer) return True ================================================ FILE: backend/ee/onyx/server/user_group/api.py ================================================ from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from ee.onyx.db.persona import update_persona_access from ee.onyx.db.user_group import add_users_to_user_group from ee.onyx.db.user_group import delete_user_group as db_delete_user_group from ee.onyx.db.user_group import fetch_user_group from ee.onyx.db.user_group import fetch_user_groups from ee.onyx.db.user_group import fetch_user_groups_for_user from ee.onyx.db.user_group import insert_user_group from ee.onyx.db.user_group import prepare_user_group_for_deletion from ee.onyx.db.user_group import rename_user_group from ee.onyx.db.user_group import update_user_curator_relationship from ee.onyx.db.user_group import update_user_group from ee.onyx.server.user_group.models import AddUsersToUserGroupRequest from ee.onyx.server.user_group.models import MinimalUserGroupSnapshot from ee.onyx.server.user_group.models import SetCuratorRequest from ee.onyx.server.user_group.models import UpdateGroupAgentsRequest from ee.onyx.server.user_group.models import UserGroup from ee.onyx.server.user_group.models import UserGroupCreate from ee.onyx.server.user_group.models import UserGroupRename from ee.onyx.server.user_group.models import UserGroupUpdate from onyx.auth.users import current_admin_user from onyx.auth.users import current_curator_or_admin_user from onyx.auth.users import current_user from onyx.configs.app_configs import DISABLE_VECTOR_DB from onyx.configs.constants import PUBLIC_API_TAGS from onyx.db.engine.sql_engine import get_session from onyx.db.models import User from onyx.db.models import UserRole from onyx.db.persona import get_persona_by_id from onyx.error_handling.error_codes import OnyxErrorCode from onyx.error_handling.exceptions import OnyxError from onyx.utils.logger import setup_logger logger = setup_logger() router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS) @router.get("/admin/user-group") def list_user_groups( include_default: bool = False, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> list[UserGroup]: if user.role == UserRole.ADMIN: user_groups = fetch_user_groups( db_session, only_up_to_date=False, eager_load_for_snapshot=True, include_default=include_default, ) else: user_groups = fetch_user_groups_for_user( db_session=db_session, user_id=user.id, only_curator_groups=user.role == UserRole.CURATOR, eager_load_for_snapshot=True, include_default=include_default, ) return [UserGroup.from_model(user_group) for user_group in user_groups] @router.get("/user-groups/minimal") def list_minimal_user_groups( include_default: bool = False, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> list[MinimalUserGroupSnapshot]: if user.role == UserRole.ADMIN: user_groups = fetch_user_groups( db_session, only_up_to_date=False, include_default=include_default, ) else: user_groups = fetch_user_groups_for_user( db_session=db_session, user_id=user.id, include_default=include_default, ) return [ MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups ] @router.get("/admin/user-group/{user_group_id}/permissions") def get_user_group_permissions( user_group_id: int, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> list[str]: group = fetch_user_group(db_session, user_group_id) if group is None: raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found") return [ grant.permission.value for grant in group.permission_grants if not grant.is_deleted ] @router.post("/admin/user-group") def create_user_group( user_group: UserGroupCreate, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> UserGroup: try: db_user_group = insert_user_group(db_session, user_group) except IntegrityError: raise HTTPException( 400, f"User group with name '{user_group.name}' already exists. Please " + "choose a different name.", ) return UserGroup.from_model(db_user_group) @router.patch("/admin/user-group/rename") def rename_user_group_endpoint( rename_request: UserGroupRename, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> UserGroup: group = fetch_user_group(db_session, rename_request.id) if group and group.is_default: raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot rename a default system group.") try: return UserGroup.from_model( rename_user_group( db_session=db_session, user_group_id=rename_request.id, new_name=rename_request.name, ) ) except IntegrityError: raise OnyxError( OnyxErrorCode.DUPLICATE_RESOURCE, f"User group with name '{rename_request.name}' already exists.", ) except ValueError as e: msg = str(e) if "not found" in msg.lower(): raise OnyxError(OnyxErrorCode.NOT_FOUND, msg) raise OnyxError(OnyxErrorCode.CONFLICT, msg) @router.patch("/admin/user-group/{user_group_id}") def patch_user_group( user_group_id: int, user_group_update: UserGroupUpdate, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> UserGroup: try: return UserGroup.from_model( update_user_group( db_session=db_session, user=user, user_group_id=user_group_id, user_group_update=user_group_update, ) ) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) @router.post("/admin/user-group/{user_group_id}/add-users") def add_users( user_group_id: int, add_users_request: AddUsersToUserGroupRequest, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> UserGroup: try: return UserGroup.from_model( add_users_to_user_group( db_session=db_session, user=user, user_group_id=user_group_id, user_ids=add_users_request.user_ids, ) ) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) @router.post("/admin/user-group/{user_group_id}/set-curator") def set_user_curator( user_group_id: int, set_curator_request: SetCuratorRequest, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> None: try: update_user_curator_relationship( db_session=db_session, user_group_id=user_group_id, set_curator_request=set_curator_request, user_making_change=user, ) except ValueError as e: logger.error(f"Error setting user curator: {e}") raise HTTPException(status_code=404, detail=str(e)) @router.delete("/admin/user-group/{user_group_id}") def delete_user_group( user_group_id: int, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> None: group = fetch_user_group(db_session, user_group_id) if group and group.is_default: raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot delete a default system group.") try: prepare_user_group_for_deletion(db_session, user_group_id) except ValueError as e: raise HTTPException(status_code=404, detail=str(e)) if DISABLE_VECTOR_DB: user_group = fetch_user_group(db_session, user_group_id) if user_group: db_delete_user_group(db_session, user_group) @router.patch("/admin/user-group/{user_group_id}/agents") def update_group_agents( user_group_id: int, request: UpdateGroupAgentsRequest, user: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> None: for agent_id in request.added_agent_ids: persona = get_persona_by_id( persona_id=agent_id, user=user, db_session=db_session ) current_group_ids = [g.id for g in persona.groups] if user_group_id not in current_group_ids: update_persona_access( persona_id=agent_id, creator_user_id=user.id, db_session=db_session, group_ids=current_group_ids + [user_group_id], ) for agent_id in request.removed_agent_ids: persona = get_persona_by_id( persona_id=agent_id, user=user, db_session=db_session ) current_group_ids = [g.id for g in persona.groups] update_persona_access( persona_id=agent_id, creator_user_id=user.id, db_session=db_session, group_ids=[gid for gid in current_group_ids if gid != user_group_id], ) db_session.commit() ================================================ FILE: backend/ee/onyx/server/user_group/models.py ================================================ from uuid import UUID from pydantic import BaseModel from onyx.db.models import UserGroup as UserGroupModel from onyx.server.documents.models import ConnectorCredentialPairDescriptor from onyx.server.documents.models import ConnectorSnapshot from onyx.server.documents.models import CredentialSnapshot from onyx.server.features.document_set.models import DocumentSet from onyx.server.features.persona.models import PersonaSnapshot from onyx.server.manage.models import UserInfo from onyx.server.manage.models import UserPreferences class UserGroup(BaseModel): id: int name: str users: list[UserInfo] curator_ids: list[UUID] cc_pairs: list[ConnectorCredentialPairDescriptor] document_sets: list[DocumentSet] personas: list[PersonaSnapshot] is_up_to_date: bool is_up_for_deletion: bool is_default: bool @classmethod def from_model(cls, user_group_model: UserGroupModel) -> "UserGroup": return cls( id=user_group_model.id, name=user_group_model.name, users=[ UserInfo( id=str(user.id), email=user.email, is_active=user.is_active, is_superuser=user.is_superuser, is_verified=user.is_verified, role=user.role, preferences=UserPreferences( default_model=user.default_model, chosen_assistants=user.chosen_assistants, ), ) for user in user_group_model.users ], curator_ids=[ user.user_id for user in user_group_model.user_group_relationships if user.is_curator and user.user_id is not None ], cc_pairs=[ ConnectorCredentialPairDescriptor( id=cc_pair_relationship.cc_pair.id, name=cc_pair_relationship.cc_pair.name, connector=ConnectorSnapshot.from_connector_db_model( cc_pair_relationship.cc_pair.connector, credential_ids=[cc_pair_relationship.cc_pair.credential_id], ), credential=CredentialSnapshot.from_credential_db_model( cc_pair_relationship.cc_pair.credential ), access_type=cc_pair_relationship.cc_pair.access_type, ) for cc_pair_relationship in user_group_model.cc_pair_relationships if cc_pair_relationship.is_current ], document_sets=[ DocumentSet.from_model(ds) for ds in user_group_model.document_sets ], personas=[ PersonaSnapshot.from_model(persona) for persona in user_group_model.personas if not persona.deleted ], is_up_to_date=user_group_model.is_up_to_date, is_up_for_deletion=user_group_model.is_up_for_deletion, is_default=user_group_model.is_default, ) class MinimalUserGroupSnapshot(BaseModel): id: int name: str is_default: bool @classmethod def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot": return cls( id=user_group_model.id, name=user_group_model.name, is_default=user_group_model.is_default, ) class UserGroupCreate(BaseModel): name: str user_ids: list[UUID] cc_pair_ids: list[int] class UserGroupUpdate(BaseModel): user_ids: list[UUID] cc_pair_ids: list[int] class AddUsersToUserGroupRequest(BaseModel): user_ids: list[UUID] class UserGroupRename(BaseModel): id: int name: str class SetCuratorRequest(BaseModel): user_id: UUID is_curator: bool class UpdateGroupAgentsRequest(BaseModel): added_agent_ids: list[int] removed_agent_ids: list[int] ================================================ FILE: backend/ee/onyx/utils/__init__.py ================================================ ================================================ FILE: backend/ee/onyx/utils/encryption.py ================================================ from functools import lru_cache from os import urandom from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives import padding from cryptography.hazmat.primitives.ciphers import algorithms from cryptography.hazmat.primitives.ciphers import Cipher from cryptography.hazmat.primitives.ciphers import modes from onyx.configs.app_configs import ENCRYPTION_KEY_SECRET from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_versioned_implementation logger = setup_logger() @lru_cache(maxsize=2) def _get_trimmed_key(key: str) -> bytes: encoded_key = key.encode() key_length = len(encoded_key) if key_length < 16: raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short") # Trim to the largest valid AES key size that fits valid_lengths = [32, 24, 16] for size in valid_lengths: if key_length >= size: return encoded_key[:size] raise AssertionError("unreachable") def _encrypt_string(input_str: str, key: str | None = None) -> bytes: effective_key = key if key is not None else ENCRYPTION_KEY_SECRET if not effective_key: return input_str.encode() trimmed = _get_trimmed_key(effective_key) iv = urandom(16) padder = padding.PKCS7(algorithms.AES.block_size).padder() padded_data = padder.update(input_str.encode()) + padder.finalize() cipher = Cipher(algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend()) encryptor = cipher.encryptor() encrypted_data = encryptor.update(padded_data) + encryptor.finalize() return iv + encrypted_data def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str: effective_key = key if key is not None else ENCRYPTION_KEY_SECRET if not effective_key: return input_bytes.decode() trimmed = _get_trimmed_key(effective_key) try: iv = input_bytes[:16] encrypted_data = input_bytes[16:] cipher = Cipher( algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend() ) decryptor = cipher.decryptor() decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize() unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder() decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize() return decrypted_data.decode() except (ValueError, UnicodeDecodeError): if key is not None: # Explicit key was provided — don't fall back silently raise # Read path: attempt raw UTF-8 decode as a fallback for legacy data. # Does NOT handle data encrypted with a different key — that # ciphertext is not valid UTF-8 and will raise below. logger.warning( "AES decryption failed — falling back to raw decode. Run the re-encrypt secrets script to rotate to the current key." ) try: return input_bytes.decode() except UnicodeDecodeError: raise ValueError( "Data is not valid UTF-8 — likely encrypted with a different key. " "Run the re-encrypt secrets script to rotate to the current key." ) from None def encrypt_string_to_bytes(input_str: str, key: str | None = None) -> bytes: versioned_encryption_fn = fetch_versioned_implementation( "onyx.utils.encryption", "_encrypt_string" ) return versioned_encryption_fn(input_str, key=key) def decrypt_bytes_to_string(input_bytes: bytes, key: str | None = None) -> str: versioned_decryption_fn = fetch_versioned_implementation( "onyx.utils.encryption", "_decrypt_bytes" ) return versioned_decryption_fn(input_bytes, key=key) def test_encryption() -> None: test_string = "Onyx is the BEST!" encrypted_bytes = encrypt_string_to_bytes(test_string) decrypted_string = decrypt_bytes_to_string(encrypted_bytes) if test_string != decrypted_string: raise RuntimeError("Encryption decryption test failed") ================================================ FILE: backend/ee/onyx/utils/license.py ================================================ """RSA-4096 license signature verification utilities.""" import base64 import json import os from datetime import datetime from datetime import timezone from pathlib import Path from cryptography.exceptions import InvalidSignature from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.asymmetric import padding from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey from ee.onyx.server.license.models import LicenseData from ee.onyx.server.license.models import LicensePayload from onyx.server.settings.models import ApplicationStatus from onyx.utils.logger import setup_logger logger = setup_logger() # Path to the license public key file _LICENSE_PUBLIC_KEY_PATH = ( Path(__file__).parent.parent.parent.parent / "keys" / "license_public_key.pem" ) def _get_public_key() -> RSAPublicKey: """Load the public key from file, with env var override.""" # Allow env var override for flexibility key_pem = os.environ.get("LICENSE_PUBLIC_KEY_PEM") if not key_pem: # Read from file if not _LICENSE_PUBLIC_KEY_PATH.exists(): raise ValueError( f"License public key not found at {_LICENSE_PUBLIC_KEY_PATH}. " "License verification requires the control plane public key." ) key_pem = _LICENSE_PUBLIC_KEY_PATH.read_text() key = serialization.load_pem_public_key(key_pem.encode()) if not isinstance(key, RSAPublicKey): raise ValueError("Expected RSA public key") return key def verify_license_signature(license_data: str) -> LicensePayload: """ Verify RSA-4096 signature and return payload if valid. Args: license_data: Base64-encoded JSON containing payload and signature Returns: LicensePayload if signature is valid Raises: ValueError: If license data is invalid or signature verification fails """ try: decoded = json.loads(base64.b64decode(license_data)) # Parse into LicenseData to validate structure license_obj = LicenseData(**decoded) # IMPORTANT: Use the ORIGINAL payload JSON for signature verification, # not re-serialized through Pydantic. Pydantic may format fields differently # (e.g., datetime "+00:00" vs "Z") which would break signature verification. original_payload = decoded.get("payload", {}) payload_json = json.dumps(original_payload, sort_keys=True) signature_bytes = base64.b64decode(license_obj.signature) # Verify signature using PSS padding (modern standard) public_key = _get_public_key() public_key.verify( signature_bytes, payload_json.encode(), padding.PSS( mgf=padding.MGF1(hashes.SHA256()), salt_length=padding.PSS.MAX_LENGTH, ), hashes.SHA256(), ) return license_obj.payload except InvalidSignature: logger.error("[verify_license] FAILED: Signature verification failed") raise ValueError("Invalid license signature") except json.JSONDecodeError as e: logger.error(f"[verify_license] FAILED: JSON decode error: {e}") raise ValueError("Invalid license format: not valid JSON") except (ValueError, KeyError, TypeError) as e: logger.error( f"[verify_license] FAILED: Validation error: {type(e).__name__}: {e}" ) raise ValueError(f"Invalid license format: {type(e).__name__}: {e}") except Exception: logger.exception("[verify_license] FAILED: Unexpected error") raise ValueError("License verification failed: unexpected error") def get_license_status( payload: LicensePayload, grace_period_end: datetime | None = None, ) -> ApplicationStatus: """ Determine current license status based on expiry. Args: payload: The verified license payload grace_period_end: Optional grace period end datetime Returns: ApplicationStatus indicating current license state """ now = datetime.now(timezone.utc) # Check if grace period has expired if grace_period_end and now > grace_period_end: return ApplicationStatus.GATED_ACCESS # Check if license has expired if now > payload.expires_at: if grace_period_end and now <= grace_period_end: return ApplicationStatus.GRACE_PERIOD return ApplicationStatus.GATED_ACCESS # License is valid return ApplicationStatus.ACTIVE def is_license_valid(payload: LicensePayload) -> bool: """Check if a license is currently valid (not expired).""" now = datetime.now(timezone.utc) return now <= payload.expires_at ================================================ FILE: backend/ee/onyx/utils/posthog_client.py ================================================ import json from typing import Any from urllib.parse import unquote from posthog import Posthog from ee.onyx.configs.app_configs import MARKETING_POSTHOG_API_KEY from ee.onyx.configs.app_configs import POSTHOG_API_KEY from ee.onyx.configs.app_configs import POSTHOG_DEBUG_LOGS_ENABLED from ee.onyx.configs.app_configs import POSTHOG_HOST from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT logger = setup_logger() def posthog_on_error(error: Any, items: Any) -> None: """Log any PostHog delivery errors.""" logger.error(f"PostHog error: {error}, items: {items}") posthog: Posthog | None = None if POSTHOG_API_KEY: posthog = Posthog( project_api_key=POSTHOG_API_KEY, host=POSTHOG_HOST, debug=POSTHOG_DEBUG_LOGS_ENABLED, on_error=posthog_on_error, ) elif MULTI_TENANT: logger.warning( "POSTHOG_API_KEY is not set but MULTI_TENANT is enabled — " "PostHog telemetry and feature flags will be disabled" ) # For cross referencing between cloud and www Onyx sites # NOTE: These clients are separate because they are separate posthog projects. # We should eventually unify them into a single posthog project, # which would no longer require this workaround marketing_posthog = None if MARKETING_POSTHOG_API_KEY: marketing_posthog = Posthog( project_api_key=MARKETING_POSTHOG_API_KEY, host=POSTHOG_HOST, debug=POSTHOG_DEBUG_LOGS_ENABLED, on_error=posthog_on_error, ) def capture_and_sync_with_alternate_posthog( alternate_distinct_id: str, event: str, properties: dict[str, Any] ) -> None: """ Identify in both PostHog projects and capture the event in marketing. - Marketing keeps the marketing distinct_id (for feature flags). - Cloud identify uses the cloud distinct_id """ if not marketing_posthog: return props = properties.copy() try: marketing_posthog.identify(distinct_id=alternate_distinct_id, properties=props) marketing_posthog.capture(alternate_distinct_id, event, props) marketing_posthog.flush() except Exception as e: logger.error(f"Error capturing marketing posthog event: {e}") try: if posthog and (cloud_user_id := props.get("onyx_cloud_user_id")): cloud_props = props.copy() cloud_props.pop("onyx_cloud_user_id", None) posthog.identify( distinct_id=cloud_user_id, properties=cloud_props, ) except Exception as e: logger.error(f"Error identifying cloud posthog user: {e}") def alias_user(distinct_id: str, anonymous_id: str) -> None: """Link an anonymous distinct_id to an identified user, merging person profiles. No-ops when the IDs match (e.g. returning users whose PostHog cookie already contains their identified user ID). """ if not posthog or anonymous_id == distinct_id: return try: posthog.alias(previous_id=anonymous_id, distinct_id=distinct_id) posthog.flush() except Exception as e: logger.error(f"Error aliasing PostHog user: {e}") def get_anon_id_from_request(request: Any) -> str | None: """Extract the anonymous distinct_id from the app PostHog cookie on a request.""" if not POSTHOG_API_KEY: return None cookie_name = f"ph_{POSTHOG_API_KEY}_posthog" if (cookie_value := request.cookies.get(cookie_name)) and ( parsed := parse_posthog_cookie(cookie_value) ): return parsed.get("distinct_id") return None def get_marketing_posthog_cookie_name() -> str | None: if not MARKETING_POSTHOG_API_KEY: return None return f"onyx_custom_ph_{MARKETING_POSTHOG_API_KEY}_posthog" def parse_posthog_cookie(cookie_value: str) -> dict[str, Any] | None: """ Parse a URL-encoded JSON PostHog cookie Expected format (URL-encoded): {"distinct_id":"...", "featureFlags":{"landing_page_variant":"..."}, ...} Returns: Dict with 'distinct_id' explicitly required and all other cookie values passed through as-is, or None if parsing fails or distinct_id is missing. """ try: decoded_cookie = unquote(cookie_value) cookie_data = json.loads(decoded_cookie) distinct_id = cookie_data.get("distinct_id") if not distinct_id or not isinstance(distinct_id, str): return None return cookie_data except (json.JSONDecodeError, KeyError, TypeError, AttributeError) as e: logger.warning(f"Failed to parse cookie: {e}") return None ================================================ FILE: backend/ee/onyx/utils/telemetry.py ================================================ from typing import Any from ee.onyx.utils.posthog_client import posthog from onyx.utils.logger import setup_logger logger = setup_logger() def event_telemetry( distinct_id: str, event: str, properties: dict[str, Any] | None = None ) -> None: """Capture and send an event to PostHog, flushing immediately.""" if not posthog: return logger.info(f"Capturing PostHog event: {distinct_id} {event} {properties}") try: posthog.capture(distinct_id, event, properties) posthog.flush() except Exception as e: logger.error(f"Error capturing PostHog event: {e}") def identify_user(distinct_id: str, properties: dict[str, Any] | None = None) -> None: """Create/update a PostHog person profile, flushing immediately.""" if not posthog: return try: posthog.identify(distinct_id, properties) posthog.flush() except Exception as e: logger.error(f"Error identifying PostHog user: {e}") ================================================ FILE: backend/generated/README.md ================================================ - Generated Files * Generated files live here. This directory should be git ignored. ================================================ FILE: backend/keys/license_public_key.pem ================================================ -----BEGIN PUBLIC KEY----- MIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA5DpchQujdxjCwpc4/RQP Hej6rc3SS/5ENCXL0I8NAfMogel0fqG6PKRhonyEh/Bt3P4q18y8vYzAShwf4b6Q aS0WwshbvnkjyWlsK0BY4HLBKPkTpes7kaz8MwmPZDeelvGJ7SNv3FvyJR4QsoSQ GSoB5iTH7hi63TjzdxtckkXoNG+GdVd/koxVDUv2uWcAoWIFTTcbKWyuq2SS/5Sf xdVaIArqfAhLpnNbnM9OS7lZ1xP+29ZXpHxDoeluz35tJLMNBYn9u0y+puo1kW1E TOGizlAq5kmEMsTJ55e9ZuyIV3gZAUaUKe8CxYJPkOGt0Gj6e1jHoHZCBJmaq97Y stKj//84HNBzajaryEZuEfRecJ94ANEjkD8u9cGmW+9VxRe5544zWguP5WMT/nv1 0Q+jkOBW2hkY5SS0Rug4cblxiB7bDymWkaX6+sC0VWd5g6WXp36EuP2T0v3mYuHU GDEiWbD44ToREPVwE/M07ny8qhLo/HYk2l8DKFt83hXe7ePBnyQdcsrVbQWOO1na j43OkoU5gOFyOkrk2RmmtCjA8jSnw+tGCTpRaRcshqoWC1MjZyU+8/kDteXNkmv9 /B5VxzYSyX+abl7yAu5wLiUPW8l+mOazzWu0nPkmiA160ArxnRyxbGnmp4dUIrt5 azYku4tQYLSsSabfhcpeiCsCAwEAAQ== -----END PUBLIC KEY----- ================================================ FILE: backend/model_server/__init__.py ================================================ ================================================ FILE: backend/model_server/constants.py ================================================ MODEL_WARM_UP_STRING = "hi " * 512 class GPUStatus: CUDA = "cuda" MAC_MPS = "mps" NONE = "none" ================================================ FILE: backend/model_server/encoders.py ================================================ import asyncio import time from typing import Any from typing import TYPE_CHECKING from fastapi import APIRouter from fastapi import HTTPException from fastapi import Request from model_server.utils import simple_log_function_time from onyx.utils.logger import setup_logger from shared_configs.enums import EmbedTextType from shared_configs.model_server_models import Embedding from shared_configs.model_server_models import EmbedRequest from shared_configs.model_server_models import EmbedResponse if TYPE_CHECKING: from sentence_transformers import SentenceTransformer logger = setup_logger() router = APIRouter(prefix="/encoder") _GLOBAL_MODELS_DICT: dict[str, "SentenceTransformer"] = {} def get_embedding_model( model_name: str, max_context_length: int, ) -> "SentenceTransformer": """ Loads or returns a cached SentenceTransformer, sets max_seq_length, pins device, pre-warms rotary caches once, and wraps encode() with a lock to avoid cache races. """ from sentence_transformers import SentenceTransformer def _prewarm_rope(st_model: "SentenceTransformer", target_len: int) -> None: """ Build RoPE cos/sin caches once on the final device/dtype so later forwards only read. Works by calling the underlying HF model directly with dummy IDs/attention. """ try: # ensure > max seq after tokenization # Ideally we would use the saved tokenizer, but whatever it's ok # we'll make an assumption about tokenization here long_text = "x " * (target_len * 2) _ = st_model.encode( [long_text], batch_size=1, convert_to_tensor=True, show_progress_bar=False, normalize_embeddings=False, ) logger.info("RoPE pre-warm successful") except Exception as e: logger.warning(f"RoPE pre-warm skipped/failed: {e}") global _GLOBAL_MODELS_DICT if model_name not in _GLOBAL_MODELS_DICT: logger.notice(f"Loading {model_name}") model = SentenceTransformer( model_name_or_path=model_name, trust_remote_code=True, ) model.max_seq_length = max_context_length _prewarm_rope(model, max_context_length) _GLOBAL_MODELS_DICT[model_name] = model else: model = _GLOBAL_MODELS_DICT[model_name] if max_context_length != model.max_seq_length: model.max_seq_length = max_context_length prev = getattr(model, "_rope_prewarmed_to", 0) if max_context_length > int(prev or 0): _prewarm_rope(model, max_context_length) return _GLOBAL_MODELS_DICT[model_name] ENCODING_RETRIES = 3 ENCODING_RETRY_DELAY = 0.1 def _concurrent_embedding( texts: list[str], model: "SentenceTransformer", normalize_embeddings: bool ) -> Any: """Synchronous wrapper for concurrent_embedding to use with run_in_executor.""" for _ in range(ENCODING_RETRIES): try: return model.encode(texts, normalize_embeddings=normalize_embeddings) except RuntimeError as e: # There is a concurrency bug in the SentenceTransformer library that causes # the model to fail to encode texts. It's pretty rare and we want to allow # concurrent embedding, hence we retry (the specific error is # "RuntimeError: Already borrowed" and occurs in the transformers library) logger.warning(f"Error encoding texts, retrying: {e}") time.sleep(ENCODING_RETRY_DELAY) return model.encode(texts, normalize_embeddings=normalize_embeddings) @simple_log_function_time() async def embed_text( texts: list[str], model_name: str | None, max_context_length: int, normalize_embeddings: bool, prefix: str | None, gpu_type: str = "UNKNOWN", ) -> list[Embedding]: if not all(texts): logger.error("Empty strings provided for embedding") raise ValueError("Empty strings are not allowed for embedding.") if not texts: logger.error("No texts provided for embedding") raise ValueError("No texts provided for embedding.") start = time.monotonic() total_chars = 0 for text in texts: total_chars += len(text) # Only local models should call this function now # API providers should go directly to API server if model_name is not None: logger.info( f"Embedding {len(texts)} texts with {total_chars} total characters with local model: {model_name}" ) prefixed_texts = [f"{prefix}{text}" for text in texts] if prefix else texts local_model = get_embedding_model( model_name=model_name, max_context_length=max_context_length ) # Run CPU-bound embedding in a thread pool embeddings_vectors = await asyncio.get_event_loop().run_in_executor( None, lambda: _concurrent_embedding( prefixed_texts, local_model, normalize_embeddings ), ) embeddings = [ embedding if isinstance(embedding, list) else embedding.tolist() for embedding in embeddings_vectors ] elapsed = time.monotonic() - start logger.info( f"Successfully embedded {len(texts)} texts with {total_chars} total characters " f"with local model {model_name} in {elapsed:.2f}" ) logger.info( f"event=embedding_model " f"texts={len(texts)} " f"chars={total_chars} " f"model={model_name} " f"gpu={gpu_type} " f"elapsed={elapsed:.2f}" ) else: logger.error("Model name not specified for embedding") raise ValueError("Model name must be provided to run embeddings.") return embeddings @router.post("/bi-encoder-embed") async def route_bi_encoder_embed( request: Request, embed_request: EmbedRequest, ) -> EmbedResponse: return await process_embed_request(embed_request, request.app.state.gpu_type) async def process_embed_request( embed_request: EmbedRequest, gpu_type: str = "UNKNOWN" ) -> EmbedResponse: from litellm.exceptions import RateLimitError # Only local models should use this endpoint - API providers should make direct API calls if embed_request.provider_type is not None: raise ValueError( f"Model server embedding endpoint should only be used for local models. " f"API provider '{embed_request.provider_type}' should make direct API calls instead." ) if not embed_request.texts: raise HTTPException(status_code=400, detail="No texts to be embedded") if not all(embed_request.texts): raise ValueError("Empty strings are not allowed for embedding.") try: if embed_request.text_type == EmbedTextType.QUERY: prefix = embed_request.manual_query_prefix elif embed_request.text_type == EmbedTextType.PASSAGE: prefix = embed_request.manual_passage_prefix else: prefix = None embeddings = await embed_text( texts=embed_request.texts, model_name=embed_request.model_name, max_context_length=embed_request.max_context_length, normalize_embeddings=embed_request.normalize_embeddings, prefix=prefix, gpu_type=gpu_type, ) return EmbedResponse(embeddings=embeddings) except RateLimitError as e: raise HTTPException( status_code=429, detail=str(e), ) except Exception as e: logger.exception( f"Error during embedding process: provider={embed_request.provider_type} model={embed_request.model_name}" ) raise HTTPException( status_code=500, detail=f"Error during embedding process: {e}" ) ================================================ FILE: backend/model_server/legacy/README.md ================================================ This directory contains code that was useful and may become useful again in the future. We stopped using rerankers because the state of the art rerankers are not significantly better than the biencoders and much worse than LLMs which are also capable of acting on a small set of documents for filtering, reranking, etc. We stopped using the internal query classifier as that's now offloaded to the LLM which does query expansion so we know ahead of time if it's a keyword or semantic query. ================================================ FILE: backend/model_server/legacy/__init__.py ================================================ ================================================ FILE: backend/model_server/legacy/custom_models.py ================================================ # from typing import cast # from typing import Optional # from typing import TYPE_CHECKING # import numpy as np # import torch # import torch.nn.functional as F # from fastapi import APIRouter # from huggingface_hub import snapshot_download # from pydantic import BaseModel # from model_server.constants import MODEL_WARM_UP_STRING # from model_server.legacy.onyx_torch_model import ConnectorClassifier # from model_server.legacy.onyx_torch_model import HybridClassifier # from model_server.utils import simple_log_function_time # from onyx.utils.logger import setup_logger # from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO # from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG # from shared_configs.configs import INDEXING_ONLY # from shared_configs.configs import INTENT_MODEL_TAG # from shared_configs.configs import INTENT_MODEL_VERSION # from shared_configs.model_server_models import IntentRequest # from shared_configs.model_server_models import IntentResponse # if TYPE_CHECKING: # from setfit import SetFitModel # type: ignore[import-untyped] # from transformers import PreTrainedTokenizer, BatchEncoding # INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi" * 50 # INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX = 1.0 # INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN = 0.7 # INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE = 4.0 # INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH = 10 # INFORMATION_CONTENT_MODEL_VERSION = "onyx-dot-app/information-content-model" # INFORMATION_CONTENT_MODEL_TAG: str | None = None # class ConnectorClassificationRequest(BaseModel): # available_connectors: list[str] # query: str # class ConnectorClassificationResponse(BaseModel): # connectors: list[str] # class ContentClassificationPrediction(BaseModel): # predicted_label: int # content_boost_factor: float # logger = setup_logger() # router = APIRouter(prefix="/custom") # _CONNECTOR_CLASSIFIER_TOKENIZER: Optional["PreTrainedTokenizer"] = None # _CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None # _INTENT_TOKENIZER: Optional["PreTrainedTokenizer"] = None # _INTENT_MODEL: HybridClassifier | None = None # _INFORMATION_CONTENT_MODEL: Optional["SetFitModel"] = None # _INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version! # def get_connector_classifier_tokenizer() -> "PreTrainedTokenizer": # global _CONNECTOR_CLASSIFIER_TOKENIZER # from transformers import AutoTokenizer, PreTrainedTokenizer # if _CONNECTOR_CLASSIFIER_TOKENIZER is None: # # The tokenizer details are not uploaded to the HF hub since it's just the # # unmodified distilbert tokenizer. # _CONNECTOR_CLASSIFIER_TOKENIZER = cast( # PreTrainedTokenizer, # AutoTokenizer.from_pretrained("distilbert-base-uncased"), # ) # return _CONNECTOR_CLASSIFIER_TOKENIZER # def get_local_connector_classifier( # model_name_or_path: str = CONNECTOR_CLASSIFIER_MODEL_REPO, # tag: str = CONNECTOR_CLASSIFIER_MODEL_TAG, # ) -> ConnectorClassifier: # global _CONNECTOR_CLASSIFIER_MODEL # if _CONNECTOR_CLASSIFIER_MODEL is None: # try: # # Calculate where the cache should be, then load from local if available # local_path = snapshot_download( # repo_id=model_name_or_path, revision=tag, local_files_only=True # ) # _CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained( # local_path # ) # except Exception as e: # logger.warning(f"Failed to load model directly: {e}") # try: # # Attempt to download the model snapshot # logger.info(f"Downloading model snapshot for {model_name_or_path}") # local_path = snapshot_download(repo_id=model_name_or_path, revision=tag) # _CONNECTOR_CLASSIFIER_MODEL = ConnectorClassifier.from_pretrained( # local_path # ) # except Exception as e: # logger.error( # f"Failed to load model even after attempted snapshot download: {e}" # ) # raise # return _CONNECTOR_CLASSIFIER_MODEL # def get_intent_model_tokenizer() -> "PreTrainedTokenizer": # from transformers import AutoTokenizer, PreTrainedTokenizer # global _INTENT_TOKENIZER # if _INTENT_TOKENIZER is None: # # The tokenizer details are not uploaded to the HF hub since it's just the # # unmodified distilbert tokenizer. # _INTENT_TOKENIZER = cast( # PreTrainedTokenizer, # AutoTokenizer.from_pretrained("distilbert-base-uncased"), # ) # return _INTENT_TOKENIZER # def get_local_intent_model( # model_name_or_path: str = INTENT_MODEL_VERSION, # tag: str | None = INTENT_MODEL_TAG, # ) -> HybridClassifier: # global _INTENT_MODEL # if _INTENT_MODEL is None: # try: # # Calculate where the cache should be, then load from local if available # logger.notice(f"Loading model from local cache: {model_name_or_path}") # local_path = snapshot_download( # repo_id=model_name_or_path, revision=tag, local_files_only=True # ) # _INTENT_MODEL = HybridClassifier.from_pretrained(local_path) # logger.notice(f"Loaded model from local cache: {local_path}") # except Exception as e: # logger.warning(f"Failed to load model directly: {e}") # try: # # Attempt to download the model snapshot # logger.notice(f"Downloading model snapshot for {model_name_or_path}") # local_path = snapshot_download( # repo_id=model_name_or_path, revision=tag, local_files_only=False # ) # _INTENT_MODEL = HybridClassifier.from_pretrained(local_path) # except Exception as e: # logger.error( # f"Failed to load model even after attempted snapshot download: {e}" # ) # raise # return _INTENT_MODEL # def get_local_information_content_model( # model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION, # tag: str | None = INFORMATION_CONTENT_MODEL_TAG, # ) -> "SetFitModel": # from setfit import SetFitModel # global _INFORMATION_CONTENT_MODEL # if _INFORMATION_CONTENT_MODEL is None: # try: # # Calculate where the cache should be, then load from local if available # logger.notice( # f"Loading content information model from local cache: {model_name_or_path}" # ) # local_path = snapshot_download( # repo_id=model_name_or_path, revision=tag, local_files_only=True # ) # _INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path) # logger.notice( # f"Loaded content information model from local cache: {local_path}" # ) # except Exception as e: # logger.warning(f"Failed to load content information model directly: {e}") # try: # # Attempt to download the model snapshot # logger.notice( # f"Downloading content information model snapshot for {model_name_or_path}" # ) # local_path = snapshot_download( # repo_id=model_name_or_path, revision=tag, local_files_only=False # ) # _INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path) # except Exception as e: # logger.error( # f"Failed to load content information model even after attempted snapshot download: {e}" # ) # raise # return _INFORMATION_CONTENT_MODEL # def tokenize_connector_classification_query( # connectors: list[str], # query: str, # tokenizer: "PreTrainedTokenizer", # connector_token_end_id: int, # ) -> tuple[torch.Tensor, torch.Tensor]: # """ # Tokenize the connectors & user query into one prompt for the forward pass of ConnectorClassifier models # The attention mask is just all 1s. The prompt is CLS + each connector name suffixed with the connector end # token and then the user query. # """ # input_ids = torch.tensor([tokenizer.cls_token_id], dtype=torch.long) # for connector in connectors: # connector_token_ids = tokenizer( # connector, # add_special_tokens=False, # return_tensors="pt", # ) # input_ids = torch.cat( # ( # input_ids, # connector_token_ids["input_ids"].squeeze(dim=0), # torch.tensor([connector_token_end_id], dtype=torch.long), # ), # dim=-1, # ) # query_token_ids = tokenizer( # query, # add_special_tokens=False, # return_tensors="pt", # ) # input_ids = torch.cat( # ( # input_ids, # query_token_ids["input_ids"].squeeze(dim=0), # torch.tensor([tokenizer.sep_token_id], dtype=torch.long), # ), # dim=-1, # ) # attention_mask = torch.ones(input_ids.numel(), dtype=torch.long) # return input_ids.unsqueeze(0), attention_mask.unsqueeze(0) # def warm_up_connector_classifier_model() -> None: # logger.info( # f"Warming up connector_classifier model {CONNECTOR_CLASSIFIER_MODEL_TAG}" # ) # connector_classifier_tokenizer = get_connector_classifier_tokenizer() # connector_classifier = get_local_connector_classifier() # input_ids, attention_mask = tokenize_connector_classification_query( # ["GitHub"], # "onyx classifier query google doc", # connector_classifier_tokenizer, # connector_classifier.connector_end_token_id, # ) # input_ids = input_ids.to(connector_classifier.device) # attention_mask = attention_mask.to(connector_classifier.device) # connector_classifier(input_ids, attention_mask) # def warm_up_intent_model() -> None: # logger.notice(f"Warming up Intent Model: {INTENT_MODEL_VERSION}") # intent_tokenizer = get_intent_model_tokenizer() # tokens = intent_tokenizer( # MODEL_WARM_UP_STRING, return_tensors="pt", truncation=True, padding=True # ) # intent_model = get_local_intent_model() # device = intent_model.device # intent_model( # query_ids=tokens["input_ids"].to(device), # query_mask=tokens["attention_mask"].to(device), # ) # def warm_up_information_content_model() -> None: # logger.notice("Warming up Content Model") # TODO: add version if needed # information_content_model = get_local_information_content_model() # information_content_model(INFORMATION_CONTENT_MODEL_WARM_UP_STRING) # @simple_log_function_time() # def run_inference(tokens: "BatchEncoding") -> tuple[list[float], list[float]]: # intent_model = get_local_intent_model() # device = intent_model.device # outputs = intent_model( # query_ids=tokens["input_ids"].to(device), # query_mask=tokens["attention_mask"].to(device), # ) # token_logits = outputs["token_logits"] # intent_logits = outputs["intent_logits"] # # Move tensors to CPU before applying softmax and converting to numpy # intent_probabilities = F.softmax(intent_logits.cpu(), dim=-1).numpy()[0] # token_probabilities = F.softmax(token_logits.cpu(), dim=-1).numpy()[0] # # Extract the probabilities for the positive class (index 1) for each token # token_positive_probs = token_probabilities[:, 1].tolist() # return intent_probabilities.tolist(), token_positive_probs # @simple_log_function_time() # def run_content_classification_inference( # text_inputs: list[str], # ) -> list[ContentClassificationPrediction]: # """ # Assign a score to the segments in question. The model stored in get_local_information_content_model() # creates the 'model score' based on its training, and the scores are then converted to a 0.0-1.0 scale. # In the code outside of the model/inference model servers that score will be converted into the actual # boost factor. # """ # def _prob_to_score(prob: float) -> float: # """ # Conversion of base score to 0.0 - 1.0 score. Note that the min/max values depend on the model! # """ # _MIN_BASE_SCORE = 0.25 # _MAX_BASE_SCORE = 0.75 # if prob < _MIN_BASE_SCORE: # raw_score = 0.0 # elif prob < _MAX_BASE_SCORE: # raw_score = (prob - _MIN_BASE_SCORE) / (_MAX_BASE_SCORE - _MIN_BASE_SCORE) # else: # raw_score = 1.0 # return ( # INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN # + ( # INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX # - INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN # ) # * raw_score # ) # _BATCH_SIZE = 32 # content_model = get_local_information_content_model() # # Process inputs in batches # all_output_classes: list[int] = [] # all_base_output_probabilities: list[float] = [] # for i in range(0, len(text_inputs), _BATCH_SIZE): # batch = text_inputs[i : i + _BATCH_SIZE] # batch_with_prefix = [] # batch_indices = [] # # Pre-allocate results for this batch # batch_output_classes: list[np.ndarray] = [np.array(1)] * len(batch) # batch_probabilities: list[np.ndarray] = [np.array(1.0)] * len(batch) # # Pre-process batch to handle long input exceptions # for j, text in enumerate(batch): # if len(text) == 0: # # if no input, treat as non-informative from the model's perspective # batch_output_classes[j] = np.array(0) # batch_probabilities[j] = np.array(0.0) # logger.warning("Input for Content Information Model is empty") # elif ( # len(text.split()) # <= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH # ): # # if input is short, use the model # batch_with_prefix.append( # _INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + text # ) # batch_indices.append(j) # else: # # if longer than cutoff, treat as informative (stay with default), but issue warning # logger.warning("Input for Content Information Model too long") # if batch_with_prefix: # Only run model if we have valid inputs # # Get predictions for the batch # model_output_classes = content_model(batch_with_prefix) # model_output_probabilities = content_model.predict_proba(batch_with_prefix) # # Place results in the correct positions # for idx, batch_idx in enumerate(batch_indices): # batch_output_classes[batch_idx] = model_output_classes[idx].numpy() # batch_probabilities[batch_idx] = model_output_probabilities[idx][ # 1 # ].numpy() # x[1] is prob of the positive class # all_output_classes.extend([int(x) for x in batch_output_classes]) # all_base_output_probabilities.extend([float(x) for x in batch_probabilities]) # logits = [ # np.log(p / (1 - p)) if p != 0.0 and p != 1.0 else (100 if p == 1.0 else -100) # for p in all_base_output_probabilities # ] # scaled_logits = [ # logit / INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE # for logit in logits # ] # output_probabilities_with_temp = [ # np.exp(scaled_logit) / (1 + np.exp(scaled_logit)) # for scaled_logit in scaled_logits # ] # prediction_scores = [ # _prob_to_score(p_temp) for p_temp in output_probabilities_with_temp # ] # content_classification_predictions = [ # ContentClassificationPrediction( # predicted_label=predicted_label, content_boost_factor=output_score # ) # for predicted_label, output_score in zip(all_output_classes, prediction_scores) # ] # return content_classification_predictions # def map_keywords( # input_ids: torch.Tensor, tokenizer: "PreTrainedTokenizer", is_keyword: list[bool] # ) -> list[str]: # tokens = tokenizer.convert_ids_to_tokens(input_ids) # type: ignore # if not len(tokens) == len(is_keyword): # raise ValueError("Length of tokens and keyword predictions must match") # if input_ids[0] == tokenizer.cls_token_id: # tokens = tokens[1:] # is_keyword = is_keyword[1:] # if input_ids[-1] == tokenizer.sep_token_id: # tokens = tokens[:-1] # is_keyword = is_keyword[:-1] # unk_token = tokenizer.unk_token # if unk_token in tokens: # raise ValueError("Unknown token detected in the input") # keywords = [] # current_keyword = "" # for ind, token in enumerate(tokens): # if is_keyword[ind]: # if token.startswith("##"): # current_keyword += token[2:] # else: # if current_keyword: # keywords.append(current_keyword) # current_keyword = token # else: # # If mispredicted a later token of a keyword, add it to the current keyword # # to complete it # if current_keyword: # if len(current_keyword) > 2 and current_keyword.startswith("##"): # current_keyword = current_keyword[2:] # else: # keywords.append(current_keyword) # current_keyword = "" # if current_keyword: # keywords.append(current_keyword) # return keywords # def clean_keywords(keywords: list[str]) -> list[str]: # cleaned_words = [] # for word in keywords: # word = word[:-2] if word.endswith("'s") else word # word = word.replace("/", " ") # word = word.replace("'", "").replace('"', "") # cleaned_words.extend([w for w in word.strip().split() if w and not w.isspace()]) # return cleaned_words # def run_connector_classification(req: ConnectorClassificationRequest) -> list[str]: # tokenizer = get_connector_classifier_tokenizer() # model = get_local_connector_classifier() # connector_names = req.available_connectors # input_ids, attention_mask = tokenize_connector_classification_query( # connector_names, # req.query, # tokenizer, # model.connector_end_token_id, # ) # input_ids = input_ids.to(model.device) # attention_mask = attention_mask.to(model.device) # global_confidence, classifier_confidence = model(input_ids, attention_mask) # if global_confidence.item() < 0.5: # return [] # passed_connectors = [] # for i, connector_name in enumerate(connector_names): # if classifier_confidence.view(-1)[i].item() > 0.5: # passed_connectors.append(connector_name) # return passed_connectors # def run_analysis(intent_req: IntentRequest) -> tuple[bool, list[str]]: # tokenizer = get_intent_model_tokenizer() # model_input = tokenizer( # intent_req.query, return_tensors="pt", truncation=False, padding=False # ) # if len(model_input.input_ids[0]) > 512: # # If the user text is too long, assume it is semantic and keep all words # return True, intent_req.query.split() # intent_probs, token_probs = run_inference(model_input) # is_keyword_sequence = intent_probs[0] >= intent_req.keyword_percent_threshold # keyword_preds = [ # token_prob >= intent_req.keyword_percent_threshold for token_prob in token_probs # ] # try: # keywords = map_keywords(model_input.input_ids[0], tokenizer, keyword_preds) # except Exception as e: # logger.warning( # f"Failed to extract keywords for query: {intent_req.query} due to {e}" # ) # # Fallback to keeping all words # keywords = intent_req.query.split() # cleaned_keywords = clean_keywords(keywords) # return is_keyword_sequence, cleaned_keywords # @router.post("/connector-classification") # async def process_connector_classification_request( # classification_request: ConnectorClassificationRequest, # ) -> ConnectorClassificationResponse: # if INDEXING_ONLY: # raise RuntimeError( # "Indexing model server should not call connector classification endpoint" # ) # if len(classification_request.available_connectors) == 0: # return ConnectorClassificationResponse(connectors=[]) # connectors = run_connector_classification(classification_request) # return ConnectorClassificationResponse(connectors=connectors) # @router.post("/query-analysis") # async def process_analysis_request( # intent_request: IntentRequest, # ) -> IntentResponse: # if INDEXING_ONLY: # raise RuntimeError("Indexing model server should not call intent endpoint") # is_keyword, keywords = run_analysis(intent_request) # return IntentResponse(is_keyword=is_keyword, keywords=keywords) # @router.post("/content-classification") # async def process_content_classification_request( # content_classification_requests: list[str], # ) -> list[ContentClassificationPrediction]: # return run_content_classification_inference(content_classification_requests) ================================================ FILE: backend/model_server/legacy/onyx_torch_model.py ================================================ # import json # import os # from typing import cast # from typing import TYPE_CHECKING # import torch # import torch.nn as nn # if TYPE_CHECKING: # from transformers import DistilBertConfig # class HybridClassifier(nn.Module): # def __init__(self) -> None: # from transformers import DistilBertConfig, DistilBertModel # super().__init__() # config = DistilBertConfig() # self.distilbert = DistilBertModel(config) # config = self.distilbert.config # type: ignore # # Keyword tokenwise binary classification layer # self.keyword_classifier = nn.Linear(config.dim, 2) # # Intent Classifier layers # self.pre_classifier = nn.Linear(config.dim, config.dim) # self.intent_classifier = nn.Linear(config.dim, 2) # self.device = torch.device("cpu") # def forward( # self, # query_ids: torch.Tensor, # query_mask: torch.Tensor, # ) -> dict[str, torch.Tensor]: # outputs = self.distilbert(input_ids=query_ids, attention_mask=query_mask) # sequence_output = outputs.last_hidden_state # # Intent classification on the CLS token # cls_token_state = sequence_output[:, 0, :] # pre_classifier_out = self.pre_classifier(cls_token_state) # intent_logits = self.intent_classifier(pre_classifier_out) # # Keyword classification on all tokens # token_logits = self.keyword_classifier(sequence_output) # return {"intent_logits": intent_logits, "token_logits": token_logits} # @classmethod # def from_pretrained(cls, load_directory: str) -> "HybridClassifier": # model_path = os.path.join(load_directory, "pytorch_model.bin") # config_path = os.path.join(load_directory, "config.json") # with open(config_path, "r") as f: # config = json.load(f) # model = cls(**config) # if torch.backends.mps.is_available(): # # Apple silicon GPU # device = torch.device("mps") # elif torch.cuda.is_available(): # device = torch.device("cuda") # else: # device = torch.device("cpu") # model.load_state_dict(torch.load(model_path, map_location=device)) # model = model.to(device) # model.device = device # model.eval() # # Eval doesn't set requires_grad to False, do it manually to save memory and have faster inference # for param in model.parameters(): # param.requires_grad = False # return model # class ConnectorClassifier(nn.Module): # def __init__(self, config: "DistilBertConfig") -> None: # from transformers import DistilBertTokenizer, DistilBertModel # super().__init__() # self.config = config # self.distilbert = DistilBertModel(config) # config = self.distilbert.config # type: ignore # self.connector_global_classifier = nn.Linear(config.dim, 1) # self.connector_match_classifier = nn.Linear(config.dim, 1) # self.tokenizer = DistilBertTokenizer.from_pretrained("distilbert-base-uncased") # # Token indicating end of connector name, and on which classifier is used # self.connector_end_token_id = self.tokenizer.get_vocab()[ # self.config.connector_end_token # ] # self.device = torch.device("cpu") # def forward( # self, # input_ids: torch.Tensor, # attention_mask: torch.Tensor, # ) -> tuple[torch.Tensor, torch.Tensor]: # hidden_states = self.distilbert( # input_ids=input_ids, attention_mask=attention_mask # ).last_hidden_state # cls_hidden_states = hidden_states[ # :, 0, : # ] # Take leap of faith that first token is always [CLS] # global_logits = self.connector_global_classifier(cls_hidden_states).view(-1) # global_confidence = torch.sigmoid(global_logits).view(-1) # connector_end_position_ids = input_ids == self.connector_end_token_id # connector_end_hidden_states = hidden_states[connector_end_position_ids] # classifier_output = self.connector_match_classifier(connector_end_hidden_states) # classifier_confidence = torch.nn.functional.sigmoid(classifier_output).view(-1) # return global_confidence, classifier_confidence # @classmethod # def from_pretrained(cls, repo_dir: str) -> "ConnectorClassifier": # from transformers import DistilBertConfig # config = cast( # DistilBertConfig, # DistilBertConfig.from_pretrained(os.path.join(repo_dir, "config.json")), # ) # device = ( # torch.device("cuda") # if torch.cuda.is_available() # else ( # torch.device("mps") # if torch.backends.mps.is_available() # else torch.device("cpu") # ) # ) # state_dict = torch.load( # os.path.join(repo_dir, "pytorch_model.pt"), # map_location=device, # weights_only=True, # ) # model = cls(config) # model.load_state_dict(state_dict) # model.to(device) # model.device = device # model.eval() # for param in model.parameters(): # param.requires_grad = False # return model ================================================ FILE: backend/model_server/legacy/reranker.py ================================================ # import asyncio # from typing import Optional # from typing import TYPE_CHECKING # from fastapi import APIRouter # from fastapi import HTTPException # from model_server.utils import simple_log_function_time # from onyx.utils.logger import setup_logger # from shared_configs.configs import INDEXING_ONLY # from shared_configs.model_server_models import RerankRequest # from shared_configs.model_server_models import RerankResponse # if TYPE_CHECKING: # from sentence_transformers import CrossEncoder # logger = setup_logger() # router = APIRouter(prefix="/encoder") # _RERANK_MODEL: Optional["CrossEncoder"] = None # def get_local_reranking_model( # model_name: str, # ) -> "CrossEncoder": # global _RERANK_MODEL # from sentence_transformers import CrossEncoder # if _RERANK_MODEL is None: # logger.notice(f"Loading {model_name}") # model = CrossEncoder(model_name) # _RERANK_MODEL = model # return _RERANK_MODEL # @simple_log_function_time() # async def local_rerank(query: str, docs: list[str], model_name: str) -> list[float]: # cross_encoder = get_local_reranking_model(model_name) # # Run CPU-bound reranking in a thread pool # return await asyncio.get_event_loop().run_in_executor( # None, # lambda: cross_encoder.predict([(query, doc) for doc in docs]).tolist(), # ) # @router.post("/cross-encoder-scores") # async def process_rerank_request(rerank_request: RerankRequest) -> RerankResponse: # """Cross encoders can be purely black box from the app perspective""" # # Only local models should use this endpoint - API providers should make direct API calls # if rerank_request.provider_type is not None: # raise ValueError( # f"Model server reranking endpoint should only be used for local models. " # f"API provider '{rerank_request.provider_type}' should make direct API calls instead." # ) # if INDEXING_ONLY: # raise RuntimeError("Indexing model server should not call reranking endpoint") # if not rerank_request.documents or not rerank_request.query: # raise HTTPException( # status_code=400, detail="Missing documents or query for reranking" # ) # if not all(rerank_request.documents): # raise ValueError("Empty documents cannot be reranked.") # try: # # At this point, provider_type is None, so handle local reranking # sim_scores = await local_rerank( # query=rerank_request.query, # docs=rerank_request.documents, # model_name=rerank_request.model_name, # ) # return RerankResponse(scores=sim_scores) # except Exception as e: # logger.exception(f"Error during reranking process:\n{str(e)}") # raise HTTPException( # status_code=500, detail="Failed to run Cross-Encoder reranking" # ) ================================================ FILE: backend/model_server/main.py ================================================ import logging import os import shutil from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from pathlib import Path import sentry_sdk import torch import uvicorn from fastapi import FastAPI from prometheus_fastapi_instrumentator import Instrumentator from sentry_sdk.integrations.fastapi import FastApiIntegration from sentry_sdk.integrations.starlette import StarletteIntegration from transformers import logging as transformer_logging from model_server.encoders import router as encoders_router from model_server.management_endpoints import router as management_router from model_server.utils import get_gpu_type from onyx import __version__ from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_uvicorn_logger from onyx.utils.middleware import add_onyx_request_id_middleware from onyx.utils.middleware import add_onyx_tenant_id_middleware from shared_configs.configs import INDEXING_ONLY from shared_configs.configs import MIN_THREADS_ML_MODELS from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST from shared_configs.configs import MODEL_SERVER_PORT from shared_configs.configs import SENTRY_DSN os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" HF_CACHE_PATH = Path(".cache/huggingface") TEMP_HF_CACHE_PATH = Path(".cache/temp_huggingface") transformer_logging.set_verbosity_error() logger = setup_logger() file_handlers = [ h for h in logger.logger.handlers if isinstance(h, logging.FileHandler) ] setup_uvicorn_logger(shared_file_handlers=file_handlers) def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -> None: """ This moves the files from the temp huggingface cache to the huggingface cache We have to move each file individually because the directories might have the same name but not the same contents and we dont want to remove the files in the existing huggingface cache that don't exist in the temp huggingface cache. """ for item in source.iterdir(): target_path = dest / item.relative_to(source) if item.is_dir(): _move_files_recursively(item, target_path, overwrite) else: target_path.parent.mkdir(parents=True, exist_ok=True) if target_path.exists() and not overwrite: continue shutil.move(str(item), str(target_path)) @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator: gpu_type = get_gpu_type() logger.notice(f"Torch GPU Detection: gpu_type={gpu_type}") app.state.gpu_type = gpu_type try: if TEMP_HF_CACHE_PATH.is_dir(): logger.notice("Moving contents of temp_huggingface to huggingface cache.") _move_files_recursively(TEMP_HF_CACHE_PATH, HF_CACHE_PATH) shutil.rmtree(TEMP_HF_CACHE_PATH, ignore_errors=True) logger.notice("Moved contents of temp_huggingface to huggingface cache.") except Exception as e: logger.warning( f"Error moving contents of temp_huggingface to huggingface cache: {e}. " "This is not a critical error and the model server will continue to run." ) torch.set_num_threads(max(MIN_THREADS_ML_MODELS, torch.get_num_threads())) logger.notice(f"Torch Threads: {torch.get_num_threads()}") yield def get_model_app() -> FastAPI: application = FastAPI( title="Onyx Model Server", version=__version__, lifespan=lifespan ) if SENTRY_DSN: sentry_sdk.init( dsn=SENTRY_DSN, integrations=[StarletteIntegration(), FastApiIntegration()], traces_sample_rate=0.1, release=__version__, ) logger.info("Sentry initialized") else: logger.debug("Sentry DSN not provided, skipping Sentry initialization") application.include_router(management_router) application.include_router(encoders_router) request_id_prefix = "INF" if INDEXING_ONLY: request_id_prefix = "IDX" add_onyx_tenant_id_middleware(application, logger) add_onyx_request_id_middleware(application, request_id_prefix, logger) # Initialize and instrument the app Instrumentator().instrument(application).expose(application) return application app = get_model_app() if __name__ == "__main__": logger.notice( f"Starting Onyx Model Server on http://{MODEL_SERVER_ALLOWED_HOST}:{str(MODEL_SERVER_PORT)}/" ) logger.notice(f"Model Server Version: {__version__}") uvicorn.run(app, host=MODEL_SERVER_ALLOWED_HOST, port=MODEL_SERVER_PORT) ================================================ FILE: backend/model_server/management_endpoints.py ================================================ from fastapi import APIRouter from fastapi import Response from model_server.constants import GPUStatus from model_server.utils import get_gpu_type router = APIRouter(prefix="/api") @router.get("/health") async def healthcheck() -> Response: return Response(status_code=200) @router.get("/gpu-status") async def route_gpu_status() -> dict[str, bool | str]: gpu_type = get_gpu_type() gpu_available = gpu_type != GPUStatus.NONE return {"gpu_available": gpu_available, "type": gpu_type} ================================================ FILE: backend/model_server/utils.py ================================================ import asyncio import time from collections.abc import Callable from collections.abc import Generator from collections.abc import Iterator from functools import wraps from typing import Any from typing import cast from typing import TypeVar import torch from model_server.constants import GPUStatus from onyx.utils.logger import setup_logger logger = setup_logger() F = TypeVar("F", bound=Callable) FG = TypeVar("FG", bound=Callable[..., Generator | Iterator]) def simple_log_function_time( func_name: str | None = None, debug_only: bool = False, include_args: bool = False, ) -> Callable[[F], F]: def decorator(func: F) -> F: if asyncio.iscoroutinefunction(func): @wraps(func) async def wrapped_async_func(*args: Any, **kwargs: Any) -> Any: start_time = time.time() result = await func(*args, **kwargs) elapsed_time_str = str(time.time() - start_time) log_name = func_name or func.__name__ args_str = f" args={args} kwargs={kwargs}" if include_args else "" final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds" if debug_only: logger.debug(final_log) else: logger.notice(final_log) return result return cast(F, wrapped_async_func) else: @wraps(func) def wrapped_sync_func(*args: Any, **kwargs: Any) -> Any: start_time = time.time() result = func(*args, **kwargs) elapsed_time_str = str(time.time() - start_time) log_name = func_name or func.__name__ args_str = f" args={args} kwargs={kwargs}" if include_args else "" final_log = f"{log_name}{args_str} took {elapsed_time_str} seconds" if debug_only: logger.debug(final_log) else: logger.notice(final_log) return result return cast(F, wrapped_sync_func) return decorator def get_gpu_type() -> str: if torch.cuda.is_available(): return GPUStatus.CUDA if torch.backends.mps.is_available(): return GPUStatus.MAC_MPS return GPUStatus.NONE ================================================ FILE: backend/onyx/__init__.py ================================================ import os __version__ = os.environ.get("ONYX_VERSION", "") or "Development" ================================================ FILE: backend/onyx/access/__init__.py ================================================ ================================================ FILE: backend/onyx/access/access.py ================================================ from collections.abc import Callable from typing import cast from sqlalchemy.orm import Session from onyx.access.models import DocumentAccess from onyx.access.utils import prefix_user_email from onyx.configs.constants import DocumentSource from onyx.configs.constants import PUBLIC_DOC_PAT from onyx.db.document import get_access_info_for_document from onyx.db.document import get_access_info_for_documents from onyx.db.models import User from onyx.db.models import UserFile from onyx.db.user_file import fetch_user_files_with_access_relationships from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop from onyx.utils.variable_functionality import fetch_versioned_implementation def _get_access_for_document( document_id: str, db_session: Session, ) -> DocumentAccess: info = get_access_info_for_document( db_session=db_session, document_id=document_id, ) doc_access = DocumentAccess.build( user_emails=info[1] if info and info[1] else [], user_groups=[], external_user_emails=[], external_user_group_ids=[], is_public=info[2] if info else False, ) return doc_access def get_access_for_document( document_id: str, db_session: Session, ) -> DocumentAccess: versioned_get_access_for_document_fn = fetch_versioned_implementation( "onyx.access.access", "_get_access_for_document" ) return versioned_get_access_for_document_fn(document_id, db_session) def get_null_document_access() -> DocumentAccess: return DocumentAccess.build( user_emails=[], user_groups=[], is_public=False, external_user_emails=[], external_user_group_ids=[], ) def _get_access_for_documents( document_ids: list[str], db_session: Session, ) -> dict[str, DocumentAccess]: document_access_info = get_access_info_for_documents( db_session=db_session, document_ids=document_ids, ) doc_access = {} for document_id, user_emails, is_public in document_access_info: doc_access[document_id] = DocumentAccess.build( user_emails=[email for email in user_emails if email], # MIT version will wipe all groups and external groups on update user_groups=[], is_public=is_public, external_user_emails=[], external_user_group_ids=[], ) # Sometimes the document has not been indexed by the indexing job yet, in those cases # the document does not exist and so we use least permissive. Specifically the EE version # checks the MIT version permissions and creates a superset. This ensures that this flow # does not fail even if the Document has not yet been indexed. for doc_id in document_ids: if doc_id not in doc_access: doc_access[doc_id] = get_null_document_access() return doc_access def get_access_for_documents( document_ids: list[str], db_session: Session, ) -> dict[str, DocumentAccess]: """Fetches all access information for the given documents.""" versioned_get_access_for_documents_fn = fetch_versioned_implementation( "onyx.access.access", "_get_access_for_documents" ) return versioned_get_access_for_documents_fn(document_ids, db_session) def _get_acl_for_user( user: User, db_session: Session # noqa: ARG001 ) -> set[str]: # noqa: ARG001 """Returns a list of ACL entries that the user has access to. This is meant to be used downstream to filter out documents that the user does not have access to. The user should have access to a document if at least one entry in the document's ACL matches one entry in the returned set. Anonymous users only have access to public documents. """ if user.is_anonymous: return {PUBLIC_DOC_PAT} return {prefix_user_email(user.email), PUBLIC_DOC_PAT} def get_acl_for_user(user: User, db_session: Session | None = None) -> set[str]: versioned_acl_for_user_fn = fetch_versioned_implementation( "onyx.access.access", "_get_acl_for_user" ) return versioned_acl_for_user_fn(user, db_session) def source_should_fetch_permissions_during_indexing(source: DocumentSource) -> bool: _source_should_fetch_permissions_during_indexing_func = cast( Callable[[DocumentSource], bool], fetch_ee_implementation_or_noop( "onyx.external_permissions.sync_params", "source_should_fetch_permissions_during_indexing", False, ), ) return _source_should_fetch_permissions_during_indexing_func(source) def get_access_for_user_files( user_file_ids: list[str], db_session: Session, ) -> dict[str, DocumentAccess]: versioned_fn = fetch_versioned_implementation( "onyx.access.access", "get_access_for_user_files_impl" ) return versioned_fn(user_file_ids, db_session) def get_access_for_user_files_impl( user_file_ids: list[str], db_session: Session, ) -> dict[str, DocumentAccess]: user_files = fetch_user_files_with_access_relationships(user_file_ids, db_session) return build_access_for_user_files_impl(user_files) def build_access_for_user_files( user_files: list[UserFile], ) -> dict[str, DocumentAccess]: """Compute access from pre-loaded UserFile objects (with relationships). Callers must ensure UserFile.user, Persona.users, and Persona.user are eagerly loaded (and Persona.groups for the EE path).""" versioned_fn = fetch_versioned_implementation( "onyx.access.access", "build_access_for_user_files_impl" ) return versioned_fn(user_files) def build_access_for_user_files_impl( user_files: list[UserFile], ) -> dict[str, DocumentAccess]: result: dict[str, DocumentAccess] = {} for user_file in user_files: emails, is_public = collect_user_file_access(user_file) result[str(user_file.id)] = DocumentAccess.build( user_emails=list(emails), user_groups=[], is_public=is_public, external_user_emails=[], external_user_group_ids=[], ) return result def collect_user_file_access(user_file: UserFile) -> tuple[set[str], bool]: """Collect all user emails that should have access to this user file. Includes the owner plus any users who have access via shared personas. Returns (emails, is_public).""" emails: set[str] = {user_file.user.email} is_public = False for persona in user_file.assistants: if persona.deleted: continue if persona.is_public: is_public = True if persona.user_id is not None and persona.user: emails.add(persona.user.email) for shared_user in persona.users: emails.add(shared_user.email) return emails, is_public ================================================ FILE: backend/onyx/access/hierarchy_access.py ================================================ from sqlalchemy.orm import Session from onyx.db.models import User from onyx.utils.variable_functionality import fetch_versioned_implementation def _get_user_external_group_ids( db_session: Session, # noqa: ARG001 user: User, # noqa: ARG001 ) -> list[str]: return [] def get_user_external_group_ids(db_session: Session, user: User) -> list[str]: versioned_get_user_external_group_ids = fetch_versioned_implementation( "onyx.access.hierarchy_access", "_get_user_external_group_ids" ) return versioned_get_user_external_group_ids(db_session, user) ================================================ FILE: backend/onyx/access/models.py ================================================ from dataclasses import dataclass from onyx.access.utils import prefix_external_group from onyx.access.utils import prefix_user_email from onyx.access.utils import prefix_user_group from onyx.configs.constants import PUBLIC_DOC_PAT @dataclass(frozen=True) class ExternalAccess: # arbitrary limit to prevent excessively large permissions sets # not internally enforced ... the caller can check this before using the instance MAX_NUM_ENTRIES = 5000 # Emails of external users with access to the doc externally external_user_emails: set[str] # Names or external IDs of groups with access to the doc external_user_group_ids: set[str] # Whether the document is public in the external system or Onyx is_public: bool def __str__(self) -> str: """Prevent extremely long logs""" def truncate_set(s: set[str], max_len: int = 100) -> str: s_str = str(s) if len(s_str) > max_len: return f"{s_str[:max_len]}... ({len(s)} items)" return s_str return ( f"ExternalAccess(" f"external_user_emails={truncate_set(self.external_user_emails)}, " f"external_user_group_ids={truncate_set(self.external_user_group_ids)}, " f"is_public={self.is_public})" ) @property def num_entries(self) -> int: return len(self.external_user_emails) + len(self.external_user_group_ids) @classmethod def public(cls) -> "ExternalAccess": return cls( external_user_emails=set(), external_user_group_ids=set(), is_public=True, ) @classmethod def empty(cls) -> "ExternalAccess": """ A helper function that returns an *empty* set of external user-emails and group-ids, and sets `is_public` to `False`. This effectively makes the document in question "private" or inaccessible to anyone else. This is especially helpful to use when you are performing permission-syncing, and some document's permissions aren't able to be determined (for whatever reason). Setting its `ExternalAccess` to "private" is a feasible fallback. """ return cls( external_user_emails=set(), external_user_group_ids=set(), is_public=False, ) @dataclass(frozen=True) class DocExternalAccess: """ This is just a class to wrap the external access and the document ID together. It's used for syncing document permissions to Vespa. """ external_access: ExternalAccess # The document ID doc_id: str def to_dict(self) -> dict: return { "external_access": { "external_user_emails": list(self.external_access.external_user_emails), "external_user_group_ids": list( self.external_access.external_user_group_ids ), "is_public": self.external_access.is_public, }, "doc_id": self.doc_id, } @classmethod def from_dict(cls, data: dict) -> "DocExternalAccess": external_access = ExternalAccess( external_user_emails=set( data["external_access"].get("external_user_emails", []) ), external_user_group_ids=set( data["external_access"].get("external_user_group_ids", []) ), is_public=data["external_access"]["is_public"], ) return cls( external_access=external_access, doc_id=data["doc_id"], ) @dataclass(frozen=True) class NodeExternalAccess: """ Wraps external access with a hierarchy node's raw ID. Used for syncing hierarchy node permissions (e.g., folder permissions). """ external_access: ExternalAccess # The raw node ID from the source system (e.g., Google Drive folder ID) raw_node_id: str # The source type (e.g., "google_drive") source: str def to_dict(self) -> dict: return { "external_access": { "external_user_emails": list(self.external_access.external_user_emails), "external_user_group_ids": list( self.external_access.external_user_group_ids ), "is_public": self.external_access.is_public, }, "raw_node_id": self.raw_node_id, "source": self.source, } @classmethod def from_dict(cls, data: dict) -> "NodeExternalAccess": external_access = ExternalAccess( external_user_emails=set( data["external_access"].get("external_user_emails", []) ), external_user_group_ids=set( data["external_access"].get("external_user_group_ids", []) ), is_public=data["external_access"]["is_public"], ) return cls( external_access=external_access, raw_node_id=data["raw_node_id"], source=data["source"], ) # Union type for elements that can have permissions synced ElementExternalAccess = DocExternalAccess | NodeExternalAccess # TODO(andrei): First refactor this into a pydantic model, then get rid of # duplicate fields. @dataclass(frozen=True, init=False) class DocumentAccess(ExternalAccess): # User emails for Onyx users, None indicates admin user_emails: set[str | None] # Names of user groups associated with this document user_groups: set[str] external_user_emails: set[str] external_user_group_ids: set[str] is_public: bool def __init__(self) -> None: raise TypeError( "Use `DocumentAccess.build(...)` instead of creating an instance directly." ) def to_acl(self) -> set[str]: """Converts the access state to a set of formatted ACL strings. NOTE: When querying for documents, the supplied ACL filter strings must be formatted in the same way as this function. """ acl_set: set[str] = set() for user_email in self.user_emails: if user_email: acl_set.add(prefix_user_email(user_email)) for group_name in self.user_groups: acl_set.add(prefix_user_group(group_name)) for external_user_email in self.external_user_emails: acl_set.add(prefix_user_email(external_user_email)) for external_group_id in self.external_user_group_ids: acl_set.add(prefix_external_group(external_group_id)) if self.is_public: acl_set.add(PUBLIC_DOC_PAT) return acl_set @classmethod def build( cls, user_emails: list[str | None], user_groups: list[str], external_user_emails: list[str], external_user_group_ids: list[str], is_public: bool, ) -> "DocumentAccess": """Don't prefix incoming data wth acl type, prefix on read from to_acl!""" obj = object.__new__(cls) object.__setattr__( obj, "user_emails", {user_email for user_email in user_emails if user_email} ) object.__setattr__(obj, "user_groups", set(user_groups)) object.__setattr__( obj, "external_user_emails", {external_email for external_email in external_user_emails}, ) object.__setattr__( obj, "external_user_group_ids", {external_group_id for external_group_id in external_user_group_ids}, ) object.__setattr__(obj, "is_public", is_public) return obj default_public_access = DocumentAccess.build( external_user_emails=[], external_user_group_ids=[], user_emails=[], user_groups=[], is_public=True, ) ================================================ FILE: backend/onyx/access/utils.py ================================================ from onyx.configs.constants import DocumentSource def prefix_user_email(user_email: str) -> str: """Prefixes a user email to eliminate collision with group names. This applies to both a Onyx user and an External user, this is to make the query time more efficient""" return f"user_email:{user_email}" def prefix_user_group(user_group_name: str) -> str: """Prefixes a user group name to eliminate collision with user emails. This assumes that user ids are prefixed with a different prefix.""" return f"group:{user_group_name}" def prefix_external_group(ext_group_name: str) -> str: """Prefixes an external group name to eliminate collision with user emails / Onyx groups.""" return f"external_group:{ext_group_name}" def build_ext_group_name_for_onyx(ext_group_name: str, source: DocumentSource) -> str: """ External groups may collide across sources, every source needs its own prefix. NOTE: the name is lowercased to handle case sensitivity for group names """ return f"{source.value}_{ext_group_name}".lower() ================================================ FILE: backend/onyx/auth/__init__.py ================================================ ================================================ FILE: backend/onyx/auth/anonymous_user.py ================================================ from collections.abc import Mapping from typing import Any from typing import cast from onyx.auth.schemas import UserRole from onyx.configs.constants import ANONYMOUS_USER_EMAIL from onyx.configs.constants import ANONYMOUS_USER_INFO_ID from onyx.configs.constants import KV_ANONYMOUS_USER_PERSONALIZATION_KEY from onyx.configs.constants import KV_ANONYMOUS_USER_PREFERENCES_KEY from onyx.key_value_store.store import KeyValueStore from onyx.key_value_store.store import KvKeyNotFoundError from onyx.server.manage.models import UserInfo from onyx.server.manage.models import UserPersonalization from onyx.server.manage.models import UserPreferences def set_anonymous_user_preferences( store: KeyValueStore, preferences: UserPreferences ) -> None: store.store(KV_ANONYMOUS_USER_PREFERENCES_KEY, preferences.model_dump()) def set_anonymous_user_personalization( store: KeyValueStore, personalization: UserPersonalization ) -> None: store.store(KV_ANONYMOUS_USER_PERSONALIZATION_KEY, personalization.model_dump()) def load_anonymous_user_preferences(store: KeyValueStore) -> UserPreferences: try: preferences_data = cast( Mapping[str, Any], store.load(KV_ANONYMOUS_USER_PREFERENCES_KEY) ) return UserPreferences(**preferences_data) except KvKeyNotFoundError: return UserPreferences( chosen_assistants=None, default_model=None, auto_scroll=True ) def fetch_anonymous_user_info(store: KeyValueStore) -> UserInfo: """Fetch a UserInfo object for anonymous users (used for API responses).""" personalization = UserPersonalization() try: personalization_data = cast( Mapping[str, Any], store.load(KV_ANONYMOUS_USER_PERSONALIZATION_KEY) ) personalization = UserPersonalization(**personalization_data) except KvKeyNotFoundError: pass return UserInfo( id=ANONYMOUS_USER_INFO_ID, email=ANONYMOUS_USER_EMAIL, is_active=True, is_superuser=False, is_verified=True, role=UserRole.LIMITED, preferences=load_anonymous_user_preferences(store), personalization=personalization, is_anonymous_user=True, password_configured=False, ) ================================================ FILE: backend/onyx/auth/api_key.py ================================================ import hashlib import secrets import uuid from urllib.parse import quote from fastapi import Request from passlib.hash import sha256_crypt from pydantic import BaseModel from onyx.auth.constants import API_KEY_LENGTH from onyx.auth.constants import API_KEY_PREFIX from onyx.auth.constants import DEPRECATED_API_KEY_PREFIX from onyx.auth.schemas import UserRole from onyx.auth.utils import get_hashed_bearer_token_from_request from onyx.configs.app_configs import API_KEY_HASH_ROUNDS from shared_configs.configs import MULTI_TENANT class ApiKeyDescriptor(BaseModel): api_key_id: int api_key_display: str api_key: str | None = None # only present on initial creation api_key_name: str | None = None api_key_role: UserRole user_id: uuid.UUID def generate_api_key(tenant_id: str | None = None) -> str: if not MULTI_TENANT or not tenant_id: return API_KEY_PREFIX + secrets.token_urlsafe(API_KEY_LENGTH) encoded_tenant = quote(tenant_id) # URL encode the tenant ID return f"{API_KEY_PREFIX}{encoded_tenant}.{secrets.token_urlsafe(API_KEY_LENGTH)}" def _deprecated_hash_api_key(api_key: str) -> str: return sha256_crypt.hash(api_key, salt="", rounds=API_KEY_HASH_ROUNDS) def hash_api_key(api_key: str) -> str: # NOTE: no salt is needed, as the API key is randomly generated # and overlaps are impossible if api_key.startswith(API_KEY_PREFIX): return hashlib.sha256(api_key.encode("utf-8")).hexdigest() if api_key.startswith(DEPRECATED_API_KEY_PREFIX): return _deprecated_hash_api_key(api_key) raise ValueError(f"Invalid API key prefix: {api_key[:3]}") def build_displayable_api_key(api_key: str) -> str: if api_key.startswith(API_KEY_PREFIX): api_key = api_key[len(API_KEY_PREFIX) :] return API_KEY_PREFIX + api_key[:4] + "********" + api_key[-4:] def get_hashed_api_key_from_request(request: Request) -> str | None: """Extract and hash API key from Authorization header. Accepts both "Bearer " and raw key formats. """ return get_hashed_bearer_token_from_request( request, valid_prefixes=[API_KEY_PREFIX, DEPRECATED_API_KEY_PREFIX], hash_fn=hash_api_key, allow_non_bearer=True, # API keys historically support both formats ) ================================================ FILE: backend/onyx/auth/captcha.py ================================================ """Captcha verification for user registration.""" import httpx from pydantic import BaseModel from pydantic import Field from onyx.configs.app_configs import CAPTCHA_ENABLED from onyx.configs.app_configs import RECAPTCHA_SCORE_THRESHOLD from onyx.configs.app_configs import RECAPTCHA_SECRET_KEY from onyx.utils.logger import setup_logger logger = setup_logger() RECAPTCHA_VERIFY_URL = "https://www.google.com/recaptcha/api/siteverify" class CaptchaVerificationError(Exception): """Raised when captcha verification fails.""" class RecaptchaResponse(BaseModel): """Response from Google reCAPTCHA verification API.""" success: bool score: float | None = None # Only present for reCAPTCHA v3 action: str | None = None challenge_ts: str | None = None hostname: str | None = None error_codes: list[str] | None = Field(default=None, alias="error-codes") def is_captcha_enabled() -> bool: """Check if captcha verification is enabled.""" return CAPTCHA_ENABLED and bool(RECAPTCHA_SECRET_KEY) async def verify_captcha_token( token: str, expected_action: str = "signup", ) -> None: """ Verify a reCAPTCHA token with Google's API. Args: token: The reCAPTCHA response token from the client expected_action: Expected action name for v3 verification Raises: CaptchaVerificationError: If verification fails """ if not is_captcha_enabled(): return if not token: raise CaptchaVerificationError("Captcha token is required") try: async with httpx.AsyncClient() as client: response = await client.post( RECAPTCHA_VERIFY_URL, data={ "secret": RECAPTCHA_SECRET_KEY, "response": token, }, timeout=10.0, ) response.raise_for_status() data = response.json() result = RecaptchaResponse(**data) if not result.success: error_codes = result.error_codes or ["unknown-error"] logger.warning(f"Captcha verification failed: {error_codes}") raise CaptchaVerificationError( f"Captcha verification failed: {', '.join(error_codes)}" ) # For reCAPTCHA v3, also check the score if result.score is not None: if result.score < RECAPTCHA_SCORE_THRESHOLD: logger.warning( f"Captcha score too low: {result.score} < {RECAPTCHA_SCORE_THRESHOLD}" ) raise CaptchaVerificationError( "Captcha verification failed: suspicious activity detected" ) # Optionally verify the action matches if result.action and result.action != expected_action: logger.warning( f"Captcha action mismatch: {result.action} != {expected_action}" ) raise CaptchaVerificationError( "Captcha verification failed: action mismatch" ) logger.debug( f"Captcha verification passed: score={result.score}, action={result.action}" ) except httpx.HTTPError as e: logger.error(f"Captcha API request failed: {e}") # In case of API errors, we might want to allow registration # to prevent blocking legitimate users. This is a policy decision. raise CaptchaVerificationError("Captcha verification service unavailable") ================================================ FILE: backend/onyx/auth/constants.py ================================================ """Authentication constants shared across auth modules.""" # API Key constants API_KEY_PREFIX = "on_" DEPRECATED_API_KEY_PREFIX = "dn_" API_KEY_LENGTH = 192 # PAT constants PAT_PREFIX = "onyx_pat_" PAT_LENGTH = 192 # Shared header constants API_KEY_HEADER_NAME = "Authorization" API_KEY_HEADER_ALTERNATIVE_NAME = "X-Onyx-Authorization" BEARER_PREFIX = "Bearer " ================================================ FILE: backend/onyx/auth/disposable_email_validator.py ================================================ """ Utility to validate and block disposable/temporary email addresses. This module fetches a list of known disposable email domains from a remote source and caches them for performance. It's used during user registration to prevent abuse from temporary email services. """ import threading import time from typing import Set import httpx from onyx.configs.app_configs import DISPOSABLE_EMAIL_DOMAINS_URL from onyx.utils.logger import setup_logger logger = setup_logger() class DisposableEmailValidator: """ Thread-safe singleton validator for disposable email domains. Fetches and caches the list of disposable domains, with periodic refresh. """ _instance: "DisposableEmailValidator | None" = None _lock = threading.Lock() def __new__(cls) -> "DisposableEmailValidator": if cls._instance is None: with cls._lock: if cls._instance is None: cls._instance = super().__new__(cls) return cls._instance def __init__(self) -> None: # Check if already initialized using a try/except to avoid type issues try: if self._initialized: return except AttributeError: pass self._domains: Set[str] = set() self._last_fetch_time: float = 0 self._fetch_lock = threading.Lock() # Cache for 1 hour self._cache_duration = 3600 # Hardcoded fallback list of common disposable domains # This ensures we block at least these even if the remote fetch fails self._fallback_domains = { "trashlify.com", "10minutemail.com", "guerrillamail.com", "mailinator.com", "tempmail.com", "chat-tempmail.com", "throwaway.email", "yopmail.com", "temp-mail.org", "getnada.com", "maildrop.cc", } # Set initialized flag last to prevent race conditions self._initialized: bool = True def _should_refresh(self) -> bool: """Check if the cached domains should be refreshed.""" return (time.time() - self._last_fetch_time) > self._cache_duration def _fetch_domains(self) -> Set[str]: """ Fetch disposable email domains from the configured URL. Returns: Set of domain strings (lowercased) """ if not DISPOSABLE_EMAIL_DOMAINS_URL: logger.debug("DISPOSABLE_EMAIL_DOMAINS_URL not configured") return self._fallback_domains.copy() try: logger.info( f"Fetching disposable email domains from {DISPOSABLE_EMAIL_DOMAINS_URL}" ) with httpx.Client(timeout=10.0) as client: response = client.get(DISPOSABLE_EMAIL_DOMAINS_URL) response.raise_for_status() domains_list = response.json() if not isinstance(domains_list, list): logger.error( f"Expected list from disposable domains URL, got {type(domains_list)}" ) return self._fallback_domains.copy() # Convert all to lowercase and create set domains = {domain.lower().strip() for domain in domains_list if domain} # Always include fallback domains domains.update(self._fallback_domains) logger.info( f"Successfully fetched {len(domains)} disposable email domains" ) return domains except httpx.HTTPError as e: logger.warning(f"Failed to fetch disposable domains (HTTP error): {e}") except Exception as e: logger.warning(f"Failed to fetch disposable domains: {e}") # On error, return fallback domains return self._fallback_domains.copy() def get_domains(self) -> Set[str]: """ Get the cached set of disposable email domains. Refreshes the cache if needed. Returns: Set of disposable domain strings (lowercased) """ # Fast path: return cached domains if still fresh if self._domains and not self._should_refresh(): return self._domains.copy() # Slow path: need to refresh with self._fetch_lock: # Double-check after acquiring lock if self._domains and not self._should_refresh(): return self._domains.copy() self._domains = self._fetch_domains() self._last_fetch_time = time.time() return self._domains.copy() def is_disposable(self, email: str) -> bool: """ Check if an email address uses a disposable domain. Args: email: The email address to check Returns: True if the email domain is disposable, False otherwise """ if not email or "@" not in email: return False parts = email.split("@") if len(parts) != 2 or not parts[0]: # Must have user@domain with non-empty user return False domain = parts[1].lower().strip() if not domain: # Domain part must not be empty return False disposable_domains = self.get_domains() return domain in disposable_domains # Global singleton instance _validator = DisposableEmailValidator() def is_disposable_email(email: str) -> bool: """ Check if an email address uses a disposable/temporary domain. This is a convenience function that uses the global validator instance. Args: email: The email address to check Returns: True if the email uses a disposable domain, False otherwise """ return _validator.is_disposable(email) def refresh_disposable_domains() -> None: """ Force a refresh of the disposable domains list. This can be called manually if you want to update the list without waiting for the cache to expire. """ _validator._last_fetch_time = 0 _validator.get_domains() ================================================ FILE: backend/onyx/auth/email_utils.py ================================================ import base64 import smtplib from datetime import datetime from email.mime.image import MIMEImage from email.mime.multipart import MIMEMultipart from email.mime.text import MIMEText from email.utils import formatdate from email.utils import make_msgid import sendgrid # type: ignore from sendgrid.helpers.mail import Attachment # type: ignore from sendgrid.helpers.mail import Content from sendgrid.helpers.mail import ContentId from sendgrid.helpers.mail import Disposition from sendgrid.helpers.mail import Email from sendgrid.helpers.mail import FileContent from sendgrid.helpers.mail import FileName from sendgrid.helpers.mail import FileType from sendgrid.helpers.mail import Mail from sendgrid.helpers.mail import To from onyx.configs.app_configs import EMAIL_CONFIGURED from onyx.configs.app_configs import EMAIL_FROM from onyx.configs.app_configs import SENDGRID_API_KEY from onyx.configs.app_configs import SMTP_PASS from onyx.configs.app_configs import SMTP_PORT from onyx.configs.app_configs import SMTP_SERVER from onyx.configs.app_configs import SMTP_USER from onyx.configs.app_configs import WEB_DOMAIN from onyx.configs.constants import AuthType from onyx.configs.constants import ONYX_DEFAULT_APPLICATION_NAME from onyx.configs.constants import ONYX_DISCORD_URL from onyx.db.models import User from onyx.server.runtime.onyx_runtime import OnyxRuntime from onyx.utils.logger import setup_logger from onyx.utils.url import add_url_params from onyx.utils.variable_functionality import fetch_versioned_implementation from shared_configs.configs import MULTI_TENANT logger = setup_logger() HTML_EMAIL_TEMPLATE = """\ {title} """ def build_html_email( application_name: str | None, heading: str, message: str, cta_text: str | None = None, cta_link: str | None = None, ) -> str: community_link_fragment = "" if application_name == ONYX_DEFAULT_APPLICATION_NAME: community_link_fragment = f'
Have questions? Join our Discord community here.' if cta_text and cta_link: cta_block = f'{cta_text}' else: cta_block = "" return HTML_EMAIL_TEMPLATE.format( application_name=application_name, title=heading, heading=heading, message=message, cta_block=cta_block, community_link_fragment=community_link_fragment, year=datetime.now().year, ) def send_email( user_email: str, subject: str, html_body: str, text_body: str, mail_from: str = EMAIL_FROM, inline_png: tuple[str, bytes] | None = None, ) -> None: if not EMAIL_CONFIGURED: raise ValueError("Email is not configured.") if SENDGRID_API_KEY: send_email_with_sendgrid( user_email, subject, html_body, text_body, mail_from, inline_png ) return send_email_with_smtplib( user_email, subject, html_body, text_body, mail_from, inline_png ) def send_email_with_sendgrid( user_email: str, subject: str, html_body: str, text_body: str, mail_from: str = EMAIL_FROM, inline_png: tuple[str, bytes] | None = None, ) -> None: from_email = Email(mail_from) if mail_from else Email("noreply@onyx.app") to_email = To(user_email) mail = Mail( from_email=from_email, to_emails=to_email, subject=subject, plain_text_content=Content("text/plain", text_body), ) # Add HTML content mail.add_content(Content("text/html", html_body)) if inline_png: image_name, image_data = inline_png # Create attachment encoded_image = base64.b64encode(image_data).decode() attachment = Attachment() attachment.file_content = FileContent(encoded_image) attachment.file_name = FileName(image_name) attachment.file_type = FileType("image/png") attachment.disposition = Disposition("inline") attachment.content_id = ContentId(image_name) mail.add_attachment(attachment) # Get a JSON-ready representation of the Mail object mail_json = mail.get() sg = sendgrid.SendGridAPIClient(api_key=SENDGRID_API_KEY) response = sg.client.mail.send.post(request_body=mail_json) # can raise if response.status_code != 202: logger.warning(f"Unexpected status code {response.status_code}") def send_email_with_smtplib( user_email: str, subject: str, html_body: str, text_body: str, mail_from: str = EMAIL_FROM, inline_png: tuple[str, bytes] | None = None, ) -> None: # Create a multipart/alternative message - this indicates these are alternative versions of the same content msg = MIMEMultipart("alternative") msg["Subject"] = subject msg["To"] = user_email if mail_from: msg["From"] = mail_from msg["Date"] = formatdate(localtime=True) msg["Message-ID"] = make_msgid(domain="onyx.app") # Add text part first (lowest priority) text_part = MIMEText(text_body, "plain") msg.attach(text_part) if inline_png: # For HTML with images, create a multipart/related container related = MIMEMultipart("related") # Add the HTML part to the related container html_part = MIMEText(html_body, "html") related.attach(html_part) # Add image with proper Content-ID to the related container img = MIMEImage(inline_png[1], _subtype="png") img.add_header("Content-ID", f"<{inline_png[0]}>") img.add_header("Content-Disposition", "inline", filename=inline_png[0]) related.attach(img) # Add the related part to the message (higher priority than text) msg.attach(related) else: # No images, just add HTML directly (higher priority than text) html_part = MIMEText(html_body, "html") msg.attach(html_part) with smtplib.SMTP(SMTP_SERVER, SMTP_PORT) as s: s.starttls() s.login(SMTP_USER, SMTP_PASS) s.send_message(msg) def send_subscription_cancellation_email(user_email: str) -> None: """This is templated but isn't meaningful for whitelabeling.""" # Example usage of the reusable HTML try: load_runtime_settings_fn = fetch_versioned_implementation( "onyx.server.enterprise_settings.store", "load_runtime_settings" ) settings = load_runtime_settings_fn() application_name = settings.application_name except ModuleNotFoundError: application_name = ONYX_DEFAULT_APPLICATION_NAME onyx_file = OnyxRuntime.get_emailable_logo() subject = f"Your {application_name} Subscription Has Been Canceled" heading = "Subscription Canceled" message = ( "

We're sorry to see you go.

" "

Your subscription has been canceled and will end on your next billing date.

" "

If you change your mind, you can always come back!

" ) cta_text = "Renew Subscription" cta_link = "https://www.onyx.app/pricing" html_content = build_html_email( application_name, heading, message, cta_text, cta_link, ) text_content = ( "We're sorry to see you go.\n" "Your subscription has been canceled and will end on your next billing date.\n" "If you change your mind, visit https://www.onyx.app/pricing" ) send_email( user_email, subject, html_content, text_content, inline_png=("logo.png", onyx_file.data), ) def build_user_email_invite( from_email: str, to_email: str, application_name: str, auth_type: AuthType ) -> tuple[str, str]: heading = "You've Been Invited!" # the exact action taken by the user, and thus the message, depends on the auth type message = f"

You have been invited by {from_email} to join an organization on {application_name}.

" if auth_type == AuthType.CLOUD: message += ( "

To join the organization, please click the button below to set a password " "or login with Google and complete your registration.

" ) elif auth_type == AuthType.BASIC: message += "

To join the organization, please click the button below to set a password and complete your registration.

" elif auth_type == AuthType.GOOGLE_OAUTH: message += "

To join the organization, please click the button below to login with Google and complete your registration.

" elif auth_type == AuthType.OIDC or auth_type == AuthType.SAML: message += "

To join the organization, please click the button below to complete your registration.

" else: raise ValueError(f"Invalid auth type: {auth_type}") cta_text = "Join Organization" cta_link = f"{WEB_DOMAIN}/auth/signup?email={to_email}" html_content = build_html_email( application_name, heading, message, cta_text, cta_link, ) # text content is the fallback for clients that don't support HTML # not as critical, so not having special cases for each auth type text_content = ( f"You have been invited by {from_email} to join an organization on {application_name}.\n" "To join the organization, please visit the following link:\n" f"{WEB_DOMAIN}/auth/signup?email={to_email}\n" ) if auth_type == AuthType.CLOUD: text_content += "You'll be asked to set a password or login with Google to complete your registration." return text_content, html_content def send_user_email_invite( user_email: str, current_user: User, auth_type: AuthType ) -> None: try: load_runtime_settings_fn = fetch_versioned_implementation( "onyx.server.enterprise_settings.store", "load_runtime_settings" ) settings = load_runtime_settings_fn() application_name = settings.application_name except ModuleNotFoundError: application_name = ONYX_DEFAULT_APPLICATION_NAME onyx_file = OnyxRuntime.get_emailable_logo() subject = f"Invitation to Join {application_name} Organization" text_content, html_content = build_user_email_invite( current_user.email, user_email, application_name, auth_type ) send_email( user_email, subject, html_content, text_content, inline_png=("logo.png", onyx_file.data), ) def send_forgot_password_email( user_email: str, token: str, tenant_id: str, mail_from: str = EMAIL_FROM, ) -> None: # Builds a forgot password email with or without fancy HTML try: load_runtime_settings_fn = fetch_versioned_implementation( "onyx.server.enterprise_settings.store", "load_runtime_settings" ) settings = load_runtime_settings_fn() application_name = settings.application_name except ModuleNotFoundError: application_name = ONYX_DEFAULT_APPLICATION_NAME onyx_file = OnyxRuntime.get_emailable_logo() subject = f"Reset Your {application_name} Password" heading = "Reset Your Password" tenant_param = f"&tenant={tenant_id}" if tenant_id and MULTI_TENANT else "" message = "

Please click the button below to reset your password. This link will expire in 24 hours.

" cta_text = "Reset Password" cta_link = f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}" html_content = build_html_email( application_name, heading, message, cta_text, cta_link, ) text_content = ( f"Please click the following link to reset your password. This link will expire in 24 hours.\n" f"{WEB_DOMAIN}/auth/reset-password?token={token}{tenant_param}" ) send_email( user_email, subject, html_content, text_content, mail_from, inline_png=("logo.png", onyx_file.data), ) def send_user_verification_email( user_email: str, token: str, new_organization: bool = False, mail_from: str = EMAIL_FROM, ) -> None: # Builds a verification email try: load_runtime_settings_fn = fetch_versioned_implementation( "onyx.server.enterprise_settings.store", "load_runtime_settings" ) settings = load_runtime_settings_fn() application_name = settings.application_name except ModuleNotFoundError: application_name = ONYX_DEFAULT_APPLICATION_NAME onyx_file = OnyxRuntime.get_emailable_logo() subject = f"{application_name} Email Verification" link = f"{WEB_DOMAIN}/auth/verify-email?token={token}" if new_organization: link = add_url_params(link, {"first_user": "true"}) message = ( f"

Click the following link to verify your email address:

{link}

" ) html_content = build_html_email( application_name, "Verify Your Email", message, ) text_content = f"Click the following link to verify your email address: {link}" send_email( user_email, subject, html_content, text_content, mail_from, inline_png=("logo.png", onyx_file.data), ) ================================================ FILE: backend/onyx/auth/invited_users.py ================================================ from typing import cast from onyx.configs.constants import KV_PENDING_USERS_KEY from onyx.configs.constants import KV_USER_STORE_KEY from onyx.key_value_store.factory import get_kv_store from onyx.key_value_store.interface import KvKeyNotFoundError from onyx.utils.special_types import JSON_ro def remove_user_from_invited_users(email: str) -> int: try: store = get_kv_store() user_emails = cast(list, store.load(KV_USER_STORE_KEY)) remaining_users = [user for user in user_emails if user != email] store.store(KV_USER_STORE_KEY, cast(JSON_ro, remaining_users)) return len(remaining_users) except KvKeyNotFoundError: return 0 def get_invited_users() -> list[str]: try: store = get_kv_store() return cast(list, store.load(KV_USER_STORE_KEY)) except KvKeyNotFoundError: return list() def write_invited_users(emails: list[str]) -> int: store = get_kv_store() store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails)) return len(emails) def get_pending_users() -> list[str]: try: store = get_kv_store() return cast(list, store.load(KV_PENDING_USERS_KEY)) except KvKeyNotFoundError: return list() def write_pending_users(emails: list[str]) -> int: store = get_kv_store() store.store(KV_PENDING_USERS_KEY, cast(JSON_ro, emails)) return len(emails) ================================================ FILE: backend/onyx/auth/jwt.py ================================================ import json from enum import Enum from functools import lru_cache from typing import Any from typing import cast import jwt import requests from cryptography.hazmat.primitives.asymmetric.rsa import RSAPublicKey from jwt import decode as jwt_decode from jwt import InvalidTokenError from jwt import PyJWTError from jwt.algorithms import RSAAlgorithm from onyx.configs.app_configs import JWT_PUBLIC_KEY_URL from onyx.utils.logger import setup_logger logger = setup_logger() _PUBLIC_KEY_FETCH_ATTEMPTS = 2 class PublicKeyFormat(Enum): JWKS = "jwks" PEM = "pem" @lru_cache() def _fetch_public_key_payload() -> tuple[str | dict[str, Any], PublicKeyFormat] | None: """Fetch and cache the raw JWT verification material.""" if JWT_PUBLIC_KEY_URL is None: logger.error("JWT_PUBLIC_KEY_URL is not set") return None try: response = requests.get(JWT_PUBLIC_KEY_URL) response.raise_for_status() except requests.RequestException as exc: logger.error(f"Failed to fetch JWT public key: {str(exc)}") return None content_type = response.headers.get("Content-Type", "").lower() raw_body = response.text body_lstripped = raw_body.lstrip() if "application/json" in content_type or body_lstripped.startswith("{"): try: data = response.json() except ValueError: logger.error("JWT public key URL returned invalid JSON") return None if isinstance(data, dict) and "keys" in data: return data, PublicKeyFormat.JWKS logger.error( "JWT public key URL returned JSON but no JWKS 'keys' field was found" ) return None body = raw_body.strip() if not body: logger.error("JWT public key URL returned an empty response") return None return body, PublicKeyFormat.PEM def get_public_key(token: str) -> RSAPublicKey | str | None: """Return the concrete public key used to verify the provided JWT token.""" payload = _fetch_public_key_payload() if payload is None: logger.error("Failed to retrieve public key payload") return None key_material, key_format = payload if key_format is PublicKeyFormat.JWKS: jwks_data = cast(dict[str, Any], key_material) return _resolve_public_key_from_jwks(token, jwks_data) return cast(str, key_material) def _resolve_public_key_from_jwks( token: str, jwks_payload: dict[str, Any] ) -> RSAPublicKey | None: try: header = jwt.get_unverified_header(token) except PyJWTError as e: logger.error(f"Unable to parse JWT header: {str(e)}") return None keys = jwks_payload.get("keys", []) if isinstance(jwks_payload, dict) else [] if not keys: logger.error("JWKS payload did not contain any keys") return None kid = header.get("kid") thumbprint = header.get("x5t") candidates = [] if kid: candidates = [k for k in keys if k.get("kid") == kid] if not candidates and thumbprint: candidates = [k for k in keys if k.get("x5t") == thumbprint] if not candidates and len(keys) == 1: candidates = keys if not candidates: logger.warning( "No matching JWK found for token header (kid=%s, x5t=%s)", kid, thumbprint ) return None if len(candidates) > 1: logger.warning( "Multiple JWKs matched token header kid=%s; selecting the first occurrence", kid, ) jwk = candidates[0] try: return cast(RSAPublicKey, RSAAlgorithm.from_jwk(json.dumps(jwk))) except ValueError as e: logger.error(f"Failed to construct RSA key from JWK: {str(e)}") return None async def verify_jwt_token(token: str) -> dict[str, Any] | None: for attempt in range(_PUBLIC_KEY_FETCH_ATTEMPTS): public_key = get_public_key(token) if public_key is None: logger.error("Unable to resolve a public key for JWT verification") if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1: _fetch_public_key_payload.cache_clear() continue return None try: payload = jwt_decode( token, public_key, algorithms=["RS256"], options={"verify_aud": False}, ) except InvalidTokenError as e: logger.error(f"Invalid JWT token: {str(e)}") if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1: _fetch_public_key_payload.cache_clear() continue return None except PyJWTError as e: logger.error(f"JWT decoding error: {str(e)}") if attempt < _PUBLIC_KEY_FETCH_ATTEMPTS - 1: _fetch_public_key_payload.cache_clear() continue return None return payload return None ================================================ FILE: backend/onyx/auth/oauth_refresher.py ================================================ from datetime import datetime from datetime import timezone from typing import Any from typing import cast from typing import Dict from typing import List from typing import Optional import httpx from fastapi_users.manager import BaseUserManager from sqlalchemy.ext.asyncio import AsyncSession from onyx.configs.app_configs import OAUTH_CLIENT_ID from onyx.configs.app_configs import OAUTH_CLIENT_SECRET from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY from onyx.db.models import OAuthAccount from onyx.db.models import User from onyx.utils.logger import setup_logger logger = setup_logger() # Standard OAuth refresh token endpoints REFRESH_ENDPOINTS = { "google": "https://oauth2.googleapis.com/token", } # NOTE: Keeping this as a utility function for potential future debugging, # but not using it in production code async def _test_expire_oauth_token( user: User, oauth_account: OAuthAccount, db_session: AsyncSession, # noqa: ARG001 user_manager: BaseUserManager[User, Any], expire_in_seconds: int = 10, ) -> bool: """ Utility function for testing - Sets an OAuth token to expire in a short time to facilitate testing of the refresh flow. Not used in production code. """ try: new_expires_at = int( (datetime.now(timezone.utc).timestamp() + expire_in_seconds) ) updated_data: Dict[str, Any] = {"expires_at": new_expires_at} await user_manager.user_db.update_oauth_account( user, cast(Any, oauth_account), updated_data ) return True except Exception as e: logger.exception(f"Error setting artificial expiration: {str(e)}") return False async def refresh_oauth_token( user: User, oauth_account: OAuthAccount, db_session: AsyncSession, # noqa: ARG001 user_manager: BaseUserManager[User, Any], ) -> bool: """ Attempt to refresh an OAuth token that's about to expire or has expired. Returns True if successful, False otherwise. """ if not oauth_account.refresh_token: logger.warning( f"No refresh token available for {user.email}'s {oauth_account.oauth_name} account" ) return False provider = oauth_account.oauth_name if provider not in REFRESH_ENDPOINTS: logger.warning(f"Refresh endpoint not configured for provider: {provider}") return False try: logger.info(f"Refreshing OAuth token for {user.email}'s {provider} account") async with httpx.AsyncClient() as client: response = await client.post( REFRESH_ENDPOINTS[provider], data={ "client_id": OAUTH_CLIENT_ID, "client_secret": OAUTH_CLIENT_SECRET, "refresh_token": oauth_account.refresh_token, "grant_type": "refresh_token", }, headers={"Content-Type": "application/x-www-form-urlencoded"}, ) if response.status_code != 200: logger.error( f"Failed to refresh OAuth token: Status {response.status_code}" ) return False token_data = response.json() new_access_token = token_data.get("access_token") new_refresh_token = token_data.get( "refresh_token", oauth_account.refresh_token ) expires_in = token_data.get("expires_in") # Calculate new expiry time if provided new_expires_at: Optional[int] = None if expires_in: new_expires_at = int( (datetime.now(timezone.utc).timestamp() + expires_in) ) # Update the OAuth account updated_data: Dict[str, Any] = { "access_token": new_access_token, "refresh_token": new_refresh_token, } if new_expires_at: updated_data["expires_at"] = new_expires_at # Update oidc_expiry in user model if we're tracking it if TRACK_EXTERNAL_IDP_EXPIRY: oidc_expiry = datetime.fromtimestamp( new_expires_at, tz=timezone.utc ) await user_manager.user_db.update( user, {"oidc_expiry": oidc_expiry} ) # Update the OAuth account await user_manager.user_db.update_oauth_account( user, cast(Any, oauth_account), updated_data ) logger.info(f"Successfully refreshed OAuth token for {user.email}") return True except Exception as e: logger.exception(f"Error refreshing OAuth token: {str(e)}") return False async def check_and_refresh_oauth_tokens( user: User, db_session: AsyncSession, user_manager: BaseUserManager[User, Any], ) -> None: """ Check if any OAuth tokens are expired or about to expire and refresh them. """ if not hasattr(user, "oauth_accounts") or not user.oauth_accounts: return now_timestamp = datetime.now(timezone.utc).timestamp() # Buffer time to refresh tokens before they expire (in seconds) buffer_seconds = 300 # 5 minutes for oauth_account in user.oauth_accounts: # Skip accounts without refresh tokens if not oauth_account.refresh_token: continue # If token is about to expire, refresh it if ( oauth_account.expires_at and oauth_account.expires_at - now_timestamp < buffer_seconds ): logger.info(f"OAuth token for {user.email} is about to expire - refreshing") success = await refresh_oauth_token( user, oauth_account, db_session, user_manager ) if not success: logger.warning( "Failed to refresh OAuth token. User may need to re-authenticate." ) async def check_oauth_account_has_refresh_token( user: User, # noqa: ARG001 oauth_account: OAuthAccount, ) -> bool: """ Check if an OAuth account has a refresh token. Returns True if a refresh token exists, False otherwise. """ return bool(oauth_account.refresh_token) async def get_oauth_accounts_requiring_refresh_token(user: User) -> List[OAuthAccount]: """ Returns a list of OAuth accounts for a user that are missing refresh tokens. These accounts will need re-authentication to get refresh tokens. """ if not hasattr(user, "oauth_accounts") or not user.oauth_accounts: return [] accounts_needing_refresh = [] for oauth_account in user.oauth_accounts: has_refresh_token = await check_oauth_account_has_refresh_token( user, oauth_account ) if not has_refresh_token: accounts_needing_refresh.append(oauth_account) return accounts_needing_refresh ================================================ FILE: backend/onyx/auth/oauth_token_manager.py ================================================ import time from typing import Any from urllib.parse import urlencode from uuid import UUID import requests from sqlalchemy.orm import Session from onyx.db.models import OAuthConfig from onyx.db.models import OAuthUserToken from onyx.db.oauth_config import get_user_oauth_token from onyx.db.oauth_config import upsert_user_oauth_token from onyx.utils.logger import setup_logger from onyx.utils.sensitive import SensitiveValue logger = setup_logger() class OAuthTokenManager: """Manages OAuth token retrieval, refresh, and validation""" def __init__(self, oauth_config: OAuthConfig, user_id: UUID, db_session: Session): self.oauth_config = oauth_config self.user_id = user_id self.db_session = db_session def get_valid_access_token(self) -> str | None: """Get valid access token, refreshing if necessary""" user_token = get_user_oauth_token( self.oauth_config.id, self.user_id, self.db_session ) if not user_token: return None if not user_token.token_data: return None token_data = self._unwrap_token_data(user_token.token_data) # Check if token is expired if OAuthTokenManager.is_token_expired(token_data): # Try to refresh if we have a refresh token if "refresh_token" in token_data: try: return self.refresh_token(user_token) except Exception as e: logger.warning(f"Failed to refresh token: {e}") return None else: return None return token_data.get("access_token") def refresh_token(self, user_token: OAuthUserToken) -> str: """Refresh access token using refresh token""" if not user_token.token_data: raise ValueError("No token data available for refresh") if ( self.oauth_config.client_id is None or self.oauth_config.client_secret is None ): raise ValueError( "OAuth client_id and client_secret are required for token refresh" ) token_data = self._unwrap_token_data(user_token.token_data) data: dict[str, str] = { "grant_type": "refresh_token", "refresh_token": token_data["refresh_token"], "client_id": self._unwrap_sensitive_str(self.oauth_config.client_id), "client_secret": self._unwrap_sensitive_str( self.oauth_config.client_secret ), } response = requests.post( self.oauth_config.token_url, data=data, headers={"Accept": "application/json"}, ) response.raise_for_status() new_token_data = response.json() # Calculate expires_at if expires_in is present if "expires_in" in new_token_data: new_token_data["expires_at"] = ( int(time.time()) + new_token_data["expires_in"] ) # Preserve refresh_token if not returned (some providers don't return it) if "refresh_token" not in new_token_data and "refresh_token" in token_data: new_token_data["refresh_token"] = token_data["refresh_token"] # Update token in DB upsert_user_oauth_token( self.oauth_config.id, self.user_id, new_token_data, self.db_session, ) return new_token_data["access_token"] @classmethod def token_expiration_time(cls, token_data: dict[str, Any]) -> int | None: """Get the token expiration time""" expires_at = token_data.get("expires_at") if not expires_at: return None return expires_at @classmethod def is_token_expired(cls, token_data: dict[str, Any]) -> bool: """Check if token is expired (with 60 second buffer)""" expires_at = cls.token_expiration_time(token_data) if not expires_at: return False # No expiration data, assume valid # Add 60 second buffer to avoid race conditions return int(time.time()) + 60 >= expires_at def exchange_code_for_token(self, code: str, redirect_uri: str) -> dict[str, Any]: """Exchange authorization code for access token""" if ( self.oauth_config.client_id is None or self.oauth_config.client_secret is None ): raise ValueError( "OAuth client_id and client_secret are required for code exchange" ) data: dict[str, str] = { "grant_type": "authorization_code", "code": code, "client_id": self._unwrap_sensitive_str(self.oauth_config.client_id), "client_secret": self._unwrap_sensitive_str( self.oauth_config.client_secret ), "redirect_uri": redirect_uri, } response = requests.post( self.oauth_config.token_url, data=data, headers={"Accept": "application/json"}, ) response.raise_for_status() token_data = response.json() # Calculate expires_at if expires_in is present if "expires_in" in token_data: token_data["expires_at"] = int(time.time()) + token_data["expires_in"] return token_data @staticmethod def build_authorization_url( oauth_config: OAuthConfig, redirect_uri: str, state: str ) -> str: """Build OAuth authorization URL""" if oauth_config.client_id is None: raise ValueError("OAuth client_id is required to build authorization URL") params: dict[str, Any] = { "client_id": OAuthTokenManager._unwrap_sensitive_str( oauth_config.client_id ), "redirect_uri": redirect_uri, "response_type": "code", "state": state, } # Add scopes if configured if oauth_config.scopes: params["scope"] = " ".join(oauth_config.scopes) # Add any additional provider-specific parameters if oauth_config.additional_params: params.update(oauth_config.additional_params) # Check if URL already has query parameters separator = "&" if "?" in oauth_config.authorization_url else "?" return f"{oauth_config.authorization_url}{separator}{urlencode(params)}" @staticmethod def _unwrap_sensitive_str(value: SensitiveValue[str] | str) -> str: if isinstance(value, SensitiveValue): return value.get_value(apply_mask=False) return value @staticmethod def _unwrap_token_data( token_data: SensitiveValue[dict[str, Any]] | dict[str, Any], ) -> dict[str, Any]: if isinstance(token_data, SensitiveValue): return token_data.get_value(apply_mask=False) return token_data ================================================ FILE: backend/onyx/auth/pat.py ================================================ """Personal Access Token generation and validation.""" import hashlib import secrets from datetime import datetime from datetime import timedelta from datetime import timezone from urllib.parse import quote from fastapi import Request from onyx.auth.constants import PAT_LENGTH from onyx.auth.constants import PAT_PREFIX from onyx.auth.utils import get_hashed_bearer_token_from_request from shared_configs.configs import MULTI_TENANT def generate_pat(tenant_id: str | None = None) -> str: """Generate cryptographically secure PAT.""" if MULTI_TENANT and tenant_id: encoded_tenant = quote(tenant_id) return f"{PAT_PREFIX}{encoded_tenant}.{secrets.token_urlsafe(PAT_LENGTH)}" return PAT_PREFIX + secrets.token_urlsafe(PAT_LENGTH) def hash_pat(token: str) -> str: """Hash PAT using SHA256 (no salt needed due to cryptographic randomness).""" return hashlib.sha256(token.encode("utf-8")).hexdigest() def build_displayable_pat(token: str) -> str: """Create masked display version: show prefix + first 4 random chars, mask middle, show last 4. Example: onyx_pat_abc1****xyz9 """ # Show first 12 chars (onyx_pat_ + 4 random chars) and last 4 chars return f"{token[:12]}****{token[-4:]}" def get_hashed_pat_from_request(request: Request) -> str | None: """Extract and hash PAT from Authorization header. Only accepts "Bearer " format (unlike API keys which support raw format). """ return get_hashed_bearer_token_from_request( request, valid_prefixes=[PAT_PREFIX], hash_fn=hash_pat, allow_non_bearer=False, # PATs require Bearer prefix ) def calculate_expiration(days: int | None) -> datetime | None: """Calculate expiration at 23:59:59.999999 UTC on the target date. None = no expiration.""" if days is None: return None expiry_date = datetime.now(timezone.utc).date() + timedelta(days=days) return datetime.combine(expiry_date, datetime.max.time()).replace( tzinfo=timezone.utc ) ================================================ FILE: backend/onyx/auth/permissions.py ================================================ """ Permission resolution for group-based authorization. Granted permissions are stored as a JSONB column on the User table and loaded for free with every auth query. Implied permissions are expanded at read time — only directly granted permissions are persisted. """ from collections.abc import Callable from collections.abc import Coroutine from typing import Any from fastapi import Depends from onyx.auth.users import current_user from onyx.db.enums import Permission from onyx.db.models import User from onyx.error_handling.error_codes import OnyxErrorCode from onyx.error_handling.exceptions import OnyxError from onyx.utils.logger import setup_logger logger = setup_logger() ALL_PERMISSIONS: frozenset[str] = frozenset(p.value for p in Permission) # Implication map: granted permission -> set of permissions it implies. IMPLIED_PERMISSIONS: dict[str, set[str]] = { Permission.ADD_AGENTS.value: {Permission.READ_AGENTS.value}, Permission.MANAGE_AGENTS.value: { Permission.ADD_AGENTS.value, Permission.READ_AGENTS.value, }, Permission.MANAGE_DOCUMENT_SETS.value: { Permission.READ_DOCUMENT_SETS.value, Permission.READ_CONNECTORS.value, }, Permission.ADD_CONNECTORS.value: {Permission.READ_CONNECTORS.value}, Permission.MANAGE_CONNECTORS.value: { Permission.ADD_CONNECTORS.value, Permission.READ_CONNECTORS.value, }, Permission.MANAGE_USER_GROUPS.value: { Permission.READ_CONNECTORS.value, Permission.READ_DOCUMENT_SETS.value, Permission.READ_AGENTS.value, Permission.READ_USERS.value, }, } def resolve_effective_permissions(granted: set[str]) -> set[str]: """Expand granted permissions with their implied permissions. If "admin" is present, returns all 19 permissions. """ if Permission.FULL_ADMIN_PANEL_ACCESS.value in granted: return set(ALL_PERMISSIONS) effective = set(granted) changed = True while changed: changed = False for perm in list(effective): implied = IMPLIED_PERMISSIONS.get(perm) if implied and not implied.issubset(effective): effective |= implied changed = True return effective def get_effective_permissions(user: User) -> set[Permission]: """Read granted permissions from the column and expand implied permissions.""" granted: set[Permission] = set() for p in user.effective_permissions: try: granted.add(Permission(p)) except ValueError: logger.warning(f"Skipping unknown permission '{p}' for user {user.id}") if Permission.FULL_ADMIN_PANEL_ACCESS in granted: return set(Permission) expanded = resolve_effective_permissions({p.value for p in granted}) return {Permission(p) for p in expanded} def require_permission( required: Permission, ) -> Callable[..., Coroutine[Any, Any, User]]: """FastAPI dependency factory for permission-based access control. Usage: @router.get("/endpoint") def endpoint(user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS))): ... """ async def dependency(user: User = Depends(current_user)) -> User: effective = get_effective_permissions(user) if Permission.FULL_ADMIN_PANEL_ACCESS in effective: return user if required not in effective: raise OnyxError( OnyxErrorCode.INSUFFICIENT_PERMISSIONS, "You do not have the required permissions for this action.", ) return user return dependency ================================================ FILE: backend/onyx/auth/schemas.py ================================================ import uuid from enum import Enum from typing import Any from fastapi_users import schemas from typing_extensions import override from onyx.db.enums import AccountType class UserRole(str, Enum): """ User roles - Basic can't perform any admin actions - Admin can perform all admin actions - Curator can perform admin actions for groups they are curators of - Global Curator can perform admin actions for all groups they are a member of - Limited can access a limited set of basic api endpoints - Slack are users that have used onyx via slack but dont have a web login - External permissioned users that have been picked up during the external permissions sync process but don't have a web login """ LIMITED = "limited" BASIC = "basic" ADMIN = "admin" CURATOR = "curator" GLOBAL_CURATOR = "global_curator" SLACK_USER = "slack_user" EXT_PERM_USER = "ext_perm_user" def is_web_login(self) -> bool: return self not in [ UserRole.SLACK_USER, UserRole.EXT_PERM_USER, ] class UserRead(schemas.BaseUser[uuid.UUID]): role: UserRole class UserCreate(schemas.BaseUserCreate): role: UserRole = UserRole.BASIC account_type: AccountType = AccountType.STANDARD tenant_id: str | None = None # Captcha token for cloud signup protection (optional, only used when captcha is enabled) # Excluded from create_update_dict so it never reaches the DB layer captcha_token: str | None = None @override def create_update_dict(self) -> dict[str, Any]: d = super().create_update_dict() d.pop("captcha_token", None) # Force STANDARD for self-registration; only trusted paths # (SCIM, API key creation) supply a different account_type directly. d["account_type"] = AccountType.STANDARD return d @override def create_update_dict_superuser(self) -> dict[str, Any]: d = super().create_update_dict_superuser() d.pop("captcha_token", None) d.setdefault("account_type", self.account_type) return d class UserUpdate(schemas.BaseUserUpdate): """ Role updates are not allowed through the user update endpoint for security reasons Role changes should be handled through a separate, admin-only process """ class AuthBackend(str, Enum): REDIS = "redis" POSTGRES = "postgres" JWT = "jwt" ================================================ FILE: backend/onyx/auth/users.py ================================================ import base64 import hashlib import json import os import random import secrets import string import uuid from collections.abc import AsyncGenerator from datetime import datetime from datetime import timedelta from datetime import timezone from typing import Any from typing import cast from typing import Dict from typing import List from typing import Literal from typing import Optional from typing import Protocol from typing import Tuple from typing import TypeVar from urllib.parse import urlparse import jwt from email_validator import EmailNotValidError from email_validator import EmailUndeliverableError from email_validator import validate_email from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import Query from fastapi import Request from fastapi import Response from fastapi import status from fastapi import WebSocket from fastapi.responses import JSONResponse from fastapi.responses import RedirectResponse from fastapi.security import OAuth2PasswordRequestForm from fastapi_users import BaseUserManager from fastapi_users import exceptions from fastapi_users import FastAPIUsers from fastapi_users import models from fastapi_users import schemas from fastapi_users import UUIDIDMixin from fastapi_users.authentication import AuthenticationBackend from fastapi_users.authentication import CookieTransport from fastapi_users.authentication import JWTStrategy from fastapi_users.authentication import RedisStrategy from fastapi_users.authentication import Strategy from fastapi_users.authentication.strategy.db import AccessTokenDatabase from fastapi_users.authentication.strategy.db import DatabaseStrategy from fastapi_users.exceptions import UserAlreadyExists from fastapi_users.jwt import decode_jwt from fastapi_users.jwt import generate_jwt from fastapi_users.jwt import SecretType from fastapi_users.manager import UserManagerDependency from fastapi_users.openapi import OpenAPIResponseType from fastapi_users.router.common import ErrorCode from fastapi_users.router.common import ErrorModel from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase from httpx_oauth.integrations.fastapi import OAuth2AuthorizeCallback from httpx_oauth.oauth2 import BaseOAuth2 from httpx_oauth.oauth2 import GetAccessTokenError from httpx_oauth.oauth2 import OAuth2Token from pydantic import BaseModel from sqlalchemy import nulls_last from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from onyx.auth.api_key import get_hashed_api_key_from_request from onyx.auth.disposable_email_validator import is_disposable_email from onyx.auth.email_utils import send_forgot_password_email from onyx.auth.email_utils import send_user_verification_email from onyx.auth.invited_users import get_invited_users from onyx.auth.invited_users import remove_user_from_invited_users from onyx.auth.jwt import verify_jwt_token from onyx.auth.pat import get_hashed_pat_from_request from onyx.auth.schemas import AuthBackend from onyx.auth.schemas import UserCreate from onyx.auth.schemas import UserRole from onyx.configs.app_configs import AUTH_BACKEND from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS from onyx.configs.app_configs import AUTH_TYPE from onyx.configs.app_configs import EMAIL_CONFIGURED from onyx.configs.app_configs import JWT_PUBLIC_KEY_URL from onyx.configs.app_configs import PASSWORD_MAX_LENGTH from onyx.configs.app_configs import PASSWORD_MIN_LENGTH from onyx.configs.app_configs import PASSWORD_REQUIRE_DIGIT from onyx.configs.app_configs import PASSWORD_REQUIRE_LOWERCASE from onyx.configs.app_configs import PASSWORD_REQUIRE_SPECIAL_CHAR from onyx.configs.app_configs import PASSWORD_REQUIRE_UPPERCASE from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY from onyx.configs.app_configs import USER_AUTH_SECRET from onyx.configs.app_configs import VALID_EMAIL_DOMAINS from onyx.configs.app_configs import WEB_DOMAIN from onyx.configs.constants import ANONYMOUS_USER_COOKIE_NAME from onyx.configs.constants import ANONYMOUS_USER_EMAIL from onyx.configs.constants import ANONYMOUS_USER_UUID from onyx.configs.constants import AuthType from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN from onyx.configs.constants import DANSWER_API_KEY_PREFIX from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME from onyx.configs.constants import MilestoneRecordType from onyx.configs.constants import OnyxRedisLocks from onyx.configs.constants import PASSWORD_SPECIAL_CHARS from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER from onyx.db.api_key import fetch_user_for_api_key from onyx.db.auth import get_access_token_db from onyx.db.auth import get_default_admin_user_emails from onyx.db.auth import get_user_count from onyx.db.auth import get_user_db from onyx.db.auth import SQLAlchemyUserAdminDB from onyx.db.engine.async_sql_engine import get_async_session from onyx.db.engine.async_sql_engine import get_async_session_context_manager from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.engine.sql_engine import get_session_with_tenant from onyx.db.enums import AccountType from onyx.db.models import AccessToken from onyx.db.models import OAuthAccount from onyx.db.models import Persona from onyx.db.models import User from onyx.db.pat import fetch_user_for_pat from onyx.db.users import assign_user_to_default_groups__no_commit from onyx.db.users import get_user_by_email from onyx.error_handling.error_codes import OnyxErrorCode from onyx.error_handling.exceptions import log_onyx_error from onyx.error_handling.exceptions import onyx_error_to_json_response from onyx.error_handling.exceptions import OnyxError from onyx.redis.redis_pool import get_async_redis_connection from onyx.redis.redis_pool import retrieve_ws_token_data from onyx.server.settings.store import load_settings from onyx.server.utils import BasicAuthenticationError from onyx.utils.logger import setup_logger from onyx.utils.telemetry import mt_cloud_alias from onyx.utils.telemetry import mt_cloud_get_anon_id from onyx.utils.telemetry import mt_cloud_identify from onyx.utils.telemetry import mt_cloud_telemetry from onyx.utils.telemetry import optional_telemetry from onyx.utils.telemetry import RecordType from onyx.utils.timing import log_function_time from onyx.utils.url import add_url_params from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop from shared_configs.configs import async_return_default_schema from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() REGISTER_INVITE_ONLY_CODE = "REGISTER_INVITE_ONLY" def is_user_admin(user: User) -> bool: return user.role == UserRole.ADMIN def verify_auth_setting() -> None: """Log warnings for AUTH_TYPE issues. This only runs on app startup not during migrations/scripts. """ raw_auth_type = (os.environ.get("AUTH_TYPE") or "").lower() if raw_auth_type == "cloud": raise ValueError( "'cloud' is not a valid auth type for self-hosted deployments." ) if raw_auth_type == "disabled": logger.warning( "AUTH_TYPE='disabled' is no longer supported. Using 'basic' instead. Please update your configuration." ) logger.notice(f"Using Auth Type: {AUTH_TYPE.value}") def get_display_email(email: str | None, space_less: bool = False) -> str: if email and email.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN): name = email.split("@")[0] if name == DANSWER_API_KEY_PREFIX + UNNAMED_KEY_PLACEHOLDER: return "Unnamed API Key" if space_less: return name return name.replace("API_KEY__", "API Key: ") return email or "" def generate_password() -> str: lowercase_letters = string.ascii_lowercase uppercase_letters = string.ascii_uppercase digits = string.digits special_characters = string.punctuation # Ensure at least one of each required character type password = [ secrets.choice(uppercase_letters), secrets.choice(digits), secrets.choice(special_characters), ] # Fill the rest with a mix of characters remaining_length = 12 - len(password) all_characters = lowercase_letters + uppercase_letters + digits + special_characters password.extend(secrets.choice(all_characters) for _ in range(remaining_length)) # Shuffle the password to randomize the position of the required characters random.shuffle(password) return "".join(password) def user_needs_to_be_verified() -> bool: if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD: return REQUIRE_EMAIL_VERIFICATION # For other auth types, if the user is authenticated it's assumed that # the user is already verified via the external IDP return False def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool: from onyx.cache.factory import get_cache_backend cache = get_cache_backend(tenant_id=tenant_id) value = cache.get(OnyxRedisLocks.ANONYMOUS_USER_ENABLED) if value is None: return False return int(value.decode("utf-8")) == 1 def workspace_invite_only_enabled() -> bool: settings = load_settings() return settings.invite_only_enabled def verify_email_is_invited(email: str) -> None: if AUTH_TYPE in {AuthType.SAML, AuthType.OIDC}: # SSO providers manage membership; allow JIT provisioning regardless of invites return if not workspace_invite_only_enabled(): return whitelist = get_invited_users() if not email: raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Email must be specified") try: email_info = validate_email(email, check_deliverability=False) except EmailUndeliverableError: raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Email is not valid") for email_whitelist in whitelist: try: # normalized emails are now being inserted into the db # we can remove this normalization on read after some time has passed email_info_whitelist = validate_email( email_whitelist, check_deliverability=False ) except EmailNotValidError: continue # oddly, normalization does not include lowercasing the user part of the # email address ... which we want to allow if email_info.normalized.lower() == email_info_whitelist.normalized.lower(): return raise OnyxError( OnyxErrorCode.UNAUTHORIZED, "This workspace is invite-only. Please ask your admin to invite you.", ) def verify_email_in_whitelist(email: str, tenant_id: str) -> None: with get_session_with_tenant(tenant_id=tenant_id) as db_session: if not get_user_by_email(email, db_session): verify_email_is_invited(email) def verify_email_domain(email: str, *, is_registration: bool = False) -> None: if email.count("@") != 1: raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Email is not valid") local_part, domain = email.split("@") domain = domain.lower() local_part = local_part.lower() if AUTH_TYPE == AuthType.CLOUD: # Normalize googlemail.com to gmail.com (they deliver to the same inbox) if domain == "googlemail.com": raise OnyxError( OnyxErrorCode.INVALID_INPUT, "Please use @gmail.com instead of @googlemail.com.", ) # Only block dotted Gmail on new signups — existing users must still be # able to sign in with the address they originally registered with. if is_registration and domain == "gmail.com" and "." in local_part: raise OnyxError( OnyxErrorCode.INVALID_INPUT, "Gmail addresses with '.' are not allowed. Please use your base email address.", ) if "+" in local_part and domain != "onyx.app": raise OnyxError( OnyxErrorCode.INVALID_INPUT, "Email addresses with '+' are not allowed. Please use your base email address.", ) # Check if email uses a disposable/temporary domain if is_disposable_email(email): raise OnyxError( OnyxErrorCode.INVALID_INPUT, "Disposable email addresses are not allowed. Please use a permanent email address.", ) # Check domain whitelist if configured if VALID_EMAIL_DOMAINS: if domain not in VALID_EMAIL_DOMAINS: raise OnyxError(OnyxErrorCode.INVALID_INPUT, "Email domain is not valid") def enforce_seat_limit(db_session: Session, seats_needed: int = 1) -> None: """Raise HTTPException(402) if adding users would exceed the seat limit. No-op for multi-tenant or CE deployments. """ if MULTI_TENANT: return result = fetch_ee_implementation_or_noop( "onyx.db.license", "check_seat_availability", None )(db_session, seats_needed=seats_needed) if result is not None and not result.available: raise OnyxError(OnyxErrorCode.SEAT_LIMIT_EXCEEDED, result.error_message) class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]): reset_password_token_secret = USER_AUTH_SECRET verification_token_secret = USER_AUTH_SECRET verification_token_lifetime_seconds = AUTH_COOKIE_EXPIRE_TIME_SECONDS user_db: SQLAlchemyUserDatabase[User, uuid.UUID] async def get_by_email(self, user_email: str) -> User: tenant_id = fetch_ee_implementation_or_noop( "onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None )(user_email) async with get_async_session_context_manager(tenant_id) as db_session: if MULTI_TENANT: tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID]( db_session, User, OAuthAccount ) user = await tenant_user_db.get_by_email(user_email) else: user = await self.user_db.get_by_email(user_email) if not user: raise exceptions.UserNotExists() return user async def create( self, user_create: schemas.UC | UserCreate, safe: bool = False, request: Optional[Request] = None, ) -> User: # Verify captcha if enabled (for cloud signup protection) from onyx.auth.captcha import CaptchaVerificationError from onyx.auth.captcha import is_captcha_enabled from onyx.auth.captcha import verify_captcha_token if is_captcha_enabled() and request is not None: # Get captcha token from request body or headers captcha_token = None if hasattr(user_create, "captcha_token"): captcha_token = getattr(user_create, "captcha_token", None) # Also check headers as a fallback if not captcha_token: captcha_token = request.headers.get("X-Captcha-Token") try: await verify_captcha_token( captcha_token or "", expected_action="signup" ) except CaptchaVerificationError as e: raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e)) # We verify the password here to make sure it's valid before we proceed await self.validate_password( user_create.password, cast(schemas.UC, user_create) ) # Check for disposable emails BEFORE provisioning tenant # This prevents creating tenants for throwaway email addresses try: verify_email_domain(user_create.email, is_registration=True) except OnyxError as e: # Log blocked disposable email attempts if "Disposable email" in e.detail: domain = ( user_create.email.split("@")[-1] if "@" in user_create.email else "unknown" ) logger.warning( f"Blocked disposable email registration attempt: {domain}", extra={"email_domain": domain}, ) raise user_count: int | None = None referral_source = ( request.cookies.get("referral_source", None) if request is not None else None ) tenant_id = await fetch_ee_implementation_or_noop( "onyx.server.tenants.provisioning", "get_or_provision_tenant", async_return_default_schema, )( email=user_create.email, referral_source=referral_source, request=request, ) user: User token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) try: async with get_async_session_context_manager(tenant_id) as db_session: # Check invite list based on deployment mode if MULTI_TENANT: # Multi-tenant: Only require invite for existing tenants # New tenant creation (first user) doesn't require an invite user_count = await get_user_count() if user_count > 0: # Tenant already has users - require invite for new users verify_email_is_invited(user_create.email) else: # Single-tenant: Check invite list (skips if SAML/OIDC or no list configured) verify_email_is_invited(user_create.email) if MULTI_TENANT: tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID]( db_session, User, OAuthAccount ) self.user_db = tenant_user_db if hasattr(user_create, "role"): user_create.role = UserRole.BASIC user_count = await get_user_count() if ( user_count == 0 or user_create.email in get_default_admin_user_emails() ): user_create.role = UserRole.ADMIN # Check seat availability for new users (single-tenant only) with get_session_with_current_tenant() as sync_db: existing = get_user_by_email(user_create.email, sync_db) if existing is None: enforce_seat_limit(sync_db) user_created = False try: user = await super().create(user_create, safe=safe, request=request) user_created = True except IntegrityError as error: # Race condition: another request created the same user after the # pre-insert existence check but before our commit. await self.user_db.session.rollback() logger.warning( "IntegrityError while creating user %s, assuming duplicate: %s", user_create.email, str(error), ) try: user = await self.get_by_email(user_create.email) except exceptions.UserNotExists: # Unexpected integrity error, surface it for handling upstream. raise error if MULTI_TENANT: user_by_session = await db_session.get(User, user.id) if user_by_session: user = user_by_session if ( user.account_type.is_web_login() or not isinstance(user_create, UserCreate) or not user_create.account_type.is_web_login() ): raise exceptions.UserAlreadyExists() # Cache id before expire — accessing attrs on an expired # object triggers a sync lazy-load which raises MissingGreenlet # in this async context. user_id = user.id self._upgrade_user_to_standard__sync(user_id, user_create) # Expire so the async session re-fetches the row updated by # the sync session above. self.user_db.session.expire(user) user = await self.user_db.get(user_id) # type: ignore[assignment] except exceptions.UserAlreadyExists: user = await self.get_by_email(user_create.email) # we must use the existing user in the session if it matches # the user we just got by email. Note that this only applies # to multi-tenant, due to the overwriting of the user_db if MULTI_TENANT: user_by_session = await db_session.get(User, user.id) if user_by_session: user = user_by_session # Handle case where user has used product outside of web and is now creating an account through web if ( user.account_type.is_web_login() or not isinstance(user_create, UserCreate) or not user_create.account_type.is_web_login() ): raise exceptions.UserAlreadyExists() # Cache id before expire — accessing attrs on an expired # object triggers a sync lazy-load which raises MissingGreenlet # in this async context. user_id = user.id self._upgrade_user_to_standard__sync(user_id, user_create) # Expire so the async session re-fetches the row updated by # the sync session above. self.user_db.session.expire(user) user = await self.user_db.get(user_id) # type: ignore[assignment] if user_created: await self._assign_default_pinned_assistants(user, db_session) remove_user_from_invited_users(user_create.email) finally: CURRENT_TENANT_ID_CONTEXTVAR.reset(token) return user async def _assign_default_pinned_assistants( self, user: User, db_session: AsyncSession ) -> None: if user.pinned_assistants is not None: return result = await db_session.execute( select(Persona.id) .where( Persona.is_featured.is_(True), Persona.is_public.is_(True), Persona.is_listed.is_(True), Persona.deleted.is_(False), ) .order_by( nulls_last(Persona.display_priority.asc()), Persona.id.asc(), ) ) default_persona_ids = list(result.scalars().all()) if not default_persona_ids: return await self.user_db.update( user, {"pinned_assistants": default_persona_ids}, ) user.pinned_assistants = default_persona_ids def _upgrade_user_to_standard__sync( self, user_id: uuid.UUID, user_create: UserCreate, ) -> None: """Upgrade a non-web user to STANDARD and assign default groups atomically. All writes happen in a single sync transaction so neither the field update nor the group assignment is visible without the other. """ with get_session_with_current_tenant() as sync_db: sync_user = sync_db.query(User).filter(User.id == user_id).first() # type: ignore[arg-type] if sync_user: sync_user.hashed_password = self.password_helper.hash( user_create.password ) sync_user.is_verified = user_create.is_verified or False sync_user.role = user_create.role sync_user.account_type = AccountType.STANDARD assign_user_to_default_groups__no_commit( sync_db, sync_user, is_admin=(user_create.role == UserRole.ADMIN), ) sync_db.commit() else: logger.warning( "User %s not found in sync session during upgrade to standard; " "skipping upgrade", user_id, ) async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None: # Validate password according to configurable security policy (defined via environment variables) if len(password) < PASSWORD_MIN_LENGTH: raise exceptions.InvalidPasswordException( reason=f"Password must be at least {PASSWORD_MIN_LENGTH} characters long." ) if len(password) > PASSWORD_MAX_LENGTH: raise exceptions.InvalidPasswordException( reason=f"Password must not exceed {PASSWORD_MAX_LENGTH} characters." ) if PASSWORD_REQUIRE_UPPERCASE and not any(char.isupper() for char in password): raise exceptions.InvalidPasswordException( reason="Password must contain at least one uppercase letter." ) if PASSWORD_REQUIRE_LOWERCASE and not any(char.islower() for char in password): raise exceptions.InvalidPasswordException( reason="Password must contain at least one lowercase letter." ) if PASSWORD_REQUIRE_DIGIT and not any(char.isdigit() for char in password): raise exceptions.InvalidPasswordException( reason="Password must contain at least one number." ) if PASSWORD_REQUIRE_SPECIAL_CHAR and not any( char in PASSWORD_SPECIAL_CHARS for char in password ): raise exceptions.InvalidPasswordException( reason=f"Password must contain at least one special character from the following set: {PASSWORD_SPECIAL_CHARS}." ) return @log_function_time(print_only=True) async def oauth_callback( self, oauth_name: str, access_token: str, account_id: str, account_email: str, expires_at: Optional[int] = None, refresh_token: Optional[str] = None, request: Optional[Request] = None, *, associate_by_email: bool = False, is_verified_by_default: bool = False, ) -> User: referral_source = ( getattr(request.state, "referral_source", None) if request else None ) tenant_id = await fetch_ee_implementation_or_noop( "onyx.server.tenants.provisioning", "get_or_provision_tenant", async_return_default_schema, )( email=account_email, referral_source=referral_source, request=request, ) if not tenant_id: raise HTTPException(status_code=401, detail="User not found") # Proceed with the tenant context token = None async with get_async_session_context_manager(tenant_id) as db_session: token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) verify_email_in_whitelist(account_email, tenant_id) verify_email_domain(account_email) # NOTE(rkuo): If this UserManager is instantiated per connection # should we even be doing this here? if MULTI_TENANT: tenant_user_db = SQLAlchemyUserAdminDB[User, uuid.UUID]( db_session, User, OAuthAccount ) self.user_db = tenant_user_db oauth_account_dict = { "oauth_name": oauth_name, "access_token": access_token, "account_id": account_id, "account_email": account_email, "expires_at": expires_at, "refresh_token": refresh_token, } user: User | None = None try: # Attempt to get user by OAuth account user = await self.get_by_oauth_account(oauth_name, account_id) except exceptions.UserNotExists: try: # Attempt to get user by email user = await self.user_db.get_by_email(account_email) if not associate_by_email: raise exceptions.UserAlreadyExists() # Make sure user is not None before adding OAuth account if user is not None: user = await self.user_db.add_oauth_account( user, oauth_account_dict ) else: # This shouldn't happen since get_by_email would raise UserNotExists # but adding as a safeguard raise exceptions.UserNotExists() except exceptions.UserNotExists: verify_email_domain(account_email, is_registration=True) # Check seat availability before creating (single-tenant only) with get_session_with_current_tenant() as sync_db: enforce_seat_limit(sync_db) password = self.password_helper.generate() user_dict = { "email": account_email, "hashed_password": self.password_helper.hash(password), "is_verified": is_verified_by_default, "account_type": AccountType.STANDARD, } user = await self.user_db.create(user_dict) await self.user_db.add_oauth_account(user, oauth_account_dict) await self._assign_default_pinned_assistants(user, db_session) await self.on_after_register(user, request) else: # User exists, update OAuth account if needed if user is not None: # Add explicit check for existing_oauth_account in user.oauth_accounts: if ( existing_oauth_account.account_id == account_id and existing_oauth_account.oauth_name == oauth_name ): user = await self.user_db.update_oauth_account( user, # NOTE: OAuthAccount DOES implement the OAuthAccountProtocol # but the type checker doesn't know that :( existing_oauth_account, # type: ignore oauth_account_dict, ) # NOTE: Most IdPs have very short expiry times, and we don't want to force the user to # re-authenticate that frequently, so by default this is disabled if expires_at and TRACK_EXTERNAL_IDP_EXPIRY: oidc_expiry = datetime.fromtimestamp(expires_at, tz=timezone.utc) await self.user_db.update( user, update_dict={"oidc_expiry": oidc_expiry} ) # Handle case where user has used product outside of web and is now creating an account through web if not user.account_type.is_web_login(): # We must use the existing user in the session if it matches # the user we just got by email/oauth. Note that this only applies # to multi-tenant, due to the overwriting of the user_db if MULTI_TENANT: if user.id: user_by_session = await db_session.get(User, user.id) if user_by_session: user = user_by_session # If the user is inactive, check seat availability before # upgrading role — otherwise they'd become an inactive BASIC # user who still can't log in. if not user.is_active: with get_session_with_current_tenant() as sync_db: enforce_seat_limit(sync_db) # Upgrade the user and assign default groups in a single # transaction so neither change is visible without the other. was_inactive = not user.is_active with get_session_with_current_tenant() as sync_db: sync_user = sync_db.query(User).filter(User.id == user.id).first() # type: ignore[arg-type] if sync_user: sync_user.is_verified = is_verified_by_default sync_user.role = UserRole.BASIC sync_user.account_type = AccountType.STANDARD if was_inactive: sync_user.is_active = True assign_user_to_default_groups__no_commit(sync_db, sync_user) sync_db.commit() # Refresh the async user object so downstream code # (e.g. oidc_expiry check) sees the updated fields. self.user_db.session.expire(user) user = await self.user_db.get(user.id) assert user is not None # this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false` # otherwise, the oidc expiry will always be old, and the user will never be able to login if user.oidc_expiry is not None and not TRACK_EXTERNAL_IDP_EXPIRY: await self.user_db.update(user, {"oidc_expiry": None}) user.oidc_expiry = None # type: ignore remove_user_from_invited_users(user.email) if token: CURRENT_TENANT_ID_CONTEXTVAR.reset(token) return user async def on_after_login( self, user: User, request: Optional[Request] = None, response: Optional[Response] = None, ) -> None: try: if response and request and ANONYMOUS_USER_COOKIE_NAME in request.cookies: response.delete_cookie( ANONYMOUS_USER_COOKIE_NAME, # Ensure cookie deletion doesn't override other cookies by setting the same path/domain path="/", domain=None, secure=WEB_DOMAIN.startswith("https"), ) logger.debug(f"Deleted anonymous user cookie for user {user.email}") except Exception: logger.exception("Error deleting anonymous user cookie") tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() # Link the anonymous PostHog session to the identified user so that # pre-login session recordings and events merge into one person profile. if anon_id := mt_cloud_get_anon_id(request): mt_cloud_alias(distinct_id=str(user.id), anonymous_id=anon_id) mt_cloud_identify( distinct_id=str(user.id), properties={"email": user.email, "tenant_id": tenant_id}, ) async def on_after_register( self, user: User, request: Optional[Request] = None ) -> None: tenant_id = await fetch_ee_implementation_or_noop( "onyx.server.tenants.provisioning", "get_or_provision_tenant", async_return_default_schema, )( email=user.email, request=request, ) user_count = None token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) try: user_count = await get_user_count() logger.debug(f"Current tenant user count: {user_count}") # Link the anonymous PostHog session to the identified user so # that pre-signup session recordings merge into one person profile. if anon_id := mt_cloud_get_anon_id(request): mt_cloud_alias(distinct_id=str(user.id), anonymous_id=anon_id) # Ensure a PostHog person profile exists for this user. mt_cloud_identify( distinct_id=str(user.id), properties={"email": user.email, "tenant_id": tenant_id}, ) mt_cloud_telemetry( tenant_id=tenant_id, distinct_id=str(user.id), event=MilestoneRecordType.USER_SIGNED_UP, ) if user_count == 1: mt_cloud_telemetry( tenant_id=tenant_id, distinct_id=str(user.id), event=MilestoneRecordType.TENANT_CREATED, ) # Assign user to the appropriate default group (Admin or Basic). # Must happen inside the try block while tenant context is active, # otherwise get_session_with_current_tenant() targets the wrong schema. is_admin = user_count == 1 or user.email in get_default_admin_user_emails() with get_session_with_current_tenant() as db_session: assign_user_to_default_groups__no_commit( db_session, user, is_admin=is_admin ) db_session.commit() finally: CURRENT_TENANT_ID_CONTEXTVAR.reset(token) # Fetch EE PostHog functions if available get_marketing_posthog_cookie_name = fetch_ee_implementation_or_noop( module="onyx.utils.posthog_client", attribute="get_marketing_posthog_cookie_name", noop_return_value=None, ) parse_posthog_cookie = fetch_ee_implementation_or_noop( module="onyx.utils.posthog_client", attribute="parse_posthog_cookie", noop_return_value=None, ) capture_and_sync_with_alternate_posthog = fetch_ee_implementation_or_noop( module="onyx.utils.posthog_client", attribute="capture_and_sync_with_alternate_posthog", noop_return_value=None, ) if ( request and user_count is not None and (marketing_cookie_name := get_marketing_posthog_cookie_name()) and (marketing_cookie_value := request.cookies.get(marketing_cookie_name)) and (parsed_cookie := parse_posthog_cookie(marketing_cookie_value)) ): marketing_anonymous_id = parsed_cookie["distinct_id"] # Technically, USER_SIGNED_UP is only fired from the cloud site when # it is the first user in a tenant. However, it is semantically correct # for the marketing site and should probably be refactored for the cloud site # to also be semantically correct. properties = { "email": user.email, "onyx_cloud_user_id": str(user.id), "tenant_id": str(tenant_id) if tenant_id else None, "role": user.role.value, "is_first_user": user_count == 1, "source": "marketing_site_signup", "conversion_timestamp": datetime.now(timezone.utc).isoformat(), } # Add all other values from the marketing cookie (featureFlags, etc.) for key, value in parsed_cookie.items(): if key != "distinct_id": properties.setdefault(key, value) capture_and_sync_with_alternate_posthog( alternate_distinct_id=marketing_anonymous_id, event=MilestoneRecordType.USER_SIGNED_UP, properties=properties, ) logger.debug(f"User {user.id} has registered.") optional_telemetry( record_type=RecordType.SIGN_UP, data={"action": "create"}, user_id=str(user.id), ) async def on_after_forgot_password( self, user: User, token: str, request: Optional[Request] = None, # noqa: ARG002 ) -> None: if not EMAIL_CONFIGURED: logger.error( "Email is not configured. Please configure email in the admin panel" ) raise HTTPException( status.HTTP_500_INTERNAL_SERVER_ERROR, "Your admin has not enabled this feature.", ) tenant_id = await fetch_ee_implementation_or_noop( "onyx.server.tenants.provisioning", "get_or_provision_tenant", async_return_default_schema, )(email=user.email) send_forgot_password_email(user.email, tenant_id=tenant_id, token=token) async def on_after_request_verify( self, user: User, token: str, request: Optional[Request] = None, # noqa: ARG002 ) -> None: verify_email_domain(user.email) logger.notice( f"Verification requested for user {user.id}. Verification token: {token}" ) user_count = await get_user_count() send_user_verification_email( user.email, token, new_organization=user_count == 1 ) @log_function_time(print_only=True) async def authenticate( self, credentials: OAuth2PasswordRequestForm ) -> Optional[User]: email = credentials.username tenant_id: str | None = None try: tenant_id = fetch_ee_implementation_or_noop( "onyx.server.tenants.provisioning", "get_tenant_id_for_email", POSTGRES_DEFAULT_SCHEMA, )( email=email, ) except Exception as e: logger.warning( f"User attempted to login with invalid credentials: {str(e)}" ) if not tenant_id: # User not found in mapping self.password_helper.hash(credentials.password) return None # Create a tenant-specific session async with get_async_session_context_manager(tenant_id) as tenant_session: tenant_user_db: SQLAlchemyUserDatabase = SQLAlchemyUserDatabase( tenant_session, User ) self.user_db = tenant_user_db # Proceed with authentication try: user = await self.get_by_email(email) except exceptions.UserNotExists: self.password_helper.hash(credentials.password) return None if not user.account_type.is_web_login(): raise BasicAuthenticationError( detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD", ) verified, updated_password_hash = self.password_helper.verify_and_update( credentials.password, user.hashed_password ) if not verified: return None if updated_password_hash is not None: await self.user_db.update( user, {"hashed_password": updated_password_hash} ) return user async def reset_password_as_admin(self, user_id: uuid.UUID) -> str: """Admin-only. Generate a random password for a user and return it.""" user = await self.get(user_id) new_password = generate_password() await self._update(user, {"password": new_password}) return new_password async def change_password_if_old_matches( self, user: User, old_password: str, new_password: str ) -> None: """ For normal users to change password if they know the old one. Raises 400 if old password doesn't match. """ verified, updated_password_hash = self.password_helper.verify_and_update( old_password, user.hashed_password ) if not verified: # Raise some HTTPException (or your custom exception) if old password is invalid: from fastapi import HTTPException, status raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid current password", ) # If the hash was upgraded behind the scenes, we can keep it before setting the new password: if updated_password_hash: user.hashed_password = updated_password_hash # Now apply and validate the new password await self._update(user, {"password": new_password}) async def get_user_manager( user_db: SQLAlchemyUserDatabase = Depends(get_user_db), ) -> AsyncGenerator[UserManager, None]: yield UserManager(user_db) cookie_transport = CookieTransport( cookie_max_age=SESSION_EXPIRE_TIME_SECONDS, cookie_secure=WEB_DOMAIN.startswith("https"), cookie_name=FASTAPI_USERS_AUTH_COOKIE_NAME, ) T = TypeVar("T", covariant=True) ID = TypeVar("ID", contravariant=True) # Protocol for strategies that support token refreshing without inheritance. class RefreshableStrategy(Protocol): """Protocol for authentication strategies that support token refreshing.""" async def refresh_token(self, token: Optional[str], user: Any) -> str: """ Refresh an existing token by extending its lifetime. Returns either the same token with extended expiration or a new token. """ ... class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]): """ A custom strategy that fetches the actual async Redis connection inside each method. We do NOT pass a synchronous or "coroutine" redis object to the constructor. """ def __init__( self, lifetime_seconds: Optional[int] = SESSION_EXPIRE_TIME_SECONDS, key_prefix: str = REDIS_AUTH_KEY_PREFIX, ): self.lifetime_seconds = lifetime_seconds self.key_prefix = key_prefix async def write_token(self, user: User) -> str: redis = await get_async_redis_connection() tenant_id = await fetch_ee_implementation_or_noop( "onyx.server.tenants.provisioning", "get_or_provision_tenant", async_return_default_schema, )(email=user.email) token_data = { "sub": str(user.id), "tenant_id": tenant_id, } token = secrets.token_urlsafe() await redis.set( f"{self.key_prefix}{token}", json.dumps(token_data), ex=self.lifetime_seconds, ) return token async def read_token( self, token: Optional[str], user_manager: BaseUserManager[User, uuid.UUID] ) -> Optional[User]: redis = await get_async_redis_connection() token_data_str = await redis.get(f"{self.key_prefix}{token}") if not token_data_str: return None try: token_data = json.loads(token_data_str) user_id = token_data["sub"] parsed_id = user_manager.parse_id(user_id) return await user_manager.get(parsed_id) except (exceptions.UserNotExists, exceptions.InvalidID, KeyError): return None async def destroy_token(self, token: str, user: User) -> None: # noqa: ARG002 """Properly delete the token from async redis.""" redis = await get_async_redis_connection() await redis.delete(f"{self.key_prefix}{token}") async def refresh_token(self, token: Optional[str], user: User) -> str: """Refresh a token by extending its expiration time in Redis.""" if token is None: # If no token provided, create a new one return await self.write_token(user) redis = await get_async_redis_connection() token_key = f"{self.key_prefix}{token}" # Check if token exists token_data_str = await redis.get(token_key) if not token_data_str: # Token not found, create new one return await self.write_token(user) # Token exists, extend its lifetime token_data = json.loads(token_data_str) await redis.set( token_key, json.dumps(token_data), ex=self.lifetime_seconds, ) return token class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]): """Database strategy with token refreshing capabilities.""" def __init__( self, access_token_db: AccessTokenDatabase[AccessToken], lifetime_seconds: Optional[int] = None, ): super().__init__(access_token_db, lifetime_seconds) self._access_token_db = access_token_db async def refresh_token(self, token: Optional[str], user: User) -> str: """Refresh a token by updating its expiration time in the database.""" if token is None: return await self.write_token(user) # Find the token in database access_token = await self._access_token_db.get_by_token(token) if access_token is None: # Token not found, create new one return await self.write_token(user) # Update expiration time new_expires = datetime.now(timezone.utc) + timedelta( seconds=float(self.lifetime_seconds or SESSION_EXPIRE_TIME_SECONDS) ) await self._access_token_db.update(access_token, {"expires": new_expires}) return token class SingleTenantJWTStrategy(JWTStrategy[User, uuid.UUID]): """Stateless JWT strategy for single-tenant deployments. Tokens are self-contained and verified via signature — no Redis or DB lookup required per request. An ``iat`` claim is embedded so that downstream code can determine when the token was created without querying an external store. Refresh is implemented by issuing a brand-new JWT (the old one remains valid until its natural expiry). ``destroy_token`` is a no-op because JWTs cannot be server-side invalidated. """ def __init__( self, secret: SecretType, lifetime_seconds: int | None = SESSION_EXPIRE_TIME_SECONDS, token_audience: list[str] | None = None, algorithm: str = "HS256", public_key: SecretType | None = None, ): super().__init__( secret=secret, lifetime_seconds=lifetime_seconds, token_audience=token_audience or ["fastapi-users:auth"], algorithm=algorithm, public_key=public_key, ) async def write_token(self, user: User) -> str: data = { "sub": str(user.id), "aud": self.token_audience, "iat": int(datetime.now(timezone.utc).timestamp()), } return generate_jwt( data, self.encode_key, self.lifetime_seconds, algorithm=self.algorithm ) async def destroy_token(self, token: str, user: User) -> None: # noqa: ARG002 # JWTs are stateless — nothing to invalidate server-side. # NOTE: a compromise that makes JWT auth stateful but revocable # is to include a token_version claim in the JWT payload. The token_version # is incremented whenever the user logs out (or gets login revoked). Whenever # the JWT is used, it is only valid if the token_version claim is the same as the one # in the db. If not, the JWT is invalid and the user needs to login again. return async def refresh_token( self, token: Optional[str], # noqa: ARG002 user: User, # noqa: ARG002 ) -> str: """Issue a fresh JWT with a new expiry.""" return await self.write_token(user) def get_redis_strategy() -> TenantAwareRedisStrategy: return TenantAwareRedisStrategy() def get_database_strategy( access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db), ) -> RefreshableDatabaseStrategy: return RefreshableDatabaseStrategy( access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS ) def get_jwt_strategy() -> SingleTenantJWTStrategy: return SingleTenantJWTStrategy( secret=USER_AUTH_SECRET, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS, ) if AUTH_BACKEND == AuthBackend.JWT: if MULTI_TENANT or AUTH_TYPE == AuthType.CLOUD: raise ValueError( "JWT auth backend is only supported for single-tenant, self-hosted deployments. Use 'redis' or 'postgres' instead." ) if not USER_AUTH_SECRET: raise ValueError("USER_AUTH_SECRET is required for JWT auth backend.") if AUTH_BACKEND == AuthBackend.REDIS: auth_backend = AuthenticationBackend( name="redis", transport=cookie_transport, get_strategy=get_redis_strategy ) elif AUTH_BACKEND == AuthBackend.POSTGRES: auth_backend = AuthenticationBackend( name="postgres", transport=cookie_transport, get_strategy=get_database_strategy ) elif AUTH_BACKEND == AuthBackend.JWT: auth_backend = AuthenticationBackend( name="jwt", transport=cookie_transport, get_strategy=get_jwt_strategy ) else: raise ValueError(f"Invalid auth backend: {AUTH_BACKEND}") class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]): def get_logout_router( self, backend: AuthenticationBackend, requires_verification: bool = REQUIRE_EMAIL_VERIFICATION, ) -> APIRouter: """ Provide a router for logout only for OAuth/OIDC Flows. This way the login router does not need to be included """ router = APIRouter() get_current_user_token = self.authenticator.current_user_token( active=True, verified=requires_verification ) logout_responses: OpenAPIResponseType = { **{ status.HTTP_401_UNAUTHORIZED: { "description": "Missing token or inactive user." } }, **backend.transport.get_openapi_logout_responses_success(), } @router.post( "/logout", name=f"auth:{backend.name}.logout", responses=logout_responses ) async def logout( user_token: Tuple[models.UP, str] = Depends(get_current_user_token), strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), ) -> Response: user, token = user_token return await backend.logout(strategy, user, token) return router def get_refresh_router( self, backend: AuthenticationBackend, requires_verification: bool = REQUIRE_EMAIL_VERIFICATION, ) -> APIRouter: """ Provide a router for session token refreshing. """ # Import the oauth_refresher here to avoid circular imports from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens router = APIRouter() get_current_user_token = self.authenticator.current_user_token( active=True, verified=requires_verification ) refresh_responses: OpenAPIResponseType = { **{ status.HTTP_401_UNAUTHORIZED: { "description": "Missing token or inactive user." } }, **backend.transport.get_openapi_login_responses_success(), } @router.post( "/refresh", name=f"auth:{backend.name}.refresh", responses=refresh_responses ) async def refresh( user_token: Tuple[models.UP, str] = Depends(get_current_user_token), strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), user_manager: BaseUserManager[models.UP, models.ID] = Depends( get_user_manager ), db_session: AsyncSession = Depends(get_async_session), ) -> Response: try: user, token = user_token logger.info(f"Processing token refresh request for user {user.email}") # Check if user has OAuth accounts that need refreshing await check_and_refresh_oauth_tokens( user=cast(User, user), db_session=db_session, user_manager=cast(Any, user_manager), ) # Check if strategy supports refreshing supports_refresh = hasattr(strategy, "refresh_token") and callable( getattr(strategy, "refresh_token") ) if supports_refresh: try: refresh_method = getattr(strategy, "refresh_token") new_token = await refresh_method(token, user) logger.info( f"Successfully refreshed session token for user {user.email}" ) return await backend.transport.get_login_response(new_token) except Exception as e: logger.error(f"Error refreshing session token: {str(e)}") # Fallback to logout and login if refresh fails await backend.logout(strategy, user, token) return await backend.login(strategy, user) # Fallback: logout and login again logger.info( "Strategy doesn't support refresh - using logout/login flow" ) await backend.logout(strategy, user, token) return await backend.login(strategy, user) except Exception as e: logger.error(f"Unexpected error in refresh endpoint: {str(e)}") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail=f"Token refresh failed: {str(e)}", ) return router fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID]( get_user_manager, [auth_backend] ) # NOTE: verified=REQUIRE_EMAIL_VERIFICATION is not used here since we # take care of that in `double_check_user` ourself. This is needed, since # we want the /me endpoint to still return a user even if they are not # yet verified, so that the frontend knows they exist optional_fastapi_current_user = fastapi_users.current_user(active=True, optional=True) _JWT_EMAIL_CLAIM_KEYS = ("email", "preferred_username", "upn") def _extract_email_from_jwt(payload: dict[str, Any]) -> str | None: """Return the best-effort email/username from a decoded JWT payload.""" for key in _JWT_EMAIL_CLAIM_KEYS: value = payload.get(key) if isinstance(value, str) and value: try: email_info = validate_email(value, check_deliverability=False) except EmailNotValidError: continue normalized_email = email_info.normalized or email_info.email return normalized_email.lower() return None async def _sync_jwt_oidc_expiry( user_manager: UserManager, user: User, payload: dict[str, Any] ) -> None: if TRACK_EXTERNAL_IDP_EXPIRY: expires_at = payload.get("exp") if expires_at is None: return try: expiry_timestamp = int(expires_at) except (TypeError, ValueError): logger.warning("Invalid exp claim on JWT for user %s", user.email) return oidc_expiry = datetime.fromtimestamp(expiry_timestamp, tz=timezone.utc) if user.oidc_expiry == oidc_expiry: return await user_manager.user_db.update(user, {"oidc_expiry": oidc_expiry}) user.oidc_expiry = oidc_expiry return if user.oidc_expiry is not None: await user_manager.user_db.update(user, {"oidc_expiry": None}) user.oidc_expiry = None # type: ignore async def _get_or_create_user_from_jwt( payload: dict[str, Any], request: Request, async_db_session: AsyncSession, ) -> User | None: email = _extract_email_from_jwt(payload) if email is None: logger.warning( "JWT token decoded successfully but no email claim found; skipping auth" ) return None # Enforce the same allowlist/domain policies as other auth flows verify_email_is_invited(email) verify_email_domain(email) user_db: SQLAlchemyUserAdminDB[User, uuid.UUID] = SQLAlchemyUserAdminDB( async_db_session, User, OAuthAccount ) user_manager = UserManager(user_db) try: user = await user_manager.get_by_email(email) if not user.is_active: logger.warning("Inactive user %s attempted JWT login; skipping", email) return None if not user.account_type.is_web_login(): raise exceptions.UserNotExists() except exceptions.UserNotExists: logger.info("Provisioning user %s from JWT login", email) try: user = await user_manager.create( UserCreate( email=email, password=generate_password(), is_verified=True, ), request=request, ) except exceptions.UserAlreadyExists: user = await user_manager.get_by_email(email) if not user.is_active: logger.warning( "Inactive user %s attempted JWT login during provisioning race; skipping", email, ) return None if not user.account_type.is_web_login(): logger.warning( "Non-web-login user %s attempted JWT login during provisioning race; skipping", email, ) return None await _sync_jwt_oidc_expiry(user_manager, user, payload) return user async def _check_for_saml_and_jwt( request: Request, user: User | None, async_db_session: AsyncSession, ) -> User | None: # If user is None, check for JWT in Authorization header if user is None and JWT_PUBLIC_KEY_URL is not None: auth_header = request.headers.get("Authorization") if auth_header and auth_header.startswith("Bearer "): token = auth_header[len("Bearer ") :].strip() payload = await verify_jwt_token(token) if payload is not None: user = await _get_or_create_user_from_jwt( payload, request, async_db_session ) return user async def optional_user( request: Request, async_db_session: AsyncSession = Depends(get_async_session), user: User | None = Depends(optional_fastapi_current_user), ) -> User | None: if user := await _check_for_saml_and_jwt(request, user, async_db_session): # If user is already set, _check_for_saml_and_jwt returns the same user object return user try: if hashed_pat := get_hashed_pat_from_request(request): user = await fetch_user_for_pat(hashed_pat, async_db_session) elif hashed_api_key := get_hashed_api_key_from_request(request): user = await fetch_user_for_api_key(hashed_api_key, async_db_session) except ValueError: logger.warning("Issue with validating authentication token") return None return user def get_anonymous_user() -> User: """Create anonymous user object.""" user = User( id=uuid.UUID(ANONYMOUS_USER_UUID), email=ANONYMOUS_USER_EMAIL, hashed_password="", is_active=True, is_verified=True, is_superuser=False, role=UserRole.LIMITED, account_type=AccountType.ANONYMOUS, use_memories=False, enable_memory_tool=False, ) return user async def double_check_user( user: User | None, include_expired: bool = False, allow_anonymous_access: bool = False, ) -> User: if user is not None: # If user attempted to authenticate, verify them, do not default # to anonymous access if it fails. if user_needs_to_be_verified() and not user.is_verified: raise BasicAuthenticationError( detail="Access denied. User is not verified.", ) if ( user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc) and not include_expired ): raise BasicAuthenticationError( detail="Access denied. User's OIDC token has expired.", ) return user if allow_anonymous_access: return get_anonymous_user() raise BasicAuthenticationError( detail="Access denied. User is not authenticated.", ) async def current_user_with_expired_token( user: User | None = Depends(optional_user), ) -> User: return await double_check_user(user, include_expired=True) async def current_limited_user( user: User | None = Depends(optional_user), ) -> User: return await double_check_user(user) async def current_chat_accessible_user( user: User | None = Depends(optional_user), ) -> User: tenant_id = get_current_tenant_id() return await double_check_user( user, allow_anonymous_access=anonymous_user_enabled(tenant_id=tenant_id) ) async def current_user( user: User | None = Depends(optional_user), ) -> User: user = await double_check_user(user) if user.role == UserRole.LIMITED: raise BasicAuthenticationError( detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.", ) return user async def current_curator_or_admin_user( user: User = Depends(current_user), ) -> User: allowed_roles = {UserRole.GLOBAL_CURATOR, UserRole.CURATOR, UserRole.ADMIN} if user.role not in allowed_roles: raise BasicAuthenticationError( detail="Access denied. User is not a curator or admin.", ) return user async def current_admin_user(user: User = Depends(current_user)) -> User: if user.role != UserRole.ADMIN: raise BasicAuthenticationError( detail="Access denied. User must be an admin to perform this action.", ) return user async def _get_user_from_token_data(token_data: dict) -> User | None: """Shared logic: token data dict → User object. Args: token_data: Decoded token data containing 'sub' (user ID). Returns: User object if found and active, None otherwise. """ user_id = token_data.get("sub") if not user_id: return None try: user_uuid = uuid.UUID(user_id) except ValueError: return None async with get_async_session_context_manager() as async_db_session: user = await async_db_session.get(User, user_uuid) if user is None or not user.is_active: return None return user _LOOPBACK_HOSTNAMES = frozenset({"localhost", "127.0.0.1", "::1"}) def _is_same_origin(actual: str, expected: str) -> bool: """Compare two origins for the WebSocket CSWSH check. Scheme and hostname must match exactly. Port must also match, except when the hostname is a loopback address (localhost / 127.0.0.1 / ::1), where port is ignored. On loopback, all ports belong to the same operator, so port differences carry no security significance — the CSWSH threat is remote origins, not local ones. """ a = urlparse(actual.rstrip("/")) e = urlparse(expected.rstrip("/")) if a.scheme != e.scheme or a.hostname != e.hostname: return False if a.hostname in _LOOPBACK_HOSTNAMES: return True actual_port = a.port or (443 if a.scheme == "https" else 80) expected_port = e.port or (443 if e.scheme == "https" else 80) return actual_port == expected_port async def current_user_from_websocket( websocket: WebSocket, token: str = Query(..., description="WebSocket authentication token"), ) -> User: """ WebSocket authentication dependency using query parameter. Validates the WS token from query param and returns the User. Raises BasicAuthenticationError if authentication fails. The token must be obtained from POST /voice/ws-token before connecting. Tokens are single-use and expire after 60 seconds. Usage: 1. POST /voice/ws-token -> {"token": "xxx"} 2. Connect to ws://host/path?token=xxx This applies the same auth checks as current_user() for HTTP endpoints. """ # Check Origin header to prevent Cross-Site WebSocket Hijacking (CSWSH). # Browsers always send Origin on WebSocket connections. origin = websocket.headers.get("origin") if not origin: logger.warning("WS auth: missing Origin header") raise BasicAuthenticationError(detail="Access denied. Missing origin.") if not _is_same_origin(origin, WEB_DOMAIN): logger.warning(f"WS auth: origin mismatch. Expected {WEB_DOMAIN}, got {origin}") raise BasicAuthenticationError(detail="Access denied. Invalid origin.") # Validate WS token in Redis (single-use, deleted after retrieval) try: token_data = await retrieve_ws_token_data(token) if token_data is None: raise BasicAuthenticationError( detail="Access denied. Invalid or expired authentication token." ) except BasicAuthenticationError: raise except Exception as e: logger.error(f"WS auth: error during token validation: {e}") raise BasicAuthenticationError( detail="Authentication verification failed." ) from e # Get user from token data user = await _get_user_from_token_data(token_data) if user is None: logger.warning(f"WS auth: user not found for id={token_data.get('sub')}") raise BasicAuthenticationError( detail="Access denied. User not found or inactive." ) # Apply same checks as HTTP auth (verification, OIDC expiry, role) user = await double_check_user(user) # Block LIMITED users (same as current_user) if user.role == UserRole.LIMITED: logger.warning(f"WS auth: user {user.email} has LIMITED role") raise BasicAuthenticationError( detail="Access denied. User role is LIMITED. BASIC or higher permissions are required.", ) logger.debug(f"WS auth: authenticated {user.email}") return user def get_default_admin_user_emails_() -> list[str]: # No default seeding available for Onyx MIT return [] STATE_TOKEN_AUDIENCE = "fastapi-users:oauth-state" STATE_TOKEN_LIFETIME_SECONDS = 3600 CSRF_TOKEN_KEY = "csrftoken" CSRF_TOKEN_COOKIE_NAME = "fastapiusersoauthcsrf" PKCE_COOKIE_NAME_PREFIX = "fastapiusersoauthpkce" class OAuth2AuthorizeResponse(BaseModel): authorization_url: str def generate_state_token( data: Dict[str, str], secret: SecretType, lifetime_seconds: int = STATE_TOKEN_LIFETIME_SECONDS, ) -> str: data["aud"] = STATE_TOKEN_AUDIENCE return generate_jwt(data, secret, lifetime_seconds) def generate_csrf_token() -> str: return secrets.token_urlsafe(32) def _base64url_encode(data: bytes) -> str: return base64.urlsafe_b64encode(data).rstrip(b"=").decode("ascii") def generate_pkce_pair() -> tuple[str, str]: verifier = secrets.token_urlsafe(64) challenge = _base64url_encode(hashlib.sha256(verifier.encode("ascii")).digest()) return verifier, challenge def get_pkce_cookie_name(state: str) -> str: state_hash = hashlib.sha256(state.encode("utf-8")).hexdigest() return f"{PKCE_COOKIE_NAME_PREFIX}_{state_hash}" # refer to https://github.com/fastapi-users/fastapi-users/blob/42ddc241b965475390e2bce887b084152ae1a2cd/fastapi_users/fastapi_users.py#L91 def create_onyx_oauth_router( oauth_client: BaseOAuth2, backend: AuthenticationBackend, state_secret: SecretType, redirect_url: Optional[str] = None, associate_by_email: bool = False, is_verified_by_default: bool = False, enable_pkce: bool = False, ) -> APIRouter: return get_oauth_router( oauth_client, backend, get_user_manager, state_secret, redirect_url, associate_by_email, is_verified_by_default, enable_pkce=enable_pkce, ) def get_oauth_router( oauth_client: BaseOAuth2, backend: AuthenticationBackend, get_user_manager: UserManagerDependency[models.UP, models.ID], state_secret: SecretType, redirect_url: Optional[str] = None, associate_by_email: bool = False, is_verified_by_default: bool = False, *, csrf_token_cookie_name: str = CSRF_TOKEN_COOKIE_NAME, csrf_token_cookie_path: str = "/", csrf_token_cookie_domain: Optional[str] = None, csrf_token_cookie_secure: Optional[bool] = None, csrf_token_cookie_httponly: bool = True, csrf_token_cookie_samesite: Optional[Literal["lax", "strict", "none"]] = "lax", enable_pkce: bool = False, ) -> APIRouter: """Generate a router with the OAuth routes.""" router = APIRouter() callback_route_name = f"oauth:{oauth_client.name}.{backend.name}.callback" if redirect_url is not None: oauth2_authorize_callback = OAuth2AuthorizeCallback( oauth_client, redirect_url=redirect_url, ) else: oauth2_authorize_callback = OAuth2AuthorizeCallback( oauth_client, route_name=callback_route_name, ) async def null_access_token_state() -> tuple[OAuth2Token, Optional[str]] | None: return None access_token_state_dependency = ( oauth2_authorize_callback if not enable_pkce else null_access_token_state ) if csrf_token_cookie_secure is None: csrf_token_cookie_secure = WEB_DOMAIN.startswith("https") @router.get( "/authorize", name=f"oauth:{oauth_client.name}.{backend.name}.authorize", response_model=OAuth2AuthorizeResponse, ) async def authorize( request: Request, response: Response, redirect: bool = Query(False), scopes: List[str] = Query(None), ) -> Response | OAuth2AuthorizeResponse: referral_source = request.cookies.get("referral_source", None) if redirect_url is not None: authorize_redirect_url = redirect_url else: # Use WEB_DOMAIN instead of request.url_for() to prevent host # header poisoning — request.url_for() trusts the Host header. callback_path = request.app.url_path_for(callback_route_name) authorize_redirect_url = f"{WEB_DOMAIN}{callback_path}" next_url = request.query_params.get("next", "/") csrf_token = generate_csrf_token() state_data: Dict[str, str] = { "next_url": next_url, "referral_source": referral_source or "default_referral", CSRF_TOKEN_KEY: csrf_token, } state = generate_state_token(state_data, state_secret) pkce_cookie: tuple[str, str] | None = None if enable_pkce: code_verifier, code_challenge = generate_pkce_pair() pkce_cookie_name = get_pkce_cookie_name(state) pkce_cookie = (pkce_cookie_name, code_verifier) authorization_url = await oauth_client.get_authorization_url( authorize_redirect_url, state, scopes, code_challenge=code_challenge, code_challenge_method="S256", ) else: # Get the basic authorization URL authorization_url = await oauth_client.get_authorization_url( authorize_redirect_url, state, scopes, ) # For Google OAuth, add parameters to request refresh tokens if oauth_client.name == "google": authorization_url = add_url_params( authorization_url, {"access_type": "offline", "prompt": "consent"} ) def set_oauth_cookie( target_response: Response, *, key: str, value: str, ) -> None: target_response.set_cookie( key=key, value=value, max_age=STATE_TOKEN_LIFETIME_SECONDS, path=csrf_token_cookie_path, domain=csrf_token_cookie_domain, secure=csrf_token_cookie_secure, httponly=csrf_token_cookie_httponly, samesite=csrf_token_cookie_samesite, ) response_with_cookies: Response if redirect: response_with_cookies = RedirectResponse(authorization_url, status_code=302) else: response_with_cookies = response set_oauth_cookie( response_with_cookies, key=csrf_token_cookie_name, value=csrf_token, ) if pkce_cookie is not None: pkce_cookie_name, code_verifier = pkce_cookie set_oauth_cookie( response_with_cookies, key=pkce_cookie_name, value=code_verifier, ) if redirect: return response_with_cookies return OAuth2AuthorizeResponse(authorization_url=authorization_url) @log_function_time(print_only=True) @router.get( "/callback", name=callback_route_name, description="The response varies based on the authentication backend used.", responses={ status.HTTP_400_BAD_REQUEST: { "model": ErrorModel, "content": { "application/json": { "examples": { "INVALID_STATE_TOKEN": { "summary": "Invalid state token.", "value": None, }, ErrorCode.LOGIN_BAD_CREDENTIALS: { "summary": "User is inactive.", "value": {"detail": ErrorCode.LOGIN_BAD_CREDENTIALS}, }, } } }, }, }, ) async def callback( request: Request, access_token_state: Tuple[OAuth2Token, Optional[str]] | None = Depends( access_token_state_dependency ), code: Optional[str] = None, state: Optional[str] = None, error: Optional[str] = None, user_manager: BaseUserManager[models.UP, models.ID] = Depends(get_user_manager), strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy), ) -> Response: pkce_cookie_name: str | None = None def delete_pkce_cookie(response: Response) -> None: if enable_pkce and pkce_cookie_name: response.delete_cookie( key=pkce_cookie_name, path=csrf_token_cookie_path, domain=csrf_token_cookie_domain, secure=csrf_token_cookie_secure, httponly=csrf_token_cookie_httponly, samesite=csrf_token_cookie_samesite, ) def build_error_response(exc: OnyxError) -> JSONResponse: log_onyx_error(exc) error_response = onyx_error_to_json_response(exc) delete_pkce_cookie(error_response) return error_response def decode_and_validate_state(state_value: str) -> Dict[str, str]: try: state_data = decode_jwt( state_value, state_secret, [STATE_TOKEN_AUDIENCE] ) except jwt.DecodeError: raise OnyxError( OnyxErrorCode.VALIDATION_ERROR, getattr( ErrorCode, "ACCESS_TOKEN_DECODE_ERROR", "ACCESS_TOKEN_DECODE_ERROR", ), ) except jwt.ExpiredSignatureError: raise OnyxError( OnyxErrorCode.VALIDATION_ERROR, getattr( ErrorCode, "ACCESS_TOKEN_ALREADY_EXPIRED", "ACCESS_TOKEN_ALREADY_EXPIRED", ), ) except jwt.PyJWTError: raise OnyxError( OnyxErrorCode.VALIDATION_ERROR, getattr( ErrorCode, "ACCESS_TOKEN_DECODE_ERROR", "ACCESS_TOKEN_DECODE_ERROR", ), ) cookie_csrf_token = request.cookies.get(csrf_token_cookie_name) state_csrf_token = state_data.get(CSRF_TOKEN_KEY) if ( not cookie_csrf_token or not state_csrf_token or not secrets.compare_digest(cookie_csrf_token, state_csrf_token) ): raise OnyxError( OnyxErrorCode.VALIDATION_ERROR, getattr(ErrorCode, "OAUTH_INVALID_STATE", "OAUTH_INVALID_STATE"), ) return state_data token: OAuth2Token state_data: Dict[str, str] # `code`, `state`, and `error` are read directly only in the PKCE path. # In the non-PKCE path, `oauth2_authorize_callback` consumes them. if enable_pkce: if state is not None: pkce_cookie_name = get_pkce_cookie_name(state) if error is not None: return build_error_response( OnyxError( OnyxErrorCode.VALIDATION_ERROR, "Authorization request failed or was denied", ) ) if code is None: return build_error_response( OnyxError( OnyxErrorCode.VALIDATION_ERROR, "Missing authorization code in OAuth callback", ) ) if state is None: return build_error_response( OnyxError( OnyxErrorCode.VALIDATION_ERROR, "Missing state parameter in OAuth callback", ) ) state_value = state if redirect_url is not None: callback_redirect_url = redirect_url else: callback_path = request.app.url_path_for(callback_route_name) callback_redirect_url = f"{WEB_DOMAIN}{callback_path}" code_verifier = request.cookies.get(cast(str, pkce_cookie_name)) if not code_verifier: return build_error_response( OnyxError( OnyxErrorCode.VALIDATION_ERROR, "Missing PKCE verifier cookie in OAuth callback", ) ) try: state_data = decode_and_validate_state(state_value) except OnyxError as e: return build_error_response(e) try: token = await oauth_client.get_access_token( code, callback_redirect_url, code_verifier ) except GetAccessTokenError: return build_error_response( OnyxError( OnyxErrorCode.VALIDATION_ERROR, "Authorization code exchange failed", ) ) else: if access_token_state is None: raise OnyxError( OnyxErrorCode.INTERNAL_ERROR, "Missing OAuth callback state" ) token, callback_state = access_token_state if callback_state is None: raise OnyxError( OnyxErrorCode.VALIDATION_ERROR, "Missing state parameter in OAuth callback", ) state_data = decode_and_validate_state(callback_state) async def complete_login_flow( token: OAuth2Token, state_data: Dict[str, str] ) -> RedirectResponse: account_id, account_email = await oauth_client.get_id_email( token["access_token"] ) if account_email is None: raise OnyxError( OnyxErrorCode.VALIDATION_ERROR, ErrorCode.OAUTH_NOT_AVAILABLE_EMAIL, ) next_url = state_data.get("next_url", "/") referral_source = state_data.get("referral_source", None) try: tenant_id = fetch_ee_implementation_or_noop( "onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None )(account_email) except exceptions.UserNotExists: tenant_id = None request.state.referral_source = referral_source # Proceed to authenticate or create the user try: user = await user_manager.oauth_callback( oauth_client.name, token["access_token"], account_id, account_email, token.get("expires_at"), token.get("refresh_token"), request, associate_by_email=associate_by_email, is_verified_by_default=is_verified_by_default, ) except UserAlreadyExists: raise OnyxError( OnyxErrorCode.VALIDATION_ERROR, ErrorCode.OAUTH_USER_ALREADY_EXISTS, ) if not user.is_active: raise OnyxError( OnyxErrorCode.VALIDATION_ERROR, ErrorCode.LOGIN_BAD_CREDENTIALS, ) # Login user response = await backend.login(strategy, user) await user_manager.on_after_login(user, request, response) # Prepare redirect response if tenant_id is None: # Use URL utility to add parameters redirect_destination = add_url_params(next_url, {"new_team": "true"}) redirect_response = RedirectResponse( redirect_destination, status_code=302 ) else: # No parameters to add redirect_response = RedirectResponse(next_url, status_code=302) # Copy headers from auth response to redirect response, with special handling for Set-Cookie for header_name, header_value in response.headers.items(): header_name_lower = header_name.lower() if header_name_lower == "set-cookie": redirect_response.headers.append(header_name, header_value) continue if header_name_lower in {"location", "content-length"}: continue redirect_response.headers[header_name] = header_value return redirect_response if enable_pkce: try: redirect_response = await complete_login_flow(token, state_data) except OnyxError as e: return build_error_response(e) delete_pkce_cookie(redirect_response) return redirect_response return await complete_login_flow(token, state_data) return router ================================================ FILE: backend/onyx/auth/utils.py ================================================ """Shared authentication utilities for bearer token extraction and validation.""" from collections.abc import Callable from urllib.parse import unquote from fastapi import Request from onyx.auth.constants import API_KEY_HEADER_ALTERNATIVE_NAME from onyx.auth.constants import API_KEY_HEADER_NAME from onyx.auth.constants import API_KEY_PREFIX from onyx.auth.constants import BEARER_PREFIX from onyx.auth.constants import DEPRECATED_API_KEY_PREFIX from onyx.auth.constants import PAT_PREFIX def get_hashed_bearer_token_from_request( request: Request, valid_prefixes: list[str], hash_fn: Callable[[str], str], allow_non_bearer: bool = False, ) -> str | None: """Generic extraction and hashing of bearer tokens from request headers. Args: request: The FastAPI request valid_prefixes: List of valid token prefixes (e.g., ["on_", "onyx_pat_"]) hash_fn: Function to hash the token (e.g., hash_api_key or hash_pat) allow_non_bearer: If True, accept raw tokens without "Bearer " prefix Returns: Hashed token if valid format, else None """ auth_header = request.headers.get( API_KEY_HEADER_ALTERNATIVE_NAME ) or request.headers.get(API_KEY_HEADER_NAME) if not auth_header: return None # Handle bearer format if auth_header.startswith(BEARER_PREFIX): token = auth_header[len(BEARER_PREFIX) :].strip() elif allow_non_bearer: token = auth_header else: return None # Check if token starts with any valid prefix if valid_prefixes: valid = any(token.startswith(prefix) for prefix in valid_prefixes) if not valid: return None return hash_fn(token) def _extract_tenant_from_bearer_token( request: Request, valid_prefixes: list[str] ) -> str | None: """Generic tenant extraction from bearer token. Returns None if invalid format. Args: request: The FastAPI request valid_prefixes: List of valid token prefixes (e.g., ["on_", "dn_"]) Returns: Tenant ID if found in format ., else None """ auth_header = request.headers.get( API_KEY_HEADER_ALTERNATIVE_NAME ) or request.headers.get(API_KEY_HEADER_NAME) if not auth_header or not auth_header.startswith(BEARER_PREFIX): return None token = auth_header[len(BEARER_PREFIX) :].strip() # Check if token starts with any valid prefix matched_prefix = None for prefix in valid_prefixes: if token.startswith(prefix): matched_prefix = prefix break if not matched_prefix: return None # Parse tenant from token format: . parts = token[len(matched_prefix) :].split(".", 1) if len(parts) != 2: return None tenant_id = parts[0] return unquote(tenant_id) if tenant_id else None def extract_tenant_from_auth_header(request: Request) -> str | None: """Extract tenant ID from API key or PAT header. Unified function for extracting tenant from any bearer token (API key or PAT). Checks all known token prefixes in order. Returns: Tenant ID if found, else None """ return _extract_tenant_from_bearer_token( request, [API_KEY_PREFIX, DEPRECATED_API_KEY_PREFIX, PAT_PREFIX] ) ================================================ FILE: backend/onyx/background/README.md ================================================ # Overview of Onyx Background Jobs The background jobs take care of: 1. Pulling/Indexing documents (from connectors) 2. Updating document metadata (from connectors) 3. Cleaning up checkpoints and logic around indexing work (indexing indexing checkpoints and index attempt metadata) 4. Handling user uploaded files and deletions (from the Projects feature and uploads via the Chat) 5. Reporting metrics on things like queue length for monitoring purposes ## Worker → Queue Mapping | Worker | File | Queues | |--------|------|--------| | Primary | `apps/primary.py` | `celery` | | Light | `apps/light.py` | `vespa_metadata_sync`, `connector_deletion`, `doc_permissions_upsert`, `checkpoint_cleanup`, `index_attempt_cleanup` | | Heavy | `apps/heavy.py` | `connector_pruning`, `connector_doc_permissions_sync`, `connector_external_group_sync`, `csv_generation`, `sandbox` | | Docprocessing | `apps/docprocessing.py` | `docprocessing` | | Docfetching | `apps/docfetching.py` | `connector_doc_fetching` | | User File Processing | `apps/user_file_processing.py` | `user_file_processing`, `user_file_project_sync`, `user_file_delete` | | Monitoring | `apps/monitoring.py` | `monitoring` | | Background (consolidated) | `apps/background.py` | All queues above except `celery` | ## Non-Worker Apps | App | File | Purpose | |-----|------|---------| | **Beat** | `beat.py` | Celery beat scheduler with `DynamicTenantScheduler` that generates per-tenant periodic task schedules | | **Client** | `client.py` | Minimal app for task submission from non-worker processes (e.g., API server) | ### Shared Module `app_base.py` provides: - `TenantAwareTask` - Base task class that sets tenant context - Signal handlers for logging, cleanup, and lifecycle events - Readiness probes and health checks ## Worker Details ### Primary (Coordinator and task dispatcher) It is the single worker which handles tasks from the default celery queue. It is a singleton worker ensured by the `PRIMARY_WORKER` Redis lock which it touches every `CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8` seconds (using Celery Bootsteps) On startup: - waits for redis, postgres, document index to all be healthy - acquires the singleton lock - cleans all the redis states associated with background jobs - mark orphaned index attempts failed Then it cycles through its tasks as scheduled by Celery Beat: | Task | Frequency | Description | |------|-----------|-------------| | `check_for_indexing` | 15s | Scans for connectors needing indexing → dispatches to `DOCFETCHING` queue | | `check_for_vespa_sync_task` | 20s | Finds stale documents/document sets → dispatches sync tasks to `VESPA_METADATA_SYNC` queue | | `check_for_pruning` | 20s | Finds connectors due for pruning → dispatches to `CONNECTOR_PRUNING` queue | | `check_for_connector_deletion` | 20s | Processes deletion requests → dispatches to `CONNECTOR_DELETION` queue | | `check_for_user_file_processing` | 20s | Checks for user uploads → dispatches to `USER_FILE_PROCESSING` queue | | `check_for_checkpoint_cleanup` | 1h | Cleans up old indexing checkpoints | | `check_for_index_attempt_cleanup` | 30m | Cleans up old index attempts | | `kombu_message_cleanup_task` | periodic | Cleans orphaned Kombu messages from DB (Kombu being the messaging framework used by Celery) | | `celery_beat_heartbeat` | 1m | Heartbeat for Beat watchdog | Watchdog is a separate Python process managed by supervisord which runs alongside celery workers. It checks the ONYX_CELERY_BEAT_HEARTBEAT_KEY in Redis to ensure Celery Beat is not dead. Beat schedules the celery_beat_heartbeat for Primary to touch the key and share that it's still alive. See supervisord.conf for watchdog config. ### Light Fast and short living tasks that are not resource intensive. High concurrency: Can have 24 concurrent workers, each with a prefetch of 8 for a total of 192 tasks in flight at once. Tasks it handles: - Syncs access/permissions, document sets, boosts, hidden state - Deletes documents that are marked for deletion in Postgres - Cleanup of checkpoints and index attempts ### Heavy Long running, resource intensive tasks, handles pruning and sandbox operations. Low concurrency - max concurrency of 4 with 1 prefetch. Does not interact with the Document Index, it handles the syncs with external systems. Large volume API calls to handle pruning and fetching permissions, etc. Generates CSV exports which may take a long time with significant data in Postgres. Sandbox (new feature) for running Next.js, Python virtual env, OpenCode AI Agent, and access to knowledge files ### Docprocessing, Docfetching, User File Processing Docprocessing and Docfetching are for indexing documents: - Docfetching runs connectors to pull documents from external APIs (Google Drive, Confluence, etc.), stores batches to file storage, and dispatches docprocessing tasks - Docprocessing retrieves batches, runs the indexing pipeline (chunking, embedding), and indexes into the Document Index User Files come from uploads directly via the input bar ### Monitoring Observability and metrics collections: - Queue lengths, connector success/failure, lconnector latencies - Memory of supervisor managed processes (workers, beat, slack) - Cloud and multitenant specific monitorings ================================================ FILE: backend/onyx/background/celery/apps/app_base.py ================================================ import logging import multiprocessing import os import time from typing import Any from typing import cast import sentry_sdk from celery import bootsteps # type: ignore from celery import Task from celery.app import trace from celery.exceptions import WorkerShutdown from celery.signals import task_postrun from celery.signals import task_prerun from celery.states import READY_STATES from celery.utils.log import get_task_logger from celery.worker import strategy # type: ignore from redis.lock import Lock as RedisLock from sentry_sdk.integrations.celery import CeleryIntegration from sqlalchemy import text from sqlalchemy.orm import Session from onyx import __version__ from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatter from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter from onyx.background.celery.celery_utils import celery_is_worker_primary from onyx.background.celery.celery_utils import make_probe_path from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_PREFIX from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_TASKSET_KEY from onyx.configs.app_configs import DISABLE_VECTOR_DB from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX from onyx.configs.constants import OnyxRedisLocks from onyx.db.engine.sql_engine import get_sqlalchemy_engine from onyx.document_index.opensearch.client import ( wait_for_opensearch_with_timeout, ) from onyx.document_index.vespa.shared_utils.utils import wait_for_vespa_with_timeout from onyx.httpx.httpx_pool import HttpxPool from onyx.redis.redis_connector import RedisConnector from onyx.redis.redis_connector_delete import RedisConnectorDelete from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync from onyx.redis.redis_connector_prune import RedisConnectorPrune from onyx.redis.redis_document_set import RedisDocumentSet from onyx.redis.redis_pool import get_redis_client from onyx.redis.redis_usergroup import RedisUserGroup from onyx.tracing.setup import setup_tracing from onyx.utils.logger import ColoredFormatter from onyx.utils.logger import LoggerContextVars from onyx.utils.logger import PlainFormatter from onyx.utils.logger import setup_logger from shared_configs.configs import DEV_LOGGING_ENABLED from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.configs import SENTRY_DSN from shared_configs.configs import TENANT_ID_PREFIX from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() task_logger = get_task_logger(__name__) if SENTRY_DSN: sentry_sdk.init( dsn=SENTRY_DSN, integrations=[CeleryIntegration()], traces_sample_rate=0.1, release=__version__, ) logger.info("Sentry initialized") else: logger.debug("Sentry DSN not provided, skipping Sentry initialization") class TenantAwareTask(Task): """A custom base Task that sets tenant_id in a contextvar before running.""" abstract = True # So Celery knows not to register this as a real task. def __call__(self, *args: Any, **kwargs: Any) -> Any: # Grab tenant_id from the kwargs, or fallback to default if missing. tenant_id = kwargs.get("tenant_id", None) or POSTGRES_DEFAULT_SCHEMA # Set the context var CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) # Actually run the task now try: return super().__call__(*args, **kwargs) finally: # Clear or reset after the task runs # so it does not leak into any subsequent tasks on the same worker process CURRENT_TENANT_ID_CONTEXTVAR.set(None) @task_prerun.connect def on_task_prerun( sender: Any | None = None, # noqa: ARG001 task_id: str | None = None, # noqa: ARG001 task: Task | None = None, # noqa: ARG001 args: tuple[Any, ...] | None = None, # noqa: ARG001 kwargs: dict[str, Any] | None = None, # noqa: ARG001 **other_kwargs: Any, # noqa: ARG001 ) -> None: # Reset any per-task logging context so that prefixes (e.g. pruning_ctx) # from a previous task executed in the same worker process do not leak # into the next task's log messages. This fixes incorrect [CC Pair:/Index Attempt] # prefixes observed when a pruning task finishes and an indexing task # runs in the same process. LoggerContextVars.reset() def on_task_postrun( sender: Any | None = None, # noqa: ARG001 task_id: str | None = None, task: Task | None = None, args: tuple | None = None, # noqa: ARG001 kwargs: dict[str, Any] | None = None, retval: Any | None = None, # noqa: ARG001 state: str | None = None, **kwds: Any, # noqa: ARG001 ) -> None: """We handle this signal in order to remove completed tasks from their respective tasksets. This allows us to track the progress of document set and user group syncs. This function runs after any task completes (both success and failure) Note that this signal does not fire on a task that failed to complete and is going to be retried. This also does not fire if a worker with acks_late=False crashes (which all of our long running workers are) """ if not task: return task_logger.debug(f"Task {task.name} (ID: {task_id}) completed with state: {state}") if state not in READY_STATES: return if not task_id: return if task.name.startswith(ONYX_CLOUD_CELERY_TASK_PREFIX): # this is a cloud / all tenant task ... no postrun is needed return # Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg if not kwargs: logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs") tenant_id = POSTGRES_DEFAULT_SCHEMA else: tenant_id = cast(str, kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA)) task_logger.debug( f"Task {task.name} (ID: {task_id}) completed with state: {state} {f'for tenant_id={tenant_id}' if tenant_id else ''}" ) r = get_redis_client(tenant_id=tenant_id) # NOTE: we want to remove the `Redis*` classes, prefer to just have functions to # do these things going forward. In short, things should generally be like the doc # sync task rather than the others below if task_id.startswith(DOCUMENT_SYNC_PREFIX): r.srem(DOCUMENT_SYNC_TASKSET_KEY, task_id) return if task_id.startswith(RedisDocumentSet.PREFIX): document_set_id = RedisDocumentSet.get_id_from_task_id(task_id) if document_set_id is not None: rds = RedisDocumentSet(tenant_id, int(document_set_id)) r.srem(rds.taskset_key, task_id) return if task_id.startswith(RedisUserGroup.PREFIX): usergroup_id = RedisUserGroup.get_id_from_task_id(task_id) if usergroup_id is not None: rug = RedisUserGroup(tenant_id, int(usergroup_id)) r.srem(rug.taskset_key, task_id) return if task_id.startswith(RedisConnectorDelete.PREFIX): cc_pair_id = RedisConnector.get_id_from_task_id(task_id) if cc_pair_id is not None: RedisConnectorDelete.remove_from_taskset(int(cc_pair_id), task_id, r) return if task_id.startswith(RedisConnectorPrune.SUBTASK_PREFIX): cc_pair_id = RedisConnector.get_id_from_task_id(task_id) if cc_pair_id is not None: RedisConnectorPrune.remove_from_taskset(int(cc_pair_id), task_id, r) return if task_id.startswith(RedisConnectorPermissionSync.SUBTASK_PREFIX): cc_pair_id = RedisConnector.get_id_from_task_id(task_id) if cc_pair_id is not None: RedisConnectorPermissionSync.remove_from_taskset( int(cc_pair_id), task_id, r ) return if task_id.startswith(RedisConnectorExternalGroupSync.SUBTASK_PREFIX): cc_pair_id = RedisConnector.get_id_from_task_id(task_id) if cc_pair_id is not None: RedisConnectorExternalGroupSync.remove_from_taskset( int(cc_pair_id), task_id, r ) return def on_celeryd_init( sender: str, # noqa: ARG001 conf: Any = None, # noqa: ARG001 **kwargs: Any, # noqa: ARG001 ) -> None: """The first signal sent on celery worker startup""" # NOTE(rkuo): start method "fork" is unsafe and we really need it to be "spawn" # But something is blocking set_start_method from working in the cloud unless # force=True. so we use force=True as a fallback. all_start_methods: list[str] = multiprocessing.get_all_start_methods() logger.info(f"Multiprocessing all start methods: {all_start_methods}") try: multiprocessing.set_start_method("spawn") # fork is unsafe, set to spawn except Exception: logger.info( "Multiprocessing set_start_method exceptioned. Trying force=True..." ) try: multiprocessing.set_start_method( "spawn", force=True ) # fork is unsafe, set to spawn except Exception: logger.info( "Multiprocessing set_start_method force=True exceptioned even with force=True." ) logger.info( f"Multiprocessing selected start method: {multiprocessing.get_start_method()}" ) # Initialize tracing in workers if credentials are available. setup_tracing() def wait_for_redis(sender: Any, **kwargs: Any) -> None: # noqa: ARG001 """Waits for redis to become ready subject to a hardcoded timeout. Will raise WorkerShutdown to kill the celery worker if the timeout is reached.""" r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA) WAIT_INTERVAL = 5 WAIT_LIMIT = 60 ready = False time_start = time.monotonic() logger.info("Redis: Readiness probe starting.") while True: try: if r.ping(): ready = True break except Exception: pass time_elapsed = time.monotonic() - time_start if time_elapsed > WAIT_LIMIT: break logger.info( f"Redis: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}" ) time.sleep(WAIT_INTERVAL) if not ready: msg = f"Redis: Readiness probe did not succeed within the timeout ({WAIT_LIMIT} seconds). Exiting..." logger.error(msg) raise WorkerShutdown(msg) logger.info("Redis: Readiness probe succeeded. Continuing...") return def wait_for_db(sender: Any, **kwargs: Any) -> None: # noqa: ARG001 """Waits for the db to become ready subject to a hardcoded timeout. Will raise WorkerShutdown to kill the celery worker if the timeout is reached.""" WAIT_INTERVAL = 5 WAIT_LIMIT = 60 ready = False time_start = time.monotonic() logger.info("Database: Readiness probe starting.") while True: try: with Session(get_sqlalchemy_engine()) as db_session: result = db_session.execute(text("SELECT NOW()")).scalar() if result: ready = True break except Exception: pass time_elapsed = time.monotonic() - time_start if time_elapsed > WAIT_LIMIT: break logger.info( f"Database: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}" ) time.sleep(WAIT_INTERVAL) if not ready: msg = f"Database: Readiness probe did not succeed within the timeout ({WAIT_LIMIT} seconds). Exiting..." logger.error(msg) raise WorkerShutdown(msg) logger.info("Database: Readiness probe succeeded. Continuing...") return def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None: # noqa: ARG001 logger.info(f"Running as a secondary celery worker: pid={os.getpid()}") # Set up variables for waiting on primary worker WAIT_INTERVAL = 5 WAIT_LIMIT = 60 r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA) time_start = time.monotonic() logger.info("Waiting for primary worker to be ready...") while True: if r.exists(OnyxRedisLocks.PRIMARY_WORKER): break time_elapsed = time.monotonic() - time_start logger.info( f"Primary worker is not ready yet. elapsed={time_elapsed:.1f} timeout={WAIT_LIMIT:.1f}" ) if time_elapsed > WAIT_LIMIT: msg = f"Primary worker was not ready within the timeout. ({WAIT_LIMIT} seconds). Exiting..." logger.error(msg) raise WorkerShutdown(msg) time.sleep(WAIT_INTERVAL) logger.info("Wait for primary worker completed successfully. Continuing...") return def on_worker_ready(sender: Any, **kwargs: Any) -> None: # noqa: ARG001 task_logger.info("worker_ready signal received.") # file based way to do readiness/liveness probes # https://medium.com/ambient-innovation/health-checks-for-celery-in-kubernetes-cf3274a3e106 # https://github.com/celery/celery/issues/4079#issuecomment-1270085680 hostname: str = cast(str, sender.hostname) path = make_probe_path("readiness", hostname) path.touch() logger.info(f"Readiness signal touched at {path}.") def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: # noqa: ARG001 HttpxPool.close_all() hostname: str = cast(str, sender.hostname) path = make_probe_path("readiness", hostname) path.unlink(missing_ok=True) if not celery_is_worker_primary(sender): return if not hasattr(sender, "primary_worker_lock"): # primary_worker_lock will not exist when MULTI_TENANT is True return if not sender.primary_worker_lock: return logger.info("Releasing primary worker lock.") lock: RedisLock = sender.primary_worker_lock try: if lock.owned(): try: lock.release() sender.primary_worker_lock = None except Exception: logger.exception("Failed to release primary worker lock") except Exception: logger.exception("Failed to check if primary worker lock is owned") def on_setup_logging( loglevel: int, logfile: str | None, format: str, # noqa: ARG001 colorize: bool, # noqa: ARG001 **kwargs: Any, # noqa: ARG001 ) -> None: # TODO: could unhardcode format and colorize and accept these as options from # celery's config root_logger = logging.getLogger() root_logger.handlers = [] # Define the log format log_format = ( "%(levelname)-8s %(asctime)s %(filename)15s:%(lineno)-4d: %(name)s %(message)s" ) # Set up the root handler root_handler = logging.StreamHandler() root_formatter = ColoredFormatter( log_format, datefmt="%m/%d/%Y %I:%M:%S %p", ) root_handler.setFormatter(root_formatter) root_logger.addHandler(root_handler) if logfile: # Truncate log file if DEV_LOGGING_ENABLED (for clean dev experience) if DEV_LOGGING_ENABLED and os.path.exists(logfile): try: open(logfile, "w").close() # Truncate the file except Exception: pass # Ignore errors, just proceed with normal logging root_file_handler = logging.FileHandler(logfile) root_file_formatter = PlainFormatter( log_format, datefmt="%m/%d/%Y %I:%M:%S %p", ) root_file_handler.setFormatter(root_file_formatter) root_logger.addHandler(root_file_handler) root_logger.setLevel(loglevel) # Configure the task logger task_logger.handlers = [] task_handler = logging.StreamHandler() task_handler.addFilter(TenantContextFilter()) task_formatter = CeleryTaskColoredFormatter( log_format, datefmt="%m/%d/%Y %I:%M:%S %p", ) task_handler.setFormatter(task_formatter) task_logger.addHandler(task_handler) if logfile: # No need to truncate again, already done above for root logger task_file_handler = logging.FileHandler(logfile) task_file_handler.addFilter(TenantContextFilter()) task_file_formatter = CeleryTaskPlainFormatter( log_format, datefmt="%m/%d/%Y %I:%M:%S %p", ) task_file_handler.setFormatter(task_file_formatter) task_logger.addHandler(task_file_handler) task_logger.setLevel(loglevel) task_logger.propagate = False # hide celery task received spam # e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] received" strategy.logger.setLevel(logging.WARNING) # uncomment this to hide celery task succeeded/failed spam # e.g. "Task check_for_pruning[a1e96171-0ba8-4e00-887b-9fbf7442eab3] succeeded in 0.03137450001668185s: None" trace.logger.setLevel(logging.WARNING) def set_task_finished_log_level(logLevel: int) -> None: """call this to override the setLevel in on_setup_logging. We are interested in the task timings in the cloud but it can be spammy for self hosted.""" trace.logger.setLevel(logLevel) class TenantContextFilter(logging.Filter): """Logging filter to inject tenant ID into the logger's name.""" def filter(self, record: logging.LogRecord) -> bool: if not MULTI_TENANT: record.name = "" return True tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get() if tenant_id: # Match the 8 character tenant abbreviation used in OnyxLoggingAdapter tenant_id = tenant_id.split(TENANT_ID_PREFIX)[-1][:8] record.name = f"[t:{tenant_id}]" else: record.name = "" return True @task_postrun.connect def reset_tenant_id( sender: Any | None = None, # noqa: ARG001 task_id: str | None = None, # noqa: ARG001 task: Task | None = None, # noqa: ARG001 args: tuple[Any, ...] | None = None, # noqa: ARG001 kwargs: dict[str, Any] | None = None, # noqa: ARG001 **other_kwargs: Any, # noqa: ARG001 ) -> None: """Signal handler to reset tenant ID in context var after task ends.""" CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA) def wait_for_vespa_or_shutdown( sender: Any, # noqa: ARG001 **kwargs: Any, # noqa: ARG001 ) -> None: # noqa: ARG001 """Waits for Vespa to become ready subject to a timeout. Raises WorkerShutdown if the timeout is reached.""" if DISABLE_VECTOR_DB: logger.info( "DISABLE_VECTOR_DB is set — skipping Vespa/OpenSearch readiness check." ) return if not wait_for_vespa_with_timeout(): msg = "[Vespa] Readiness probe did not succeed within the timeout. Exiting..." logger.error(msg) raise WorkerShutdown(msg) if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX: if not wait_for_opensearch_with_timeout(): msg = "[OpenSearch] Readiness probe did not succeed within the timeout. Exiting..." logger.error(msg) raise WorkerShutdown(msg) # File for validating worker liveness class LivenessProbe(bootsteps.StartStopStep): requires = {"celery.worker.components:Timer"} def __init__(self, worker: Any, **kwargs: Any) -> None: super().__init__(worker, **kwargs) self.requests: list[Any] = [] self.task_tref = None self.path = make_probe_path("liveness", worker.hostname) def start(self, worker: Any) -> None: self.task_tref = worker.timer.call_repeatedly( 15.0, self.update_liveness_file, (worker,), priority=10, ) def stop(self, worker: Any) -> None: # noqa: ARG002 self.path.unlink(missing_ok=True) if self.task_tref: self.task_tref.cancel() def update_liveness_file(self, worker: Any) -> None: # noqa: ARG002 self.path.touch() def get_bootsteps() -> list[type]: return [LivenessProbe] # Task modules that require a vector DB (Vespa/OpenSearch). # When DISABLE_VECTOR_DB is True these are excluded from autodiscover lists. _VECTOR_DB_TASK_MODULES: set[str] = { "onyx.background.celery.tasks.connector_deletion", "onyx.background.celery.tasks.docprocessing", "onyx.background.celery.tasks.docfetching", "onyx.background.celery.tasks.pruning", "onyx.background.celery.tasks.vespa", "onyx.background.celery.tasks.opensearch_migration", "onyx.background.celery.tasks.doc_permission_syncing", "onyx.background.celery.tasks.hierarchyfetching", # EE modules that are vector-DB-dependent "ee.onyx.background.celery.tasks.doc_permission_syncing", "ee.onyx.background.celery.tasks.external_group_syncing", } # NOTE: "onyx.background.celery.tasks.shared" is intentionally NOT in the set # above. It contains celery_beat_heartbeat (which only writes to Redis) alongside # document cleanup tasks. The cleanup tasks won't be invoked in minimal mode # because the periodic tasks that trigger them are in other filtered modules. def filter_task_modules(modules: list[str]) -> list[str]: """Remove vector-DB-dependent task modules when DISABLE_VECTOR_DB is True.""" if not DISABLE_VECTOR_DB: return modules return [m for m in modules if m not in _VECTOR_DB_TASK_MODULES] ================================================ FILE: backend/onyx/background/celery/apps/beat.py ================================================ from datetime import timedelta from typing import Any from celery import Celery from celery import signals from celery.beat import PersistentScheduler # type: ignore from celery.signals import beat_init from celery.utils.log import get_task_logger import onyx.background.celery.apps.app_base as app_base from onyx.background.celery.celery_utils import make_probe_path from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME from onyx.db.engine.sql_engine import SqlEngine from onyx.db.engine.tenant_utils import get_all_tenant_ids from onyx.server.runtime.onyx_runtime import OnyxRuntime from onyx.utils.variable_functionality import fetch_versioned_implementation from shared_configs.configs import IGNORED_SYNCING_TENANT_LIST from shared_configs.configs import MULTI_TENANT task_logger = get_task_logger(__name__) celery_app = Celery(__name__) celery_app.config_from_object("onyx.background.celery.configs.beat") class DynamicTenantScheduler(PersistentScheduler): """This scheduler is useful because we can dynamically adjust task generation rates through it.""" RELOAD_INTERVAL = 60 def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.last_beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT self._reload_interval = timedelta( seconds=DynamicTenantScheduler.RELOAD_INTERVAL ) self._last_reload = self.app.now() - self._reload_interval # Let the parent class handle store initialization self.setup_schedule() task_logger.info( f"DynamicTenantScheduler initialized: reload_interval={self._reload_interval}" ) self._liveness_probe_path = make_probe_path("liveness", "beat@hostname") # do not set the initial schedule here because we don't have db access yet. # do it in beat_init after the db engine is initialized # An initial schedule is required ... otherwise, the scheduler will delay # for 5 minutes before calling tick() def setup_schedule(self) -> None: super().setup_schedule() def tick(self) -> float: retval = super().tick() now = self.app.now() if ( self._last_reload is None or (now - self._last_reload) > self._reload_interval ): task_logger.debug("Reload interval reached, initiating task update") self._liveness_probe_path.touch() try: self._try_updating_schedule() except (AttributeError, KeyError): task_logger.exception("Failed to process task configuration") except Exception: task_logger.exception("Unexpected error updating tasks") self._last_reload = now return retval def _generate_schedule( self, tenant_ids: list[str] | list[None], beat_multiplier: float ) -> dict[str, dict[str, Any]]: """Given a list of tenant id's, generates a new beat schedule for celery.""" new_schedule: dict[str, dict[str, Any]] = {} if MULTI_TENANT: # cloud tasks are system wide and thus only need to be on the beat schedule # once for all tenants get_cloud_tasks_to_schedule = fetch_versioned_implementation( "onyx.background.celery.tasks.beat_schedule", "get_cloud_tasks_to_schedule", ) cloud_tasks_to_schedule: list[dict[str, Any]] = get_cloud_tasks_to_schedule( beat_multiplier ) for task in cloud_tasks_to_schedule: task_name = task["name"] cloud_task = { "task": task["task"], "schedule": task["schedule"], "kwargs": task.get("kwargs", {}), } if options := task.get("options"): task_logger.debug(f"Adding options to task {task_name}: {options}") cloud_task["options"] = options new_schedule[task_name] = cloud_task # regular task beats are multiplied across all tenants # note that currently this just schedules for a single tenant in self hosted # and doesn't do anything in the cloud because it's much more scalable # to schedule a single cloud beat task to dispatch per tenant tasks. get_tasks_to_schedule = fetch_versioned_implementation( "onyx.background.celery.tasks.beat_schedule", "get_tasks_to_schedule" ) tasks_to_schedule: list[dict[str, Any]] = get_tasks_to_schedule() for tenant_id in tenant_ids: if IGNORED_SYNCING_TENANT_LIST and tenant_id in IGNORED_SYNCING_TENANT_LIST: task_logger.debug( f"Skipping tenant {tenant_id} as it is in the ignored syncing list" ) continue for task in tasks_to_schedule: task_name = task["name"] tenant_task_name = f"{task['name']}-{tenant_id}" task_logger.debug(f"Creating task configuration for {tenant_task_name}") tenant_task = { "task": task["task"], "schedule": task["schedule"], "kwargs": {"tenant_id": tenant_id}, } if options := task.get("options"): task_logger.debug( f"Adding options to task {tenant_task_name}: {options}" ) tenant_task["options"] = options new_schedule[tenant_task_name] = tenant_task return new_schedule def _try_updating_schedule(self) -> None: """Only updates the actual beat schedule on the celery app when it changes""" do_update = False task_logger.debug("_try_updating_schedule starting") tenant_ids = get_all_tenant_ids() task_logger.debug(f"Found {len(tenant_ids)} IDs") # get current schedule and extract current tenants current_schedule = self.schedule.items() # get potential new state try: beat_multiplier = OnyxRuntime.get_beat_multiplier() except Exception: beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT new_schedule = self._generate_schedule(tenant_ids, beat_multiplier) # if the schedule or beat multiplier has changed, update while True: if beat_multiplier != self.last_beat_multiplier: do_update = True break if not DynamicTenantScheduler._compare_schedules( current_schedule, new_schedule ): do_update = True break break if not do_update: # exit early if nothing changed task_logger.info( f"_try_updating_schedule - Schedule unchanged: tasks={len(new_schedule)} beat_multiplier={beat_multiplier}" ) return # schedule needs updating task_logger.debug( "Schedule update required", extra={ "new_tasks": len(new_schedule), "current_tasks": len(current_schedule), }, ) # Create schedule entries entries = {} for name, entry in new_schedule.items(): entries[name] = self.Entry( name=name, app=self.app, task=entry["task"], schedule=entry["schedule"], options=entry.get("options", {}), kwargs=entry.get("kwargs", {}), ) # Update the schedule using the scheduler's methods self.schedule.clear() self.schedule.update(entries) # Ensure changes are persisted self.sync() task_logger.info( f"_try_updating_schedule - Schedule updated: " f"prev_num_tasks={len(current_schedule)} " f"prev_beat_multiplier={self.last_beat_multiplier} " f"tasks={len(new_schedule)} " f"beat_multiplier={beat_multiplier}" ) self.last_beat_multiplier = beat_multiplier @staticmethod def _compare_schedules(schedule1: dict, schedule2: dict) -> bool: """Compare schedules by task name only to determine if an update is needed. True if equivalent, False if not.""" current_tasks = set(name for name, _ in schedule1) new_tasks = set(schedule2.keys()) return current_tasks == new_tasks @beat_init.connect def on_beat_init(sender: Any, **kwargs: Any) -> None: task_logger.info("beat_init signal received.") # Celery beat shouldn't touch the db at all. But just setting a low minimum here. SqlEngine.set_app_name(POSTGRES_CELERY_BEAT_APP_NAME) SqlEngine.init_engine(pool_size=2, max_overflow=0) app_base.wait_for_redis(sender, **kwargs) path = make_probe_path("readiness", "beat@hostname") path.touch() task_logger.info(f"Readiness signal touched at {path}.") # first time init of the scheduler after db has been init'ed scheduler: DynamicTenantScheduler = sender.scheduler scheduler._try_updating_schedule() @signals.setup_logging.connect def on_setup_logging( loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any ) -> None: app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs) celery_app.conf.beat_scheduler = DynamicTenantScheduler celery_app.conf.task_default_base = app_base.TenantAwareTask ================================================ FILE: backend/onyx/background/celery/apps/client.py ================================================ from celery import Celery import onyx.background.celery.apps.app_base as app_base celery_app = Celery(__name__) celery_app.config_from_object("onyx.background.celery.configs.client") celery_app.Task = app_base.TenantAwareTask # type: ignore [misc] ================================================ FILE: backend/onyx/background/celery/apps/docfetching.py ================================================ from typing import Any from typing import cast from celery import Celery from celery import signals from celery import Task from celery.apps.worker import Worker from celery.signals import celeryd_init from celery.signals import worker_init from celery.signals import worker_ready from celery.signals import worker_shutdown import onyx.background.celery.apps.app_base as app_base from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME from onyx.db.engine.sql_engine import SqlEngine from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected from onyx.server.metrics.celery_task_metrics import on_celery_task_retry from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun from onyx.server.metrics.metrics_server import start_metrics_server from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT logger = setup_logger() celery_app = Celery(__name__) celery_app.config_from_object("onyx.background.celery.configs.docfetching") celery_app.Task = app_base.TenantAwareTask # type: ignore [misc] @signals.task_prerun.connect def on_task_prerun( sender: Any | None = None, task_id: str | None = None, task: Task | None = None, args: tuple | None = None, kwargs: dict | None = None, **kwds: Any, ) -> None: app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds) on_celery_task_prerun(task_id, task) on_indexing_task_prerun(task_id, task, kwargs) @signals.task_postrun.connect def on_task_postrun( sender: Any | None = None, task_id: str | None = None, task: Task | None = None, args: tuple | None = None, kwargs: dict | None = None, retval: Any | None = None, state: str | None = None, **kwds: Any, ) -> None: app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds) on_celery_task_postrun(task_id, task, state) on_indexing_task_postrun(task_id, task, kwargs, state) @signals.task_retry.connect def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001 # task_retry signal doesn't pass task_id in kwargs; get it from # the sender (the task instance) via sender.request.id. task_id = getattr(getattr(sender, "request", None), "id", None) on_celery_task_retry(task_id, sender) @signals.task_revoked.connect def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None: task_name = getattr(sender, "name", None) or str(sender) on_celery_task_revoked(kwargs.get("task_id"), task_name) @signals.task_rejected.connect def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001 # task_rejected sends the Consumer as sender, not the task instance. # The task name must be extracted from the Celery message headers. message = kwargs.get("message") task_name: str | None = None if message is not None: headers = getattr(message, "headers", None) or {} task_name = headers.get("task") if task_name is None: task_name = "unknown" on_celery_task_rejected(None, task_name) @celeryd_init.connect def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None: app_base.on_celeryd_init(sender, conf, **kwargs) @worker_init.connect def on_worker_init(sender: Worker, **kwargs: Any) -> None: logger.info("worker_init signal received.") SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME) pool_size = cast(int, sender.concurrency) # type: ignore SqlEngine.init_engine(pool_size=pool_size, max_overflow=8) app_base.wait_for_redis(sender, **kwargs) app_base.wait_for_db(sender, **kwargs) app_base.wait_for_vespa_or_shutdown(sender, **kwargs) # Less startup checks in multi-tenant case if MULTI_TENANT: return app_base.on_secondary_worker_init(sender, **kwargs) @worker_ready.connect def on_worker_ready(sender: Any, **kwargs: Any) -> None: start_metrics_server("docfetching") app_base.on_worker_ready(sender, **kwargs) @worker_shutdown.connect def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: app_base.on_worker_shutdown(sender, **kwargs) @signals.setup_logging.connect def on_setup_logging( loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any ) -> None: app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs) base_bootsteps = app_base.get_bootsteps() for bootstep in base_bootsteps: celery_app.steps["worker"].add(bootstep) celery_app.autodiscover_tasks( app_base.filter_task_modules( [ "onyx.background.celery.tasks.docfetching", ] ) ) ================================================ FILE: backend/onyx/background/celery/apps/docprocessing.py ================================================ from typing import Any from typing import cast from celery import Celery from celery import signals from celery import Task from celery.apps.worker import Worker from celery.signals import celeryd_init from celery.signals import worker_init from celery.signals import worker_process_init from celery.signals import worker_ready from celery.signals import worker_shutdown import onyx.background.celery.apps.app_base as app_base from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME from onyx.db.engine.sql_engine import SqlEngine from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected from onyx.server.metrics.celery_task_metrics import on_celery_task_retry from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun from onyx.server.metrics.metrics_server import start_metrics_server from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT logger = setup_logger() celery_app = Celery(__name__) celery_app.config_from_object("onyx.background.celery.configs.docprocessing") celery_app.Task = app_base.TenantAwareTask # type: ignore [misc] @signals.task_prerun.connect def on_task_prerun( sender: Any | None = None, task_id: str | None = None, task: Task | None = None, args: tuple | None = None, kwargs: dict | None = None, **kwds: Any, ) -> None: app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds) on_celery_task_prerun(task_id, task) on_indexing_task_prerun(task_id, task, kwargs) @signals.task_postrun.connect def on_task_postrun( sender: Any | None = None, task_id: str | None = None, task: Task | None = None, args: tuple | None = None, kwargs: dict | None = None, retval: Any | None = None, state: str | None = None, **kwds: Any, ) -> None: app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds) on_celery_task_postrun(task_id, task, state) on_indexing_task_postrun(task_id, task, kwargs, state) @signals.task_retry.connect def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001 # task_retry signal doesn't pass task_id in kwargs; get it from # the sender (the task instance) via sender.request.id. task_id = getattr(getattr(sender, "request", None), "id", None) on_celery_task_retry(task_id, sender) @signals.task_revoked.connect def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None: task_name = getattr(sender, "name", None) or str(sender) on_celery_task_revoked(kwargs.get("task_id"), task_name) @signals.task_rejected.connect def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001 # task_rejected sends the Consumer as sender, not the task instance. # The task name must be extracted from the Celery message headers. message = kwargs.get("message") task_name: str | None = None if message is not None: headers = getattr(message, "headers", None) or {} task_name = headers.get("task") if task_name is None: task_name = "unknown" on_celery_task_rejected(None, task_name) @celeryd_init.connect def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None: app_base.on_celeryd_init(sender, conf, **kwargs) @worker_init.connect def on_worker_init(sender: Worker, **kwargs: Any) -> None: logger.info("worker_init signal received.") SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME) # rkuo: Transient errors keep happening in the indexing watchdog threads. # "SSL connection has been closed unexpectedly" # actually setting the spawn method in the cloud fixes 95% of these. # setting pre ping might help even more, but not worrying about that yet pool_size = cast(int, sender.concurrency) # type: ignore SqlEngine.init_engine(pool_size=pool_size, max_overflow=8) app_base.wait_for_redis(sender, **kwargs) app_base.wait_for_db(sender, **kwargs) app_base.wait_for_vespa_or_shutdown(sender, **kwargs) # Less startup checks in multi-tenant case if MULTI_TENANT: return app_base.on_secondary_worker_init(sender, **kwargs) @worker_ready.connect def on_worker_ready(sender: Any, **kwargs: Any) -> None: start_metrics_server("docprocessing") app_base.on_worker_ready(sender, **kwargs) @worker_shutdown.connect def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: app_base.on_worker_shutdown(sender, **kwargs) # Note: worker_process_init only fires in prefork pool mode. Docprocessing uses # worker_pool="threads" (see configs/docprocessing.py), so this handler is # effectively a no-op in normal operation. It remains as a safety net in case # the pool type is ever changed to prefork. Prometheus metrics are safe in # thread-pool mode since all threads share the same process memory and can # update the same Counter/Gauge/Histogram objects directly. @worker_process_init.connect def init_worker(**kwargs: Any) -> None: # noqa: ARG001 SqlEngine.reset_engine() @signals.setup_logging.connect def on_setup_logging( loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any ) -> None: app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs) base_bootsteps = app_base.get_bootsteps() for bootstep in base_bootsteps: celery_app.steps["worker"].add(bootstep) celery_app.autodiscover_tasks( app_base.filter_task_modules( [ "onyx.background.celery.tasks.docprocessing", ] ) ) ================================================ FILE: backend/onyx/background/celery/apps/heavy.py ================================================ from typing import Any from typing import cast from celery import Celery from celery import signals from celery import Task from celery.apps.worker import Worker from celery.signals import celeryd_init from celery.signals import worker_init from celery.signals import worker_ready from celery.signals import worker_shutdown import onyx.background.celery.apps.app_base as app_base from onyx.configs.constants import POSTGRES_CELERY_WORKER_HEAVY_APP_NAME from onyx.db.engine.sql_engine import SqlEngine from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT logger = setup_logger() celery_app = Celery(__name__) celery_app.config_from_object("onyx.background.celery.configs.heavy") celery_app.Task = app_base.TenantAwareTask # type: ignore [misc] @signals.task_prerun.connect def on_task_prerun( sender: Any | None = None, task_id: str | None = None, task: Task | None = None, args: tuple | None = None, kwargs: dict | None = None, **kwds: Any, ) -> None: app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds) @signals.task_postrun.connect def on_task_postrun( sender: Any | None = None, task_id: str | None = None, task: Task | None = None, args: tuple | None = None, kwargs: dict | None = None, retval: Any | None = None, state: str | None = None, **kwds: Any, ) -> None: app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds) @celeryd_init.connect def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None: app_base.on_celeryd_init(sender, conf, **kwargs) @worker_init.connect def on_worker_init(sender: Worker, **kwargs: Any) -> None: logger.info("worker_init signal received.") SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_HEAVY_APP_NAME) pool_size = cast(int, sender.concurrency) # type: ignore SqlEngine.init_engine(pool_size=pool_size, max_overflow=8) app_base.wait_for_redis(sender, **kwargs) app_base.wait_for_db(sender, **kwargs) app_base.wait_for_vespa_or_shutdown(sender, **kwargs) # Less startup checks in multi-tenant case if MULTI_TENANT: return app_base.on_secondary_worker_init(sender, **kwargs) @worker_ready.connect def on_worker_ready(sender: Any, **kwargs: Any) -> None: app_base.on_worker_ready(sender, **kwargs) @worker_shutdown.connect def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: app_base.on_worker_shutdown(sender, **kwargs) @signals.setup_logging.connect def on_setup_logging( loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any ) -> None: app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs) base_bootsteps = app_base.get_bootsteps() for bootstep in base_bootsteps: celery_app.steps["worker"].add(bootstep) celery_app.autodiscover_tasks( app_base.filter_task_modules( [ "onyx.background.celery.tasks.pruning", # Sandbox tasks (file sync, cleanup) "onyx.server.features.build.sandbox.tasks", "onyx.background.celery.tasks.hierarchyfetching", ] ) ) ================================================ FILE: backend/onyx/background/celery/apps/light.py ================================================ from typing import Any from celery import Celery from celery import signals from celery import Task from celery.apps.worker import Worker from celery.signals import celeryd_init from celery.signals import worker_init from celery.signals import worker_ready from celery.signals import worker_shutdown import onyx.background.celery.apps.app_base as app_base from onyx.background.celery.celery_utils import httpx_init_vespa_pool from onyx.configs.app_configs import MANAGED_VESPA from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH from onyx.configs.constants import POSTGRES_CELERY_WORKER_LIGHT_APP_NAME from onyx.db.engine.sql_engine import SqlEngine from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT logger = setup_logger() celery_app = Celery(__name__) celery_app.config_from_object("onyx.background.celery.configs.light") celery_app.Task = app_base.TenantAwareTask # type: ignore [misc] @signals.task_prerun.connect def on_task_prerun( sender: Any | None = None, task_id: str | None = None, task: Task | None = None, args: tuple | None = None, kwargs: dict | None = None, **kwds: Any, ) -> None: app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds) @signals.task_postrun.connect def on_task_postrun( sender: Any | None = None, task_id: str | None = None, task: Task | None = None, args: tuple | None = None, kwargs: dict | None = None, retval: Any | None = None, state: str | None = None, **kwds: Any, ) -> None: app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds) @celeryd_init.connect def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None: app_base.on_celeryd_init(sender, conf, **kwargs) @worker_init.connect def on_worker_init(sender: Worker, **kwargs: Any) -> None: EXTRA_CONCURRENCY = 8 # small extra fudge factor for connection limits logger.info("worker_init signal received.") logger.info(f"Concurrency: {sender.concurrency}") # type: ignore SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_LIGHT_APP_NAME) SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=EXTRA_CONCURRENCY) # type: ignore if MANAGED_VESPA: httpx_init_vespa_pool( sender.concurrency + EXTRA_CONCURRENCY, # type: ignore ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH, ) else: httpx_init_vespa_pool(sender.concurrency + EXTRA_CONCURRENCY) # type: ignore app_base.wait_for_redis(sender, **kwargs) app_base.wait_for_db(sender, **kwargs) app_base.wait_for_vespa_or_shutdown(sender, **kwargs) # Less startup checks in multi-tenant case if MULTI_TENANT: return app_base.on_secondary_worker_init(sender, **kwargs) @worker_ready.connect def on_worker_ready(sender: Any, **kwargs: Any) -> None: app_base.on_worker_ready(sender, **kwargs) @worker_shutdown.connect def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: app_base.on_worker_shutdown(sender, **kwargs) @signals.setup_logging.connect def on_setup_logging( loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any ) -> None: app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs) base_bootsteps = app_base.get_bootsteps() for bootstep in base_bootsteps: celery_app.steps["worker"].add(bootstep) celery_app.autodiscover_tasks( app_base.filter_task_modules( [ "onyx.background.celery.tasks.shared", "onyx.background.celery.tasks.vespa", "onyx.background.celery.tasks.connector_deletion", "onyx.background.celery.tasks.doc_permission_syncing", "onyx.background.celery.tasks.docprocessing", "onyx.background.celery.tasks.opensearch_migration", # Sandbox cleanup tasks (isolated in build feature) "onyx.server.features.build.sandbox.tasks", ] ) ) ================================================ FILE: backend/onyx/background/celery/apps/monitoring.py ================================================ import multiprocessing from typing import Any from celery import Celery from celery import signals from celery import Task from celery.signals import celeryd_init from celery.signals import worker_init from celery.signals import worker_ready from celery.signals import worker_shutdown import onyx.background.celery.apps.app_base as app_base from onyx.configs.constants import POSTGRES_CELERY_WORKER_MONITORING_APP_NAME from onyx.db.engine.sql_engine import SqlEngine from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT logger = setup_logger() celery_app = Celery(__name__) celery_app.config_from_object("onyx.background.celery.configs.monitoring") celery_app.Task = app_base.TenantAwareTask # type: ignore [misc] @signals.task_prerun.connect def on_task_prerun( sender: Any | None = None, task_id: str | None = None, task: Task | None = None, args: tuple | None = None, kwargs: dict | None = None, **kwds: Any, ) -> None: app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds) @signals.task_postrun.connect def on_task_postrun( sender: Any | None = None, task_id: str | None = None, task: Task | None = None, args: tuple | None = None, kwargs: dict | None = None, retval: Any | None = None, state: str | None = None, **kwds: Any, ) -> None: app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds) @celeryd_init.connect def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None: app_base.on_celeryd_init(sender, conf, **kwargs) # Set by on_worker_init so on_worker_ready knows whether to start the server. _prometheus_collectors_ok: bool = False @worker_init.connect def on_worker_init(sender: Any, **kwargs: Any) -> None: global _prometheus_collectors_ok logger.info("worker_init signal received.") logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}") SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_MONITORING_APP_NAME) SqlEngine.init_engine(pool_size=sender.concurrency, max_overflow=3) app_base.wait_for_redis(sender, **kwargs) app_base.wait_for_db(sender, **kwargs) _prometheus_collectors_ok = _setup_prometheus_collectors(sender) # Less startup checks in multi-tenant case if MULTI_TENANT: return app_base.on_secondary_worker_init(sender, **kwargs) def _setup_prometheus_collectors(sender: Any) -> bool: """Register Prometheus collectors that need Redis/DB access. Passes the Celery app so the queue depth collector can obtain a fresh broker Redis client on each scrape (rather than holding a stale reference). Returns True if registration succeeded, False otherwise. """ try: from onyx.server.metrics.indexing_pipeline_setup import ( setup_indexing_pipeline_metrics, ) setup_indexing_pipeline_metrics(sender.app) logger.info("Prometheus indexing pipeline collectors registered") return True except Exception: logger.exception("Failed to register Prometheus indexing pipeline collectors") return False @worker_ready.connect def on_worker_ready(sender: Any, **kwargs: Any) -> None: if _prometheus_collectors_ok: from onyx.server.metrics.metrics_server import start_metrics_server start_metrics_server("monitoring") else: logger.warning( "Skipping Prometheus metrics server — collector registration failed" ) app_base.on_worker_ready(sender, **kwargs) @worker_shutdown.connect def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: app_base.on_worker_shutdown(sender, **kwargs) @signals.setup_logging.connect def on_setup_logging( loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any ) -> None: app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs) base_bootsteps = app_base.get_bootsteps() for bootstep in base_bootsteps: celery_app.steps["worker"].add(bootstep) celery_app.autodiscover_tasks( app_base.filter_task_modules( [ "onyx.background.celery.tasks.monitoring", ] ) ) ================================================ FILE: backend/onyx/background/celery/apps/primary.py ================================================ import logging import os from typing import Any from typing import cast from celery import bootsteps # type: ignore from celery import Celery from celery import signals from celery import Task from celery.apps.worker import Worker from celery.exceptions import WorkerShutdown from celery.result import AsyncResult from celery.signals import celeryd_init from celery.signals import worker_init from celery.signals import worker_ready from celery.signals import worker_shutdown from redis.lock import Lock as RedisLock import onyx.background.celery.apps.app_base as app_base from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.celery_utils import celery_is_worker_primary from onyx.background.celery.tasks.vespa.document_sync import reset_document_sync from onyx.configs.app_configs import CELERY_WORKER_PRIMARY_POOL_OVERFLOW from onyx.configs.constants import CELERY_PRIMARY_WORKER_LOCK_TIMEOUT from onyx.configs.constants import OnyxRedisConstants from onyx.configs.constants import OnyxRedisLocks from onyx.configs.constants import POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.engine.sql_engine import SqlEngine from onyx.db.index_attempt import get_index_attempt from onyx.db.index_attempt import mark_attempt_canceled from onyx.db.indexing_coordination import IndexingCoordination from onyx.redis.redis_connector_delete import RedisConnectorDelete from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync from onyx.redis.redis_connector_prune import RedisConnectorPrune from onyx.redis.redis_connector_stop import RedisConnectorStop from onyx.redis.redis_document_set import RedisDocumentSet from onyx.redis.redis_pool import get_redis_client from onyx.redis.redis_usergroup import RedisUserGroup from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA logger = setup_logger() celery_app = Celery(__name__) celery_app.config_from_object("onyx.background.celery.configs.primary") celery_app.Task = app_base.TenantAwareTask # type: ignore [misc] @signals.task_prerun.connect def on_task_prerun( sender: Any | None = None, task_id: str | None = None, task: Task | None = None, args: tuple | None = None, kwargs: dict | None = None, **kwds: Any, ) -> None: app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds) @signals.task_postrun.connect def on_task_postrun( sender: Any | None = None, task_id: str | None = None, task: Task | None = None, args: tuple | None = None, kwargs: dict | None = None, retval: Any | None = None, state: str | None = None, **kwds: Any, ) -> None: app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds) @celeryd_init.connect def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None: app_base.on_celeryd_init(sender, conf, **kwargs) @worker_init.connect def on_worker_init(sender: Worker, **kwargs: Any) -> None: logger.info("worker_init signal received.") SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME) pool_size = cast(int, sender.concurrency) # type: ignore SqlEngine.init_engine( pool_size=pool_size, max_overflow=CELERY_WORKER_PRIMARY_POOL_OVERFLOW ) app_base.wait_for_redis(sender, **kwargs) app_base.wait_for_db(sender, **kwargs) app_base.wait_for_vespa_or_shutdown(sender, **kwargs) logger.info(f"Running as the primary celery worker: pid={os.getpid()}") # Less startup checks in multi-tenant case if MULTI_TENANT: return # This is singleton work that should be done on startup exactly once # by the primary worker. This is unnecessary in the multi tenant scenario r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA) # Log the role and slave count - being connected to a slave or slave count > 0 could be problematic replication_info: dict[str, Any] = cast(dict, r.info("replication")) role: str = cast(str, replication_info.get("role", "")) connected_slaves: int = replication_info.get("connected_slaves", 0) logger.info( f"Redis INFO REPLICATION: role={role} connected_slaves={connected_slaves}" ) memory_info: dict[str, Any] = cast(dict, r.info("memory")) maxmemory_policy: str = cast(str, memory_info.get("maxmemory_policy", "")) logger.info(f"Redis INFO MEMORY: maxmemory_policy={maxmemory_policy}") # For the moment, we're assuming that we are the only primary worker # that should be running. # TODO: maybe check for or clean up another zombie primary worker if we detect it r.delete(OnyxRedisLocks.PRIMARY_WORKER) # this process wide lock is taken to help other workers start up in order. # it is planned to use this lock to enforce singleton behavior on the primary # worker, since the primary worker does redis cleanup on startup, but this isn't # implemented yet. # set thread_local=False since we don't control what thread the periodic task might # reacquire the lock with lock: RedisLock = r.lock( OnyxRedisLocks.PRIMARY_WORKER, timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, thread_local=False, ) logger.info("Primary worker lock: Acquire starting.") acquired = lock.acquire(blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2) if acquired: logger.info("Primary worker lock: Acquire succeeded.") else: logger.error("Primary worker lock: Acquire failed!") raise WorkerShutdown("Primary worker lock could not be acquired!") # tacking on our own user data to the sender sender.primary_worker_lock = lock # type: ignore # As currently designed, when this worker starts as "primary", we reinitialize redis # to a clean state (for our purposes, anyway) r.delete(OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK) r.delete(OnyxRedisConstants.ACTIVE_FENCES) # NOTE: we want to remove the `Redis*` classes, prefer to just have functions # This is the preferred way to do this going forward reset_document_sync(r) RedisDocumentSet.reset_all(r) RedisUserGroup.reset_all(r) RedisConnectorDelete.reset_all(r) RedisConnectorPrune.reset_all(r) RedisConnectorStop.reset_all(r) RedisConnectorPermissionSync.reset_all(r) RedisConnectorExternalGroupSync.reset_all(r) # mark orphaned index attempts as failed # This uses database coordination instead of Redis fencing with get_session_with_current_tenant() as db_session: # Get potentially orphaned attempts (those with active status and task IDs) potentially_orphaned_ids = IndexingCoordination.get_orphaned_index_attempt_ids( db_session ) for attempt_id in potentially_orphaned_ids: attempt = get_index_attempt(db_session, attempt_id) # handle case where not started or docfetching is done but indexing is not if ( not attempt or not attempt.celery_task_id or attempt.total_batches is not None ): continue # Check if the Celery task actually exists try: result: AsyncResult = AsyncResult(attempt.celery_task_id) # If the task is not in PENDING state, it exists in Celery if result.state != "PENDING": continue # Task is orphaned - mark as failed failure_reason = ( f"Orphaned index attempt found on startup - Celery task not found: " f"index_attempt={attempt.id} " f"cc_pair={attempt.connector_credential_pair_id} " f"search_settings={attempt.search_settings_id} " f"celery_task_id={attempt.celery_task_id}" ) logger.warning(failure_reason) mark_attempt_canceled(attempt.id, db_session, failure_reason) except Exception: # If we can't check the task status, be conservative and continue logger.warning( f"Could not verify Celery task status on startup for attempt {attempt.id}, task_id={attempt.celery_task_id}" ) @worker_ready.connect def on_worker_ready(sender: Any, **kwargs: Any) -> None: app_base.on_worker_ready(sender, **kwargs) @worker_shutdown.connect def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: app_base.on_worker_shutdown(sender, **kwargs) @signals.setup_logging.connect def on_setup_logging( loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any ) -> None: app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs) # this can be spammy, so just enable it in the cloud for now if MULTI_TENANT: app_base.set_task_finished_log_level(logging.INFO) class HubPeriodicTask(bootsteps.StartStopStep): """Regularly reacquires the primary worker lock outside of the task queue. Use the task_logger in this class to avoid double logging. This cannot be done inside a regular beat task because it must run on schedule and a queue of existing work would starve the task from running. """ # it's unclear to me whether using the hub's timer or the bootstep timer is better requires = {"celery.worker.components:Hub"} def __init__(self, worker: Any, **kwargs: Any) -> None: # noqa: ARG002 self.interval = CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 8 # Interval in seconds self.task_tref = None def start(self, worker: Any) -> None: if not celery_is_worker_primary(worker): return # Access the worker's event loop (hub) hub = worker.consumer.controller.hub # Schedule the periodic task self.task_tref = hub.call_repeatedly( self.interval, self.run_periodic_task, worker ) task_logger.info("Scheduled periodic task with hub.") def run_periodic_task(self, worker: Any) -> None: try: if not celery_is_worker_primary(worker): return if not hasattr(worker, "primary_worker_lock"): return lock: RedisLock = worker.primary_worker_lock r = get_redis_client(tenant_id=POSTGRES_DEFAULT_SCHEMA) if lock.owned(): task_logger.debug("Reacquiring primary worker lock.") lock.reacquire() else: task_logger.warning( "Full acquisition of primary worker lock. Reasons could be worker restart or lock expiration." ) lock = r.lock( OnyxRedisLocks.PRIMARY_WORKER, timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT, ) task_logger.info("Primary worker lock: Acquire starting.") acquired = lock.acquire( blocking_timeout=CELERY_PRIMARY_WORKER_LOCK_TIMEOUT / 2 ) if acquired: task_logger.info("Primary worker lock: Acquire succeeded.") worker.primary_worker_lock = lock else: task_logger.error("Primary worker lock: Acquire failed!") raise TimeoutError("Primary worker lock could not be acquired!") except Exception: task_logger.exception("Periodic task failed.") def stop(self, worker: Any) -> None: # noqa: ARG002 # Cancel the scheduled task when the worker stops if self.task_tref: self.task_tref.cancel() task_logger.info("Canceled periodic task with hub.") celery_app.steps["worker"].add(HubPeriodicTask) base_bootsteps = app_base.get_bootsteps() for bootstep in base_bootsteps: celery_app.steps["worker"].add(bootstep) celery_app.autodiscover_tasks( app_base.filter_task_modules( [ "onyx.background.celery.tasks.connector_deletion", "onyx.background.celery.tasks.docprocessing", "onyx.background.celery.tasks.evals", "onyx.background.celery.tasks.hierarchyfetching", "onyx.background.celery.tasks.periodic", "onyx.background.celery.tasks.pruning", "onyx.background.celery.tasks.shared", "onyx.background.celery.tasks.vespa", "onyx.background.celery.tasks.llm_model_update", "onyx.background.celery.tasks.user_file_processing", ] ) ) ================================================ FILE: backend/onyx/background/celery/apps/task_formatters.py ================================================ import logging from celery import current_task from onyx.utils.logger import ColoredFormatter from onyx.utils.logger import PlainFormatter class CeleryTaskPlainFormatter(PlainFormatter): def format(self, record: logging.LogRecord) -> str: task = current_task if task and task.request: record.__dict__.update(task_id=task.request.id, task_name=task.name) record.msg = f"[{task.name}({task.request.id})] {record.msg}" return super().format(record) class CeleryTaskColoredFormatter(ColoredFormatter): def format(self, record: logging.LogRecord) -> str: task = current_task if task and task.request: record.__dict__.update(task_id=task.request.id, task_name=task.name) record.msg = f"[{task.name}({task.request.id})] {record.msg}" return super().format(record) ================================================ FILE: backend/onyx/background/celery/apps/user_file_processing.py ================================================ from typing import Any from typing import cast from celery import Celery from celery import signals from celery import Task from celery.apps.worker import Worker from celery.signals import celeryd_init from celery.signals import worker_init from celery.signals import worker_process_init from celery.signals import worker_ready from celery.signals import worker_shutdown import onyx.background.celery.apps.app_base as app_base from onyx.configs.constants import POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME from onyx.db.engine.sql_engine import SqlEngine from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT logger = setup_logger() celery_app = Celery(__name__) celery_app.config_from_object("onyx.background.celery.configs.user_file_processing") celery_app.Task = app_base.TenantAwareTask # type: ignore [misc] @signals.task_prerun.connect def on_task_prerun( sender: Any | None = None, task_id: str | None = None, task: Task | None = None, args: tuple | None = None, kwargs: dict | None = None, **kwds: Any, ) -> None: app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds) @signals.task_postrun.connect def on_task_postrun( sender: Any | None = None, task_id: str | None = None, task: Task | None = None, args: tuple | None = None, kwargs: dict | None = None, retval: Any | None = None, state: str | None = None, **kwds: Any, ) -> None: app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds) @celeryd_init.connect def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None: app_base.on_celeryd_init(sender, conf, **kwargs) @worker_init.connect def on_worker_init(sender: Worker, **kwargs: Any) -> None: logger.info("worker_init signal received.") SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME) # rkuo: Transient errors keep happening in the indexing watchdog threads. # "SSL connection has been closed unexpectedly" # actually setting the spawn method in the cloud fixes 95% of these. # setting pre ping might help even more, but not worrying about that yet pool_size = cast(int, sender.concurrency) # type: ignore SqlEngine.init_engine(pool_size=pool_size, max_overflow=8) app_base.wait_for_redis(sender, **kwargs) app_base.wait_for_db(sender, **kwargs) app_base.wait_for_vespa_or_shutdown(sender, **kwargs) # Less startup checks in multi-tenant case if MULTI_TENANT: return app_base.on_secondary_worker_init(sender, **kwargs) @worker_ready.connect def on_worker_ready(sender: Any, **kwargs: Any) -> None: app_base.on_worker_ready(sender, **kwargs) @worker_shutdown.connect def on_worker_shutdown(sender: Any, **kwargs: Any) -> None: app_base.on_worker_shutdown(sender, **kwargs) @worker_process_init.connect def init_worker(**kwargs: Any) -> None: # noqa: ARG001 SqlEngine.reset_engine() @signals.setup_logging.connect def on_setup_logging( loglevel: Any, logfile: Any, format: Any, colorize: Any, **kwargs: Any ) -> None: app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs) base_bootsteps = app_base.get_bootsteps() for bootstep in base_bootsteps: celery_app.steps["worker"].add(bootstep) celery_app.autodiscover_tasks( app_base.filter_task_modules( [ "onyx.background.celery.tasks.user_file_processing", ] ) ) ================================================ FILE: backend/onyx/background/celery/celery_k8s_probe.py ================================================ # script to use as a kubernetes readiness / liveness probe import argparse import sys import time from pathlib import Path def main_readiness(filename: str) -> int: """Checks if the file exists.""" path = Path(filename) if not path.is_file(): return 1 return 0 def main_liveness(filename: str) -> int: """Checks if the file exists AND was recently modified.""" path = Path(filename) if not path.is_file(): return 1 stats = path.stat() liveness_timestamp = stats.st_mtime current_timestamp = time.time() time_diff = current_timestamp - liveness_timestamp if time_diff > 60: return 1 return 0 if __name__ == "__main__": exit_code: int parser = argparse.ArgumentParser(description="k8s readiness/liveness probe") parser.add_argument( "--probe", type=str, choices=["readiness", "liveness"], help="The type of probe", required=True, ) parser.add_argument("--filename", help="The filename to watch", required=True) args = parser.parse_args() if args.probe == "readiness": exit_code = main_readiness(args.filename) elif args.probe == "liveness": exit_code = main_liveness(args.filename) else: raise ValueError(f"Unknown probe type: {args.probe}") sys.exit(exit_code) ================================================ FILE: backend/onyx/background/celery/celery_redis.py ================================================ # These are helper objects for tracking the keys we need to write in redis import json import threading from typing import Any from typing import cast from celery import Celery from redis import Redis from onyx.background.celery.configs.base import CELERY_SEPARATOR from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS _broker_client: Redis | None = None _broker_url: str | None = None _broker_client_lock = threading.Lock() def celery_get_broker_client(app: Celery) -> Redis: """Return a shared Redis client connected to the Celery broker DB. Uses a module-level singleton so all tasks on a worker share one connection instead of creating a new one per call. The client connects directly to the broker Redis DB (parsed from the broker URL). Thread-safe via lock — safe for use in Celery thread-pool workers. Usage: r_celery = celery_get_broker_client(self.app) length = celery_get_queue_length(queue, r_celery) """ global _broker_client, _broker_url with _broker_client_lock: url = app.conf.broker_url if _broker_client is not None and _broker_url == url: try: _broker_client.ping() return _broker_client except Exception: try: _broker_client.close() except Exception: pass _broker_client = None elif _broker_client is not None: try: _broker_client.close() except Exception: pass _broker_client = None _broker_url = url _broker_client = Redis.from_url( url, decode_responses=False, health_check_interval=REDIS_HEALTH_CHECK_INTERVAL, socket_keepalive=True, socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS, retry_on_timeout=True, ) return _broker_client def celery_get_unacked_length(r: Redis) -> int: """Checking the unacked queue is useful because a non-zero length tells us there may be prefetched tasks. There can be other tasks in here besides indexing tasks, so this is mostly useful just to see if the task count is non zero. ref: https://blog.hikaru.run/2022/08/29/get-waiting-tasks-count-in-celery.html """ length = cast(int, r.hlen("unacked")) return length def celery_get_unacked_task_ids(queue: str, r: Redis) -> set[str]: """Gets the set of task id's matching the given queue in the unacked hash. Unacked entries belonging to the indexing queues are "prefetched", so this gives us crucial visibility as to what tasks are in that state. """ tasks: set[str] = set() for _, v in r.hscan_iter("unacked"): v_bytes = cast(bytes, v) v_str = v_bytes.decode("utf-8") task = json.loads(v_str) task_description = task[0] task_queue = task[2] if task_queue != queue: continue task_id = task_description.get("headers", {}).get("id") if not task_id: continue # if the queue matches and we see the task_id, add it tasks.add(task_id) return tasks def celery_get_queue_length(queue: str, r: Redis) -> int: """This is a redis specific way to get the length of a celery queue. It is priority aware and knows how to count across the multiple redis lists used to implement task prioritization. This operation is not atomic.""" total_length = 0 for i in range(len(OnyxCeleryPriority)): queue_name = queue if i > 0: queue_name += CELERY_SEPARATOR queue_name += str(i) length = r.llen(queue_name) total_length += cast(int, length) return total_length def celery_find_task(task_id: str, queue: str, r: Redis) -> int: """This is a redis specific way to find a task for a particular queue in redis. It is priority aware and knows how to look through the multiple redis lists used to implement task prioritization. This operation is not atomic. This is a linear search O(n) ... so be careful using it when the task queues can be larger. Returns true if the id is in the queue, False if not. """ for priority in range(len(OnyxCeleryPriority)): queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue tasks = cast(list[bytes], r.lrange(queue_name, 0, -1)) for task in tasks: task_dict: dict[str, Any] = json.loads(task.decode("utf-8")) if task_dict.get("headers", {}).get("id") == task_id: return True return False def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]: """This is a redis specific way to build a list of tasks in a queue and return them as a set. This helps us read the queue once and then efficiently look for missing tasks in the queue. """ task_set: set[str] = set() for priority in range(len(OnyxCeleryPriority)): queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue tasks = cast(list[bytes], r.lrange(queue_name, 0, -1)) for task in tasks: task_dict: dict[str, Any] = json.loads(task.decode("utf-8")) task_id = task_dict.get("headers", {}).get("id") if task_id: task_set.add(task_id) return task_set def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]: """Returns a list of current workers containing name_filter, or all workers if name_filter is None. We've empirically discovered that the celery inspect API is potentially unstable and may hang or return empty results when celery is under load. Suggest using this more to debug and troubleshoot than in production code. """ worker_names: list[str] = [] # filter for and create an indexing specific inspect object inspect = app.control.inspect() workers: dict[str, Any] = inspect.ping() # type: ignore if workers: for worker_name in list(workers.keys()): # if the name filter not set, return all worker names if not name_filter: worker_names.append(worker_name) continue # if the name filter is set, return only worker names that contain the name filter if name_filter not in worker_name: continue worker_names.append(worker_name) return worker_names def celery_inspect_get_reserved(worker_names: list[str], app: Celery) -> set[str]: """Returns a list of reserved tasks on the specified workers. We've empirically discovered that the celery inspect API is potentially unstable and may hang or return empty results when celery is under load. Suggest using this more to debug and troubleshoot than in production code. """ reserved_task_ids: set[str] = set() inspect = app.control.inspect(destination=worker_names) # get the list of reserved tasks reserved_tasks: dict[str, list] | None = inspect.reserved() # type: ignore if reserved_tasks: for _, task_list in reserved_tasks.items(): for task in task_list: reserved_task_ids.add(task["id"]) return reserved_task_ids def celery_inspect_get_active(worker_names: list[str], app: Celery) -> set[str]: """Returns a list of active tasks on the specified workers. We've empirically discovered that the celery inspect API is potentially unstable and may hang or return empty results when celery is under load. Suggest using this more to debug and troubleshoot than in production code. """ active_task_ids: set[str] = set() inspect = app.control.inspect(destination=worker_names) # get the list of reserved tasks active_tasks: dict[str, list] | None = inspect.active() # type: ignore if active_tasks: for _, task_list in active_tasks.items(): for task in task_list: active_task_ids.add(task["id"]) return active_task_ids ================================================ FILE: backend/onyx/background/celery/celery_utils.py ================================================ from collections.abc import Generator from collections.abc import Iterator from collections.abc import Sequence from datetime import datetime from datetime import timezone from pathlib import Path from typing import Any from typing import cast from typing import TypeVar import httpx from pydantic import BaseModel from onyx.configs.app_configs import MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT from onyx.connectors.connector_runner import CheckpointOutputWrapper from onyx.connectors.cross_connector_utils.rate_limit_wrapper import ( rate_limit_builder, ) from onyx.connectors.interfaces import BaseConnector from onyx.connectors.interfaces import CheckpointedConnector from onyx.connectors.interfaces import ConnectorCheckpoint from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SlimConnector from onyx.connectors.interfaces import SlimConnectorWithPermSync from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import SlimDocument from onyx.httpx.httpx_pool import HttpxPool from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger logger = setup_logger() CT = TypeVar("CT", bound=ConnectorCheckpoint) class SlimConnectorExtractionResult(BaseModel): """Result of extracting document IDs and hierarchy nodes from a connector. raw_id_to_parent maps document ID → parent_hierarchy_raw_node_id (or None). Use raw_id_to_parent.keys() wherever the old set of IDs was needed. """ raw_id_to_parent: dict[str, str | None] hierarchy_nodes: list[HierarchyNode] def _checkpointed_batched_items( connector: CheckpointedConnector[CT], start: float, end: float, ) -> Generator[list[Document | HierarchyNode | ConnectorFailure], None, None]: """Loop through all checkpoint steps and yield batched items. Some checkpointed connectors (e.g. IMAP) are multi-step: the first checkpoint call may only initialize internal state without yielding any documents. This function loops until checkpoint.has_more is False to ensure all items are collected across every step. """ checkpoint = connector.build_dummy_checkpoint() while True: checkpoint_output = connector.load_from_checkpoint( start=start, end=end, checkpoint=checkpoint ) wrapper: CheckpointOutputWrapper[CT] = CheckpointOutputWrapper() batch: list[Document | HierarchyNode | ConnectorFailure] = [] for document, hierarchy_node, failure, next_checkpoint in wrapper( checkpoint_output ): if document is not None: batch.append(document) elif hierarchy_node is not None: batch.append(hierarchy_node) elif failure is not None: batch.append(failure) if next_checkpoint is not None: checkpoint = next_checkpoint if batch: yield batch if not checkpoint.has_more: break def _get_failure_id(failure: ConnectorFailure) -> str | None: """Extract the document/entity ID from a ConnectorFailure.""" if failure.failed_document: return failure.failed_document.document_id if failure.failed_entity: return failure.failed_entity.entity_id return None class BatchResult(BaseModel): raw_id_to_parent: dict[str, str | None] hierarchy_nodes: list[HierarchyNode] def _extract_from_batch( doc_list: Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure], ) -> BatchResult: """Separate a batch into document IDs (with parent mapping) and hierarchy nodes. ConnectorFailure items have their failed document/entity IDs added to the ID dict so that failed-to-retrieve documents are not accidentally pruned. """ ids: dict[str, str | None] = {} hierarchy_nodes: list[HierarchyNode] = [] for item in doc_list: if isinstance(item, HierarchyNode): hierarchy_nodes.append(item) elif isinstance(item, ConnectorFailure): failed_id = _get_failure_id(item) if failed_id: ids[failed_id] = None logger.warning( f"Failed to retrieve document {failed_id}: {item.failure_message}" ) else: ids[item.id] = item.parent_hierarchy_raw_node_id return BatchResult(raw_id_to_parent=ids, hierarchy_nodes=hierarchy_nodes) def extract_ids_from_runnable_connector( runnable_connector: BaseConnector, callback: IndexingHeartbeatInterface | None = None, ) -> SlimConnectorExtractionResult: """ Extract document IDs and hierarchy nodes from a runnable connector. Hierarchy nodes yielded alongside documents/slim docs are collected and returned in the result. ConnectorFailure items have their IDs preserved so that failed-to-retrieve documents are not accidentally pruned. Optionally, a callback can be passed to handle the length of each document batch. """ all_raw_id_to_parent: dict[str, str | None] = {} all_hierarchy_nodes: list[HierarchyNode] = [] # Sequence (covariant) lets all the specific list[...] iterator types unify here raw_batch_generator: ( Iterator[Sequence[Document | SlimDocument | HierarchyNode | ConnectorFailure]] | None ) = None if isinstance(runnable_connector, SlimConnector): raw_batch_generator = runnable_connector.retrieve_all_slim_docs() elif isinstance(runnable_connector, SlimConnectorWithPermSync): raw_batch_generator = runnable_connector.retrieve_all_slim_docs_perm_sync() # If the connector isn't slim, fall back to running it normally to get ids elif isinstance(runnable_connector, LoadConnector): raw_batch_generator = runnable_connector.load_from_state() elif isinstance(runnable_connector, PollConnector): start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp() end = datetime.now(timezone.utc).timestamp() raw_batch_generator = runnable_connector.poll_source(start=start, end=end) elif isinstance(runnable_connector, CheckpointedConnector): start = datetime(1970, 1, 1, tzinfo=timezone.utc).timestamp() end = datetime.now(timezone.utc).timestamp() raw_batch_generator = _checkpointed_batched_items( runnable_connector, start, end ) else: raise RuntimeError("Pruning job could not find a valid runnable_connector.") # this function is called per batch for rate limiting doc_batch_processing_func = ( rate_limit_builder( max_calls=MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE, period=60 )(lambda x: x) if MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE else lambda x: x ) # process raw batches to extract both IDs and hierarchy nodes for doc_list in raw_batch_generator: if callback and callback.should_stop(): raise RuntimeError( "extract_ids_from_runnable_connector: Stop signal detected" ) batch_result = _extract_from_batch(doc_list) batch_ids = batch_result.raw_id_to_parent batch_nodes = batch_result.hierarchy_nodes doc_batch_processing_func(batch_ids) all_raw_id_to_parent.update(batch_ids) all_hierarchy_nodes.extend(batch_nodes) if callback: callback.progress("extract_ids_from_runnable_connector", len(batch_ids)) return SlimConnectorExtractionResult( raw_id_to_parent=all_raw_id_to_parent, hierarchy_nodes=all_hierarchy_nodes, ) def celery_is_listening_to_queue(worker: Any, name: str) -> bool: """Checks to see if we're listening to the named queue""" # how to get a list of queues this worker is listening to # https://stackoverflow.com/questions/29790523/how-to-determine-which-queues-a-celery-worker-is-consuming-at-runtime queue_names = list(worker.app.amqp.queues.consume_from.keys()) for queue_name in queue_names: if queue_name == name: return True return False def celery_is_worker_primary(worker: Any) -> bool: """There are multiple approaches that could be taken to determine if a celery worker is 'primary', as defined by us. But the way we do it is to check the hostname set for the celery worker, which can be done on the command line with '--hostname'.""" hostname = worker.hostname if hostname.startswith("primary"): return True return False def httpx_init_vespa_pool( max_keepalive_connections: int, timeout: int = VESPA_REQUEST_TIMEOUT, ssl_cert: str | None = None, ssl_key: str | None = None, ) -> None: httpx_cert = None httpx_verify = False if ssl_cert and ssl_key: httpx_cert = cast(tuple[str, str], (ssl_cert, ssl_key)) httpx_verify = True HttpxPool.init_client( name="vespa", cert=httpx_cert, verify=httpx_verify, timeout=timeout, http2=False, limits=httpx.Limits(max_keepalive_connections=max_keepalive_connections), ) def make_probe_path(probe: str, hostname: str) -> Path: """templates the path for a k8s probe file. e.g. /tmp/onyx_k8s_indexing_readiness.txt """ hostname_parts = hostname.split("@") if len(hostname_parts) != 2: raise ValueError(f"hostname could not be split! {hostname=}") name = hostname_parts[0] if not name: raise ValueError(f"name cannot be empty! {name=}") safe_name = "".join(c for c in name if c.isalnum()).rstrip() return Path(f"/tmp/onyx_k8s_{safe_name}_{probe}.txt") ================================================ FILE: backend/onyx/background/celery/configs/base.py ================================================ # docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html import urllib.parse from onyx.configs.app_configs import CELERY_BROKER_POOL_LIMIT from onyx.configs.app_configs import CELERY_RESULT_EXPIRES from onyx.configs.app_configs import REDIS_DB_NUMBER_CELERY from onyx.configs.app_configs import REDIS_DB_NUMBER_CELERY_RESULT_BACKEND from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL from onyx.configs.app_configs import REDIS_HOST from onyx.configs.app_configs import REDIS_PASSWORD from onyx.configs.app_configs import REDIS_PORT from onyx.configs.app_configs import REDIS_SSL from onyx.configs.app_configs import REDIS_SSL_CA_CERTS from onyx.configs.app_configs import REDIS_SSL_CERT_REQS from onyx.configs.app_configs import USE_REDIS_IAM_AUTH from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS CELERY_SEPARATOR = ":" CELERY_PASSWORD_PART = "" if REDIS_PASSWORD: CELERY_PASSWORD_PART = ":" + urllib.parse.quote(REDIS_PASSWORD, safe="") + "@" REDIS_SCHEME = "redis" # SSL-specific query parameters for Redis URL SSL_QUERY_PARAMS = "" if REDIS_SSL and not USE_REDIS_IAM_AUTH: REDIS_SCHEME = "rediss" SSL_QUERY_PARAMS = f"?ssl_cert_reqs={REDIS_SSL_CERT_REQS}" if REDIS_SSL_CA_CERTS: SSL_QUERY_PARAMS += f"&ssl_ca_certs={REDIS_SSL_CA_CERTS}" # region Broker settings # example celery_broker_url: "redis://:password@localhost:6379/15" broker_url = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY}{SSL_QUERY_PARAMS}" broker_connection_retry_on_startup = True broker_pool_limit = CELERY_BROKER_POOL_LIMIT # redis broker settings # https://docs.celeryq.dev/projects/kombu/en/stable/reference/kombu.transport.redis.html broker_transport_options = { "priority_steps": list(range(len(OnyxCeleryPriority))), "sep": CELERY_SEPARATOR, "queue_order_strategy": "priority", "retry_on_timeout": True, "health_check_interval": REDIS_HEALTH_CHECK_INTERVAL, "socket_keepalive": True, "socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS, } # endregion # redis backend settings # https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings # there doesn't appear to be a way to set socket_keepalive_options on the redis result backend redis_socket_keepalive = True redis_retry_on_timeout = True redis_backend_health_check_interval = REDIS_HEALTH_CHECK_INTERVAL task_default_priority = OnyxCeleryPriority.MEDIUM task_acks_late = True # region Task result backend settings # It's possible we don't even need celery's result backend, in which case all of the optimization below # might be irrelevant result_backend = f"{REDIS_SCHEME}://{CELERY_PASSWORD_PART}{REDIS_HOST}:{REDIS_PORT}/{REDIS_DB_NUMBER_CELERY_RESULT_BACKEND}{SSL_QUERY_PARAMS}" result_expires = CELERY_RESULT_EXPIRES # 86400 seconds is the default # endregion # Leaving this to the default of True may cause double logging since both our own app # and celery think they are controlling the logger. # TODO: Configure celery's logger entirely manually and set this to False # worker_hijack_root_logger = False # region Notes on serialization performance # Option 0: Defaults (json serializer, no compression) # about 1.5 KB per queued task. 1KB in queue, 400B for result, 100 as a child entry in generator result # Option 1: Reduces generator task result sizes by roughly 20% # task_compression = "bzip2" # task_serializer = "pickle" # result_compression = "bzip2" # result_serializer = "pickle" # accept_content=["pickle"] # Option 2: this significantly reduces the size of the result for generator tasks since the list of children # can be large. small tasks change very little # def pickle_bz2_encoder(data): # return bz2.compress(pickle.dumps(data)) # def pickle_bz2_decoder(data): # return pickle.loads(bz2.decompress(data)) # from kombu import serialization # To register custom serialization with Celery/Kombu # serialization.register('pickle-bzip2', pickle_bz2_encoder, pickle_bz2_decoder, 'application/x-pickle-bz2', 'binary') # task_serializer = "pickle-bzip2" # result_serializer = "pickle-bzip2" # accept_content=["pickle", "pickle-bzip2"] # endregion ================================================ FILE: backend/onyx/background/celery/configs/beat.py ================================================ # docs: https://docs.celeryq.dev/en/stable/userguide/configuration.html import onyx.background.celery.configs.base as shared_config broker_url = shared_config.broker_url broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup broker_pool_limit = shared_config.broker_pool_limit broker_transport_options = shared_config.broker_transport_options redis_socket_keepalive = shared_config.redis_socket_keepalive redis_retry_on_timeout = shared_config.redis_retry_on_timeout redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval result_backend = shared_config.result_backend result_expires = shared_config.result_expires # 86400 seconds is the default ================================================ FILE: backend/onyx/background/celery/configs/client.py ================================================ import onyx.background.celery.configs.base as shared_config broker_url = shared_config.broker_url broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup broker_pool_limit = shared_config.broker_pool_limit broker_transport_options = shared_config.broker_transport_options redis_socket_keepalive = shared_config.redis_socket_keepalive redis_retry_on_timeout = shared_config.redis_retry_on_timeout redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval result_backend = shared_config.result_backend result_expires = shared_config.result_expires # 86400 seconds is the default task_default_priority = shared_config.task_default_priority task_acks_late = shared_config.task_acks_late ================================================ FILE: backend/onyx/background/celery/configs/docfetching.py ================================================ import onyx.background.celery.configs.base as shared_config from onyx.configs.app_configs import CELERY_WORKER_DOCFETCHING_CONCURRENCY broker_url = shared_config.broker_url broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup broker_pool_limit = shared_config.broker_pool_limit broker_transport_options = shared_config.broker_transport_options redis_socket_keepalive = shared_config.redis_socket_keepalive redis_retry_on_timeout = shared_config.redis_retry_on_timeout redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval result_backend = shared_config.result_backend result_expires = shared_config.result_expires # 86400 seconds is the default task_default_priority = shared_config.task_default_priority task_acks_late = shared_config.task_acks_late # Docfetching worker configuration worker_concurrency = CELERY_WORKER_DOCFETCHING_CONCURRENCY worker_pool = "threads" worker_prefetch_multiplier = 1 ================================================ FILE: backend/onyx/background/celery/configs/docprocessing.py ================================================ import onyx.background.celery.configs.base as shared_config from onyx.configs.app_configs import CELERY_WORKER_DOCPROCESSING_CONCURRENCY broker_url = shared_config.broker_url broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup broker_pool_limit = shared_config.broker_pool_limit broker_transport_options = shared_config.broker_transport_options redis_socket_keepalive = shared_config.redis_socket_keepalive redis_retry_on_timeout = shared_config.redis_retry_on_timeout redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval result_backend = shared_config.result_backend result_expires = shared_config.result_expires # 86400 seconds is the default task_default_priority = shared_config.task_default_priority task_acks_late = shared_config.task_acks_late # Indexing worker specific ... this lets us track the transition to STARTED in redis # We don't currently rely on this but it has the potential to be useful and # indexing tasks are not high volume # we don't turn this on yet because celery occasionally runs tasks more than once # which means a duplicate run might change the task state unexpectedly # task_track_started = True worker_concurrency = CELERY_WORKER_DOCPROCESSING_CONCURRENCY worker_pool = "threads" worker_prefetch_multiplier = 1 ================================================ FILE: backend/onyx/background/celery/configs/heavy.py ================================================ import onyx.background.celery.configs.base as shared_config from onyx.configs.app_configs import CELERY_WORKER_HEAVY_CONCURRENCY broker_url = shared_config.broker_url broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup broker_pool_limit = shared_config.broker_pool_limit broker_transport_options = shared_config.broker_transport_options redis_socket_keepalive = shared_config.redis_socket_keepalive redis_retry_on_timeout = shared_config.redis_retry_on_timeout redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval result_backend = shared_config.result_backend result_expires = shared_config.result_expires # 86400 seconds is the default task_default_priority = shared_config.task_default_priority task_acks_late = shared_config.task_acks_late worker_concurrency = CELERY_WORKER_HEAVY_CONCURRENCY worker_pool = "threads" worker_prefetch_multiplier = 1 ================================================ FILE: backend/onyx/background/celery/configs/light.py ================================================ import onyx.background.celery.configs.base as shared_config from onyx.configs.app_configs import CELERY_WORKER_LIGHT_CONCURRENCY from onyx.configs.app_configs import CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER broker_url = shared_config.broker_url broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup broker_pool_limit = shared_config.broker_pool_limit broker_transport_options = shared_config.broker_transport_options redis_socket_keepalive = shared_config.redis_socket_keepalive redis_retry_on_timeout = shared_config.redis_retry_on_timeout redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval result_backend = shared_config.result_backend result_expires = shared_config.result_expires # 86400 seconds is the default task_default_priority = shared_config.task_default_priority task_acks_late = shared_config.task_acks_late worker_concurrency = CELERY_WORKER_LIGHT_CONCURRENCY worker_pool = "threads" worker_prefetch_multiplier = CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER ================================================ FILE: backend/onyx/background/celery/configs/monitoring.py ================================================ import onyx.background.celery.configs.base as shared_config from onyx.configs.app_configs import CELERY_WORKER_MONITORING_CONCURRENCY broker_url = shared_config.broker_url broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup broker_pool_limit = shared_config.broker_pool_limit broker_transport_options = shared_config.broker_transport_options redis_socket_keepalive = shared_config.redis_socket_keepalive redis_retry_on_timeout = shared_config.redis_retry_on_timeout redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval result_backend = shared_config.result_backend result_expires = shared_config.result_expires # 86400 seconds is the default task_default_priority = shared_config.task_default_priority task_acks_late = shared_config.task_acks_late # Monitoring worker specific settings worker_concurrency = CELERY_WORKER_MONITORING_CONCURRENCY worker_pool = "threads" worker_prefetch_multiplier = 1 ================================================ FILE: backend/onyx/background/celery/configs/primary.py ================================================ import onyx.background.celery.configs.base as shared_config from onyx.configs.app_configs import CELERY_WORKER_PRIMARY_CONCURRENCY broker_url = shared_config.broker_url broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup broker_pool_limit = shared_config.broker_pool_limit broker_transport_options = shared_config.broker_transport_options redis_socket_keepalive = shared_config.redis_socket_keepalive redis_retry_on_timeout = shared_config.redis_retry_on_timeout redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval result_backend = shared_config.result_backend result_expires = shared_config.result_expires # 86400 seconds is the default task_default_priority = shared_config.task_default_priority task_acks_late = shared_config.task_acks_late worker_concurrency = CELERY_WORKER_PRIMARY_CONCURRENCY worker_pool = "threads" worker_prefetch_multiplier = 1 ================================================ FILE: backend/onyx/background/celery/configs/user_file_processing.py ================================================ import onyx.background.celery.configs.base as shared_config from onyx.configs.app_configs import CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY broker_url = shared_config.broker_url broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup broker_pool_limit = shared_config.broker_pool_limit broker_transport_options = shared_config.broker_transport_options redis_socket_keepalive = shared_config.redis_socket_keepalive redis_retry_on_timeout = shared_config.redis_retry_on_timeout redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval result_backend = shared_config.result_backend result_expires = shared_config.result_expires # 86400 seconds is the default task_default_priority = shared_config.task_default_priority task_acks_late = shared_config.task_acks_late # User file processing worker configuration worker_concurrency = CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY worker_pool = "threads" worker_prefetch_multiplier = 1 ================================================ FILE: backend/onyx/background/celery/memory_monitoring.py ================================================ # backend/onyx/background/celery/memory_monitoring.py import logging import os from logging.handlers import RotatingFileHandler import psutil from onyx.utils.logger import is_running_in_container from onyx.utils.logger import setup_logger # Regular application logger logger = setup_logger() # Only set up memory monitoring in container environment if is_running_in_container(): # Set up a dedicated memory monitoring logger MEMORY_LOG_DIR = "/var/log/onyx/memory" MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log") MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files # Ensure log directory exists os.makedirs(MEMORY_LOG_DIR, exist_ok=True) # Create a dedicated logger for memory monitoring memory_logger = logging.getLogger("memory_monitoring") memory_logger.setLevel(logging.INFO) # Create a rotating file handler memory_handler = RotatingFileHandler( MEMORY_LOG_FILE, maxBytes=MEMORY_LOG_MAX_BYTES, backupCount=MEMORY_LOG_BACKUP_COUNT, ) # Create a formatter that includes all relevant information memory_formatter = logging.Formatter( "%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S" ) memory_handler.setFormatter(memory_formatter) memory_logger.addHandler(memory_handler) else: # Create a null logger when not in container memory_logger = logging.getLogger("memory_monitoring") memory_logger.addHandler(logging.NullHandler()) def emit_process_memory( pid: int, process_name: str, additional_metadata: dict[str, str | int] ) -> None: # Skip memory monitoring if not in container if not is_running_in_container(): return try: process = psutil.Process(pid) memory_info = process.memory_info() cpu_percent = process.cpu_percent(interval=0.1) # Build metadata string from additional_metadata dictionary metadata_str = " ".join( [f"{key}={value}" for key, value in additional_metadata.items()] ) metadata_str = f" {metadata_str}" if metadata_str else "" memory_logger.info( f"PROCESS_MEMORY process_name={process_name} pid={pid} " f"rss_mb={memory_info.rss / (1024 * 1024):.2f} " f"vms_mb={memory_info.vms / (1024 * 1024):.2f} " f"cpu={cpu_percent:.2f}{metadata_str}" ) except Exception: logger.exception("Error monitoring process memory.") ================================================ FILE: backend/onyx/background/celery/tasks/beat_schedule.py ================================================ import copy from datetime import timedelta from typing import Any from celery.schedules import crontab from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS from onyx.configs.app_configs import DISABLE_VECTOR_DB from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX from onyx.configs.app_configs import ENTERPRISE_EDITION_ENABLED from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from shared_configs.configs import MULTI_TENANT # choosing 15 minutes because it roughly gives us enough time to process many tasks # we might be able to reduce this greatly if we can run a unified # loop across all tenants rather than tasks per tenant # we set expires because it isn't necessary to queue up these tasks # it's only important that they run relatively regularly BEAT_EXPIRES_DEFAULT = 15 * 60 # 15 minutes (in seconds) # hack to slow down task dispatch in the cloud until # we have a better implementation (backpressure, etc) # Note that DynamicTenantScheduler can adjust the runtime value for this via Redis CLOUD_BEAT_MULTIPLIER_DEFAULT = 8.0 CLOUD_DOC_PERMISSION_SYNC_MULTIPLIER_DEFAULT = 1.0 # tasks that run in either self-hosted on cloud beat_task_templates: list[dict] = [ { "name": "check-for-user-file-processing", "task": OnyxCeleryTask.CHECK_FOR_USER_FILE_PROCESSING, "schedule": timedelta(seconds=20), "options": { "priority": OnyxCeleryPriority.MEDIUM, "expires": BEAT_EXPIRES_DEFAULT, }, }, { "name": "check-for-user-file-project-sync", "task": OnyxCeleryTask.CHECK_FOR_USER_FILE_PROJECT_SYNC, "schedule": timedelta(seconds=20), "options": { "priority": OnyxCeleryPriority.MEDIUM, "expires": BEAT_EXPIRES_DEFAULT, }, }, { "name": "check-for-user-file-delete", "task": OnyxCeleryTask.CHECK_FOR_USER_FILE_DELETE, "schedule": timedelta(seconds=20), "options": { "priority": OnyxCeleryPriority.MEDIUM, "expires": BEAT_EXPIRES_DEFAULT, }, }, { "name": "check-for-indexing", "task": OnyxCeleryTask.CHECK_FOR_INDEXING, "schedule": timedelta(seconds=15), "options": { "priority": OnyxCeleryPriority.MEDIUM, "expires": BEAT_EXPIRES_DEFAULT, }, }, { "name": "check-for-checkpoint-cleanup", "task": OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP, "schedule": timedelta(hours=1), "options": { "priority": OnyxCeleryPriority.LOW, "expires": BEAT_EXPIRES_DEFAULT, }, }, { "name": "check-for-index-attempt-cleanup", "task": OnyxCeleryTask.CHECK_FOR_INDEX_ATTEMPT_CLEANUP, "schedule": timedelta(minutes=30), "options": { "priority": OnyxCeleryPriority.MEDIUM, "expires": BEAT_EXPIRES_DEFAULT, }, }, { "name": "check-for-connector-deletion", "task": OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION, "schedule": timedelta(seconds=20), "options": { "priority": OnyxCeleryPriority.MEDIUM, "expires": BEAT_EXPIRES_DEFAULT, }, }, { "name": "check-for-vespa-sync", "task": OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK, "schedule": timedelta(seconds=20), "options": { "priority": OnyxCeleryPriority.MEDIUM, "expires": BEAT_EXPIRES_DEFAULT, }, }, { "name": "check-for-pruning", "task": OnyxCeleryTask.CHECK_FOR_PRUNING, "schedule": timedelta(seconds=20), "options": { "priority": OnyxCeleryPriority.MEDIUM, "expires": BEAT_EXPIRES_DEFAULT, }, }, { "name": "check-for-hierarchy-fetching", "task": OnyxCeleryTask.CHECK_FOR_HIERARCHY_FETCHING, "schedule": timedelta(hours=1), # Check hourly, but only fetch once per day "options": { "priority": OnyxCeleryPriority.LOW, "expires": BEAT_EXPIRES_DEFAULT, }, }, { "name": "monitor-background-processes", "task": OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES, "schedule": timedelta(minutes=5), "options": { "priority": OnyxCeleryPriority.LOW, "expires": BEAT_EXPIRES_DEFAULT, "queue": OnyxCeleryQueues.MONITORING, }, }, # Sandbox cleanup tasks { "name": "cleanup-idle-sandboxes", "task": OnyxCeleryTask.CLEANUP_IDLE_SANDBOXES, "schedule": timedelta(minutes=1), "options": { "priority": OnyxCeleryPriority.LOW, "expires": BEAT_EXPIRES_DEFAULT, "queue": OnyxCeleryQueues.SANDBOX, }, }, { "name": "cleanup-old-snapshots", "task": OnyxCeleryTask.CLEANUP_OLD_SNAPSHOTS, "schedule": timedelta(hours=24), "options": { "priority": OnyxCeleryPriority.LOW, "expires": BEAT_EXPIRES_DEFAULT, "queue": OnyxCeleryQueues.SANDBOX, }, }, ] if ENTERPRISE_EDITION_ENABLED: beat_task_templates.extend( [ { "name": "check-for-doc-permissions-sync", "task": OnyxCeleryTask.CHECK_FOR_DOC_PERMISSIONS_SYNC, "schedule": timedelta(seconds=30), "options": { "priority": OnyxCeleryPriority.MEDIUM, "expires": BEAT_EXPIRES_DEFAULT, }, }, { "name": "check-for-external-group-sync", "task": OnyxCeleryTask.CHECK_FOR_EXTERNAL_GROUP_SYNC, "schedule": timedelta(seconds=20), "options": { "priority": OnyxCeleryPriority.MEDIUM, "expires": BEAT_EXPIRES_DEFAULT, }, }, ] ) # Add the Auto LLM update task if the config URL is set (has a default) if AUTO_LLM_CONFIG_URL: beat_task_templates.append( { "name": "check-for-auto-llm-update", "task": OnyxCeleryTask.CHECK_FOR_AUTO_LLM_UPDATE, "schedule": timedelta(seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS), "options": { "priority": OnyxCeleryPriority.LOW, "expires": BEAT_EXPIRES_DEFAULT, }, } ) # Add scheduled eval task if datasets are configured if SCHEDULED_EVAL_DATASET_NAMES: beat_task_templates.append( { "name": "scheduled-eval-pipeline", "task": OnyxCeleryTask.SCHEDULED_EVAL_TASK, # run every Sunday at midnight UTC "schedule": crontab( hour=0, minute=0, day_of_week=0, ), "options": { "priority": OnyxCeleryPriority.LOW, "expires": BEAT_EXPIRES_DEFAULT, }, } ) # Add OpenSearch migration task if enabled. if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX: beat_task_templates.append( { "name": "migrate-chunks-from-vespa-to-opensearch", "task": OnyxCeleryTask.MIGRATE_CHUNKS_FROM_VESPA_TO_OPENSEARCH_TASK, # Try to enqueue an invocation of this task with this frequency. "schedule": timedelta(seconds=120), # 2 minutes "options": { "priority": OnyxCeleryPriority.LOW, # If the task was not dequeued in this time, revoke it. "expires": BEAT_EXPIRES_DEFAULT, "queue": OnyxCeleryQueues.OPENSEARCH_MIGRATION, }, } ) # Beat task names that require a vector DB. Filtered out when DISABLE_VECTOR_DB. _VECTOR_DB_BEAT_TASK_NAMES: set[str] = { "check-for-indexing", "check-for-connector-deletion", "check-for-vespa-sync", "check-for-pruning", "check-for-hierarchy-fetching", "check-for-checkpoint-cleanup", "check-for-index-attempt-cleanup", "check-for-doc-permissions-sync", "check-for-external-group-sync", "migrate-chunks-from-vespa-to-opensearch", } if DISABLE_VECTOR_DB: beat_task_templates = [ t for t in beat_task_templates if t["name"] not in _VECTOR_DB_BEAT_TASK_NAMES ] def make_cloud_generator_task(task: dict[str, Any]) -> dict[str, Any]: cloud_task: dict[str, Any] = {} # constant options for cloud beat task generators task_schedule: timedelta = task["schedule"] cloud_task["schedule"] = task_schedule cloud_task["options"] = {} cloud_task["options"]["priority"] = OnyxCeleryPriority.HIGHEST cloud_task["options"]["expires"] = BEAT_EXPIRES_DEFAULT # settings dependent on the original task cloud_task["name"] = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_{task['name']}" cloud_task["task"] = OnyxCeleryTask.CLOUD_BEAT_TASK_GENERATOR cloud_task["kwargs"] = {} cloud_task["kwargs"]["task_name"] = task["task"] optional_fields = ["queue", "priority", "expires"] for field in optional_fields: if field in task["options"]: cloud_task["kwargs"][field] = task["options"][field] return cloud_task # tasks that only run in the cloud and are system wide # the name attribute must start with ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" to be seen # by the DynamicTenantScheduler as system wide task and not a per tenant task beat_cloud_tasks: list[dict] = [ # cloud specific tasks { "name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-alembic", "task": OnyxCeleryTask.CLOUD_MONITOR_ALEMBIC, "schedule": timedelta(hours=1), "options": { "queue": OnyxCeleryQueues.MONITORING, "priority": OnyxCeleryPriority.HIGH, "expires": BEAT_EXPIRES_DEFAULT, }, }, { "name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-celery-queues", "task": OnyxCeleryTask.CLOUD_MONITOR_CELERY_QUEUES, "schedule": timedelta(seconds=30), "options": { "queue": OnyxCeleryQueues.MONITORING, "priority": OnyxCeleryPriority.HIGH, "expires": BEAT_EXPIRES_DEFAULT, }, }, { "name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-available-tenants", "task": OnyxCeleryTask.CLOUD_CHECK_AVAILABLE_TENANTS, "schedule": timedelta(minutes=10), "options": { "queue": OnyxCeleryQueues.MONITORING, "priority": OnyxCeleryPriority.HIGH, "expires": BEAT_EXPIRES_DEFAULT, }, }, { "name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor-celery-pidbox", "task": OnyxCeleryTask.CLOUD_MONITOR_CELERY_PIDBOX, "schedule": timedelta(hours=4), "options": { "queue": OnyxCeleryQueues.MONITORING, "priority": OnyxCeleryPriority.HIGH, "expires": BEAT_EXPIRES_DEFAULT, }, }, ] # tasks that only run self hosted tasks_to_schedule: list[dict] = [] if not MULTI_TENANT: tasks_to_schedule.extend( [ { "name": "monitor-celery-queues", "task": OnyxCeleryTask.MONITOR_CELERY_QUEUES, "schedule": timedelta(seconds=10), "options": { "priority": OnyxCeleryPriority.MEDIUM, "expires": BEAT_EXPIRES_DEFAULT, "queue": OnyxCeleryQueues.MONITORING, }, }, { "name": "monitor-process-memory", "task": OnyxCeleryTask.MONITOR_PROCESS_MEMORY, "schedule": timedelta(minutes=5), "options": { "priority": OnyxCeleryPriority.LOW, "expires": BEAT_EXPIRES_DEFAULT, "queue": OnyxCeleryQueues.MONITORING, }, }, { "name": "celery-beat-heartbeat", "task": OnyxCeleryTask.CELERY_BEAT_HEARTBEAT, "schedule": timedelta(minutes=1), "options": { "priority": OnyxCeleryPriority.HIGHEST, "expires": BEAT_EXPIRES_DEFAULT, "queue": OnyxCeleryQueues.PRIMARY, }, }, ] ) tasks_to_schedule.extend(beat_task_templates) def generate_cloud_tasks( beat_tasks: list[dict], beat_templates: list[dict], beat_multiplier: float ) -> list[dict[str, Any]]: """ beat_tasks: system wide tasks that can be sent as is beat_templates: task templates that will be transformed into per tenant tasks via the cloud_beat_task_generator beat_multiplier: a multiplier that can be applied on top of the task schedule to speed up or slow down the task generation rate. useful in production. Returns a list of cloud tasks, which consists of incoming tasks + tasks generated from incoming templates. """ if beat_multiplier <= 0: raise ValueError("beat_multiplier must be positive!") cloud_tasks: list[dict] = [] # generate our tenant aware cloud tasks from the templates for beat_template in beat_templates: cloud_task = make_cloud_generator_task(beat_template) cloud_tasks.append(cloud_task) # factor in the cloud multiplier for the above for cloud_task in cloud_tasks: cloud_task["schedule"] = cloud_task["schedule"] * beat_multiplier # add the fixed cloud/system beat tasks. No multiplier for these. cloud_tasks.extend(copy.deepcopy(beat_tasks)) return cloud_tasks def get_cloud_tasks_to_schedule(beat_multiplier: float) -> list[dict[str, Any]]: return generate_cloud_tasks(beat_cloud_tasks, beat_task_templates, beat_multiplier) def get_tasks_to_schedule() -> list[dict[str, Any]]: return tasks_to_schedule ================================================ FILE: backend/onyx/background/celery/tasks/connector_deletion/__init__.py ================================================ ================================================ FILE: backend/onyx/background/celery/tasks/connector_deletion/tasks.py ================================================ import traceback from datetime import datetime from datetime import timezone from typing import Any from typing import cast from celery import Celery from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded from pydantic import ValidationError from redis import Redis from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.celery_redis import celery_get_broker_client from onyx.background.celery.celery_redis import celery_get_queue_length from onyx.background.celery.celery_redis import celery_get_queued_task_ids from onyx.configs.app_configs import JOB_TIMEOUT from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisConstants from onyx.configs.constants import OnyxRedisLocks from onyx.configs.constants import OnyxRedisSignals from onyx.db.connector import fetch_connector_by_id from onyx.db.connector_credential_pair import add_deletion_failure_message from onyx.db.connector_credential_pair import ( delete_connector_credential_pair__no_commit, ) from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.connector_credential_pair import get_connector_credential_pairs from onyx.db.document import ( delete_all_documents_by_connector_credential_pair__no_commit, ) from onyx.db.document import get_document_ids_for_connector_credential_pair from onyx.db.document_set import delete_document_set_cc_pair_relationship__no_commit from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import IndexingStatus from onyx.db.enums import SyncStatus from onyx.db.enums import SyncType from onyx.db.index_attempt import delete_index_attempts from onyx.db.index_attempt import get_recent_attempts_for_cc_pair from onyx.db.permission_sync_attempt import ( delete_doc_permission_sync_attempts__no_commit, ) from onyx.db.permission_sync_attempt import ( delete_external_group_permission_sync_attempts__no_commit, ) from onyx.db.search_settings import get_all_search_settings from onyx.db.sync_record import cleanup_sync_records from onyx.db.sync_record import insert_sync_record from onyx.db.sync_record import update_sync_record_status from onyx.db.tag import delete_orphan_tags__no_commit from onyx.redis.redis_connector import RedisConnector from onyx.redis.redis_connector_delete import RedisConnectorDelete from onyx.redis.redis_connector_delete import RedisConnectorDeletePayload from onyx.redis.redis_pool import get_redis_client from onyx.redis.redis_pool import get_redis_replica_client from onyx.utils.variable_functionality import ( fetch_versioned_implementation_with_fallback, ) from onyx.utils.variable_functionality import noop_fallback class TaskDependencyError(RuntimeError): """Raised to the caller to indicate dependent tasks are running that would interfere with connector deletion.""" def revoke_tasks_blocking_deletion( redis_connector: RedisConnector, db_session: Session, app: Celery ) -> None: search_settings_list = get_all_search_settings(db_session) for search_settings in search_settings_list: try: recent_index_attempts = get_recent_attempts_for_cc_pair( cc_pair_id=redis_connector.cc_pair_id, search_settings_id=search_settings.id, limit=1, db_session=db_session, ) if ( recent_index_attempts and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS and recent_index_attempts[0].celery_task_id ): app.control.revoke(recent_index_attempts[0].celery_task_id) task_logger.info( f"Revoked indexing task {recent_index_attempts[0].celery_task_id}." ) except Exception: task_logger.exception("Exception while revoking indexing task") try: permissions_sync_payload = redis_connector.permissions.payload if permissions_sync_payload and permissions_sync_payload.celery_task_id: app.control.revoke(permissions_sync_payload.celery_task_id) task_logger.info( f"Revoked permissions sync task {permissions_sync_payload.celery_task_id}." ) except Exception: task_logger.exception("Exception while revoking pruning task") try: prune_payload = redis_connector.prune.payload if prune_payload and prune_payload.celery_task_id: app.control.revoke(prune_payload.celery_task_id) task_logger.info(f"Revoked pruning task {prune_payload.celery_task_id}.") except Exception: task_logger.exception("Exception while revoking permissions sync task") try: external_group_sync_payload = redis_connector.external_group_sync.payload if external_group_sync_payload and external_group_sync_payload.celery_task_id: app.control.revoke(external_group_sync_payload.celery_task_id) task_logger.info( f"Revoked external group sync task {external_group_sync_payload.celery_task_id}." ) except Exception: task_logger.exception("Exception while revoking external group sync task") @shared_task( name=OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION, ignore_result=True, soft_time_limit=JOB_TIMEOUT, trail=False, bind=True, ) def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None: r = get_redis_client() r_replica = get_redis_replica_client() lock_beat: RedisLock = r.lock( OnyxRedisLocks.CHECK_CONNECTOR_DELETION_BEAT_LOCK, timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT, ) # Prevent this task from overlapping with itself if not lock_beat.acquire(blocking=False): return None try: # we want to run this less frequently than the overall task lock_beat.reacquire() if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES): # clear fences that don't have associated celery tasks in progress try: r_celery = celery_get_broker_client(self.app) validate_connector_deletion_fences( tenant_id, r, r_replica, r_celery, lock_beat ) except Exception: task_logger.exception( "Exception while validating connector deletion fences" ) r.set(OnyxRedisSignals.BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES, 1, ex=300) # collect cc_pair_ids cc_pair_ids: list[int] = [] with get_session_with_current_tenant() as db_session: cc_pairs = get_connector_credential_pairs(db_session) for cc_pair in cc_pairs: cc_pair_ids.append(cc_pair.id) # try running cleanup on the cc_pair_ids for cc_pair_id in cc_pair_ids: with get_session_with_current_tenant() as db_session: redis_connector = RedisConnector(tenant_id, cc_pair_id) try: try_generate_document_cc_pair_cleanup_tasks( self.app, cc_pair_id, db_session, lock_beat, tenant_id ) except TaskDependencyError as e: # this means we wanted to start deleting but dependent tasks were running # on the first error, we set a stop signal and revoke the dependent tasks # on subsequent errors, we hard reset blocking fences after our specified timeout # is exceeded task_logger.info(str(e)) if not redis_connector.stop.fenced: # one time revoke of celery tasks task_logger.info("Revoking any tasks blocking deletion.") revoke_tasks_blocking_deletion( redis_connector, db_session, self.app ) redis_connector.stop.set_fence(True) redis_connector.stop.set_timeout() else: # stop signal already set if redis_connector.stop.timed_out: # waiting too long, just reset blocking fences task_logger.info( "Timed out waiting for tasks blocking deletion. Resetting blocking fences." ) redis_connector.prune.reset() redis_connector.permissions.reset() redis_connector.external_group_sync.reset() else: # just wait pass else: # clear the stop signal if it exists ... no longer needed redis_connector.stop.set_fence(False) lock_beat.reacquire() keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES)) for key in keys: key_bytes = cast(bytes, key) if not r.exists(key_bytes): r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes) continue key_str = key_bytes.decode("utf-8") if key_str.startswith(RedisConnectorDelete.FENCE_PREFIX): monitor_connector_deletion_taskset(tenant_id, key_bytes, r) except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." ) except Exception: task_logger.exception("Unexpected exception during connector deletion check") finally: if lock_beat.owned(): lock_beat.release() return True def try_generate_document_cc_pair_cleanup_tasks( app: Celery, cc_pair_id: int, db_session: Session, lock_beat: RedisLock, tenant_id: str, ) -> int | None: """Returns an int if syncing is needed. The int represents the number of sync tasks generated. Note that syncing can still be required even if the number of sync tasks generated is zero. Returns None if no syncing is required. Will raise TaskDependencyError if dependent tasks such as indexing and pruning are still running. In our case, the caller reacts by setting a stop signal in Redis to exit those tasks as quickly as possible. """ lock_beat.reacquire() redis_connector = RedisConnector(tenant_id, cc_pair_id) # don't generate sync tasks if tasks are still pending if redis_connector.delete.fenced: return None # we need to load the state of the object inside the fence # to avoid a race condition with db.commit/fence deletion # at the end of this taskset cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, ) if not cc_pair: return None if cc_pair.status != ConnectorCredentialPairStatus.DELETING: # there should be no in-progress sync records if this is up to date # clean it up just in case things got into a bad state cleanup_sync_records( db_session=db_session, entity_id=cc_pair_id, sync_type=SyncType.CONNECTOR_DELETION, ) return None # set a basic fence to start redis_connector.delete.set_active() fence_payload = RedisConnectorDeletePayload( num_tasks=None, submitted=datetime.now(timezone.utc), ) redis_connector.delete.set_fence(fence_payload) try: # do not proceed if connector indexing or connector pruning are running search_settings_list = get_all_search_settings(db_session) for search_settings in search_settings_list: recent_index_attempts = get_recent_attempts_for_cc_pair( cc_pair_id=cc_pair_id, search_settings_id=search_settings.id, limit=1, db_session=db_session, ) if ( recent_index_attempts and recent_index_attempts[0].status == IndexingStatus.IN_PROGRESS ): raise TaskDependencyError( "Connector deletion - Delayed (indexing in progress): " f"cc_pair={cc_pair_id} " f"search_settings={search_settings.id}" ) if redis_connector.prune.fenced: raise TaskDependencyError( f"Connector deletion - Delayed (pruning in progress): cc_pair={cc_pair_id}" ) if redis_connector.permissions.fenced: raise TaskDependencyError( f"Connector deletion - Delayed (permissions in progress): cc_pair={cc_pair_id}" ) # add tasks to celery and build up the task set to monitor in redis redis_connector.delete.taskset_clear() # Add all documents that need to be updated into the queue task_logger.info( f"RedisConnectorDeletion.generate_tasks starting. cc_pair={cc_pair_id}" ) tasks_generated = redis_connector.delete.generate_tasks( app, db_session, lock_beat ) if tasks_generated is None: raise ValueError("RedisConnectorDeletion.generate_tasks returned None") try: insert_sync_record( db_session=db_session, entity_id=cc_pair_id, sync_type=SyncType.CONNECTOR_DELETION, ) except Exception: task_logger.exception("insert_sync_record exceptioned.") except TaskDependencyError: redis_connector.delete.set_fence(None) raise except Exception: task_logger.exception("Unexpected exception") redis_connector.delete.set_fence(None) return None else: # Currently we are allowing the sync to proceed with 0 tasks. # It's possible for sets/groups to be generated initially with no entries # and they still need to be marked as up to date. # if tasks_generated == 0: # return 0 task_logger.info( f"RedisConnectorDeletion.generate_tasks finished. cc_pair={cc_pair_id} tasks_generated={tasks_generated}" ) # set this only after all tasks have been added fence_payload.num_tasks = tasks_generated redis_connector.delete.set_fence(fence_payload) return tasks_generated def monitor_connector_deletion_taskset( tenant_id: str, key_bytes: bytes, r: Redis, # noqa: ARG001 ) -> None: fence_key = key_bytes.decode("utf-8") cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key) if cc_pair_id_str is None: task_logger.warning(f"could not parse cc_pair_id from {fence_key}") return cc_pair_id = int(cc_pair_id_str) redis_connector = RedisConnector(tenant_id, cc_pair_id) fence_data = redis_connector.delete.payload if not fence_data: task_logger.warning( f"Connector deletion - fence payload invalid: cc_pair={cc_pair_id}" ) return if fence_data.num_tasks is None: # the fence is setting up but isn't ready yet return remaining = redis_connector.delete.get_remaining() task_logger.info( f"Connector deletion progress: cc_pair={cc_pair_id} remaining={remaining} initial={fence_data.num_tasks}" ) if remaining > 0: with get_session_with_current_tenant() as db_session: update_sync_record_status( db_session=db_session, entity_id=cc_pair_id, sync_type=SyncType.CONNECTOR_DELETION, sync_status=SyncStatus.IN_PROGRESS, num_docs_synced=remaining, ) return with get_session_with_current_tenant() as db_session: cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, ) credential_id_to_delete: int | None = None connector_id_to_delete: int | None = None if not cc_pair: task_logger.warning( f"Connector deletion - cc_pair not found: cc_pair={cc_pair_id}" ) return try: doc_ids = get_document_ids_for_connector_credential_pair( db_session, cc_pair.connector_id, cc_pair.credential_id ) if len(doc_ids) > 0: # NOTE(rkuo): if this happens, documents somehow got added while # deletion was in progress. Likely a bug gating off pruning and indexing # work before deletion starts. task_logger.warning( "Connector deletion - documents still found after taskset completion. " "Clearing the current deletion attempt and allowing deletion to restart: " f"cc_pair={cc_pair_id} " f"docs_deleted={fence_data.num_tasks} " f"docs_remaining={len(doc_ids)}" ) # We don't want to waive off why we get into this state, but resetting # our attempt and letting the deletion restart is a good way to recover redis_connector.delete.reset() raise RuntimeError( "Connector deletion - documents still found after taskset completion" ) # clean up the rest of the related Postgres entities # index attempts delete_index_attempts( db_session=db_session, cc_pair_id=cc_pair_id, ) # permission sync attempts delete_doc_permission_sync_attempts__no_commit( db_session=db_session, cc_pair_id=cc_pair_id, ) delete_external_group_permission_sync_attempts__no_commit( db_session=db_session, cc_pair_id=cc_pair_id, ) # document sets delete_document_set_cc_pair_relationship__no_commit( db_session=db_session, connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id, ) # user groups cleanup_user_groups = fetch_versioned_implementation_with_fallback( "onyx.db.user_group", "delete_user_group_cc_pair_relationship__no_commit", noop_fallback, ) cleanup_user_groups( cc_pair_id=cc_pair_id, db_session=db_session, ) # delete orphan tags delete_orphan_tags__no_commit(db_session) # Store IDs before potentially expiring cc_pair connector_id_to_delete = cc_pair.connector_id credential_id_to_delete = cc_pair.credential_id # Explicitly delete document by connector credential pair records before deleting the connector # This is needed because connector_id is a primary key in that table and cascading deletes won't work delete_all_documents_by_connector_credential_pair__no_commit( db_session=db_session, connector_id=connector_id_to_delete, credential_id=credential_id_to_delete, ) # Flush to ensure document deletion happens before connector deletion db_session.flush() # Expire the cc_pair to ensure SQLAlchemy doesn't try to manage its state # related to the deleted DocumentByConnectorCredentialPair during commit db_session.expire(cc_pair) # finally, delete the cc-pair delete_connector_credential_pair__no_commit( db_session=db_session, connector_id=connector_id_to_delete, credential_id=credential_id_to_delete, ) # if there are no credentials left, delete the connector connector = fetch_connector_by_id( db_session=db_session, connector_id=connector_id_to_delete, ) if not connector or not len(connector.credentials): task_logger.info( "Connector deletion - Found no credentials left for connector, deleting connector" ) db_session.delete(connector) db_session.commit() update_sync_record_status( db_session=db_session, entity_id=cc_pair_id, sync_type=SyncType.CONNECTOR_DELETION, sync_status=SyncStatus.SUCCESS, num_docs_synced=fence_data.num_tasks, ) except Exception as e: db_session.rollback() stack_trace = traceback.format_exc() error_message = f"Error: {str(e)}\n\nStack Trace:\n{stack_trace}" add_deletion_failure_message(db_session, cc_pair_id, error_message) update_sync_record_status( db_session=db_session, entity_id=cc_pair_id, sync_type=SyncType.CONNECTOR_DELETION, sync_status=SyncStatus.FAILED, num_docs_synced=fence_data.num_tasks, ) task_logger.exception( f"Connector deletion exceptioned: " f"cc_pair={cc_pair_id} connector={connector_id_to_delete} credential={credential_id_to_delete}" ) raise e task_logger.info( f"Connector deletion succeeded: " f"cc_pair={cc_pair_id} " f"connector={connector_id_to_delete} " f"credential={credential_id_to_delete} " f"docs_deleted={fence_data.num_tasks}" ) redis_connector.delete.reset() def validate_connector_deletion_fences( tenant_id: str, r: Redis, r_replica: Redis, r_celery: Redis, lock_beat: RedisLock, ) -> None: # building lookup table can be expensive, so we won't bother # validating until the queue is small CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN = 1024 queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery) if queue_len > CONNECTION_DELETION_VALIDATION_MAX_QUEUE_LEN: return queued_upsert_tasks = celery_get_queued_task_ids( OnyxCeleryQueues.CONNECTOR_DELETION, r_celery ) # validate all existing connector deletion jobs lock_beat.reacquire() keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES)) for key in keys: key_bytes = cast(bytes, key) key_str = key_bytes.decode("utf-8") if not key_str.startswith(RedisConnectorDelete.FENCE_PREFIX): continue validate_connector_deletion_fence( tenant_id, key_bytes, queued_upsert_tasks, r, ) lock_beat.reacquire() return def validate_connector_deletion_fence( tenant_id: str, key_bytes: bytes, queued_upsert_tasks: set[str], r: Redis, ) -> None: """Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist. This can happen if the indexing worker hard crashes or is terminated. Being in this bad state means the fence will never clear without help, so this function gives the help. How this works: 1. This function renews the active signal with a 5 minute TTL under the following conditions 1.2. When the task is seen in the redis queue 1.3. When the task is seen in the reserved / prefetched list 2. Externally, the active signal is renewed when: 2.1. The fence is created 2.2. The indexing watchdog checks the spawned task. 3. The TTL allows us to get through the transitions on fence startup and when the task starts executing. More TTL clarification: it is seemingly impossible to exactly query Celery for whether a task is in the queue or currently executing. 1. An unknown task id is always returned as state PENDING. 2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task and the time it actually starts on the worker. queued_tasks: the celery queue of lightweight permission sync tasks reserved_tasks: prefetched tasks for sync task generator """ # if the fence doesn't exist, there's nothing to do fence_key = key_bytes.decode("utf-8") cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key) if cc_pair_id_str is None: task_logger.warning( f"validate_connector_deletion_fence - could not parse id from {fence_key}" ) return cc_pair_id = int(cc_pair_id_str) # parse out metadata and initialize the helper class with it redis_connector = RedisConnector(tenant_id, int(cc_pair_id)) # check to see if the fence/payload exists if not redis_connector.delete.fenced: return # in the cloud, the payload format may have changed ... # it's a little sloppy, but just reset the fence for now if that happens # TODO: add intentional cleanup/abort logic try: payload = redis_connector.delete.payload except ValidationError: task_logger.exception( "validate_connector_deletion_fence - " "Resetting fence because fence schema is out of date: " f"cc_pair={cc_pair_id} " f"fence={fence_key}" ) redis_connector.delete.reset() return if not payload: return # OK, there's actually something for us to validate # look up every task in the current taskset in the celery queue # every entry in the taskset should have an associated entry in the celery task queue # because we get the celery tasks first, the entries in our own permissions taskset # should be roughly a subset of the tasks in celery # this check isn't very exact, but should be sufficient over a period of time # A single successful check over some number of attempts is sufficient. # TODO: if the number of tasks in celery is much lower than than the taskset length # we might be able to shortcut the lookup since by definition some of the tasks # must not exist in celery. tasks_scanned = 0 tasks_not_in_celery = 0 # a non-zero number after completing our check is bad for member in r.sscan_iter(redis_connector.delete.taskset_key): tasks_scanned += 1 member_bytes = cast(bytes, member) member_str = member_bytes.decode("utf-8") if member_str in queued_upsert_tasks: continue tasks_not_in_celery += 1 task_logger.info( f"validate_connector_deletion_fence task check: tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}" ) # we're active if there are still tasks to run and those tasks all exist in celery if tasks_scanned > 0 and tasks_not_in_celery == 0: redis_connector.delete.set_active() return # we may want to enable this check if using the active task list somehow isn't good enough # if redis_connector_index.generator_locked(): # logger.info(f"{payload.celery_task_id} is currently executing.") # if we get here, we didn't find any direct indication that the associated celery tasks exist, # but they still might be there due to gaps in our ability to check states during transitions # Checking the active signal safeguards us against these transition periods # (which has a duration that allows us to bridge those gaps) if redis_connector.delete.active(): return # celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up. task_logger.warning( "validate_connector_deletion_fence - " "Resetting fence because no associated celery tasks were found: " f"cc_pair={cc_pair_id} " f"fence={fence_key}" ) redis_connector.delete.reset() return ================================================ FILE: backend/onyx/background/celery/tasks/docfetching/__init__.py ================================================ ================================================ FILE: backend/onyx/background/celery/tasks/docfetching/task_creation_utils.py ================================================ from uuid import uuid4 from celery import Celery from redis import Redis from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from onyx.background.celery.apps.app_base import task_logger from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.index_attempt import mark_attempt_failed from onyx.db.indexing_coordination import IndexingCoordination from onyx.db.models import ConnectorCredentialPair from onyx.db.models import SearchSettings def try_creating_docfetching_task( celery_app: Celery, cc_pair: ConnectorCredentialPair, search_settings: SearchSettings, reindex: bool, db_session: Session, r: Redis, tenant_id: str, ) -> int | None: """Checks for any conditions that should block the indexing task from being created, then creates the task. Does not check for scheduling related conditions as this function is used to trigger indexing immediately. Now uses database-based coordination instead of Redis fencing. """ LOCK_TIMEOUT = 30 # we need to serialize any attempt to trigger indexing since it can be triggered # either via celery beat or manually (API call) lock: RedisLock = r.lock( DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_indexing_task", timeout=LOCK_TIMEOUT, ) acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2) if not acquired: return None index_attempt_id = None try: # Basic status checks db_session.refresh(cc_pair) if cc_pair.status == ConnectorCredentialPairStatus.DELETING: return None # Generate custom task ID for tracking custom_task_id = f"docfetching_{cc_pair.id}_{search_settings.id}_{uuid4()}" # Try to create a new index attempt using database coordination # This replaces the Redis fencing mechanism index_attempt_id = IndexingCoordination.try_create_index_attempt( db_session=db_session, cc_pair_id=cc_pair.id, search_settings_id=search_settings.id, celery_task_id=custom_task_id, from_beginning=reindex, ) if index_attempt_id is None: # Another indexing attempt is already running return None # Use higher priority for first-time indexing to ensure new connectors # get processed before re-indexing of existing connectors has_successful_attempt = cc_pair.last_successful_index_time is not None priority = ( OnyxCeleryPriority.MEDIUM if has_successful_attempt else OnyxCeleryPriority.HIGH ) # Send the task to Celery result = celery_app.send_task( OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK, kwargs=dict( index_attempt_id=index_attempt_id, cc_pair_id=cc_pair.id, search_settings_id=search_settings.id, tenant_id=tenant_id, ), queue=OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, task_id=custom_task_id, priority=priority, ) if not result: raise RuntimeError("send_task for connector_doc_fetching_task failed.") task_logger.info( f"Created docfetching task: " f"cc_pair={cc_pair.id} " f"search_settings={search_settings.id} " f"attempt_id={index_attempt_id} " f"celery_task_id={custom_task_id}" ) return index_attempt_id except Exception: task_logger.exception( f"try_creating_indexing_task - Unexpected exception: cc_pair={cc_pair.id} search_settings={search_settings.id}" ) # Clean up on failure if index_attempt_id is not None: mark_attempt_failed(index_attempt_id, db_session) return None finally: if lock.owned(): lock.release() return index_attempt_id ================================================ FILE: backend/onyx/background/celery/tasks/docfetching/tasks.py ================================================ import multiprocessing import os import time import traceback from time import sleep import sentry_sdk from celery import Celery from celery import shared_task from celery import Task from onyx import __version__ from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.memory_monitoring import emit_process_memory from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat from onyx.background.celery.tasks.docprocessing.heartbeat import stop_heartbeat from onyx.background.celery.tasks.docprocessing.tasks import ConnectorIndexingLogBuilder from onyx.background.celery.tasks.docprocessing.utils import IndexingCallback from onyx.background.celery.tasks.models import DocProcessingContext from onyx.background.celery.tasks.models import IndexingWatchdogTerminalStatus from onyx.background.celery.tasks.models import SimpleJobResult from onyx.background.indexing.job_client import SimpleJob from onyx.background.indexing.job_client import SimpleJobClient from onyx.background.indexing.job_client import SimpleJobException from onyx.background.indexing.run_docfetching import run_docfetching_entrypoint from onyx.configs.constants import CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT from onyx.configs.constants import OnyxCeleryTask from onyx.connectors.exceptions import ConnectorValidationError from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.enums import IndexingStatus from onyx.db.index_attempt import get_index_attempt from onyx.db.index_attempt import mark_attempt_canceled from onyx.db.index_attempt import mark_attempt_failed from onyx.db.indexing_coordination import IndexingCoordination from onyx.redis.redis_connector import RedisConnector from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import global_version from shared_configs.configs import SENTRY_DSN logger = setup_logger() def _verify_indexing_attempt( index_attempt_id: int, cc_pair_id: int, search_settings_id: int, ) -> None: """ Verify that the indexing attempt exists and is in the correct state. """ with get_session_with_current_tenant() as db_session: attempt = get_index_attempt(db_session, index_attempt_id) if not attempt: raise SimpleJobException( f"docfetching_task - IndexAttempt not found: attempt_id={index_attempt_id}", code=IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND.code, ) if attempt.connector_credential_pair_id != cc_pair_id: raise SimpleJobException( f"docfetching_task - CC pair mismatch: expected={cc_pair_id} actual={attempt.connector_credential_pair_id}", code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code, ) if attempt.search_settings_id != search_settings_id: raise SimpleJobException( f"docfetching_task - Search settings mismatch: expected={search_settings_id} actual={attempt.search_settings_id}", code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code, ) if attempt.status not in [ IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS, ]: raise SimpleJobException( f"docfetching_task - Invalid attempt status: attempt_id={index_attempt_id} status={attempt.status}", code=IndexingWatchdogTerminalStatus.FENCE_MISMATCH.code, ) # Check for cancellation if IndexingCoordination.check_cancellation_requested( db_session, index_attempt_id ): raise SimpleJobException( f"docfetching_task - Cancellation requested: attempt_id={index_attempt_id}", code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code, ) logger.info( f"docfetching_task - IndexAttempt verified: " f"attempt_id={index_attempt_id} " f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id}" ) def docfetching_task( app: Celery, index_attempt_id: int, cc_pair_id: int, search_settings_id: int, is_ee: bool, tenant_id: str, ) -> None: """ This function is run in a SimpleJob as a new process. It is responsible for validating some stuff, but basically it just calls run_indexing_entrypoint. NOTE: if an exception is raised out of this task, the primary worker will detect that the task transitioned to a "READY" state but the generator_complete_key doesn't exist. This will cause the primary worker to abort the indexing attempt and clean up. """ # Start heartbeat for this indexing attempt heartbeat_thread, stop_event = start_heartbeat(index_attempt_id) try: _docfetching_task( app, index_attempt_id, cc_pair_id, search_settings_id, is_ee, tenant_id ) finally: stop_heartbeat(heartbeat_thread, stop_event) # Stop heartbeat before exiting def _docfetching_task( app: Celery, index_attempt_id: int, cc_pair_id: int, search_settings_id: int, is_ee: bool, tenant_id: str, ) -> None: # Since connector_indexing_proxy_task spawns a new process using this function as # the entrypoint, we init Sentry here. if SENTRY_DSN: sentry_sdk.init( dsn=SENTRY_DSN, traces_sample_rate=0.1, release=__version__, ) logger.info("Sentry initialized") else: logger.debug("Sentry DSN not provided, skipping Sentry initialization") logger.info( f"Indexing spawned task starting: " f"attempt={index_attempt_id} " f"tenant={tenant_id} " f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id}" ) redis_connector = RedisConnector(tenant_id, cc_pair_id) # TODO: remove all fences, cause all signals to be set in postgres if redis_connector.delete.fenced: raise SimpleJobException( f"Indexing will not start because connector deletion is in progress: " f"attempt={index_attempt_id} " f"cc_pair={cc_pair_id} " f"fence={redis_connector.delete.fence_key}", code=IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION.code, ) if redis_connector.stop.fenced: raise SimpleJobException( f"Indexing will not start because a connector stop signal was detected: " f"attempt={index_attempt_id} " f"cc_pair={cc_pair_id} " f"fence={redis_connector.stop.fence_key}", code=IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL.code, ) # Verify the indexing attempt exists and is valid # This replaces the Redis fence payload waiting _verify_indexing_attempt(index_attempt_id, cc_pair_id, search_settings_id) try: with get_session_with_current_tenant() as db_session: attempt = get_index_attempt(db_session, index_attempt_id) if not attempt: raise SimpleJobException( f"Index attempt not found: index_attempt={index_attempt_id}", code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code, ) cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, ) if not cc_pair: raise SimpleJobException( f"cc_pair not found: cc_pair={cc_pair_id}", code=IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH.code, ) # define a callback class callback = IndexingCallback( redis_connector, ) logger.info( f"Indexing spawned task running entrypoint: attempt={index_attempt_id} " f"tenant={tenant_id} " f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id}" ) # This is where the heavy/real work happens run_docfetching_entrypoint( app, index_attempt_id, tenant_id, cc_pair_id, is_ee, callback=callback, ) except ConnectorValidationError: raise SimpleJobException( f"Indexing task failed: attempt={index_attempt_id} " f"tenant={tenant_id} " f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id}", code=IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR.code, ) except Exception as e: logger.exception( f"Indexing spawned task failed: attempt={index_attempt_id} " f"tenant={tenant_id} " f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id}" ) # special bulletproofing ... truncate long exception messages # for exception types that require more args, this will fail # thus the try/except try: sanitized_e = type(e)(str(e)[:1024]) sanitized_e.__traceback__ = e.__traceback__ raise sanitized_e except Exception: raise e logger.info( f"Indexing spawned task finished: attempt={index_attempt_id} cc_pair={cc_pair_id} search_settings={search_settings_id}" ) os._exit(0) # ensure process exits cleanly def process_job_result( job: SimpleJob, connector_source: str | None, index_attempt_id: int, log_builder: ConnectorIndexingLogBuilder, ) -> SimpleJobResult: result = SimpleJobResult() result.connector_source = connector_source if job.process: result.exit_code = job.process.exitcode if job.status != "error": result.status = IndexingWatchdogTerminalStatus.SUCCEEDED return result ignore_exitcode = False # In EKS, there is an edge case where successful tasks return exit # code 1 in the cloud due to the set_spawn_method not sticking. # Workaround: check that the total number of batches is set, since this only # happens when docfetching completed successfully with get_session_with_current_tenant() as db_session: index_attempt = get_index_attempt(db_session, index_attempt_id) if index_attempt and index_attempt.total_batches is not None: ignore_exitcode = True if ignore_exitcode: result.status = IndexingWatchdogTerminalStatus.SUCCEEDED task_logger.warning( log_builder.build( "Indexing watchdog - spawned task has non-zero exit code but completion signal is OK. Continuing...", exit_code=str(result.exit_code), ) ) else: if result.exit_code is not None: result.status = IndexingWatchdogTerminalStatus.from_code(result.exit_code) job_level_exception = job.exception() result.exception_str = f"Docfetching returned exit code {result.exit_code} with exception: {job_level_exception}" return result @shared_task( name=OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK, bind=True, acks_late=False, track_started=True, ) def docfetching_proxy_task( self: Task, index_attempt_id: int, cc_pair_id: int, search_settings_id: int, tenant_id: str, ) -> None: """ This task is the entrypoint for the full indexing pipeline, which is composed of two tasks: docfetching and docprocessing. This task is spawned by "try_creating_indexing_task" which is called in the "check_for_indexing" task. This task spawns a new process for a new scheduled index attempt. That new process (which runs the docfetching_task function) does the following: 1) determines parameters of the indexing attempt (which connector indexing function to run, start and end time, from prev checkpoint or not), then run that connector. Specifically, connectors are responsible for reading data from an outside source and converting it to Onyx documents. At the moment these two steps (reading external data and converting to an Onyx document) are not parallelized in most connectors; that's a subject for future work. Each document batch produced by step 1 is stored in the file store, and a docprocessing task is spawned to process it. docprocessing involves the steps listed below. 2) upserts documents to postgres (index_doc_batch_prepare) 3) chunks each document (optionally adds context for contextual rag) 4) embeds chunks (embed_chunks_with_failure_handling) via a call to the model server 5) write chunks to vespa (write_chunks_to_vector_db_with_backoff) 6) update document and indexing metadata in postgres 7) pulls all document IDs from the source and compares those IDs to locally stored documents and deletes all locally stored IDs missing from the most recently pulled document ID list Some important notes: Invariants: - docfetching proxy tasks are spawned by check_for_indexing. The proxy then runs the docfetching_task wrapped in a watchdog. The watchdog is responsible for monitoring the docfetching_task and marking the index attempt as failed if it is not making progress. - All docprocessing tasks are spawned by a docfetching task. - all docfetching tasks, docprocessing tasks, and document batches in the file store are associated with a specific index attempt. - the index attempt status is the source of truth for what is currently happening with the index attempt. It is coupled with the creation/running of docfetching and docprocessing tasks as much as possible. How we deal with failures/ partial indexing: - non-checkpointed connectors/ new runs in general => delete the old document batches from the file store and do the new run - checkpointed connectors + resuming from checkpoint => reissue the old document batches and do a new run Misc: - most inter-process communication is handled in postgres, some is still in redis and we're trying to remove it - Heartbeat spawned in docfetching and docprocessing is how check_for_indexing monitors liveliness - progress based liveliness check: if nothing is done in 3-6 hours, mark the attempt as failed - TODO: task level timeouts (i.e. a connector stuck in an infinite loop) Comments below are from the old version and some may no longer be valid. TODO(rkuo): refactor this so that there is a single return path where we canonically log the result of running this function. Some more Richard notes: celery out of process task execution strategy is pool=prefork, but it uses fork, and forking is inherently unstable. To work around this, we use pool=threads and proxy our work to a spawned task. acks_late must be set to False. Otherwise, celery's visibility timeout will cause any task that runs longer than the timeout to be redispatched by the broker. There appears to be no good workaround for this, so we need to handle redispatching manually. NOTE: we try/except all db access in this function because as a watchdog, this function needs to be extremely stable. """ # TODO: remove dependence on Redis start = time.monotonic() result = SimpleJobResult() ctx = DocProcessingContext( tenant_id=tenant_id, cc_pair_id=cc_pair_id, search_settings_id=search_settings_id, index_attempt_id=index_attempt_id, ) log_builder = ConnectorIndexingLogBuilder(ctx) task_logger.info( log_builder.build( "Indexing watchdog - starting", mp_start_method=str(multiprocessing.get_start_method()), ) ) if not self.request.id: task_logger.error("self.request.id is None!") client = SimpleJobClient() task_logger.info(f"submitting docfetching_task with tenant_id={tenant_id}") job = client.submit( docfetching_task, self.app, index_attempt_id, cc_pair_id, search_settings_id, global_version.is_ee_version(), tenant_id, ) if not job or not job.process: result.status = IndexingWatchdogTerminalStatus.SPAWN_FAILED task_logger.info( log_builder.build( "Indexing watchdog - finished", status=str(result.status.value), exit_code=str(result.exit_code), ) ) return # Ensure the process has moved out of the starting state num_waits = 0 while True: if num_waits > 15: result.status = IndexingWatchdogTerminalStatus.SPAWN_NOT_ALIVE task_logger.info( log_builder.build( "Indexing watchdog - finished", status=str(result.status.value), exit_code=str(result.exit_code), ) ) job.release() return if job.process.is_alive() or job.process.exitcode is not None: break sleep(1) num_waits += 1 task_logger.info( log_builder.build( "Indexing watchdog - spawn succeeded", pid=str(job.process.pid), ) ) # Track the last time memory info was emitted last_memory_emit_time = 0.0 try: with get_session_with_current_tenant() as db_session: index_attempt = get_index_attempt( db_session=db_session, index_attempt_id=index_attempt_id, eager_load_cc_pair=True, ) if not index_attempt: raise RuntimeError("Index attempt not found") result.connector_source = ( index_attempt.connector_credential_pair.connector.source.value ) while True: sleep(5) time.monotonic() # if the job is done, clean up and break if job.done(): try: result = process_job_result( job, result.connector_source, index_attempt_id, log_builder ) except Exception: task_logger.exception( log_builder.build( "Indexing watchdog - spawned task exceptioned" ) ) finally: job.release() break # log the memory usage for tracking down memory leaks / connector-specific memory issues pid = job.process.pid if pid is not None: # Only emit memory info once per minute (60 seconds) current_time = time.monotonic() if current_time - last_memory_emit_time >= 60.0: emit_process_memory( pid, "indexing_worker", { "cc_pair_id": cc_pair_id, "search_settings_id": search_settings_id, "index_attempt_id": index_attempt_id, }, ) last_memory_emit_time = current_time # if the spawned task is still running, restart the check once again # if the index attempt is not in a finished status try: with get_session_with_current_tenant() as db_session: index_attempt = get_index_attempt( db_session=db_session, index_attempt_id=index_attempt_id ) if not index_attempt: continue if not index_attempt.is_finished(): continue except Exception: task_logger.exception( log_builder.build( "Indexing watchdog - transient exception looking up index attempt" ) ) continue except Exception as e: result.status = IndexingWatchdogTerminalStatus.WATCHDOG_EXCEPTIONED if isinstance(e, ConnectorValidationError): # No need to expose full stack trace for validation errors result.exception_str = str(e) else: result.exception_str = traceback.format_exc() # handle exit and reporting elapsed = time.monotonic() - start if result.exception_str is not None: # print with exception try: with get_session_with_current_tenant() as db_session: attempt = get_index_attempt(db_session, ctx.index_attempt_id) # only mark failures if not already terminal, # otherwise we're overwriting potential real stack traces if attempt and not attempt.status.is_terminal(): failure_reason = ( f"Spawned task exceptioned: exit_code={result.exit_code}" ) mark_attempt_failed( ctx.index_attempt_id, db_session, failure_reason=failure_reason, full_exception_trace=result.exception_str, ) except Exception: task_logger.exception( log_builder.build( "Indexing watchdog - transient exception marking index attempt as failed" ) ) normalized_exception_str = "None" if result.exception_str: normalized_exception_str = result.exception_str.replace( "\n", "\\n" ).replace('"', '\\"') task_logger.warning( log_builder.build( "Indexing watchdog - finished", source=result.connector_source, status=result.status.value, exit_code=str(result.exit_code), exception=f'"{normalized_exception_str}"', elapsed=f"{elapsed:.2f}s", ) ) raise RuntimeError(f"Exception encountered: traceback={result.exception_str}") # print without exception if result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_SIGNAL: try: with get_session_with_current_tenant() as db_session: logger.exception( f"Marking attempt {index_attempt_id} as canceled due to termination signal" ) mark_attempt_canceled( index_attempt_id, db_session, "Connector termination signal detected", ) except Exception: task_logger.exception( log_builder.build( "Indexing watchdog - transient exception marking index attempt as canceled" ) ) job.cancel() elif result.status == IndexingWatchdogTerminalStatus.TERMINATED_BY_ACTIVITY_TIMEOUT: try: with get_session_with_current_tenant() as db_session: mark_attempt_failed( index_attempt_id, db_session, "Indexing watchdog - activity timeout exceeded: " f"attempt={index_attempt_id} " f"timeout={CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT}s", ) except Exception: logger.exception( log_builder.build( "Indexing watchdog - transient exception marking index attempt as failed" ) ) job.cancel() else: pass task_logger.info( log_builder.build( "Indexing watchdog - finished", source=result.connector_source, status=str(result.status.value), exit_code=str(result.exit_code), elapsed=f"{elapsed:.2f}s", ) ) ================================================ FILE: backend/onyx/background/celery/tasks/docprocessing/__init__.py ================================================ ================================================ FILE: backend/onyx/background/celery/tasks/docprocessing/heartbeat.py ================================================ import contextvars import threading from sqlalchemy import update from onyx.configs.constants import INDEXING_WORKER_HEARTBEAT_INTERVAL from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import IndexAttempt from onyx.utils.logger import setup_logger logger = setup_logger() def start_heartbeat(index_attempt_id: int) -> tuple[threading.Thread, threading.Event]: """Start a heartbeat thread for the given index attempt""" stop_event = threading.Event() def heartbeat_loop() -> None: while not stop_event.wait(INDEXING_WORKER_HEARTBEAT_INTERVAL): try: with get_session_with_current_tenant() as db_session: db_session.execute( update(IndexAttempt) .where(IndexAttempt.id == index_attempt_id) .values(heartbeat_counter=IndexAttempt.heartbeat_counter + 1) ) db_session.commit() except Exception: logger.exception( "Failed to update heartbeat counter for index attempt %s", index_attempt_id, ) # Ensure contextvars from the outer context are available in the thread context = contextvars.copy_context() thread = threading.Thread(target=context.run, args=(heartbeat_loop,), daemon=True) thread.start() return thread, stop_event def stop_heartbeat(thread: threading.Thread, stop_event: threading.Event) -> None: """Stop the heartbeat thread""" stop_event.set() thread.join(timeout=5) # Wait up to 5 seconds for clean shutdown ================================================ FILE: backend/onyx/background/celery/tasks/docprocessing/tasks.py ================================================ import gc import os import time import traceback from collections import defaultdict from datetime import datetime from datetime import timedelta from datetime import timezone from typing import Any from celery import Celery from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded from fastapi import HTTPException from pydantic import BaseModel from redis import Redis from redis.lock import Lock as RedisLock from sqlalchemy import exists from sqlalchemy import select from sqlalchemy.orm import Session from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.celery_redis import celery_find_task from onyx.background.celery.celery_redis import celery_get_broker_client from onyx.background.celery.celery_redis import celery_get_unacked_task_ids from onyx.background.celery.celery_utils import httpx_init_vespa_pool from onyx.background.celery.memory_monitoring import emit_process_memory from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT from onyx.background.celery.tasks.docfetching.task_creation_utils import ( try_creating_docfetching_task, ) from onyx.background.celery.tasks.docprocessing.heartbeat import start_heartbeat from onyx.background.celery.tasks.docprocessing.heartbeat import stop_heartbeat from onyx.background.celery.tasks.docprocessing.utils import IndexingCallback from onyx.background.celery.tasks.docprocessing.utils import is_in_repeated_error_state from onyx.background.celery.tasks.docprocessing.utils import should_index from onyx.background.celery.tasks.models import DocProcessingContext from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint from onyx.background.indexing.checkpointing_utils import ( get_index_attempts_with_old_checkpoints, ) from onyx.background.indexing.index_attempt_utils import cleanup_index_attempts from onyx.background.indexing.index_attempt_utils import get_old_index_attempts from onyx.configs.app_configs import AUTH_TYPE from onyx.configs.app_configs import MANAGED_VESPA from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH from onyx.configs.constants import AuthType from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import CELERY_INDEXING_LOCK_TIMEOUT from onyx.configs.constants import MilestoneRecordType from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisConstants from onyx.configs.constants import OnyxRedisLocks from onyx.configs.constants import OnyxRedisSignals from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import Document from onyx.connectors.models import IndexAttemptMetadata from onyx.db.connector import mark_ccpair_with_indexing_trigger from onyx.db.connector_credential_pair import ( fetch_indexable_standard_connector_credential_pair_ids, ) from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.connector_credential_pair import set_cc_pair_repeated_error_state from onyx.db.connector_credential_pair import update_connector_credential_pair_from_id from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.engine.time_utils import get_db_current_time from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import IndexingMode from onyx.db.enums import IndexingStatus from onyx.db.enums import SwitchoverType from onyx.db.index_attempt import create_index_attempt_error from onyx.db.index_attempt import get_index_attempt from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair from onyx.db.index_attempt import IndexAttemptError from onyx.db.index_attempt import mark_attempt_canceled from onyx.db.index_attempt import mark_attempt_failed from onyx.db.index_attempt import mark_attempt_partially_succeeded from onyx.db.index_attempt import mark_attempt_succeeded from onyx.db.indexing_coordination import CoordinationStatus from onyx.db.indexing_coordination import INDEXING_PROGRESS_TIMEOUT_HOURS from onyx.db.indexing_coordination import IndexingCoordination from onyx.db.models import IndexAttempt from onyx.db.models import SearchSettings from onyx.db.search_settings import get_current_search_settings from onyx.db.search_settings import get_secondary_search_settings from onyx.db.swap_index import check_and_perform_index_swap from onyx.document_index.factory import get_all_document_indices from onyx.file_store.document_batch_storage import DocumentBatchStorage from onyx.file_store.document_batch_storage import get_document_batch_storage from onyx.httpx.httpx_pool import HttpxPool from onyx.indexing.adapters.document_indexing_adapter import ( DocumentIndexingBatchAdapter, ) from onyx.indexing.embedder import DefaultIndexingEmbedder from onyx.indexing.indexing_pipeline import run_indexing_pipeline from onyx.natural_language_processing.search_nlp_models import EmbeddingModel from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder from onyx.redis.redis_connector import RedisConnector from onyx.redis.redis_pool import get_redis_client from onyx.redis.redis_pool import get_redis_replica_client from onyx.redis.redis_pool import redis_lock_dump from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT from onyx.redis.redis_utils import is_fence from onyx.server.runtime.onyx_runtime import OnyxRuntime from onyx.utils.logger import setup_logger from onyx.utils.middleware import make_randomized_onyx_request_id from onyx.utils.telemetry import mt_cloud_telemetry from onyx.utils.telemetry import optional_telemetry from onyx.utils.telemetry import RecordType from shared_configs.configs import INDEXING_MODEL_SERVER_HOST from shared_configs.configs import INDEXING_MODEL_SERVER_PORT from shared_configs.configs import MULTI_TENANT from shared_configs.configs import USAGE_LIMITS_ENABLED from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR logger = setup_logger() DOCPROCESSING_STALL_TIMEOUT_MULTIPLIER = 4 DOCPROCESSING_HEARTBEAT_TIMEOUT_MULTIPLIER = 24 # Heartbeat timeout: if no heartbeat received for 30 minutes, consider it dead # This should be much longer than INDEXING_WORKER_HEARTBEAT_INTERVAL (30s) HEARTBEAT_TIMEOUT_SECONDS = 30 * 60 # 30 minutes INDEX_ATTEMPT_BATCH_SIZE = 500 def _get_fence_validation_block_expiration() -> int: """ Compute the expiration time for the fence validation block signal. Base expiration is 60 seconds, multiplied by the beat multiplier only in MULTI_TENANT mode. """ base_expiration = 60 # seconds if not MULTI_TENANT: return base_expiration try: beat_multiplier = OnyxRuntime.get_beat_multiplier() except Exception: beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT return int(base_expiration * beat_multiplier) def validate_active_indexing_attempts( lock_beat: RedisLock, ) -> None: """ Validates that active indexing attempts are still alive by checking heartbeat. If no heartbeat has been received for a certain amount of time, mark the attempt as failed. This uses the heartbeat_counter field which is incremented by active worker threads every INDEXING_WORKER_HEARTBEAT_INTERVAL seconds. """ logger.info("Validating active indexing attempts") with get_session_with_current_tenant() as db_session: # Find all active indexing attempts active_attempts = ( db_session.execute( select(IndexAttempt).where( IndexAttempt.status.in_([IndexingStatus.IN_PROGRESS]), IndexAttempt.celery_task_id.isnot(None), ) ) .scalars() .all() ) for attempt in active_attempts: lock_beat.reacquire() # Initialize timeout for each attempt to prevent state pollution heartbeat_timeout_seconds = HEARTBEAT_TIMEOUT_SECONDS # Double-check the attempt still exists and has the same status fresh_attempt = get_index_attempt(db_session, attempt.id) if not fresh_attempt or fresh_attempt.status.is_terminal(): continue # Check if this attempt has been updated with heartbeat tracking if fresh_attempt.last_heartbeat_time is None: # First time seeing this attempt - initialize heartbeat tracking fresh_attempt.last_heartbeat_value = fresh_attempt.heartbeat_counter fresh_attempt.last_heartbeat_time = datetime.now(timezone.utc) db_session.commit() task_logger.info( f"Initialized heartbeat tracking for attempt {fresh_attempt.id}: counter={fresh_attempt.heartbeat_counter}" ) continue # Check if the heartbeat counter has advanced since last check current_counter = fresh_attempt.heartbeat_counter last_known_counter = fresh_attempt.last_heartbeat_value last_check_time = fresh_attempt.last_heartbeat_time task_logger.debug( f"Checking heartbeat for attempt {fresh_attempt.id}: " f"current_counter={current_counter} " f"last_known_counter={last_known_counter} " f"last_check_time={last_check_time}" ) if current_counter > last_known_counter: # Heartbeat has advanced - worker is alive fresh_attempt.last_heartbeat_value = current_counter fresh_attempt.last_heartbeat_time = datetime.now(timezone.utc) db_session.commit() task_logger.debug( f"Heartbeat advanced for attempt {fresh_attempt.id}: new_counter={current_counter}" ) continue if fresh_attempt.total_batches and fresh_attempt.completed_batches == 0: heartbeat_timeout_seconds = ( HEARTBEAT_TIMEOUT_SECONDS * DOCPROCESSING_HEARTBEAT_TIMEOUT_MULTIPLIER ) cutoff_time = datetime.now(timezone.utc) - timedelta( seconds=heartbeat_timeout_seconds ) # Heartbeat hasn't advanced - check if it's been too long if last_check_time >= cutoff_time: task_logger.debug( f"Heartbeat hasn't advanced for attempt {fresh_attempt.id} but still within timeout window" ) continue # No heartbeat for too long - mark as failed failure_reason = ( f"No heartbeat received for {heartbeat_timeout_seconds} seconds" ) task_logger.warning( f"Heartbeat timeout for attempt {fresh_attempt.id}: " f"last_heartbeat_time={last_check_time} " f"cutoff_time={cutoff_time} " f"counter={current_counter}" ) try: mark_attempt_failed( fresh_attempt.id, db_session, failure_reason=failure_reason, ) task_logger.error( f"Marked attempt {fresh_attempt.id} as failed due to heartbeat timeout" ) except Exception: task_logger.exception( f"Failed to mark attempt {fresh_attempt.id} as failed due to heartbeat timeout" ) class ConnectorIndexingLogBuilder: def __init__(self, ctx: DocProcessingContext): self.ctx = ctx def build(self, msg: str, **kwargs: Any) -> str: msg_final = ( f"{msg}: " f"tenant_id={self.ctx.tenant_id} " f"attempt={self.ctx.index_attempt_id} " f"cc_pair={self.ctx.cc_pair_id} " f"search_settings={self.ctx.search_settings_id}" ) # Append extra keyword arguments in logfmt style if kwargs: extra_logfmt = " ".join(f"{key}={value}" for key, value in kwargs.items()) msg_final = f"{msg_final} {extra_logfmt}" return msg_final def monitor_indexing_attempt_progress( attempt: IndexAttempt, tenant_id: str, db_session: Session, task: Task ) -> None: """ TODO: rewrite this docstring Monitor the progress of an indexing attempt using database coordination. This replaces the Redis fence-based monitoring. Race condition handling: - Uses database coordination status to track progress - Only updates CC pair status based on confirmed database state - Handles concurrent completion gracefully """ if not attempt.celery_task_id: # Attempt hasn't been assigned a task yet return cc_pair = get_connector_credential_pair_from_id( db_session, attempt.connector_credential_pair_id ) if not cc_pair: task_logger.warning(f"CC pair not found for attempt {attempt.id}") return # Check if the CC Pair should be moved to INITIAL_INDEXING if cc_pair.status == ConnectorCredentialPairStatus.SCHEDULED: cc_pair.status = ConnectorCredentialPairStatus.INITIAL_INDEXING db_session.commit() # Get coordination status to track progress coordination_status = IndexingCoordination.get_coordination_status( db_session, attempt.id ) current_db_time = get_db_current_time(db_session) total_batches: int | str = ( coordination_status.total_batches if coordination_status.total_batches is not None else "?" ) if coordination_status.found: task_logger.info( f"Indexing attempt progress: " f"attempt={attempt.id} " f"cc_pair={attempt.connector_credential_pair_id} " f"search_settings={attempt.search_settings_id} " f"completed_batches={coordination_status.completed_batches} " f"total_batches={total_batches} " f"total_docs={coordination_status.total_docs} " f"total_failures={coordination_status.total_failures}" f"elapsed={(current_db_time - attempt.time_created).seconds}" ) if coordination_status.cancellation_requested: task_logger.info(f"Indexing attempt {attempt.id} has been cancelled") mark_attempt_canceled(attempt.id, db_session) return storage = get_document_batch_storage( attempt.connector_credential_pair_id, attempt.id ) # Check task completion using Celery try: check_indexing_completion( attempt.id, coordination_status, storage, tenant_id, task ) except Exception as e: logger.exception( f"Failed to monitor document processing completion: attempt={attempt.id} error={str(e)}" ) # Mark the attempt as failed if monitoring fails try: with get_session_with_current_tenant() as db_session: mark_attempt_failed( attempt.id, db_session, failure_reason=f"Processing monitoring failed: {str(e)}", full_exception_trace=traceback.format_exc(), ) except Exception: logger.exception("Failed to mark attempt as failed") # Try to clean up storage try: logger.info(f"Cleaning up storage after monitoring failure: {storage}") storage.cleanup_all_batches() except Exception: logger.exception("Failed to cleanup storage after monitoring failure") def _resolve_indexing_entity_errors( cc_pair_id: int, db_session: Session, ) -> None: unresolved_errors = get_index_attempt_errors_for_cc_pair( cc_pair_id=cc_pair_id, unresolved_only=True, db_session=db_session, ) for error in unresolved_errors: if error.entity_id: error.is_resolved = True db_session.add(error) db_session.commit() def check_indexing_completion( index_attempt_id: int, coordination_status: CoordinationStatus, storage: DocumentBatchStorage, tenant_id: str, task: Task, ) -> None: logger.info( f"Checking for indexing completion: attempt={index_attempt_id} tenant={tenant_id}" ) # Check if indexing is complete and all batches are processed batches_total = coordination_status.total_batches batches_processed = coordination_status.completed_batches indexing_completed = ( batches_total is not None and batches_processed >= batches_total ) logger.info( f"Indexing status: " f"indexing_completed={indexing_completed} " f"batches_processed={batches_processed}/{batches_total if batches_total is not None else '?'} " f"total_docs={coordination_status.total_docs} " f"total_chunks={coordination_status.total_chunks} " f"total_failures={coordination_status.total_failures}" ) # Update progress tracking and check for stalls with get_session_with_current_tenant() as db_session: stalled_timeout_hours = INDEXING_PROGRESS_TIMEOUT_HOURS # Index attempts that are waiting between docfetching and # docprocessing get a generous stalling timeout if batches_total is not None and batches_processed == 0: stalled_timeout_hours = ( stalled_timeout_hours * DOCPROCESSING_STALL_TIMEOUT_MULTIPLIER ) timed_out = not IndexingCoordination.update_progress_tracking( db_session, index_attempt_id, batches_processed, timeout_hours=stalled_timeout_hours, ) # Check for stalls (3-6 hour timeout). Only applies to in-progress attempts. attempt = get_index_attempt(db_session, index_attempt_id) if attempt and timed_out: if attempt.status == IndexingStatus.IN_PROGRESS: logger.error( f"Indexing attempt {index_attempt_id} has been indexing for " f"{stalled_timeout_hours // 2}-{stalled_timeout_hours} hours without progress. " f"Marking it as failed." ) mark_attempt_failed( index_attempt_id, db_session, failure_reason="Stalled indexing" ) elif ( attempt.status == IndexingStatus.NOT_STARTED and attempt.celery_task_id ): # Check if the task exists in the celery queue # This handles the case where Redis dies after task creation but before task execution redis_celery = celery_get_broker_client(task.app) task_exists = celery_find_task( attempt.celery_task_id, OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, redis_celery, ) unacked_task_ids = celery_get_unacked_task_ids( OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, redis_celery ) if not task_exists and attempt.celery_task_id not in unacked_task_ids: # there is a race condition where the docfetching task has been taken off # the queues (i.e. started) but the indexing attempt still has a status of # Not Started because the switch to in progress takes like 0.1 seconds. # sleep a bit and confirm that the attempt is still not in progress. time.sleep(1) attempt = get_index_attempt(db_session, index_attempt_id) if attempt and attempt.status == IndexingStatus.NOT_STARTED: logger.error( f"Task {attempt.celery_task_id} attached to indexing attempt " f"{index_attempt_id} does not exist in the queue. " f"Marking indexing attempt as failed." ) mark_attempt_failed( index_attempt_id, db_session, failure_reason="Task not in queue", ) else: logger.info( f"Indexing attempt {index_attempt_id} is {attempt.status}. 3-6 hours without heartbeat " "but task is in the queue. Likely underprovisioned docfetching worker." ) # Update last progress time so we won't time out again for another 3 hours IndexingCoordination.update_progress_tracking( db_session, index_attempt_id, batches_processed, force_update_progress=True, ) # check again on the next check_for_indexing task # TODO: on the cloud this is currently 25 minutes at most, which # is honestly too slow. We should either increase the frequency of # this task or change where we check for completion. if not indexing_completed: return # If processing is complete, handle completion logger.info(f"Connector indexing finished for index attempt {index_attempt_id}.") # All processing is complete total_failures = coordination_status.total_failures with get_session_with_current_tenant() as db_session: if total_failures == 0: attempt = mark_attempt_succeeded(index_attempt_id, db_session) logger.info(f"Index attempt {index_attempt_id} completed successfully") else: attempt = mark_attempt_partially_succeeded(index_attempt_id, db_session) logger.info( f"Index attempt {index_attempt_id} completed with {total_failures} failures" ) # Update CC pair status if successful cc_pair = get_connector_credential_pair_from_id( db_session, attempt.connector_credential_pair_id ) if cc_pair is None: raise RuntimeError( f"CC pair {attempt.connector_credential_pair_id} not found in database" ) if attempt.status.is_successful(): # NOTE: we define the last successful index time as the time the last successful # attempt finished. This is distinct from the poll_range_end of the last successful # attempt, which is the time up to which documents have been fetched. cc_pair.last_successful_index_time = attempt.time_updated if cc_pair.status in [ ConnectorCredentialPairStatus.SCHEDULED, ConnectorCredentialPairStatus.INITIAL_INDEXING, ]: # User file connectors must be paused on success # NOTE: _run_indexing doesn't update connectors if the index attempt is the future embedding model cc_pair.status = ConnectorCredentialPairStatus.ACTIVE db_session.commit() mt_cloud_telemetry( tenant_id=tenant_id, distinct_id=tenant_id, event=MilestoneRecordType.CONNECTOR_SUCCEEDED, ) # Clear repeated error state on success if cc_pair.in_repeated_error_state: cc_pair.in_repeated_error_state = False db_session.commit() if attempt.status == IndexingStatus.SUCCESS: logger.info( f"Resolving indexing entity errors for attempt {index_attempt_id}" ) _resolve_indexing_entity_errors( cc_pair_id=attempt.connector_credential_pair_id, db_session=db_session, ) # Clean up FileStore storage (still needed for document batches during transition) try: logger.info(f"Cleaning up storage after indexing completion: {storage}") storage.cleanup_all_batches() except Exception: logger.exception("Failed to clean up document batches - continuing") logger.info(f"Database coordination completed for attempt {index_attempt_id}") def active_indexing_attempt( cc_pair_id: int, search_settings_id: int, db_session: Session, ) -> bool: """ Check if there's already an active indexing attempt for this CC pair + search settings. This prevents race conditions where multiple indexing attempts could be created. We check for any non-terminal status (NOT_STARTED, IN_PROGRESS). Returns True if there's an active indexing attempt, False otherwise. """ active_indexing_attempt = db_session.execute( select( exists().where( IndexAttempt.connector_credential_pair_id == cc_pair_id, IndexAttempt.search_settings_id == search_settings_id, IndexAttempt.status.in_( [ IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS, ] ), ) ) ).scalar() if active_indexing_attempt: task_logger.debug( f"active_indexing_attempt - Skipping due to active indexing attempt: " f"cc_pair={cc_pair_id} search_settings={search_settings_id}" ) return bool(active_indexing_attempt) def _kickoff_indexing_tasks( celery_app: Celery, db_session: Session, search_settings: SearchSettings, cc_pair_ids: list[int], secondary_index_building: bool, redis_client: Redis, lock_beat: RedisLock, tenant_id: str, ) -> int: """Kick off indexing tasks for the given cc_pair_ids and search_settings. Returns the number of tasks successfully created. """ tasks_created = 0 for cc_pair_id in cc_pair_ids: lock_beat.reacquire() # Lightweight check prior to fetching cc pair if active_indexing_attempt( cc_pair_id=cc_pair_id, search_settings_id=search_settings.id, db_session=db_session, ): continue cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, ) if not cc_pair: task_logger.warning( f"_kickoff_indexing_tasks - CC pair not found: cc_pair={cc_pair_id}" ) continue # Heavyweight check after fetching cc pair if not should_index( cc_pair=cc_pair, search_settings_instance=search_settings, secondary_index_building=secondary_index_building, db_session=db_session, ): task_logger.debug( f"_kickoff_indexing_tasks - Not indexing cc_pair_id: {cc_pair_id} " f"search_settings={search_settings.id}, " f"secondary_index_building={secondary_index_building}" ) continue task_logger.debug( f"_kickoff_indexing_tasks - Will index cc_pair_id: {cc_pair_id} " f"search_settings={search_settings.id}, " f"secondary_index_building={secondary_index_building}" ) reindex = False # the indexing trigger is only checked and cleared with the current search settings if search_settings.status.is_current() and cc_pair.indexing_trigger is not None: if cc_pair.indexing_trigger == IndexingMode.REINDEX: reindex = True task_logger.info( f"_kickoff_indexing_tasks - Connector indexing manual trigger detected: " f"cc_pair={cc_pair.id} " f"search_settings={search_settings.id} " f"indexing_mode={cc_pair.indexing_trigger}" ) mark_ccpair_with_indexing_trigger(cc_pair.id, None, db_session) # using a task queue and only allowing one task per cc_pair/search_setting # prevents us from starving out certain attempts attempt_id = try_creating_docfetching_task( celery_app, cc_pair, search_settings, reindex, db_session, redis_client, tenant_id, ) if attempt_id is not None: task_logger.info( f"Connector indexing queued: index_attempt={attempt_id} cc_pair={cc_pair.id} search_settings={search_settings.id}" ) tasks_created += 1 else: task_logger.error( f"Failed to create indexing task: cc_pair={cc_pair.id} search_settings={search_settings.id}" ) return tasks_created @shared_task( name=OnyxCeleryTask.CHECK_FOR_INDEXING, soft_time_limit=300, bind=True, ) def check_for_indexing(self: Task, *, tenant_id: str) -> int | None: """a lightweight task used to kick off the pipeline of indexing tasks. Occcasionally does some validation of existing state to clear up error conditions. This task is the entrypoint for the full "indexing pipeline", which is composed of two tasks: "docfetching" and "docprocessing". More details in the docfetching task (OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK). For cc pairs that should be indexed (see should_index()), this task calls try_creating_docfetching_task, which creates a docfetching task. All the logic for determining what state the indexing pipeline is in w.r.t previous failed attempt, checkpointing, etc is handled in the docfetching task. """ time_start = time.monotonic() task_logger.warning("check_for_indexing - Starting") tasks_created = 0 locked = False redis_client = get_redis_client() redis_client_replica = get_redis_replica_client() # we need to use celery's redis client to access its redis data # (which lives on a different db number) # redis_client_celery: Redis = self.app.broker_connection().channel().client # type: ignore lock_beat: RedisLock = redis_client.lock( OnyxRedisLocks.CHECK_INDEXING_BEAT_LOCK, timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT, ) # these tasks should never overlap if not lock_beat.acquire(blocking=False): return None try: locked = True # SPECIAL 0/3: sync lookup table for active fences # we want to run this less frequently than the overall task if not redis_client.exists(OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE): # build a lookup table of existing fences # this is just a migration concern and should be unnecessary once # lookup tables are rolled out for key_bytes in redis_client_replica.scan_iter( count=SCAN_ITER_COUNT_DEFAULT ): if is_fence(key_bytes) and not redis_client.sismember( OnyxRedisConstants.ACTIVE_FENCES, key_bytes ): logger.warning(f"Adding {key_bytes} to the lookup table.") redis_client.sadd(OnyxRedisConstants.ACTIVE_FENCES, key_bytes) redis_client.set( OnyxRedisSignals.BLOCK_BUILD_FENCE_LOOKUP_TABLE, 1, ex=OnyxRuntime.get_build_fence_lookup_table_interval(), ) # 1/3: KICKOFF # check for search settings swap with get_session_with_current_tenant() as db_session: old_search_settings = check_and_perform_index_swap(db_session=db_session) current_search_settings = get_current_search_settings(db_session) # So that the first time users aren't surprised by really slow speed of first # batch of documents indexed if current_search_settings.provider_type is None and not MULTI_TENANT: if old_search_settings: embedding_model = EmbeddingModel.from_db_model( search_settings=current_search_settings, server_host=INDEXING_MODEL_SERVER_HOST, server_port=INDEXING_MODEL_SERVER_PORT, ) # only warm up if search settings were changed warm_up_bi_encoder( embedding_model=embedding_model, ) # gather search settings and indexable cc_pair_ids # indexable CC pairs include everything for future model and only active cc pairs for current model lock_beat.reacquire() with get_session_with_current_tenant() as db_session: # Get CC pairs for primary search settings standard_cc_pair_ids = ( fetch_indexable_standard_connector_credential_pair_ids( db_session, active_cc_pairs_only=True ) ) primary_cc_pair_ids = standard_cc_pair_ids # Get CC pairs for secondary search settings secondary_cc_pair_ids: list[int] = [] secondary_search_settings = get_secondary_search_settings(db_session) if secondary_search_settings: # For ACTIVE_ONLY, we skip paused connectors include_paused = ( secondary_search_settings.switchover_type != SwitchoverType.ACTIVE_ONLY ) standard_cc_pair_ids = ( fetch_indexable_standard_connector_credential_pair_ids( db_session, active_cc_pairs_only=not include_paused ) ) secondary_cc_pair_ids = standard_cc_pair_ids # Flag CC pairs in repeated error state for primary/current search settings with get_session_with_current_tenant() as db_session: for cc_pair_id in primary_cc_pair_ids: lock_beat.reacquire() cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, ) # if already in repeated error state, don't do anything # this is important so that we don't keep pausing the connector # immediately upon a user un-pausing it to manually re-trigger and # recover. if ( cc_pair and not cc_pair.in_repeated_error_state and is_in_repeated_error_state( cc_pair=cc_pair, search_settings_id=current_search_settings.id, db_session=db_session, ) ): set_cc_pair_repeated_error_state( db_session=db_session, cc_pair_id=cc_pair_id, in_repeated_error_state=True, ) # When entering repeated error state, also pause the connector # to prevent continued indexing retry attempts burning through embedding credits. # NOTE: only for Cloud, since most self-hosted users use self-hosted embedding # models. Also, they are more prone to repeated failures -> eventual success. if AUTH_TYPE == AuthType.CLOUD: update_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair.id, status=ConnectorCredentialPairStatus.PAUSED, ) # NOTE: At this point, we haven't done heavy checks on whether or not the CC pairs should actually be indexed # Heavy check, should_index(), is called in _kickoff_indexing_tasks with get_session_with_current_tenant() as db_session: # Primary first tasks_created += _kickoff_indexing_tasks( celery_app=self.app, db_session=db_session, search_settings=current_search_settings, cc_pair_ids=primary_cc_pair_ids, secondary_index_building=secondary_search_settings is not None, redis_client=redis_client, lock_beat=lock_beat, tenant_id=tenant_id, ) # Secondary indexing (only if secondary search settings exist and switchover_type is not INSTANT) if ( secondary_search_settings and secondary_search_settings.switchover_type != SwitchoverType.INSTANT and secondary_cc_pair_ids ): tasks_created += _kickoff_indexing_tasks( celery_app=self.app, db_session=db_session, search_settings=secondary_search_settings, cc_pair_ids=secondary_cc_pair_ids, secondary_index_building=True, redis_client=redis_client, lock_beat=lock_beat, tenant_id=tenant_id, ) elif ( secondary_search_settings and secondary_search_settings.switchover_type == SwitchoverType.INSTANT ): task_logger.info( f"Skipping secondary indexing: switchover_type=INSTANT for search_settings={secondary_search_settings.id}" ) # 2/3: VALIDATE # Check for inconsistent index attempts - active attempts without task IDs # This can happen if attempt creation fails partway through lock_beat.reacquire() with get_session_with_current_tenant() as db_session: inconsistent_attempts = ( db_session.execute( select(IndexAttempt).where( IndexAttempt.status.in_( [IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS] ), IndexAttempt.celery_task_id.is_(None), ) ) .scalars() .all() ) for attempt in inconsistent_attempts: lock_beat.reacquire() # Double-check the attempt still has the inconsistent state fresh_attempt = get_index_attempt(db_session, attempt.id) if ( not fresh_attempt or fresh_attempt.celery_task_id or fresh_attempt.status.is_terminal() ): continue failure_reason = ( f"Inconsistent index attempt found - active status without Celery task: " f"index_attempt={attempt.id} " f"cc_pair={attempt.connector_credential_pair_id} " f"search_settings={attempt.search_settings_id}" ) task_logger.error(failure_reason) mark_attempt_failed( attempt.id, db_session, failure_reason=failure_reason ) lock_beat.reacquire() # we want to run this less frequently than the overall task if not redis_client.exists(OnyxRedisSignals.BLOCK_VALIDATE_INDEXING_FENCES): # Check for orphaned index attempts that have Celery task IDs but no actual running tasks # This can happen if workers crash or tasks are terminated unexpectedly # We reuse the same Redis signal name for backwards compatibility try: validate_active_indexing_attempts(lock_beat) except Exception: task_logger.exception( "Exception while validating active indexing attempts" ) redis_client.set( OnyxRedisSignals.BLOCK_VALIDATE_INDEXING_FENCES, 1, ex=_get_fence_validation_block_expiration(), ) # 3/3: FINALIZE - Monitor active indexing attempts using database lock_beat.reacquire() with get_session_with_current_tenant() as db_session: # Monitor all active indexing attempts directly from the database # This replaces the Redis fence-based monitoring active_attempts = ( db_session.execute( select(IndexAttempt).where( IndexAttempt.status.in_( [IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS] ) ) ) .scalars() .all() ) for attempt in active_attempts: try: monitor_indexing_attempt_progress( attempt, tenant_id, db_session, self ) except Exception: task_logger.exception(f"Error monitoring attempt {attempt.id}") lock_beat.reacquire() except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." ) except Exception: task_logger.exception("Unexpected exception during indexing check") finally: if locked: if lock_beat.owned(): lock_beat.release() else: task_logger.error( f"check_for_indexing - Lock not owned on completion: tenant={tenant_id}" ) redis_lock_dump(lock_beat, redis_client) time_elapsed = time.monotonic() - time_start task_logger.info(f"check_for_indexing finished: elapsed={time_elapsed:.2f}") return tasks_created # primary @shared_task( name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP, soft_time_limit=300, bind=True, ) def check_for_checkpoint_cleanup(self: Task, *, tenant_id: str) -> None: """Clean up old checkpoints that are older than 7 days.""" locked = False redis_client = get_redis_client(tenant_id=tenant_id) lock: RedisLock = redis_client.lock( OnyxRedisLocks.CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK, timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT, ) # these tasks should never overlap if not lock.acquire(blocking=False): return None try: locked = True with get_session_with_current_tenant() as db_session: old_attempts = get_index_attempts_with_old_checkpoints(db_session) for attempt in old_attempts: task_logger.info( f"Cleaning up checkpoint for index attempt {attempt.id}" ) self.app.send_task( OnyxCeleryTask.CLEANUP_CHECKPOINT, kwargs={ "index_attempt_id": attempt.id, "tenant_id": tenant_id, }, queue=OnyxCeleryQueues.CHECKPOINT_CLEANUP, priority=OnyxCeleryPriority.MEDIUM, ) except Exception: task_logger.exception("Unexpected exception during checkpoint cleanup") return None finally: if locked: if lock.owned(): lock.release() else: task_logger.error( f"check_for_checkpoint_cleanup - Lock not owned on completion: tenant={tenant_id}" ) # light worker @shared_task( name=OnyxCeleryTask.CLEANUP_CHECKPOINT, bind=True, ) def cleanup_checkpoint_task( self: Task, # noqa: ARG001 *, index_attempt_id: int, tenant_id: str | None, ) -> None: """Clean up a checkpoint for a given index attempt""" start = time.monotonic() try: with get_session_with_current_tenant() as db_session: cleanup_checkpoint(db_session, index_attempt_id) finally: elapsed = time.monotonic() - start task_logger.info( f"cleanup_checkpoint_task completed: tenant_id={tenant_id} index_attempt_id={index_attempt_id} elapsed={elapsed:.2f}" ) # primary @shared_task( name=OnyxCeleryTask.CHECK_FOR_INDEX_ATTEMPT_CLEANUP, soft_time_limit=300, bind=True, ) def check_for_index_attempt_cleanup(self: Task, *, tenant_id: str) -> None: """Clean up old index attempts that are older than 7 days.""" locked = False redis_client = get_redis_client(tenant_id=tenant_id) lock: RedisLock = redis_client.lock( OnyxRedisLocks.CHECK_INDEX_ATTEMPT_CLEANUP_BEAT_LOCK, timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT, ) # these tasks should never overlap if not lock.acquire(blocking=False): task_logger.info( f"check_for_index_attempt_cleanup - Lock not acquired: tenant={tenant_id}" ) return None try: locked = True batch_size = INDEX_ATTEMPT_BATCH_SIZE with get_session_with_current_tenant() as db_session: old_attempts = get_old_index_attempts(db_session) # We need to batch this because during the initial run, the system might have a large number # of index attempts since they were never deleted. After that, the number will be # significantly lower. if len(old_attempts) == 0: task_logger.info( "check_for_index_attempt_cleanup - No index attempts to cleanup" ) return for i in range(0, len(old_attempts), batch_size): batch = old_attempts[i : i + batch_size] task_logger.info( f"check_for_index_attempt_cleanup - Cleaning up index attempts {len(batch)}" ) self.app.send_task( OnyxCeleryTask.CLEANUP_INDEX_ATTEMPT, kwargs={ "index_attempt_ids": [attempt.id for attempt in batch], "tenant_id": tenant_id, }, queue=OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP, priority=OnyxCeleryPriority.MEDIUM, ) except Exception: task_logger.exception("Unexpected exception during index attempt cleanup check") return None finally: if locked: if lock.owned(): lock.release() else: task_logger.error( f"check_for_index_attempt_cleanup - Lock not owned on completion: tenant={tenant_id}" ) # light worker @shared_task( name=OnyxCeleryTask.CLEANUP_INDEX_ATTEMPT, bind=True, ) def cleanup_index_attempt_task( self: Task, # noqa: ARG001 *, index_attempt_ids: list[int], tenant_id: str, ) -> None: """Clean up an index attempt""" start = time.monotonic() try: with get_session_with_current_tenant() as db_session: cleanup_index_attempts(db_session, index_attempt_ids) finally: elapsed = time.monotonic() - start task_logger.info( f"cleanup_index_attempt_task completed: tenant_id={tenant_id} " f"index_attempt_ids={index_attempt_ids} " f"elapsed={elapsed:.2f}" ) class DocumentProcessingBatch(BaseModel): """Data structure for a document processing batch.""" batch_id: str index_attempt_id: int cc_pair_id: int tenant_id: str batch_num: int def _check_failure_threshold( total_failures: int, document_count: int, batch_num: int, last_failure: ConnectorFailure | None, ) -> None: """Check if we've hit the failure threshold and raise an appropriate exception if so. We consider the threshold hit if: 1. We have more than 3 failures AND 2. Failures account for more than 10% of processed documents """ failure_ratio = total_failures / (document_count or 1) FAILURE_THRESHOLD = 3 FAILURE_RATIO_THRESHOLD = 0.1 if total_failures > FAILURE_THRESHOLD and failure_ratio > FAILURE_RATIO_THRESHOLD: logger.error( f"Connector run failed with '{total_failures}' errors after '{batch_num}' batches." ) if last_failure and last_failure.exception: raise last_failure.exception from last_failure.exception raise RuntimeError( f"Connector run encountered too many errors, aborting. Last error: {last_failure}" ) def _resolve_indexing_document_errors( cc_pair_id: int, failures: list[ConnectorFailure], document_batch: list[Document], ) -> None: with get_session_with_current_tenant() as db_session_temp: # get previously unresolved errors unresolved_errors = get_index_attempt_errors_for_cc_pair( cc_pair_id=cc_pair_id, unresolved_only=True, db_session=db_session_temp, ) doc_id_to_unresolved_errors: dict[str, list[IndexAttemptError]] = defaultdict( list ) for error in unresolved_errors: if error.document_id: doc_id_to_unresolved_errors[error.document_id].append(error) # resolve errors for documents that were successfully indexed failed_document_ids = [ failure.failed_document.document_id for failure in failures if failure.failed_document ] successful_document_ids = [ document.id for document in document_batch if document.id not in failed_document_ids ] for document_id in successful_document_ids: if document_id not in doc_id_to_unresolved_errors: continue logger.info(f"Resolving IndexAttemptError for document '{document_id}'") for error in doc_id_to_unresolved_errors[document_id]: error.is_resolved = True db_session_temp.add(error) db_session_temp.commit() @shared_task( name=OnyxCeleryTask.DOCPROCESSING_TASK, bind=True, ) def docprocessing_task( self: Task, # noqa: ARG001 index_attempt_id: int, cc_pair_id: int, tenant_id: str, batch_num: int, ) -> None: """Process a batch of documents through the indexing pipeline. This task retrieves documents from storage and processes them through the indexing pipeline (embedding + vector store indexing). """ # Start heartbeat for this indexing attempt heartbeat_thread, stop_event = start_heartbeat(index_attempt_id) try: # Cannot use the TaskSingleton approach here because the worker is multithreaded token = INDEX_ATTEMPT_INFO_CONTEXTVAR.set((cc_pair_id, index_attempt_id)) _docprocessing_task(index_attempt_id, cc_pair_id, tenant_id, batch_num) finally: stop_heartbeat(heartbeat_thread, stop_event) # Stop heartbeat before exiting INDEX_ATTEMPT_INFO_CONTEXTVAR.reset(token) def _check_chunk_usage_limit(tenant_id: str) -> None: """Check if chunk indexing usage limit has been exceeded. Raises UsageLimitExceededError if the limit is exceeded. """ if not USAGE_LIMITS_ENABLED: return from onyx.db.usage import UsageType from onyx.server.usage_limits import check_usage_and_raise with get_session_with_current_tenant() as db_session: check_usage_and_raise( db_session=db_session, usage_type=UsageType.CHUNKS_INDEXED, tenant_id=tenant_id, pending_amount=0, # Just check current usage ) def _docprocessing_task( index_attempt_id: int, cc_pair_id: int, tenant_id: str, batch_num: int, ) -> None: start_time = time.monotonic() if tenant_id: CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) # Check if chunk indexing usage limit has been exceeded before processing if USAGE_LIMITS_ENABLED: try: _check_chunk_usage_limit(tenant_id) except HTTPException as e: # Log the error and fail the indexing attempt task_logger.error( f"Chunk indexing usage limit exceeded for tenant {tenant_id}: {e}" ) with get_session_with_current_tenant() as db_session: from onyx.db.index_attempt import mark_attempt_failed mark_attempt_failed( index_attempt_id=index_attempt_id, db_session=db_session, failure_reason=str(e), ) raise task_logger.info( f"Processing document batch: attempt={index_attempt_id} batch_num={batch_num} " ) # Get the document batch storage storage = get_document_batch_storage(cc_pair_id, index_attempt_id) redis_connector = RedisConnector(tenant_id, cc_pair_id) r = get_redis_client(tenant_id=tenant_id) # 20 is the documented default for httpx max_keepalive_connections if MANAGED_VESPA: httpx_init_vespa_pool( 20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH ) else: httpx_init_vespa_pool(20) # dummy lock to satisfy linter per_batch_lock: RedisLock | None = None try: # FIX: Monitor memory before loading documents to track problematic batches emit_process_memory( os.getpid(), "docprocessing", { "phase": "before_load", "tenant_id": tenant_id, "cc_pair_id": cc_pair_id, "index_attempt_id": index_attempt_id, "batch_num": batch_num, }, ) # Retrieve documents from storage documents = storage.get_batch(batch_num) if not documents: task_logger.error(f"No documents found for batch {batch_num}") return # FIX: Monitor memory after loading documents emit_process_memory( os.getpid(), "docprocessing", { "phase": "after_load", "tenant_id": tenant_id, "cc_pair_id": cc_pair_id, "index_attempt_id": index_attempt_id, "batch_num": batch_num, "doc_count": len(documents), }, ) with get_session_with_current_tenant() as db_session: # matches parts of _run_indexing index_attempt = get_index_attempt( db_session, index_attempt_id, eager_load_cc_pair=True, eager_load_search_settings=True, ) if not index_attempt: raise RuntimeError(f"Index attempt {index_attempt_id} not found") if index_attempt.search_settings is None: raise ValueError("Search settings must be set for indexing") if ( index_attempt.celery_task_id is None or index_attempt.status.is_terminal() ): raise RuntimeError( f"Index attempt {index_attempt_id} is not running, status {index_attempt.status}" ) cross_batch_db_lock: RedisLock = r.lock( redis_connector.db_lock_key(index_attempt.search_settings.id), timeout=CELERY_INDEXING_LOCK_TIMEOUT, thread_local=False, ) callback = IndexingCallback( redis_connector, ) # TODO: right now this is the only thing the callback is used for, # probably there is a simpler way to handle pausing if callback.should_stop(): raise RuntimeError("Docprocessing cancelled by connector pausing") # Set up indexing pipeline components embedding_model = DefaultIndexingEmbedder.from_db_search_settings( search_settings=index_attempt.search_settings, callback=callback, ) document_indices = get_all_document_indices( index_attempt.search_settings, None, httpx_client=HttpxPool.get("vespa"), ) # Set up metadata for this batch index_attempt_metadata = IndexAttemptMetadata( attempt_id=index_attempt_id, connector_id=index_attempt.connector_credential_pair.connector.id, credential_id=index_attempt.connector_credential_pair.credential.id, request_id=make_randomized_onyx_request_id("DIP"), structured_id=f"{tenant_id}:{cc_pair_id}:{index_attempt_id}:{batch_num}", batch_num=batch_num, ) # Process documents through indexing pipeline connector_source = ( index_attempt.connector_credential_pair.connector.source.value ) task_logger.info( f"Processing {len(documents)} documents through indexing pipeline: " f"cc_pair_id={cc_pair_id}, source={connector_source}, " f"batch_num={batch_num}" ) adapter = DocumentIndexingBatchAdapter( db_session=db_session, connector_id=index_attempt.connector_credential_pair.connector.id, credential_id=index_attempt.connector_credential_pair.credential.id, tenant_id=tenant_id, index_attempt_metadata=index_attempt_metadata, ) # real work happens here! index_pipeline_result = run_indexing_pipeline( embedder=embedding_model, document_indices=document_indices, ignore_time_skip=True, # Documents are already filtered during extraction db_session=db_session, tenant_id=tenant_id, document_batch=documents, request_id=index_attempt_metadata.request_id, adapter=adapter, ) # Track chunk indexing usage for cloud usage limits if USAGE_LIMITS_ENABLED and index_pipeline_result.total_chunks > 0: try: from onyx.db.usage import increment_usage from onyx.db.usage import UsageType with get_session_with_current_tenant() as usage_db_session: increment_usage( db_session=usage_db_session, usage_type=UsageType.CHUNKS_INDEXED, amount=index_pipeline_result.total_chunks, ) usage_db_session.commit() except Exception as e: # Log but don't fail indexing if usage tracking fails task_logger.warning(f"Failed to track chunk indexing usage: {e}") # Update batch completion and document counts atomically using database coordination with get_session_with_current_tenant() as db_session, cross_batch_db_lock: IndexingCoordination.update_batch_completion_and_docs( db_session=db_session, index_attempt_id=index_attempt_id, total_docs_indexed=index_pipeline_result.total_docs, new_docs_indexed=index_pipeline_result.new_docs, total_chunks=index_pipeline_result.total_chunks, ) _resolve_indexing_document_errors( cc_pair_id, index_pipeline_result.failures, documents, ) coordination_status = None # Record failures in the database if index_pipeline_result.failures: with get_session_with_current_tenant() as db_session: for failure in index_pipeline_result.failures: create_index_attempt_error( index_attempt_id, cc_pair_id, failure, db_session, ) # Use database state instead of FileStore for failure checking with get_session_with_current_tenant() as db_session: coordination_status = IndexingCoordination.get_coordination_status( db_session, index_attempt_id ) _check_failure_threshold( coordination_status.total_failures, coordination_status.total_docs, batch_num, index_pipeline_result.failures[-1], ) # Add telemetry for indexing progress using database coordination status # only re-fetch coordination status if necessary if coordination_status is None: with get_session_with_current_tenant() as db_session: coordination_status = IndexingCoordination.get_coordination_status( db_session, index_attempt_id ) optional_telemetry( record_type=RecordType.INDEXING_PROGRESS, data={ "index_attempt_id": index_attempt_id, "cc_pair_id": cc_pair_id, "current_docs_indexed": coordination_status.total_docs, "current_chunks_indexed": coordination_status.total_chunks, "source": index_attempt.connector_credential_pair.connector.source.value, "completed_batches": coordination_status.completed_batches, "total_batches": coordination_status.total_batches, }, tenant_id=tenant_id, ) # Clean up this batch after successful processing storage.delete_batch_by_num(batch_num) # FIX: Explicitly clear document batch from memory and force garbage collection # This helps prevent memory accumulation across multiple batches # NOTE: Thread-local event loops in embedding threads are cleaned up automatically # via the _cleanup_thread_local decorator in search_nlp_models.py del documents gc.collect() # FIX: Log final memory usage to track problematic tenants/CC pairs emit_process_memory( os.getpid(), "docprocessing", { "phase": "after_processing", "tenant_id": tenant_id, "cc_pair_id": cc_pair_id, "index_attempt_id": index_attempt_id, "batch_num": batch_num, "chunks_processed": index_pipeline_result.total_chunks, }, ) elapsed_time = time.monotonic() - start_time task_logger.info( f"Completed document batch processing: " f"index_attempt={index_attempt_id} " f"cc_pair={cc_pair_id} " f"search_settings={index_attempt.search_settings.id} " f"batch_num={batch_num} " f"docs={len(index_pipeline_result.failures) + index_pipeline_result.total_docs} " f"chunks={index_pipeline_result.total_chunks} " f"failures={len(index_pipeline_result.failures)} " f"elapsed={elapsed_time:.2f}s" ) except Exception: task_logger.exception( f"Document batch processing failed: batch_num={batch_num} attempt={index_attempt_id} " ) raise finally: if per_batch_lock and per_batch_lock.owned(): per_batch_lock.release() ================================================ FILE: backend/onyx/background/celery/tasks/docprocessing/utils.py ================================================ import time from datetime import datetime from datetime import timezone from redis import Redis from redis.exceptions import LockError from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from onyx.configs.app_configs import DISABLE_INDEX_UPDATE_ON_SWAP from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import DocumentSource from onyx.db.engine.time_utils import get_db_current_time from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import IndexingStatus from onyx.db.enums import IndexModelStatus from onyx.db.index_attempt import get_last_attempt_for_cc_pair from onyx.db.index_attempt import get_recent_attempts_for_cc_pair from onyx.db.models import ConnectorCredentialPair from onyx.db.models import SearchSettings from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.redis.redis_connector import RedisConnector from onyx.redis.redis_pool import redis_lock_dump from onyx.utils.logger import setup_logger logger = setup_logger() NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE = 5 class IndexingCallbackBase(IndexingHeartbeatInterface): PARENT_CHECK_INTERVAL = 60 def __init__( self, parent_pid: int, redis_connector: RedisConnector, redis_lock: RedisLock, redis_client: Redis, timeout_seconds: int | None = None, ): super().__init__() self.parent_pid = parent_pid self.redis_connector: RedisConnector = redis_connector self.redis_lock: RedisLock = redis_lock self.redis_client = redis_client self.started: datetime = datetime.now(timezone.utc) self.redis_lock.reacquire() self.last_tag: str = f"{self.__class__.__name__}.__init__" self.last_lock_reacquire: datetime = datetime.now(timezone.utc) self.last_lock_monotonic = time.monotonic() self.last_parent_check = time.monotonic() self.start_monotonic = time.monotonic() self.timeout_seconds = timeout_seconds def should_stop(self) -> bool: # Check if the associated indexing attempt has been cancelled # TODO: Pass index_attempt_id to the callback and check cancellation using the db if bool(self.redis_connector.stop.fenced): return True # Check if the task has exceeded its timeout # NOTE: Celery's soft_time_limit does not work with thread pools, # so we must enforce timeouts internally. if self.timeout_seconds is not None: elapsed = time.monotonic() - self.start_monotonic if elapsed > self.timeout_seconds: logger.warning( f"IndexingCallback Docprocessing - task timeout exceeded: " f"elapsed={elapsed:.0f}s timeout={self.timeout_seconds}s " f"cc_pair={self.redis_connector.cc_pair_id}" ) return True return False def progress(self, tag: str, amount: int) -> None: # noqa: ARG002 """Amount isn't used yet.""" # rkuo: this shouldn't be necessary yet because we spawn the process this runs inside # with daemon=True. It seems likely some indexing tasks will need to spawn other processes # eventually, which daemon=True prevents, so leave this code in until we're ready to test it. # if self.parent_pid: # # check if the parent pid is alive so we aren't running as a zombie # now = time.monotonic() # if now - self.last_parent_check > IndexingCallback.PARENT_CHECK_INTERVAL: # try: # # this is unintuitive, but it checks if the parent pid is still running # os.kill(self.parent_pid, 0) # except Exception: # logger.exception("IndexingCallback - parent pid check exceptioned") # raise # self.last_parent_check = now try: current_time = time.monotonic() if current_time - self.last_lock_monotonic >= ( CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4 ): self.redis_lock.reacquire() self.last_lock_reacquire = datetime.now(timezone.utc) self.last_lock_monotonic = time.monotonic() self.last_tag = tag except LockError: logger.exception( f"{self.__class__.__name__} - lock.reacquire exceptioned: " f"lock_timeout={self.redis_lock.timeout} " f"start={self.started} " f"last_tag={self.last_tag} " f"last_reacquired={self.last_lock_reacquire} " f"now={datetime.now(timezone.utc)}" ) redis_lock_dump(self.redis_lock, self.redis_client) raise # NOTE: we're in the process of removing all fences from indexing; this will # eventually no longer be used. For now, it is used only for connector pausing. class IndexingCallback(IndexingHeartbeatInterface): def __init__( self, redis_connector: RedisConnector, ): self.redis_connector = redis_connector def should_stop(self) -> bool: # Check if the associated indexing attempt has been cancelled # TODO: Pass index_attempt_id to the callback and check cancellation using the db return bool(self.redis_connector.stop.fenced) # included to satisfy old interface def progress(self, tag: str, amount: int) -> None: pass # NOTE: The validate_indexing_fence and validate_indexing_fences functions have been removed # as they are no longer needed with database-based coordination. The new validation is # handled by validate_active_indexing_attempts in the main indexing tasks module. def is_in_repeated_error_state( cc_pair: ConnectorCredentialPair, search_settings_id: int, db_session: Session ) -> bool: """Checks if the cc pair / search setting combination is in a repeated error state.""" # if the connector doesn't have a refresh_freq, a single failed attempt is enough number_of_failed_attempts_in_a_row_needed = ( NUM_REPEAT_ERRORS_BEFORE_REPEATED_ERROR_STATE if cc_pair.connector.refresh_freq is not None else 1 ) most_recent_index_attempts = get_recent_attempts_for_cc_pair( cc_pair_id=cc_pair.id, search_settings_id=search_settings_id, limit=number_of_failed_attempts_in_a_row_needed, db_session=db_session, ) return len( most_recent_index_attempts ) >= number_of_failed_attempts_in_a_row_needed and all( attempt.status == IndexingStatus.FAILED for attempt in most_recent_index_attempts ) def should_index( cc_pair: ConnectorCredentialPair, search_settings_instance: SearchSettings, secondary_index_building: bool, db_session: Session, ) -> bool: """Checks various global settings and past indexing attempts to determine if we should try to start indexing the cc pair / search setting combination. Note that tactical checks such as preventing overlap with a currently running task are not handled here. Return True if we should try to index, False if not. """ connector = cc_pair.connector last_index_attempt = get_last_attempt_for_cc_pair( cc_pair_id=cc_pair.id, search_settings_id=search_settings_instance.id, db_session=db_session, ) all_recent_errored = is_in_repeated_error_state( cc_pair=cc_pair, search_settings_id=search_settings_instance.id, db_session=db_session, ) # uncomment for debugging # task_logger.debug( # f"_should_index: " # f"cc_pair={cc_pair.id} " # f"connector={cc_pair.connector_id} " # f"refresh_freq={connector.refresh_freq}" # ) # don't kick off indexing for `NOT_APPLICABLE` sources if connector.source == DocumentSource.NOT_APPLICABLE: # print(f"Not indexing cc_pair={cc_pair.id}: NOT_APPLICABLE source") return False # User can still manually create single indexing attempts via the UI for the # currently in use index if DISABLE_INDEX_UPDATE_ON_SWAP: if ( search_settings_instance.status == IndexModelStatus.PRESENT and secondary_index_building ): # print( # f"Not indexing cc_pair={cc_pair.id}: DISABLE_INDEX_UPDATE_ON_SWAP is True and secondary index building" # ) return False # When switching over models, always index at least once if search_settings_instance.status == IndexModelStatus.FUTURE: if last_index_attempt: # No new index if the last index attempt succeeded # Once is enough. The model will never be able to swap otherwise. if last_index_attempt.status == IndexingStatus.SUCCESS: # print( # f"Not indexing cc_pair={cc_pair.id}: FUTURE model with successful last index attempt={last_index.id}" # ) return False # No new index if the last index attempt is waiting to start if last_index_attempt.status == IndexingStatus.NOT_STARTED: # print( # f"Not indexing cc_pair={cc_pair.id}: FUTURE model with NOT_STARTED last index attempt={last_index.id}" # ) return False # No new index if the last index attempt is running if last_index_attempt.status == IndexingStatus.IN_PROGRESS: # print( # f"Not indexing cc_pair={cc_pair.id}: FUTURE model with IN_PROGRESS last index attempt={last_index.id}" # ) return False else: if ( connector.id == 0 or connector.source == DocumentSource.INGESTION_API ): # Ingestion API # print( # f"Not indexing cc_pair={cc_pair.id}: FUTURE model with Ingestion API source" # ) return False return True # If the connector is paused or is the ingestion API, don't index # NOTE: during an embedding model switch over, the following logic # is bypassed by the above check for a future model if ( not cc_pair.status.is_active() or connector.id == 0 or connector.source == DocumentSource.INGESTION_API ): # print( # f"Not indexing cc_pair={cc_pair.id}: Connector is paused or is Ingestion API" # ) return False if search_settings_instance.status.is_current(): if cc_pair.indexing_trigger is not None: # if a manual indexing trigger is on the cc pair, honor it for live search settings return True # if no attempt has ever occurred, we should index regardless of refresh_freq if not last_index_attempt: return True if connector.refresh_freq is None: # print(f"Not indexing cc_pair={cc_pair.id}: refresh_freq is None") return False # if in the "initial" phase, we should always try and kick-off indexing # as soon as possible if there is no ongoing attempt. In other words, # no delay UNLESS we're repeatedly failing to index. if ( cc_pair.status == ConnectorCredentialPairStatus.INITIAL_INDEXING and not all_recent_errored ): return True current_db_time = get_db_current_time(db_session) time_since_index = current_db_time - last_index_attempt.time_updated if time_since_index.total_seconds() < connector.refresh_freq: # print( # f"Not indexing cc_pair={cc_pair.id}: Last index attempt={last_index_attempt.id} " # f"too recent ({time_since_index.total_seconds()}s < {connector.refresh_freq}s)" # ) return False return True ================================================ FILE: backend/onyx/background/celery/tasks/evals/__init__.py ================================================ ================================================ FILE: backend/onyx/background/celery/tasks/evals/tasks.py ================================================ from datetime import datetime from datetime import timezone from typing import Any from celery import shared_task from celery import Task from onyx.configs.app_configs import BRAINTRUST_API_KEY from onyx.configs.app_configs import JOB_TIMEOUT from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES from onyx.configs.app_configs import SCHEDULED_EVAL_PERMISSIONS_EMAIL from onyx.configs.app_configs import SCHEDULED_EVAL_PROJECT from onyx.configs.constants import OnyxCeleryTask from onyx.evals.eval import run_eval from onyx.evals.models import EvalConfigurationOptions from onyx.utils.logger import setup_logger logger = setup_logger() @shared_task( name=OnyxCeleryTask.EVAL_RUN_TASK, ignore_result=True, soft_time_limit=JOB_TIMEOUT, bind=True, trail=False, ) def eval_run_task( self: Task, # noqa: ARG001 *, configuration_dict: dict[str, Any], ) -> None: """Background task to run an evaluation with the given configuration""" try: configuration = EvalConfigurationOptions.model_validate(configuration_dict) run_eval(configuration, remote_dataset_name=configuration.dataset_name) logger.info("Successfully completed eval run task") except Exception: logger.error("Failed to run eval task") raise @shared_task( name=OnyxCeleryTask.SCHEDULED_EVAL_TASK, ignore_result=True, soft_time_limit=JOB_TIMEOUT * 5, # Allow more time for multiple datasets bind=True, trail=False, ) def scheduled_eval_task(self: Task, **kwargs: Any) -> None: # noqa: ARG001 """ Scheduled task to run evaluations on configured datasets. Runs weekly on Sunday at midnight UTC. Configure via environment variables (with defaults): - SCHEDULED_EVAL_DATASET_NAMES: Comma-separated list of Braintrust dataset names - SCHEDULED_EVAL_PERMISSIONS_EMAIL: Email for search permissions (default: roshan@onyx.app) - SCHEDULED_EVAL_PROJECT: Braintrust project name """ if not BRAINTRUST_API_KEY: logger.error("BRAINTRUST_API_KEY is not configured, cannot run scheduled evals") return if not SCHEDULED_EVAL_PROJECT: logger.error( "SCHEDULED_EVAL_PROJECT is not configured, cannot run scheduled evals" ) return if not SCHEDULED_EVAL_DATASET_NAMES: logger.info("No scheduled eval datasets configured, skipping") return if not SCHEDULED_EVAL_PERMISSIONS_EMAIL: logger.error("SCHEDULED_EVAL_PERMISSIONS_EMAIL not configured") return project_name = SCHEDULED_EVAL_PROJECT dataset_names = SCHEDULED_EVAL_DATASET_NAMES permissions_email = SCHEDULED_EVAL_PERMISSIONS_EMAIL # Create a timestamp for the scheduled run run_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d") logger.info( f"Starting scheduled eval pipeline for project '{project_name}' with {len(dataset_names)} dataset(s): {dataset_names}" ) pipeline_start = datetime.now(timezone.utc) results: list[dict[str, Any]] = [] for dataset_name in dataset_names: start_time = datetime.now(timezone.utc) error_message: str | None = None success = False # Create informative experiment name for scheduled runs experiment_name = f"{dataset_name} - {run_timestamp}" try: logger.info( f"Running scheduled eval for dataset: {dataset_name} (project: {project_name})" ) configuration = EvalConfigurationOptions( search_permissions_email=permissions_email, dataset_name=dataset_name, no_send_logs=False, braintrust_project=project_name, experiment_name=experiment_name, ) result = run_eval( configuration=configuration, remote_dataset_name=dataset_name, ) success = result.success logger.info(f"Completed eval for {dataset_name}: success={success}") except Exception as e: logger.exception(f"Failed to run scheduled eval for {dataset_name}") error_message = str(e) success = False end_time = datetime.now(timezone.utc) results.append( { "dataset_name": dataset_name, "success": success, "start_time": start_time, "end_time": end_time, "error_message": error_message, } ) pipeline_end = datetime.now(timezone.utc) total_duration = (pipeline_end - pipeline_start).total_seconds() passed_count = sum(1 for r in results if r["success"]) logger.info( f"Scheduled eval pipeline completed: {passed_count}/{len(results)} passed in {total_duration:.1f}s" ) ================================================ FILE: backend/onyx/background/celery/tasks/hierarchyfetching/__init__.py ================================================ ================================================ FILE: backend/onyx/background/celery/tasks/hierarchyfetching/tasks.py ================================================ """Celery tasks for hierarchy fetching. This module provides tasks for fetching hierarchy node information from connectors. Hierarchy nodes represent structural elements like folders, spaces, and pages that can be used to filter search results. The hierarchy fetching pipeline runs once per day per connector and fetches structural information from the connector source. """ import time from datetime import datetime from datetime import timedelta from datetime import timezone from uuid import uuid4 from celery import Celery from celery import shared_task from celery import Task from redis import Redis from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from onyx.background.celery.apps.app_base import task_logger from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX from onyx.configs.constants import DocumentSource from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisLocks from onyx.connectors.factory import ConnectorMissingException from onyx.connectors.factory import identify_connector_class from onyx.connectors.factory import instantiate_connector from onyx.connectors.interfaces import HierarchyConnector from onyx.connectors.models import HierarchyNode as PydanticHierarchyNode from onyx.db.connector import mark_cc_pair_as_hierarchy_fetched from onyx.db.connector_credential_pair import ( fetch_indexable_standard_connector_credential_pair_ids, ) from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries from onyx.db.hierarchy import upsert_hierarchy_nodes_batch from onyx.db.models import ConnectorCredentialPair from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch from onyx.redis.redis_hierarchy import ensure_source_node_exists from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry from onyx.redis.redis_pool import get_redis_client from onyx.utils.logger import setup_logger logger = setup_logger() # Hierarchy fetching runs once per day (24 hours in seconds) HIERARCHY_FETCH_INTERVAL_SECONDS = 24 * 60 * 60 def _connector_supports_hierarchy_fetching( cc_pair: ConnectorCredentialPair, ) -> bool: """Return True only for connectors whose class implements HierarchyConnector.""" try: connector_class = identify_connector_class( cc_pair.connector.source, ) except ConnectorMissingException as e: task_logger.warning( "Skipping hierarchy fetching enqueue for source=%s input_type=%s: %s", cc_pair.connector.source, cc_pair.connector.input_type, str(e), ) return False return issubclass(connector_class, HierarchyConnector) def _is_hierarchy_fetching_due(cc_pair: ConnectorCredentialPair) -> bool: """Returns boolean indicating if hierarchy fetching is due for this connector. Hierarchy fetching should run once per day for active connectors. """ # Skip if not active if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE: return False # Skip if connector has never successfully indexed if not cc_pair.last_successful_index_time: return False # Check if we've fetched hierarchy recently last_fetch = cc_pair.last_time_hierarchy_fetch if last_fetch is None: # Never fetched before - fetch now return True # Check if enough time has passed since last fetch next_fetch_time = last_fetch + timedelta(seconds=HIERARCHY_FETCH_INTERVAL_SECONDS) return datetime.now(timezone.utc) >= next_fetch_time def _try_creating_hierarchy_fetching_task( celery_app: Celery, cc_pair: ConnectorCredentialPair, db_session: Session, r: Redis, tenant_id: str, ) -> str | None: """Try to create a hierarchy fetching task for a connector. Returns the task ID if created, None otherwise. """ LOCK_TIMEOUT = 30 # Serialize task creation attempts lock: RedisLock = r.lock( DANSWER_REDIS_FUNCTION_LOCK_PREFIX + f"hierarchy_fetching_{cc_pair.id}", timeout=LOCK_TIMEOUT, ) acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2) if not acquired: return None try: # Refresh to get latest state db_session.refresh(cc_pair) if cc_pair.status == ConnectorCredentialPairStatus.DELETING: return None # Generate task ID custom_task_id = f"hierarchy_fetching_{cc_pair.id}_{uuid4()}" # Send the task result = celery_app.send_task( OnyxCeleryTask.CONNECTOR_HIERARCHY_FETCHING_TASK, kwargs=dict( cc_pair_id=cc_pair.id, tenant_id=tenant_id, ), queue=OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING, task_id=custom_task_id, priority=OnyxCeleryPriority.LOW, ) if not result: raise RuntimeError("send_task for hierarchy_fetching_task failed.") task_logger.info( f"Created hierarchy fetching task: cc_pair={cc_pair.id} celery_task_id={custom_task_id}" ) return custom_task_id except Exception: task_logger.exception( f"Failed to create hierarchy fetching task: cc_pair={cc_pair.id}" ) return None finally: if lock.owned(): lock.release() @shared_task( name=OnyxCeleryTask.CHECK_FOR_HIERARCHY_FETCHING, soft_time_limit=300, bind=True, ) def check_for_hierarchy_fetching(self: Task, *, tenant_id: str) -> int | None: """Check for connectors that need hierarchy fetching and spawn tasks. This task runs periodically (once per day) and checks all active connectors to see if they need hierarchy information fetched. """ time_start = time.monotonic() task_logger.info("check_for_hierarchy_fetching - Starting") tasks_created = 0 locked = False redis_client = get_redis_client() lock_beat: RedisLock = redis_client.lock( OnyxRedisLocks.CHECK_HIERARCHY_FETCHING_BEAT_LOCK, timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT, ) # These tasks should never overlap if not lock_beat.acquire(blocking=False): return None try: locked = True with get_session_with_current_tenant() as db_session: # Get all active connector credential pairs cc_pair_ids = fetch_indexable_standard_connector_credential_pair_ids( db_session=db_session, active_cc_pairs_only=True, ) for cc_pair_id in cc_pair_ids: lock_beat.reacquire() cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, ) if not cc_pair or not _connector_supports_hierarchy_fetching(cc_pair): continue if not _is_hierarchy_fetching_due(cc_pair): continue task_id = _try_creating_hierarchy_fetching_task( celery_app=self.app, cc_pair=cc_pair, db_session=db_session, r=redis_client, tenant_id=tenant_id, ) if task_id: tasks_created += 1 except Exception: task_logger.exception("check_for_hierarchy_fetching - Unexpected error") finally: if locked: if lock_beat.owned(): lock_beat.release() else: task_logger.error( "check_for_hierarchy_fetching - Lock not owned on completion" ) time_elapsed = time.monotonic() - time_start task_logger.info( f"check_for_hierarchy_fetching finished: tasks_created={tasks_created} elapsed={time_elapsed:.2f}s" ) return tasks_created # Batch size for hierarchy node processing HIERARCHY_NODE_BATCH_SIZE = 100 def _run_hierarchy_extraction( db_session: Session, cc_pair: ConnectorCredentialPair, source: DocumentSource, tenant_id: str, ) -> int: """ Run the hierarchy extraction for a connector. Instantiates the connector and calls load_hierarchy() if the connector implements HierarchyConnector. Returns the total number of hierarchy nodes extracted. """ connector = cc_pair.connector credential = cc_pair.credential # Instantiate the connector using its configured input type runnable_connector = instantiate_connector( db_session=db_session, source=source, input_type=connector.input_type, connector_specific_config=connector.connector_specific_config, credential=credential, ) # Check if the connector supports hierarchy fetching if not isinstance(runnable_connector, HierarchyConnector): task_logger.debug( f"Connector {source} does not implement HierarchyConnector, skipping" ) return 0 redis_client = get_redis_client(tenant_id=tenant_id) # Ensure the SOURCE-type root node exists before processing hierarchy nodes. # This is the root of the hierarchy tree - all other nodes for this source # should ultimately have this as an ancestor. ensure_source_node_exists(redis_client, db_session, source) # Determine time range: start from last hierarchy fetch, end at now last_fetch = cc_pair.last_time_hierarchy_fetch start_time = last_fetch.timestamp() if last_fetch else 0 end_time = datetime.now(timezone.utc).timestamp() # Check if connector is public - all hierarchy nodes from public connectors # should be accessible to all users is_connector_public = cc_pair.access_type == AccessType.PUBLIC total_nodes = 0 node_batch: list[PydanticHierarchyNode] = [] def _process_batch() -> int: """Process accumulated hierarchy nodes batch.""" if not node_batch: return 0 upserted_nodes = upsert_hierarchy_nodes_batch( db_session=db_session, nodes=node_batch, source=source, commit=True, is_connector_public=is_connector_public, ) upsert_hierarchy_node_cc_pair_entries( db_session=db_session, hierarchy_node_ids=[n.id for n in upserted_nodes], connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id, commit=True, ) # Cache in Redis for fast ancestor resolution cache_entries = [ HierarchyNodeCacheEntry.from_db_model(node) for node in upserted_nodes ] cache_hierarchy_nodes_batch( redis_client=redis_client, source=source, entries=cache_entries, ) count = len(node_batch) node_batch.clear() return count # Fetch hierarchy nodes from the connector for node in runnable_connector.load_hierarchy(start=start_time, end=end_time): node_batch.append(node) if len(node_batch) >= HIERARCHY_NODE_BATCH_SIZE: total_nodes += _process_batch() # Process any remaining nodes total_nodes += _process_batch() return total_nodes @shared_task( name=OnyxCeleryTask.CONNECTOR_HIERARCHY_FETCHING_TASK, soft_time_limit=3600, # 1 hour soft limit time_limit=3900, # 1 hour 5 min hard limit bind=True, ) def connector_hierarchy_fetching_task( self: Task, # noqa: ARG001 *, cc_pair_id: int, tenant_id: str, ) -> None: """Fetch hierarchy information from a connector. This task fetches structural information (folders, spaces, pages, etc.) from the connector source and stores it in the database. """ task_logger.info( f"connector_hierarchy_fetching_task starting: cc_pair={cc_pair_id} tenant={tenant_id}" ) try: with get_session_with_current_tenant() as db_session: cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, ) if not cc_pair: task_logger.warning( f"CC pair not found for hierarchy fetching: cc_pair={cc_pair_id}" ) return if cc_pair.status == ConnectorCredentialPairStatus.DELETING: task_logger.info( f"Skipping hierarchy fetching for deleting connector: cc_pair={cc_pair_id}" ) return source = cc_pair.connector.source total_nodes = _run_hierarchy_extraction( db_session=db_session, cc_pair=cc_pair, source=source, tenant_id=tenant_id, ) task_logger.info( f"connector_hierarchy_fetching_task: Extracted {total_nodes} hierarchy nodes for cc_pair={cc_pair_id}" ) # Update the last fetch time to prevent re-running until next interval mark_cc_pair_as_hierarchy_fetched(db_session, cc_pair_id) except Exception: task_logger.exception( f"connector_hierarchy_fetching_task failed: cc_pair={cc_pair_id}" ) raise task_logger.info( f"connector_hierarchy_fetching_task completed: cc_pair={cc_pair_id}" ) ================================================ FILE: backend/onyx/background/celery/tasks/llm_model_update/__init__.py ================================================ ================================================ FILE: backend/onyx/background/celery/tasks/llm_model_update/tasks.py ================================================ from celery import shared_task from celery import Task from onyx.background.celery.apps.app_base import task_logger from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL from onyx.configs.constants import OnyxCeleryTask from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.llm.well_known_providers.auto_update_service import ( sync_llm_models_from_github, ) @shared_task( name=OnyxCeleryTask.CHECK_FOR_AUTO_LLM_UPDATE, ignore_result=True, soft_time_limit=300, # 5 minute timeout trail=False, bind=True, ) def check_for_auto_llm_updates( self: Task, # noqa: ARG001 *, tenant_id: str, # noqa: ARG001 ) -> bool | None: """Periodic task to fetch LLM model updates from GitHub and sync them to providers in Auto mode. This task checks the GitHub-hosted config file and updates all providers that have is_auto_mode=True. """ if not AUTO_LLM_CONFIG_URL: task_logger.debug("AUTO_LLM_CONFIG_URL not configured, skipping") return None try: # Sync to database with get_session_with_current_tenant() as db_session: results = sync_llm_models_from_github(db_session) if results: task_logger.info(f"Auto mode sync results: {results}") else: task_logger.debug("No model updates applied") except Exception: task_logger.exception("Error in auto LLM update task") raise return True ================================================ FILE: backend/onyx/background/celery/tasks/models.py ================================================ from enum import Enum from pydantic import BaseModel class DocProcessingContext(BaseModel): tenant_id: str cc_pair_id: int search_settings_id: int index_attempt_id: int class IndexingWatchdogTerminalStatus(str, Enum): """The different statuses the watchdog can finish with. TODO: create broader success/failure/abort categories """ UNDEFINED = "undefined" SUCCEEDED = "succeeded" SPAWN_FAILED = "spawn_failed" # connector spawn failed SPAWN_NOT_ALIVE = ( "spawn_not_alive" # spawn succeeded but process did not come alive ) BLOCKED_BY_DELETION = "blocked_by_deletion" BLOCKED_BY_STOP_SIGNAL = "blocked_by_stop_signal" FENCE_NOT_FOUND = "fence_not_found" # fence does not exist FENCE_READINESS_TIMEOUT = ( "fence_readiness_timeout" # fence exists but wasn't ready within the timeout ) FENCE_MISMATCH = "fence_mismatch" # task and fence metadata mismatch TASK_ALREADY_RUNNING = "task_already_running" # task appears to be running already INDEX_ATTEMPT_MISMATCH = ( "index_attempt_mismatch" # expected index attempt metadata not found in db ) CONNECTOR_VALIDATION_ERROR = ( "connector_validation_error" # the connector validation failed ) CONNECTOR_EXCEPTIONED = "connector_exceptioned" # the connector itself exceptioned WATCHDOG_EXCEPTIONED = "watchdog_exceptioned" # the watchdog exceptioned # the watchdog received a termination signal TERMINATED_BY_SIGNAL = "terminated_by_signal" # the watchdog terminated the task due to no activity TERMINATED_BY_ACTIVITY_TIMEOUT = "terminated_by_activity_timeout" # NOTE: this may actually be the same as SIGKILL, but parsed differently by python # consolidate once we know more OUT_OF_MEMORY = "out_of_memory" PROCESS_SIGNAL_SIGKILL = "process_signal_sigkill" @property def code(self) -> int: _ENUM_TO_CODE: dict[IndexingWatchdogTerminalStatus, int] = { IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL: -9, IndexingWatchdogTerminalStatus.OUT_OF_MEMORY: 137, IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR: 247, IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION: 248, IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL: 249, IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND: 250, IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT: 251, IndexingWatchdogTerminalStatus.FENCE_MISMATCH: 252, IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING: 253, IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH: 254, IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED: 255, } return _ENUM_TO_CODE[self] @classmethod def from_code(cls, code: int) -> "IndexingWatchdogTerminalStatus": _CODE_TO_ENUM: dict[int, IndexingWatchdogTerminalStatus] = { -9: IndexingWatchdogTerminalStatus.PROCESS_SIGNAL_SIGKILL, 137: IndexingWatchdogTerminalStatus.OUT_OF_MEMORY, 247: IndexingWatchdogTerminalStatus.CONNECTOR_VALIDATION_ERROR, 248: IndexingWatchdogTerminalStatus.BLOCKED_BY_DELETION, 249: IndexingWatchdogTerminalStatus.BLOCKED_BY_STOP_SIGNAL, 250: IndexingWatchdogTerminalStatus.FENCE_NOT_FOUND, 251: IndexingWatchdogTerminalStatus.FENCE_READINESS_TIMEOUT, 252: IndexingWatchdogTerminalStatus.FENCE_MISMATCH, 253: IndexingWatchdogTerminalStatus.TASK_ALREADY_RUNNING, 254: IndexingWatchdogTerminalStatus.INDEX_ATTEMPT_MISMATCH, 255: IndexingWatchdogTerminalStatus.CONNECTOR_EXCEPTIONED, } if code in _CODE_TO_ENUM: return _CODE_TO_ENUM[code] return IndexingWatchdogTerminalStatus.UNDEFINED class SimpleJobResult: """The data we want to have when the watchdog finishes""" def __init__(self) -> None: self.status = IndexingWatchdogTerminalStatus.UNDEFINED self.connector_source = None self.exit_code = None self.exception_str = None status: IndexingWatchdogTerminalStatus connector_source: str | None exit_code: int | None exception_str: str | None ================================================ FILE: backend/onyx/background/celery/tasks/monitoring/__init__.py ================================================ ================================================ FILE: backend/onyx/background/celery/tasks/monitoring/tasks.py ================================================ import json import time from datetime import timedelta from itertools import islice from typing import Any from typing import cast from typing import Literal import psutil from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded from pydantic import BaseModel from redis import Redis from redis.lock import Lock as RedisLock from sqlalchemy import select from sqlalchemy import text from sqlalchemy.orm import Session from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.celery_redis import celery_get_broker_client from onyx.background.celery.celery_redis import celery_get_queue_length from onyx.background.celery.celery_redis import celery_get_unacked_task_ids from onyx.background.celery.memory_monitoring import emit_process_memory from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import ONYX_CLOUD_TENANT_ID from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisLocks from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.engine.sql_engine import get_session_with_shared_schema from onyx.db.engine.tenant_utils import get_all_tenant_ids from onyx.db.engine.time_utils import get_db_current_time from onyx.db.enums import IndexingStatus from onyx.db.enums import SyncStatus from onyx.db.enums import SyncType from onyx.db.models import ConnectorCredentialPair from onyx.db.models import DocumentSet from onyx.db.models import IndexAttempt from onyx.db.models import SyncRecord from onyx.db.models import UserGroup from onyx.db.search_settings import get_active_search_settings_list from onyx.redis.redis_pool import get_redis_client from onyx.redis.redis_pool import redis_lock_dump from onyx.utils.logger import is_running_in_container from onyx.utils.telemetry import optional_telemetry from onyx.utils.telemetry import RecordType from shared_configs.configs import MULTI_TENANT from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR _MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes _MONITORING_TIME_LIMIT = _MONITORING_SOFT_TIME_LIMIT + 60 # 6 minutes _CONNECTOR_INDEX_ATTEMPT_START_LATENCY_KEY_FMT = ( "monitoring_connector_index_attempt_start_latency:{cc_pair_id}:{index_attempt_id}" ) _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT = ( "monitoring_connector_index_attempt_run_success:{cc_pair_id}:{index_attempt_id}" ) _FINAL_METRIC_KEY_FMT = "sync_final_metrics:{sync_type}:{entity_id}:{sync_record_id}" _SYNC_START_LATENCY_KEY_FMT = ( "sync_start_latency:{sync_type}:{entity_id}:{sync_record_id}" ) _CONNECTOR_START_TIME_KEY_FMT = "connector_start_time:{cc_pair_id}:{index_attempt_id}" _CONNECTOR_END_TIME_KEY_FMT = "connector_end_time:{cc_pair_id}:{index_attempt_id}" _SYNC_START_TIME_KEY_FMT = "sync_start_time:{sync_type}:{entity_id}:{sync_record_id}" _SYNC_END_TIME_KEY_FMT = "sync_end_time:{sync_type}:{entity_id}:{sync_record_id}" def _mark_metric_as_emitted(redis_std: Redis, key: str) -> None: """Mark a metric as having been emitted by setting a Redis key with expiration""" redis_std.set(key, "1", ex=24 * 60 * 60) # Expire after 1 day def _has_metric_been_emitted(redis_std: Redis, key: str) -> bool: """Check if a metric has been emitted by checking for existence of Redis key""" return bool(redis_std.exists(key)) class Metric(BaseModel): key: ( str | None ) # only required if we need to store that we have emitted this metric name: str value: Any tags: dict[str, str] def log(self) -> None: """Log the metric in a standardized format""" data = { "metric": self.name, "value": self.value, "tags": self.tags, } task_logger.info(json.dumps(data)) def emit(self, tenant_id: str) -> None: # Convert value to appropriate type based on the input value bool_value = None float_value = None int_value = None string_value = None # NOTE: have to do bool first, since `isinstance(True, int)` is true # e.g. bool is a subclass of int if isinstance(self.value, bool): bool_value = self.value elif isinstance(self.value, int): int_value = self.value elif isinstance(self.value, float): float_value = self.value elif isinstance(self.value, str): string_value = self.value else: task_logger.error( f"Invalid metric value type: {type(self.value)} ({self.value}) for metric {self.name}." ) return # don't send None values over the wire data = { k: v for k, v in { "metric_name": self.name, "float_value": float_value, "int_value": int_value, "string_value": string_value, "bool_value": bool_value, "tags": self.tags, }.items() if v is not None } task_logger.info(f"Emitting metric: {data}") optional_telemetry( record_type=RecordType.METRIC, data=data, tenant_id=tenant_id, ) def _collect_queue_metrics(redis_celery: Redis) -> list[Metric]: """Collect metrics about queue lengths for different Celery queues""" metrics = [] queue_mappings = { "celery_queue_length": OnyxCeleryQueues.PRIMARY, "docprocessing_queue_length": OnyxCeleryQueues.DOCPROCESSING, "docfetching_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, "sync_queue_length": OnyxCeleryQueues.VESPA_METADATA_SYNC, "deletion_queue_length": OnyxCeleryQueues.CONNECTOR_DELETION, "pruning_queue_length": OnyxCeleryQueues.CONNECTOR_PRUNING, "permissions_sync_queue_length": OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, "external_group_sync_queue_length": OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, "permissions_upsert_queue_length": OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, "hierarchy_fetching_queue_length": OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING, "llm_model_update_queue_length": OnyxCeleryQueues.LLM_MODEL_UPDATE, "checkpoint_cleanup_queue_length": OnyxCeleryQueues.CHECKPOINT_CLEANUP, "index_attempt_cleanup_queue_length": OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP, "csv_generation_queue_length": OnyxCeleryQueues.CSV_GENERATION, "user_file_processing_queue_length": OnyxCeleryQueues.USER_FILE_PROCESSING, "user_file_project_sync_queue_length": OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, "user_file_delete_queue_length": OnyxCeleryQueues.USER_FILE_DELETE, "monitoring_queue_length": OnyxCeleryQueues.MONITORING, "sandbox_queue_length": OnyxCeleryQueues.SANDBOX, "opensearch_migration_queue_length": OnyxCeleryQueues.OPENSEARCH_MIGRATION, } for name, queue in queue_mappings.items(): metrics.append( Metric( key=None, name=name, value=celery_get_queue_length(queue, redis_celery), tags={"queue": name}, ) ) return metrics def _build_connector_start_latency_metric( cc_pair: ConnectorCredentialPair, recent_attempt: IndexAttempt, second_most_recent_attempt: IndexAttempt | None, redis_std: Redis, ) -> Metric | None: if not recent_attempt.time_started: return None # check if we already emitted a metric for this index attempt metric_key = _CONNECTOR_INDEX_ATTEMPT_START_LATENCY_KEY_FMT.format( cc_pair_id=cc_pair.id, index_attempt_id=recent_attempt.id, ) if _has_metric_been_emitted(redis_std, metric_key): task_logger.info( f"Skipping metric for connector {cc_pair.connector.id} " f"index attempt {recent_attempt.id} because it has already been " "emitted" ) return None # Connector start latency # first run case - we should start as soon as it's created if not second_most_recent_attempt: desired_start_time = cc_pair.connector.time_created else: if not cc_pair.connector.refresh_freq: task_logger.debug( "Connector has no refresh_freq and this is a non-initial index attempt. " "Assuming user manually triggered indexing, so we'll skip start latency metric." ) return None desired_start_time = second_most_recent_attempt.time_updated + timedelta( seconds=cc_pair.connector.refresh_freq ) start_latency = (recent_attempt.time_started - desired_start_time).total_seconds() task_logger.info( f"Start latency for index attempt {recent_attempt.id}: {start_latency:.2f}s " f"(desired: {desired_start_time}, actual: {recent_attempt.time_started})" ) job_id = build_job_id("connector", str(cc_pair.id), str(recent_attempt.id)) return Metric( key=metric_key, name="connector_start_latency", value=start_latency, tags={ "job_id": job_id, "connector_id": str(cc_pair.connector.id), "source": str(cc_pair.connector.source), }, ) def _build_connector_final_metrics( cc_pair: ConnectorCredentialPair, recent_attempts: list[IndexAttempt], redis_std: Redis, ) -> list[Metric]: """ Final metrics for connector index attempts: - Boolean success/fail metric - If success, emit: * duration (seconds) * doc_count """ metrics = [] for attempt in recent_attempts: metric_key = _CONNECTOR_INDEX_ATTEMPT_RUN_SUCCESS_KEY_FMT.format( cc_pair_id=cc_pair.id, index_attempt_id=attempt.id, ) if _has_metric_been_emitted(redis_std, metric_key): task_logger.info( f"Skipping final metrics for connector {cc_pair.connector.id} index attempt {attempt.id}, already emitted." ) continue # We only emit final metrics if the attempt is in a terminal state if attempt.status not in [ IndexingStatus.SUCCESS, IndexingStatus.FAILED, IndexingStatus.CANCELED, ]: # Not finished; skip continue job_id = build_job_id("connector", str(cc_pair.id), str(attempt.id)) success = attempt.status == IndexingStatus.SUCCESS metrics.append( Metric( key=metric_key, # We'll mark the same key for any final metrics name="connector_run_succeeded", value=success, tags={ "job_id": job_id, "connector_id": str(cc_pair.connector.id), "source": str(cc_pair.connector.source), "status": attempt.status.value, }, ) ) if success: # Make sure we have valid time_started if attempt.time_started and attempt.time_updated: duration_seconds = ( attempt.time_updated - attempt.time_started ).total_seconds() metrics.append( Metric( key=None, # No need for a new key, or you can reuse the same if you prefer name="connector_index_duration_seconds", value=duration_seconds, tags={ "job_id": job_id, "connector_id": str(cc_pair.connector.id), "source": str(cc_pair.connector.source), }, ) ) else: task_logger.error( f"Index attempt {attempt.id} succeeded but has missing time " f"(time_started={attempt.time_started}, time_updated={attempt.time_updated})." ) # For doc counts, choose whichever field is more relevant doc_count = attempt.total_docs_indexed or 0 metrics.append( Metric( key=None, name="connector_index_doc_count", value=doc_count, tags={ "job_id": job_id, "connector_id": str(cc_pair.connector.id), "source": str(cc_pair.connector.source), }, ) ) return metrics def _collect_connector_metrics(db_session: Session, redis_std: Redis) -> list[Metric]: """Collect metrics about connector runs from the past hour""" one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1) # Get all connector credential pairs cc_pairs = db_session.scalars(select(ConnectorCredentialPair)).all() # Might be more than one search setting, or just one active_search_settings_list = get_active_search_settings_list(db_session) metrics = [] # If you want to process each cc_pair against each search setting: for cc_pair in cc_pairs: for search_settings in active_search_settings_list: recent_attempts = ( db_session.query(IndexAttempt) .filter( IndexAttempt.connector_credential_pair_id == cc_pair.id, IndexAttempt.search_settings_id == search_settings.id, ) .order_by(IndexAttempt.time_created.desc()) .limit(2) .all() ) if not recent_attempts: continue most_recent_attempt = recent_attempts[0] second_most_recent_attempt = ( recent_attempts[1] if len(recent_attempts) > 1 else None ) if one_hour_ago > most_recent_attempt.time_created: continue # Build a job_id for correlation job_id = build_job_id( "connector", str(cc_pair.id), str(most_recent_attempt.id) ) # Add raw start time metric if available if most_recent_attempt.time_started: start_time_key = _CONNECTOR_START_TIME_KEY_FMT.format( cc_pair_id=cc_pair.id, index_attempt_id=most_recent_attempt.id, ) metrics.append( Metric( key=start_time_key, name="connector_start_time", value=most_recent_attempt.time_started.timestamp(), tags={ "job_id": job_id, "connector_id": str(cc_pair.connector.id), "source": str(cc_pair.connector.source), }, ) ) # Add raw end time metric if available and in terminal state if ( most_recent_attempt.status.is_terminal() and most_recent_attempt.time_updated ): end_time_key = _CONNECTOR_END_TIME_KEY_FMT.format( cc_pair_id=cc_pair.id, index_attempt_id=most_recent_attempt.id, ) metrics.append( Metric( key=end_time_key, name="connector_end_time", value=most_recent_attempt.time_updated.timestamp(), tags={ "job_id": job_id, "connector_id": str(cc_pair.connector.id), "source": str(cc_pair.connector.source), }, ) ) # Connector start latency start_latency_metric = _build_connector_start_latency_metric( cc_pair, most_recent_attempt, second_most_recent_attempt, redis_std ) if start_latency_metric: metrics.append(start_latency_metric) # Connector run success/failure final_metrics = _build_connector_final_metrics( cc_pair, recent_attempts, redis_std ) metrics.extend(final_metrics) return metrics def _collect_sync_metrics(db_session: Session, redis_std: Redis) -> list[Metric]: """ Collect metrics for document set and group syncing: - Success/failure status - Start latency (for doc sets / user groups) - Duration & doc count (only if success) - Throughput (docs/min) (only if success) - Raw start/end times for each sync """ one_hour_ago = get_db_current_time(db_session) - timedelta(hours=1) # Get all sync records that ended in the last hour recent_sync_records = db_session.scalars( select(SyncRecord) .where(SyncRecord.sync_end_time.isnot(None)) .where(SyncRecord.sync_end_time >= one_hour_ago) .order_by(SyncRecord.sync_end_time.desc()) ).all() task_logger.info( f"Collecting sync metrics for {len(recent_sync_records)} sync records" ) metrics = [] for sync_record in recent_sync_records: # Build a job_id for correlation job_id = build_job_id("sync_record", str(sync_record.id)) # Add raw start time metric start_time_key = _SYNC_START_TIME_KEY_FMT.format( sync_type=sync_record.sync_type, entity_id=sync_record.entity_id, sync_record_id=sync_record.id, ) metrics.append( Metric( key=start_time_key, name="sync_start_time", value=sync_record.sync_start_time.timestamp(), tags={ "job_id": job_id, "sync_type": str(sync_record.sync_type), }, ) ) # Add raw end time metric if available if sync_record.sync_end_time: end_time_key = _SYNC_END_TIME_KEY_FMT.format( sync_type=sync_record.sync_type, entity_id=sync_record.entity_id, sync_record_id=sync_record.id, ) metrics.append( Metric( key=end_time_key, name="sync_end_time", value=sync_record.sync_end_time.timestamp(), tags={ "job_id": job_id, "sync_type": str(sync_record.sync_type), }, ) ) # Emit a SUCCESS/FAIL boolean metric # Use a single Redis key to avoid re-emitting final metrics final_metric_key = _FINAL_METRIC_KEY_FMT.format( sync_type=sync_record.sync_type, entity_id=sync_record.entity_id, sync_record_id=sync_record.id, ) if not _has_metric_been_emitted(redis_std, final_metric_key): # Evaluate success sync_succeeded = sync_record.sync_status == SyncStatus.SUCCESS metrics.append( Metric( key=final_metric_key, name="sync_run_succeeded", value=sync_succeeded, tags={ "job_id": job_id, "sync_type": str(sync_record.sync_type), "status": str(sync_record.sync_status), }, ) ) # If successful, emit additional metrics if sync_succeeded: if sync_record.sync_end_time and sync_record.sync_start_time: duration_seconds = ( sync_record.sync_end_time - sync_record.sync_start_time ).total_seconds() else: task_logger.error( f"Invalid times for sync record {sync_record.id}: " f"start={sync_record.sync_start_time}, end={sync_record.sync_end_time}" ) duration_seconds = None doc_count = sync_record.num_docs_synced or 0 sync_speed = None if duration_seconds and duration_seconds > 0: duration_mins = duration_seconds / 60.0 sync_speed = ( doc_count / duration_mins if duration_mins > 0 else None ) # Emit duration, doc count, speed if duration_seconds is not None: metrics.append( Metric( key=final_metric_key, name="sync_duration_seconds", value=duration_seconds, tags={ "job_id": job_id, "sync_type": str(sync_record.sync_type), }, ) ) else: task_logger.error( f"Invalid sync record {sync_record.id} with no duration" ) metrics.append( Metric( key=final_metric_key, name="sync_doc_count", value=doc_count, tags={ "job_id": job_id, "sync_type": str(sync_record.sync_type), }, ) ) if sync_speed is not None: metrics.append( Metric( key=final_metric_key, name="sync_speed_docs_per_min", value=sync_speed, tags={ "job_id": job_id, "sync_type": str(sync_record.sync_type), }, ) ) else: task_logger.error( f"Invalid sync record {sync_record.id} with no duration" ) # Emit start latency start_latency_key = _SYNC_START_LATENCY_KEY_FMT.format( sync_type=sync_record.sync_type, entity_id=sync_record.entity_id, sync_record_id=sync_record.id, ) if not _has_metric_been_emitted(redis_std, start_latency_key): # Get the entity's last update time based on sync type entity: DocumentSet | UserGroup | None = None if sync_record.sync_type == SyncType.DOCUMENT_SET: entity = db_session.scalar( select(DocumentSet).where(DocumentSet.id == sync_record.entity_id) ) elif sync_record.sync_type == SyncType.USER_GROUP: entity = db_session.scalar( select(UserGroup).where(UserGroup.id == sync_record.entity_id) ) else: # Only user groups and document set sync records have # an associated entity we can use for latency metrics continue if entity is None: task_logger.error( f"Sync record of type {sync_record.sync_type} doesn't have an entity " f"associated with it (id={sync_record.entity_id}). Skipping start latency metric." ) # Calculate start latency in seconds: # (actual sync start) - (last modified time) if ( entity is not None and entity.time_last_modified_by_user and sync_record.sync_start_time ): start_latency = ( sync_record.sync_start_time - entity.time_last_modified_by_user ).total_seconds() if start_latency < 0: task_logger.error( f"Negative start latency for sync record {sync_record.id} " f"(start={sync_record.sync_start_time}, entity_modified={entity.time_last_modified_by_user})" ) continue metrics.append( Metric( key=start_latency_key, name="sync_start_latency_seconds", value=start_latency, tags={ "job_id": job_id, "sync_type": str(sync_record.sync_type), }, ) ) return metrics def build_job_id( job_type: Literal["connector", "sync_record"], primary_id: str, secondary_id: str | None = None, ) -> str: if job_type == "connector": if secondary_id is None: raise ValueError( "secondary_id (attempt_id) is required for connector job_type" ) return f"connector:{primary_id}:attempt:{secondary_id}" elif job_type == "sync_record": return f"sync_record:{primary_id}" @shared_task( name=OnyxCeleryTask.MONITOR_BACKGROUND_PROCESSES, ignore_result=True, soft_time_limit=_MONITORING_SOFT_TIME_LIMIT, time_limit=_MONITORING_TIME_LIMIT, queue=OnyxCeleryQueues.MONITORING, bind=True, ) def monitor_background_processes(self: Task, *, tenant_id: str) -> None: """Collect and emit metrics about background processes. This task runs periodically to gather metrics about: - Queue lengths for different Celery queues - Connector run metrics (start latency, success rate) - Syncing speed metrics - Worker status and task counts """ if tenant_id is not None: CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) task_logger.info("Starting background monitoring") r = get_redis_client() lock_monitoring: RedisLock = r.lock( OnyxRedisLocks.MONITOR_BACKGROUND_PROCESSES_LOCK, timeout=_MONITORING_SOFT_TIME_LIMIT, ) # these tasks should never overlap if not lock_monitoring.acquire(blocking=False): task_logger.info("Skipping monitoring task because it is already running") return None try: redis_std = get_redis_client() # Collect queue metrics with broker connection r_celery = celery_get_broker_client(self.app) queue_metrics = _collect_queue_metrics(r_celery) # Collect remaining metrics (no broker connection needed) with get_session_with_current_tenant() as db_session: all_metrics: list[Metric] = queue_metrics all_metrics.extend(_collect_connector_metrics(db_session, redis_std)) all_metrics.extend(_collect_sync_metrics(db_session, redis_std)) for metric in all_metrics: if metric.key is None or not _has_metric_been_emitted( redis_std, metric.key ): metric.log() metric.emit(tenant_id) if metric.key is not None: _mark_metric_as_emitted(redis_std, metric.key) task_logger.info("Successfully collected background metrics") except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." ) except Exception as e: task_logger.exception("Error collecting background process metrics") raise e finally: if lock_monitoring.owned(): lock_monitoring.release() task_logger.info("Background monitoring task finished") @shared_task( name=OnyxCeleryTask.CLOUD_MONITOR_ALEMBIC, ) def cloud_check_alembic() -> bool | None: """A task to verify that all tenants are on the same alembic revision. This check is expected to fail if a cloud alembic migration is currently running across all tenants. TODO: have the cloud migration script set an activity signal that this check uses to know it doesn't make sense to run a check at the present time. """ # Used as a placeholder if the alembic revision cannot be retrieved ALEMBIC_NULL_REVISION = "000000000000" time_start = time.monotonic() redis_client = get_redis_client(tenant_id=ONYX_CLOUD_TENANT_ID) lock_beat: RedisLock = redis_client.lock( OnyxRedisLocks.CLOUD_CHECK_ALEMBIC_BEAT_LOCK, timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT, ) # these tasks should never overlap if not lock_beat.acquire(blocking=False): return None last_lock_time = time.monotonic() tenant_to_revision: dict[str, str] = {} revision_counts: dict[str, int] = {} out_of_date_tenants: dict[str, str] = {} top_revision: str = "" tenant_ids: list[str] | list[None] = [] try: # map tenant_id to revision (or ALEMBIC_NULL_REVISION if the query fails) tenant_ids = get_all_tenant_ids() for tenant_id in tenant_ids: current_time = time.monotonic() if current_time - last_lock_time >= (CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4): lock_beat.reacquire() last_lock_time = current_time if tenant_id is None: continue with get_session_with_shared_schema() as session: try: result = session.execute( text(f'SELECT * FROM "{tenant_id}".alembic_version LIMIT 1') ) result_scalar: str | None = result.scalar_one_or_none() if result_scalar is None: raise ValueError("Alembic version should not be None.") tenant_to_revision[tenant_id] = result_scalar except Exception: task_logger.error(f"Tenant {tenant_id} has no revision!") tenant_to_revision[tenant_id] = ALEMBIC_NULL_REVISION # get the total count of each revision for k, v in tenant_to_revision.items(): revision_counts[v] = revision_counts.get(v, 0) + 1 # error if any null revision tenants are found if ALEMBIC_NULL_REVISION in revision_counts: num_null_revisions = revision_counts[ALEMBIC_NULL_REVISION] raise ValueError(f"No revision was found for {num_null_revisions} tenants!") # get the revision with the most counts sorted_revision_counts = sorted( revision_counts.items(), key=lambda item: item[1], reverse=True ) if len(sorted_revision_counts) == 0: raise ValueError( f"cloud_check_alembic - No revisions found for {len(tenant_ids)} tenant ids!" ) top_revision, _ = sorted_revision_counts[0] # build a list of out of date tenants for k, v in tenant_to_revision.items(): if v == top_revision: continue out_of_date_tenants[k] = v except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." ) raise except Exception: task_logger.exception("Unexpected exception during cloud alembic check") raise finally: if lock_beat.owned(): lock_beat.release() else: task_logger.error("cloud_check_alembic - Lock not owned on completion") redis_lock_dump(lock_beat, redis_client) if len(out_of_date_tenants) > 0: task_logger.error( f"Found out of date tenants: " f"num_out_of_date_tenants={len(out_of_date_tenants)} " f"num_tenants={len(tenant_ids)} " f"revision={top_revision}" ) num_to_log = min(5, len(out_of_date_tenants)) task_logger.info( f"Logging {num_to_log}/{len(out_of_date_tenants)} out of date tenants." ) for k, v in islice(out_of_date_tenants.items(), 5): task_logger.info(f"Out of date tenant: tenant={k} revision={v}") else: task_logger.info( f"All tenants are up to date: num_tenants={len(tenant_ids)} revision={top_revision}" ) time_elapsed = time.monotonic() - time_start task_logger.info( f"cloud_check_alembic finished: num_tenants={len(tenant_ids)} elapsed={time_elapsed:.2f}" ) return True @shared_task( name=OnyxCeleryTask.CLOUD_MONITOR_CELERY_QUEUES, ignore_result=True, bind=True ) def cloud_monitor_celery_queues( self: Task, ) -> None: return monitor_celery_queues_helper(self) @shared_task(name=OnyxCeleryTask.MONITOR_CELERY_QUEUES, ignore_result=True, bind=True) def monitor_celery_queues(self: Task, *, tenant_id: str) -> None: # noqa: ARG001 return monitor_celery_queues_helper(self) def monitor_celery_queues_helper( task: Task, ) -> None: """A task to monitor all celery queue lengths.""" r_celery = celery_get_broker_client(task.app) n_celery = celery_get_queue_length(OnyxCeleryQueues.PRIMARY, r_celery) n_docfetching = celery_get_queue_length( OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery ) n_docprocessing = celery_get_queue_length(OnyxCeleryQueues.DOCPROCESSING, r_celery) n_user_file_processing = celery_get_queue_length( OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery ) n_user_file_project_sync = celery_get_queue_length( OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, r_celery ) n_user_file_delete = celery_get_queue_length( OnyxCeleryQueues.USER_FILE_DELETE, r_celery ) n_sync = celery_get_queue_length(OnyxCeleryQueues.VESPA_METADATA_SYNC, r_celery) n_deletion = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery) n_pruning = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery) n_permissions_sync = celery_get_queue_length( OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery ) n_external_group_sync = celery_get_queue_length( OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery ) n_permissions_upsert = celery_get_queue_length( OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery ) n_hierarchy_fetching = celery_get_queue_length( OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING, r_celery ) n_llm_model_update = celery_get_queue_length( OnyxCeleryQueues.LLM_MODEL_UPDATE, r_celery ) n_checkpoint_cleanup = celery_get_queue_length( OnyxCeleryQueues.CHECKPOINT_CLEANUP, r_celery ) n_index_attempt_cleanup = celery_get_queue_length( OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP, r_celery ) n_csv_generation = celery_get_queue_length( OnyxCeleryQueues.CSV_GENERATION, r_celery ) n_monitoring = celery_get_queue_length(OnyxCeleryQueues.MONITORING, r_celery) n_sandbox = celery_get_queue_length(OnyxCeleryQueues.SANDBOX, r_celery) n_opensearch_migration = celery_get_queue_length( OnyxCeleryQueues.OPENSEARCH_MIGRATION, r_celery ) n_docfetching_prefetched = celery_get_unacked_task_ids( OnyxCeleryQueues.CONNECTOR_DOC_FETCHING, r_celery ) n_docprocessing_prefetched = celery_get_unacked_task_ids( OnyxCeleryQueues.DOCPROCESSING, r_celery ) task_logger.info( f"Queue lengths: celery={n_celery} " f"docfetching={n_docfetching} " f"docfetching_prefetched={len(n_docfetching_prefetched)} " f"docprocessing={n_docprocessing} " f"docprocessing_prefetched={len(n_docprocessing_prefetched)} " f"user_file_processing={n_user_file_processing} " f"user_file_project_sync={n_user_file_project_sync} " f"user_file_delete={n_user_file_delete} " f"sync={n_sync} " f"deletion={n_deletion} " f"pruning={n_pruning} " f"permissions_sync={n_permissions_sync} " f"external_group_sync={n_external_group_sync} " f"permissions_upsert={n_permissions_upsert} " f"hierarchy_fetching={n_hierarchy_fetching} " f"llm_model_update={n_llm_model_update} " f"checkpoint_cleanup={n_checkpoint_cleanup} " f"index_attempt_cleanup={n_index_attempt_cleanup} " f"csv_generation={n_csv_generation} " f"monitoring={n_monitoring} " f"sandbox={n_sandbox} " f"opensearch_migration={n_opensearch_migration} " ) """Memory monitoring""" def _get_cmdline_for_process(process: psutil.Process) -> str | None: try: return " ".join(process.cmdline()) except (psutil.NoSuchProcess, psutil.AccessDenied, psutil.ZombieProcess): return None @shared_task( name=OnyxCeleryTask.MONITOR_PROCESS_MEMORY, ignore_result=True, soft_time_limit=_MONITORING_SOFT_TIME_LIMIT, time_limit=_MONITORING_TIME_LIMIT, queue=OnyxCeleryQueues.MONITORING, bind=True, ) def monitor_process_memory(self: Task, *, tenant_id: str) -> None: # noqa: ARG001 """ Task to monitor memory usage of supervisor-managed processes. This periodically checks the memory usage of processes and logs information in a standardized format. The task looks for processes managed by supervisor and logs their memory usage statistics. This is useful for monitoring memory consumption over time and identifying potential memory leaks. """ # don't run this task in multi-tenant mode, have other, better means of monitoring if MULTI_TENANT: return # Skip memory monitoring if not in container if not is_running_in_container(): return try: # Get all supervisor-managed processes supervisor_processes: dict[int, str] = {} # Map cmd line elements to more readable process names process_type_mapping = { "--hostname=primary": "primary", "--hostname=light": "light", "--hostname=heavy": "heavy", "--hostname=indexing": "indexing", "--hostname=monitoring": "monitoring", "beat": "beat", "slack/listener.py": "slack", } # Find all python processes that are likely celery workers for proc in psutil.process_iter(): cmdline = _get_cmdline_for_process(proc) if not cmdline: continue # Match supervisor-managed processes for process_name, process_type in process_type_mapping.items(): if process_name in cmdline: if process_type in supervisor_processes.values(): task_logger.error( f"Duplicate process type for type {process_type} with cmd {cmdline} with pid={proc.pid}." ) continue supervisor_processes[proc.pid] = process_type break if len(supervisor_processes) != len(process_type_mapping): task_logger.error( f"Missing processes: {set(process_type_mapping.keys()).symmetric_difference(supervisor_processes.values())}" ) # Log memory usage for each process for pid, process_type in supervisor_processes.items(): try: emit_process_memory(pid, process_type, {}) except psutil.NoSuchProcess: # Process may have terminated since we obtained the list continue except Exception as e: task_logger.exception(f"Error monitoring process {pid}: {str(e)}") except Exception: task_logger.exception("Error in monitor_process_memory task") @shared_task( name=OnyxCeleryTask.CLOUD_MONITOR_CELERY_PIDBOX, ignore_result=True, bind=True ) def cloud_monitor_celery_pidbox( self: Task, ) -> None: """ Celery can leave behind orphaned pidboxes from old workers that are idle and never cleaned up. This task removes them based on idle time to avoid Redis clutter and overflowing the instance. This is a real issue we've observed in production. Note: - Setting CELERY_ENABLE_REMOTE_CONTROL = False would prevent pidbox keys entirely, but might also disable features like inspect, broadcast, and worker remote control. Use with caution. """ num_deleted = 0 MAX_PIDBOX_IDLE = 24 * 3600 # 1 day in seconds r_celery = celery_get_broker_client(self.app) for key in r_celery.scan_iter("*.reply.celery.pidbox"): key_bytes = cast(bytes, key) key_str = key_bytes.decode("utf-8") if key_str.startswith("_kombu"): continue idletime_raw = r_celery.object("idletime", key) if idletime_raw is None: continue idletime = cast(int, idletime_raw) if idletime < MAX_PIDBOX_IDLE: continue r_celery.delete(key) task_logger.info( f"Deleted idle pidbox: pidbox={key_str} idletime={idletime} max_idletime={MAX_PIDBOX_IDLE}" ) num_deleted += 1 # Enable later in case we want some aggregate metrics # task_logger.info(f"Deleted idle pidbox: pidbox={key_str}") ================================================ FILE: backend/onyx/background/celery/tasks/opensearch_migration/__init__.py ================================================ ================================================ FILE: backend/onyx/background/celery/tasks/opensearch_migration/constants.py ================================================ # Tasks are expected to cease execution and do cleanup after the soft time # limit. In principle they are also forceably terminated after the hard time # limit, in practice this does not happen since we use threadpools for Celery # task execution, and we simple hope that the total task time plus cleanup does # not exceed this. Therefore tasks should regularly check their timeout and lock # status. The lock timeout is the maximum time the lock manager (Redis in this # case) will enforce the lock, independent of what is happening in the task. To # reduce the chances that a task is still doing work while a lock has expired, # make the lock timeout well above the task timeouts. In practice we should # never see locks be held for this long anyway because a task should release the # lock after its cleanup which happens at most after its soft timeout. # Constants corresponding to migrate_documents_from_vespa_to_opensearch_task. from onyx.configs.app_configs import OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE MIGRATION_TASK_SOFT_TIME_LIMIT_S = 60 * 5 # 5 minutes. MIGRATION_TASK_TIME_LIMIT_S = 60 * 6 # 6 minutes. # The maximum time the lock can be held for. Will automatically be released # after this time. MIGRATION_TASK_LOCK_TIMEOUT_S = 60 * 7 # 7 minutes. assert ( MIGRATION_TASK_SOFT_TIME_LIMIT_S < MIGRATION_TASK_TIME_LIMIT_S ), "The soft time limit must be less than the time limit." assert ( MIGRATION_TASK_TIME_LIMIT_S < MIGRATION_TASK_LOCK_TIMEOUT_S ), "The time limit must be less than the lock timeout." # Time to wait to acquire the lock. MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S = 60 * 2 # 2 minutes. # Constants corresponding to check_for_documents_for_opensearch_migration_task. CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S = 60 # 60 seconds / 1 minute. CHECK_FOR_DOCUMENTS_TASK_TIME_LIMIT_S = 90 # 90 seconds. # The maximum time the lock can be held for. Will automatically be released # after this time. CHECK_FOR_DOCUMENTS_TASK_LOCK_TIMEOUT_S = 120 # 120 seconds / 2 minutes. assert ( CHECK_FOR_DOCUMENTS_TASK_SOFT_TIME_LIMIT_S < CHECK_FOR_DOCUMENTS_TASK_TIME_LIMIT_S ), "The soft time limit must be less than the time limit." assert ( CHECK_FOR_DOCUMENTS_TASK_TIME_LIMIT_S < CHECK_FOR_DOCUMENTS_TASK_LOCK_TIMEOUT_S ), "The time limit must be less than the lock timeout." # Time to wait to acquire the lock. CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S = 30 # 30 seconds. TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE = 15 # WARNING: Do not change these values without knowing what changes also need to # be made to OpenSearchTenantMigrationRecord. GET_VESPA_CHUNKS_PAGE_SIZE = OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE GET_VESPA_CHUNKS_SLICE_COUNT = 4 # String used to indicate in the vespa_visit_continuation_token mapping that the # slice has finished and there is nothing left to visit. FINISHED_VISITING_SLICE_CONTINUATION_TOKEN = ( "FINISHED_VISITING_SLICE_CONTINUATION_TOKEN" ) ================================================ FILE: backend/onyx/background/celery/tasks/opensearch_migration/tasks.py ================================================ """Celery tasks for migrating documents from Vespa to OpenSearch.""" import time import traceback from celery import shared_task from celery import Task from redis.lock import Lock as RedisLock from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.tasks.opensearch_migration.constants import ( FINISHED_VISITING_SLICE_CONTINUATION_TOKEN, ) from onyx.background.celery.tasks.opensearch_migration.constants import ( GET_VESPA_CHUNKS_PAGE_SIZE, ) from onyx.background.celery.tasks.opensearch_migration.constants import ( MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S, ) from onyx.background.celery.tasks.opensearch_migration.constants import ( MIGRATION_TASK_LOCK_TIMEOUT_S, ) from onyx.background.celery.tasks.opensearch_migration.constants import ( MIGRATION_TASK_SOFT_TIME_LIMIT_S, ) from onyx.background.celery.tasks.opensearch_migration.constants import ( MIGRATION_TASK_TIME_LIMIT_S, ) from onyx.background.celery.tasks.opensearch_migration.transformer import ( transform_vespa_chunks_to_opensearch_chunks, ) from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX from onyx.configs.app_configs import VESPA_MIGRATION_REQUEST_TIMEOUT_S from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisLocks from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.opensearch_migration import build_sanitized_to_original_doc_id_mapping from onyx.db.opensearch_migration import get_vespa_visit_state from onyx.db.opensearch_migration import is_migration_completed from onyx.db.opensearch_migration import ( mark_migration_completed_time_if_not_set_with_commit, ) from onyx.db.opensearch_migration import ( try_insert_opensearch_tenant_migration_record_with_commit, ) from onyx.db.opensearch_migration import update_vespa_visit_progress_with_commit from onyx.db.search_settings import get_current_search_settings from onyx.document_index.interfaces_new import TenantState from onyx.document_index.opensearch.opensearch_document_index import ( OpenSearchDocumentIndex, ) from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex from onyx.indexing.models import IndexingSetting from onyx.redis.redis_pool import get_redis_client from shared_configs.configs import MULTI_TENANT from shared_configs.contextvars import get_current_tenant_id def is_continuation_token_done_for_all_slices( continuation_token_map: dict[int, str | None], ) -> bool: return all( continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN for continuation_token in continuation_token_map.values() ) # shared_task allows this task to be shared across celery app instances. @shared_task( name=OnyxCeleryTask.MIGRATE_CHUNKS_FROM_VESPA_TO_OPENSEARCH_TASK, # Does not store the task's return value in the result backend. ignore_result=True, # WARNING: This is here just for rigor but since we use threads for Celery # this config is not respected and timeout logic must be implemented in the # task. soft_time_limit=MIGRATION_TASK_SOFT_TIME_LIMIT_S, # WARNING: This is here just for rigor but since we use threads for Celery # this config is not respected and timeout logic must be implemented in the # task. time_limit=MIGRATION_TASK_TIME_LIMIT_S, # Passed in self to the task to get task metadata. bind=True, ) def migrate_chunks_from_vespa_to_opensearch_task( self: Task, # noqa: ARG001 *, tenant_id: str, ) -> bool | None: """ Periodic task to migrate chunks from Vespa to OpenSearch via the Visit API. Uses Vespa's Visit API to iterate through ALL chunks in bulk (not per-document), transform them, and index them into OpenSearch. Progress is tracked via a continuation token map stored in the OpenSearchTenantMigrationRecord. The first time we see no continuation token map and non-zero chunks migrated, we consider the migration complete and all subsequent invocations are no-ops. We divide the index into GET_VESPA_CHUNKS_SLICE_COUNT independent slices where progress is tracked for each slice. Returns: None if OpenSearch migration is not enabled, or if the lock could not be acquired; effectively a no-op. True if the task completed successfully. False if the task errored. """ # 1. Check if we should run the task. # 1.a. If OpenSearch indexing is disabled, we don't run the task. if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX: task_logger.warning( "OpenSearch migration is not enabled, skipping chunk migration task." ) return None task_logger.info("Starting chunk-level migration from Vespa to OpenSearch.") task_start_time = time.monotonic() # 1.b. Only one instance per tenant of this task may run concurrently at # once. If we fail to acquire a lock, we assume it is because another task # has one and we exit. r = get_redis_client() lock: RedisLock = r.lock( name=OnyxRedisLocks.OPENSEARCH_MIGRATION_BEAT_LOCK, # The maximum time the lock can be held for. Will automatically be # released after this time. timeout=MIGRATION_TASK_LOCK_TIMEOUT_S, # .acquire will block until the lock is acquired. blocking=True, # Time to wait to acquire the lock. blocking_timeout=MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S, ) if not lock.acquire(): task_logger.warning( "The OpenSearch migration task timed out waiting for the lock." ) return None else: task_logger.info( f"Acquired the OpenSearch migration lock. Took {time.monotonic() - task_start_time:.3f} seconds. " f"Token: {lock.local.token}" ) # 2. Prepare to migrate. total_chunks_migrated_this_task = 0 total_chunks_errored_this_task = 0 try: # 2.a. Double-check that tenant info is correct. if tenant_id != get_current_tenant_id(): err_str = ( f"Tenant ID mismatch in the OpenSearch migration task: " f"{tenant_id} != {get_current_tenant_id()}. This should never happen." ) task_logger.error(err_str) return False # Do as much as we can with a DB session in one spot to not hold a # session during a migration batch. with get_session_with_current_tenant() as db_session: # 2.b. Immediately check to see if this tenant is done, to save # having to do any other work. This function does not require a # migration record to necessarily exist. if is_migration_completed(db_session): return True # 2.c. Try to insert the OpenSearchTenantMigrationRecord table if it # does not exist. try_insert_opensearch_tenant_migration_record_with_commit(db_session) # 2.d. Get search settings. search_settings = get_current_search_settings(db_session) indexing_setting = IndexingSetting.from_db_model(search_settings) # 2.e. Build sanitized to original doc ID mapping to check for # conflicts in the event we sanitize a doc ID to an # already-existing doc ID. # We reconstruct this mapping for every task invocation because # a document may have been added in the time between two tasks. sanitized_doc_start_time = time.monotonic() sanitized_to_original_doc_id_mapping = ( build_sanitized_to_original_doc_id_mapping(db_session) ) task_logger.debug( f"Built sanitized_to_original_doc_id_mapping with {len(sanitized_to_original_doc_id_mapping)} entries " f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds." ) # 2.f. Get the current migration state. continuation_token_map, total_chunks_migrated = get_vespa_visit_state( db_session ) # 2.f.1. Double-check that the migration state does not imply # completion. Really we should never have to enter this block as we # would expect is_migration_completed to return True, but in the # strange event that the migration is complete but the migration # completed time was never stamped, we do so here. if is_continuation_token_done_for_all_slices(continuation_token_map): task_logger.info( f"OpenSearch migration COMPLETED for tenant {tenant_id}. Total chunks migrated: {total_chunks_migrated}." ) mark_migration_completed_time_if_not_set_with_commit(db_session) return True task_logger.debug( f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. " f"Continuation token map: {continuation_token_map}" ) with get_vespa_http_client( timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S ) as vespa_client: # 2.g. Create the OpenSearch and Vespa document indexes. tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT) opensearch_document_index = OpenSearchDocumentIndex( tenant_state=tenant_state, index_name=search_settings.index_name, embedding_dim=indexing_setting.final_embedding_dim, embedding_precision=indexing_setting.embedding_precision, ) vespa_document_index = VespaDocumentIndex( index_name=search_settings.index_name, tenant_state=tenant_state, large_chunks_enabled=False, httpx_client=vespa_client, ) # 2.h. Get the approximate chunk count in Vespa as of this time to # update the migration record. approx_chunk_count_in_vespa: int | None = None get_chunk_count_start_time = time.monotonic() try: approx_chunk_count_in_vespa = vespa_document_index.get_chunk_count() except Exception: # This failure should not be blocking. task_logger.exception( "Error getting approximate chunk count in Vespa. Moving on..." ) task_logger.debug( f"Took {time.monotonic() - get_chunk_count_start_time:.3f} seconds to attempt to get " f"approximate chunk count in Vespa. Got {approx_chunk_count_in_vespa}." ) # 3. Do the actual migration in batches until we run out of time. while ( time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S and lock.owned() ): # 3.a. Get the next batch of raw chunks from Vespa. get_vespa_chunks_start_time = time.monotonic() raw_vespa_chunks, next_continuation_token_map = ( vespa_document_index.get_all_raw_document_chunks_paginated( continuation_token_map=continuation_token_map, page_size=GET_VESPA_CHUNKS_PAGE_SIZE, ) ) task_logger.debug( f"Read {len(raw_vespa_chunks)} chunks from Vespa in {time.monotonic() - get_vespa_chunks_start_time:.3f} " f"seconds. Next continuation token map: {next_continuation_token_map}" ) # 3.b. Transform the raw chunks to OpenSearch chunks in memory. opensearch_document_chunks, errored_chunks = ( transform_vespa_chunks_to_opensearch_chunks( raw_vespa_chunks, tenant_state, sanitized_to_original_doc_id_mapping, ) ) if len(opensearch_document_chunks) != len(raw_vespa_chunks): task_logger.error( f"Migration task error: Number of candidate chunks to migrate ({len(opensearch_document_chunks)}) does " f"not match number of chunks in Vespa ({len(raw_vespa_chunks)}). {len(errored_chunks)} chunks " "errored." ) # 3.c. Index the OpenSearch chunks into OpenSearch. index_opensearch_chunks_start_time = time.monotonic() opensearch_document_index.index_raw_chunks( chunks=opensearch_document_chunks ) task_logger.debug( f"Indexed {len(opensearch_document_chunks)} chunks into OpenSearch in " f"{time.monotonic() - index_opensearch_chunks_start_time:.3f} seconds." ) total_chunks_migrated_this_task += len(opensearch_document_chunks) total_chunks_errored_this_task += len(errored_chunks) # Do as much as we can with a DB session in one spot to not hold a # session during a migration batch. with get_session_with_current_tenant() as db_session: # 3.d. Update the migration state. update_vespa_visit_progress_with_commit( db_session, continuation_token_map=next_continuation_token_map, chunks_processed=len(opensearch_document_chunks), chunks_errored=len(errored_chunks), approx_chunk_count_in_vespa=approx_chunk_count_in_vespa, ) # 3.e. Get the current migration state. Even thought we # technically have it in-memory since we just wrote it, we # want to reference the DB as the source of truth at all # times. continuation_token_map, total_chunks_migrated = ( get_vespa_visit_state(db_session) ) # 3.e.1. Check if the migration is done. if is_continuation_token_done_for_all_slices( continuation_token_map ): task_logger.info( f"OpenSearch migration COMPLETED for tenant {tenant_id}. Total chunks migrated: {total_chunks_migrated}." ) mark_migration_completed_time_if_not_set_with_commit(db_session) return True task_logger.debug( f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. " f"Continuation token map: {continuation_token_map}" ) except Exception: traceback.print_exc() task_logger.exception("Error in the OpenSearch migration task.") return False finally: if lock.owned(): lock.release() else: task_logger.warning( "The OpenSearch migration lock was not owned on completion of the migration task." ) task_logger.info( f"OpenSearch chunk migration task pausing (time limit reached). " f"Total chunks migrated this task: {total_chunks_migrated_this_task}. " f"Total chunks errored this task: {total_chunks_errored_this_task}. " f"Elapsed: {time.monotonic() - task_start_time:.3f}s. " "Will resume from continuation token on next invocation." ) return True ================================================ FILE: backend/onyx/background/celery/tasks/opensearch_migration/transformer.py ================================================ import traceback from datetime import datetime from datetime import timezone from typing import Any from onyx.configs.constants import PUBLIC_DOC_PAT from onyx.document_index.interfaces_new import TenantState from onyx.document_index.opensearch.schema import DocumentChunk from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST from onyx.document_index.vespa_constants import BLURB from onyx.document_index.vespa_constants import BOOST from onyx.document_index.vespa_constants import CHUNK_CONTEXT from onyx.document_index.vespa_constants import CHUNK_ID from onyx.document_index.vespa_constants import CONTENT from onyx.document_index.vespa_constants import DOC_SUMMARY from onyx.document_index.vespa_constants import DOC_UPDATED_AT from onyx.document_index.vespa_constants import DOCUMENT_ID from onyx.document_index.vespa_constants import DOCUMENT_SETS from onyx.document_index.vespa_constants import EMBEDDINGS from onyx.document_index.vespa_constants import FULL_CHUNK_EMBEDDING_KEY from onyx.document_index.vespa_constants import HIDDEN from onyx.document_index.vespa_constants import IMAGE_FILE_NAME from onyx.document_index.vespa_constants import METADATA_LIST from onyx.document_index.vespa_constants import METADATA_SUFFIX from onyx.document_index.vespa_constants import PERSONAS from onyx.document_index.vespa_constants import PRIMARY_OWNERS from onyx.document_index.vespa_constants import SECONDARY_OWNERS from onyx.document_index.vespa_constants import SEMANTIC_IDENTIFIER from onyx.document_index.vespa_constants import SOURCE_LINKS from onyx.document_index.vespa_constants import SOURCE_TYPE from onyx.document_index.vespa_constants import TENANT_ID from onyx.document_index.vespa_constants import TITLE from onyx.document_index.vespa_constants import TITLE_EMBEDDING from onyx.document_index.vespa_constants import USER_PROJECT from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT logger = setup_logger(__name__) FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [ DOCUMENT_ID, CHUNK_ID, TITLE, TITLE_EMBEDDING, CONTENT, EMBEDDINGS, SOURCE_TYPE, METADATA_LIST, DOC_UPDATED_AT, HIDDEN, BOOST, SEMANTIC_IDENTIFIER, IMAGE_FILE_NAME, SOURCE_LINKS, BLURB, DOC_SUMMARY, CHUNK_CONTEXT, METADATA_SUFFIX, DOCUMENT_SETS, USER_PROJECT, PERSONAS, PRIMARY_OWNERS, SECONDARY_OWNERS, ACCESS_CONTROL_LIST, ] if MULTI_TENANT: FIELDS_NEEDED_FOR_TRANSFORMATION.append(TENANT_ID) def _extract_content_vector(embeddings: Any) -> list[float]: """Extracts the full chunk embedding vector from Vespa's embeddings tensor. Vespa stores embeddings as a tensor(t{},x[dim]) where 't' maps embedding names (like "full_chunk") to vectors. The API can return this in different formats: 1. Direct list: {"full_chunk": [...]} 2. Blocks format: {"blocks": {"full_chunk": [0.1, 0.2, ...]}} 3. Possibly other formats. We only support formats 1 and 2. Any other supplied format will raise an error. Raises: ValueError: If the embeddings format is not supported. Returns: The full chunk content embedding vector as a list of floats. """ if isinstance(embeddings, dict): # Handle format 1. full_chunk_embedding = embeddings.get(FULL_CHUNK_EMBEDDING_KEY) if isinstance(full_chunk_embedding, list): # Double check that within the list we have floats and not another # list or dict. if not full_chunk_embedding: raise ValueError("Full chunk embedding is empty.") if isinstance(full_chunk_embedding[0], float): return full_chunk_embedding # Handle format 2. blocks = embeddings.get("blocks") if isinstance(blocks, dict): full_chunk_embedding = blocks.get(FULL_CHUNK_EMBEDDING_KEY) if isinstance(full_chunk_embedding, list): # Double check that within the list we have floats and not another # list or dict. if not full_chunk_embedding: raise ValueError("Full chunk embedding is empty.") if isinstance(full_chunk_embedding[0], float): return full_chunk_embedding raise ValueError(f"Unknown embedding format: {type(embeddings)}") def _extract_title_vector(title_embedding: Any | None) -> list[float] | None: """Extract the title embedding vector. Returns None if no title embedding exists. Vespa returns title_embedding as tensor(x[dim]) which can be in formats: 1. Direct list: [0.1, 0.2, ...] 2. Values format: {"values": [0.1, 0.2, ...]} 3. Possibly other formats. Only formats 1 and 2 are supported. Any other supplied format will raise an error. Raises: ValueError: If the title embedding format is not supported. Returns: The title embedding vector as a list of floats. """ if title_embedding is None: return None # Handle format 1. if isinstance(title_embedding, list): # Double check that within the list we have floats and not another # list or dict. if not title_embedding: return None if isinstance(title_embedding[0], float): return title_embedding # Handle format 2. if isinstance(title_embedding, dict): # Try values format. values = title_embedding.get("values") if values is not None and isinstance(values, list): # Double check that within the list we have floats and not another # list or dict. if not values: return None if isinstance(values[0], float): return values raise ValueError(f"Unknown title embedding format: {type(title_embedding)}") def _transform_vespa_document_sets_to_opensearch_document_sets( vespa_document_sets: dict[str, int] | None, ) -> list[str] | None: if not vespa_document_sets: return None return list(vespa_document_sets.keys()) def _transform_vespa_acl_to_opensearch_acl( vespa_acl: dict[str, int] | None, ) -> tuple[bool, list[str]]: if not vespa_acl: return False, [] acl_list = list(vespa_acl.keys()) is_public = PUBLIC_DOC_PAT in acl_list if is_public: acl_list.remove(PUBLIC_DOC_PAT) return is_public, acl_list def transform_vespa_chunks_to_opensearch_chunks( vespa_chunks: list[dict[str, Any]], tenant_state: TenantState, sanitized_to_original_doc_id_mapping: dict[str, str], ) -> tuple[list[DocumentChunk], list[dict[str, Any]]]: result: list[DocumentChunk] = [] errored_chunks: list[dict[str, Any]] = [] for vespa_chunk in vespa_chunks: try: # This should exist; fail loudly if it does not. vespa_document_id: str = vespa_chunk[DOCUMENT_ID] if not vespa_document_id: raise ValueError("Missing document_id in Vespa chunk.") # Vespa doc IDs were sanitized using # replace_invalid_doc_id_characters. This was a poor design choice # and we don't want this in OpenSearch; whatever restrictions there # may be on indexed chunk ID should have no bearing on the chunk's # document ID field, even if document ID is an argument to the chunk # ID. Deliberately choose to use the real doc ID supplied to this # function. if vespa_document_id in sanitized_to_original_doc_id_mapping: logger.warning( f"Migration warning: Vespa document ID {vespa_document_id} does not match the document ID supplied " f"{sanitized_to_original_doc_id_mapping[vespa_document_id]}. " "The Vespa ID will be discarded." ) document_id = sanitized_to_original_doc_id_mapping.get( vespa_document_id, vespa_document_id ) # This should exist; fail loudly if it does not. chunk_index: int = vespa_chunk[CHUNK_ID] title: str | None = vespa_chunk.get(TITLE) # WARNING: Should supply format.tensors=short-value to the Vespa # client in order to get a supported format for the tensors. title_vector: list[float] | None = _extract_title_vector( vespa_chunk.get(TITLE_EMBEDDING) ) # This should exist; fail loudly if it does not. content: str = vespa_chunk[CONTENT] if not content: raise ValueError( f"Missing content in Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index}." ) # This should exist; fail loudly if it does not. # WARNING: Should supply format.tensors=short-value to the Vespa # client in order to get a supported format for the tensors. content_vector: list[float] = _extract_content_vector( vespa_chunk[EMBEDDINGS] ) if not content_vector: raise ValueError( f"Missing content_vector in Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index}." ) # This should exist; fail loudly if it does not. source_type: str = vespa_chunk[SOURCE_TYPE] if not source_type: raise ValueError( f"Missing source_type in Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index}." ) metadata_list: list[str] | None = vespa_chunk.get(METADATA_LIST) _raw_doc_updated_at: int | None = vespa_chunk.get(DOC_UPDATED_AT) last_updated: datetime | None = ( datetime.fromtimestamp(_raw_doc_updated_at, tz=timezone.utc) if _raw_doc_updated_at is not None else None ) hidden: bool = vespa_chunk.get(HIDDEN, False) # This should exist; fail loudly if it does not. global_boost: int = vespa_chunk[BOOST] # This should exist; fail loudly if it does not. semantic_identifier: str = vespa_chunk[SEMANTIC_IDENTIFIER] if not semantic_identifier: raise ValueError( f"Missing semantic_identifier in Vespa chunk with document ID {vespa_document_id} and chunk " f"index {chunk_index}." ) image_file_id: str | None = vespa_chunk.get(IMAGE_FILE_NAME) source_links: str | None = vespa_chunk.get(SOURCE_LINKS) blurb: str = vespa_chunk.get(BLURB, "") doc_summary: str = vespa_chunk.get(DOC_SUMMARY, "") chunk_context: str = vespa_chunk.get(CHUNK_CONTEXT, "") metadata_suffix: str | None = vespa_chunk.get(METADATA_SUFFIX) document_sets: list[str] | None = ( _transform_vespa_document_sets_to_opensearch_document_sets( vespa_chunk.get(DOCUMENT_SETS) ) ) user_projects: list[int] | None = vespa_chunk.get(USER_PROJECT) personas: list[int] | None = vespa_chunk.get(PERSONAS) primary_owners: list[str] | None = vespa_chunk.get(PRIMARY_OWNERS) secondary_owners: list[str] | None = vespa_chunk.get(SECONDARY_OWNERS) is_public, acl_list = _transform_vespa_acl_to_opensearch_acl( vespa_chunk.get(ACCESS_CONTROL_LIST) ) if not is_public and not acl_list: logger.warning( f"Migration warning: Vespa chunk with document ID {vespa_document_id} and chunk index {chunk_index} has no " "public ACL and no access control list. This does not make sense as it implies the document is never " "searchable. Continuing with the migration..." ) chunk_tenant_id: str | None = vespa_chunk.get(TENANT_ID) if MULTI_TENANT: if not chunk_tenant_id: raise ValueError( "Missing tenant_id in Vespa chunk in a multi-tenant environment." ) if chunk_tenant_id != tenant_state.tenant_id: raise ValueError( f"Chunk tenant_id {chunk_tenant_id} does not match expected tenant_id {tenant_state.tenant_id}" ) opensearch_chunk = DocumentChunk( # We deliberately choose to use the doc ID supplied to this function # over the Vespa doc ID. document_id=document_id, chunk_index=chunk_index, title=title, title_vector=title_vector, content=content, content_vector=content_vector, source_type=source_type, metadata_list=metadata_list, last_updated=last_updated, public=is_public, access_control_list=acl_list, hidden=hidden, global_boost=global_boost, semantic_identifier=semantic_identifier, image_file_id=image_file_id, source_links=source_links, blurb=blurb, doc_summary=doc_summary, chunk_context=chunk_context, metadata_suffix=metadata_suffix, document_sets=document_sets, user_projects=user_projects, personas=personas, primary_owners=primary_owners, secondary_owners=secondary_owners, tenant_id=tenant_state, ) result.append(opensearch_chunk) except Exception: traceback.print_exc() logger.exception( f"Migration error: Error transforming Vespa chunk with document ID {vespa_chunk.get(DOCUMENT_ID)} " f"and chunk index {vespa_chunk.get(CHUNK_ID)} into an OpenSearch chunk. Continuing with " "the migration..." ) errored_chunks.append(vespa_chunk) return result, errored_chunks ================================================ FILE: backend/onyx/background/celery/tasks/periodic/__init__.py ================================================ ================================================ FILE: backend/onyx/background/celery/tasks/periodic/tasks.py ================================================ ##### # Periodic Tasks ##### import json from typing import Any from celery import shared_task from celery.contrib.abortable import AbortableTask # type: ignore from celery.exceptions import TaskRevokedError from sqlalchemy import inspect from sqlalchemy import text from sqlalchemy.orm import Session from onyx.background.celery.apps.app_base import task_logger from onyx.configs.app_configs import JOB_TIMEOUT from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import PostgresAdvisoryLocks from onyx.db.engine.sql_engine import get_session_with_current_tenant @shared_task( name=OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK, soft_time_limit=JOB_TIMEOUT, bind=True, base=AbortableTask, ) def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int: # noqa: ARG001 """Runs periodically to clean up the kombu_message table""" # we will select messages older than this amount to clean up KOMBU_MESSAGE_CLEANUP_AGE = 7 # days KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000 ctx = {} ctx["last_processed_id"] = 0 ctx["deleted"] = 0 ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT with get_session_with_current_tenant() as db_session: # Exit the task if we can't take the advisory lock result = db_session.execute( text("SELECT pg_try_advisory_lock(:id)"), {"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value}, ).scalar() if not result: return 0 while True: if self.is_aborted(): raise TaskRevokedError("kombu_message_cleanup_task was aborted.") b = kombu_message_cleanup_task_helper(ctx, db_session) if not b: break db_session.commit() if ctx["deleted"] > 0: task_logger.info( f"Deleted {ctx['deleted']} orphaned messages from kombu_message." ) return ctx["deleted"] def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool: """ Helper function to clean up old messages from the `kombu_message` table that are no longer relevant. This function retrieves messages from the `kombu_message` table that are no longer visible and older than a specified interval. It checks if the corresponding task_id exists in the `celery_taskmeta` table. If the task_id does not exist, the message is deleted. Args: ctx (dict): A context dictionary containing configuration parameters such as: - 'cleanup_age' (int): The age in days after which messages are considered old. - 'page_limit' (int): The maximum number of messages to process in one batch. - 'last_processed_id' (int): The ID of the last processed message to handle pagination. - 'deleted' (int): A counter to track the number of deleted messages. db_session (Session): The SQLAlchemy database session for executing queries. Returns: bool: Returns True if there are more rows to process, False if not. """ inspector = inspect(db_session.bind) if not inspector: return False # With the move to redis as celery's broker and backend, kombu tables may not even exist. # We can fail silently. if not inspector.has_table("kombu_message"): return False query = text( """ SELECT id, timestamp, payload FROM kombu_message WHERE visible = 'false' AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days AND id > :last_processed_id ORDER BY id LIMIT :page_limit """ ) kombu_messages = db_session.execute( query, { "interval_days": f"{ctx['cleanup_age']} days", "page_limit": ctx["page_limit"], "last_processed_id": ctx["last_processed_id"], }, ).fetchall() if len(kombu_messages) == 0: return False for msg in kombu_messages: payload = json.loads(msg[2]) task_id = payload["headers"]["id"] # Check if task_id exists in celery_taskmeta task_exists = db_session.execute( text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"), {"task_id": task_id}, ).fetchone() # If task_id does not exist, delete the message if not task_exists: result = db_session.execute( text("DELETE FROM kombu_message WHERE id = :message_id"), {"message_id": msg[0]}, ) if result.rowcount > 0: # type: ignore ctx["deleted"] += 1 ctx["last_processed_id"] = msg[0] return True ================================================ FILE: backend/onyx/background/celery/tasks/pruning/__init__.py ================================================ ================================================ FILE: backend/onyx/background/celery/tasks/pruning/tasks.py ================================================ import time from datetime import datetime from datetime import timedelta from datetime import timezone from typing import Any from typing import cast from uuid import uuid4 from celery import Celery from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded from pydantic import ValidationError from redis import Redis from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.celery_redis import celery_find_task from onyx.background.celery.celery_redis import celery_get_broker_client from onyx.background.celery.celery_redis import celery_get_queue_length from onyx.background.celery.celery_redis import celery_get_queued_task_ids from onyx.background.celery.celery_redis import celery_get_unacked_task_ids from onyx.background.celery.celery_utils import extract_ids_from_runnable_connector from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT from onyx.background.celery.tasks.docprocessing.utils import IndexingCallbackBase from onyx.configs.app_configs import ALLOW_SIMULTANEOUS_PRUNING from onyx.configs.app_configs import JOB_TIMEOUT from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX from onyx.configs.constants import DocumentSource from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisConstants from onyx.configs.constants import OnyxRedisLocks from onyx.configs.constants import OnyxRedisSignals from onyx.connectors.factory import instantiate_connector from onyx.connectors.models import InputType from onyx.db.connector import mark_ccpair_as_pruned from onyx.db.connector_credential_pair import get_connector_credential_pair from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.connector_credential_pair import get_connector_credential_pairs from onyx.db.document import get_documents_for_connector_credential_pair from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import SyncStatus from onyx.db.enums import SyncType from onyx.db.hierarchy import delete_orphaned_hierarchy_nodes from onyx.db.hierarchy import link_hierarchy_nodes_to_documents from onyx.db.hierarchy import remove_stale_hierarchy_node_cc_pair_entries from onyx.db.hierarchy import reparent_orphaned_hierarchy_nodes from onyx.db.hierarchy import update_document_parent_hierarchy_nodes from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries from onyx.db.hierarchy import upsert_hierarchy_nodes_batch from onyx.db.models import ConnectorCredentialPair from onyx.db.models import HierarchyNode as DBHierarchyNode from onyx.db.sync_record import insert_sync_record from onyx.db.sync_record import update_sync_record_status from onyx.db.tag import delete_orphan_tags__no_commit from onyx.redis.redis_connector import RedisConnector from onyx.redis.redis_connector_prune import RedisConnectorPrune from onyx.redis.redis_connector_prune import RedisConnectorPrunePayload from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch from onyx.redis.redis_hierarchy import ensure_source_node_exists from onyx.redis.redis_hierarchy import evict_hierarchy_nodes_from_cache from onyx.redis.redis_hierarchy import get_node_id_from_raw_id from onyx.redis.redis_hierarchy import get_source_node_id_from_cache from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry from onyx.redis.redis_pool import get_redis_client from onyx.redis.redis_pool import get_redis_replica_client from onyx.server.runtime.onyx_runtime import OnyxRuntime from onyx.server.utils import make_short_id from onyx.utils.logger import format_error_for_logging from onyx.utils.logger import LoggerContextVars from onyx.utils.logger import pruning_ctx from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT logger = setup_logger() def _get_pruning_block_expiration() -> int: """ Compute the expiration time for the pruning block signal. Base expiration is 60 seconds (1 minute), multiplied by the beat multiplier only in MULTI_TENANT mode. """ base_expiration = 60 # seconds if not MULTI_TENANT: return base_expiration try: beat_multiplier = OnyxRuntime.get_beat_multiplier() except Exception: beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT return int(base_expiration * beat_multiplier) def _get_fence_validation_block_expiration() -> int: """ Compute the expiration time for the fence validation block signal. Base expiration is 300 seconds, multiplied by the beat multiplier only in MULTI_TENANT mode. """ base_expiration = 300 # seconds if not MULTI_TENANT: return base_expiration try: beat_multiplier = OnyxRuntime.get_beat_multiplier() except Exception: beat_multiplier = CLOUD_BEAT_MULTIPLIER_DEFAULT return int(base_expiration * beat_multiplier) class PruneCallback(IndexingCallbackBase): def progress(self, tag: str, amount: int) -> None: self.redis_connector.prune.set_active() super().progress(tag, amount) def _resolve_and_update_document_parents( db_session: Session, redis_client: Redis, source: DocumentSource, raw_id_to_parent: dict[str, str | None], ) -> None: """Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id for each document and bulk-update the DB. Mirrors the resolution logic in run_docfetching.py.""" source_node_id = get_source_node_id_from_cache(redis_client, db_session, source) resolved: dict[str, int | None] = {} for doc_id, raw_parent_id in raw_id_to_parent.items(): if raw_parent_id is None: continue node_id, found = get_node_id_from_raw_id(redis_client, source, raw_parent_id) resolved[doc_id] = node_id if found else source_node_id if not resolved: return update_document_parent_hierarchy_nodes( db_session=db_session, doc_parent_map=resolved, commit=True, ) task_logger.info( f"Pruning: resolved and updated parent hierarchy for {len(resolved)} documents (source={source.value})" ) """Jobs / utils for kicking off pruning tasks.""" def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool: """Returns boolean indicating if pruning is due. Next pruning time is calculated as a delta from the last successful prune, or the last successful indexing if pruning has never succeeded. TODO(rkuo): consider whether we should allow pruning to be immediately rescheduled if pruning fails (which is what it does now). A backoff could be reasonable. """ # skip pruning if no prune frequency is set # pruning can still be forced via the API which will run a pruning task directly if not cc_pair.connector.prune_freq: return False # skip pruning if not active if cc_pair.status != ConnectorCredentialPairStatus.ACTIVE: return False # skip pruning if the next scheduled prune time hasn't been reached yet last_pruned = cc_pair.last_pruned if not last_pruned: if not cc_pair.last_successful_index_time: # if we've never indexed, we can't prune return False # if never pruned, use the connector creation time. We could also # compute the completion time of the first successful index attempt, but # that is a reasonably heavy operation. This is a reasonable approximation — # in the worst case, we'll prune a little bit earlier than we should. last_pruned = cc_pair.connector.time_created next_prune = last_pruned + timedelta(seconds=cc_pair.connector.prune_freq) return datetime.now(timezone.utc) >= next_prune @shared_task( name=OnyxCeleryTask.CHECK_FOR_PRUNING, ignore_result=True, soft_time_limit=JOB_TIMEOUT, bind=True, ) def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None: r = get_redis_client() r_replica = get_redis_replica_client() lock_beat: RedisLock = r.lock( OnyxRedisLocks.CHECK_PRUNE_BEAT_LOCK, timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT, ) # these tasks should never overlap if not lock_beat.acquire(blocking=False): return None try: # the entire task needs to run frequently in order to finalize pruning # but pruning only kicks off once per hour if not r.exists(OnyxRedisSignals.BLOCK_PRUNING): task_logger.info("Checking for pruning due") cc_pair_ids: list[int] = [] with get_session_with_current_tenant() as db_session: cc_pairs = get_connector_credential_pairs(db_session) for cc_pair_entry in cc_pairs: cc_pair_ids.append(cc_pair_entry.id) for cc_pair_id in cc_pair_ids: lock_beat.reacquire() with get_session_with_current_tenant() as db_session: cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, ) if not cc_pair: logger.error(f"CC pair not found: {cc_pair_id}") continue if not _is_pruning_due(cc_pair): logger.info(f"CC pair not due for pruning: {cc_pair_id}") continue payload_id = try_creating_prune_generator_task( self.app, cc_pair, db_session, r, tenant_id ) if not payload_id: logger.info(f"Pruning not created: {cc_pair_id}") continue task_logger.info( f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}" ) r.set(OnyxRedisSignals.BLOCK_PRUNING, 1, ex=_get_pruning_block_expiration()) # we want to run this less frequently than the overall task lock_beat.reacquire() if not r.exists(OnyxRedisSignals.BLOCK_VALIDATE_PRUNING_FENCES): # clear any permission fences that don't have associated celery tasks in progress # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker), # or be currently executing try: r_celery = celery_get_broker_client(self.app) validate_pruning_fences(tenant_id, r, r_replica, r_celery, lock_beat) except Exception: task_logger.exception("Exception while validating pruning fences") r.set( OnyxRedisSignals.BLOCK_VALIDATE_PRUNING_FENCES, 1, ex=_get_fence_validation_block_expiration(), ) # use a lookup table to find active fences. We still have to verify the fence # exists since it is an optimization and not the source of truth. lock_beat.reacquire() keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES)) for key in keys: key_bytes = cast(bytes, key) if not r.exists(key_bytes): r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes) continue key_str = key_bytes.decode("utf-8") if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX): with get_session_with_current_tenant() as db_session: monitor_ccpair_pruning_taskset(tenant_id, key_bytes, r, db_session) except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." ) except Exception as e: error_msg = format_error_for_logging(e) task_logger.warning(f"Unexpected pruning check exception: {error_msg}") task_logger.exception("Unexpected exception during pruning check") finally: if lock_beat.owned(): lock_beat.release() task_logger.info(f"check_for_pruning finished: tenant={tenant_id}") return True def try_creating_prune_generator_task( celery_app: Celery, cc_pair: ConnectorCredentialPair, db_session: Session, r: Redis, tenant_id: str, ) -> str | None: """Checks for any conditions that should block the pruning generator task from being created, then creates the task. Does not check for scheduling related conditions as this function is used to trigger prunes immediately, e.g. via the web ui. """ logger.info(f"try_creating_prune_generator_task: cc_pair={cc_pair.id}") redis_connector = RedisConnector(tenant_id, cc_pair.id) if not ALLOW_SIMULTANEOUS_PRUNING: count = redis_connector.prune.get_active_task_count() if count > 0: logger.info( f"try_creating_prune_generator_task: cc_pair={cc_pair.id} no simultaneous pruning allowed" ) return None LOCK_TIMEOUT = 30 # we need to serialize starting pruning since it can be triggered either via # celery beat or manually (API call) lock: RedisLock = r.lock( DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_creating_prune_generator_task", timeout=LOCK_TIMEOUT, ) acquired = lock.acquire(blocking_timeout=LOCK_TIMEOUT / 2) if not acquired: logger.info( f"try_creating_prune_generator_task: cc_pair={cc_pair.id} lock not acquired" ) return None try: # skip pruning if already pruning if redis_connector.prune.fenced: logger.info( f"try_creating_prune_generator_task: cc_pair={cc_pair.id} already pruning" ) return None # skip pruning if the cc_pair is deleting if redis_connector.delete.fenced: logger.info( f"try_creating_prune_generator_task: cc_pair={cc_pair.id} deleting" ) return None # skip pruning if doc permissions sync is running if redis_connector.permissions.fenced: logger.info( f"try_creating_prune_generator_task: cc_pair={cc_pair.id} permissions sync running" ) return None db_session.refresh(cc_pair) if cc_pair.status == ConnectorCredentialPairStatus.DELETING: logger.info( f"try_creating_prune_generator_task: cc_pair={cc_pair.id} deleting" ) return None # add a long running generator task to the queue redis_connector.prune.generator_clear() redis_connector.prune.taskset_clear() custom_task_id = f"{redis_connector.prune.generator_task_key}_{uuid4()}" # create before setting fence to avoid race condition where the monitoring # task updates the sync record before it is created try: insert_sync_record( db_session=db_session, entity_id=cc_pair.id, sync_type=SyncType.PRUNING, ) except Exception: task_logger.exception("insert_sync_record exceptioned.") # signal active before the fence is set redis_connector.prune.set_active() # set a basic fence to start payload = RedisConnectorPrunePayload( id=make_short_id(), submitted=datetime.now(timezone.utc), started=None, celery_task_id=None, ) redis_connector.prune.set_fence(payload) result = celery_app.send_task( OnyxCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK, kwargs=dict( cc_pair_id=cc_pair.id, connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id, tenant_id=tenant_id, ), queue=OnyxCeleryQueues.CONNECTOR_PRUNING, task_id=custom_task_id, priority=OnyxCeleryPriority.LOW, ) # fill in the celery task id payload.celery_task_id = result.id redis_connector.prune.set_fence(payload) payload_id = payload.id except Exception as e: error_msg = format_error_for_logging(e) task_logger.warning( f"Unexpected try_creating_prune_generator_task exception: cc_pair={cc_pair.id} {error_msg}" ) task_logger.exception(f"Unexpected exception: cc_pair={cc_pair.id}") return None finally: if lock.owned(): lock.release() task_logger.info( f"try_creating_prune_generator_task finished: cc_pair={cc_pair.id} payload_id={payload_id}" ) return payload_id @shared_task( name=OnyxCeleryTask.CONNECTOR_PRUNING_GENERATOR_TASK, acks_late=False, soft_time_limit=JOB_TIMEOUT, track_started=True, trail=False, bind=True, ) def connector_pruning_generator_task( self: Task, cc_pair_id: int, connector_id: int, credential_id: int, tenant_id: str, ) -> None: """connector pruning task. For a cc pair, this task pulls all document IDs from the source and compares those IDs to locally stored documents and deletes all locally stored IDs missing from the most recently pulled document ID list""" payload_id: str | None = None LoggerContextVars.reset() pruning_ctx_dict = pruning_ctx.get() pruning_ctx_dict["cc_pair_id"] = cc_pair_id pruning_ctx_dict["request_id"] = self.request.id pruning_ctx.set(pruning_ctx_dict) task_logger.info(f"Pruning generator starting: cc_pair={cc_pair_id}") redis_connector = RedisConnector(tenant_id, cc_pair_id) r = get_redis_client() # this wait is needed to avoid a race condition where # the primary worker sends the task and it is immediately executed # before the primary worker can finalize the fence start = time.monotonic() while True: if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT: raise ValueError( f"connector_prune_generator_task - timed out waiting for fence to be ready: " f"fence={redis_connector.prune.fence_key}" ) if not redis_connector.prune.fenced: # The fence must exist raise ValueError( f"connector_prune_generator_task - fence not found: fence={redis_connector.prune.fence_key}" ) payload = redis_connector.prune.payload # The payload must exist if not payload: raise ValueError( "connector_prune_generator_task: payload invalid or not found" ) if payload.celery_task_id is None: logger.info( f"connector_prune_generator_task - Waiting for fence: fence={redis_connector.prune.fence_key}" ) time.sleep(1) continue payload_id = payload.id logger.info( f"connector_prune_generator_task - Fence found, continuing...: " f"fence={redis_connector.prune.fence_key} " f"payload_id={payload.id}" ) break # set thread_local=False since we don't control what thread the indexing/pruning # might run our callback with lock: RedisLock = r.lock( OnyxRedisLocks.PRUNING_LOCK_PREFIX + f"_{redis_connector.cc_pair_id}", timeout=CELERY_PRUNING_LOCK_TIMEOUT, thread_local=False, ) acquired = lock.acquire(blocking=False) if not acquired: task_logger.warning( f"Pruning task already running, exiting...: cc_pair={cc_pair_id}" ) return None try: with get_session_with_current_tenant() as db_session: cc_pair = get_connector_credential_pair( db_session=db_session, connector_id=connector_id, credential_id=credential_id, ) if not cc_pair: task_logger.warning( f"cc_pair not found for {connector_id} {credential_id}" ) return payload = redis_connector.prune.payload if not payload: raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}") new_payload = RedisConnectorPrunePayload( id=payload.id, submitted=payload.submitted, started=datetime.now(timezone.utc), celery_task_id=payload.celery_task_id, ) redis_connector.prune.set_fence(new_payload) task_logger.info( f"Pruning generator running connector: cc_pair={cc_pair_id} connector_source={cc_pair.connector.source}" ) runnable_connector = instantiate_connector( db_session, cc_pair.connector.source, InputType.SLIM_RETRIEVAL, cc_pair.connector.connector_specific_config, cc_pair.credential, ) callback = PruneCallback( 0, redis_connector, lock, r, timeout_seconds=JOB_TIMEOUT, ) # Extract docs and hierarchy nodes from the source extraction_result = extract_ids_from_runnable_connector( runnable_connector, callback ) all_connector_doc_ids = extraction_result.raw_id_to_parent # Process hierarchy nodes (same as docfetching): # upsert to Postgres and cache in Redis source = cc_pair.connector.source redis_client = get_redis_client(tenant_id=tenant_id) ensure_source_node_exists(redis_client, db_session, source) upserted_nodes: list[DBHierarchyNode] = [] if extraction_result.hierarchy_nodes: is_connector_public = cc_pair.access_type == AccessType.PUBLIC upserted_nodes = upsert_hierarchy_nodes_batch( db_session=db_session, nodes=extraction_result.hierarchy_nodes, source=source, commit=True, is_connector_public=is_connector_public, ) upsert_hierarchy_node_cc_pair_entries( db_session=db_session, hierarchy_node_ids=[n.id for n in upserted_nodes], connector_id=connector_id, credential_id=credential_id, commit=True, ) cache_entries = [ HierarchyNodeCacheEntry.from_db_model(node) for node in upserted_nodes ] cache_hierarchy_nodes_batch( redis_client=redis_client, source=source, entries=cache_entries, ) task_logger.info( f"Pruning: persisted and cached {len(extraction_result.hierarchy_nodes)} " f"hierarchy nodes for cc_pair={cc_pair_id}" ) # Resolve parent_hierarchy_raw_node_id → parent_hierarchy_node_id # and bulk-update documents, mirroring the docfetching resolution _resolve_and_update_document_parents( db_session=db_session, redis_client=redis_client, source=source, raw_id_to_parent=all_connector_doc_ids, ) # Link hierarchy nodes to documents for sources where pages can be # both hierarchy nodes AND documents (e.g. Notion, Confluence) all_doc_id_list = list(all_connector_doc_ids.keys()) link_hierarchy_nodes_to_documents( db_session=db_session, document_ids=all_doc_id_list, source=source, commit=True, ) # a list of docs in our local index all_indexed_document_ids = { doc.id for doc in get_documents_for_connector_credential_pair( db_session=db_session, connector_id=connector_id, credential_id=credential_id, ) } # generate list of docs to remove (no longer in the source) doc_ids_to_remove = list( all_indexed_document_ids - all_connector_doc_ids.keys() ) task_logger.info( "Pruning set collected: " f"cc_pair={cc_pair_id} " f"connector_source={cc_pair.connector.source} " f"docs_to_remove={len(doc_ids_to_remove)}" ) task_logger.info( f"RedisConnector.prune.generate_tasks starting. cc_pair={cc_pair_id}" ) tasks_generated = redis_connector.prune.generate_tasks( set(doc_ids_to_remove), self.app, db_session, None ) if tasks_generated is None: return None task_logger.info( f"RedisConnector.prune.generate_tasks finished. cc_pair={cc_pair_id} tasks_generated={tasks_generated}" ) redis_connector.prune.generator_complete = tasks_generated # --- Hierarchy node pruning --- live_node_ids = {n.id for n in upserted_nodes} stale_removed = remove_stale_hierarchy_node_cc_pair_entries( db_session=db_session, connector_id=connector_id, credential_id=credential_id, live_hierarchy_node_ids=live_node_ids, commit=True, ) deleted_raw_ids = delete_orphaned_hierarchy_nodes( db_session=db_session, source=source, commit=True, ) reparented_nodes = reparent_orphaned_hierarchy_nodes( db_session=db_session, source=source, commit=True, ) if deleted_raw_ids: evict_hierarchy_nodes_from_cache(redis_client, source, deleted_raw_ids) if reparented_nodes: reparented_cache_entries = [ HierarchyNodeCacheEntry.from_db_model(node) for node in reparented_nodes ] cache_hierarchy_nodes_batch( redis_client, source, reparented_cache_entries ) if stale_removed or deleted_raw_ids or reparented_nodes: task_logger.info( f"Hierarchy node pruning: cc_pair={cc_pair_id} " f"stale_entries_removed={stale_removed} " f"nodes_deleted={len(deleted_raw_ids)} " f"nodes_reparented={len(reparented_nodes)}" ) except Exception as e: task_logger.exception( f"Pruning exceptioned: cc_pair={cc_pair_id} connector={connector_id} payload_id={payload_id}" ) redis_connector.prune.reset() raise e finally: if lock.owned(): lock.release() task_logger.info( f"Pruning generator finished: cc_pair={cc_pair_id} payload_id={payload_id}" ) """Monitoring pruning utils""" def monitor_ccpair_pruning_taskset( tenant_id: str, key_bytes: bytes, r: Redis, # noqa: ARG001 db_session: Session, ) -> None: fence_key = key_bytes.decode("utf-8") cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key) if cc_pair_id_str is None: task_logger.warning( f"monitor_ccpair_pruning_taskset: could not parse cc_pair_id from {fence_key}" ) return cc_pair_id = int(cc_pair_id_str) redis_connector = RedisConnector(tenant_id, cc_pair_id) if not redis_connector.prune.fenced: return initial = redis_connector.prune.generator_complete if initial is None: return remaining = redis_connector.prune.get_remaining() task_logger.info( f"Connector pruning progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}" ) if remaining > 0: return mark_ccpair_as_pruned(int(cc_pair_id), db_session) task_logger.info( f"Connector pruning finished: cc_pair={cc_pair_id} num_pruned={initial}" ) update_sync_record_status( db_session=db_session, entity_id=cc_pair_id, sync_type=SyncType.PRUNING, sync_status=SyncStatus.SUCCESS, num_docs_synced=initial, ) delete_orphan_tags__no_commit(db_session) redis_connector.prune.taskset_clear() redis_connector.prune.generator_clear() redis_connector.prune.set_fence(None) def validate_pruning_fences( tenant_id: str, r: Redis, r_replica: Redis, r_celery: Redis, lock_beat: RedisLock, ) -> None: # building lookup table can be expensive, so we won't bother # validating until the queue is small PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN = 1024 queue_len = celery_get_queue_length(OnyxCeleryQueues.CONNECTOR_DELETION, r_celery) if queue_len > PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN: return # the queue for a single pruning generator task reserved_generator_tasks = celery_get_unacked_task_ids( OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery ) # the queue for a reasonably large set of lightweight deletion tasks queued_upsert_tasks = celery_get_queued_task_ids( OnyxCeleryQueues.CONNECTOR_DELETION, r_celery ) # Use replica for this because the worst thing that happens # is that we don't run the validation on this pass keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES)) for key in keys: key_bytes = cast(bytes, key) key_str = key_bytes.decode("utf-8") if not key_str.startswith(RedisConnectorPrune.FENCE_PREFIX): continue validate_pruning_fence( tenant_id, key_bytes, reserved_generator_tasks, queued_upsert_tasks, r, r_celery, ) lock_beat.reacquire() return def validate_pruning_fence( tenant_id: str, key_bytes: bytes, reserved_tasks: set[str], queued_tasks: set[str], r: Redis, r_celery: Redis, ) -> None: """See validate_indexing_fence for an overall idea of validation flows. queued_tasks: the celery queue of lightweight permission sync tasks reserved_tasks: prefetched tasks for sync task generator """ # if the fence doesn't exist, there's nothing to do fence_key = key_bytes.decode("utf-8") cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key) if cc_pair_id_str is None: task_logger.warning( f"validate_pruning_fence - could not parse id from {fence_key}" ) return cc_pair_id = int(cc_pair_id_str) # parse out metadata and initialize the helper class with it redis_connector = RedisConnector(tenant_id, int(cc_pair_id)) # check to see if the fence/payload exists if not redis_connector.prune.fenced: return # in the cloud, the payload format may have changed ... # it's a little sloppy, but just reset the fence for now if that happens # TODO: add intentional cleanup/abort logic try: payload = redis_connector.prune.payload except ValidationError: task_logger.exception( "validate_pruning_fence - " "Resetting fence because fence schema is out of date: " f"cc_pair={cc_pair_id} " f"fence={fence_key}" ) redis_connector.prune.reset() return if not payload: return if not payload.celery_task_id: return # OK, there's actually something for us to validate # either the generator task must be in flight or its subtasks must be found = celery_find_task( payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_PRUNING, r_celery, ) if found: # the celery task exists in the redis queue redis_connector.prune.set_active() return if payload.celery_task_id in reserved_tasks: # the celery task was prefetched and is reserved within a worker redis_connector.prune.set_active() return # look up every task in the current taskset in the celery queue # every entry in the taskset should have an associated entry in the celery task queue # because we get the celery tasks first, the entries in our own pruning taskset # should be roughly a subset of the tasks in celery # this check isn't very exact, but should be sufficient over a period of time # A single successful check over some number of attempts is sufficient. # TODO: if the number of tasks in celery is much lower than than the taskset length # we might be able to shortcut the lookup since by definition some of the tasks # must not exist in celery. tasks_scanned = 0 tasks_not_in_celery = 0 # a non-zero number after completing our check is bad for member in r.sscan_iter(redis_connector.prune.taskset_key): tasks_scanned += 1 member_bytes = cast(bytes, member) member_str = member_bytes.decode("utf-8") if member_str in queued_tasks: continue if member_str in reserved_tasks: continue tasks_not_in_celery += 1 task_logger.info( f"validate_pruning_fence task check: tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}" ) # we're active if there are still tasks to run and those tasks all exist in celery if tasks_scanned > 0 and tasks_not_in_celery == 0: redis_connector.prune.set_active() return # we may want to enable this check if using the active task list somehow isn't good enough # if redis_connector_index.generator_locked(): # logger.info(f"{payload.celery_task_id} is currently executing.") # if we get here, we didn't find any direct indication that the associated celery tasks exist, # but they still might be there due to gaps in our ability to check states during transitions # Checking the active signal safeguards us against these transition periods # (which has a duration that allows us to bridge those gaps) if redis_connector.prune.active(): return # celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up. task_logger.warning( "validate_pruning_fence - " "Resetting fence because no associated celery tasks were found: " f"cc_pair={cc_pair_id} " f"fence={fence_key} " f"payload_id={payload.id}" ) redis_connector.prune.reset() return ================================================ FILE: backend/onyx/background/celery/tasks/shared/RetryDocumentIndex.py ================================================ import httpx from tenacity import retry from tenacity import retry_if_exception_type from tenacity import stop_after_delay from tenacity import wait_random_exponential from onyx.document_index.interfaces import DocumentIndex from onyx.document_index.interfaces import VespaDocumentFields from onyx.document_index.interfaces import VespaDocumentUserFields class RetryDocumentIndex: """A wrapper class to help with specific retries against Vespa involving read timeouts. wait_random_exponential implements full jitter as per this article: https://aws.amazon.com/blogs/architecture/exponential-backoff-and-jitter/""" MAX_WAIT = 30 # STOP_AFTER + MAX_WAIT should be slightly less (5?) than the celery soft_time_limit STOP_AFTER = 70 def __init__(self, index: DocumentIndex): self.index: DocumentIndex = index @retry( retry=retry_if_exception_type(httpx.ReadTimeout), wait=wait_random_exponential(multiplier=1, max=MAX_WAIT), stop=stop_after_delay(STOP_AFTER), ) def delete_single( self, doc_id: str, *, tenant_id: str, chunk_count: int | None, ) -> int: return self.index.delete_single( doc_id, tenant_id=tenant_id, chunk_count=chunk_count, ) @retry( retry=retry_if_exception_type(httpx.ReadTimeout), wait=wait_random_exponential(multiplier=1, max=MAX_WAIT), stop=stop_after_delay(STOP_AFTER), ) def update_single( self, doc_id: str, *, tenant_id: str, chunk_count: int | None, fields: VespaDocumentFields | None, user_fields: VespaDocumentUserFields | None, ) -> None: self.index.update_single( doc_id, tenant_id=tenant_id, chunk_count=chunk_count, fields=fields, user_fields=user_fields, ) ================================================ FILE: backend/onyx/background/celery/tasks/shared/__init__.py ================================================ ================================================ FILE: backend/onyx/background/celery/tasks/shared/tasks.py ================================================ import time from enum import Enum from http import HTTPStatus import httpx from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded from redis import Redis from tenacity import RetryError from onyx.access.access import get_access_for_document from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex from onyx.configs.constants import ONYX_CELERY_BEAT_HEARTBEAT_KEY from onyx.configs.constants import OnyxCeleryTask from onyx.db.document import delete_document_by_connector_credential_pair__no_commit from onyx.db.document import delete_documents_complete__no_commit from onyx.db.document import fetch_chunk_count_for_document from onyx.db.document import get_document from onyx.db.document import get_document_connector_count from onyx.db.document import mark_document_as_modified from onyx.db.document import mark_document_as_synced from onyx.db.document_set import fetch_document_sets_for_document from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.relationships import delete_document_references_from_kg from onyx.db.search_settings import get_active_search_settings from onyx.document_index.factory import get_all_document_indices from onyx.document_index.interfaces import VespaDocumentFields from onyx.httpx.httpx_pool import HttpxPool from onyx.redis.redis_pool import get_redis_client from onyx.server.documents.models import ConnectorCredentialPairIdentifier DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES = 3 # 5 seconds more than RetryDocumentIndex STOP_AFTER+MAX_WAIT LIGHT_SOFT_TIME_LIMIT = 105 LIGHT_TIME_LIMIT = LIGHT_SOFT_TIME_LIMIT + 15 class OnyxCeleryTaskCompletionStatus(str, Enum): """The different statuses the watchdog can finish with. TODO: create broader success/failure/abort categories """ UNDEFINED = "undefined" SUCCEEDED = "succeeded" SKIPPED = "skipped" SOFT_TIME_LIMIT = "soft_time_limit" NON_RETRYABLE_EXCEPTION = "non_retryable_exception" RETRYABLE_EXCEPTION = "retryable_exception" @shared_task( name=OnyxCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK, soft_time_limit=LIGHT_SOFT_TIME_LIMIT, time_limit=LIGHT_TIME_LIMIT, max_retries=DOCUMENT_BY_CC_PAIR_CLEANUP_MAX_RETRIES, bind=True, ) def document_by_cc_pair_cleanup_task( self: Task, document_id: str, connector_id: int, credential_id: int, tenant_id: str, ) -> bool: """A lightweight subtask used to clean up document to cc pair relationships. Created by connection deletion and connector pruning parent tasks.""" """ To delete a connector / credential pair: (1) find all documents associated with connector / credential pair where there this the is only connector / credential pair that has indexed it (2) delete all documents from document stores (3) delete all entries from postgres (4) find all documents associated with connector / credential pair where there are multiple connector / credential pairs that have indexed it (5) update document store entries to remove access associated with the connector / credential pair from the access list (6) delete all relevant entries from postgres """ task_logger.debug(f"Task start: doc={document_id}") start = time.monotonic() completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED try: with get_session_with_current_tenant() as db_session: action = "skip" active_search_settings = get_active_search_settings(db_session) # This flow is for updates and deletion so we get all indices. document_indices = get_all_document_indices( active_search_settings.primary, active_search_settings.secondary, httpx_client=HttpxPool.get("vespa"), ) retry_document_indices: list[RetryDocumentIndex] = [ RetryDocumentIndex(document_index) for document_index in document_indices ] count = get_document_connector_count(db_session, document_id) if count == 1: # count == 1 means this is the only remaining cc_pair reference to the doc # delete it from vespa and the db action = "delete" chunk_count = fetch_chunk_count_for_document(document_id, db_session) for retry_document_index in retry_document_indices: _ = retry_document_index.delete_single( document_id, tenant_id=tenant_id, chunk_count=chunk_count, ) delete_document_references_from_kg( db_session=db_session, document_id=document_id, ) delete_documents_complete__no_commit( db_session=db_session, document_ids=[document_id], ) db_session.commit() completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED elif count > 1: action = "update" # count > 1 means the document still has cc_pair references doc = get_document(document_id, db_session) if not doc: return False # the below functions do not include cc_pairs being deleted. # i.e. they will correctly omit access for the current cc_pair doc_access = get_access_for_document( document_id=document_id, db_session=db_session ) doc_sets = fetch_document_sets_for_document(document_id, db_session) update_doc_sets: set[str] = set(doc_sets) fields = VespaDocumentFields( document_sets=update_doc_sets, access=doc_access, boost=doc.boost, hidden=doc.hidden, ) for retry_document_index in retry_document_indices: # TODO(andrei): Previously there was a comment here saying # it was ok if a doc did not exist in the document index. I # don't agree with that claim, so keep an eye on this task # to see if this raises. retry_document_index.update_single( document_id, tenant_id=tenant_id, chunk_count=doc.chunk_count, fields=fields, user_fields=None, ) # there are still other cc_pair references to the doc, so just resync to Vespa delete_document_by_connector_credential_pair__no_commit( db_session=db_session, document_id=document_id, connector_credential_pair_identifier=ConnectorCredentialPairIdentifier( connector_id=connector_id, credential_id=credential_id, ), ) mark_document_as_synced(document_id, db_session) db_session.commit() completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED else: completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED elapsed = time.monotonic() - start task_logger.info( f"doc={document_id} action={action} refcount={count} elapsed={elapsed:.2f}" ) except SoftTimeLimitExceeded: task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}") completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT except Exception as ex: e: Exception | None = None while True: if isinstance(ex, RetryError): task_logger.warning( f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}" ) # only set the inner exception if it is of type Exception e_temp = ex.last_attempt.exception() if isinstance(e_temp, Exception): e = e_temp else: e = ex if isinstance(e, httpx.HTTPStatusError): if e.response.status_code == HTTPStatus.BAD_REQUEST: task_logger.exception( f"Non-retryable HTTPStatusError: doc={document_id} status={e.response.status_code}" ) completion_status = ( OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION ) break task_logger.exception( f"document_by_cc_pair_cleanup_task exceptioned: doc={document_id}" ) completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION if ( self.max_retries is not None and self.request.retries >= self.max_retries ): # This is the last attempt! mark the document as dirty in the db so that it # eventually gets fixed out of band via stale document reconciliation task_logger.warning( f"Max celery task retries reached. Marking doc as dirty for reconciliation: doc={document_id}" ) with get_session_with_current_tenant() as db_session: # delete the cc pair relationship now and let reconciliation clean it up # in vespa delete_document_by_connector_credential_pair__no_commit( db_session=db_session, document_id=document_id, connector_credential_pair_identifier=ConnectorCredentialPairIdentifier( connector_id=connector_id, credential_id=credential_id, ), ) mark_document_as_modified(document_id, db_session) completion_status = ( OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION ) break # Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64 countdown = 2 ** (self.request.retries + 4) self.retry(exc=e, countdown=countdown) # this will raise a celery exception break # we won't hit this, but it looks weird not to have it finally: task_logger.info( f"document_by_cc_pair_cleanup_task completed: status={completion_status.value} doc={document_id}" ) if completion_status != OnyxCeleryTaskCompletionStatus.SUCCEEDED: return False task_logger.info(f"document_by_cc_pair_cleanup_task finished: doc={document_id}") return True @shared_task(name=OnyxCeleryTask.CELERY_BEAT_HEARTBEAT, ignore_result=True, bind=True) def celery_beat_heartbeat(self: Task, *, tenant_id: str) -> None: # noqa: ARG001 """When this task runs, it writes a key to Redis with a TTL. An external observer can check this key to figure out if the celery beat is still running. """ time_start = time.monotonic() r: Redis = get_redis_client() r.set(ONYX_CELERY_BEAT_HEARTBEAT_KEY, 1, ex=600) time_elapsed = time.monotonic() - time_start task_logger.info(f"celery_beat_heartbeat finished: elapsed={time_elapsed:.2f}") ================================================ FILE: backend/onyx/background/celery/tasks/user_file_processing/__init__.py ================================================ ================================================ FILE: backend/onyx/background/celery/tasks/user_file_processing/tasks.py ================================================ import datetime import time from typing import Any from uuid import UUID import httpx import sqlalchemy as sa from celery import Celery from celery import shared_task from celery import Task from redis import Redis from redis.lock import Lock as RedisLock from retry import retry from sqlalchemy import select from sqlalchemy.orm import Session from onyx.access.access import build_access_for_user_files from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.celery_redis import celery_get_broker_client from onyx.background.celery.celery_redis import celery_get_queue_length from onyx.background.celery.celery_utils import httpx_init_vespa_pool from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex from onyx.configs.app_configs import DISABLE_VECTOR_DB from onyx.configs.app_configs import MANAGED_VESPA from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import CELERY_USER_FILE_DELETE_TASK_EXPIRES from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES from onyx.configs.constants import DocumentSource from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisLocks from onyx.configs.constants import USER_FILE_DELETE_MAX_QUEUE_DEPTH from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH from onyx.configs.constants import USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH from onyx.connectors.file.connector import LocalFileConnector from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.enums import UserFileStatus from onyx.db.models import UserFile from onyx.db.search_settings import get_active_search_settings from onyx.db.search_settings import get_active_search_settings_list from onyx.db.user_file import fetch_user_files_with_access_relationships from onyx.document_index.factory import get_all_document_indices from onyx.document_index.interfaces import VespaDocumentFields from onyx.document_index.interfaces import VespaDocumentUserFields from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT from onyx.file_store.file_store import get_default_file_store from onyx.file_store.utils import store_user_file_plaintext from onyx.file_store.utils import user_file_id_to_plaintext_file_name from onyx.httpx.httpx_pool import HttpxPool from onyx.indexing.adapters.user_file_indexing_adapter import UserFileIndexingAdapter from onyx.indexing.embedder import DefaultIndexingEmbedder from onyx.indexing.indexing_pipeline import run_indexing_pipeline from onyx.redis.redis_pool import get_redis_client from onyx.utils.variable_functionality import global_version def _as_uuid(value: str | UUID) -> UUID: """Return a UUID, accepting either a UUID or a string-like value.""" return value if isinstance(value, UUID) else UUID(str(value)) def _user_file_lock_key(user_file_id: str | UUID) -> str: return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}" def _user_file_queued_key(user_file_id: str | UUID) -> str: """Key that exists while a process_single_user_file task is sitting in the queue. The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES before enqueuing and the worker deletes it as its first action. This prevents the beat from adding duplicate tasks for files that already have a live task in flight. """ return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}" def user_file_project_sync_lock_key(user_file_id: str | UUID) -> str: return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}" def _user_file_project_sync_queued_key(user_file_id: str | UUID) -> str: return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_QUEUED_PREFIX}:{user_file_id}" def _user_file_delete_lock_key(user_file_id: str | UUID) -> str: return f"{OnyxRedisLocks.USER_FILE_DELETE_LOCK_PREFIX}:{user_file_id}" def _user_file_delete_queued_key(user_file_id: str | UUID) -> str: """Key that exists while a delete_single_user_file task is sitting in the queue. The beat generator sets this with a TTL equal to CELERY_USER_FILE_DELETE_TASK_EXPIRES before enqueuing and the worker deletes it as its first action. This prevents the beat from adding duplicate tasks for files that already have a live task in flight. """ return f"{OnyxRedisLocks.USER_FILE_DELETE_QUEUED_PREFIX}:{user_file_id}" def get_user_file_project_sync_queue_depth(celery_app: Celery) -> int: redis_celery = celery_get_broker_client(celery_app) return celery_get_queue_length( OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, redis_celery ) def enqueue_user_file_project_sync_task( *, celery_app: Celery, redis_client: Redis, user_file_id: str | UUID, tenant_id: str, priority: OnyxCeleryPriority = OnyxCeleryPriority.HIGH, ) -> bool: """Enqueue a project-sync task if no matching queued task already exists.""" queued_key = _user_file_project_sync_queued_key(user_file_id) # NX+EX gives us atomic dedupe and a self-healing TTL. queued_guard_set = redis_client.set( queued_key, 1, nx=True, ex=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES, ) if not queued_guard_set: return False try: celery_app.send_task( OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC, kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id}, queue=OnyxCeleryQueues.USER_FILE_PROJECT_SYNC, priority=priority, expires=CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES, ) except Exception: # Roll back the queued guard if task publish fails. redis_client.delete(queued_key) raise return True @retry(tries=3, delay=1, backoff=2, jitter=(0.0, 1.0)) def _visit_chunks( *, http_client: httpx.Client, index_name: str, selection: str, continuation: str | None = None, ) -> tuple[list[dict[str, Any]], str | None]: task_logger.info( f"Visiting chunks for index={index_name} with selection={selection}" ) base_url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name) params: dict[str, str] = { "selection": selection, "wantedDocumentCount": "100", # Use smaller batch size to avoid timeouts } if continuation: params["continuation"] = continuation resp = http_client.get(base_url, params=params, timeout=None) resp.raise_for_status() payload = resp.json() return payload.get("documents", []), payload.get("continuation") def _get_document_chunk_count( *, index_name: str, selection: str, ) -> int: chunk_count = 0 continuation = None while True: docs, continuation = _visit_chunks( http_client=HttpxPool.get("vespa"), index_name=index_name, selection=selection, continuation=continuation, ) if not docs: break chunk_count += len(docs) if not continuation: break return chunk_count @shared_task( name=OnyxCeleryTask.CHECK_FOR_USER_FILE_PROCESSING, soft_time_limit=300, bind=True, ignore_result=True, ) def check_user_file_processing(self: Task, *, tenant_id: str) -> None: """Scan for user files with PROCESSING status and enqueue per-file tasks. Three mechanisms prevent queue runaway: 1. **Queue depth backpressure** – if the broker queue already has more than USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle entirely. Workers are clearly behind; adding more tasks would only make the backlog worse. 2. **Per-file queued guard** – before enqueuing a task we set a short-lived Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key already exists the file already has a live task in the queue, so we skip it. The worker deletes the key the moment it picks up the task so the next beat cycle can re-enqueue if the file is still PROCESSING. 3. **Task expiry** – every enqueued task carries an `expires` value equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in the queue after that deadline, Celery discards it without touching the DB. This is a belt-and-suspenders defence: even if the guard key is lost (e.g. Redis restart), stale tasks evict themselves rather than piling up forever. """ task_logger.info("check_user_file_processing - Starting") redis_client = get_redis_client(tenant_id=tenant_id) lock: RedisLock = redis_client.lock( OnyxRedisLocks.USER_FILE_PROCESSING_BEAT_LOCK, timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT, ) # Do not overlap generator runs if not lock.acquire(blocking=False): return None enqueued = 0 skipped_guard = 0 try: # --- Protection 1: queue depth backpressure --- r_celery = celery_get_broker_client(self.app) queue_len = celery_get_queue_length( OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery ) if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH: task_logger.warning( f"check_user_file_processing - Queue depth {queue_len} exceeds " f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for " f"tenant={tenant_id}" ) return None with get_session_with_current_tenant() as db_session: user_file_ids = ( db_session.execute( select(UserFile.id).where( UserFile.status == UserFileStatus.PROCESSING ) ) .scalars() .all() ) for user_file_id in user_file_ids: # --- Protection 2: per-file queued guard --- queued_key = _user_file_queued_key(user_file_id) guard_set = redis_client.set( queued_key, 1, ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, nx=True, ) if not guard_set: skipped_guard += 1 continue # --- Protection 3: task expiry --- # If task submission fails, clear the guard immediately so the # next beat cycle can retry enqueuing this file. try: self.app.send_task( OnyxCeleryTask.PROCESS_SINGLE_USER_FILE, kwargs={ "user_file_id": str(user_file_id), "tenant_id": tenant_id, }, queue=OnyxCeleryQueues.USER_FILE_PROCESSING, priority=OnyxCeleryPriority.HIGH, expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, ) except Exception: redis_client.delete(queued_key) raise enqueued += 1 finally: if lock.owned(): lock.release() task_logger.info( f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} tasks for tenant={tenant_id}" ) return None def _process_user_file_without_vector_db( uf: UserFile, documents: list[Document], db_session: Session, ) -> None: """Process a user file when the vector DB is disabled. Extracts raw text and computes a token count, stores the plaintext in the file store, and marks the file as COMPLETED. Skips embedding and the indexing pipeline entirely. """ from onyx.llm.factory import get_default_llm from onyx.llm.factory import get_llm_tokenizer_encode_func # Combine section text from all document sections combined_text = " ".join( section.text for doc in documents for section in doc.sections if section.text ) # Compute token count using the user's default LLM tokenizer try: llm = get_default_llm() encode = get_llm_tokenizer_encode_func(llm) token_count: int | None = len(encode(combined_text)) except Exception: task_logger.warning( f"_process_user_file_without_vector_db - Failed to compute token count for {uf.id}, falling back to None" ) token_count = None # Persist plaintext for fast FileReaderTool loads store_user_file_plaintext( user_file_id=uf.id, plaintext_content=combined_text, ) # Update the DB record if uf.status != UserFileStatus.DELETING: uf.status = UserFileStatus.COMPLETED uf.token_count = token_count uf.chunk_count = 0 # no chunks without vector DB uf.last_project_sync_at = datetime.datetime.now(datetime.timezone.utc) db_session.add(uf) db_session.commit() task_logger.info( f"_process_user_file_without_vector_db - Completed id={uf.id} tokens={token_count}" ) def _process_user_file_with_indexing( uf: UserFile, user_file_id: str, documents: list[Document], tenant_id: str, db_session: Session, ) -> None: """Process a user file through the full indexing pipeline (vector DB path).""" # 20 is the documented default for httpx max_keepalive_connections if MANAGED_VESPA: httpx_init_vespa_pool( 20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH ) else: httpx_init_vespa_pool(20) search_settings_list = get_active_search_settings_list(db_session) current_search_settings = next( (ss for ss in search_settings_list if ss.status.is_current()), None, ) if current_search_settings is None: raise RuntimeError( f"_process_user_file_with_indexing - No current search settings found for tenant={tenant_id}" ) adapter = UserFileIndexingAdapter( tenant_id=tenant_id, db_session=db_session, ) embedding_model = DefaultIndexingEmbedder.from_db_search_settings( search_settings=current_search_settings, ) document_indices = get_all_document_indices( current_search_settings, None, httpx_client=HttpxPool.get("vespa"), ) index_pipeline_result = run_indexing_pipeline( embedder=embedding_model, document_indices=document_indices, ignore_time_skip=True, db_session=db_session, tenant_id=tenant_id, document_batch=documents, request_id=None, adapter=adapter, ) task_logger.info( f"_process_user_file_with_indexing - Indexing pipeline completed ={index_pipeline_result}" ) if ( index_pipeline_result.failures or index_pipeline_result.total_docs != len(documents) or index_pipeline_result.total_chunks == 0 ): task_logger.error( f"_process_user_file_with_indexing - Indexing pipeline failed id={user_file_id}" ) if uf.status != UserFileStatus.DELETING: uf.status = UserFileStatus.FAILED db_session.add(uf) db_session.commit() raise RuntimeError(f"Indexing pipeline failed for user file {user_file_id}") def process_user_file_impl( *, user_file_id: str, tenant_id: str, redis_locking: bool ) -> None: """Core implementation for processing a single user file. When redis_locking=True, acquires a per-file Redis lock and clears the queued-key guard (Celery path). When redis_locking=False, skips all Redis operations (BackgroundTask path). """ task_logger.info(f"process_user_file_impl - Starting id={user_file_id}") start = time.monotonic() file_lock: RedisLock | None = None if redis_locking: redis_client = get_redis_client(tenant_id=tenant_id) redis_client.delete(_user_file_queued_key(user_file_id)) file_lock = redis_client.lock( _user_file_lock_key(user_file_id), timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT, ) if file_lock is not None and not file_lock.acquire(blocking=False): task_logger.info( f"process_user_file_impl - Lock held, skipping user_file_id={user_file_id}" ) return documents: list[Document] = [] try: with get_session_with_current_tenant() as db_session: uf = db_session.get(UserFile, _as_uuid(user_file_id)) if not uf: task_logger.warning( f"process_user_file_impl - UserFile not found id={user_file_id}" ) return if uf.status not in ( UserFileStatus.PROCESSING, UserFileStatus.INDEXING, ): task_logger.info( f"process_user_file_impl - Skipping id={user_file_id} status={uf.status}" ) return connector = LocalFileConnector( file_locations=[uf.file_id], file_names=[uf.name] if uf.name else None, ) connector.load_credentials({}) try: for batch in connector.load_from_state(): documents.extend( [doc for doc in batch if not isinstance(doc, HierarchyNode)] ) for document in documents: document.id = str(user_file_id) document.source = DocumentSource.USER_FILE if DISABLE_VECTOR_DB: _process_user_file_without_vector_db( uf=uf, documents=documents, db_session=db_session, ) else: _process_user_file_with_indexing( uf=uf, user_file_id=user_file_id, documents=documents, tenant_id=tenant_id, db_session=db_session, ) except Exception as e: task_logger.exception( f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}" ) current_user_file = db_session.get(UserFile, _as_uuid(user_file_id)) if ( current_user_file and current_user_file.status != UserFileStatus.DELETING ): uf.status = UserFileStatus.FAILED db_session.add(uf) db_session.commit() return elapsed = time.monotonic() - start task_logger.info( f"process_user_file_impl - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s" ) except Exception as e: with get_session_with_current_tenant() as db_session: uf = db_session.get(UserFile, _as_uuid(user_file_id)) if uf: if uf.status != UserFileStatus.DELETING: uf.status = UserFileStatus.FAILED db_session.add(uf) db_session.commit() task_logger.exception( f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}" ) raise finally: if file_lock is not None and file_lock.owned(): file_lock.release() @shared_task( name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE, bind=True, ignore_result=True, ) def process_single_user_file( self: Task, # noqa: ARG001 *, user_file_id: str, tenant_id: str, ) -> None: process_user_file_impl( user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True ) @shared_task( name=OnyxCeleryTask.CHECK_FOR_USER_FILE_DELETE, soft_time_limit=300, bind=True, ignore_result=True, ) def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None: """Scan for user files with DELETING status and enqueue per-file tasks. Three mechanisms prevent queue runaway (mirrors check_user_file_processing): 1. **Queue depth backpressure** – if the broker queue already has more than USER_FILE_DELETE_MAX_QUEUE_DEPTH items we skip this beat cycle entirely. 2. **Per-file queued guard** – before enqueuing a task we set a short-lived Redis key (TTL = CELERY_USER_FILE_DELETE_TASK_EXPIRES). If that key already exists the file already has a live task in the queue, so we skip it. The worker deletes the key the moment it picks up the task so the next beat cycle can re-enqueue if the file is still DELETING. 3. **Task expiry** – every enqueued task carries an `expires` value equal to CELERY_USER_FILE_DELETE_TASK_EXPIRES. If a task is still sitting in the queue after that deadline, Celery discards it without touching the DB. """ task_logger.info("check_for_user_file_delete - Starting") redis_client = get_redis_client(tenant_id=tenant_id) lock: RedisLock = redis_client.lock( OnyxRedisLocks.USER_FILE_DELETE_BEAT_LOCK, timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT, ) if not lock.acquire(blocking=False): return None enqueued = 0 skipped_guard = 0 try: # --- Protection 1: queue depth backpressure --- # NOTE: must use the broker's Redis client (not redis_client) because # Celery queues live on a separate Redis DB with CELERY_SEPARATOR keys. r_celery = celery_get_broker_client(self.app) queue_len = celery_get_queue_length(OnyxCeleryQueues.USER_FILE_DELETE, r_celery) if queue_len > USER_FILE_DELETE_MAX_QUEUE_DEPTH: task_logger.warning( f"check_for_user_file_delete - Queue depth {queue_len} exceeds " f"{USER_FILE_DELETE_MAX_QUEUE_DEPTH}, skipping enqueue for " f"tenant={tenant_id}" ) return None with get_session_with_current_tenant() as db_session: user_file_ids = ( db_session.execute( select(UserFile.id).where( UserFile.status == UserFileStatus.DELETING ) ) .scalars() .all() ) for user_file_id in user_file_ids: # --- Protection 2: per-file queued guard --- queued_key = _user_file_delete_queued_key(user_file_id) guard_set = redis_client.set( queued_key, 1, ex=CELERY_USER_FILE_DELETE_TASK_EXPIRES, nx=True, ) if not guard_set: skipped_guard += 1 continue # --- Protection 3: task expiry --- try: self.app.send_task( OnyxCeleryTask.DELETE_SINGLE_USER_FILE, kwargs={ "user_file_id": str(user_file_id), "tenant_id": tenant_id, }, queue=OnyxCeleryQueues.USER_FILE_DELETE, priority=OnyxCeleryPriority.HIGH, expires=CELERY_USER_FILE_DELETE_TASK_EXPIRES, ) except Exception: redis_client.delete(queued_key) raise enqueued += 1 finally: if lock.owned(): lock.release() task_logger.info( f"check_for_user_file_delete - Enqueued {enqueued} tasks, skipped_guard={skipped_guard} for tenant={tenant_id}" ) return None def delete_user_file_impl( *, user_file_id: str, tenant_id: str, redis_locking: bool ) -> None: """Core implementation for deleting a single user file. When redis_locking=True, acquires a per-file Redis lock (Celery path). When redis_locking=False, skips Redis operations (BackgroundTask path). """ task_logger.info(f"delete_user_file_impl - Starting id={user_file_id}") file_lock: RedisLock | None = None if redis_locking: redis_client = get_redis_client(tenant_id=tenant_id) # Clear the queued guard so the beat can re-enqueue if deletion fails # and the file remains in DELETING status. redis_client.delete(_user_file_delete_queued_key(user_file_id)) file_lock = redis_client.lock( _user_file_delete_lock_key(user_file_id), timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT, ) if file_lock is not None and not file_lock.acquire(blocking=False): task_logger.info( f"delete_user_file_impl - Lock held, skipping user_file_id={user_file_id}" ) return try: with get_session_with_current_tenant() as db_session: user_file = db_session.get(UserFile, _as_uuid(user_file_id)) if not user_file: task_logger.info( f"delete_user_file_impl - User file not found id={user_file_id}" ) return if not DISABLE_VECTOR_DB: if MANAGED_VESPA: httpx_init_vespa_pool( 20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH ) else: httpx_init_vespa_pool(20) active_search_settings = get_active_search_settings(db_session) document_indices = get_all_document_indices( search_settings=active_search_settings.primary, secondary_search_settings=active_search_settings.secondary, httpx_client=HttpxPool.get("vespa"), ) retry_document_indices: list[RetryDocumentIndex] = [ RetryDocumentIndex(document_index) for document_index in document_indices ] index_name = active_search_settings.primary.index_name selection = f"{index_name}.document_id=='{user_file_id}'" chunk_count = 0 if user_file.chunk_count is None or user_file.chunk_count == 0: chunk_count = _get_document_chunk_count( index_name=index_name, selection=selection, ) else: chunk_count = user_file.chunk_count for retry_document_index in retry_document_indices: retry_document_index.delete_single( doc_id=user_file_id, tenant_id=tenant_id, chunk_count=chunk_count, ) file_store = get_default_file_store() try: file_store.delete_file(user_file.file_id) file_store.delete_file( user_file_id_to_plaintext_file_name(user_file.id) ) except Exception as e: task_logger.exception( f"delete_user_file_impl - Error deleting file id={user_file.id} - {e.__class__.__name__}" ) db_session.delete(user_file) db_session.commit() task_logger.info(f"delete_user_file_impl - Completed id={user_file_id}") except Exception as e: task_logger.exception( f"delete_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}" ) raise finally: if file_lock is not None and file_lock.owned(): file_lock.release() @shared_task( name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE, bind=True, ignore_result=True, ) def process_single_user_file_delete( self: Task, # noqa: ARG001 *, user_file_id: str, tenant_id: str, ) -> None: delete_user_file_impl( user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True ) @shared_task( name=OnyxCeleryTask.CHECK_FOR_USER_FILE_PROJECT_SYNC, soft_time_limit=300, bind=True, ignore_result=True, ) def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None: """Scan for user files needing project sync and enqueue per-file tasks.""" task_logger.info("Starting") redis_client = get_redis_client(tenant_id=tenant_id) lock: RedisLock = redis_client.lock( OnyxRedisLocks.USER_FILE_PROJECT_SYNC_BEAT_LOCK, timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT, ) if not lock.acquire(blocking=False): return None enqueued = 0 skipped_guard = 0 try: queue_depth = get_user_file_project_sync_queue_depth(self.app) if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH: task_logger.warning( f"Queue depth {queue_depth} exceeds " f"{USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH}, skipping enqueue for tenant={tenant_id}" ) return None with get_session_with_current_tenant() as db_session: user_file_ids = ( db_session.execute( select(UserFile.id).where( sa.and_( sa.or_( UserFile.needs_project_sync.is_(True), UserFile.needs_persona_sync.is_(True), ), UserFile.status == UserFileStatus.COMPLETED, ) ) ) .scalars() .all() ) for user_file_id in user_file_ids: if not enqueue_user_file_project_sync_task( celery_app=self.app, redis_client=redis_client, user_file_id=user_file_id, tenant_id=tenant_id, priority=OnyxCeleryPriority.HIGH, ): skipped_guard += 1 continue enqueued += 1 finally: if lock.owned(): lock.release() task_logger.info( f"Enqueued {enqueued} Skipped guard {skipped_guard} tasks for tenant={tenant_id}" ) return None def project_sync_user_file_impl( *, user_file_id: str, tenant_id: str, redis_locking: bool ) -> None: """Core implementation for syncing a user file's project/persona metadata. When redis_locking=True, acquires a per-file Redis lock and clears the queued-key guard (Celery path). When redis_locking=False, skips Redis operations (BackgroundTask path). """ task_logger.info(f"project_sync_user_file_impl - Starting id={user_file_id}") file_lock: RedisLock | None = None if redis_locking: redis_client = get_redis_client(tenant_id=tenant_id) redis_client.delete(_user_file_project_sync_queued_key(user_file_id)) file_lock = redis_client.lock( user_file_project_sync_lock_key(user_file_id), timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT, ) if file_lock is not None and not file_lock.acquire(blocking=False): task_logger.info( f"project_sync_user_file_impl - Lock held, skipping user_file_id={user_file_id}" ) return try: with get_session_with_current_tenant() as db_session: user_files = fetch_user_files_with_access_relationships( [user_file_id], db_session, eager_load_groups=global_version.is_ee_version(), ) user_file = user_files[0] if user_files else None if not user_file: task_logger.info( f"project_sync_user_file_impl - User file not found id={user_file_id}" ) return if not DISABLE_VECTOR_DB: if MANAGED_VESPA: httpx_init_vespa_pool( 20, ssl_cert=VESPA_CLOUD_CERT_PATH, ssl_key=VESPA_CLOUD_KEY_PATH ) else: httpx_init_vespa_pool(20) active_search_settings = get_active_search_settings(db_session) document_indices = get_all_document_indices( search_settings=active_search_settings.primary, secondary_search_settings=active_search_settings.secondary, httpx_client=HttpxPool.get("vespa"), ) retry_document_indices: list[RetryDocumentIndex] = [ RetryDocumentIndex(document_index) for document_index in document_indices ] project_ids = [project.id for project in user_file.projects] persona_ids = [p.id for p in user_file.assistants if not p.deleted] file_id_str = str(user_file.id) access_map = build_access_for_user_files([user_file]) access = access_map.get(file_id_str) for retry_document_index in retry_document_indices: retry_document_index.update_single( doc_id=file_id_str, tenant_id=tenant_id, chunk_count=user_file.chunk_count, fields=( VespaDocumentFields(access=access) if access is not None else None ), user_fields=VespaDocumentUserFields( user_projects=project_ids, personas=persona_ids, ), ) task_logger.info( f"project_sync_user_file_impl - User file id={user_file_id}" ) user_file.needs_project_sync = False user_file.needs_persona_sync = False user_file.last_project_sync_at = datetime.datetime.now( datetime.timezone.utc ) db_session.add(user_file) db_session.commit() except Exception as e: task_logger.exception( f"project_sync_user_file_impl - Error syncing project for file id={user_file_id} - {e.__class__.__name__}" ) raise finally: if file_lock is not None and file_lock.owned(): file_lock.release() @shared_task( name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC, bind=True, ignore_result=True, ) def process_single_user_file_project_sync( self: Task, # noqa: ARG001 *, user_file_id: str, tenant_id: str, ) -> None: project_sync_user_file_impl( user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True ) ================================================ FILE: backend/onyx/background/celery/tasks/vespa/__init__.py ================================================ ================================================ FILE: backend/onyx/background/celery/tasks/vespa/document_sync.py ================================================ import time from typing import cast from uuid import uuid4 from celery import Celery from redis import Redis from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from onyx.configs.app_configs import DB_YIELD_PER_DEFAULT from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisConstants from onyx.db.document import construct_document_id_select_by_needs_sync from onyx.db.document import count_documents_by_needs_sync from onyx.utils.logger import setup_logger # Redis keys for document sync tracking DOCUMENT_SYNC_PREFIX = "documentsync" DOCUMENT_SYNC_FENCE_KEY = f"{DOCUMENT_SYNC_PREFIX}_fence" DOCUMENT_SYNC_TASKSET_KEY = f"{DOCUMENT_SYNC_PREFIX}_taskset" FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks TASKSET_TTL = FENCE_TTL logger = setup_logger() def is_document_sync_fenced(r: Redis) -> bool: """Check if document sync tasks are currently in progress.""" return bool(r.exists(DOCUMENT_SYNC_FENCE_KEY)) def get_document_sync_payload(r: Redis) -> int | None: """Get the initial number of tasks that were created.""" bytes_result = r.get(DOCUMENT_SYNC_FENCE_KEY) if bytes_result is None: return None return int(cast(int, bytes_result)) def get_document_sync_remaining(r: Redis) -> int: """Get the number of tasks still pending completion.""" return cast(int, r.scard(DOCUMENT_SYNC_TASKSET_KEY)) def set_document_sync_fence(r: Redis, payload: int | None) -> None: """Set up the fence and register with active fences.""" if payload is None: r.srem(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY) r.delete(DOCUMENT_SYNC_FENCE_KEY) return r.set(DOCUMENT_SYNC_FENCE_KEY, payload, ex=FENCE_TTL) r.sadd(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY) def delete_document_sync_taskset(r: Redis) -> None: """Clear the document sync taskset.""" r.delete(DOCUMENT_SYNC_TASKSET_KEY) def reset_document_sync(r: Redis) -> None: """Reset all document sync tracking data.""" r.srem(OnyxRedisConstants.ACTIVE_FENCES, DOCUMENT_SYNC_FENCE_KEY) r.delete(DOCUMENT_SYNC_TASKSET_KEY) r.delete(DOCUMENT_SYNC_FENCE_KEY) def generate_document_sync_tasks( r: Redis, max_tasks: int, celery_app: Celery, db_session: Session, lock: RedisLock, tenant_id: str, ) -> tuple[int, int]: """Generate sync tasks for all documents that need syncing. Args: r: Redis client max_tasks: Maximum number of tasks to generate celery_app: Celery application instance db_session: Database session lock: Redis lock for coordination tenant_id: Tenant identifier Returns: tuple[int, int]: (tasks_generated, total_docs_found) """ last_lock_time = time.monotonic() num_tasks_sent = 0 num_docs = 0 # Get all documents that need syncing stmt = construct_document_id_select_by_needs_sync() for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT): doc_id = cast(str, doc_id) current_time = time.monotonic() # Reacquire lock periodically to prevent timeout if current_time - last_lock_time >= (CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4): lock.reacquire() last_lock_time = current_time num_docs += 1 # Create a unique task ID custom_task_id = f"{DOCUMENT_SYNC_PREFIX}_{uuid4()}" # Add to the tracking taskset in Redis BEFORE creating the celery task r.sadd(DOCUMENT_SYNC_TASKSET_KEY, custom_task_id) r.expire(DOCUMENT_SYNC_TASKSET_KEY, TASKSET_TTL) # Create the Celery task celery_app.send_task( OnyxCeleryTask.VESPA_METADATA_SYNC_TASK, kwargs=dict(document_id=doc_id, tenant_id=tenant_id), queue=OnyxCeleryQueues.VESPA_METADATA_SYNC, task_id=custom_task_id, priority=OnyxCeleryPriority.MEDIUM, ignore_result=True, ) num_tasks_sent += 1 if num_tasks_sent >= max_tasks: break return num_tasks_sent, num_docs def try_generate_stale_document_sync_tasks( celery_app: Celery, max_tasks: int, db_session: Session, r: Redis, lock_beat: RedisLock, tenant_id: str, ) -> int | None: # the fence is up, do nothing if is_document_sync_fenced(r): return None # add tasks to celery and build up the task set to monitor in redis stale_doc_count = count_documents_by_needs_sync(db_session) if stale_doc_count == 0: logger.info("No stale documents found. Skipping sync tasks generation.") return None logger.info( f"Stale documents found (at least {stale_doc_count}). Generating sync tasks in one batch." ) logger.info("generate_document_sync_tasks starting for all documents.") # Generate all tasks in one pass result = generate_document_sync_tasks( r, max_tasks, celery_app, db_session, lock_beat, tenant_id ) if result is None: return None tasks_generated, total_docs = result if tasks_generated >= max_tasks: logger.info( f"generate_document_sync_tasks reached the task generation limit: " f"tasks_generated={tasks_generated} max_tasks={max_tasks}" ) else: logger.info( f"generate_document_sync_tasks finished for all documents. " f"tasks_generated={tasks_generated} total_docs_found={total_docs}" ) set_document_sync_fence(r, tasks_generated) return tasks_generated ================================================ FILE: backend/onyx/background/celery/tasks/vespa/tasks.py ================================================ import time from collections.abc import Callable from http import HTTPStatus from typing import Any from typing import cast import httpx from celery import Celery from celery import shared_task from celery import Task from celery.exceptions import SoftTimeLimitExceeded from redis import Redis from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from tenacity import RetryError from onyx.access.access import get_access_for_document from onyx.background.celery.apps.app_base import task_logger from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex from onyx.background.celery.tasks.shared.tasks import LIGHT_SOFT_TIME_LIMIT from onyx.background.celery.tasks.shared.tasks import LIGHT_TIME_LIMIT from onyx.background.celery.tasks.shared.tasks import OnyxCeleryTaskCompletionStatus from onyx.background.celery.tasks.vespa.document_sync import DOCUMENT_SYNC_FENCE_KEY from onyx.background.celery.tasks.vespa.document_sync import get_document_sync_payload from onyx.background.celery.tasks.vespa.document_sync import get_document_sync_remaining from onyx.background.celery.tasks.vespa.document_sync import reset_document_sync from onyx.background.celery.tasks.vespa.document_sync import ( try_generate_stale_document_sync_tasks, ) from onyx.configs.app_configs import JOB_TIMEOUT from onyx.configs.app_configs import VESPA_SYNC_MAX_TASKS from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisConstants from onyx.configs.constants import OnyxRedisLocks from onyx.db.document import get_document from onyx.db.document import mark_document_as_synced from onyx.db.document_set import delete_document_set from onyx.db.document_set import fetch_document_sets from onyx.db.document_set import fetch_document_sets_for_document from onyx.db.document_set import get_document_set_by_id from onyx.db.document_set import mark_document_set_as_synced from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.enums import SyncStatus from onyx.db.enums import SyncType from onyx.db.models import DocumentSet from onyx.db.models import UserGroup from onyx.db.search_settings import get_active_search_settings from onyx.db.sync_record import cleanup_sync_records from onyx.db.sync_record import insert_sync_record from onyx.db.sync_record import update_sync_record_status from onyx.document_index.factory import get_all_document_indices from onyx.document_index.interfaces import VespaDocumentFields from onyx.httpx.httpx_pool import HttpxPool from onyx.redis.redis_document_set import RedisDocumentSet from onyx.redis.redis_pool import get_redis_client from onyx.redis.redis_pool import get_redis_replica_client from onyx.redis.redis_pool import redis_lock_dump from onyx.redis.redis_usergroup import RedisUserGroup from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_versioned_implementation from onyx.utils.variable_functionality import ( fetch_versioned_implementation_with_fallback, ) from onyx.utils.variable_functionality import global_version from onyx.utils.variable_functionality import noop_fallback logger = setup_logger() # celery auto associates tasks created inside another task, # which bloats the result metadata considerably. trail=False prevents this. # TODO(andrei): Rename all these kinds of functions from *vespa* to a more # generic *document_index*. @shared_task( name=OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK, ignore_result=True, soft_time_limit=JOB_TIMEOUT, trail=False, bind=True, ) def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None: """Runs periodically to check if any document needs syncing. Generates sets of tasks for Celery if syncing is needed.""" # Useful for debugging timing issues with reacquisitions. # TODO: remove once more generalized logging is in place task_logger.info("check_for_vespa_sync_task started") time_start = time.monotonic() r = get_redis_client() r_replica = get_redis_replica_client() lock_beat: RedisLock = r.lock( OnyxRedisLocks.CHECK_VESPA_SYNC_BEAT_LOCK, timeout=CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT, ) # these tasks should never overlap if not lock_beat.acquire(blocking=False): return None try: # 1/3: KICKOFF with get_session_with_current_tenant() as db_session: try_generate_stale_document_sync_tasks( self.app, VESPA_SYNC_MAX_TASKS, db_session, r, lock_beat, tenant_id ) # region document set scan lock_beat.reacquire() document_set_ids: list[int] = [] with get_session_with_current_tenant() as db_session: # check if any document sets are not synced document_set_info = fetch_document_sets( user_id=None, db_session=db_session, include_outdated=True ) for document_set, _ in document_set_info: document_set_ids.append(document_set.id) for document_set_id in document_set_ids: lock_beat.reacquire() with get_session_with_current_tenant() as db_session: try_generate_document_set_sync_tasks( self.app, document_set_id, db_session, r, lock_beat, tenant_id ) # endregion # check if any user groups are not synced lock_beat.reacquire() if global_version.is_ee_version(): try: fetch_user_groups = fetch_versioned_implementation( "onyx.db.user_group", "fetch_user_groups" ) except ModuleNotFoundError: # Always exceptions on the MIT version, which is expected # We shouldn't actually get here if the ee version check works pass else: usergroup_ids: list[int] = [] with get_session_with_current_tenant() as db_session: user_groups = fetch_user_groups( db_session=db_session, only_up_to_date=False ) for usergroup in user_groups: usergroup_ids.append(usergroup.id) for usergroup_id in usergroup_ids: lock_beat.reacquire() with get_session_with_current_tenant() as db_session: try_generate_user_group_sync_tasks( self.app, usergroup_id, db_session, r, lock_beat, tenant_id ) # 2/3: VALIDATE: TODO # 3/3: FINALIZE lock_beat.reacquire() keys = cast(set[Any], r_replica.smembers(OnyxRedisConstants.ACTIVE_FENCES)) for key in keys: key_bytes = cast(bytes, key) if not r.exists(key_bytes): r.srem(OnyxRedisConstants.ACTIVE_FENCES, key_bytes) continue key_str = key_bytes.decode("utf-8") # NOTE: removing the "Redis*" classes, prefer to just have functions to # do these things going forward. In short, things should generally be like the doc # sync task rather than the others if key_str == DOCUMENT_SYNC_FENCE_KEY: monitor_document_sync_taskset(r) elif key_str.startswith(RedisDocumentSet.FENCE_PREFIX): with get_session_with_current_tenant() as db_session: monitor_document_set_taskset(tenant_id, key_bytes, r, db_session) elif key_str.startswith(RedisUserGroup.FENCE_PREFIX): monitor_usergroup_taskset = ( fetch_versioned_implementation_with_fallback( "onyx.background.celery.tasks.vespa.tasks", "monitor_usergroup_taskset", noop_fallback, ) ) with get_session_with_current_tenant() as db_session: monitor_usergroup_taskset(tenant_id, key_bytes, r, db_session) except SoftTimeLimitExceeded: task_logger.info( "Soft time limit exceeded, task is being terminated gracefully." ) except Exception: task_logger.exception("Unexpected exception during vespa metadata sync") finally: if lock_beat.owned(): lock_beat.release() else: task_logger.error( f"check_for_vespa_sync_task - Lock not owned on completion: tenant={tenant_id}" ) redis_lock_dump(lock_beat, r) time_elapsed = time.monotonic() - time_start task_logger.debug(f"check_for_vespa_sync_task finished: elapsed={time_elapsed:.2f}") return True def try_generate_document_set_sync_tasks( celery_app: Celery, document_set_id: int, db_session: Session, r: Redis, lock_beat: RedisLock, tenant_id: str, ) -> int | None: lock_beat.reacquire() rds = RedisDocumentSet(tenant_id, document_set_id) # don't generate document set sync tasks if tasks are still pending if rds.fenced: return None # don't generate sync tasks if we're up to date # race condition with the monitor/cleanup function if we use a cached result! document_set = get_document_set_by_id( db_session=db_session, document_set_id=document_set_id, ) if not document_set: return None if document_set.is_up_to_date: # there should be no in-progress sync records if this is up to date # clean it up just in case things got into a bad state cleanup_sync_records( db_session=db_session, entity_id=document_set_id, sync_type=SyncType.DOCUMENT_SET, ) return None # add tasks to celery and build up the task set to monitor in redis r.delete(rds.taskset_key) task_logger.info( f"RedisDocumentSet.generate_tasks starting. document_set_id={document_set.id}" ) # Add all documents that need to be updated into the queue result = rds.generate_tasks( VESPA_SYNC_MAX_TASKS, celery_app, db_session, r, lock_beat, tenant_id ) if result is None: return None tasks_generated = result[0] # Currently we are allowing the sync to proceed with 0 tasks. # It's possible for sets/groups to be generated initially with no entries # and they still need to be marked as up to date. # if tasks_generated == 0: # return 0 task_logger.info( f"RedisDocumentSet.generate_tasks finished. document_set={document_set.id} tasks_generated={tasks_generated}" ) # create before setting fence to avoid race condition where the monitoring # task updates the sync record before it is created try: insert_sync_record( db_session=db_session, entity_id=document_set_id, sync_type=SyncType.DOCUMENT_SET, ) except Exception: task_logger.exception("insert_sync_record exceptioned.") # set this only after all tasks have been added rds.set_fence(tasks_generated) return tasks_generated def try_generate_user_group_sync_tasks( celery_app: Celery, usergroup_id: int, db_session: Session, r: Redis, lock_beat: RedisLock, tenant_id: str, ) -> int | None: lock_beat.reacquire() rug = RedisUserGroup(tenant_id, usergroup_id) if rug.fenced: # don't generate sync tasks if tasks are still pending return None # race condition with the monitor/cleanup function if we use a cached result! fetch_user_group = cast( Callable[[Session, int], UserGroup | None], fetch_versioned_implementation("onyx.db.user_group", "fetch_user_group"), ) usergroup = fetch_user_group(db_session, usergroup_id) if not usergroup: return None if usergroup.is_up_to_date: # there should be no in-progress sync records if this is up to date # clean it up just in case things got into a bad state cleanup_sync_records( db_session=db_session, entity_id=usergroup_id, sync_type=SyncType.USER_GROUP, ) return None # add tasks to celery and build up the task set to monitor in redis r.delete(rug.taskset_key) # Add all documents that need to be updated into the queue task_logger.info( f"RedisUserGroup.generate_tasks starting. usergroup_id={usergroup.id}" ) result = rug.generate_tasks( VESPA_SYNC_MAX_TASKS, celery_app, db_session, r, lock_beat, tenant_id ) if result is None: return None tasks_generated = result[0] # Currently we are allowing the sync to proceed with 0 tasks. # It's possible for sets/groups to be generated initially with no entries # and they still need to be marked as up to date. # if tasks_generated == 0: # return 0 task_logger.info( f"RedisUserGroup.generate_tasks finished. usergroup={usergroup.id} tasks_generated={tasks_generated}" ) # create before setting fence to avoid race condition where the monitoring # task updates the sync record before it is created try: insert_sync_record( db_session=db_session, entity_id=usergroup_id, sync_type=SyncType.USER_GROUP, ) except Exception: task_logger.exception("insert_sync_record exceptioned.") # set this only after all tasks have been added rug.set_fence(tasks_generated) return tasks_generated def monitor_document_sync_taskset(r: Redis) -> None: initial_count = get_document_sync_payload(r) if initial_count is None: return remaining = get_document_sync_remaining(r) task_logger.info( f"Document sync progress: remaining={remaining} initial={initial_count}" ) if remaining == 0: reset_document_sync(r) task_logger.info(f"Successfully synced all documents. count={initial_count}") def monitor_document_set_taskset( tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session ) -> None: fence_key = key_bytes.decode("utf-8") document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key) if document_set_id_str is None: task_logger.warning(f"could not parse document set id from {fence_key}") return document_set_id = int(document_set_id_str) rds = RedisDocumentSet(tenant_id, document_set_id) if not rds.fenced: return initial_count = rds.payload if initial_count is None: return count = cast(int, r.scard(rds.taskset_key)) task_logger.info( f"Document set sync progress: document_set={document_set_id} remaining={count} initial={initial_count}" ) if count > 0: update_sync_record_status( db_session=db_session, entity_id=document_set_id, sync_type=SyncType.DOCUMENT_SET, sync_status=SyncStatus.IN_PROGRESS, num_docs_synced=count, ) return document_set = cast( DocumentSet, get_document_set_by_id(db_session=db_session, document_set_id=document_set_id), ) # casting since we "know" a document set with this ID exists if document_set: has_connector_pairs = bool(document_set.connector_credential_pairs) # Federated connectors should keep a document set alive even without cc pairs. has_federated_connectors = bool( getattr(document_set, "federated_connectors", []) ) if not has_connector_pairs and not has_federated_connectors: # If there are no connectors of any kind, delete the document set. delete_document_set(document_set_row=document_set, db_session=db_session) task_logger.info( f"Successfully deleted document set: document_set={document_set_id}" ) else: mark_document_set_as_synced(document_set_id, db_session) task_logger.info( f"Successfully synced document set: document_set={document_set_id}" ) try: update_sync_record_status( db_session=db_session, entity_id=document_set_id, sync_type=SyncType.DOCUMENT_SET, sync_status=SyncStatus.SUCCESS, num_docs_synced=initial_count, ) except Exception: task_logger.exception( f"update_sync_record_status exceptioned. document_set_id={document_set_id} Resetting document set regardless." ) rds.reset() @shared_task( name=OnyxCeleryTask.VESPA_METADATA_SYNC_TASK, bind=True, soft_time_limit=LIGHT_SOFT_TIME_LIMIT, time_limit=LIGHT_TIME_LIMIT, max_retries=3, ) def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) -> bool: start = time.monotonic() completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED try: with get_session_with_current_tenant() as db_session: active_search_settings = get_active_search_settings(db_session) # This flow is for updates so we get all indices. document_indices = get_all_document_indices( search_settings=active_search_settings.primary, secondary_search_settings=active_search_settings.secondary, httpx_client=HttpxPool.get("vespa"), ) retry_document_indices: list[RetryDocumentIndex] = [ RetryDocumentIndex(document_index) for document_index in document_indices ] doc = get_document(document_id, db_session) if not doc: elapsed = time.monotonic() - start task_logger.info( f"doc={document_id} action=no_operation elapsed={elapsed:.2f}" ) completion_status = OnyxCeleryTaskCompletionStatus.SKIPPED else: # document set sync doc_sets = fetch_document_sets_for_document(document_id, db_session) update_doc_sets: set[str] = set(doc_sets) # User group sync doc_access = get_access_for_document( document_id=document_id, db_session=db_session ) fields = VespaDocumentFields( document_sets=update_doc_sets, access=doc_access, boost=doc.boost, hidden=doc.hidden, # aggregated_boost_factor=doc.aggregated_boost_factor, ) for retry_document_index in retry_document_indices: # TODO(andrei): Previously there was a comment here saying # it was ok if a doc did not exist in the document index. I # don't agree with that claim, so keep an eye on this task # to see if this raises. retry_document_index.update_single( document_id, tenant_id=tenant_id, chunk_count=doc.chunk_count, fields=fields, user_fields=None, ) # update db last. Worst case = we crash right before this and # the sync might repeat again later mark_document_as_synced(document_id, db_session) elapsed = time.monotonic() - start task_logger.info(f"doc={document_id} action=sync elapsed={elapsed:.2f}") completion_status = OnyxCeleryTaskCompletionStatus.SUCCEEDED except SoftTimeLimitExceeded: task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}") completion_status = OnyxCeleryTaskCompletionStatus.SOFT_TIME_LIMIT except Exception as ex: e: Exception | None = None while True: if isinstance(ex, RetryError): task_logger.warning( f"Tenacity retry failed: num_attempts={ex.last_attempt.attempt_number}" ) # only set the inner exception if it is of type Exception e_temp = ex.last_attempt.exception() if isinstance(e_temp, Exception): e = e_temp else: e = ex if isinstance(e, httpx.HTTPStatusError): if e.response.status_code == HTTPStatus.BAD_REQUEST: task_logger.exception( f"Non-retryable HTTPStatusError: doc={document_id} status={e.response.status_code}" ) completion_status = ( OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION ) break task_logger.exception( f"vespa_metadata_sync_task exceptioned: doc={document_id}" ) completion_status = OnyxCeleryTaskCompletionStatus.RETRYABLE_EXCEPTION if ( self.max_retries is not None and self.request.retries >= self.max_retries ): completion_status = ( OnyxCeleryTaskCompletionStatus.NON_RETRYABLE_EXCEPTION ) # Exponential backoff from 2^4 to 2^6 ... i.e. 16, 32, 64 countdown = 2 ** (self.request.retries + 4) self.retry(exc=e, countdown=countdown) # this will raise a celery exception break # we won't hit this, but it looks weird not to have it finally: task_logger.info( f"vespa_metadata_sync_task completed: status={completion_status.value} doc={document_id}" ) return completion_status == OnyxCeleryTaskCompletionStatus.SUCCEEDED ================================================ FILE: backend/onyx/background/celery/versioned_apps/beat.py ================================================ """Factory stub for running celery worker / celery beat.""" from celery import Celery from onyx.background.celery.apps.beat import celery_app from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable set_is_ee_based_on_env_variable() app: Celery = celery_app ================================================ FILE: backend/onyx/background/celery/versioned_apps/client.py ================================================ """Factory stub for running celery worker / celery beat. This code is different from the primary/beat stubs because there is no EE version to fetch. Port over the code in those files if we add an EE version of this worker. This is an app stub purely for sending tasks as a client. """ from celery import Celery from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable set_is_ee_based_on_env_variable() def get_app() -> Celery: from onyx.background.celery.apps.client import celery_app return celery_app app = get_app() ================================================ FILE: backend/onyx/background/celery/versioned_apps/docfetching.py ================================================ """Factory stub for running celery worker / celery beat. This code is different from the primary/beat stubs because there is no EE version to fetch. Port over the code in those files if we add an EE version of this worker.""" from celery import Celery from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable set_is_ee_based_on_env_variable() def get_app() -> Celery: from onyx.background.celery.apps.docfetching import celery_app return celery_app app = get_app() ================================================ FILE: backend/onyx/background/celery/versioned_apps/docprocessing.py ================================================ """Factory stub for running celery worker / celery beat. This code is different from the primary/beat stubs because there is no EE version to fetch. Port over the code in those files if we add an EE version of this worker.""" from celery import Celery from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable set_is_ee_based_on_env_variable() def get_app() -> Celery: from onyx.background.celery.apps.docprocessing import celery_app return celery_app app = get_app() ================================================ FILE: backend/onyx/background/celery/versioned_apps/heavy.py ================================================ """Factory stub for running celery worker / celery beat. This code is different from the primary/beat stubs because there is no EE version to fetch. Port over the code in those files if we add an EE version of this worker.""" from celery import Celery from onyx.utils.variable_functionality import fetch_versioned_implementation from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable set_is_ee_based_on_env_variable() app: Celery = fetch_versioned_implementation( "onyx.background.celery.apps.heavy", "celery_app", ) ================================================ FILE: backend/onyx/background/celery/versioned_apps/light.py ================================================ """Factory stub for running celery worker / celery beat. This code is different from the primary/beat stubs because there is no EE version to fetch. Port over the code in those files if we add an EE version of this worker.""" from celery import Celery from onyx.utils.variable_functionality import fetch_versioned_implementation from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable set_is_ee_based_on_env_variable() app: Celery = fetch_versioned_implementation( "onyx.background.celery.apps.light", "celery_app", ) ================================================ FILE: backend/onyx/background/celery/versioned_apps/monitoring.py ================================================ """Factory stub for running celery worker / celery beat.""" from celery import Celery from onyx.utils.variable_functionality import fetch_versioned_implementation from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable set_is_ee_based_on_env_variable() app: Celery = fetch_versioned_implementation( "onyx.background.celery.apps.monitoring", "celery_app", ) ================================================ FILE: backend/onyx/background/celery/versioned_apps/primary.py ================================================ """Factory stub for running celery worker / celery beat.""" from celery import Celery from onyx.utils.variable_functionality import fetch_versioned_implementation from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable set_is_ee_based_on_env_variable() app: Celery = fetch_versioned_implementation( "onyx.background.celery.apps.primary", "celery_app", ) ================================================ FILE: backend/onyx/background/celery/versioned_apps/user_file_processing.py ================================================ """Factory stub for running the user file processing Celery worker.""" from celery import Celery from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable set_is_ee_based_on_env_variable() def get_app() -> Celery: from onyx.background.celery.apps.user_file_processing import celery_app return celery_app app = get_app() ================================================ FILE: backend/onyx/background/error_logging.py ================================================ from sqlalchemy.exc import IntegrityError from onyx.db.background_error import create_background_error from onyx.db.engine.sql_engine import get_session_with_current_tenant def emit_background_error( message: str, cc_pair_id: int | None = None, ) -> None: """Currently just saves a row in the background_errors table. In the future, could create notifications based on the severity.""" error_message = "" # try to write to the db, but handle IntegrityError specifically try: with get_session_with_current_tenant() as db_session: create_background_error(db_session, message, cc_pair_id) except IntegrityError as e: # Log an error if the cc_pair_id was deleted or any other exception occurs error_message = ( f"Failed to create background error: {str(e)}. Original message: {message}" ) except Exception: pass if not error_message: return # if we get here from an IntegrityError, try to write the error message to the db # we need a new session because the first session is now invalid try: with get_session_with_current_tenant() as db_session: create_background_error(db_session, error_message, None) except Exception: pass ================================================ FILE: backend/onyx/background/indexing/checkpointing_utils.py ================================================ from datetime import datetime from datetime import timedelta from io import BytesIO from sqlalchemy import and_ from sqlalchemy.orm import Session from onyx.configs.constants import FileOrigin from onyx.configs.constants import NUM_DAYS_TO_KEEP_CHECKPOINTS from onyx.connectors.interfaces import BaseConnector from onyx.connectors.interfaces import CheckpointedConnector from onyx.connectors.models import ConnectorCheckpoint from onyx.db.engine.time_utils import get_db_current_time from onyx.db.index_attempt import get_index_attempt from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair from onyx.db.models import IndexAttempt from onyx.db.models import IndexingStatus from onyx.file_store.file_store import get_default_file_store from onyx.utils.logger import setup_logger from onyx.utils.object_size_check import deep_getsizeof logger = setup_logger() _NUM_RECENT_ATTEMPTS_TO_CONSIDER = 50 def _build_checkpoint_pointer(index_attempt_id: int) -> str: return f"checkpoint_{index_attempt_id}.json" def save_checkpoint( db_session: Session, index_attempt_id: int, checkpoint: ConnectorCheckpoint ) -> str: """Save a checkpoint for a given index attempt to the file store""" checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id) file_store = get_default_file_store() file_store.save_file( content=BytesIO(checkpoint.model_dump_json().encode()), display_name=checkpoint_pointer, file_origin=FileOrigin.INDEXING_CHECKPOINT, file_type="application/json", file_id=checkpoint_pointer, ) index_attempt = get_index_attempt(db_session, index_attempt_id) if not index_attempt: raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.") index_attempt.checkpoint_pointer = checkpoint_pointer db_session.add(index_attempt) db_session.commit() return checkpoint_pointer def load_checkpoint( index_attempt_id: int, connector: BaseConnector ) -> ConnectorCheckpoint: """Load a checkpoint for a given index attempt from the file store""" checkpoint_pointer = _build_checkpoint_pointer(index_attempt_id) file_store = get_default_file_store() checkpoint_io = file_store.read_file(checkpoint_pointer, mode="rb") checkpoint_data = checkpoint_io.read().decode("utf-8") if isinstance(connector, CheckpointedConnector): return connector.validate_checkpoint_json(checkpoint_data) return ConnectorCheckpoint.model_validate_json(checkpoint_data) def get_latest_valid_checkpoint( db_session: Session, cc_pair_id: int, search_settings_id: int, window_start: datetime, window_end: datetime, connector: BaseConnector, ) -> tuple[ConnectorCheckpoint, bool]: """Get the latest valid checkpoint for a given connector credential pair""" checkpoint_candidates = get_recent_completed_attempts_for_cc_pair( cc_pair_id=cc_pair_id, search_settings_id=search_settings_id, db_session=db_session, limit=_NUM_RECENT_ATTEMPTS_TO_CONSIDER, ) # don't keep using checkpoints if we've had a bunch of failed attempts in a row # where we make no progress. Only do this if we have had at least # _NUM_RECENT_ATTEMPTS_TO_CONSIDER completed attempts. if len(checkpoint_candidates) >= _NUM_RECENT_ATTEMPTS_TO_CONSIDER: had_any_progress = False for candidate in checkpoint_candidates: if ( candidate.total_docs_indexed is not None and candidate.total_docs_indexed > 0 ) or candidate.status.is_successful(): had_any_progress = True break if not had_any_progress: logger.warning( f"{_NUM_RECENT_ATTEMPTS_TO_CONSIDER} consecutive failed attempts without progress " f"found for cc_pair={cc_pair_id}. Ignoring checkpoint to let the run start " "from scratch." ) return connector.build_dummy_checkpoint(), False # filter out any candidates that don't meet the criteria checkpoint_candidates = [ candidate for candidate in checkpoint_candidates if ( candidate.poll_range_start == window_start and candidate.poll_range_end == window_end and ( candidate.status == IndexingStatus.FAILED # if the background job was killed (and thus the attempt was canceled) # we still want to use the checkpoint so that we can pick up where we left off or candidate.status == IndexingStatus.CANCELED ) and candidate.checkpoint_pointer is not None # NOTE: There are a couple connectors that may make progress but not have # any "total_docs_indexed". E.g. they are going through # Slack channels, and tons of them don't have any updates. # Leaving the below in as historical context / in-case we want to use it again. # we want to make sure that the checkpoint is actually useful # if it's only gone through a few docs, it's probably not worth # using. This also avoids weird cases where a connector is basically # non-functional but still "makes progress" by slowly moving the # checkpoint forward run after run # and candidate.total_docs_indexed # and candidate.total_docs_indexed > 100 ) ] # assumes latest checkpoint is the furthest along. This only isn't true # if something else has gone wrong. latest_valid_checkpoint_candidate = ( checkpoint_candidates[0] if checkpoint_candidates else None ) checkpoint = connector.build_dummy_checkpoint() if latest_valid_checkpoint_candidate is None: logger.info( f"No valid checkpoint found for cc_pair={cc_pair_id}. Starting from scratch." ) return checkpoint, False try: previous_checkpoint = load_checkpoint( index_attempt_id=latest_valid_checkpoint_candidate.id, connector=connector, ) except Exception: logger.exception( f"Failed to load checkpoint from previous failed attempt with ID " f"{latest_valid_checkpoint_candidate.id}. Falling back to default checkpoint." ) return checkpoint, False logger.info( f"Using checkpoint from previous failed attempt with ID " f"{latest_valid_checkpoint_candidate.id}. Previous checkpoint: " f"{previous_checkpoint}" ) return previous_checkpoint, True def get_index_attempts_with_old_checkpoints( db_session: Session, days_to_keep: int = NUM_DAYS_TO_KEEP_CHECKPOINTS ) -> list[IndexAttempt]: """Get all index attempts with checkpoints older than the specified number of days. Args: db_session: The database session days_to_keep: Number of days to keep checkpoints for (default: NUM_DAYS_TO_KEEP_CHECKPOINTS) Returns: List of IndexAttempt objects with old checkpoints """ cutoff_date = get_db_current_time(db_session) - timedelta(days=days_to_keep) # Find all index attempts with checkpoints older than cutoff_date old_attempts = ( db_session.query(IndexAttempt) .filter( and_( IndexAttempt.checkpoint_pointer.isnot(None), IndexAttempt.time_created < cutoff_date, ) ) .all() ) return old_attempts def cleanup_checkpoint(db_session: Session, index_attempt_id: int) -> None: """Clean up a checkpoint for a given index attempt""" index_attempt = get_index_attempt(db_session, index_attempt_id) if not index_attempt: raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.") if not index_attempt.checkpoint_pointer: return None file_store = get_default_file_store() file_store.delete_file(index_attempt.checkpoint_pointer) index_attempt.checkpoint_pointer = None db_session.add(index_attempt) db_session.commit() return None def check_checkpoint_size(checkpoint: ConnectorCheckpoint) -> None: """Check if the checkpoint content size exceeds the limit (200MB)""" content_size = deep_getsizeof(checkpoint.model_dump()) if content_size > 200_000_000: # 200MB in bytes raise ValueError( f"Checkpoint content size ({content_size} bytes) exceeds 200MB limit" ) ================================================ FILE: backend/onyx/background/indexing/dask_utils.py ================================================ import asyncio import psutil from dask.distributed import WorkerPlugin from distributed import Worker from onyx.utils.logger import setup_logger logger = setup_logger() class ResourceLogger(WorkerPlugin): def __init__(self, log_interval: int = 60 * 5): self.log_interval = log_interval def setup(self, worker: Worker) -> None: """This method will be called when the plugin is attached to a worker.""" self.worker = worker worker.loop.add_callback(self.log_resources) async def log_resources(self) -> None: """Periodically log CPU and memory usage. NOTE: must be async or else will clog up the worker indefinitely due to the fact that Dask uses Tornado under the hood (which is async)""" while True: cpu_percent = psutil.cpu_percent(interval=None) memory_available_gb = psutil.virtual_memory().available / (1024.0**3) # You can now log these values or send them to a monitoring service logger.debug( f"Worker {self.worker.address}: CPU usage {cpu_percent}%, Memory available {memory_available_gb}GB" ) await asyncio.sleep(self.log_interval) ================================================ FILE: backend/onyx/background/indexing/index_attempt_utils.py ================================================ from datetime import timedelta from sqlalchemy import func from sqlalchemy.orm import Session from onyx.configs.constants import NUM_DAYS_TO_KEEP_INDEX_ATTEMPTS from onyx.db.engine.time_utils import get_db_current_time from onyx.db.models import IndexAttempt from onyx.db.models import IndexAttemptError # Always retain at least this many attempts per connector/search settings pair NUM_RECENT_INDEX_ATTEMPTS_TO_KEEP = 10 def get_old_index_attempts( db_session: Session, days_to_keep: int = NUM_DAYS_TO_KEEP_INDEX_ATTEMPTS ) -> list[IndexAttempt]: """ Get index attempts older than the specified number of days while retaining the latest NUM_RECENT_INDEX_ATTEMPTS_TO_KEEP per connector/search settings pair. """ cutoff_date = get_db_current_time(db_session) - timedelta(days=days_to_keep) ranked_attempts = ( db_session.query( IndexAttempt.id.label("attempt_id"), IndexAttempt.time_created.label("time_created"), func.row_number() .over( partition_by=( IndexAttempt.connector_credential_pair_id, IndexAttempt.search_settings_id, ), order_by=IndexAttempt.time_created.desc(), ) .label("attempt_rank"), ) ).subquery() return ( db_session.query(IndexAttempt) .join( ranked_attempts, IndexAttempt.id == ranked_attempts.c.attempt_id, ) .filter( ranked_attempts.c.time_created < cutoff_date, ranked_attempts.c.attempt_rank > NUM_RECENT_INDEX_ATTEMPTS_TO_KEEP, ) .all() ) def cleanup_index_attempts(db_session: Session, index_attempt_ids: list[int]) -> None: """Clean up multiple index attempts""" db_session.query(IndexAttemptError).filter( IndexAttemptError.index_attempt_id.in_(index_attempt_ids) ).delete(synchronize_session=False) db_session.query(IndexAttempt).filter( IndexAttempt.id.in_(index_attempt_ids) ).delete(synchronize_session=False) db_session.commit() ================================================ FILE: backend/onyx/background/indexing/job_client.py ================================================ """Custom client that works similarly to Dask, but simpler and more lightweight. Dask jobs behaved very strangely - they would die all the time, retries would not follow the expected behavior, etc. NOTE: cannot use Celery directly due to https://github.com/celery/celery/issues/7007#issuecomment-1740139367""" import multiprocessing as mp import sys import traceback from collections.abc import Callable from dataclasses import dataclass from multiprocessing.context import SpawnProcess from typing import Any from typing import Literal from typing import Optional from onyx.configs.constants import POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME from onyx.db.engine.sql_engine import SqlEngine from onyx.utils.logger import setup_logger from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.configs import TENANT_ID_PREFIX from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() class SimpleJobException(Exception): """lets us raise an exception that will return a specific error code""" def __init__(self, *args: Any, **kwargs: Any) -> None: code: int | None = kwargs.pop("code", None) self.code = code super().__init__(*args, **kwargs) JobStatusType = ( Literal["error"] | Literal["finished"] | Literal["pending"] | Literal["running"] | Literal["cancelled"] ) def _initializer( func: Callable, queue: mp.Queue, args: list | tuple, kwargs: dict[str, Any] | None = None, ) -> Any: """Initialize the child process with a fresh SQLAlchemy Engine. Based on SQLAlchemy's recommendations to handle multiprocessing: https://docs.sqlalchemy.org/en/20/core/pooling.html#using-connection-pools-with-multiprocessing-or-os-fork """ if kwargs is None: kwargs = {} logger.info("Initializing spawned worker child process.") # 1. Get tenant_id from args or fallback to default tenant_id = POSTGRES_DEFAULT_SCHEMA for arg in reversed(args): if isinstance(arg, str) and arg.startswith(TENANT_ID_PREFIX): tenant_id = arg break # 2. Set the tenant context before running anything token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) # Reset the engine in the child process SqlEngine.reset_engine() # Optionally set a custom app name for database logging purposes SqlEngine.set_app_name(POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME) # Initialize a new engine with desired parameters SqlEngine.init_engine( pool_size=4, max_overflow=12, pool_recycle=60, pool_pre_ping=True ) # Proceed with executing the target function try: return func(*args, **kwargs) except SimpleJobException as e: logger.exception("SimpleJob raised a SimpleJobException") error_msg = traceback.format_exc() queue.put(error_msg) # Send the exception to the parent process sys.exit(e.code) # use the given exit code except Exception: logger.exception("SimpleJob raised an exception") error_msg = traceback.format_exc() queue.put(error_msg) # Send the exception to the parent process sys.exit(255) # use 255 to indicate a generic exception finally: CURRENT_TENANT_ID_CONTEXTVAR.reset(token) def _run_in_process( func: Callable, queue: mp.Queue, args: list | tuple, kwargs: dict[str, Any] | None = None, ) -> None: _initializer(func, queue, args, kwargs) @dataclass class SimpleJob: """Drop in replacement for `dask.distributed.Future`""" id: int process: Optional["SpawnProcess"] = None queue: Optional[mp.Queue] = None _exception: Optional[str] = None def cancel(self) -> bool: return self.release() def release(self) -> bool: if self.process is not None and self.process.is_alive(): self.process.terminate() return True return False @property def status(self) -> JobStatusType: if not self.process: return "pending" elif self.process.is_alive(): return "running" elif self.process.exitcode is None: return "cancelled" elif self.process.exitcode != 0: return "error" else: return "finished" def done(self) -> bool: return ( self.status == "finished" or self.status == "cancelled" or self.status == "error" ) def exception(self) -> str: """Needed to match the Dask API, but not implemented since we don't currently have a way to get back the exception information from the child process.""" """Retrieve exception from the multiprocessing queue if available.""" if self._exception is None and self.queue and not self.queue.empty(): self._exception = self.queue.get() # Get exception from queue return ( self._exception or f"Job with ID '{self.id}' did not report an exception." ) class SimpleJobClient: """Drop in replacement for `dask.distributed.Client`""" def __init__(self, n_workers: int = 1) -> None: self.n_workers = n_workers self.job_id_counter = 0 self.jobs: dict[int, SimpleJob] = {} def _cleanup_completed_jobs(self) -> None: current_job_ids = list(self.jobs.keys()) for job_id in current_job_ids: job = self.jobs.get(job_id) if job and job.done(): logger.debug(f"Cleaning up job with id: '{job.id}'") del self.jobs[job.id] def submit( self, func: Callable, *args: Any, pure: bool = True, # noqa: ARG002 ) -> SimpleJob | None: """NOTE: `pure` arg is needed so this can be a drop in replacement for Dask""" self._cleanup_completed_jobs() if len(self.jobs) >= self.n_workers: logger.debug( f"No available workers to run job. Currently running '{len(self.jobs)}' jobs, with a limit of '{self.n_workers}'." ) return None job_id = self.job_id_counter self.job_id_counter += 1 # this approach allows us to always "spawn" a new process regardless of # get_start_method's current setting ctx = mp.get_context("spawn") queue = ctx.Queue() process = ctx.Process( target=_run_in_process, args=(func, queue, args), daemon=True ) job = SimpleJob(id=job_id, process=process, queue=queue) process.start() self.jobs[job_id] = job return job ================================================ FILE: backend/onyx/background/indexing/memory_tracer.py ================================================ import tracemalloc from onyx.utils.logger import setup_logger logger = setup_logger() DANSWER_TRACEMALLOC_FRAMES = 10 class MemoryTracer: def __init__(self, interval: int = 0, num_print_entries: int = 5): self.interval = interval self.num_print_entries = num_print_entries self.snapshot_first: tracemalloc.Snapshot | None = None self.snapshot_prev: tracemalloc.Snapshot | None = None self.snapshot: tracemalloc.Snapshot | None = None self.counter = 0 def start(self) -> None: """Start the memory tracer if interval is greater than 0.""" if self.interval > 0: logger.debug(f"Memory tracer starting: interval={self.interval}") tracemalloc.start(DANSWER_TRACEMALLOC_FRAMES) self._take_snapshot() def stop(self) -> None: """Stop the memory tracer if it's running.""" if self.interval > 0: self.log_final_diff() tracemalloc.stop() logger.debug("Memory tracer stopped.") def _take_snapshot(self) -> None: """Take a snapshot and update internal snapshot states.""" snapshot = tracemalloc.take_snapshot() # Filter out irrelevant frames snapshot = snapshot.filter_traces( ( tracemalloc.Filter(False, tracemalloc.__file__), tracemalloc.Filter(False, ""), tracemalloc.Filter(False, ""), ) ) if not self.snapshot_first: self.snapshot_first = snapshot if self.snapshot: self.snapshot_prev = self.snapshot self.snapshot = snapshot def _log_diff( self, current: tracemalloc.Snapshot, previous: tracemalloc.Snapshot ) -> None: """Log the memory difference between two snapshots.""" stats = current.compare_to(previous, "traceback") for s in stats[: self.num_print_entries]: logger.debug(f"Tracer diff: {s}") for line in s.traceback.format(): logger.debug(f"* {line}") def increment_and_maybe_trace(self) -> None: """Increment counter and perform trace if interval is hit.""" if self.interval <= 0: return self.counter += 1 if self.counter % self.interval == 0: logger.debug( f"Running trace comparison for batch {self.counter}. interval={self.interval}" ) self._take_snapshot() if self.snapshot and self.snapshot_prev: self._log_diff(self.snapshot, self.snapshot_prev) def log_final_diff(self) -> None: """Log the final memory diff between start and end of indexing.""" if self.interval <= 0: return logger.debug( f"Running trace comparison between start and end of indexing. {self.counter} batches processed." ) self._take_snapshot() if self.snapshot and self.snapshot_first: self._log_diff(self.snapshot, self.snapshot_first) ================================================ FILE: backend/onyx/background/indexing/models.py ================================================ from datetime import datetime from pydantic import BaseModel from onyx.db.models import IndexAttemptError class IndexAttemptErrorPydantic(BaseModel): id: int connector_credential_pair_id: int document_id: str | None document_link: str | None entity_id: str | None failed_time_range_start: datetime | None failed_time_range_end: datetime | None failure_message: str is_resolved: bool = False time_created: datetime index_attempt_id: int @classmethod def from_model(cls, model: IndexAttemptError) -> "IndexAttemptErrorPydantic": return cls( id=model.id, connector_credential_pair_id=model.connector_credential_pair_id, document_id=model.document_id, document_link=model.document_link, entity_id=model.entity_id, failed_time_range_start=model.failed_time_range_start, failed_time_range_end=model.failed_time_range_end, failure_message=model.failure_message, is_resolved=model.is_resolved, time_created=model.time_created, index_attempt_id=model.index_attempt_id, ) ================================================ FILE: backend/onyx/background/indexing/run_docfetching.py ================================================ import sys import time import traceback from datetime import datetime from datetime import timedelta from datetime import timezone from celery import Celery from sqlalchemy.orm import Session from onyx.access.access import source_should_fetch_permissions_during_indexing from onyx.background.indexing.checkpointing_utils import check_checkpoint_size from onyx.background.indexing.checkpointing_utils import get_latest_valid_checkpoint from onyx.background.indexing.checkpointing_utils import save_checkpoint from onyx.background.indexing.memory_tracer import MemoryTracer from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.app_configs import INDEXING_SIZE_WARNING_THRESHOLD from onyx.configs.app_configs import INDEXING_TRACER_INTERVAL from onyx.configs.app_configs import INTEGRATION_TESTS_MODE from onyx.configs.app_configs import LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES from onyx.configs.app_configs import POLL_CONNECTOR_OFFSET from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.connectors.connector_runner import ConnectorRunner from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import UnexpectedValidationError from onyx.connectors.factory import instantiate_connector from onyx.connectors.interfaces import CheckpointedConnector from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import ConnectorStopSignal from onyx.connectors.models import Document from onyx.connectors.models import IndexAttemptMetadata from onyx.connectors.models import TextSection from onyx.db.connector import mark_ccpair_with_indexing_trigger from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.connector_credential_pair import get_last_successful_attempt_poll_range_end from onyx.db.connector_credential_pair import update_connector_credential_pair from onyx.db.constants import CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import IndexingStatus from onyx.db.enums import IndexModelStatus from onyx.db.enums import ProcessingMode from onyx.db.hierarchy import upsert_hierarchy_node_cc_pair_entries from onyx.db.hierarchy import upsert_hierarchy_nodes_batch from onyx.db.index_attempt import create_index_attempt_error from onyx.db.index_attempt import get_index_attempt from onyx.db.index_attempt import get_recent_completed_attempts_for_cc_pair from onyx.db.index_attempt import mark_attempt_canceled from onyx.db.index_attempt import mark_attempt_failed from onyx.db.index_attempt import transition_attempt_to_in_progress from onyx.db.indexing_coordination import IndexingCoordination from onyx.db.models import IndexAttempt from onyx.file_store.document_batch_storage import DocumentBatchStorage from onyx.file_store.document_batch_storage import get_document_batch_storage from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.indexing.indexing_pipeline import index_doc_batch_prepare from onyx.redis.redis_hierarchy import cache_hierarchy_nodes_batch from onyx.redis.redis_hierarchy import ensure_source_node_exists from onyx.redis.redis_hierarchy import get_node_id_from_raw_id from onyx.redis.redis_hierarchy import get_source_node_id_from_cache from onyx.redis.redis_hierarchy import HierarchyNodeCacheEntry from onyx.redis.redis_pool import get_redis_client from onyx.server.features.build.indexing.persistent_document_writer import ( get_persistent_document_writer, ) from onyx.utils.logger import setup_logger from onyx.utils.middleware import make_randomized_onyx_request_id from onyx.utils.postgres_sanitization import sanitize_document_for_postgres from onyx.utils.postgres_sanitization import sanitize_hierarchy_nodes_for_postgres from onyx.utils.variable_functionality import global_version from shared_configs.configs import MULTI_TENANT from shared_configs.contextvars import INDEX_ATTEMPT_INFO_CONTEXTVAR logger = setup_logger(propagate=False) INDEXING_TRACER_NUM_PRINT_ENTRIES = 5 def _get_connector_runner( db_session: Session, attempt: IndexAttempt, batch_size: int, start_time: datetime, end_time: datetime, include_permissions: bool, leave_connector_active: bool = LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE, ) -> ConnectorRunner: """ NOTE: `start_time` and `end_time` are only used for poll connectors Returns an iterator of document batches and whether the returned documents are the complete list of existing documents of the connector. If the task of type LOAD_STATE, the list will be considered complete and otherwise incomplete. """ task = attempt.connector_credential_pair.connector.input_type try: runnable_connector = instantiate_connector( db_session=db_session, source=attempt.connector_credential_pair.connector.source, input_type=task, connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config, credential=attempt.connector_credential_pair.credential, ) # validate the connector settings if not INTEGRATION_TESTS_MODE: runnable_connector.validate_connector_settings() if attempt.connector_credential_pair.access_type == AccessType.SYNC: runnable_connector.validate_perm_sync() except UnexpectedValidationError as e: logger.exception( "Unable to instantiate connector due to an unexpected temporary issue." ) raise e except Exception as e: logger.exception("Unable to instantiate connector. Pausing until fixed.") # since we failed to even instantiate the connector, we pause the CCPair since # it will never succeed # Sometimes there are cases where the connector will # intermittently fail to initialize in which case we should pass in # leave_connector_active=True to allow it to continue. # For example, if there is nightly maintenance on a Confluence Server instance, # the connector will fail to initialize every night. if not leave_connector_active: cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=attempt.connector_credential_pair.id, ) if cc_pair and cc_pair.status == ConnectorCredentialPairStatus.ACTIVE: update_connector_credential_pair( db_session=db_session, connector_id=attempt.connector_credential_pair.connector.id, credential_id=attempt.connector_credential_pair.credential.id, status=ConnectorCredentialPairStatus.PAUSED, ) raise e return ConnectorRunner( connector=runnable_connector, batch_size=batch_size, include_permissions=include_permissions, time_range=(start_time, end_time), ) def strip_null_characters(doc_batch: list[Document]) -> list[Document]: cleaned_batch = [] for doc in doc_batch: if sys.getsizeof(doc) > MAX_FILE_SIZE_BYTES: logger.warning( f"doc {doc.id} too large, Document size: {sys.getsizeof(doc)}" ) cleaned_batch.append(sanitize_document_for_postgres(doc)) return cleaned_batch def _check_connector_and_attempt_status( db_session_temp: Session, cc_pair_id: int, search_settings_status: IndexModelStatus, index_attempt_id: int, ) -> None: """ Checks the status of the connector credential pair and index attempt. Raises a RuntimeError if any conditions are not met. """ cc_pair_loop = get_connector_credential_pair_from_id( db_session_temp, cc_pair_id, ) if not cc_pair_loop: raise RuntimeError(f"CC pair {cc_pair_id} not found in DB.") if ( cc_pair_loop.status == ConnectorCredentialPairStatus.PAUSED and search_settings_status != IndexModelStatus.FUTURE ) or cc_pair_loop.status == ConnectorCredentialPairStatus.DELETING: raise ConnectorStopSignal(f"Connector {cc_pair_loop.status.value.lower()}") index_attempt_loop = get_index_attempt(db_session_temp, index_attempt_id) if not index_attempt_loop: raise RuntimeError(f"Index attempt {index_attempt_id} not found in DB.") if index_attempt_loop.status == IndexingStatus.CANCELED: raise ConnectorStopSignal(f"Index attempt {index_attempt_id} was canceled") if index_attempt_loop.status != IndexingStatus.IN_PROGRESS: error_str = "" if index_attempt_loop.error_msg: error_str = f" Original error: {index_attempt_loop.error_msg}" raise RuntimeError( f"Index Attempt is not running, status is {index_attempt_loop.status}.{error_str}" ) if index_attempt_loop.celery_task_id is None: raise RuntimeError(f"Index attempt {index_attempt_id} has no celery task id") # TODO: delete from here if ends up unused def _check_failure_threshold( total_failures: int, document_count: int, batch_num: int, last_failure: ConnectorFailure | None, ) -> None: """Check if we've hit the failure threshold and raise an appropriate exception if so. We consider the threshold hit if: 1. We have more than 3 failures AND 2. Failures account for more than 10% of processed documents """ failure_ratio = total_failures / (document_count or 1) FAILURE_THRESHOLD = 3 FAILURE_RATIO_THRESHOLD = 0.1 if total_failures > FAILURE_THRESHOLD and failure_ratio > FAILURE_RATIO_THRESHOLD: logger.error( f"Connector run failed with '{total_failures}' errors after '{batch_num}' batches." ) if last_failure and last_failure.exception: raise last_failure.exception from last_failure.exception raise RuntimeError( f"Connector run encountered too many errors, aborting. Last error: {last_failure}" ) def run_docfetching_entrypoint( app: Celery, index_attempt_id: int, tenant_id: str, connector_credential_pair_id: int, is_ee: bool = False, callback: IndexingHeartbeatInterface | None = None, ) -> None: """Don't swallow exceptions here ... propagate them up.""" if is_ee: global_version.set_ee() # set the indexing attempt ID so that all log messages from this process # will have it added as a prefix token = INDEX_ATTEMPT_INFO_CONTEXTVAR.set( (connector_credential_pair_id, index_attempt_id) ) with get_session_with_current_tenant() as db_session: attempt = transition_attempt_to_in_progress(index_attempt_id, db_session) tenant_str = "" if MULTI_TENANT: tenant_str = f" for tenant {tenant_id}" connector_name = attempt.connector_credential_pair.connector.name connector_config = ( attempt.connector_credential_pair.connector.connector_specific_config ) credential_id = attempt.connector_credential_pair.credential_id logger.info( f"Docfetching starting{tenant_str}: " f"connector='{connector_name}' " f"config='{connector_config}' " f"credentials='{credential_id}'" ) connector_document_extraction( app, index_attempt_id, attempt.connector_credential_pair_id, attempt.search_settings_id, tenant_id, callback, ) logger.info( f"Docfetching finished{tenant_str}: " f"connector='{connector_name}' " f"config='{connector_config}' " f"credentials='{credential_id}'" ) INDEX_ATTEMPT_INFO_CONTEXTVAR.reset(token) def connector_document_extraction( app: Celery, index_attempt_id: int, cc_pair_id: int, search_settings_id: int, tenant_id: str, callback: IndexingHeartbeatInterface | None = None, ) -> None: """Extract documents from connector and queue them for indexing pipeline processing. This is the first part of the split indexing process that runs the connector and extracts documents, storing them in the filestore for later processing. """ start_time = time.monotonic() logger.info( f"Document extraction starting: " f"attempt={index_attempt_id} " f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id} " f"tenant={tenant_id}" ) # Get batch storage (transition to IN_PROGRESS is handled by run_indexing_entrypoint) batch_storage = get_document_batch_storage(cc_pair_id, index_attempt_id) # Initialize memory tracer. NOTE: won't actually do anything if # `INDEXING_TRACER_INTERVAL` is 0. memory_tracer = MemoryTracer(interval=INDEXING_TRACER_INTERVAL) memory_tracer.start() index_attempt = None last_batch_num = 0 # used to continue from checkpointing # comes from _run_indexing with get_session_with_current_tenant() as db_session: index_attempt = get_index_attempt( db_session, index_attempt_id, eager_load_cc_pair=True, eager_load_search_settings=True, ) if not index_attempt: raise RuntimeError(f"Index attempt {index_attempt_id} not found") if index_attempt.search_settings is None: raise ValueError("Search settings must be set for indexing") # Clear the indexing trigger if it was set, to prevent duplicate indexing attempts if index_attempt.connector_credential_pair.indexing_trigger is not None: logger.info( "Clearing indexing trigger: " f"cc_pair={index_attempt.connector_credential_pair.id} " f"trigger={index_attempt.connector_credential_pair.indexing_trigger}" ) mark_ccpair_with_indexing_trigger( index_attempt.connector_credential_pair.id, None, db_session ) db_connector = index_attempt.connector_credential_pair.connector db_credential = index_attempt.connector_credential_pair.credential processing_mode = index_attempt.connector_credential_pair.processing_mode is_primary = index_attempt.search_settings.status == IndexModelStatus.PRESENT is_connector_public = ( index_attempt.connector_credential_pair.access_type == AccessType.PUBLIC ) from_beginning = index_attempt.from_beginning has_successful_attempt = ( index_attempt.connector_credential_pair.last_successful_index_time is not None ) # Use higher priority for first-time indexing to ensure new connectors # get processed before re-indexing of existing connectors docprocessing_priority = ( OnyxCeleryPriority.MEDIUM if has_successful_attempt else OnyxCeleryPriority.HIGH ) earliest_index_time = ( db_connector.indexing_start.timestamp() if db_connector.indexing_start else 0 ) should_fetch_permissions_during_indexing = ( index_attempt.connector_credential_pair.access_type == AccessType.SYNC and source_should_fetch_permissions_during_indexing(db_connector.source) and is_primary # if we've already successfully indexed, let the doc_sync job # take care of doc-level permissions and (from_beginning or not has_successful_attempt) ) # Set up time windows for polling last_successful_index_poll_range_end = ( earliest_index_time if from_beginning else get_last_successful_attempt_poll_range_end( cc_pair_id=cc_pair_id, earliest_index=earliest_index_time, search_settings=index_attempt.search_settings, db_session=db_session, ) ) if last_successful_index_poll_range_end > POLL_CONNECTOR_OFFSET: window_start = datetime.fromtimestamp( last_successful_index_poll_range_end, tz=timezone.utc ) - timedelta(minutes=POLL_CONNECTOR_OFFSET) else: # don't go into "negative" time if we've never indexed before window_start = datetime.fromtimestamp(0, tz=timezone.utc) most_recent_attempt = next( iter( get_recent_completed_attempts_for_cc_pair( cc_pair_id=cc_pair_id, search_settings_id=index_attempt.search_settings_id, db_session=db_session, limit=1, ) ), None, ) # if the last attempt failed, try and use the same window. This is necessary # to ensure correctness with checkpointing. If we don't do this, things like # new slack channels could be missed (since existing slack channels are # cached as part of the checkpoint). if ( most_recent_attempt and most_recent_attempt.poll_range_end and ( most_recent_attempt.status == IndexingStatus.FAILED or most_recent_attempt.status == IndexingStatus.CANCELED ) ): window_end = most_recent_attempt.poll_range_end else: window_end = datetime.now(tz=timezone.utc) # set time range in db index_attempt.poll_range_start = window_start index_attempt.poll_range_end = window_end db_session.commit() # TODO: maybe memory tracer here # Set up connector runner connector_runner = _get_connector_runner( db_session=db_session, attempt=index_attempt, batch_size=INDEX_BATCH_SIZE, start_time=window_start, end_time=window_end, include_permissions=should_fetch_permissions_during_indexing, ) # don't use a checkpoint if we're explicitly indexing from # the beginning in order to avoid weird interactions between # checkpointing / failure handling # OR # if the last attempt was successful if index_attempt.from_beginning or ( most_recent_attempt and most_recent_attempt.status.is_successful() ): logger.info( f"Cleaning up all old batches for index attempt {index_attempt_id} before starting new run" ) batch_storage.cleanup_all_batches() checkpoint = connector_runner.connector.build_dummy_checkpoint() else: logger.info( f"Getting latest valid checkpoint for index attempt {index_attempt_id}" ) checkpoint, resuming_from_checkpoint = get_latest_valid_checkpoint( db_session=db_session, cc_pair_id=cc_pair_id, search_settings_id=index_attempt.search_settings_id, window_start=window_start, window_end=window_end, connector=connector_runner.connector, ) # checkpoint resumption OR the connector already finished. if ( isinstance(connector_runner.connector, CheckpointedConnector) and resuming_from_checkpoint ) or ( most_recent_attempt and most_recent_attempt.total_batches is not None and not checkpoint.has_more ): reissued_batch_count, completed_batches = reissue_old_batches( batch_storage, index_attempt_id, cc_pair_id, tenant_id, app, most_recent_attempt, docprocessing_priority, ) last_batch_num = reissued_batch_count + completed_batches index_attempt.completed_batches = completed_batches db_session.commit() else: logger.info( f"Cleaning up all batches for index attempt {index_attempt_id} before starting new run" ) # for non-checkpointed connectors, throw out batches from previous unsuccessful attempts # because we'll be getting those documents again anyways. batch_storage.cleanup_all_batches() # Save initial checkpoint save_checkpoint( db_session=db_session, index_attempt_id=index_attempt_id, checkpoint=checkpoint, ) try: batch_num = last_batch_num # starts at 0 if no last batch total_doc_batches_queued = 0 total_failures = 0 document_count = 0 # Ensure the SOURCE-type root hierarchy node exists before processing. # This is the root of the hierarchy tree for this source - all other # hierarchy nodes should ultimately have this as an ancestor. redis_client = get_redis_client(tenant_id=tenant_id) with get_session_with_current_tenant() as db_session: ensure_source_node_exists(redis_client, db_session, db_connector.source) # Main extraction loop while checkpoint.has_more: logger.info( f"Running '{db_connector.source.value}' connector with checkpoint: {checkpoint}" ) for ( document_batch, hierarchy_node_batch, failure, next_checkpoint, ) in connector_runner.run(checkpoint): # Check if connector is disabled mid run and stop if so unless it's the secondary # index being built. We want to populate it even for paused connectors # Often paused connectors are sources that aren't updated frequently but the # contents still need to be initially pulled. if callback and callback.should_stop(): raise ConnectorStopSignal("Connector stop signal detected") # will exception if the connector/index attempt is marked as paused/failed with get_session_with_current_tenant() as db_session_tmp: _check_connector_and_attempt_status( db_session_tmp, cc_pair_id, index_attempt.search_settings.status, index_attempt_id, ) # save record of any failures at the connector level if failure is not None: total_failures += 1 with get_session_with_current_tenant() as db_session: create_index_attempt_error( index_attempt_id, cc_pair_id, failure, db_session, ) _check_failure_threshold( total_failures, document_count, batch_num, failure ) # Save checkpoint if provided if next_checkpoint: checkpoint = next_checkpoint # Process hierarchy nodes batch - upsert to Postgres and cache in Redis if hierarchy_node_batch: hierarchy_node_batch_cleaned = ( sanitize_hierarchy_nodes_for_postgres(hierarchy_node_batch) ) with get_session_with_current_tenant() as db_session: upserted_nodes = upsert_hierarchy_nodes_batch( db_session=db_session, nodes=hierarchy_node_batch_cleaned, source=db_connector.source, commit=True, is_connector_public=is_connector_public, ) upsert_hierarchy_node_cc_pair_entries( db_session=db_session, hierarchy_node_ids=[n.id for n in upserted_nodes], connector_id=db_connector.id, credential_id=db_credential.id, commit=True, ) # Cache in Redis for fast ancestor resolution during doc processing redis_client = get_redis_client(tenant_id=tenant_id) cache_entries = [ HierarchyNodeCacheEntry.from_db_model(node) for node in upserted_nodes ] cache_hierarchy_nodes_batch( redis_client=redis_client, source=db_connector.source, entries=cache_entries, ) logger.debug( f"Persisted and cached {len(hierarchy_node_batch_cleaned)} hierarchy nodes for attempt={index_attempt_id}" ) # below is all document processing task, so if no batch we can just continue if not document_batch: continue # Clean documents and create batch doc_batch_cleaned = strip_null_characters(document_batch) # Resolve parent_hierarchy_raw_node_id to parent_hierarchy_node_id # using the Redis cache (just populated from hierarchy nodes batch) with get_session_with_current_tenant() as db_session_tmp: source_node_id = get_source_node_id_from_cache( redis_client, db_session_tmp, db_connector.source ) for doc in doc_batch_cleaned: if doc.parent_hierarchy_raw_node_id is not None: node_id, found = get_node_id_from_raw_id( redis_client, db_connector.source, doc.parent_hierarchy_raw_node_id, ) doc.parent_hierarchy_node_id = ( node_id if found else source_node_id ) else: doc.parent_hierarchy_node_id = source_node_id batch_description = [] for doc in doc_batch_cleaned: batch_description.append(doc.to_short_descriptor()) doc_size = 0 for section in doc.sections: if ( isinstance(section, TextSection) and section.text is not None ): doc_size += len(section.text) if doc_size > INDEXING_SIZE_WARNING_THRESHOLD: logger.warning( f"Document size: doc='{doc.to_short_descriptor()}' " f"size={doc_size} " f"threshold={INDEXING_SIZE_WARNING_THRESHOLD}" ) logger.debug(f"Indexing batch of documents: {batch_description}") memory_tracer.increment_and_maybe_trace() if processing_mode == ProcessingMode.FILE_SYSTEM: # File system only - write directly to persistent storage, # skip chunking/embedding/Vespa but still track documents in DB # IMPORTANT: Write to S3 FIRST, before marking as indexed in DB. # Write documents to persistent file system # Use creator_id for user-segregated storage paths (sandbox isolation) creator_id = index_attempt.connector_credential_pair.creator_id if creator_id is None: raise ValueError( f"ConnectorCredentialPair {index_attempt.connector_credential_pair.id} " "must have a creator_id for persistent document storage" ) user_id_str: str = str(creator_id) writer = get_persistent_document_writer( user_id=user_id_str, tenant_id=tenant_id, ) written_paths = writer.write_documents(doc_batch_cleaned) # Only after successful S3 write, mark documents as indexed in DB with get_session_with_current_tenant() as db_session: # Create metadata for the batch index_attempt_metadata = IndexAttemptMetadata( attempt_id=index_attempt_id, connector_id=db_connector.id, credential_id=db_credential.id, request_id=make_randomized_onyx_request_id("FSI"), structured_id=f"{tenant_id}:{cc_pair_id}:{index_attempt_id}:{batch_num}", batch_num=batch_num, ) # Upsert documents to PostgreSQL (document table + cc_pair relationship) # This is a subset of what docprocessing does - just DB tracking, no chunking/embedding index_doc_batch_prepare( documents=doc_batch_cleaned, index_attempt_metadata=index_attempt_metadata, db_session=db_session, ignore_time_skip=True, # Documents already filtered during extraction ) # Mark documents as indexed for the CC pair mark_document_as_indexed_for_cc_pair__no_commit( connector_id=db_connector.id, credential_id=db_credential.id, document_ids=[doc.id for doc in doc_batch_cleaned], db_session=db_session, ) db_session.commit() # Update coordination directly (no docprocessing task) with get_session_with_current_tenant() as db_session: IndexingCoordination.update_batch_completion_and_docs( db_session=db_session, index_attempt_id=index_attempt_id, total_docs_indexed=len(doc_batch_cleaned), new_docs_indexed=len(doc_batch_cleaned), total_chunks=0, # No chunks for file system mode ) batch_num += 1 total_doc_batches_queued += 1 logger.info( f"Wrote documents to file system: " f"batch_num={batch_num} " f"docs={len(written_paths)} " f"attempt={index_attempt_id}" ) else: # REGULAR mode (default): Full pipeline - store and queue docprocessing batch_storage.store_batch(batch_num, doc_batch_cleaned) # Create processing task data processing_batch_data = { "index_attempt_id": index_attempt_id, "cc_pair_id": cc_pair_id, "tenant_id": tenant_id, "batch_num": batch_num, # 0-indexed } # Queue document processing task app.send_task( OnyxCeleryTask.DOCPROCESSING_TASK, kwargs=processing_batch_data, queue=OnyxCeleryQueues.DOCPROCESSING, priority=docprocessing_priority, ) batch_num += 1 total_doc_batches_queued += 1 logger.info( f"Queued document processing batch: " f"batch_num={batch_num} " f"docs={len(doc_batch_cleaned)} " f"attempt={index_attempt_id}" ) # Check checkpoint size periodically CHECKPOINT_SIZE_CHECK_INTERVAL = 100 if batch_num % CHECKPOINT_SIZE_CHECK_INTERVAL == 0: check_checkpoint_size(checkpoint) # Save latest checkpoint # NOTE: checkpointing is used to track which batches have # been sent to the filestore, NOT which batches have been fully indexed # as it used to be. with get_session_with_current_tenant() as db_session: save_checkpoint( db_session=db_session, index_attempt_id=index_attempt_id, checkpoint=checkpoint, ) elapsed_time = time.monotonic() - start_time logger.info( f"Document extraction completed: " f"attempt={index_attempt_id} " f"batches_queued={total_doc_batches_queued} " f"elapsed={elapsed_time:.2f}s" ) # Set total batches in database to signal extraction completion. # Used by check_for_indexing to determine if the index attempt is complete. with get_session_with_current_tenant() as db_session: IndexingCoordination.set_total_batches( db_session=db_session, index_attempt_id=index_attempt_id, total_batches=batch_num, ) # Trigger file sync to user's sandbox (if running) - only for FILE_SYSTEM mode # This syncs the newly written documents from S3 to any running sandbox pod if processing_mode == ProcessingMode.FILE_SYSTEM: creator_id = index_attempt.connector_credential_pair.creator_id if creator_id: source_value = db_connector.source.value app.send_task( OnyxCeleryTask.SANDBOX_FILE_SYNC, kwargs={ "user_id": str(creator_id), "tenant_id": tenant_id, "source": source_value, }, queue=OnyxCeleryQueues.SANDBOX, ) logger.info( f"Triggered sandbox file sync for user {creator_id} source={source_value} after indexing complete" ) except Exception as e: logger.exception( f"Document extraction failed: attempt={index_attempt_id} error={str(e)}" ) # Do NOT clean up batches on failure; future runs will use those batches # while docfetching will continue from the saved checkpoint if one exists if isinstance(e, ConnectorValidationError): # On validation errors during indexing, we want to cancel the indexing attempt # and mark the CCPair as invalid. This prevents the connector from being # used in the future until the credentials are updated. with get_session_with_current_tenant() as db_session_temp: logger.exception( f"Marking attempt {index_attempt_id} as canceled due to validation error." ) mark_attempt_canceled( index_attempt_id, db_session_temp, reason=f"{CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX}{str(e)}", ) if is_primary: if not index_attempt: # should always be set by now raise RuntimeError("Should never happen.") VALIDATION_ERROR_THRESHOLD = 5 recent_index_attempts = get_recent_completed_attempts_for_cc_pair( cc_pair_id=cc_pair_id, search_settings_id=index_attempt.search_settings_id, limit=VALIDATION_ERROR_THRESHOLD, db_session=db_session_temp, ) num_validation_errors = len( [ index_attempt for index_attempt in recent_index_attempts if index_attempt.error_msg and index_attempt.error_msg.startswith( CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX ) ] ) if num_validation_errors >= VALIDATION_ERROR_THRESHOLD: logger.warning( f"Connector {db_connector.id} has {num_validation_errors} consecutive validation" f" errors. Marking the CC Pair as invalid." ) update_connector_credential_pair( db_session=db_session_temp, connector_id=db_connector.id, credential_id=db_credential.id, status=ConnectorCredentialPairStatus.INVALID, ) raise e elif isinstance(e, ConnectorStopSignal): with get_session_with_current_tenant() as db_session_temp: logger.exception( f"Marking attempt {index_attempt_id} as canceled due to stop signal." ) mark_attempt_canceled( index_attempt_id, db_session_temp, reason=str(e), ) else: with get_session_with_current_tenant() as db_session_temp: # don't overwrite attempts that are already failed/canceled for another reason index_attempt = get_index_attempt(db_session_temp, index_attempt_id) if index_attempt and index_attempt.status in [ IndexingStatus.CANCELED, IndexingStatus.FAILED, ]: logger.info( f"Attempt {index_attempt_id} is already failed/canceled, skipping marking as failed." ) raise e mark_attempt_failed( index_attempt_id, db_session_temp, failure_reason=str(e), full_exception_trace=traceback.format_exc(), ) raise e finally: memory_tracer.stop() def reissue_old_batches( batch_storage: DocumentBatchStorage, index_attempt_id: int, cc_pair_id: int, tenant_id: str, app: Celery, most_recent_attempt: IndexAttempt | None, priority: OnyxCeleryPriority, ) -> tuple[int, int]: # When loading from a checkpoint, we need to start new docprocessing tasks # tied to the new index attempt for any batches left over in the file store old_batches = batch_storage.get_all_batches_for_cc_pair() batch_storage.update_old_batches_to_new_index_attempt(old_batches) for batch_id in old_batches: logger.info( f"Re-issuing docprocessing task for batch {batch_id} for index attempt {index_attempt_id}" ) path_info = batch_storage.extract_path_info(batch_id) if path_info is None: logger.warning( f"Could not extract path info from batch {batch_id}, skipping" ) continue if path_info.cc_pair_id != cc_pair_id: raise RuntimeError(f"Batch {batch_id} is not for cc pair {cc_pair_id}") app.send_task( OnyxCeleryTask.DOCPROCESSING_TASK, kwargs={ "index_attempt_id": index_attempt_id, "cc_pair_id": cc_pair_id, "tenant_id": tenant_id, "batch_num": path_info.batch_num, # use same batch num as previously }, queue=OnyxCeleryQueues.DOCPROCESSING, priority=priority, ) recent_batches = most_recent_attempt.completed_batches if most_recent_attempt else 0 # resume from the batch num of the last attempt. This should be one more # than the last batch created by docfetching regardless of whether the batch # is still in the filestore waiting for processing or not. last_batch_num = len(old_batches) + recent_batches logger.info( f"Starting from batch {last_batch_num} due to re-issued batches: {old_batches}, completed batches: {recent_batches}" ) return len(old_batches), recent_batches ================================================ FILE: backend/onyx/background/periodic_poller.py ================================================ """Periodic poller for NO_VECTOR_DB deployments. Replaces Celery Beat and background workers with a lightweight daemon thread that runs from the API server process. Two responsibilities: 1. Recovery polling (every 30 s): re-processes user files stuck in PROCESSING / DELETING / needs_sync states via the drain loops defined in ``task_utils.py``. 2. Periodic task execution (configurable intervals): runs LLM model updates and scheduled evals at their configured cadences, with Postgres advisory lock deduplication across multiple API server instances. """ import threading import time from collections.abc import Callable from dataclasses import dataclass from dataclasses import field from onyx.utils.logger import setup_logger logger = setup_logger() RECOVERY_INTERVAL_SECONDS = 30 PERIODIC_TASK_LOCK_BASE = 20_000 PERIODIC_TASK_KV_PREFIX = "periodic_poller:last_claimed:" # ------------------------------------------------------------------ # Periodic task definitions # ------------------------------------------------------------------ _NEVER_RAN: float = -1e18 @dataclass class _PeriodicTaskDef: name: str interval_seconds: float lock_id: int run_fn: Callable[[], None] last_run_at: float = field(default=_NEVER_RAN) def _run_auto_llm_update() -> None: from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL if not AUTO_LLM_CONFIG_URL: return from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.llm.well_known_providers.auto_update_service import ( sync_llm_models_from_github, ) with get_session_with_current_tenant() as db_session: sync_llm_models_from_github(db_session) def _run_cache_cleanup() -> None: from onyx.cache.postgres_backend import cleanup_expired_cache_entries cleanup_expired_cache_entries() def _run_scheduled_eval() -> None: from onyx.configs.app_configs import BRAINTRUST_API_KEY from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES from onyx.configs.app_configs import SCHEDULED_EVAL_PERMISSIONS_EMAIL from onyx.configs.app_configs import SCHEDULED_EVAL_PROJECT if not all( [ BRAINTRUST_API_KEY, SCHEDULED_EVAL_PROJECT, SCHEDULED_EVAL_DATASET_NAMES, SCHEDULED_EVAL_PERMISSIONS_EMAIL, ] ): return from datetime import datetime from datetime import timezone from onyx.evals.eval import run_eval from onyx.evals.models import EvalConfigurationOptions run_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d") for dataset_name in SCHEDULED_EVAL_DATASET_NAMES: try: run_eval( configuration=EvalConfigurationOptions( search_permissions_email=SCHEDULED_EVAL_PERMISSIONS_EMAIL, dataset_name=dataset_name, no_send_logs=False, braintrust_project=SCHEDULED_EVAL_PROJECT, experiment_name=f"{dataset_name} - {run_timestamp}", ), remote_dataset_name=dataset_name, ) except Exception: logger.exception( f"Periodic poller - Failed scheduled eval for dataset {dataset_name}" ) _CACHE_CLEANUP_INTERVAL_SECONDS = 300 def _build_periodic_tasks() -> list[_PeriodicTaskDef]: from onyx.cache.interface import CacheBackendType from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS from onyx.configs.app_configs import CACHE_BACKEND from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES tasks: list[_PeriodicTaskDef] = [] if CACHE_BACKEND == CacheBackendType.POSTGRES: tasks.append( _PeriodicTaskDef( name="cache-cleanup", interval_seconds=_CACHE_CLEANUP_INTERVAL_SECONDS, lock_id=PERIODIC_TASK_LOCK_BASE + 2, run_fn=_run_cache_cleanup, ) ) if AUTO_LLM_CONFIG_URL: tasks.append( _PeriodicTaskDef( name="auto-llm-update", interval_seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS, lock_id=PERIODIC_TASK_LOCK_BASE, run_fn=_run_auto_llm_update, ) ) if SCHEDULED_EVAL_DATASET_NAMES: tasks.append( _PeriodicTaskDef( name="scheduled-eval", interval_seconds=7 * 24 * 3600, lock_id=PERIODIC_TASK_LOCK_BASE + 1, run_fn=_run_scheduled_eval, ) ) return tasks # ------------------------------------------------------------------ # Periodic task runner with advisory-lock-guarded claim # ------------------------------------------------------------------ def _try_claim_task(task_def: _PeriodicTaskDef) -> bool: """Atomically check whether *task_def* should run and record a claim. Uses a transaction-scoped advisory lock for atomicity combined with a ``KVStore`` timestamp for cross-instance dedup. The DB session is held only for this brief claim transaction, not during task execution. """ from datetime import datetime from datetime import timezone from sqlalchemy import text from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import KVStore kv_key = PERIODIC_TASK_KV_PREFIX + task_def.name with get_session_with_current_tenant() as db_session: acquired = db_session.execute( text("SELECT pg_try_advisory_xact_lock(:id)"), {"id": task_def.lock_id}, ).scalar() if not acquired: return False row = db_session.query(KVStore).filter_by(key=kv_key).first() if row and row.value is not None: last_claimed = datetime.fromisoformat(str(row.value)) elapsed = (datetime.now(timezone.utc) - last_claimed).total_seconds() if elapsed < task_def.interval_seconds: return False now_ts = datetime.now(timezone.utc).isoformat() if row: row.value = now_ts else: db_session.add(KVStore(key=kv_key, value=now_ts)) db_session.commit() return True def _try_run_periodic_task(task_def: _PeriodicTaskDef) -> None: """Run *task_def* if its interval has elapsed and no peer holds the lock.""" now = time.monotonic() if now - task_def.last_run_at < task_def.interval_seconds: return if not _try_claim_task(task_def): return try: task_def.run_fn() task_def.last_run_at = now except Exception: logger.exception( f"Periodic poller - Error running periodic task {task_def.name}" ) # ------------------------------------------------------------------ # Recovery / drain loop runner # ------------------------------------------------------------------ def _run_drain_loops(tenant_id: str) -> None: from onyx.background.task_utils import drain_delete_loop from onyx.background.task_utils import drain_processing_loop from onyx.background.task_utils import drain_project_sync_loop drain_processing_loop(tenant_id) drain_delete_loop(tenant_id) drain_project_sync_loop(tenant_id) # ------------------------------------------------------------------ # Startup recovery (10g) # ------------------------------------------------------------------ def recover_stuck_user_files(tenant_id: str) -> None: """Run all drain loops once to re-process files left in intermediate states. Called from ``lifespan()`` on startup when ``DISABLE_VECTOR_DB`` is set. """ logger.info("recover_stuck_user_files - Checking for stuck user files") try: _run_drain_loops(tenant_id) except Exception: logger.exception("recover_stuck_user_files - Error during recovery") # ------------------------------------------------------------------ # Daemon thread (10f) # ------------------------------------------------------------------ _shutdown_event = threading.Event() _poller_thread: threading.Thread | None = None def _poller_loop(tenant_id: str) -> None: from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) periodic_tasks = _build_periodic_tasks() logger.info( f"Periodic poller started with {len(periodic_tasks)} periodic task(s): {[t.name for t in periodic_tasks]}" ) while not _shutdown_event.is_set(): try: _run_drain_loops(tenant_id) except Exception: logger.exception("Periodic poller - Error in recovery polling") for task_def in periodic_tasks: try: _try_run_periodic_task(task_def) except Exception: logger.exception( f"Periodic poller - Unhandled error checking task {task_def.name}" ) _shutdown_event.wait(RECOVERY_INTERVAL_SECONDS) def start_periodic_poller(tenant_id: str) -> None: """Start the periodic poller daemon thread.""" global _poller_thread # noqa: PLW0603 _shutdown_event.clear() _poller_thread = threading.Thread( target=_poller_loop, args=(tenant_id,), daemon=True, name="no-vectordb-periodic-poller", ) _poller_thread.start() logger.info("Periodic poller thread started") def stop_periodic_poller() -> None: """Signal the periodic poller to stop and wait for it to exit.""" global _poller_thread # noqa: PLW0603 if _poller_thread is None: return _shutdown_event.set() _poller_thread.join(timeout=10) if _poller_thread.is_alive(): logger.warning("Periodic poller thread did not stop within timeout") _poller_thread = None logger.info("Periodic poller thread stopped") ================================================ FILE: backend/onyx/background/task_utils.py ================================================ """Background task utilities. Contains query-history report helpers (used by all deployment modes) and in-process background task execution helpers for NO_VECTOR_DB mode: - Atomic claim-and-mark helpers that prevent duplicate processing - Drain loops that process all pending user file work Each claim function runs a short-lived transaction: SELECT ... FOR UPDATE SKIP LOCKED, UPDATE the row to remove it from future queries, COMMIT. After the commit the row lock is released, but the row is no longer eligible for re-claiming. No long-lived sessions or advisory locks. """ from uuid import UUID import sqlalchemy as sa from sqlalchemy import select from sqlalchemy.orm import Session from onyx.db.enums import UserFileStatus from onyx.db.models import UserFile from onyx.utils.logger import setup_logger logger = setup_logger() # ------------------------------------------------------------------ # Query-history report helpers (pre-existing, used by all modes) # ------------------------------------------------------------------ QUERY_REPORT_NAME_PREFIX = "query-history" def construct_query_history_report_name( task_id: str, ) -> str: return f"{QUERY_REPORT_NAME_PREFIX}-{task_id}.csv" def extract_task_id_from_query_history_report_name(name: str) -> str: return name.removeprefix(f"{QUERY_REPORT_NAME_PREFIX}-").removesuffix(".csv") # ------------------------------------------------------------------ # Atomic claim-and-mark helpers # ------------------------------------------------------------------ # Each function runs inside a single short-lived session/transaction: # 1. SELECT ... FOR UPDATE SKIP LOCKED (locks one eligible row) # 2. UPDATE the row so it is no longer eligible # 3. COMMIT (releases the row lock) # After the commit, no other drain loop can claim the same row. def _claim_next_processing_file(db_session: Session) -> UUID | None: """Claim the next PROCESSING file by transitioning it to INDEXING. Returns the file id, or None when no eligible files remain. """ file_id = db_session.execute( select(UserFile.id) .where(UserFile.status == UserFileStatus.PROCESSING) .order_by(UserFile.created_at) .limit(1) .with_for_update(skip_locked=True) ).scalar_one_or_none() if file_id is None: return None db_session.execute( sa.update(UserFile) .where(UserFile.id == file_id) .values(status=UserFileStatus.INDEXING) ) db_session.commit() return file_id def _claim_next_deleting_file( db_session: Session, exclude_ids: set[UUID] | None = None, ) -> UUID | None: """Claim the next DELETING file. No status transition needed — the impl deletes the row on success. The short-lived FOR UPDATE lock prevents concurrent claims. *exclude_ids* prevents re-processing the same file if the impl fails. """ stmt = ( select(UserFile.id) .where(UserFile.status == UserFileStatus.DELETING) .order_by(UserFile.created_at) .limit(1) .with_for_update(skip_locked=True) ) if exclude_ids: stmt = stmt.where(UserFile.id.notin_(exclude_ids)) file_id = db_session.execute(stmt).scalar_one_or_none() db_session.commit() return file_id def _claim_next_sync_file( db_session: Session, exclude_ids: set[UUID] | None = None, ) -> UUID | None: """Claim the next file needing project/persona sync. No status transition needed — the impl clears the sync flags on success. The short-lived FOR UPDATE lock prevents concurrent claims. *exclude_ids* prevents re-processing the same file if the impl fails. """ stmt = ( select(UserFile.id) .where( sa.and_( sa.or_( UserFile.needs_project_sync.is_(True), UserFile.needs_persona_sync.is_(True), ), UserFile.status == UserFileStatus.COMPLETED, ) ) .order_by(UserFile.created_at) .limit(1) .with_for_update(skip_locked=True) ) if exclude_ids: stmt = stmt.where(UserFile.id.notin_(exclude_ids)) file_id = db_session.execute(stmt).scalar_one_or_none() db_session.commit() return file_id # ------------------------------------------------------------------ # Drain loops — process *all* pending work of each type # ------------------------------------------------------------------ def drain_processing_loop(tenant_id: str) -> None: """Process all pending PROCESSING user files.""" from onyx.background.celery.tasks.user_file_processing.tasks import ( process_user_file_impl, ) from onyx.db.engine.sql_engine import get_session_with_current_tenant while True: with get_session_with_current_tenant() as session: file_id = _claim_next_processing_file(session) if file_id is None: break try: process_user_file_impl( user_file_id=str(file_id), tenant_id=tenant_id, redis_locking=False, ) except Exception: logger.exception(f"Failed to process user file {file_id}") def drain_delete_loop(tenant_id: str) -> None: """Delete all pending DELETING user files.""" from onyx.background.celery.tasks.user_file_processing.tasks import ( delete_user_file_impl, ) from onyx.db.engine.sql_engine import get_session_with_current_tenant failed: set[UUID] = set() while True: with get_session_with_current_tenant() as session: file_id = _claim_next_deleting_file(session, exclude_ids=failed) if file_id is None: break try: delete_user_file_impl( user_file_id=str(file_id), tenant_id=tenant_id, redis_locking=False, ) except Exception: logger.exception(f"Failed to delete user file {file_id}") failed.add(file_id) def drain_project_sync_loop(tenant_id: str) -> None: """Sync all pending project/persona metadata for user files.""" from onyx.background.celery.tasks.user_file_processing.tasks import ( project_sync_user_file_impl, ) from onyx.db.engine.sql_engine import get_session_with_current_tenant failed: set[UUID] = set() while True: with get_session_with_current_tenant() as session: file_id = _claim_next_sync_file(session, exclude_ids=failed) if file_id is None: break try: project_sync_user_file_impl( user_file_id=str(file_id), tenant_id=tenant_id, redis_locking=False, ) except Exception: logger.exception(f"Failed to sync user file {file_id}") failed.add(file_id) ================================================ FILE: backend/onyx/cache/factory.py ================================================ from collections.abc import Callable from onyx.cache.interface import CacheBackend from onyx.cache.interface import CacheBackendType from onyx.configs.app_configs import CACHE_BACKEND def _build_redis_backend(tenant_id: str) -> CacheBackend: from onyx.cache.redis_backend import RedisCacheBackend from onyx.redis.redis_pool import redis_pool return RedisCacheBackend(redis_pool.get_client(tenant_id)) def _build_postgres_backend(tenant_id: str) -> CacheBackend: from onyx.cache.postgres_backend import PostgresCacheBackend return PostgresCacheBackend(tenant_id) _BACKEND_BUILDERS: dict[CacheBackendType, Callable[[str], CacheBackend]] = { CacheBackendType.REDIS: _build_redis_backend, CacheBackendType.POSTGRES: _build_postgres_backend, } def get_cache_backend(*, tenant_id: str | None = None) -> CacheBackend: """Return a tenant-aware ``CacheBackend``. If *tenant_id* is ``None``, the current tenant is read from the thread-local context variable (same behaviour as ``get_redis_client``). """ if tenant_id is None: from shared_configs.contextvars import get_current_tenant_id tenant_id = get_current_tenant_id() builder = _BACKEND_BUILDERS.get(CACHE_BACKEND) if builder is None: raise ValueError( f"Unsupported CACHE_BACKEND={CACHE_BACKEND!r}. Supported values: {[t.value for t in CacheBackendType]}" ) return builder(tenant_id) def get_shared_cache_backend() -> CacheBackend: """Return a ``CacheBackend`` in the shared (cross-tenant) namespace.""" from shared_configs.configs import DEFAULT_REDIS_PREFIX return get_cache_backend(tenant_id=DEFAULT_REDIS_PREFIX) ================================================ FILE: backend/onyx/cache/interface.py ================================================ import abc from enum import Enum from redis.exceptions import RedisError from sqlalchemy.exc import SQLAlchemyError TTL_KEY_NOT_FOUND = -2 TTL_NO_EXPIRY = -1 CACHE_TRANSIENT_ERRORS: tuple[type[Exception], ...] = (RedisError, SQLAlchemyError) """Exception types that represent transient cache connectivity / operational failures. Callers that want to fail-open (or fail-closed) on cache errors should catch this tuple instead of bare ``Exception``. When adding a new ``CacheBackend`` implementation, add its transient error base class(es) here so all call-sites pick it up automatically.""" class CacheBackendType(str, Enum): REDIS = "redis" POSTGRES = "postgres" class CacheLock(abc.ABC): """Abstract distributed lock returned by CacheBackend.lock().""" @abc.abstractmethod def acquire( self, blocking: bool = True, blocking_timeout: float | None = None, ) -> bool: raise NotImplementedError @abc.abstractmethod def release(self) -> None: raise NotImplementedError @abc.abstractmethod def owned(self) -> bool: raise NotImplementedError def __enter__(self) -> "CacheLock": if not self.acquire(): raise RuntimeError("Failed to acquire lock") return self def __exit__(self, *args: object) -> None: self.release() class CacheBackend(abc.ABC): """Thin abstraction over a key-value cache with TTL, locks, and blocking lists. Covers the subset of Redis operations used outside of Celery. When CACHE_BACKEND=postgres, a PostgreSQL-backed implementation is used instead. """ # -- basic key/value --------------------------------------------------- @abc.abstractmethod def get(self, key: str) -> bytes | None: raise NotImplementedError @abc.abstractmethod def set( self, key: str, value: str | bytes | int | float, ex: int | None = None, ) -> None: raise NotImplementedError @abc.abstractmethod def delete(self, key: str) -> None: raise NotImplementedError @abc.abstractmethod def exists(self, key: str) -> bool: raise NotImplementedError # -- TTL --------------------------------------------------------------- @abc.abstractmethod def expire(self, key: str, seconds: int) -> None: raise NotImplementedError @abc.abstractmethod def ttl(self, key: str) -> int: """Return remaining TTL in seconds. Returns ``TTL_NO_EXPIRY`` (-1) if key exists without expiry, ``TTL_KEY_NOT_FOUND`` (-2) if key is missing or expired. """ raise NotImplementedError # -- distributed lock -------------------------------------------------- @abc.abstractmethod def lock(self, name: str, timeout: float | None = None) -> CacheLock: raise NotImplementedError # -- blocking list (used by MCP OAuth BLPOP pattern) ------------------- @abc.abstractmethod def rpush(self, key: str, value: str | bytes) -> None: raise NotImplementedError @abc.abstractmethod def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None: """Block until a value is available on one of *keys*, or *timeout* expires. Returns ``(key, value)`` or ``None`` on timeout. """ raise NotImplementedError ================================================ FILE: backend/onyx/cache/postgres_backend.py ================================================ """PostgreSQL-backed ``CacheBackend`` for NO_VECTOR_DB deployments. Uses the ``cache_store`` table for key-value storage, PostgreSQL advisory locks for distributed locking, and a polling loop for the BLPOP pattern. """ import hashlib import struct import time import uuid from contextlib import AbstractContextManager from datetime import datetime from datetime import timedelta from datetime import timezone from sqlalchemy import delete from sqlalchemy import func from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy import update from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.orm import Session from onyx.cache.interface import CacheBackend from onyx.cache.interface import CacheLock from onyx.cache.interface import TTL_KEY_NOT_FOUND from onyx.cache.interface import TTL_NO_EXPIRY from onyx.db.models import CacheStore _LIST_KEY_PREFIX = "_q:" # ASCII: ':' (0x3A) < ';' (0x3B). Upper bound for range queries so [prefix+, prefix;) # captures all list-item keys (e.g. _q:mylist:123:uuid) without including other # lists whose names share a prefix (e.g. _q:mylist2:...). _LIST_KEY_RANGE_TERMINATOR = ";" _LIST_ITEM_TTL_SECONDS = 3600 _LOCK_POLL_INTERVAL = 0.1 _BLPOP_POLL_INTERVAL = 0.25 def _list_item_key(key: str) -> str: """Unique key for a list item. Timestamp for FIFO ordering; UUID prevents collision when concurrent rpush calls occur within the same nanosecond. """ return f"{_LIST_KEY_PREFIX}{key}:{time.time_ns()}:{uuid.uuid4().hex}" def _to_bytes(value: str | bytes | int | float) -> bytes: if isinstance(value, bytes): return value return str(value).encode() # ------------------------------------------------------------------ # Lock # ------------------------------------------------------------------ class PostgresCacheLock(CacheLock): """Advisory-lock-based distributed lock. Uses ``get_session_with_tenant`` for connection lifecycle. The lock is tied to the session's connection; releasing or closing the session frees it. NOTE: Unlike Redis locks, advisory locks do not auto-expire after ``timeout`` seconds. They are released when ``release()`` is called or when the session is closed. """ def __init__(self, lock_id: int, timeout: float | None, tenant_id: str) -> None: self._lock_id = lock_id self._timeout = timeout self._tenant_id = tenant_id self._session_cm: AbstractContextManager[Session] | None = None self._session: Session | None = None self._acquired = False def acquire( self, blocking: bool = True, blocking_timeout: float | None = None, ) -> bool: from onyx.db.engine.sql_engine import get_session_with_tenant self._session_cm = get_session_with_tenant(tenant_id=self._tenant_id) self._session = self._session_cm.__enter__() try: if not blocking: return self._try_lock() effective_timeout = blocking_timeout or self._timeout deadline = ( (time.monotonic() + effective_timeout) if effective_timeout else None ) while True: if self._try_lock(): return True if deadline is not None and time.monotonic() >= deadline: return False time.sleep(_LOCK_POLL_INTERVAL) finally: if not self._acquired: self._close_session() def release(self) -> None: if not self._acquired or self._session is None: return try: self._session.execute(select(func.pg_advisory_unlock(self._lock_id))) finally: self._acquired = False self._close_session() def owned(self) -> bool: return self._acquired def _close_session(self) -> None: if self._session_cm is not None: try: self._session_cm.__exit__(None, None, None) finally: self._session_cm = None self._session = None def _try_lock(self) -> bool: assert self._session is not None result = self._session.execute( select(func.pg_try_advisory_lock(self._lock_id)) ).scalar() if result: self._acquired = True return True return False # ------------------------------------------------------------------ # Backend # ------------------------------------------------------------------ class PostgresCacheBackend(CacheBackend): """``CacheBackend`` backed by the ``cache_store`` table in PostgreSQL. Each operation opens and closes its own database session so the backend is safe to share across threads. Tenant isolation is handled by SQLAlchemy's ``schema_translate_map`` (set by ``get_session_with_tenant``). """ def __init__(self, tenant_id: str) -> None: self._tenant_id = tenant_id # -- basic key/value --------------------------------------------------- def get(self, key: str) -> bytes | None: from onyx.db.engine.sql_engine import get_session_with_tenant stmt = select(CacheStore.value).where( CacheStore.key == key, or_(CacheStore.expires_at.is_(None), CacheStore.expires_at > func.now()), ) with get_session_with_tenant(tenant_id=self._tenant_id) as session: value = session.execute(stmt).scalar_one_or_none() if value is None: return None return bytes(value) def set( self, key: str, value: str | bytes | int | float, ex: int | None = None, ) -> None: from onyx.db.engine.sql_engine import get_session_with_tenant value_bytes = _to_bytes(value) expires_at = ( datetime.now(timezone.utc) + timedelta(seconds=ex) if ex is not None else None ) stmt = ( pg_insert(CacheStore) .values(key=key, value=value_bytes, expires_at=expires_at) .on_conflict_do_update( index_elements=[CacheStore.key], set_={"value": value_bytes, "expires_at": expires_at}, ) ) with get_session_with_tenant(tenant_id=self._tenant_id) as session: session.execute(stmt) session.commit() def delete(self, key: str) -> None: from onyx.db.engine.sql_engine import get_session_with_tenant with get_session_with_tenant(tenant_id=self._tenant_id) as session: session.execute(delete(CacheStore).where(CacheStore.key == key)) session.commit() def exists(self, key: str) -> bool: from onyx.db.engine.sql_engine import get_session_with_tenant stmt = ( select(CacheStore.key) .where( CacheStore.key == key, or_( CacheStore.expires_at.is_(None), CacheStore.expires_at > func.now(), ), ) .limit(1) ) with get_session_with_tenant(tenant_id=self._tenant_id) as session: return session.execute(stmt).first() is not None # -- TTL --------------------------------------------------------------- def expire(self, key: str, seconds: int) -> None: from onyx.db.engine.sql_engine import get_session_with_tenant new_exp = datetime.now(timezone.utc) + timedelta(seconds=seconds) stmt = ( update(CacheStore).where(CacheStore.key == key).values(expires_at=new_exp) ) with get_session_with_tenant(tenant_id=self._tenant_id) as session: session.execute(stmt) session.commit() def ttl(self, key: str) -> int: from onyx.db.engine.sql_engine import get_session_with_tenant stmt = select(CacheStore.expires_at).where(CacheStore.key == key) with get_session_with_tenant(tenant_id=self._tenant_id) as session: result = session.execute(stmt).first() if result is None: return TTL_KEY_NOT_FOUND expires_at: datetime | None = result[0] if expires_at is None: return TTL_NO_EXPIRY remaining = (expires_at - datetime.now(timezone.utc)).total_seconds() if remaining <= 0: return TTL_KEY_NOT_FOUND return int(remaining) # -- distributed lock -------------------------------------------------- def lock(self, name: str, timeout: float | None = None) -> CacheLock: return PostgresCacheLock( self._lock_id_for(name), timeout, tenant_id=self._tenant_id ) # -- blocking list (MCP OAuth BLPOP pattern) --------------------------- def rpush(self, key: str, value: str | bytes) -> None: self.set(_list_item_key(key), value, ex=_LIST_ITEM_TTL_SECONDS) def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None: if timeout <= 0: raise ValueError( "PostgresCacheBackend.blpop requires timeout > 0. " "timeout=0 would block the calling thread indefinitely " "with no way to interrupt short of process termination." ) from onyx.db.engine.sql_engine import get_session_with_tenant deadline = time.monotonic() + timeout while True: for key in keys: lower = f"{_LIST_KEY_PREFIX}{key}:" upper = f"{_LIST_KEY_PREFIX}{key}{_LIST_KEY_RANGE_TERMINATOR}" stmt = ( select(CacheStore) .where( CacheStore.key >= lower, CacheStore.key < upper, or_( CacheStore.expires_at.is_(None), CacheStore.expires_at > func.now(), ), ) .order_by(CacheStore.key) .limit(1) .with_for_update(skip_locked=True) ) with get_session_with_tenant(tenant_id=self._tenant_id) as session: row = session.execute(stmt).scalars().first() if row is not None: value = bytes(row.value) if row.value else b"" session.delete(row) session.commit() return (key.encode(), value) if time.monotonic() >= deadline: return None time.sleep(_BLPOP_POLL_INTERVAL) # -- helpers ----------------------------------------------------------- def _lock_id_for(self, name: str) -> int: """Map *name* to a 64-bit signed int for ``pg_advisory_lock``.""" h = hashlib.md5( f"{self._tenant_id}:{name}".encode(), usedforsecurity=False ).digest() return struct.unpack("q", h[:8])[0] # ------------------------------------------------------------------ # Periodic cleanup # ------------------------------------------------------------------ def cleanup_expired_cache_entries() -> None: """Delete rows whose ``expires_at`` is in the past. Called by the periodic poller every 5 minutes. """ from onyx.db.engine.sql_engine import get_session_with_current_tenant with get_session_with_current_tenant() as session: session.execute( delete(CacheStore).where( CacheStore.expires_at.is_not(None), CacheStore.expires_at < func.now(), ) ) session.commit() ================================================ FILE: backend/onyx/cache/redis_backend.py ================================================ from typing import cast from redis.client import Redis from redis.lock import Lock as RedisLock from onyx.cache.interface import CacheBackend from onyx.cache.interface import CacheLock class RedisCacheLock(CacheLock): """Wraps ``redis.lock.Lock`` behind the ``CacheLock`` interface.""" def __init__(self, lock: RedisLock) -> None: self._lock = lock def acquire( self, blocking: bool = True, blocking_timeout: float | None = None, ) -> bool: return bool( self._lock.acquire( blocking=blocking, blocking_timeout=blocking_timeout, ) ) def release(self) -> None: self._lock.release() def owned(self) -> bool: return bool(self._lock.owned()) class RedisCacheBackend(CacheBackend): """``CacheBackend`` implementation that delegates to a ``redis.Redis`` client. This is a thin pass-through — every method maps 1-to-1 to the underlying Redis command. ``TenantRedis`` key-prefixing is handled by the client itself (provided by ``get_redis_client``). """ def __init__(self, redis_client: Redis) -> None: self._r = redis_client # -- basic key/value --------------------------------------------------- def get(self, key: str) -> bytes | None: val = self._r.get(key) if val is None: return None if isinstance(val, bytes): return val return str(val).encode() def set( self, key: str, value: str | bytes | int | float, ex: int | None = None, ) -> None: self._r.set(key, value, ex=ex) def delete(self, key: str) -> None: self._r.delete(key) def exists(self, key: str) -> bool: return bool(self._r.exists(key)) # -- TTL --------------------------------------------------------------- def expire(self, key: str, seconds: int) -> None: self._r.expire(key, seconds) def ttl(self, key: str) -> int: return cast(int, self._r.ttl(key)) # -- distributed lock -------------------------------------------------- def lock(self, name: str, timeout: float | None = None) -> CacheLock: return RedisCacheLock(self._r.lock(name, timeout=timeout)) # -- blocking list (MCP OAuth BLPOP pattern) --------------------------- def rpush(self, key: str, value: str | bytes) -> None: self._r.rpush(key, value) def blpop(self, keys: list[str], timeout: int = 0) -> tuple[bytes, bytes] | None: result = cast(list[bytes] | None, self._r.blpop(keys, timeout=timeout)) if result is None: return None return (result[0], result[1]) ================================================ FILE: backend/onyx/chat/COMPRESSION.md ================================================ # Chat History Compression Compresses long chat histories by summarizing older messages while keeping recent ones verbatim. ## Architecture Decisions ### Branch-Aware via Tree Structure Summaries are stored as `ChatMessage` records with two key fields: - `parent_message_id` → last message when compression triggered (places summary in the tree) - `last_summarized_message_id` → pointer to an older message up the chain (the cutoff). Messages after this are kept verbatim. **Why store summary as a separate message?** If we embedded the summary in the `last_summarized_message_id` message itself, that message would contain context from messages that came after it—context that doesn't exist in other branches. By creating the summary as a new message attached to the branch tip, it only applies to the specific branch where compression occurred. It's only back-pointed to by the branch which it applies to. All of this is necessary because we keep the last few messages verbatim and also to support branching logic. ### Progressive Summarization Subsequent compressions incorporate the existing summary text + new messages, preventing information loss in very long conversations. ### Cutoff Marker Prompt Strategy The LLM receives older messages, a cutoff marker, then recent messages. It summarizes only content before the marker while using recent context to inform what's important. ## Token Budget Context window breakdown: - `max_context_tokens` — LLM's total context window - `reserved_tokens` — space for system prompt, tools, files, etc. - Available for chat history = `max_context_tokens - reserved_tokens` Note: If there is a lot of reserved tokens, chat compression may happen fairly frequently which is costly, slow, and leads to a bad user experience. Possible area of future improvement. Configurable ratios: - `COMPRESSION_TRIGGER_RATIO` (default 0.75) — compress when chat history exceeds this ratio of available space - `RECENT_MESSAGES_RATIO` (default 0.2) — portion of chat history to keep verbatim when compressing ## Flow 1. Trigger when `history_tokens > available * 0.75` 2. Find existing summary for branch (if any) 3. Split messages: older (summarize) / recent (keep 25%) 4. Generate summary via LLM 5. Save as `ChatMessage` with `parent_message_id` + `last_summarized_message_id` ## Key Functions | Function | Purpose | |----------|---------| | `get_compression_params` | Check if compression needed based on token counts | | `find_summary_for_branch` | Find applicable summary by checking `parent_message_id` membership | | `get_messages_to_summarize` | Split messages at token budget boundary | | `compress_chat_history` | Orchestrate flow, save summary message | ================================================ FILE: backend/onyx/chat/README.md ================================================ # Overview of Context Management This document reviews some design decisions around the main agent-loop powering Onyx's chat flow. It is highly recommended for all engineers contributing to this flow to be familiar with the concepts here. > Note: it is assumed the reader is familiar with the Onyx product and features such as Projects, User files, Citations, etc. ## System Prompt The system prompt is a default prompt that comes packaged with the system. Users can edit the default prompt and it will be persisted in the database. Some parts of the system prompt are dynamically updated / inserted: - Datetime of the message sent - Tools description of when to use certain tools depending on if the tool is available in that cycle - If the user has just called a search related tool, then a section about citations is included ## Custom Agent Prompt The custom agent is inserted as a user message above the most recent user message, it is dynamically moved in the history as the user sends more messages. If the user has opted to completely replace the System Prompt, then this Custom Agent prompt replaces the system prompt and does not move along the history. ## How Files are handled On upload, Files are processed for tokens, if too many tokens to fit in the context, it’s considered a failed inclusion. This is done using the LLM tokenizer. - In many cases, there is not a known tokenizer for each LLM so there is a default tokenizer used as a catchall. - File upload happens in 2 parts - the actual upload + token counting. - Files are added into chat context as a “point in time” inclusion and move up the context window as the conversation progresses. Every file knows how many tokens it is (model agnostic), image files have some assumed number of tokens. Image files are attached to User Messages also as point in time inclusions. **Future Extension**: Files selected from the search results are also counted as “point in time” inclusions. Files that are too large cannot be selected. For these files, the "entire file" does not exist for most connectors, it's pieced back together from the search engine. ## Projects If a Project contains few enough files that it all fits in the model context, we keep it close enough in the history to ensure it is easy for the LLM to access. Note that the project documents are assumed to be quite useful and that they should 1. never be dropped from context, 2. is not just a needle in a haystack type search with a strong keyword to make the LLM attend to it. Project files are vectorized and stored in the Search Engine so that if the user chooses a model with less context than the number of tokens in the project, the system can RAG over the project files. ## How documents are represented Documents from search or uploaded Project files are represented as a json so that the LLM can easily understand it. It is represented with a prefix string to make the context clearer to the LLM. Note that for search results (whether web or internal, it will just be the json) and it will be a Tool Call type of message rather than a user message. ``` Here are some documents provided for context, they may not all be relevant: { "documents": [ {"document": 1, "title": "Hello", "metadata": "status closed", "contents": "Foo"}, {"document": 2, "title": "World", "contents": "Bar"} ] } ``` Documents are represented with the `document` key so that the LLM can easily cite them with a single number. The tool returns have to be richer to be able to translate this into links and other UI elements. What the LLM sees is far simpler to reduce noise/hallucinations. Note that documents included in a single turn should be collapsed into a single user message. Search tools also give URLs to the LLM so that open_url (a separate tool) can be called on them. ## Reminders To ensure the LLM follows certain specific instructions, instructions are added at the very end of the chat context as a user message. If a search related tool is used, a citation reminder is always added. Otherwise, by default there is no reminder. If the user configures reminders, those are added to the final message. If a search related tool just ran and the user has reminders, both appear in a single message. If a search related tool is called at any point during the turn, the reminder will remain at the end until the turn is over and the agent has responded. ## Tool Calls As tool call responses can get very long (like an internal search can be many thousands of tokens), tool responses are current replaced with a hardcoded string saying it is no longer available. Tool Call details like the search query and other arguments are kept in the history as this is information rich and generally very few tokens. > Note: in the Internal Search flow with query expansion, the Tool Call which was actually run differs from what the LLM provided as arguments. > What the LLM sees in the history (to be most informative for future calls) is the full set of expanded queries. **Possible Future Extension**: Instead of dropping the Tool Call response, we might summarize it using an LLM so that it is just 1-2 sentences and captures the main points. That said, this is questionable value add because anything relevant and useful should be already captured in the Agent response. ## Examples ``` S -> System Message CA -> Custom Agent as a User Message A -> Agent Message response to user U -> User Message TC -> Agent Message for a tool call TR -> Tool response R -> Reminder F -> Point in time File P -> Project Files (not overflowed case) 1,2,3 etc. to represent turn number. A turn consists of a user input and a final response from the Agent Flow with Custom Agent S, U1, TC, TR, A1, CA, U2, A2 -- user sends another message, triggers tool call -> S, U1, TC, TR, A1, U2, A2, CA, U3, TC, TR, R, A3 - Custom agent response moves - Reminder inserted after TR Flow with Project and File Upload S, CA, P, F, U1, A1 -- user sends another message -> S, F, U1, A1, CA, P, U2, A2 - File stays in place, above the user message - Project files move along the chain as new messages are sent - Custom Agent prompt comes before project files which come before user uploaded files in each turn Reminders during a single Turn S, U1, TC, TR, R -- agent calls another tool -> S, U1, TC, TR, TC, TR, R, A1 - Reminder moved to the end ``` ## Product considerations Project files are important to the entire duration of the chat session. If the user has uploaded project files, they are likely very intent on working with those files. The LLM is much better at referencing documents close to the end of the context window so keeping it there for ease of access. User uploaded files are considered relevant for that point in time, it is ok if the Agent forgets about it as the chat gets long. If every uploaded file is constantly moved towards the end of the chat, it would degrade quality as these stack up. Even with a single file, there is some cost of making the previous User Message further away. This tradeoff is accepted for Projects because of the intent of the feature. Reminder are absolutely necessary to ensure 1-2 specific instructions get followed with a very high probability. It is less detailed than the system prompt and should be very targetted for it to work reliably and also not interfere with the last user message. ## Reasons / Experiments Custom Agent instructions being placed in the system prompt is poorly followed. It also degrades performance of the system especially when the instructions are orthogonal (or even possibly contradictory) to the system prompt. For weaker models, it causes strange artifacts in tool calls and final responses that completely ruins the user experience. Empirically, this way works better across a range of models especially when the history gets longer. Having the Custom Agent instructions not move means it fades more as the chat gets long which is also not ok from a UX perspective. Different LLMs vary in this but some now have a section that cannot be set via the API layer called the "System Prompt" (OpenAI terminology) which contains information like the model cutoff date, identity, and some other basic non-changing information. The System prompt described above is in that convention called the "Developer Prompt". It seems the distribution of the System Prompt, by which I mean the style of wording and terms used can also affect the behavior. This is different between different models and not necessarily scientific so the system prompt is built from an exploration across different models. It currently starts with: "You are a highly capable, thoughtful, and precise assistant. Your goal is to deeply understand the user's intent..." LLMs are able to handle changes in topic best at message boundaries. There are special tokens under the hood for this. We also use this property to slice up the history in the way presented above. Reminder messages are placed at the end of the prompt because all model fine tuning approaches cause the LLMs to attend very strongly to the tokens at the very back of the context closest to generation. This is the only way to get the LLMs to not miss critical information and for the product to be reliable. Specifically the built-in reminders are around citations and what tools it should call in certain situations. The document json includes a field for the LLM to cite (it's a single number) to make citations reliable and avoid weird artifacts. It's called "document" so that the LLM does not create weird artifacts in reasoning like "I should reference citation_id: 5 for...". It is also strategically placed so that it is easy to reference. It is followed by a couple short sections like the metadata and title before the long content section. It seems LLMs are still better at local attention despite having global access. In a similar concept, LLM instructions in the system prompt are structured specifically so that there are coherent sections for the LLM to attend to. This is fairly surprising actually but if there is a line of instructions effectively saying "If you try to use some tools and find that you need more information or need to call additional tools, you are encouraged to do this", having this in the Tool section of the System prompt makes all the LLMs follow it well but if it's even just a paragraph away like near the beginning of the prompt, it is often ignored. The difference is as drastic as a 30% follow rate to a 90% follow rate by even just moving the same statement a few sentences. ## Other related pointers - How messages, files, images are stored can be found in backend/onyx/db/models.py, there is also a README.md under that directory that may be helpful. --- # Overview of LLM flow architecture **Concepts:** Turn: User sends a message and AI does some set of things and responds Step/Cycle: 1 single LLM inference given some context and some tools ## 1. Top Level (process_message function): This function can be thought of as the set-up and validation layer. It ensures that the database is in a valid state, reads the messages in the session and sets up all the necessary items to run the chat loop and state containers. The major things it does are: - Validates the request - Builds the chat history for the session - Fetches any additional context such as files and images - Prepares all of the tools for the LLM - Creates the state container objects for use in the loop ### Execution (`_run_models` function): Each model runs in its own worker thread inside a `ThreadPoolExecutor`. Workers write packets to a shared `merged_queue` via an `Emitter`; the main thread drains the queue and yields packets in arrival order. This means the top level is isolated from the LLM flow and can yield packets as soon as they are produced. If a worker fails, the main thread yields a `StreamingError` for that model and keeps the other models running. All saving and database operations are handled by the main thread after the workers complete (or by the workers themselves via self-completion if the drain loop exits early). ### Emitter The emitter is an object that lower levels use to send packets without needing to yield them all the way back up the call stack. Each `Emitter` tags every packet with a `model_index` and places it on the shared `merged_queue` as a `(model_idx, packet)` tuple. The drain loop in `_run_models` consumes these tuples and yields the packets to the caller. Both the emitter and the state container are mutating state objects used only to accumulate state. There should be no logic dependent on the states of these objects, especially in the lower levels. The emitter should only take packets and should not be used for other things. ### State Container The state container is used to accumulate state during the LLM flow. Similar to the emitter, it should not be used for logic, only for accumulating state. It is used to gather all of the necessary information for saving the chat turn into the database. So it will accumulate answer tokens, reasoning tokens, tool calls, citation info, etc. This is used at the end of the flow once the lower level is completed whether on its own or stopped by the user. At that point, all of the state is read and stored into the database. The state container can be added to by any of the underlying layers, this is fine. ### Stopping Generation The drain loop in `_run_models` checks `check_is_connected()` every 50 ms (on queue timeout). The signal itself is stored in Redis and is set by the user calling the stop endpoint. On disconnect, the drain loop saves partial state for every model, yields an `OverallStop(stop_reason="user_cancelled")` packet, and returns. A `drain_done` event signals emitters to stop blocking so worker threads can exit quickly. Workers that already completed successfully will self-complete (persist their response) if the drain loop exited before reaching the normal completion path. ## 2. LLM Loop (run_llm_loop function) This function handles the logic of the Turn. It's essentially a while loop where context is added and modified (according what is outlined in the first half of this doc). Its main functionality is: - Translate and truncate the context for the LLM inference - Add context modifiers like reminders, updates to the system prompts, etc. - Run tool calls and gather results - Build some of the objects stored in the state container. ## 3. LLM Step (run_llm_step function) This function is a single inference of the LLM. It's a wrapper around the LLM stream function which handles packet translations so that the Emitter can emit individual tokens as soon as they arrive. It also keeps track of the different sections since they do not all come at once (reasoning, answers, tool calls are all built up token by token). This layer also tracks the different tool calls and returns that to the LLM Loop to execute. ## Things to know - Packets are labeled with a "turn_index" field as part of the Placement of the packet. This is not the same as the backend concept of a turn. The turn_index for the frontend is which block does this packet belong to. So while a reasoning + tool call comes from the same LLM inference (same backend LLM step), they are 2 turns to the frontend because that's how it's rendered. - There are 3 representations of a message, each scoped to a different layer: 1. **ChatMessage** — The database model. Should be converted into ChatMessageSimple early and never passed deep into the flow. 2. **ChatMessageSimple** — The canonical data model used throughout the codebase. This is the rich, full-featured representation of a message. Any modifications or additions to message structure should be made here. 3. **LanguageModelInput** — The LLM-facing representation. Intentionally minimal so the LLM interface layer stays clean and easy to maintain/extend. ================================================ FILE: backend/onyx/chat/__init__.py ================================================ ================================================ FILE: backend/onyx/chat/chat_processing_checker.py ================================================ from uuid import UUID from onyx.cache.interface import CacheBackend PREFIX = "chatprocessing" FENCE_PREFIX = f"{PREFIX}_fence" FENCE_TTL = 30 * 60 # 30 minutes def _get_fence_key(chat_session_id: UUID) -> str: """Generate the cache key for a chat session processing fence. Args: chat_session_id: The UUID of the chat session Returns: The fence key string. Tenant isolation is handled automatically by the cache backend (Redis key-prefixing or Postgres schema routing). """ return f"{FENCE_PREFIX}_{chat_session_id}" def set_processing_status( chat_session_id: UUID, cache: CacheBackend, value: bool ) -> None: """Set or clear the fence for a chat session processing a message. If the key exists, a message is being processed. Args: chat_session_id: The UUID of the chat session cache: Tenant-aware cache backend value: True to set the fence, False to clear it """ fence_key = _get_fence_key(chat_session_id) if value: cache.set(fence_key, 0, ex=FENCE_TTL) else: cache.delete(fence_key) def is_chat_session_processing(chat_session_id: UUID, cache: CacheBackend) -> bool: """Check if the chat session is processing a message. Args: chat_session_id: The UUID of the chat session cache: Tenant-aware cache backend Returns: True if the chat session is processing a message, False otherwise """ return cache.exists(_get_fence_key(chat_session_id)) ================================================ FILE: backend/onyx/chat/chat_state.py ================================================ import threading from collections.abc import Callable from dataclasses import dataclass from uuid import UUID from pydantic import BaseModel from onyx.cache.interface import CacheBackend from onyx.chat.citation_processor import CitationMapping from onyx.chat.models import ChatLoadedFile from onyx.chat.models import ChatMessageSimple from onyx.chat.models import ExtractedContextFiles from onyx.chat.models import FileToolMetadata from onyx.chat.models import SearchParams from onyx.context.search.models import SearchDoc from onyx.db.memory import UserMemoryContext from onyx.db.models import ChatMessage from onyx.db.models import ChatSession from onyx.db.models import Persona from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLMUserIdentity from onyx.onyxbot.slack.models import SlackContext from onyx.server.query_and_chat.models import SendMessageRequest from onyx.tools.models import ChatFile from onyx.tools.models import ToolCallInfo # Type alias for search doc deduplication key # Simple key: just document_id (str) # Full key: (document_id, chunk_ind, match_highlights) SearchDocKey = str | tuple[str, int, tuple[str, ...]] class ChatStateContainer: """Container for accumulating state during LLM loop execution. This container holds the partial state that can be saved to the database if the generation is stopped by the user or completes normally. Thread-safe: All write operations are protected by a lock to ensure safe concurrent access from multiple threads. For thread-safe reads, use the getter methods. Direct attribute access is not thread-safe. """ def __init__(self) -> None: self._lock = threading.Lock() # These are collected at the end after the entire tool call is completed self.tool_calls: list[ToolCallInfo] = [] # This is accumulated during the streaming self.reasoning_tokens: str | None = None # This is accumulated during the streaming of the answer self.answer_tokens: str | None = None # Store citation mapping for building citation_docs_info during partial saves self.citation_to_doc: CitationMapping = {} # True if this turn is a clarification question (deep research flow) self.is_clarification: bool = False # Pre-answer processing time (time before answer starts) in seconds self.pre_answer_processing_time: float | None = None # Note: LLM cost tracking is now handled in multi_llm.py # Search doc collection - maps dedup key to SearchDoc for all docs from tool calls self._all_search_docs: dict[SearchDocKey, SearchDoc] = {} # Track which citation numbers were actually emitted during streaming self._emitted_citations: set[int] = set() def add_tool_call(self, tool_call: ToolCallInfo) -> None: """Add a tool call to the accumulated state.""" with self._lock: self.tool_calls.append(tool_call) def set_reasoning_tokens(self, reasoning: str | None) -> None: """Set the reasoning tokens from the final answer generation.""" with self._lock: self.reasoning_tokens = reasoning def set_answer_tokens(self, answer: str | None) -> None: """Set the answer tokens from the final answer generation.""" with self._lock: self.answer_tokens = answer def set_citation_mapping(self, citation_to_doc: CitationMapping) -> None: """Set the citation mapping from citation processor.""" with self._lock: self.citation_to_doc = citation_to_doc def set_is_clarification(self, is_clarification: bool) -> None: """Set whether this turn is a clarification question.""" with self._lock: self.is_clarification = is_clarification def get_answer_tokens(self) -> str | None: """Thread-safe getter for answer_tokens.""" with self._lock: return self.answer_tokens def get_reasoning_tokens(self) -> str | None: """Thread-safe getter for reasoning_tokens.""" with self._lock: return self.reasoning_tokens def get_tool_calls(self) -> list[ToolCallInfo]: """Thread-safe getter for tool_calls (returns a copy).""" with self._lock: return self.tool_calls.copy() def get_citation_to_doc(self) -> CitationMapping: """Thread-safe getter for citation_to_doc (returns a copy).""" with self._lock: return self.citation_to_doc.copy() def get_is_clarification(self) -> bool: """Thread-safe getter for is_clarification.""" with self._lock: return self.is_clarification def set_pre_answer_processing_time(self, duration: float | None) -> None: """Set the pre-answer processing time (time before answer starts).""" with self._lock: self.pre_answer_processing_time = duration def get_pre_answer_processing_time(self) -> float | None: """Thread-safe getter for pre_answer_processing_time.""" with self._lock: return self.pre_answer_processing_time @staticmethod def create_search_doc_key( search_doc: SearchDoc, use_simple_key: bool = True ) -> SearchDocKey: """Create a unique key for a SearchDoc for deduplication. Args: search_doc: The SearchDoc to create a key for use_simple_key: If True (default), use only document_id for deduplication. If False, include chunk_ind and match_highlights so that the same document/chunk with different highlights are stored separately. """ if use_simple_key: return search_doc.document_id match_highlights_tuple = tuple(sorted(search_doc.match_highlights or [])) return (search_doc.document_id, search_doc.chunk_ind, match_highlights_tuple) def add_search_docs( self, search_docs: list[SearchDoc], use_simple_key: bool = True ) -> None: """Add search docs to the accumulated collection with deduplication. Args: search_docs: List of SearchDoc objects to add use_simple_key: If True (default), deduplicate by document_id only. If False, deduplicate by document_id + chunk_ind + match_highlights. """ with self._lock: for doc in search_docs: key = self.create_search_doc_key(doc, use_simple_key) if key not in self._all_search_docs: self._all_search_docs[key] = doc def get_all_search_docs(self) -> dict[SearchDocKey, SearchDoc]: """Thread-safe getter for all accumulated search docs (returns a copy).""" with self._lock: return self._all_search_docs.copy() def add_emitted_citation(self, citation_num: int) -> None: """Add a citation number that was actually emitted during streaming.""" with self._lock: self._emitted_citations.add(citation_num) def get_emitted_citations(self) -> set[int]: """Thread-safe getter for emitted citations (returns a copy).""" with self._lock: return self._emitted_citations.copy() class AvailableFiles(BaseModel): """Separated file IDs for the FileReaderTool so it knows which loader to use.""" # IDs from the ``user_file`` table (project / persona-attached files). user_file_ids: list[UUID] = [] # IDs from the ``file_record`` table (chat-attached files). chat_file_ids: list[UUID] = [] @dataclass(frozen=True) class ChatTurnSetup: """Immutable context produced by ``build_chat_turn`` and consumed by ``_run_models``.""" new_msg_req: SendMessageRequest chat_session: ChatSession persona: Persona user_message: ChatMessage user_identity: LLMUserIdentity llms: list[LLM] # length 1 for single-model, N for multi-model model_display_names: list[str] # parallel to llms simple_chat_history: list[ChatMessageSimple] extracted_context_files: ExtractedContextFiles reserved_messages: list[ChatMessage] # length 1 for single, N for multi reserved_token_count: int search_params: SearchParams all_injected_file_metadata: dict[str, FileToolMetadata] available_files: AvailableFiles tool_id_to_name_map: dict[int, str] forced_tool_id: int | None files: list[ChatLoadedFile] chat_files_for_tools: list[ChatFile] custom_agent_prompt: str | None user_memory_context: UserMemoryContext # For deep research: was the last assistant message a clarification request? skip_clarification: bool check_is_connected: Callable[[], bool] cache: CacheBackend # Execution params forwarded to per-model tool construction bypass_acl: bool slack_context: SlackContext | None custom_tool_additional_headers: dict[str, str] | None mcp_headers: dict[str, str] | None ================================================ FILE: backend/onyx/chat/chat_utils.py ================================================ import json import re from collections.abc import Callable from typing import cast from uuid import UUID from fastapi.datastructures import Headers from pydantic import BaseModel from sqlalchemy.orm import Session from onyx.chat.models import ChatHistoryResult from onyx.chat.models import ChatLoadedFile from onyx.chat.models import ChatMessageSimple from onyx.chat.models import FileToolMetadata from onyx.chat.models import ToolCallSimple from onyx.configs.constants import DEFAULT_PERSONA_ID from onyx.configs.constants import MessageType from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME from onyx.db.chat import create_chat_session from onyx.db.chat import get_chat_messages_by_session from onyx.db.chat import get_or_create_root_message from onyx.db.kg_config import get_kg_config_settings from onyx.db.kg_config import is_kg_config_settings_enabled_valid from onyx.db.models import ChatMessage from onyx.db.models import ChatSession from onyx.db.models import Persona from onyx.db.models import SearchDoc as DbSearchDoc from onyx.db.models import UserFile from onyx.db.projects import check_project_ownership from onyx.file_processing.extract_file_text import extract_file_text from onyx.file_store.file_store import get_default_file_store from onyx.file_store.models import ChatFileType from onyx.file_store.models import FileDescriptor from onyx.file_store.utils import plaintext_file_name_for_id from onyx.file_store.utils import store_plaintext from onyx.kg.models import KGException from onyx.kg.setup.kg_default_entity_definitions import ( populate_missing_default_entity_types__commit, ) from onyx.prompts.chat_prompts import ADDITIONAL_CONTEXT_PROMPT from onyx.prompts.chat_prompts import TOOL_CALL_RESPONSE_CROSS_MESSAGE from onyx.prompts.tool_prompts import TOOL_CALL_FAILURE_PROMPT from onyx.server.query_and_chat.models import ChatSessionCreationRequest from onyx.server.query_and_chat.streaming_models import CitationInfo from onyx.tools.models import ToolCallKickoff from onyx.utils.logger import setup_logger from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel from onyx.utils.timing import log_function_time logger = setup_logger() IMAGE_GENERATION_TOOL_NAME = "generate_image" class FileContextResult(BaseModel): """Result of building a file's LLM context representation.""" message: ChatMessageSimple tool_metadata: FileToolMetadata def build_file_context( tool_file_id: str, filename: str, file_type: ChatFileType, content_text: str | None = None, token_count: int = 0, approx_char_count: int | None = None, ) -> FileContextResult: """Build the LLM context representation for a single file. Centralises how files should appear in the LLM prompt — the ID that FileReaderTool accepts (``UserFile.id`` for user files). """ if file_type.use_metadata_only(): message_text = ( f"File: {filename} (id={tool_file_id})\n" "Use the file_reader or python tools to access " "this file's contents." ) message = ChatMessageSimple( message=message_text, token_count=max(1, len(message_text) // 4), message_type=MessageType.USER, file_id=tool_file_id, ) else: message_text = f"File: {filename}\n{content_text or ''}\nEnd of File" message = ChatMessageSimple( message=message_text, token_count=token_count, message_type=MessageType.USER, file_id=tool_file_id, ) metadata = FileToolMetadata( file_id=tool_file_id, filename=filename, approx_char_count=( approx_char_count if approx_char_count is not None else len(content_text or "") ), ) return FileContextResult(message=message, tool_metadata=metadata) def create_chat_session_from_request( chat_session_request: ChatSessionCreationRequest, user_id: UUID | None, db_session: Session, ) -> ChatSession: """Create a chat session from a ChatSessionCreationRequest. Includes project ownership validation when project_id is provided. Args: chat_session_request: The request containing persona_id, description, and project_id user_id: The ID of the user creating the session (can be None for anonymous) db_session: The database session Returns: The newly created ChatSession Raises: ValueError: If user lacks access to the specified project Exception: If the persona is invalid """ project_id = chat_session_request.project_id if project_id: if not check_project_ownership(project_id, user_id, db_session): raise ValueError("User does not have access to project") return create_chat_session( db_session=db_session, description=chat_session_request.description or "", user_id=user_id, persona_id=chat_session_request.persona_id, project_id=chat_session_request.project_id, ) def create_chat_history_chain( chat_session_id: UUID, db_session: Session, prefetch_top_two_level_tool_calls: bool = True, # Optional id at which we finish processing stop_at_message_id: int | None = None, ) -> list[ChatMessage]: """Build the linear chain of messages without including the root message""" mainline_messages: list[ChatMessage] = [] all_chat_messages = get_chat_messages_by_session( chat_session_id=chat_session_id, user_id=None, db_session=db_session, skip_permission_check=True, prefetch_top_two_level_tool_calls=prefetch_top_two_level_tool_calls, ) if not all_chat_messages: root_message = get_or_create_root_message( chat_session_id=chat_session_id, db_session=db_session ) else: root_message = all_chat_messages[0] if root_message.parent_message is not None: raise RuntimeError( "Invalid root message, unable to fetch valid chat message sequence" ) current_message: ChatMessage | None = root_message previous_message: ChatMessage | None = None while current_message is not None: child_msg = current_message.latest_child_message # Break if at the end of the chain # or have reached the `final_id` of the submitted message if not child_msg or ( stop_at_message_id and current_message.id == stop_at_message_id ): break current_message = child_msg if ( current_message.message_type == MessageType.ASSISTANT and previous_message is not None and previous_message.message_type == MessageType.ASSISTANT and mainline_messages ): # Note that 2 user messages in a row is fine since this is often used for # adding custom prompts and reminders raise RuntimeError( "Invalid message chain, cannot have two assistant messages in a row" ) else: mainline_messages.append(current_message) previous_message = current_message return mainline_messages def reorganize_citations( answer: str, citations: list[CitationInfo] ) -> tuple[str, list[CitationInfo]]: """For a complete, citation-aware response, we want to reorganize the citations so that they are in the order of the documents that were used in the response. This just looks nicer / avoids confusion ("Why is there [7] when only 2 documents are cited?").""" # Regular expression to find all instances of [[x]](LINK) pattern = r"\[\[(.*?)\]\]\((.*?)\)" all_citation_matches = re.findall(pattern, answer) new_citation_info: dict[int, CitationInfo] = {} for citation_match in all_citation_matches: try: citation_num = int(citation_match[0]) if citation_num in new_citation_info: continue matching_citation = next( iter([c for c in citations if c.citation_number == int(citation_num)]), None, ) if matching_citation is None: continue new_citation_info[citation_num] = CitationInfo( citation_number=len(new_citation_info) + 1, document_id=matching_citation.document_id, ) except Exception: pass # Function to replace citations with their new number def slack_link_format(match: re.Match) -> str: link_text = match.group(1) try: citation_num = int(link_text) if citation_num in new_citation_info: link_text = new_citation_info[citation_num].citation_number except Exception: pass link_url = match.group(2) return f"[[{link_text}]]({link_url})" # Substitute all matches in the input text new_answer = re.sub(pattern, slack_link_format, answer) # if any citations weren't parsable, just add them back to be safe for citation in citations: if citation.citation_number not in new_citation_info: new_citation_info[citation.citation_number] = citation return new_answer, list(new_citation_info.values()) def build_citation_map_from_infos( citations_list: list[CitationInfo], db_docs: list[DbSearchDoc] ) -> dict[int, int]: """Translate a list of streaming CitationInfo objects into a mapping of citation number -> saved search doc DB id. Always cites the first instance of a document_id and assumes db_docs are ordered as shown to the user (display order). """ doc_id_to_saved_doc_id_map: dict[str, int] = {} for db_doc in db_docs: if db_doc.document_id not in doc_id_to_saved_doc_id_map: doc_id_to_saved_doc_id_map[db_doc.document_id] = db_doc.id citation_to_saved_doc_id_map: dict[int, int] = {} for citation in citations_list: if citation.citation_number not in citation_to_saved_doc_id_map: saved_id = doc_id_to_saved_doc_id_map.get(citation.document_id) if saved_id is not None: citation_to_saved_doc_id_map[citation.citation_number] = saved_id return citation_to_saved_doc_id_map def build_citation_map_from_numbers( cited_numbers: list[int] | set[int], db_docs: list[DbSearchDoc] ) -> dict[int, int]: """Translate parsed citation numbers (e.g., from [[n]]) into a mapping of citation number -> saved search doc DB id by positional index. """ citation_to_saved_doc_id_map: dict[int, int] = {} for num in sorted(set(cited_numbers)): idx = num - 1 if 0 <= idx < len(db_docs): citation_to_saved_doc_id_map[num] = db_docs[idx].id return citation_to_saved_doc_id_map def extract_headers( headers: dict[str, str] | Headers, pass_through_headers: list[str] | None ) -> dict[str, str]: """ Extract headers specified in pass_through_headers from input headers. Handles both dict and FastAPI Headers objects, accounting for lowercase keys. Args: headers: Input headers as dict or Headers object. Returns: dict: Filtered headers based on pass_through_headers. """ if not pass_through_headers: return {} extracted_headers: dict[str, str] = {} for key in pass_through_headers: if key in headers: extracted_headers[key] = headers[key] else: # fastapi makes all header keys lowercase, handling that here lowercase_key = key.lower() if lowercase_key in headers: extracted_headers[lowercase_key] = headers[lowercase_key] return extracted_headers def process_kg_commands( message: str, persona_name: str, tenant_id: str, # noqa: ARG001 db_session: Session, ) -> None: # Temporarily, until we have a draft UI for the KG Operations/Management # TODO: move to api endpoint once we get frontend if not persona_name.startswith(TMP_DRALPHA_PERSONA_NAME): return kg_config_settings = get_kg_config_settings() if not is_kg_config_settings_enabled_valid(kg_config_settings): return if message == "kg_setup": populate_missing_default_entity_types__commit(db_session=db_session) raise KGException("KG setup done") def _get_or_extract_plaintext( file_id: str, extract_fn: Callable[[], str], ) -> str: """Load cached plaintext for a file, or extract and store it. Tries to read pre-stored plaintext from the file store. On a miss, calls extract_fn to produce the text, then stores the result so future calls skip the expensive extraction. """ file_store = get_default_file_store() plaintext_key = plaintext_file_name_for_id(file_id) # Try cached plaintext first. try: plaintext_io = file_store.read_file(plaintext_key, mode="b") return plaintext_io.read().decode("utf-8") except Exception: logger.exception(f"Error when reading file, id={file_id}") # Cache miss — extract and store. content_text = extract_fn() if content_text: store_plaintext(file_id, content_text) return content_text @log_function_time(print_only=True) def load_chat_file( file_descriptor: FileDescriptor, db_session: Session ) -> ChatLoadedFile: file_io = get_default_file_store().read_file(file_descriptor["id"], mode="b") content = file_io.read() # Extract text content if it's a text file type (not an image) content_text = None # `FileDescriptor` is often JSON-roundtripped (e.g. JSONB / API), so `type` # may arrive as a raw string value instead of a `ChatFileType`. file_type = ChatFileType(file_descriptor["type"]) if file_type.is_text_file(): file_id = file_descriptor["id"] def _extract() -> str: return extract_file_text( file=file_io, file_name=file_descriptor.get("name") or "", break_on_unprocessable=False, ) # Use the user_file_id as cache key when available (matches what # the celery indexing worker stores), otherwise fall back to the # file store id (covers code-interpreter-generated files, etc.). user_file_id_str = file_descriptor.get("user_file_id") cache_key = user_file_id_str or file_id try: content_text = _get_or_extract_plaintext(cache_key, _extract) except Exception as e: logger.warning( f"Failed to retrieve content for file {file_descriptor['id']}: {str(e)}" ) # Get token count from UserFile if available token_count = 0 user_file_id_str = file_descriptor.get("user_file_id") if user_file_id_str: try: user_file_id = UUID(user_file_id_str) user_file = ( db_session.query(UserFile).filter(UserFile.id == user_file_id).first() ) if user_file and user_file.token_count: token_count = user_file.token_count except (ValueError, TypeError) as e: logger.warning( f"Failed to get token count for file {file_descriptor['id']}: {e}" ) return ChatLoadedFile( file_id=file_descriptor["id"], content=content, file_type=file_type, filename=file_descriptor.get("name"), content_text=content_text, token_count=token_count, ) def load_all_chat_files( chat_messages: list[ChatMessage], db_session: Session, ) -> list[ChatLoadedFile]: # TODO There is likely a more efficient/standard way to load the files here. file_descriptors_for_history: list[FileDescriptor] = [] for chat_message in chat_messages: if chat_message.files: file_descriptors_for_history.extend(chat_message.files) files = cast( list[ChatLoadedFile], run_functions_tuples_in_parallel( [ (load_chat_file, (file, db_session)) for file in file_descriptors_for_history ] ), ) return files def convert_chat_history_basic( chat_history: list[ChatMessage], token_counter: Callable[[str], int], max_individual_message_tokens: int | None = None, max_total_tokens: int | None = None, ) -> list[ChatMessageSimple]: """Convert ChatMessage history to ChatMessageSimple format with no tool calls or files included. Args: chat_history: List of ChatMessage objects to convert token_counter: Function to count tokens in a message string max_individual_message_tokens: If set, messages exceeding this number of tokens are dropped. If None, no messages are dropped based on individual token count. max_total_tokens: If set, maximum number of tokens allowed for the entire history. If None, the history is not trimmed based on total token count. Returns: List of ChatMessageSimple objects """ # Defensive: treat a non-positive total budget as "no history". if max_total_tokens is not None and max_total_tokens <= 0: return [] # Convert only the core USER/ASSISTANT messages; omit files and tool calls. converted: list[ChatMessageSimple] = [] for chat_message in chat_history: if chat_message.message_type not in (MessageType.USER, MessageType.ASSISTANT): continue message = chat_message.message or "" token_count = getattr(chat_message, "token_count", None) if token_count is None: token_count = token_counter(message) # Drop any single message that would dominate the context window. if ( max_individual_message_tokens is not None and token_count > max_individual_message_tokens ): continue converted.append( ChatMessageSimple( message=message, token_count=token_count, message_type=chat_message.message_type, image_files=None, ) ) if max_total_tokens is None: return converted # Enforce a max total budget by keeping a contiguous suffix of the conversation. trimmed_reversed: list[ChatMessageSimple] = [] total_tokens = 0 for msg in reversed(converted): if total_tokens + msg.token_count > max_total_tokens: break trimmed_reversed.append(msg) total_tokens += msg.token_count return list(reversed(trimmed_reversed)) def _build_tool_call_response_history_message( tool_name: str, generated_images: list[dict] | None, tool_call_response: str | None, ) -> str: if tool_name != IMAGE_GENERATION_TOOL_NAME: return TOOL_CALL_RESPONSE_CROSS_MESSAGE if generated_images: llm_image_context: list[dict[str, str]] = [] for image in generated_images: file_id = image.get("file_id") revised_prompt = image.get("revised_prompt") if not isinstance(file_id, str): continue llm_image_context.append( { "file_id": file_id, "revised_prompt": ( revised_prompt if isinstance(revised_prompt, str) else "" ), } ) if llm_image_context: return json.dumps(llm_image_context) if tool_call_response: return tool_call_response return TOOL_CALL_RESPONSE_CROSS_MESSAGE def convert_chat_history( chat_history: list[ChatMessage], files: list[ChatLoadedFile], context_image_files: list[ChatLoadedFile], additional_context: str | None, token_counter: Callable[[str], int], tool_id_to_name_map: dict[int, str], ) -> ChatHistoryResult: """Convert ChatMessage history to ChatMessageSimple format. For user messages: includes attached files (images attached to message, text files as separate messages) For assistant messages with tool calls: creates ONE ASSISTANT message with tool_calls array, followed by N TOOL_CALL_RESPONSE messages (OpenAI parallel tool calling format) For assistant messages without tool calls: creates a simple ASSISTANT message Every injected text-file message is tagged with ``file_id`` and its metadata is collected in ``ChatHistoryResult.all_injected_file_metadata``. After context-window truncation, callers compare surviving ``file_id`` tags against this map to discover "forgotten" files and provide their metadata to the FileReaderTool. """ simple_messages: list[ChatMessageSimple] = [] all_injected_file_metadata: dict[str, FileToolMetadata] = {} # Create a mapping of file IDs to loaded files for quick lookup file_map = {str(f.file_id): f for f in files} # Find the index of the last USER message last_user_message_idx = None for i in range(len(chat_history) - 1, -1, -1): if chat_history[i].message_type == MessageType.USER: last_user_message_idx = i break for idx, chat_message in enumerate(chat_history): if chat_message.message_type == MessageType.USER: # Process files attached to this message text_files: list[tuple[ChatLoadedFile, FileDescriptor]] = [] image_files: list[ChatLoadedFile] = [] if chat_message.files: for file_descriptor in chat_message.files: file_id = file_descriptor["id"] loaded_file = file_map.get(file_id) if loaded_file: if loaded_file.file_type == ChatFileType.IMAGE: image_files.append(loaded_file) else: # Text files (DOC, PLAIN_TEXT, TABULAR) are added as separate messages text_files.append((loaded_file, file_descriptor)) # Add text files as separate messages before the user message. # Each message is tagged with ``file_id`` so that forgotten files # can be detected after context-window truncation. for text_file, fd in text_files: # Use user_file_id as the FileReaderTool accepts that. # Fall back to the file-store path id. tool_id = fd.get("user_file_id") or text_file.file_id filename = text_file.filename or "unknown" ctx = build_file_context( tool_file_id=tool_id, filename=filename, file_type=text_file.file_type, content_text=text_file.content_text, token_count=text_file.token_count, ) simple_messages.append(ctx.message) all_injected_file_metadata[tool_id] = ctx.tool_metadata # Sum token counts from image files (excluding project image files) image_token_count = ( sum(img.token_count for img in image_files) if image_files else 0 ) # Add the user message with image files attached # If this is the last USER message, also include context_image_files # Note: context image file tokens are NOT counted in the token count if idx == last_user_message_idx: if context_image_files: image_files.extend(context_image_files) if additional_context: simple_messages.append( ChatMessageSimple( message=ADDITIONAL_CONTEXT_PROMPT.format( additional_context=additional_context ), token_count=token_counter(additional_context), message_type=MessageType.USER, image_files=None, ) ) simple_messages.append( ChatMessageSimple( message=chat_message.message, token_count=chat_message.token_count + image_token_count, message_type=MessageType.USER, image_files=image_files if image_files else None, ) ) elif chat_message.message_type == MessageType.ASSISTANT: # Handle tool calls if present using OpenAI parallel tool calling format: # 1. Group tool calls by turn_number # 2. For each turn: ONE ASSISTANT message with tool_calls array # 3. Followed by N TOOL_CALL_RESPONSE messages (one per tool call) if chat_message.tool_calls: # Group tool calls by turn number tool_calls_by_turn: dict[int, list] = {} for tool_call in chat_message.tool_calls: if tool_call.turn_number not in tool_calls_by_turn: tool_calls_by_turn[tool_call.turn_number] = [] tool_calls_by_turn[tool_call.turn_number].append(tool_call) # Sort turns and process each turn for turn_number in sorted(tool_calls_by_turn.keys()): turn_tool_calls = tool_calls_by_turn[turn_number] # Sort by tool_id within the turn for consistent ordering turn_tool_calls.sort(key=lambda tc: tc.tool_id) # Build ToolCallSimple list for this turn tool_calls_simple: list[ToolCallSimple] = [] for tool_call in turn_tool_calls: tool_name = tool_id_to_name_map.get( tool_call.tool_id, "unknown" ) tool_calls_simple.append( ToolCallSimple( tool_call_id=tool_call.tool_call_id, tool_name=tool_name, tool_arguments=tool_call.tool_call_arguments or {}, token_count=tool_call.tool_call_tokens, ) ) # Create ONE ASSISTANT message with all tool calls for this turn total_tool_call_tokens = sum( tc.token_count for tc in tool_calls_simple ) simple_messages.append( ChatMessageSimple( message="", # No text content when making tool calls token_count=total_tool_call_tokens, message_type=MessageType.ASSISTANT, tool_calls=tool_calls_simple, image_files=None, ) ) # Add TOOL_CALL_RESPONSE messages for each tool call in this turn for tool_call in turn_tool_calls: tool_name = tool_id_to_name_map.get( tool_call.tool_id, "unknown" ) tool_response_message = ( _build_tool_call_response_history_message( tool_name=tool_name, generated_images=tool_call.generated_images, tool_call_response=tool_call.tool_call_response, ) ) simple_messages.append( ChatMessageSimple( message=tool_response_message, token_count=( token_counter(tool_response_message) if tool_name == IMAGE_GENERATION_TOOL_NAME else 20 ), message_type=MessageType.TOOL_CALL_RESPONSE, tool_call_id=tool_call.tool_call_id, image_files=None, ) ) # Add the assistant message itself (the final answer) simple_messages.append( ChatMessageSimple( message=chat_message.message, token_count=chat_message.token_count, message_type=MessageType.ASSISTANT, image_files=None, ) ) else: raise ValueError( f"Invalid message type when constructing simple history: {chat_message.message_type}" ) return ChatHistoryResult( simple_messages=simple_messages, all_injected_file_metadata=all_injected_file_metadata, ) def get_custom_agent_prompt(persona: Persona, chat_session: ChatSession) -> str | None: """Get the custom agent prompt from persona or project instructions. If it's replacing the base system prompt, it does not count as a custom agent prompt (logic exists later also to drop it in this case). Chat Sessions in Projects that are using a custom agent will retain the custom agent prompt. Priority: persona.system_prompt (if not default Agent) > chat_session.project.instructions # NOTE: Logic elsewhere allows saving empty strings for potentially other purposes but for constructing the prompts # we never want to return an empty string for a prompt so it's translated into an explicit None. Args: persona: The Persona object chat_session: The ChatSession object Returns: The prompt to use for the custom Agent part of the prompt. """ # If using a custom Agent, always respect its prompt, even if in a Project, and even if it's an empty custom prompt. if persona.id != DEFAULT_PERSONA_ID: # Logic exists later also to drop it in this case but this is strictly correct anyhow. if persona.replace_base_system_prompt: return None return persona.system_prompt or None # If in a project and using the default Agent, respect the project instructions. if chat_session.project and chat_session.project.instructions: return chat_session.project.instructions return None def is_last_assistant_message_clarification(chat_history: list[ChatMessage]) -> bool: """Check if the last assistant message in chat history was a clarification question. This is used in the deep research flow to determine whether to skip the clarification step when the user has already responded to a clarification. Args: chat_history: List of ChatMessage objects in chronological order Returns: True if the last assistant message has is_clarification=True, False otherwise """ for message in reversed(chat_history): if message.message_type == MessageType.ASSISTANT: return message.is_clarification return False def create_tool_call_failure_messages( tool_calls: list[ToolCallKickoff], token_counter: Callable[[str], int] ) -> list[ChatMessageSimple]: """Create ChatMessageSimple objects for failed tool calls. Creates messages using OpenAI parallel tool calling format: 1. An ASSISTANT message with tool_calls field containing all failed tool calls 2. A TOOL_CALL_RESPONSE failure message for each tool call Args: tool_calls: List of ToolCallKickoff objects representing the failed tool calls token_counter: Function to count tokens in a message string Returns: List containing ChatMessageSimple objects: one assistant message with all tool calls followed by a failure response for each tool call """ if not tool_calls: return [] # Create ToolCallSimple for each failed tool call tool_calls_simple: list[ToolCallSimple] = [] for tool_call in tool_calls: tool_call_token_count = token_counter(tool_call.to_msg_str()) tool_calls_simple.append( ToolCallSimple( tool_call_id=tool_call.tool_call_id, tool_name=tool_call.tool_name, tool_arguments=tool_call.tool_args, token_count=tool_call_token_count, ) ) total_token_count = sum(tc.token_count for tc in tool_calls_simple) # Create ONE ASSISTANT message with all tool_calls (OpenAI format) assistant_msg = ChatMessageSimple( message="", # No text content when making tool calls token_count=total_token_count, message_type=MessageType.ASSISTANT, tool_calls=tool_calls_simple, image_files=None, ) messages: list[ChatMessageSimple] = [assistant_msg] # Create a TOOL_CALL_RESPONSE failure message for each tool call for tool_call in tool_calls: failure_response_msg = ChatMessageSimple( message=TOOL_CALL_FAILURE_PROMPT, token_count=50, # Tiny overestimate message_type=MessageType.TOOL_CALL_RESPONSE, tool_call_id=tool_call.tool_call_id, image_files=None, ) messages.append(failure_response_msg) return messages ================================================ FILE: backend/onyx/chat/citation_processor.py ================================================ """ Dynamic Citation Processor for LLM Responses This module provides a citation processor that can: - Accept citation number to SearchDoc mappings dynamically - Process token streams from LLMs to extract citations - Handle citations in three modes: REMOVE, KEEP_MARKERS, or HYPERLINK - Emit CitationInfo objects for detected citations (in HYPERLINK mode) - Track all seen citations regardless of mode - Maintain a list of cited documents in order of first citation """ import re from collections.abc import Generator from enum import Enum from typing import TypeAlias from onyx.configs.chat_configs import STOP_STREAM_PAT from onyx.context.search.models import SearchDoc from onyx.prompts.constants import TRIPLE_BACKTICK from onyx.server.query_and_chat.streaming_models import CitationInfo from onyx.utils.logger import setup_logger logger = setup_logger() class CitationMode(Enum): """Defines how citations should be handled in the output. REMOVE: Citations are completely removed from output text. No CitationInfo objects are emitted. Use case: When you need to remove citations from the output if they are not shared with the user (e.g. in discord bot, public slack bot). KEEP_MARKERS: Original citation markers like [1], [2] are preserved unchanged. No CitationInfo objects are emitted. Use case: When you need to track citations in research agent and later process them with collapse_citations() to renumber. HYPERLINK: Citations are replaced with markdown links like [[1]](url). CitationInfo objects are emitted for UI tracking. Use case: Final reports shown to users with clickable links. """ REMOVE = "remove" KEEP_MARKERS = "keep_markers" HYPERLINK = "hyperlink" CitationMapping: TypeAlias = dict[int, SearchDoc] # ============================================================================ # Utility functions # ============================================================================ def in_code_block(llm_text: str) -> bool: """Check if we're currently inside a code block by counting triple backticks.""" count = llm_text.count(TRIPLE_BACKTICK) return count % 2 != 0 # ============================================================================ # Main Citation Processor with Dynamic Mapping # ============================================================================ class DynamicCitationProcessor: """ A citation processor that accepts dynamic citation mappings. This processor is designed for multi-turn conversations where the citation number to document mapping is provided externally. It processes streaming tokens from an LLM, detects citations (e.g., [1], [2,3], [[4]]), and handles them according to the configured CitationMode: CitationMode.HYPERLINK (default): 1. Replaces citation markers with formatted markdown links (e.g., [[1]](url)) 2. Emits CitationInfo objects for tracking 3. Maintains the order in which documents were first cited Use case: Final reports shown to users with clickable links. CitationMode.KEEP_MARKERS: 1. Preserves original citation markers like [1], [2] unchanged 2. Does NOT emit CitationInfo objects 3. Still tracks all seen citations via get_seen_citations() Use case: When citations need later processing (e.g., renumbering). CitationMode.REMOVE: 1. Removes citation markers entirely from the output text 2. Does NOT emit CitationInfo objects 3. Still tracks all seen citations via get_seen_citations() Use case: Research agent intermediate reports. Features: - Accepts citation number → SearchDoc mapping via update_citation_mapping() - Configurable citation mode at initialization - Always tracks seen citations regardless of mode - Holds back tokens that might be partial citations - Maintains list of cited SearchDocs in order of first citation - Handles unicode bracket variants (【】, []) - Skips citation processing inside code blocks Example (HYPERLINK mode - default): processor = DynamicCitationProcessor() # Set up citation mapping processor.update_citation_mapping({1: search_doc1, 2: search_doc2}) # Process tokens from LLM for token in llm_stream: for result in processor.process_token(token): if isinstance(result, str): print(result) # Display text with [[1]](url) format elif isinstance(result, CitationInfo): handle_citation(result) # Track citation # Get cited documents at the end cited_docs = processor.get_cited_documents() Example (KEEP_MARKERS mode): processor = DynamicCitationProcessor(citation_mode=CitationMode.KEEP_MARKERS) processor.update_citation_mapping({1: search_doc1, 2: search_doc2}) # Process tokens from LLM for token in llm_stream: for result in processor.process_token(token): # Only strings are yielded, no CitationInfo objects print(result) # Display text with original [1] format preserved # Get all seen citations after processing seen_citations = processor.get_seen_citations() # {1: search_doc1, ...} Example (REMOVE mode): processor = DynamicCitationProcessor(citation_mode=CitationMode.REMOVE) processor.update_citation_mapping({1: search_doc1, 2: search_doc2}) # Process tokens - citations are removed but tracked for token in llm_stream: for result in processor.process_token(token): print(result) # Text without any citation markers # Citations are still tracked seen_citations = processor.get_seen_citations() """ def __init__( self, citation_mode: CitationMode = CitationMode.HYPERLINK, stop_stream: str | None = STOP_STREAM_PAT, ): """ Initialize the citation processor. Args: citation_mode: How to handle citations in the output. One of: - CitationMode.HYPERLINK (default): Replace [1] with [[1]](url) and emit CitationInfo objects. - CitationMode.KEEP_MARKERS: Keep original [1] markers unchanged, no CitationInfo objects emitted. - CitationMode.REMOVE: Remove citations entirely from output, no CitationInfo objects emitted. All modes track seen citations via get_seen_citations(). stop_stream: Optional stop token pattern to halt processing early. When this pattern is detected in the token stream, processing stops. Defaults to STOP_STREAM_PAT from chat configs. """ # Citation mapping from citation number to SearchDoc self.citation_to_doc: CitationMapping = {} self.seen_citations: CitationMapping = {} # citation num -> SearchDoc # Token processing state self.llm_out = "" # entire output so far self.curr_segment = "" # tokens held for citation processing self.hold = "" # tokens held for stop token processing self.stop_stream = stop_stream self.citation_mode = citation_mode # Citation tracking self.cited_documents_in_order: list[SearchDoc] = ( [] ) # SearchDocs in citation order self.cited_document_ids: set[str] = set() # all cited document_ids self.recent_cited_documents: set[str] = ( set() ) # recently cited (for deduplication) self.non_citation_count = 0 # Citation patterns # Matches potential incomplete citations: '[', '[[', '[1', '[[1', '[1,', '[1, ', etc. # Also matches unicode bracket variants: 【, [ self.possible_citation_pattern = re.compile(r"([\[【[]+(?:\d+,? ?)*$)") # Matches complete citations: # group 1: '[[1]]', [[2]], etc. (also matches 【【1】】, [[1]], 【1】, [1]) # group 2: '[1]', '[1, 2]', '[1,2,16]', etc. (also matches unicode variants) self.citation_pattern = re.compile( r"([\[【[]{2}\d+[\]】]]{2})|([\[【[]\d+(?:, ?\d+)*[\]】]])" ) def update_citation_mapping( self, citation_mapping: CitationMapping, update_duplicate_keys: bool = False, ) -> None: """ Update the citation number to SearchDoc mapping. This can be called multiple times to add or update mappings. New mappings will be merged with existing ones. Args: citation_mapping: Dictionary mapping citation numbers (1, 2, 3, ...) to SearchDoc objects update_duplicate_keys: If True, update existing mappings with new values when keys overlap. If False (default), filter out duplicate keys and only add non-duplicates. The default behavior is useful when OpenURL may have the same citation number as a Web Search result - in those cases, we keep the web search citation and snippet etc. """ if update_duplicate_keys: # Update all mappings, including duplicates self.citation_to_doc.update(citation_mapping) else: # Filter out duplicate keys and only add non-duplicates # Reason for this is that OpenURL may have the same citation number as a Web Search result # For those, we should just keep the web search citation and snippet etc. duplicate_keys = set(citation_mapping.keys()) & set( self.citation_to_doc.keys() ) non_duplicate_mapping = { k: v for k, v in citation_mapping.items() if k not in duplicate_keys } self.citation_to_doc.update(non_duplicate_mapping) def process_token( self, token: str | None ) -> Generator[str | CitationInfo, None, None]: """ Process a token from the LLM stream. This method: 1. Accumulates tokens until a complete citation or non-citation is found 2. Holds back potential partial citations (e.g., "[", "[1") 3. Yields text chunks when they're safe to display 4. Handles code blocks (avoids processing citations inside code) 5. Handles stop tokens 6. Always tracks seen citations in self.seen_citations Behavior depends on the `citation_mode` setting from __init__: - HYPERLINK: Citations are replaced with [[n]](url) format and CitationInfo objects are yielded before each formatted citation - KEEP_MARKERS: Original citation markers like [1] are preserved unchanged, no CitationInfo objects are yielded - REMOVE: Citations are removed entirely from output, no CitationInfo objects are yielded Args: token: The next token from the LLM stream, or None to signal end of stream. Pass None to flush any remaining buffered text at end of stream. Yields: str: Text chunks to display. Citation format depends on citation_mode. CitationInfo: Citation metadata (only when citation_mode=HYPERLINK) """ # None -> end of stream, flush remaining segment if token is None: if self.curr_segment: yield self.curr_segment return # Handle stop stream token if self.stop_stream: next_hold = self.hold + token if self.stop_stream in next_hold: # Extract text before the stop pattern stop_pos = next_hold.find(self.stop_stream) text_before_stop = next_hold[:stop_pos] # Process the text before stop pattern if any exists if text_before_stop: # Process text_before_stop through normal flow self.hold = "" token = text_before_stop # Continue to normal processing below else: # Stop pattern at the beginning, nothing to yield return elif next_hold == self.stop_stream[: len(next_hold)]: self.hold = next_hold return else: token = next_hold self.hold = "" self.curr_segment += token self.llm_out += token # Handle code blocks without language tags # If we see ``` followed by \n, add "plaintext" language specifier if "`" in self.curr_segment: if self.curr_segment.endswith("`"): pass elif "```" in self.curr_segment: parts = self.curr_segment.split("```") if len(parts) > 1 and len(parts[1]) > 0: piece_that_comes_after = parts[1][0] if piece_that_comes_after == "\n" and in_code_block(self.llm_out): self.curr_segment = self.curr_segment.replace( "```", "```plaintext" ) # Look for citations in current segment citation_matches = list(self.citation_pattern.finditer(self.curr_segment)) possible_citation_found = bool( re.search(self.possible_citation_pattern, self.curr_segment) ) result = "" if citation_matches and not in_code_block(self.llm_out): match_idx = 0 for match in citation_matches: match_span = match.span() # Get text before/between citations intermatch_str = self.curr_segment[match_idx : match_span[0]] self.non_citation_count += len(intermatch_str) match_idx = match_span[1] # Check if there is already a space before this citation if intermatch_str: has_leading_space = intermatch_str[-1].isspace() else: # No text between citations (consecutive citations) # If match_idx > 0, we've already processed a citation, so don't add space if match_idx > 0: # Consecutive citations - don't add space between them has_leading_space = True else: # Citation at start of segment - check if previous output has space segment_start_idx = len(self.llm_out) - len(self.curr_segment) if segment_start_idx > 0: has_leading_space = self.llm_out[ segment_start_idx - 1 ].isspace() else: has_leading_space = False # Reset recent citations if no citations found for a while if self.non_citation_count > 5: self.recent_cited_documents.clear() # Process the citation (returns formatted citation text and CitationInfo objects) # Always tracks seen citations regardless of citation_mode citation_text, citation_info_list = self._process_citation( match, has_leading_space ) if self.citation_mode == CitationMode.HYPERLINK: # HYPERLINK mode: Replace citations with markdown links [[n]](url) # Yield text before citation FIRST (preserve order) if intermatch_str: yield intermatch_str # Yield CitationInfo objects BEFORE the citation text # This allows the frontend to receive citation metadata before the token # that contains [[n]](link), enabling immediate rendering for citation in citation_info_list: yield citation # Then yield the formatted citation text if citation_text: yield citation_text elif self.citation_mode == CitationMode.KEEP_MARKERS: # KEEP_MARKERS mode: Preserve original citation markers unchanged # Yield text before citation if intermatch_str: yield intermatch_str # Yield the original citation marker as-is yield match.group() else: # CitationMode.REMOVE # REMOVE mode: Remove citations entirely from output # This strips citation markers like [1], [2], 【1】 from the output text # When removing citations, we need to handle spacing to avoid issues like: # - "text [1] more" -> "text more" (double space) # - "text [1]." -> "text ." (space before punctuation) if intermatch_str: remaining_text = self.curr_segment[match_span[1] :] # Strip trailing space from intermatch if: # 1. Remaining text starts with space (avoids double space) # 2. Remaining text starts with punctuation (avoids space before punctuation) if intermatch_str[-1].isspace() and remaining_text: first_char = remaining_text[0] # Check if next char is space or common punctuation if first_char.isspace() or first_char in ".,;:!?)]}": intermatch_str = intermatch_str.rstrip() if intermatch_str: yield intermatch_str self.non_citation_count = 0 # Leftover text could be part of next citation self.curr_segment = self.curr_segment[match_idx:] self.non_citation_count = len(self.curr_segment) # Hold onto the current segment if potential citations found, otherwise stream it if not possible_citation_found: result += self.curr_segment self.non_citation_count += len(self.curr_segment) self.curr_segment = "" if result: yield result def _process_citation( self, match: re.Match, has_leading_space: bool ) -> tuple[str, list[CitationInfo]]: """ Process a single citation match and return formatted citation text and citation info objects. This is an internal method called by process_token(). The match string can be in various formats: '[1]', '[1, 13, 6]', '[[4]]', '【1】', '[1]', etc. This method always: 1. Extracts citation numbers from the match 2. Looks up the corresponding SearchDoc from the mapping 3. Tracks seen citations in self.seen_citations (regardless of citation_mode) When citation_mode is HYPERLINK: 4. Creates formatted citation text as [[n]](url) 5. Creates CitationInfo objects for new citations 6. Handles deduplication of recently cited documents When citation_mode is REMOVE or KEEP_MARKERS: 4. Returns empty string and empty list (caller handles output based on mode) Args: match: Regex match object containing the citation pattern has_leading_space: Whether the text immediately before this citation ends with whitespace. Used to determine if a leading space should be added to the formatted output. Returns: Tuple of (formatted_citation_text, citation_info_list): - formatted_citation_text: Markdown-formatted citation text like "[[1]](https://example.com)" or empty string if not in HYPERLINK mode - citation_info_list: List of CitationInfo objects for newly cited documents, or empty list if not in HYPERLINK mode """ citation_str: str = match.group() # e.g., '[1]', '[1, 2, 3]', '[[1]]', '【1】' formatted = ( match.lastindex == 1 ) # True means already in form '[[1]]' or '【【1】】' citation_info_list: list[CitationInfo] = [] formatted_citation_parts: list[str] = [] # Extract citation numbers - regex ensures matched brackets, so we can simply slice citation_content = citation_str[2:-2] if formatted else citation_str[1:-1] for num_str in citation_content.split(","): num_str = num_str.strip() if not num_str: continue try: num = int(num_str) except ValueError: # Invalid citation, skip it logger.warning(f"Invalid citation number format: {num_str}") continue # Check if we have a mapping for this citation number if num not in self.citation_to_doc: logger.warning( f"Citation number {num} not found in mapping. Available: {list(self.citation_to_doc.keys())}" ) continue # Get the SearchDoc search_doc = self.citation_to_doc[num] doc_id = search_doc.document_id link = search_doc.link or "" # Always track seen citations regardless of citation_mode setting self.seen_citations[num] = search_doc # Only generate formatted citations and CitationInfo in HYPERLINK mode if self.citation_mode != CitationMode.HYPERLINK: continue # Format the citation text as [[n]](link) formatted_citation_parts.append(f"[[{num}]]({link})") # Skip creating CitationInfo for citations of the same work if cited recently (deduplication) if doc_id in self.recent_cited_documents: continue self.recent_cited_documents.add(doc_id) # Track cited documents and create CitationInfo only for new citations if doc_id not in self.cited_document_ids: self.cited_document_ids.add(doc_id) self.cited_documents_in_order.append(search_doc) citation_info_list.append( CitationInfo( citation_number=num, document_id=doc_id, ) ) # Join all citation parts with spaces formatted_citation_text = " ".join(formatted_citation_parts) # Apply leading space only if the text didn't already have one if formatted_citation_text and not has_leading_space: formatted_citation_text = " " + formatted_citation_text return formatted_citation_text, citation_info_list def get_cited_documents(self) -> list[SearchDoc]: """ Get the list of cited SearchDoc objects in the order they were first cited. Note: This list is only populated when `citation_mode=HYPERLINK`. When using REMOVE or KEEP_MARKERS mode, this will return an empty list. Use get_seen_citations() instead if you need to track citations without emitting CitationInfo objects. Returns: List of SearchDoc objects in the order they were first cited. Empty list if citation_mode is not HYPERLINK. """ return self.cited_documents_in_order def get_cited_document_ids(self) -> list[str]: """ Get the list of cited document IDs in the order they were first cited. Note: This list is only populated when `citation_mode=HYPERLINK`. When using REMOVE or KEEP_MARKERS mode, this will return an empty list. Use get_seen_citations() instead if you need to track citations without emitting CitationInfo objects. Returns: List of document IDs (strings) in the order they were first cited. Empty list if citation_mode is not HYPERLINK. """ return [doc.document_id for doc in self.cited_documents_in_order] def get_seen_citations(self) -> CitationMapping: """ Get all seen citations as a mapping from citation number to SearchDoc. This returns all citations that have been encountered during processing, regardless of the `citation_mode` setting. Citations are tracked whenever they are parsed, making this useful for cases where you need to know which citations appeared in the text without emitting CitationInfo objects. This is particularly useful when using REMOVE or KEEP_MARKERS mode, as get_cited_documents() will be empty in those cases, but get_seen_citations() will still contain all the citations that were found. Returns: Dictionary mapping citation numbers (int) to SearchDoc objects. The dictionary is keyed by the citation number as it appeared in the text (e.g., {1: SearchDoc(...), 3: SearchDoc(...)}). """ return self.seen_citations @property def num_cited_documents(self) -> int: """ Get the number of unique documents that have been cited. Note: This count is only updated when `citation_mode=HYPERLINK`. When using REMOVE or KEEP_MARKERS mode, this will always return 0. Use len(get_seen_citations()) instead if you need to count citations without emitting CitationInfo objects. Returns: Number of unique documents cited. 0 if citation_mode is not HYPERLINK. """ return len(self.cited_document_ids) def reset_recent_citations(self) -> None: """ Reset the recent citations tracker. The processor tracks "recently cited" documents to avoid emitting duplicate CitationInfo objects for the same document when it's cited multiple times in close succession. This method clears that tracker. This is primarily useful when `citation_mode=HYPERLINK` to allow previously cited documents to emit CitationInfo objects again. Has no effect when using REMOVE or KEEP_MARKERS mode. The recent citation tracker is also automatically cleared when more than 5 non-citation characters are processed between citations. """ self.recent_cited_documents.clear() def get_next_citation_number(self) -> int: """ Get the next available citation number for adding new documents to the mapping. This method returns the next citation number that should be used when adding new documents via update_citation_mapping(). Useful when dynamically adding citations during processing (e.g., from tool results like web search). If no citations exist yet in the mapping, returns 1. Otherwise, returns max(existing_citation_numbers) + 1. Returns: The next available citation number (1-indexed integer). Example: # After adding citations 1, 2, 3 processor.get_next_citation_number() # Returns 4 # With non-sequential citations 1, 5, 10 processor.get_next_citation_number() # Returns 11 """ if not self.citation_to_doc: return 1 return max(self.citation_to_doc.keys()) + 1 ================================================ FILE: backend/onyx/chat/citation_utils.py ================================================ import re from onyx.chat.citation_processor import CitationMapping from onyx.chat.citation_processor import DynamicCitationProcessor from onyx.context.search.models import SearchDocsResponse from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES from onyx.tools.models import ToolResponse def update_citation_processor_from_tool_response( tool_response: ToolResponse, citation_processor: DynamicCitationProcessor, ) -> None: """Update citation processor if this was a citeable tool with a SearchDocsResponse. Checks if the tool call is citeable and if the response contains a SearchDocsResponse, then creates a mapping from citation numbers to SearchDoc objects and updates the citation processor. Args: tool_response: The response from the tool execution (must have tool_call set) citation_processor: The DynamicCitationProcessor to update """ # Early return if tool_call is not set if tool_response.tool_call is None: return # Update citation processor if this was a search tool if tool_response.tool_call.tool_name in CITEABLE_TOOLS_NAMES: # Check if the rich_response is a SearchDocsResponse if isinstance(tool_response.rich_response, SearchDocsResponse): search_response = tool_response.rich_response # Create mapping from citation number to SearchDoc citation_to_doc: CitationMapping = {} for ( citation_num, doc_id, ) in search_response.citation_mapping.items(): # Find the SearchDoc with this doc_id matching_doc = next( ( doc for doc in search_response.search_docs if doc.document_id == doc_id ), None, ) if matching_doc: citation_to_doc[citation_num] = matching_doc # Update the citation processor citation_processor.update_citation_mapping(citation_to_doc) def extract_citation_order_from_text(text: str) -> list[int]: """Extract citation numbers from text in order of first appearance. Parses citation patterns like [1], [1, 2], [[1]], 【1】 etc. and returns the citation numbers in the order they first appear in the text. Args: text: The text containing citations Returns: List of citation numbers in order of first appearance (no duplicates) """ # Same pattern used in collapse_citations and DynamicCitationProcessor # Group 2 captures the number in double bracket format: [[1]], 【【1】】 # Group 4 captures the numbers in single bracket format: [1], [1, 2] citation_pattern = re.compile( r"([\[【[]{2}(\d+)[\]】]]{2})|([\[【[]([\d]+(?: *, *\d+)*)[\]】]])" ) seen: set[int] = set() order: list[int] = [] for match in citation_pattern.finditer(text): # Group 2 is for double bracket single number, group 4 is for single bracket if match.group(2): nums_str = match.group(2) elif match.group(4): nums_str = match.group(4) else: continue for num_str in nums_str.split(","): num_str = num_str.strip() if num_str: try: num = int(num_str) if num not in seen: seen.add(num) order.append(num) except ValueError: continue return order def collapse_citations( answer_text: str, existing_citation_mapping: CitationMapping, new_citation_mapping: CitationMapping, ) -> tuple[str, CitationMapping]: """Collapse the citations in the text to use the smallest possible numbers. This function takes citations in the text (like [25], [30], etc.) and replaces them with the smallest possible numbers. It starts numbering from the next available integer after the existing citation mapping. If a citation refers to a document that already exists in the existing citation mapping (matched by document_id), it uses the existing citation number instead of assigning a new one. Args: answer_text: The text containing citations to collapse (e.g., "See [25] and [30]") existing_citation_mapping: Citations already processed/displayed. These mappings are preserved unchanged in the output. new_citation_mapping: Citations from the current text that need to be collapsed. The keys are the citation numbers as they appear in answer_text. Returns: A tuple of (updated_text, combined_mapping) where: - updated_text: The text with citations replaced with collapsed numbers - combined_mapping: All values from existing_citation_mapping plus the new mappings with their (possibly renumbered) keys """ # Build a reverse lookup: document_id -> existing citation number doc_id_to_existing_citation: dict[str, int] = { doc.document_id: citation_num for citation_num, doc in existing_citation_mapping.items() } # Determine the next available citation number if existing_citation_mapping: next_citation_num = max(existing_citation_mapping.keys()) + 1 else: next_citation_num = 1 # Build the mapping from old citation numbers (in new_citation_mapping) to new numbers old_to_new: dict[int, int] = {} additional_mappings: CitationMapping = {} for old_num, search_doc in new_citation_mapping.items(): doc_id = search_doc.document_id # Check if this document already exists in existing citations if doc_id in doc_id_to_existing_citation: # Use the existing citation number old_to_new[old_num] = doc_id_to_existing_citation[doc_id] else: # Check if we've already assigned a new number to this document # (handles case where same doc appears with different old numbers) existing_new_num = None for mapped_old, mapped_new in old_to_new.items(): if ( mapped_old in new_citation_mapping and new_citation_mapping[mapped_old].document_id == doc_id ): existing_new_num = mapped_new break if existing_new_num is not None: old_to_new[old_num] = existing_new_num else: # Assign the next available number old_to_new[old_num] = next_citation_num additional_mappings[next_citation_num] = search_doc next_citation_num += 1 # Pattern to match citations like [25], [1, 2, 3], [[25]], etc. # Also matches unicode bracket variants: 【】, [] citation_pattern = re.compile( r"([\[【[]{2}\d+[\]】]]{2})|([\[【[]\d+(?:, ?\d+)*[\]】]])" ) def replace_citation(match: re.Match) -> str: """Replace citation numbers in a match with their new collapsed values.""" citation_str = match.group() # Determine bracket style if ( citation_str.startswith("[[") or citation_str.startswith("【【") or citation_str.startswith("[[") ): open_bracket = citation_str[:2] close_bracket = citation_str[-2:] content = citation_str[2:-2] else: open_bracket = citation_str[0] close_bracket = citation_str[-1] content = citation_str[1:-1] # Parse and replace citation numbers new_nums = [] for num_str in content.split(","): num_str = num_str.strip() if not num_str: continue try: num = int(num_str) # Only replace if we have a mapping for this number if num in old_to_new: new_nums.append(str(old_to_new[num])) else: # Keep original if not in our mapping new_nums.append(num_str) except ValueError: new_nums.append(num_str) # Reconstruct the citation with original bracket style new_content = ", ".join(new_nums) return f"{open_bracket}{new_content}{close_bracket}" # Replace all citations in the text updated_text = citation_pattern.sub(replace_citation, answer_text) # Build the combined mapping combined_mapping: CitationMapping = dict(existing_citation_mapping) combined_mapping.update(additional_mappings) return updated_text, combined_mapping ================================================ FILE: backend/onyx/chat/compression.py ================================================ """ Chat history compression via summarization. This module handles compressing long chat histories by summarizing older messages while keeping recent messages verbatim. Summaries are branch-aware: each summary's parent_message_id points to the last message when compression triggered, making it part of the tree structure. """ from typing import NamedTuple from pydantic import BaseModel from sqlalchemy.orm import Session from onyx.configs.chat_configs import COMPRESSION_TRIGGER_RATIO from onyx.configs.constants import MessageType from onyx.db.models import ChatMessage from onyx.llm.interfaces import LLM from onyx.llm.models import AssistantMessage from onyx.llm.models import ChatCompletionMessage from onyx.llm.models import SystemMessage from onyx.llm.models import UserMessage from onyx.natural_language_processing.utils import get_tokenizer from onyx.prompts.compression_prompts import PROGRESSIVE_SUMMARY_SYSTEM_PROMPT_BLOCK from onyx.prompts.compression_prompts import PROGRESSIVE_USER_REMINDER from onyx.prompts.compression_prompts import SUMMARIZATION_CUTOFF_MARKER from onyx.prompts.compression_prompts import SUMMARIZATION_PROMPT from onyx.prompts.compression_prompts import USER_REMINDER from onyx.tracing.framework.create import ensure_trace from onyx.tracing.llm_utils import llm_generation_span from onyx.tracing.llm_utils import record_llm_response from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() # Ratio of available context to allocate for recent messages after compression RECENT_MESSAGES_RATIO = 0.2 class CompressionResult(BaseModel): """Result of a compression operation.""" summary_created: bool messages_summarized: int error: str | None = None class CompressionParams(BaseModel): """Parameters for compression operation.""" should_compress: bool tokens_for_recent: int = 0 class SummaryContent(NamedTuple): """Messages split for summarization.""" older_messages: list[ChatMessage] recent_messages: list[ChatMessage] def calculate_total_history_tokens(chat_history: list[ChatMessage]) -> int: """ Calculate the total token count for the given chat history. Args: chat_history: Branch-aware list of messages Returns: Total token count for the history """ return sum(m.token_count or 0 for m in chat_history) def get_compression_params( max_input_tokens: int, current_history_tokens: int, reserved_tokens: int, ) -> CompressionParams: """ Calculate compression parameters based on model's context window. Args: max_input_tokens: The maximum input tokens for the LLM current_history_tokens: Current total tokens in chat history reserved_tokens: Tokens reserved for system prompt, tools, files, etc. Returns: CompressionParams indicating whether to compress and token budgets """ available = max_input_tokens - reserved_tokens # Check trigger threshold trigger_threshold = int(available * COMPRESSION_TRIGGER_RATIO) if current_history_tokens <= trigger_threshold: return CompressionParams(should_compress=False) # Calculate token budget for recent messages as a percentage of current history # This ensures we always have messages to summarize when compression triggers tokens_for_recent = int(current_history_tokens * RECENT_MESSAGES_RATIO) return CompressionParams( should_compress=True, tokens_for_recent=tokens_for_recent, ) def find_summary_for_branch( db_session: Session, chat_history: list[ChatMessage], ) -> ChatMessage | None: """ Find the most recent summary that applies to the current branch. A summary applies if its parent_message_id is in the current chat history, meaning it was created on this branch. Args: db_session: Database session chat_history: Branch-aware list of messages Returns: The applicable summary message, or None if no summary exists for this branch """ if not chat_history: return None history_ids = {m.id for m in chat_history} chat_session_id = chat_history[0].chat_session_id # Query all summaries for this session (typically few), then filter in Python. # Order by time_sent descending to get the most recent summary first. summaries = ( db_session.query(ChatMessage) .filter( ChatMessage.chat_session_id == chat_session_id, ChatMessage.last_summarized_message_id.isnot(None), ) .order_by(ChatMessage.time_sent.desc()) .all() ) # Optimization to avoid using IN clause for large histories for summary in summaries: if summary.parent_message_id in history_ids: return summary return None def get_messages_to_summarize( chat_history: list[ChatMessage], existing_summary: ChatMessage | None, tokens_for_recent: int, ) -> SummaryContent: """ Split messages into those to summarize and those to keep verbatim. Args: chat_history: Branch-aware list of messages existing_summary: Existing summary for this branch (if any) tokens_for_recent: Token budget for recent messages to keep Returns: SummaryContent with older_messages to summarize and recent_messages to keep """ # Filter to messages after the existing summary's cutoff using timestamp if existing_summary and existing_summary.last_summarized_message_id: cutoff_id = existing_summary.last_summarized_message_id last_summarized_msg = next(m for m in chat_history if m.id == cutoff_id) messages = [ m for m in chat_history if m.time_sent > last_summarized_msg.time_sent ] else: messages = list(chat_history) # Filter out empty messages messages = [m for m in messages if m.message] if not messages: return SummaryContent(older_messages=[], recent_messages=[]) # Work backwards from most recent, keeping messages until we exceed budget recent_messages: list[ChatMessage] = [] tokens_used = 0 for msg in reversed(messages): msg_tokens = msg.token_count or 0 if tokens_used + msg_tokens > tokens_for_recent and recent_messages: break recent_messages.insert(0, msg) tokens_used += msg_tokens # Ensure cutoff is right before a user message by moving any leading # non-user messages from recent_messages to older_messages while recent_messages and recent_messages[0].message_type != MessageType.USER: recent_messages.pop(0) # Everything else gets summarized recent_ids = {m.id for m in recent_messages} older_messages = [m for m in messages if m.id not in recent_ids] return SummaryContent( older_messages=older_messages, recent_messages=recent_messages ) def _build_llm_messages_for_summarization( messages: list[ChatMessage], tool_id_to_name: dict[int, str], ) -> list[UserMessage | AssistantMessage]: """Convert ChatMessage objects to LLM message format for summarization. This is intentionally different from translate_history_to_llm_format in llm_step.py: - Compacts tool calls to "[Used tools: tool1, tool2]" to save tokens in summaries - Skips TOOL_CALL_RESPONSE messages entirely (tool usage captured in assistant message) - No image/multimodal handling (summaries are text-only) - No caching or LLMConfig-specific behavior needed """ result: list[UserMessage | AssistantMessage] = [] for msg in messages: # Skip empty messages if not msg.message: continue # Handle assistant messages with tool calls compactly if msg.message_type == MessageType.ASSISTANT: if msg.tool_calls: tool_names = [ tool_id_to_name.get(tc.tool_id, "unknown") for tc in msg.tool_calls ] result.append( AssistantMessage(content=f"[Used tools: {', '.join(tool_names)}]") ) else: result.append(AssistantMessage(content=msg.message)) continue # Skip tool call response messages - tool calls are captured above via assistant messages if msg.message_type == MessageType.TOOL_CALL_RESPONSE: continue # Handle user messages if msg.message_type == MessageType.USER: result.append(UserMessage(content=msg.message)) return result def generate_summary( older_messages: list[ChatMessage], recent_messages: list[ChatMessage], llm: LLM, tool_id_to_name: dict[int, str], existing_summary: str | None = None, ) -> str: """ Generate a summary using cutoff marker approach. The cutoff marker tells the LLM to summarize only older messages, while using recent messages as context to inform what's important. Messages are sent as separate UserMessage/AssistantMessage objects rather than being concatenated into a single message. Args: older_messages: Messages to compress into summary (before cutoff) recent_messages: Messages kept verbatim (after cutoff, for context only) llm: LLM to use for summarization tool_id_to_name: Mapping of tool IDs to display names existing_summary: Previous summary text to incorporate (progressive) Returns: Summary text """ # Build system prompt system_content = SUMMARIZATION_PROMPT if existing_summary: # Progressive summarization: append existing summary to system prompt system_content += PROGRESSIVE_SUMMARY_SYSTEM_PROMPT_BLOCK.format( previous_summary=existing_summary ) final_reminder = PROGRESSIVE_USER_REMINDER else: final_reminder = USER_REMINDER # Convert messages to LLM format (using compression-specific conversion) older_llm_messages = _build_llm_messages_for_summarization( older_messages, tool_id_to_name ) recent_llm_messages = _build_llm_messages_for_summarization( recent_messages, tool_id_to_name ) # Build message list with separate messages input_messages: list[ChatCompletionMessage] = [ SystemMessage(content=system_content), ] # Add older messages (to be summarized) input_messages.extend(older_llm_messages) # Add cutoff marker as a user message input_messages.append(UserMessage(content=SUMMARIZATION_CUTOFF_MARKER)) # Add recent messages (for context only) input_messages.extend(recent_llm_messages) # Add final reminder input_messages.append(UserMessage(content=final_reminder)) with llm_generation_span( llm=llm, flow="chat_history_summarization", input_messages=input_messages, ) as span_generation: response = llm.invoke(input_messages) record_llm_response(span_generation, response) content = response.choice.message.content if not (content and content.strip()): raise ValueError("LLM returned empty summary") return content.strip() def compress_chat_history( db_session: Session, chat_history: list[ChatMessage], llm: LLM, compression_params: CompressionParams, tool_id_to_name: dict[int, str], ) -> CompressionResult: """ Main compression function. Creates a summary ChatMessage. The summary message's parent_message_id points to the last message in chat_history, making it branch-aware via the tree structure. Note: This takes the entire chat history as input, splits it into older messages (to summarize) and recent messages (kept verbatim within the token budget), generates a summary of the older part, and persists the new summary message with its parent set to the last message in history. Past summary is taken into context (progressive summarization): we find at most one existing summary for this branch. If present, only messages after that summary's last_summarized_message_id are considered; the existing summary text is passed into the LLM so the new summary incorporates it instead of summarizing from scratch. For more details, see the COMPRESSION.md file. Args: db_session: Database session chat_history: Branch-aware list of messages llm: LLM to use for summarization compression_params: Parameters from get_compression_params tool_id_to_name: Mapping of tool IDs to display names Returns: CompressionResult indicating success/failure """ if not chat_history: return CompressionResult(summary_created=False, messages_summarized=0) chat_session_id = chat_history[0].chat_session_id logger.info( f"Starting compression for session {chat_session_id}, " f"history_len={len(chat_history)}, tokens_for_recent={compression_params.tokens_for_recent}" ) with ensure_trace( "chat_history_compression", group_id=str(chat_session_id), metadata={ "tenant_id": get_current_tenant_id(), "chat_session_id": str(chat_session_id), }, ): try: # Find existing summary for this branch existing_summary = find_summary_for_branch(db_session, chat_history) # Get messages to summarize summary_content = get_messages_to_summarize( chat_history, existing_summary, tokens_for_recent=compression_params.tokens_for_recent, ) if not summary_content.older_messages: logger.debug("No messages to summarize, skipping compression") return CompressionResult(summary_created=False, messages_summarized=0) # Generate summary (incorporate existing summary if present) existing_summary_text = ( existing_summary.message if existing_summary else None ) summary_text = generate_summary( older_messages=summary_content.older_messages, recent_messages=summary_content.recent_messages, llm=llm, tool_id_to_name=tool_id_to_name, existing_summary=existing_summary_text, ) # Calculate token count for the summary tokenizer = get_tokenizer(None, None) summary_token_count = len(tokenizer.encode(summary_text)) logger.debug( f"Generated summary ({summary_token_count} tokens): {summary_text[:200]}..." ) # Create new summary as a ChatMessage # Parent is the last message in history - this makes the summary branch-aware summary_message = ChatMessage( chat_session_id=chat_session_id, message_type=MessageType.ASSISTANT, message=summary_text, token_count=summary_token_count, parent_message_id=chat_history[-1].id, last_summarized_message_id=summary_content.older_messages[-1].id, ) db_session.add(summary_message) db_session.commit() logger.info( f"Compressed {len(summary_content.older_messages)} messages into summary " f"(session_id={chat_session_id}, " f"summary_tokens={summary_token_count})" ) return CompressionResult( summary_created=True, messages_summarized=len(summary_content.older_messages), ) except Exception as e: logger.exception(f"Compression failed for session {chat_session_id}: {e}") db_session.rollback() return CompressionResult( summary_created=False, messages_summarized=0, error=str(e), ) ================================================ FILE: backend/onyx/chat/emitter.py ================================================ import threading from queue import Queue from onyx.server.query_and_chat.placement import Placement from onyx.server.query_and_chat.streaming_models import Packet class Emitter: """Routes packets from LLM/tool execution to the ``_run_models`` drain loop. Tags every packet with ``model_index`` and places it on ``merged_queue`` as a ``(model_idx, packet)`` tuple for ordered consumption downstream. Args: merged_queue: Shared queue owned by ``_run_models``. model_idx: Index embedded in packet placements (``0`` for N=1 runs). drain_done: Optional event set by ``_run_models`` when the drain loop exits early (e.g. HTTP disconnect). When set, ``emit`` returns immediately so worker threads can exit fast. """ def __init__( self, merged_queue: Queue[tuple[int, Packet | Exception | object]], model_idx: int = 0, drain_done: threading.Event | None = None, ) -> None: self._model_idx = model_idx self._merged_queue = merged_queue self._drain_done = drain_done def emit(self, packet: Packet) -> None: if self._drain_done is not None and self._drain_done.is_set(): return base = packet.placement or Placement(turn_index=0) tagged = Packet( placement=base.model_copy(update={"model_index": self._model_idx}), obj=packet.obj, ) self._merged_queue.put((self._model_idx, tagged)) ================================================ FILE: backend/onyx/chat/llm_loop.py ================================================ import json import time from collections.abc import Callable from typing import Any from typing import Literal from sqlalchemy.orm import Session from onyx.chat.chat_state import ChatStateContainer from onyx.chat.chat_utils import create_tool_call_failure_messages from onyx.chat.citation_processor import CitationMapping from onyx.chat.citation_processor import CitationMode from onyx.chat.citation_processor import DynamicCitationProcessor from onyx.chat.citation_utils import update_citation_processor_from_tool_response from onyx.chat.emitter import Emitter from onyx.chat.llm_step import extract_tool_calls_from_response_text from onyx.chat.llm_step import run_llm_step from onyx.chat.models import ChatMessageSimple from onyx.chat.models import ContextFileMetadata from onyx.chat.models import ExtractedContextFiles from onyx.chat.models import FileToolMetadata from onyx.chat.models import LlmStepResult from onyx.chat.models import ToolCallSimple from onyx.chat.prompt_utils import build_reminder_message from onyx.chat.prompt_utils import build_system_prompt from onyx.chat.prompt_utils import ( get_default_base_system_prompt, ) from onyx.configs.app_configs import INTEGRATION_TESTS_MODE from onyx.configs.constants import DocumentSource from onyx.configs.constants import MessageType from onyx.context.search.models import SearchDoc from onyx.context.search.models import SearchDocsResponse from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.memory import add_memory from onyx.db.memory import update_memory_at_index from onyx.db.memory import UserMemoryContext from onyx.db.models import Persona from onyx.llm.constants import LlmProviderNames from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLMUserIdentity from onyx.llm.interfaces import ToolChoiceOptions from onyx.llm.utils import is_true_openai_model from onyx.prompts.chat_prompts import IMAGE_GEN_REMINDER from onyx.prompts.chat_prompts import OPEN_URL_REMINDER from onyx.server.query_and_chat.placement import Placement from onyx.server.query_and_chat.streaming_models import OverallStop from onyx.server.query_and_chat.streaming_models import Packet from onyx.server.query_and_chat.streaming_models import ToolCallDebug from onyx.server.query_and_chat.streaming_models import TopLevelBranching from onyx.tools.built_in_tools import CITEABLE_TOOLS_NAMES from onyx.tools.built_in_tools import STOPPING_TOOLS_NAMES from onyx.tools.interface import Tool from onyx.tools.models import ChatFile from onyx.tools.models import CustomToolCallSummary from onyx.tools.models import MemoryToolResponseSnapshot from onyx.tools.models import PythonToolRichResponse from onyx.tools.models import ToolCallInfo from onyx.tools.models import ToolCallKickoff from onyx.tools.models import ToolResponse from onyx.tools.tool_implementations.images.models import ( FinalImageGenerationResponse, ) from onyx.tools.tool_implementations.memory.models import MemoryToolResponse from onyx.tools.tool_implementations.python.python_tool import PythonTool from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool from onyx.tools.tool_runner import run_tool_calls from onyx.tracing.framework.create import trace from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() class EmptyLLMResponseError(RuntimeError): """Raised when the streamed LLM response completes without a usable answer.""" def __init__( self, *, provider: str, model: str, tool_choice: ToolChoiceOptions, client_error_msg: str, error_code: str = "EMPTY_LLM_RESPONSE", is_retryable: bool = True, ) -> None: super().__init__(client_error_msg) self.provider = provider self.model = model self.tool_choice = tool_choice self.client_error_msg = client_error_msg self.error_code = error_code self.is_retryable = is_retryable def _build_empty_llm_response_error( llm: LLM, llm_step_result: LlmStepResult, tool_choice: ToolChoiceOptions, ) -> EmptyLLMResponseError: provider = llm.config.model_provider model = llm.config.model_name # OpenAI quota exhaustion has reached us as a streamed "stop" with zero content. # When the stream is completely empty and there is no reasoning/tool output, surface # the likely account-level cause instead of a generic tool-calling error. if ( not llm_step_result.reasoning and provider == LlmProviderNames.OPENAI and is_true_openai_model(provider, model) ): return EmptyLLMResponseError( provider=provider, model=model, tool_choice=tool_choice, client_error_msg=( "The selected OpenAI model returned an empty streamed response " "before producing any tokens. This commonly happens when the API " "key or project has no remaining quota or billing is not enabled. " "Verify quota and billing for this key and try again." ), error_code="BUDGET_EXCEEDED", is_retryable=False, ) return EmptyLLMResponseError( provider=provider, model=model, tool_choice=tool_choice, client_error_msg=( "The selected model returned no final answer before the stream " "completed. No text or tool calls were received from the upstream " "provider." ), ) def _looks_like_xml_tool_call_payload(text: str | None) -> bool: """Detect XML-style marshaled tool calls emitted as plain text.""" if not text: return False lowered = text.lower() return ( " tuple[LlmStepResult, bool]: """Attempt to extract tool calls from response text as a fallback. This is a last resort fallback for low quality LLMs or those that don't have tool calling from the serving layer. Also triggers if there's reasoning but no answer and no tool calls. Args: llm_step_result: The result from the LLM step tool_choice: The tool choice option used for this step fallback_extraction_attempted: Whether fallback extraction was already attempted tool_defs: List of tool definitions turn_index: The current turn index for placement Returns: Tuple of (possibly updated LlmStepResult, whether fallback was attempted this call) """ if fallback_extraction_attempted: return llm_step_result, False no_tool_calls = ( not llm_step_result.tool_calls or len(llm_step_result.tool_calls) == 0 ) reasoning_but_no_answer_or_tools = ( llm_step_result.reasoning and not llm_step_result.answer and no_tool_calls ) xml_tool_call_text_detected = no_tool_calls and ( _looks_like_xml_tool_call_payload(llm_step_result.answer) or _looks_like_xml_tool_call_payload(llm_step_result.raw_answer) or _looks_like_xml_tool_call_payload(llm_step_result.reasoning) ) should_try_fallback = ( (tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls) or reasoning_but_no_answer_or_tools or xml_tool_call_text_detected ) if not should_try_fallback: return llm_step_result, False # Try to extract from answer first, then fall back to reasoning extracted_tool_calls: list[ToolCallKickoff] = [] if llm_step_result.answer: extracted_tool_calls = extract_tool_calls_from_response_text( response_text=llm_step_result.answer, tool_definitions=tool_defs, placement=Placement(turn_index=turn_index), ) if ( not extracted_tool_calls and llm_step_result.raw_answer and llm_step_result.raw_answer != llm_step_result.answer ): extracted_tool_calls = extract_tool_calls_from_response_text( response_text=llm_step_result.raw_answer, tool_definitions=tool_defs, placement=Placement(turn_index=turn_index), ) if not extracted_tool_calls and llm_step_result.reasoning: extracted_tool_calls = extract_tool_calls_from_response_text( response_text=llm_step_result.reasoning, tool_definitions=tool_defs, placement=Placement(turn_index=turn_index), ) if extracted_tool_calls: logger.info( f"Extracted {len(extracted_tool_calls)} tool call(s) from response text as fallback" ) return ( LlmStepResult( reasoning=llm_step_result.reasoning, answer=llm_step_result.answer, tool_calls=extracted_tool_calls, raw_answer=llm_step_result.raw_answer, ), True, ) return llm_step_result, True # Hardcoded oppinionated value, might breaks down to something like: # Cycle 1: Calls web_search for something # Cycle 2: Calls open_url for some results # Cycle 3: Calls web_search for some other aspect of the question # Cycle 4: Calls open_url for some results # Cycle 5: Maybe call open_url for some additional results or because last set failed # Cycle 6: No more tools available, forced to answer MAX_LLM_CYCLES = 6 def _build_context_file_citation_mapping( file_metadata: list[ContextFileMetadata], starting_citation_num: int = 1, ) -> CitationMapping: """Build citation mapping for context files. Converts context file metadata into SearchDoc objects that can be cited. Citation numbers start from the provided starting number. Args: file_metadata: List of context file metadata starting_citation_num: Starting citation number (default: 1) Returns: Dictionary mapping citation numbers to SearchDoc objects """ citation_mapping: CitationMapping = {} for idx, file_meta in enumerate(file_metadata, start=starting_citation_num): search_doc = SearchDoc( document_id=file_meta.file_id, chunk_ind=0, semantic_identifier=file_meta.filename, link=None, blurb=file_meta.file_content, source_type=DocumentSource.FILE, boost=1, hidden=False, metadata={}, score=0.0, match_highlights=[file_meta.file_content], ) citation_mapping[idx] = search_doc return citation_mapping def _build_project_message( context_files: ExtractedContextFiles | None, token_counter: Callable[[str], int] | None, ) -> list[ChatMessageSimple]: """Build messages for context-injected / tool-backed files. Returns up to two messages: 1. The full-text files message (if file_texts is populated). 2. A lightweight metadata message for files the LLM should access via the FileReaderTool (e.g. oversized files that don't fit in context). """ if not context_files: return [] messages: list[ChatMessageSimple] = [] if context_files.file_texts: messages.append( _create_context_files_message(context_files, token_counter=None) ) if context_files.file_metadata_for_tool and token_counter: messages.append( _create_file_tool_metadata_message( context_files.file_metadata_for_tool, token_counter ) ) return messages def construct_message_history( system_prompt: ChatMessageSimple | None, custom_agent_prompt: ChatMessageSimple | None, simple_chat_history: list[ChatMessageSimple], reminder_message: ChatMessageSimple | None, context_files: ExtractedContextFiles | None, available_tokens: int, last_n_user_messages: int | None = None, token_counter: Callable[[str], int] | None = None, all_injected_file_metadata: dict[str, FileToolMetadata] | None = None, ) -> list[ChatMessageSimple]: if last_n_user_messages is not None: if last_n_user_messages <= 0: raise ValueError( "filtering chat history by last N user messages must be a value greater than 0" ) # Build the project / file-metadata messages up front so we can use their # actual token counts for the budget. project_messages = _build_project_message(context_files, token_counter) project_messages_tokens = sum(m.token_count for m in project_messages) history_token_budget = available_tokens history_token_budget -= system_prompt.token_count if system_prompt else 0 history_token_budget -= ( custom_agent_prompt.token_count if custom_agent_prompt else 0 ) history_token_budget -= project_messages_tokens history_token_budget -= reminder_message.token_count if reminder_message else 0 if history_token_budget < 0: raise ValueError("Not enough tokens available to construct message history") if system_prompt: system_prompt.should_cache = True # If no history, build minimal context if not simple_chat_history: result = [system_prompt] if system_prompt else [] if custom_agent_prompt: result.append(custom_agent_prompt) result.extend(project_messages) if reminder_message: result.append(reminder_message) return result # If last_n_user_messages is set, filter history to only include the last n user messages if last_n_user_messages is not None: # Find all user message indices user_msg_indices = [ i for i, msg in enumerate(simple_chat_history) if msg.message_type == MessageType.USER ] if not user_msg_indices: raise ValueError("No user message found in simple_chat_history") # If we have more than n user messages, keep only the last n if len(user_msg_indices) > last_n_user_messages: # Find the index of the n-th user message from the end # For example, if last_n_user_messages=2, we want the 2nd-to-last user message nth_user_msg_index = user_msg_indices[-(last_n_user_messages)] # Keep everything from that user message onwards simple_chat_history = simple_chat_history[nth_user_msg_index:] # Find the last USER message in the history # The history may contain tool calls and responses after the last user message last_user_msg_index = None for i in range(len(simple_chat_history) - 1, -1, -1): if simple_chat_history[i].message_type == MessageType.USER: last_user_msg_index = i break if last_user_msg_index is None: raise ValueError("No user message found in simple_chat_history") # Split history into three parts: # 1. History before the last user message # 2. The last user message # 3. Messages after the last user message (tool calls, responses, etc.) history_before_last_user = simple_chat_history[:last_user_msg_index] last_user_message = simple_chat_history[last_user_msg_index] messages_after_last_user = simple_chat_history[last_user_msg_index + 1 :] # Calculate tokens needed for the last user message and everything after it last_user_tokens = last_user_message.token_count after_user_tokens = sum(msg.token_count for msg in messages_after_last_user) # Check if we can fit at least the last user message and messages after it required_tokens = last_user_tokens + after_user_tokens if required_tokens > history_token_budget: raise ValueError( f"Not enough tokens to include the last user message and subsequent messages. " f"Required: {required_tokens}, Available: {history_token_budget}" ) # Calculate remaining budget for history before the last user message remaining_budget = history_token_budget - required_tokens # Truncate history_before_last_user from the top to fit in remaining budget. # Track dropped file messages so we can provide their metadata to the # FileReaderTool instead. truncated_history_before: list[ChatMessageSimple] = [] dropped_file_ids: list[str] = [] current_token_count = 0 for msg in reversed(history_before_last_user): if current_token_count + msg.token_count <= remaining_budget: msg.should_cache = True truncated_history_before.insert(0, msg) current_token_count += msg.token_count else: # Can't fit this message, stop truncating. # This message and everything older is dropped. break # Collect file_ids from ALL dropped messages (those not in # truncated_history_before). The truncation loop above keeps the most # recent messages, so the dropped ones are at the start of the original # list up to (len(history) - len(kept)). num_kept = len(truncated_history_before) for msg in history_before_last_user[: len(history_before_last_user) - num_kept]: if msg.file_id is not None: dropped_file_ids.append(msg.file_id) # Also treat "orphaned" metadata entries as dropped -- these are files # from messages removed by summary truncation (before convert_chat_history # ran), so no ChatMessageSimple was ever tagged with their file_id. if all_injected_file_metadata: surviving_file_ids = { msg.file_id for msg in simple_chat_history if msg.file_id is not None } for fid in all_injected_file_metadata: if fid not in surviving_file_ids and fid not in dropped_file_ids: dropped_file_ids.append(fid) # Build a forgotten-files metadata message if any file messages were # dropped AND we have metadata for them (meaning the FileReaderTool is # available). Reserve tokens for this message in the budget. forgotten_files_message: ChatMessageSimple | None = None if dropped_file_ids and all_injected_file_metadata and token_counter: forgotten_meta = [ all_injected_file_metadata[fid] for fid in dropped_file_ids if fid in all_injected_file_metadata ] if forgotten_meta: logger.debug( f"FileReader: building forgotten-files message for {[(m.file_id, m.filename) for m in forgotten_meta]}" ) forgotten_files_message = _create_file_tool_metadata_message( forgotten_meta, token_counter ) # Shrink the remaining budget. If the metadata message doesn't # fit we may need to drop more history messages. remaining_budget -= forgotten_files_message.token_count while truncated_history_before and current_token_count > remaining_budget: evicted = truncated_history_before.pop(0) current_token_count -= evicted.token_count # If the evicted message is itself a file, add it to the # forgotten metadata (it's now dropped too). if ( evicted.file_id is not None and evicted.file_id in all_injected_file_metadata and evicted.file_id not in {m.file_id for m in forgotten_meta} ): forgotten_meta.append(all_injected_file_metadata[evicted.file_id]) # Rebuild the message with the new entry forgotten_files_message = _create_file_tool_metadata_message( forgotten_meta, token_counter ) # Attach project images to the last user message if context_files and context_files.image_files: existing_images = last_user_message.image_files or [] last_user_message = ChatMessageSimple( message=last_user_message.message, token_count=last_user_message.token_count, message_type=last_user_message.message_type, image_files=existing_images + context_files.image_files, ) # Build the final message list according to README ordering: # [system], [history_before_last_user], [custom_agent], [context_files], # [forgotten_files], [last_user_message], [messages_after_last_user], [reminder] result = [system_prompt] if system_prompt else [] # 1. Add truncated history before last user message result.extend(truncated_history_before) # 2. Add custom agent prompt (inserted before last user message) if custom_agent_prompt: result.append(custom_agent_prompt) # 3. Add context files / file-metadata messages (inserted before last user message) result.extend(project_messages) # 4. Add forgotten-files metadata (right before the user's question) if forgotten_files_message: result.append(forgotten_files_message) # 5. Add last user message (with context images attached) result.append(last_user_message) # 6. Add messages after last user message (tool calls, responses, etc.) result.extend(messages_after_last_user) # 7. Add reminder message at the very end if reminder_message: result.append(reminder_message) return _drop_orphaned_tool_call_responses(result) def _drop_orphaned_tool_call_responses( messages: list[ChatMessageSimple], ) -> list[ChatMessageSimple]: """Drop tool response messages whose tool_call_id is not in prior assistant tool calls. This can happen when history truncation drops an ASSISTANT tool-call message but leaves a later TOOL_CALL_RESPONSE message in context. Some providers (e.g. Ollama) reject such history with an "unexpected tool call id" error. """ known_tool_call_ids: set[str] = set() sanitized: list[ChatMessageSimple] = [] for msg in messages: if msg.message_type == MessageType.ASSISTANT and msg.tool_calls: for tool_call in msg.tool_calls: known_tool_call_ids.add(tool_call.tool_call_id) sanitized.append(msg) continue if msg.message_type == MessageType.TOOL_CALL_RESPONSE: if msg.tool_call_id and msg.tool_call_id in known_tool_call_ids: sanitized.append(msg) else: logger.debug( "Dropping orphaned tool response with tool_call_id=%s while constructing message history", msg.tool_call_id, ) continue sanitized.append(msg) return sanitized def _create_file_tool_metadata_message( file_metadata: list[FileToolMetadata], token_counter: Callable[[str], int], ) -> ChatMessageSimple: """Build a lightweight metadata-only message listing files available via FileReaderTool. Used when files are too large to fit in context and the vector DB is disabled, so the LLM must use ``read_file`` to inspect them. """ lines = [ "You have access to the following files. Use the read_file tool to " "read sections of any file. You MUST pass the file_id UUID (not the " "filename) to read_file:" ] for meta in file_metadata: lines.append( f'- file_id="{meta.file_id}" filename="{meta.filename}" (~{meta.approx_char_count:,} chars)' ) message_content = "\n".join(lines) return ChatMessageSimple( message=message_content, token_count=token_counter(message_content), message_type=MessageType.USER, ) def _create_context_files_message( context_files: ExtractedContextFiles, token_counter: Callable[[str], int] | None, # noqa: ARG001 ) -> ChatMessageSimple: """Convert context files to a ChatMessageSimple message. Format follows the README specification for document representation. """ import json # Format as documents JSON as described in README documents_list = [] for idx, file_text in enumerate(context_files.file_texts, start=1): title = ( context_files.file_metadata[idx - 1].filename if idx - 1 < len(context_files.file_metadata) else None ) entry: dict[str, Any] = {"document": idx} if title: entry["title"] = title entry["contents"] = file_text documents_list.append(entry) documents_json = json.dumps({"documents": documents_list}, indent=2) message_content = f"Here are some documents provided for context, they may not all be relevant:\n{documents_json}" # Use pre-calculated token count from context_files return ChatMessageSimple( message=message_content, token_count=context_files.total_token_count, message_type=MessageType.USER, ) def run_llm_loop( emitter: Emitter, state_container: ChatStateContainer, simple_chat_history: list[ChatMessageSimple], tools: list[Tool], custom_agent_prompt: str | None, context_files: ExtractedContextFiles, persona: Persona | None, user_memory_context: UserMemoryContext | None, llm: LLM, token_counter: Callable[[str], int], db_session: Session, forced_tool_id: int | None = None, user_identity: LLMUserIdentity | None = None, chat_session_id: str | None = None, chat_files: list[ChatFile] | None = None, include_citations: bool = True, all_injected_file_metadata: dict[str, FileToolMetadata] | None = None, inject_memories_in_prompt: bool = True, ) -> None: with trace( "run_llm_loop", group_id=chat_session_id, metadata={ "tenant_id": get_current_tenant_id(), "chat_session_id": chat_session_id, }, ): # Fix some LiteLLM issues, from onyx.llm.litellm_singleton.config import ( initialize_litellm, ) # Here for lazy load LiteLLM initialize_litellm() # Track when the loop starts for calculating time-to-answer loop_start_time = time.monotonic() # Initialize citation processor for handling citations dynamically # When include_citations is True, use HYPERLINK mode to format citations as [[1]](url) # When include_citations is False, use REMOVE mode to strip citations from output citation_processor = DynamicCitationProcessor( citation_mode=( CitationMode.HYPERLINK if include_citations else CitationMode.REMOVE ) ) # Add project file citation mappings if project files are present project_citation_mapping: CitationMapping = {} if context_files.file_metadata: project_citation_mapping = _build_context_file_citation_mapping( context_files.file_metadata ) citation_processor.update_citation_mapping(project_citation_mapping) llm_step_result = LlmStepResult( reasoning=None, answer=None, tool_calls=None, raw_answer=None, ) # Pass the total budget to construct_message_history, which will handle token allocation available_tokens = llm.config.max_input_tokens tool_choice: ToolChoiceOptions = ToolChoiceOptions.AUTO # Initialize gathered_documents with project files if present gathered_documents: list[SearchDoc] | None = ( list(project_citation_mapping.values()) if project_citation_mapping else None ) # TODO allow citing of images in Projects. Since attached to the last user message, it has no text associated with it. # One future workaround is to include the images as separate user messages with citation information and process those. always_cite_documents: bool = bool( context_files.use_as_search_filter or context_files.file_texts ) should_cite_documents: bool = False ran_image_gen: bool = False just_ran_web_search: bool = False has_called_search_tool: bool = False code_interpreter_file_generated: bool = False fallback_extraction_attempted: bool = False citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL # Fetch this in a short-lived session so the long-running stream loop does # not pin a connection just to keep read state alive. with get_session_with_current_tenant() as prompt_db_session: default_base_system_prompt: str = get_default_base_system_prompt( prompt_db_session ) system_prompt = None custom_agent_prompt_msg = None reasoning_cycles = 0 for llm_cycle_count in range(MAX_LLM_CYCLES): # Handling tool calls based on cycle count and past cycle conditions out_of_cycles = llm_cycle_count == MAX_LLM_CYCLES - 1 if forced_tool_id: # Needs to be just the single one because the "required" currently doesn't have a specified tool, just a binary final_tools = [tool for tool in tools if tool.id == forced_tool_id] if not final_tools: raise ValueError(f"Tool {forced_tool_id} not found in tools") tool_choice = ToolChoiceOptions.REQUIRED forced_tool_id = None elif out_of_cycles or ran_image_gen: # Last cycle, no tools allowed, just answer! tool_choice = ToolChoiceOptions.NONE final_tools = [] else: tool_choice = ToolChoiceOptions.AUTO final_tools = tools # Handling the system prompt and custom agent prompt # The section below calculates the available tokens for history a bit more accurately # now that project files are loaded in. if persona and persona.replace_base_system_prompt: # Handles the case where user has checked off the "Replace base system prompt" checkbox system_prompt = ( ChatMessageSimple( message=persona.system_prompt, token_count=token_counter(persona.system_prompt), message_type=MessageType.SYSTEM, ) if persona.system_prompt else None ) custom_agent_prompt_msg = None else: # If it's an empty string, we assume the user does not want to include it as an empty System message if default_base_system_prompt: prompt_memory_context = ( user_memory_context if inject_memories_in_prompt else ( user_memory_context.without_memories() if user_memory_context else None ) ) system_prompt_str = build_system_prompt( base_system_prompt=default_base_system_prompt, datetime_aware=persona.datetime_aware if persona else True, user_memory_context=prompt_memory_context, tools=tools, should_cite_documents=should_cite_documents or always_cite_documents, ) system_prompt = ChatMessageSimple( message=system_prompt_str, token_count=token_counter(system_prompt_str), message_type=MessageType.SYSTEM, ) custom_agent_prompt_msg = ( ChatMessageSimple( message=custom_agent_prompt, token_count=token_counter(custom_agent_prompt), message_type=MessageType.USER, ) if custom_agent_prompt else None ) else: # If there is a custom agent prompt, it replaces the system prompt when the default system prompt is empty system_prompt = ( ChatMessageSimple( message=custom_agent_prompt, token_count=token_counter(custom_agent_prompt), message_type=MessageType.SYSTEM, ) if custom_agent_prompt else None ) custom_agent_prompt_msg = None reminder_message_text: str | None if ran_image_gen: # Some models are trained to give back images to the user for some similar tool # This is to prevent it generating things like: # [Cute Cat](attachment://a_cute_cat_sitting_playfully.png) reminder_message_text = IMAGE_GEN_REMINDER elif just_ran_web_search and not out_of_cycles: reminder_message_text = OPEN_URL_REMINDER else: # This is the default case, the LLM at this point may answer so it is important # to include the reminder. Potentially this should also mention citation reminder_message_text = build_reminder_message( reminder_text=( persona.task_prompt if persona and persona.task_prompt else None ), include_citation_reminder=should_cite_documents or always_cite_documents, include_file_reminder=code_interpreter_file_generated, is_last_cycle=out_of_cycles, ) reminder_msg = ( ChatMessageSimple( message=reminder_message_text, token_count=token_counter(reminder_message_text), message_type=MessageType.USER_REMINDER, ) if reminder_message_text else None ) truncated_message_history = construct_message_history( system_prompt=system_prompt, custom_agent_prompt=custom_agent_prompt_msg, simple_chat_history=simple_chat_history, reminder_message=reminder_msg, context_files=context_files, available_tokens=available_tokens, token_counter=token_counter, all_injected_file_metadata=all_injected_file_metadata, ) # This calls the LLM, yields packets (reasoning, answers, etc.) and returns the result # It also pre-processes the tool calls in preparation for running them tool_defs = [tool.tool_definition() for tool in final_tools] # Calculate total processing time from loop start until now # This measures how long the user waits before the answer starts streaming pre_answer_processing_time = time.monotonic() - loop_start_time llm_step_result, has_reasoned = run_llm_step( emitter=emitter, history=truncated_message_history, tool_definitions=tool_defs, tool_choice=tool_choice, llm=llm, placement=Placement(turn_index=llm_cycle_count + reasoning_cycles), citation_processor=citation_processor, state_container=state_container, # The rich docs representation is passed in so that when yielding the answer, it can also # immediately yield the full set of found documents. This gives us the option to show the # final set of documents immediately if desired. final_documents=gathered_documents, user_identity=user_identity, pre_answer_processing_time=pre_answer_processing_time, ) if has_reasoned: reasoning_cycles += 1 # Fallback extraction for LLMs that don't support tool calling natively or are lower quality # and might incorrectly output tool calls in other channels llm_step_result, attempted = _try_fallback_tool_extraction( llm_step_result=llm_step_result, tool_choice=tool_choice, fallback_extraction_attempted=fallback_extraction_attempted, tool_defs=tool_defs, turn_index=llm_cycle_count + reasoning_cycles, ) if attempted: # To prevent the case of excessive looping with bad models, we only allow one fallback attempt fallback_extraction_attempted = True # Save citation mapping after each LLM step for incremental state updates state_container.set_citation_mapping(citation_processor.citation_to_doc) # Run the LLM selected tools, there is some more logic here than a simple execution # each tool might have custom logic here tool_responses: list[ToolResponse] = [] tool_calls = llm_step_result.tool_calls or [] if INTEGRATION_TESTS_MODE and tool_calls: for tool_call in tool_calls: emitter.emit( Packet( placement=tool_call.placement, obj=ToolCallDebug( tool_call_id=tool_call.tool_call_id, tool_name=tool_call.tool_name, tool_args=tool_call.tool_args, ), ) ) if len(tool_calls) > 1: emitter.emit( Packet( placement=Placement( turn_index=tool_calls[0].placement.turn_index ), obj=TopLevelBranching(num_parallel_branches=len(tool_calls)), ) ) # Quick note for why citation_mapping and citation_processors are both needed: # 1. Tools return lightweight string mappings, not SearchDoc objects # 2. The SearchDoc resolution is deliberately deferred to llm_loop.py # 3. The citation_processor operates on SearchDoc objects and can't provide a complete reverse URL lookup for # in-flight citations # It can be cleaned up but not super trivial or worthwhile right now just_ran_web_search = False parallel_tool_call_results = run_tool_calls( tool_calls=tool_calls, tools=final_tools, message_history=truncated_message_history, user_memory_context=user_memory_context, user_info=None, # TODO, this is part of memories right now, might want to separate it out citation_mapping=citation_mapping, next_citation_num=citation_processor.get_next_citation_number(), max_concurrent_tools=None, skip_search_query_expansion=has_called_search_tool, chat_files=chat_files, url_snippet_map=extract_url_snippet_map(gathered_documents or []), inject_memories_in_prompt=inject_memories_in_prompt, ) tool_responses = parallel_tool_call_results.tool_responses citation_mapping = parallel_tool_call_results.updated_citation_mapping # Failure case, give something reasonable to the LLM to try again if tool_calls and not tool_responses: failure_messages = create_tool_call_failure_messages( tool_calls, token_counter ) simple_chat_history.extend(failure_messages) continue for tool_response in tool_responses: # Extract tool_call from the response (set by run_tool_calls) if tool_response.tool_call is None: raise ValueError("Tool response missing tool_call reference") tool_call = tool_response.tool_call tab_index = tool_call.placement.tab_index # Track if search tool was called (for skipping query expansion on subsequent calls) if tool_call.tool_name == SearchTool.NAME: has_called_search_tool = True # Track if code interpreter generated files with download links if ( tool_call.tool_name == PythonTool.NAME and not code_interpreter_file_generated ): try: parsed = json.loads(tool_response.llm_facing_response) if parsed.get("generated_files"): code_interpreter_file_generated = True except (json.JSONDecodeError, AttributeError): pass # Build a mapping of tool names to tool objects for getting tool_id tools_by_name = {tool.name: tool for tool in final_tools} # Add the results to the chat history. Even though tools may run in parallel, # LLM APIs require linear history, so results are added sequentially. # Get the tool object to retrieve tool_id tool = tools_by_name.get(tool_call.tool_name) if not tool: raise ValueError( f"Tool '{tool_call.tool_name}' not found in tools list" ) # Extract search_docs if this is a search tool response search_docs = None displayed_docs = None if isinstance(tool_response.rich_response, SearchDocsResponse): search_docs = tool_response.rich_response.search_docs displayed_docs = tool_response.rich_response.displayed_docs # Add ALL search docs to state container for DB persistence if search_docs: state_container.add_search_docs(search_docs) if gathered_documents: gathered_documents.extend(search_docs) else: gathered_documents = search_docs # This is used for the Open URL reminder in the next cycle # only do this if the web search tool yielded results if search_docs and tool_call.tool_name == WebSearchTool.NAME: just_ran_web_search = True # Extract generated_images if this is an image generation tool response generated_images = None if isinstance( tool_response.rich_response, FinalImageGenerationResponse ): generated_images = tool_response.rich_response.generated_images # Extract generated_files if this is a code interpreter response generated_files = None if isinstance(tool_response.rich_response, PythonToolRichResponse): generated_files = ( tool_response.rich_response.generated_files or None ) # Persist memory if this is a memory tool response memory_snapshot: MemoryToolResponseSnapshot | None = None if isinstance(tool_response.rich_response, MemoryToolResponse): persisted_memory_id: int | None = None if user_memory_context and user_memory_context.user_id: if tool_response.rich_response.index_to_replace is not None: memory = update_memory_at_index( user_id=user_memory_context.user_id, index=tool_response.rich_response.index_to_replace, new_text=tool_response.rich_response.memory_text, db_session=db_session, ) persisted_memory_id = memory.id if memory else None else: memory = add_memory( user_id=user_memory_context.user_id, memory_text=tool_response.rich_response.memory_text, db_session=db_session, ) persisted_memory_id = memory.id operation: Literal["add", "update"] = ( "update" if tool_response.rich_response.index_to_replace is not None else "add" ) memory_snapshot = MemoryToolResponseSnapshot( memory_text=tool_response.rich_response.memory_text, operation=operation, memory_id=persisted_memory_id, index=tool_response.rich_response.index_to_replace, ) if memory_snapshot: saved_response = json.dumps(memory_snapshot.model_dump()) elif isinstance(tool_response.rich_response, CustomToolCallSummary): saved_response = json.dumps( tool_response.rich_response.model_dump() ) elif isinstance(tool_response.rich_response, str): saved_response = tool_response.rich_response else: saved_response = tool_response.llm_facing_response tool_call_info = ToolCallInfo( parent_tool_call_id=None, # Top-level tool calls are attached to the chat message turn_index=llm_cycle_count + reasoning_cycles, tab_index=tab_index, tool_name=tool_call.tool_name, tool_call_id=tool_call.tool_call_id, tool_id=tool.id, reasoning_tokens=llm_step_result.reasoning, # All tool calls from this loop share the same reasoning tool_call_arguments=tool_call.tool_args, tool_call_response=saved_response, search_docs=displayed_docs or search_docs, generated_images=generated_images, generated_files=generated_files, ) # Add to state container for partial save support state_container.add_tool_call(tool_call_info) # Update citation processor if this was a search tool update_citation_processor_from_tool_response( tool_response, citation_processor ) # After processing all tool responses for this turn, add messages to history # using OpenAI parallel tool calling format: # 1. ONE ASSISTANT message with tool_calls array # 2. N TOOL_CALL_RESPONSE messages (one per tool call) if tool_responses: # Filter to only responses with valid tool_call references valid_tool_responses = [ tr for tr in tool_responses if tr.tool_call is not None ] # Build ToolCallSimple list for all tool calls in this turn tool_calls_simple: list[ToolCallSimple] = [] for tool_response in valid_tool_responses: tc = tool_response.tool_call assert ( tc is not None ) # Already filtered above, this is just for typing purposes tool_call_message = tc.to_msg_str() tool_call_token_count = token_counter(tool_call_message) tool_calls_simple.append( ToolCallSimple( tool_call_id=tc.tool_call_id, tool_name=tc.tool_name, tool_arguments=tc.tool_args, token_count=tool_call_token_count, ) ) # Create ONE ASSISTANT message with all tool calls for this turn total_tool_call_tokens = sum(tc.token_count for tc in tool_calls_simple) assistant_with_tools = ChatMessageSimple( message="", # No text content when making tool calls token_count=total_tool_call_tokens, message_type=MessageType.ASSISTANT, tool_calls=tool_calls_simple, image_files=None, ) simple_chat_history.append(assistant_with_tools) # Add TOOL_CALL_RESPONSE messages for each tool call for tool_response in valid_tool_responses: tc = tool_response.tool_call assert tc is not None # Already filtered above tool_response_message = tool_response.llm_facing_response tool_response_token_count = token_counter(tool_response_message) tool_response_msg = ChatMessageSimple( message=tool_response_message, token_count=tool_response_token_count, message_type=MessageType.TOOL_CALL_RESPONSE, tool_call_id=tc.tool_call_id, image_files=None, ) simple_chat_history.append(tool_response_msg) # If no tool calls, then it must have answered, wrap up if not llm_step_result.tool_calls or len(llm_step_result.tool_calls) == 0: break # Certain tools do not allow further actions, force the LLM wrap up on the next cycle if any( tool.tool_name in STOPPING_TOOLS_NAMES for tool in llm_step_result.tool_calls ): ran_image_gen = True if llm_step_result.tool_calls and any( tool.tool_name in CITEABLE_TOOLS_NAMES for tool in llm_step_result.tool_calls ): # As long as 1 tool with citeable documents is called at any point, we ask the LLM to try to cite should_cite_documents = True if not llm_step_result.answer and not llm_step_result.tool_calls: raise _build_empty_llm_response_error( llm=llm, llm_step_result=llm_step_result, tool_choice=tool_choice, ) if not llm_step_result.answer: raise RuntimeError( "The LLM did not return a final answer after tool execution. " "Typically this indicates invalid tool-call output, a model/provider mismatch, " "or serving API misconfiguration." ) emitter.emit( Packet( placement=Placement(turn_index=llm_cycle_count + reasoning_cycles), obj=OverallStop(type="stop"), ) ) ================================================ FILE: backend/onyx/chat/llm_step.py ================================================ import json import re import time import uuid from collections.abc import Callable from collections.abc import Generator from collections.abc import Mapping from collections.abc import Sequence from html import unescape from typing import Any from typing import cast from onyx.chat.chat_state import ChatStateContainer from onyx.chat.citation_processor import DynamicCitationProcessor from onyx.chat.emitter import Emitter from onyx.chat.models import ChatMessageSimple from onyx.chat.models import LlmStepResult from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY from onyx.configs.constants import MessageType from onyx.context.search.models import SearchDoc from onyx.file_store.models import ChatFileType from onyx.llm.constants import LlmProviderNames from onyx.llm.interfaces import LanguageModelInput from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLMConfig from onyx.llm.interfaces import LLMUserIdentity from onyx.llm.interfaces import ToolChoiceOptions from onyx.llm.model_response import Delta from onyx.llm.models import AssistantMessage from onyx.llm.models import ChatCompletionMessage from onyx.llm.models import FunctionCall from onyx.llm.models import ImageContentPart from onyx.llm.models import ImageUrlDetail from onyx.llm.models import ReasoningEffort from onyx.llm.models import SystemMessage from onyx.llm.models import TextContentPart from onyx.llm.models import ToolCall from onyx.llm.models import ToolMessage from onyx.llm.models import UserMessage from onyx.llm.prompt_cache.processor import process_with_prompt_cache from onyx.llm.utils import model_needs_formatting_reenabled from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN from onyx.prompts.constants import SYSTEM_REMINDER_TAG_CLOSE from onyx.prompts.constants import SYSTEM_REMINDER_TAG_OPEN from onyx.server.query_and_chat.placement import Placement from onyx.server.query_and_chat.streaming_models import AgentResponseDelta from onyx.server.query_and_chat.streaming_models import AgentResponseStart from onyx.server.query_and_chat.streaming_models import CitationInfo from onyx.server.query_and_chat.streaming_models import Packet from onyx.server.query_and_chat.streaming_models import ReasoningDelta from onyx.server.query_and_chat.streaming_models import ReasoningDone from onyx.server.query_and_chat.streaming_models import ReasoningStart from onyx.tools.models import ToolCallKickoff from onyx.tracing.framework.create import generation_span from onyx.utils.b64 import get_image_type_from_bytes from onyx.utils.jsonriver import Parser from onyx.utils.logger import setup_logger from onyx.utils.postgres_sanitization import sanitize_string from onyx.utils.text_processing import find_all_json_objects logger = setup_logger() _XML_INVOKE_BLOCK_RE = re.compile( r"[^>]*)>(?P.*?)", re.IGNORECASE | re.DOTALL, ) _XML_PARAMETER_RE = re.compile( r"[^>]*)>(?P.*?)", re.IGNORECASE | re.DOTALL, ) _FUNCTION_CALLS_OPEN_MARKER = " None: self._pending = "" self._inside_function_calls_block = False def process(self, content: str) -> str: if not content: return "" self._pending += content output_parts: list[str] = [] while self._pending: pending_lower = self._pending.lower() if self._inside_function_calls_block: end_idx = pending_lower.find(_FUNCTION_CALLS_CLOSE_MARKER) if end_idx == -1: # Keep buffering until we see the close marker. return "".join(output_parts) # Drop the whole function_calls block. self._pending = self._pending[ end_idx + len(_FUNCTION_CALLS_CLOSE_MARKER) : ] self._inside_function_calls_block = False continue start_idx = _find_function_calls_open_marker(pending_lower) if start_idx == -1: # Keep only a possible prefix of " 0: output_parts.append(self._pending[:emit_upto]) self._pending = self._pending[emit_upto:] return "".join(output_parts) if start_idx > 0: output_parts.append(self._pending[:start_idx]) # Enter block-stripping mode and keep scanning for close marker. self._pending = self._pending[start_idx:] self._inside_function_calls_block = True return "".join(output_parts) def flush(self) -> str: if self._inside_function_calls_block: # Drop any incomplete block at stream end. self._pending = "" self._inside_function_calls_block = False return "" remaining = self._pending self._pending = "" return remaining def _matching_open_marker_prefix_len(text: str) -> int: """Return longest suffix of text that matches prefix of " bool: return char is None or char in {">", " ", "\t", "\n", "\r"} def _find_function_calls_open_marker(text_lower: str) -> int: """Find ' Any: """Attempt to parse a JSON string value into its Python equivalent. If value is a string that looks like a JSON array or object, parse it. Otherwise return the value unchanged. This handles the case where the LLM returns arguments like: - queries: '["query1", "query2"]' instead of ["query1", "query2"] """ if not isinstance(value, str): return value stripped = value.strip() # Only attempt to parse if it looks like a JSON array or object if not ( (stripped.startswith("[") and stripped.endswith("]")) or (stripped.startswith("{") and stripped.endswith("}")) ): return value try: return json.loads(stripped) except json.JSONDecodeError: return value def _parse_tool_args_to_dict(raw_args: Any) -> dict[str, Any]: """Parse tool arguments into a dict. Normal case: - raw_args == '{"queries":[...]}' -> dict via json.loads Defensive case (JSON string literal of an object): - raw_args == '"{\\"queries\\":[...]}"' -> json.loads -> str -> json.loads -> dict Also handles the case where argument values are JSON strings that need parsing: - {"queries": '["q1", "q2"]'} -> {"queries": ["q1", "q2"]} Anything else returns {}. """ if raw_args is None: return {} if isinstance(raw_args, dict): # Parse any string values that look like JSON arrays/objects return { k: _try_parse_json_string(sanitize_string(v) if isinstance(v, str) else v) for k, v in raw_args.items() } if not isinstance(raw_args, str): return {} # Sanitize before parsing to remove NULL bytes and surrogates raw_args = sanitize_string(raw_args) try: parsed1: Any = json.loads(raw_args) except json.JSONDecodeError: return {} if isinstance(parsed1, dict): # Parse any string values that look like JSON arrays/objects return {k: _try_parse_json_string(v) for k, v in parsed1.items()} if isinstance(parsed1, str): try: parsed2: Any = json.loads(parsed1) except json.JSONDecodeError: return {} if isinstance(parsed2, dict): # Parse any string values that look like JSON arrays/objects return {k: _try_parse_json_string(v) for k, v in parsed2.items()} return {} return {} def _format_message_history_for_logging( message_history: LanguageModelInput, ) -> str: """Format message history for logging, with special handling for tool calls. Tool calls are formatted as JSON with 4-space indentation for readability. """ formatted_lines = [] separator = "================================================" # Handle single ChatCompletionMessage - wrap in list for uniform processing if isinstance( message_history, (SystemMessage, UserMessage, AssistantMessage, ToolMessage) ): message_history = [message_history] # Handle sequence of messages for i, msg in enumerate(message_history): if isinstance(msg, SystemMessage): formatted_lines.append(f"Message {i + 1} [system]:") formatted_lines.append(separator) formatted_lines.append(f"{msg.content}") elif isinstance(msg, UserMessage): formatted_lines.append(f"Message {i + 1} [user]:") formatted_lines.append(separator) if isinstance(msg.content, str): formatted_lines.append(f"{msg.content}") elif isinstance(msg.content, list): # Handle multimodal content (text + images) for part in msg.content: if isinstance(part, TextContentPart): formatted_lines.append(f"{part.text}") elif isinstance(part, ImageContentPart): url = part.image_url.url formatted_lines.append(f"[Image: {url[:50]}...]") elif isinstance(msg, AssistantMessage): formatted_lines.append(f"Message {i + 1} [assistant]:") formatted_lines.append(separator) if msg.content: formatted_lines.append(f"{msg.content}") if msg.tool_calls: formatted_lines.append("Tool calls:") for tool_call in msg.tool_calls: tool_call_dict: dict[str, Any] = { "id": tool_call.id, "type": tool_call.type, "function": { "name": tool_call.function.name, "arguments": tool_call.function.arguments, }, } tool_call_json = json.dumps(tool_call_dict, indent=4) formatted_lines.append(tool_call_json) elif isinstance(msg, ToolMessage): formatted_lines.append(f"Message {i + 1} [tool]:") formatted_lines.append(separator) formatted_lines.append(f"Tool call ID: {msg.tool_call_id}") formatted_lines.append(f"Response: {msg.content}") else: # Fallback for unknown message types formatted_lines.append(f"Message {i + 1} [unknown]:") formatted_lines.append(separator) formatted_lines.append(f"{msg}") # Add separator before next message (or at end) if i < len(message_history) - 1: formatted_lines.append(separator) return "\n".join(formatted_lines) def _update_tool_call_with_delta( tool_calls_in_progress: dict[int, dict[str, Any]], tool_call_delta: Any, ) -> None: index = tool_call_delta.index if index not in tool_calls_in_progress: tool_calls_in_progress[index] = { # Fallback ID in case the provider never sends one via deltas. "id": f"fallback_{uuid.uuid4().hex}", "name": None, "arguments": "", } if tool_call_delta.id: tool_calls_in_progress[index]["id"] = tool_call_delta.id if tool_call_delta.function: if tool_call_delta.function.name: tool_calls_in_progress[index]["name"] = tool_call_delta.function.name if tool_call_delta.function.arguments: tool_calls_in_progress[index][ "arguments" ] += tool_call_delta.function.arguments def _extract_tool_call_kickoffs( id_to_tool_call_map: dict[int, dict[str, Any]], turn_index: int, tab_index: int | None = None, sub_turn_index: int | None = None, ) -> list[ToolCallKickoff]: """Extract ToolCallKickoff objects from the tool call map. Returns a list of ToolCallKickoff objects for valid tool calls (those with both id and name). Each tool call is assigned the given turn_index and a tab_index based on its order. Args: id_to_tool_call_map: Map of tool call index to tool call data turn_index: The turn index for this set of tool calls tab_index: If provided, use this tab_index for all tool calls (otherwise auto-increment) sub_turn_index: The sub-turn index for nested tool calls """ tool_calls: list[ToolCallKickoff] = [] tab_index_calculated = 0 for tool_call_data in id_to_tool_call_map.values(): if tool_call_data.get("id") and tool_call_data.get("name"): tool_args = _parse_tool_args_to_dict(tool_call_data.get("arguments")) tool_calls.append( ToolCallKickoff( tool_call_id=tool_call_data["id"], tool_name=tool_call_data["name"], tool_args=tool_args, placement=Placement( turn_index=turn_index, tab_index=( tab_index_calculated if tab_index is None else tab_index ), sub_turn_index=sub_turn_index, ), ) ) tab_index_calculated += 1 return tool_calls def extract_tool_calls_from_response_text( response_text: str | None, tool_definitions: list[dict], placement: Placement, ) -> list[ToolCallKickoff]: """Extract tool calls from LLM response text by matching JSON against tool definitions. This is a fallback mechanism for when the LLM was expected to return tool calls but didn't use the proper tool call format. It searches for tool calls embedded in response text (JSON first, then XML-like invoke blocks) that match available tool definitions. Args: response_text: The LLM's text response to search for tool calls tool_definitions: List of tool definitions to match against placement: Placement information for the tool calls Returns: List of ToolCallKickoff objects for any matched tool calls """ if not response_text or not tool_definitions: return [] # Build a map of tool names to their definitions tool_name_to_def: dict[str, dict] = {} for tool_def in tool_definitions: if tool_def.get("type") == "function" and "function" in tool_def: func_def = tool_def["function"] tool_name = func_def.get("name") if tool_name: tool_name_to_def[tool_name] = func_def if not tool_name_to_def: return [] matched_tool_calls: list[tuple[str, dict[str, Any]]] = [] # Find all JSON objects in the response text json_objects = find_all_json_objects(response_text) prev_json_obj: dict[str, Any] | None = None prev_tool_call: tuple[str, dict[str, Any]] | None = None for json_obj in json_objects: matched_tool_call = _try_match_json_to_tool(json_obj, tool_name_to_def) if not matched_tool_call: continue # `find_all_json_objects` can return both an outer tool-call object and # its nested arguments object. If both resolve to the same tool call, # drop only this nested duplicate artifact. if ( prev_json_obj is not None and prev_tool_call is not None and matched_tool_call == prev_tool_call and _is_nested_arguments_duplicate( previous_json_obj=prev_json_obj, current_json_obj=json_obj, tool_name_to_def=tool_name_to_def, ) ): continue matched_tool_calls.append(matched_tool_call) prev_json_obj = json_obj prev_tool_call = matched_tool_call # Some providers/models emit XML-style function calls instead of JSON objects. # Keep this as a fallback behind JSON extraction to preserve current behavior. if not matched_tool_calls: matched_tool_calls = _extract_xml_tool_calls_from_response_text( response_text=response_text, tool_name_to_def=tool_name_to_def, ) tool_calls: list[ToolCallKickoff] = [] for tab_index, (tool_name, tool_args) in enumerate(matched_tool_calls): tool_calls.append( ToolCallKickoff( tool_call_id=f"extracted_{uuid.uuid4().hex[:8]}", tool_name=tool_name, tool_args=tool_args, placement=Placement( turn_index=placement.turn_index, tab_index=tab_index, sub_turn_index=placement.sub_turn_index, ), ) ) logger.info( f"Extracted {len(tool_calls)} tool call(s) from response text as fallback" ) return tool_calls def _extract_xml_tool_calls_from_response_text( response_text: str, tool_name_to_def: dict[str, dict], ) -> list[tuple[str, dict[str, Any]]]: """Extract XML-style tool calls from response text. Supports formats such as: ["foo"] """ matched_tool_calls: list[tuple[str, dict[str, Any]]] = [] for invoke_match in _XML_INVOKE_BLOCK_RE.finditer(response_text): invoke_attrs = invoke_match.group("attrs") tool_name = _extract_xml_attribute(invoke_attrs, "name") if not tool_name or tool_name not in tool_name_to_def: continue tool_args: dict[str, Any] = {} invoke_body = invoke_match.group("body") for parameter_match in _XML_PARAMETER_RE.finditer(invoke_body): parameter_attrs = parameter_match.group("attrs") parameter_name = _extract_xml_attribute(parameter_attrs, "name") if not parameter_name: continue string_attr = _extract_xml_attribute(parameter_attrs, "string") tool_args[parameter_name] = _parse_xml_parameter_value( raw_value=parameter_match.group("value"), string_attr=string_attr, ) matched_tool_calls.append((tool_name, tool_args)) return matched_tool_calls def _extract_xml_attribute(attrs: str, attr_name: str) -> str | None: """Extract a single XML-style attribute value from a tag attribute string.""" attr_match = re.search( rf"""\b{re.escape(attr_name)}\s*=\s*(['"])(.*?)\1""", attrs, flags=re.IGNORECASE | re.DOTALL, ) if not attr_match: return None return sanitize_string(unescape(attr_match.group(2).strip())) def _parse_xml_parameter_value(raw_value: str, string_attr: str | None) -> Any: """Parse a parameter value from XML-style tool call payloads.""" value = sanitize_string(unescape(raw_value).strip()) if string_attr and string_attr.lower() == "true": return value try: return json.loads(value) except json.JSONDecodeError: return value def _resolve_tool_arguments(obj: dict[str, Any]) -> dict[str, Any] | None: """Extract and parse an arguments/parameters value from a tool-call-like object. Looks for "arguments" or "parameters" keys, handles JSON-string values, and returns a dict if successful, or None otherwise. """ arguments = obj.get("arguments", obj.get("parameters", {})) if isinstance(arguments, str): arguments = sanitize_string(arguments) try: arguments = json.loads(arguments) except json.JSONDecodeError: arguments = {} if isinstance(arguments, dict): return arguments return None def _try_match_json_to_tool( json_obj: dict[str, Any], tool_name_to_def: dict[str, dict], ) -> tuple[str, dict[str, Any]] | None: """Try to match a JSON object to a tool definition. Supports several formats: 1. Direct tool call format: {"name": "tool_name", "arguments": {...}} 2. Function call format: {"function": {"name": "tool_name", "arguments": {...}}} 3. Tool name as key: {"tool_name": {...arguments...}} 4. Arguments matching a tool's parameter schema Args: json_obj: The JSON object to match tool_name_to_def: Map of tool names to their function definitions Returns: Tuple of (tool_name, tool_args) if matched, None otherwise """ # Format 1: Direct tool call format {"name": "...", "arguments": {...}} if "name" in json_obj and json_obj["name"] in tool_name_to_def: tool_name = json_obj["name"] arguments = _resolve_tool_arguments(json_obj) if arguments is not None: return (tool_name, arguments) # Format 2: Function call format {"function": {"name": "...", "arguments": {...}}} if "function" in json_obj and isinstance(json_obj["function"], dict): func_obj = json_obj["function"] if "name" in func_obj and func_obj["name"] in tool_name_to_def: tool_name = func_obj["name"] arguments = _resolve_tool_arguments(func_obj) if arguments is not None: return (tool_name, arguments) # Format 3: Tool name as key {"tool_name": {...arguments...}} for tool_name in tool_name_to_def: if tool_name in json_obj: arguments = json_obj[tool_name] if isinstance(arguments, dict): return (tool_name, arguments) # Format 4: Check if the JSON object matches a tool's parameter schema for tool_name, func_def in tool_name_to_def.items(): params = func_def.get("parameters", {}) properties = params.get("properties", {}) required = params.get("required", []) if not properties: continue # Check if all required parameters are present (empty required = all optional) if all(req in json_obj for req in required): # Check if any of the tool's properties are in the JSON object matching_props = [prop for prop in properties if prop in json_obj] if matching_props: # Filter to only include known properties filtered_args = {k: v for k, v in json_obj.items() if k in properties} return (tool_name, filtered_args) return None def _is_nested_arguments_duplicate( previous_json_obj: dict[str, Any], current_json_obj: dict[str, Any], tool_name_to_def: dict[str, dict], ) -> bool: """Detect when current object is the nested args object from previous tool call.""" extracted_args = _extract_nested_arguments_obj(previous_json_obj, tool_name_to_def) return extracted_args is not None and current_json_obj == extracted_args def _extract_nested_arguments_obj( json_obj: dict[str, Any], tool_name_to_def: dict[str, dict], ) -> dict[str, Any] | None: # Format 1: {"name": "...", "arguments": {...}} or {"name": "...", "parameters": {...}} if "name" in json_obj and json_obj["name"] in tool_name_to_def: args_obj = json_obj.get("arguments", json_obj.get("parameters")) if isinstance(args_obj, dict): return args_obj # Format 2: {"function": {"name": "...", "arguments": {...}}} if "function" in json_obj and isinstance(json_obj["function"], dict): function_obj = json_obj["function"] if "name" in function_obj and function_obj["name"] in tool_name_to_def: args_obj = function_obj.get("arguments", function_obj.get("parameters")) if isinstance(args_obj, dict): return args_obj # Format 3: {"tool_name": {...arguments...}} for tool_name in tool_name_to_def: if tool_name in json_obj and isinstance(json_obj[tool_name], dict): return json_obj[tool_name] return None def _build_structured_assistant_message(msg: ChatMessageSimple) -> AssistantMessage: tool_calls_list: list[ToolCall] | None = None if msg.tool_calls: tool_calls_list = [ ToolCall( id=tc.tool_call_id, type="function", function=FunctionCall( name=tc.tool_name, arguments=json.dumps(tc.tool_arguments), ), ) for tc in msg.tool_calls ] return AssistantMessage( role="assistant", content=msg.message or None, tool_calls=tool_calls_list, ) def _build_structured_tool_response_message(msg: ChatMessageSimple) -> ToolMessage: if not msg.tool_call_id: raise ValueError( f"Tool call response message encountered but tool_call_id is not available. Message: {msg}" ) return ToolMessage( role="tool", content=msg.message, tool_call_id=msg.tool_call_id, ) class _HistoryMessageFormatter: def format_assistant_message(self, msg: ChatMessageSimple) -> AssistantMessage: raise NotImplementedError def format_tool_response_message( self, msg: ChatMessageSimple ) -> ToolMessage | UserMessage: raise NotImplementedError class _DefaultHistoryMessageFormatter(_HistoryMessageFormatter): def format_assistant_message(self, msg: ChatMessageSimple) -> AssistantMessage: return _build_structured_assistant_message(msg) def format_tool_response_message(self, msg: ChatMessageSimple) -> ToolMessage: return _build_structured_tool_response_message(msg) class _OllamaHistoryMessageFormatter(_HistoryMessageFormatter): def format_assistant_message(self, msg: ChatMessageSimple) -> AssistantMessage: if not msg.tool_calls: return _build_structured_assistant_message(msg) tool_call_lines = [ ( f"[Tool Call] name={tc.tool_name} id={tc.tool_call_id} args={json.dumps(tc.tool_arguments)}" ) for tc in msg.tool_calls ] assistant_content = ( "\n".join([msg.message, *tool_call_lines]) if msg.message else "\n".join(tool_call_lines) ) return AssistantMessage( role="assistant", content=assistant_content, tool_calls=None, ) def format_tool_response_message(self, msg: ChatMessageSimple) -> UserMessage: if not msg.tool_call_id: raise ValueError( f"Tool call response message encountered but tool_call_id is not available. Message: {msg}" ) return UserMessage( role="user", content=f"[Tool Result] id={msg.tool_call_id}\n{msg.message}", ) _DEFAULT_HISTORY_MESSAGE_FORMATTER = _DefaultHistoryMessageFormatter() _OLLAMA_HISTORY_MESSAGE_FORMATTER = _OllamaHistoryMessageFormatter() def _get_history_message_formatter(llm_config: LLMConfig) -> _HistoryMessageFormatter: if llm_config.model_provider == LlmProviderNames.OLLAMA_CHAT: return _OLLAMA_HISTORY_MESSAGE_FORMATTER return _DEFAULT_HISTORY_MESSAGE_FORMATTER def translate_history_to_llm_format( history: list[ChatMessageSimple], llm_config: LLMConfig, ) -> LanguageModelInput: """Convert a list of ChatMessageSimple to LanguageModelInput format. Converts ChatMessageSimple messages to ChatCompletionMessage format, handling different message types and image files for multimodal support. """ messages: list[ChatCompletionMessage] = [] history_message_formatter = _get_history_message_formatter(llm_config) # Note: cacheability is computed from pre-translation ChatMessageSimple types. # Some providers flatten tool history into plain assistant/user text, so this split # may be less semantically meaningful, but it remains safe and order-preserving. last_cacheable_msg_idx = -1 all_previous_msgs_cacheable = True for idx, msg in enumerate(history): # if the message is being added to the history if PROMPT_CACHE_CHAT_HISTORY and msg.message_type in [ MessageType.SYSTEM, MessageType.USER, MessageType.USER_REMINDER, MessageType.ASSISTANT, MessageType.TOOL_CALL_RESPONSE, ]: all_previous_msgs_cacheable = ( all_previous_msgs_cacheable and msg.should_cache ) if all_previous_msgs_cacheable: last_cacheable_msg_idx = idx if msg.message_type == MessageType.SYSTEM: system_msg = SystemMessage( role="system", content=msg.message, ) messages.append(system_msg) elif msg.message_type == MessageType.USER: # Handle user messages with potential images if msg.image_files: # Build content parts: text + images content_parts: list[TextContentPart | ImageContentPart] = [ TextContentPart( type="text", text=msg.message, ) ] # Add image parts for img_file in msg.image_files: if img_file.file_type == ChatFileType.IMAGE: try: image_type = get_image_type_from_bytes(img_file.content) base64_data = img_file.to_base64() image_url = f"data:{image_type};base64,{base64_data}" image_part = ImageContentPart( type="image_url", image_url=ImageUrlDetail( url=image_url, detail=None, ), ) content_parts.append(image_part) except Exception as e: logger.warning( f"Failed to process image file {img_file.file_id}: {e}. Skipping image." ) user_msg = UserMessage( role="user", content=content_parts, ) messages.append(user_msg) else: # Simple text-only user message user_msg_text = UserMessage( role="user", content=msg.message, ) messages.append(user_msg_text) elif msg.message_type == MessageType.USER_REMINDER: # User reminder messages are wrapped with system-reminder tags # and converted to UserMessage (LLM APIs don't have a native reminder type) wrapped_content = f"{SYSTEM_REMINDER_TAG_OPEN}\n{msg.message}\n{SYSTEM_REMINDER_TAG_CLOSE}" reminder_msg = UserMessage( role="user", content=wrapped_content, ) messages.append(reminder_msg) elif msg.message_type == MessageType.ASSISTANT: messages.append(history_message_formatter.format_assistant_message(msg)) elif msg.message_type == MessageType.TOOL_CALL_RESPONSE: messages.append(history_message_formatter.format_tool_response_message(msg)) else: logger.warning( f"Unknown message type {msg.message_type} in history. Skipping message." ) # Apply model-specific formatting when translating to LLM format (e.g. OpenAI # reasoning models need CODE_BLOCK_MARKDOWN prefix for correct markdown generation) if model_needs_formatting_reenabled(llm_config.model_name): for i, m in enumerate(messages): if isinstance(m, SystemMessage): messages[i] = SystemMessage( role="system", content=CODE_BLOCK_MARKDOWN + m.content, ) break # prompt caching: rely on should_cache in ChatMessageSimple to # pick the split point for the cacheable prefix and suffix if last_cacheable_msg_idx != -1: processed_messages, _ = process_with_prompt_cache( llm_config=llm_config, cacheable_prefix=messages[: last_cacheable_msg_idx + 1], suffix=messages[last_cacheable_msg_idx + 1 :], continuation=False, ) assert isinstance(processed_messages, list) # for mypy messages = processed_messages return messages def _increment_turns( turn_index: int, sub_turn_index: int | None ) -> tuple[int, int | None]: if sub_turn_index is None: return turn_index + 1, None else: return turn_index, sub_turn_index + 1 def _delta_has_action(delta: Delta) -> bool: return bool(delta.content or delta.reasoning_content or delta.tool_calls) def run_llm_step_pkt_generator( history: list[ChatMessageSimple], tool_definitions: list[dict], tool_choice: ToolChoiceOptions, llm: LLM, placement: Placement, state_container: ChatStateContainer | None, citation_processor: DynamicCitationProcessor | None, reasoning_effort: ReasoningEffort = ReasoningEffort.AUTO, final_documents: list[SearchDoc] | None = None, user_identity: LLMUserIdentity | None = None, custom_token_processor: ( Callable[[Delta | None, Any], tuple[Delta | None, Any]] | None ) = None, max_tokens: int | None = None, # TODO: Temporary handling of nested tool calls with agents, figure out a better way to handle this use_existing_tab_index: bool = False, is_deep_research: bool = False, pre_answer_processing_time: float | None = None, timeout_override: int | None = None, ) -> Generator[Packet, None, tuple[LlmStepResult, bool]]: """Run an LLM step and stream the response as packets. NOTE: DO NOT TOUCH THIS FUNCTION BEFORE ASKING YUHONG, this is very finicky and delicate logic that is core to the app's main functionality. This generator function streams LLM responses, processing reasoning content, answer content, tool calls, and citations. It yields Packet objects for real-time streaming to clients and accumulates the final result. Args: history: List of chat messages in the conversation history. tool_definitions: List of tool definitions available to the LLM. tool_choice: Tool choice configuration (e.g., "auto", "required", "none"). llm: Language model interface to use for generation. placement: Placement info (turn_index, tab_index, sub_turn_index) for positioning packets in the conversation UI. state_container: Container for storing chat state (reasoning, answers). citation_processor: Optional processor for extracting and formatting citations from the response. If provided, processes tokens to identify citations. reasoning_effort: Optional reasoning effort configuration for models that support reasoning (e.g., o1 models). final_documents: Optional list of search documents to include in the response start packet. user_identity: Optional user identity information for the LLM. custom_token_processor: Optional callable that processes each token delta before yielding. Receives (delta, processor_state) and returns (modified_delta, new_processor_state). Can return None for delta to skip. max_tokens: Optional maximum number of tokens for the LLM response. use_existing_tab_index: If True, use the tab_index from placement for all tool calls instead of auto-incrementing. is_deep_research: If True, treat content before tool calls as reasoning when tool_choice is REQUIRED. pre_answer_processing_time: Optional time spent processing before the answer started, recorded in state_container for analytics. timeout_override: Optional timeout override for the LLM call. Yields: Packet: Streaming packets containing: - ReasoningStart/ReasoningDelta/ReasoningDone for reasoning content - AgentResponseStart/AgentResponseDelta for answer content - CitationInfo for extracted citations - ToolCallKickoff for tool calls (extracted at the end) Returns: tuple[LlmStepResult, bool]: A tuple containing: - LlmStepResult: The final result with accumulated reasoning, answer, and tool calls (if any). - bool: Whether reasoning occurred during this step. This should be used to increment the turn index or sub_turn index for the rest of the LLM loop. Note: The function handles incremental state updates, saving reasoning and answer tokens to the state container as they are generated. Tool calls are extracted and yielded only after the stream completes. """ turn_index = placement.turn_index tab_index = placement.tab_index sub_turn_index = placement.sub_turn_index def _current_placement() -> Placement: return Placement( turn_index=turn_index, tab_index=tab_index, sub_turn_index=sub_turn_index, ) llm_msg_history = translate_history_to_llm_format(history, llm.config) has_reasoned = False if LOG_ONYX_MODEL_INTERACTIONS: logger.debug( f"Message history:\n{_format_message_history_for_logging(llm_msg_history)}" ) id_to_tool_call_map: dict[int, dict[str, Any]] = {} arg_parsers: dict[int, Parser] = {} reasoning_start = False answer_start = False accumulated_reasoning = "" accumulated_answer = "" accumulated_raw_answer = "" stream_chunk_count = 0 actionable_chunk_count = 0 empty_chunk_count = 0 finish_reasons: set[str] = set() xml_tool_call_content_filter = _XmlToolCallContentFilter() processor_state: Any = None with generation_span( model=llm.config.model_name, model_config={ "base_url": str(llm.config.api_base or ""), "model_impl": "litellm", }, ) as span_generation: span_generation.span_data.input = cast( Sequence[Mapping[str, Any]], llm_msg_history ) stream_start_time = time.monotonic() first_action_recorded = False def _emit_citation_results( results: Generator[str | CitationInfo, None, None], ) -> Generator[Packet, None, None]: """Yield packets for citation processor results (str or CitationInfo).""" nonlocal accumulated_answer for result in results: if isinstance(result, str): accumulated_answer += result if state_container: state_container.set_answer_tokens(accumulated_answer) yield Packet( placement=_current_placement(), obj=AgentResponseDelta(content=result), ) elif isinstance(result, CitationInfo): yield Packet( placement=_current_placement(), obj=result, ) if state_container: state_container.add_emitted_citation(result.citation_number) def _close_reasoning_if_active() -> Generator[Packet, None, None]: """Emit ReasoningDone and increment turns if reasoning is in progress.""" nonlocal reasoning_start nonlocal has_reasoned nonlocal turn_index nonlocal sub_turn_index if reasoning_start: yield Packet( placement=Placement( turn_index=turn_index, tab_index=tab_index, sub_turn_index=sub_turn_index, ), obj=ReasoningDone(), ) has_reasoned = True turn_index, sub_turn_index = _increment_turns( turn_index, sub_turn_index ) reasoning_start = False def _emit_content_chunk(content_chunk: str) -> Generator[Packet, None, None]: nonlocal accumulated_answer nonlocal accumulated_reasoning nonlocal answer_start nonlocal reasoning_start nonlocal turn_index nonlocal sub_turn_index # When tool_choice is REQUIRED, content before tool calls is reasoning/thinking # about which tool to call, not an actual answer to the user. # Treat this content as reasoning instead of answer. if is_deep_research and tool_choice == ToolChoiceOptions.REQUIRED: accumulated_reasoning += content_chunk if state_container: state_container.set_reasoning_tokens(accumulated_reasoning) if not reasoning_start: yield Packet( placement=_current_placement(), obj=ReasoningStart(), ) yield Packet( placement=_current_placement(), obj=ReasoningDelta(reasoning=content_chunk), ) reasoning_start = True return # Normal flow for AUTO or NONE tool choice yield from _close_reasoning_if_active() if not answer_start: # Store pre-answer processing time in state container for save_chat if state_container and pre_answer_processing_time is not None: state_container.set_pre_answer_processing_time( pre_answer_processing_time ) yield Packet( placement=_current_placement(), obj=AgentResponseStart( final_documents=final_documents, pre_answer_processing_seconds=pre_answer_processing_time, ), ) answer_start = True if citation_processor: yield from _emit_citation_results( citation_processor.process_token(content_chunk) ) else: accumulated_answer += content_chunk # Save answer incrementally to state container if state_container: state_container.set_answer_tokens(accumulated_answer) yield Packet( placement=_current_placement(), obj=AgentResponseDelta(content=content_chunk), ) for packet in llm.stream( prompt=llm_msg_history, tools=tool_definitions, tool_choice=tool_choice, structured_response_format=None, # TODO max_tokens=max_tokens, reasoning_effort=reasoning_effort, user_identity=user_identity, timeout_override=timeout_override, ): stream_chunk_count += 1 if packet.usage: usage = packet.usage span_generation.span_data.usage = { "input_tokens": usage.prompt_tokens, "output_tokens": usage.completion_tokens, "cache_read_input_tokens": usage.cache_read_input_tokens, "cache_creation_input_tokens": usage.cache_creation_input_tokens, } # Note: LLM cost tracking is now handled in multi_llm.py finish_reason = packet.choice.finish_reason if finish_reason: finish_reasons.add(str(finish_reason)) delta = packet.choice.delta # Weird behavior from some model providers, just log and ignore for now if ( not delta.content and delta.reasoning_content is None and not delta.tool_calls ): empty_chunk_count += 1 logger.warning( "LLM packet is empty (no content, reasoning, or tool calls). " f"finish_reason={finish_reason}. Skipping: {packet}" ) continue if not first_action_recorded and _delta_has_action(delta): span_generation.span_data.time_to_first_action_seconds = ( time.monotonic() - stream_start_time ) first_action_recorded = True if _delta_has_action(delta): actionable_chunk_count += 1 if custom_token_processor: # The custom token processor can modify the deltas for specific custom logic # It can also return a state so that it can handle aggregated delta logic etc. # Loosely typed so the function can be flexible modified_delta, processor_state = custom_token_processor( delta, processor_state ) if modified_delta is None: continue delta = modified_delta # Should only happen once, frontend does not expect multiple # ReasoningStart or ReasoningDone packets. if delta.reasoning_content: accumulated_reasoning += delta.reasoning_content # Save reasoning incrementally to state container if state_container: state_container.set_reasoning_tokens(accumulated_reasoning) if not reasoning_start: yield Packet( placement=_current_placement(), obj=ReasoningStart(), ) yield Packet( placement=_current_placement(), obj=ReasoningDelta(reasoning=delta.reasoning_content), ) reasoning_start = True if delta.content: # Keep raw content for fallback extraction. Display content can be # filtered and, in deep-research REQUIRED mode, routed as reasoning. accumulated_raw_answer += delta.content filtered_content = xml_tool_call_content_filter.process(delta.content) if filtered_content: yield from _emit_content_chunk(filtered_content) if delta.tool_calls: yield from _close_reasoning_if_active() for tool_call_delta in delta.tool_calls: # maybe_emit depends and update being called first and attaching the delta _update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta) yield from maybe_emit_argument_delta( tool_calls_in_progress=id_to_tool_call_map, tool_call_delta=tool_call_delta, placement=_current_placement(), parsers=arg_parsers, ) # Flush any tail text buffered while checking for split " tuple[LlmStepResult, bool]: """Wrapper around run_llm_step_pkt_generator that consumes packets and emits them. Returns: tuple[LlmStepResult, bool]: The LLM step result and whether reasoning occurred. """ step_generator = run_llm_step_pkt_generator( history=history, tool_definitions=tool_definitions, tool_choice=tool_choice, llm=llm, placement=placement, state_container=state_container, citation_processor=citation_processor, reasoning_effort=reasoning_effort, final_documents=final_documents, user_identity=user_identity, custom_token_processor=custom_token_processor, max_tokens=max_tokens, use_existing_tab_index=use_existing_tab_index, is_deep_research=is_deep_research, pre_answer_processing_time=pre_answer_processing_time, timeout_override=timeout_override, ) while True: try: packet = next(step_generator) emitter.emit(packet) except StopIteration as e: llm_step_result, has_reasoned = e.value return llm_step_result, has_reasoned ================================================ FILE: backend/onyx/chat/models.py ================================================ from collections.abc import Iterator from typing import Any from uuid import UUID from pydantic import BaseModel from onyx.configs.constants import MessageType from onyx.context.search.models import SearchDoc from onyx.file_store.models import InMemoryChatFile from onyx.server.query_and_chat.models import MessageResponseIDInfo from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo from onyx.server.query_and_chat.streaming_models import CitationInfo from onyx.server.query_and_chat.streaming_models import GeneratedImage from onyx.server.query_and_chat.streaming_models import Packet from onyx.tools.models import SearchToolUsage from onyx.tools.models import ToolCallKickoff from onyx.tools.tool_implementations.custom.base_tool_types import ToolResultType class StreamingError(BaseModel): error: str stack_trace: str | None = None error_code: str | None = ( None # e.g., "RATE_LIMIT", "AUTH_ERROR", "TOOL_CALL_FAILED" ) is_retryable: bool = True # Hint to frontend if retry might help details: dict | None = None # Additional context (tool name, model name, etc.) class CustomToolResponse(BaseModel): response: ToolResultType tool_name: str class CreateChatSessionID(BaseModel): chat_session_id: UUID AnswerStreamPart = ( Packet | MessageResponseIDInfo | MultiModelMessageResponseIDInfo | StreamingError | CreateChatSessionID ) AnswerStream = Iterator[AnswerStreamPart] class ToolCallResponse(BaseModel): """Tool call with full details for non-streaming response.""" tool_name: str tool_arguments: dict[str, Any] tool_result: str search_docs: list[SearchDoc] | None = None generated_images: list[GeneratedImage] | None = None # Reasoning that led to the tool call pre_reasoning: str | None = None class ChatBasicResponse(BaseModel): # This is built piece by piece, any of these can be None as the flow could break answer: str answer_citationless: str top_documents: list[SearchDoc] error_msg: str | None message_id: int citation_info: list[CitationInfo] class ChatFullResponse(BaseModel): """Complete non-streaming response with all available data. NOTE: This model is used for the core flow of the Onyx application, any changes to it should be reviewed and approved by an experienced team member. It is very important to 1. avoid bloat and 2. that this remains backwards compatible across versions. """ # Core response fields answer: str answer_citationless: str pre_answer_reasoning: str | None = None tool_calls: list[ToolCallResponse] = [] # Documents & citations top_documents: list[SearchDoc] citation_info: list[CitationInfo] # Metadata message_id: int chat_session_id: UUID | None = None error_msg: str | None = None class ChatLoadedFile(InMemoryChatFile): content_text: str | None token_count: int class ToolCallSimple(BaseModel): """Tool call for ChatMessageSimple representation (mirrors OpenAI format). Used when an ASSISTANT message contains one or more tool calls. Each tool call has an ID, name, arguments, and token count for tracking. """ tool_call_id: str tool_name: str tool_arguments: dict[str, Any] token_count: int = 0 class ChatMessageSimple(BaseModel): message: str token_count: int message_type: MessageType # Only for USER type messages image_files: list[ChatLoadedFile] | None = None # Only for TOOL_CALL_RESPONSE type messages tool_call_id: str | None = None # For ASSISTANT messages with tool calls (OpenAI parallel tool calling format) tool_calls: list[ToolCallSimple] | None = None # The last message for which this is true # AND is true for all previous messages # (counting from the start of the history) # represents the end of the cacheable prefix # used for prompt caching should_cache: bool = False # When this message represents an injected text file, this is the file's ID. # Used to detect which file messages survive context-window truncation. file_id: str | None = None class ContextFileMetadata(BaseModel): """Metadata for a context-injected file to enable citation support.""" file_id: str filename: str file_content: str class FileToolMetadata(BaseModel): """Lightweight metadata for exposing files to the FileReaderTool. Used when files cannot be loaded directly into context (project too large or persona-attached user_files without direct-load path). The LLM receives a listing of these so it knows which files it can read via ``read_file``. """ file_id: str filename: str approx_char_count: int class ChatHistoryResult(BaseModel): """Result of converting chat history to simple format. Bundles the simple messages with metadata for every text file that was injected into the history. After context-window truncation drops older messages, callers compare surviving ``file_id`` tags against this map to discover "forgotten" files whose metadata should be provided to the FileReaderTool. """ simple_messages: list[ChatMessageSimple] all_injected_file_metadata: dict[str, FileToolMetadata] class ExtractedContextFiles(BaseModel): """Result of attempting to load user files (from a project or persona) into context.""" file_texts: list[str] image_files: list[ChatLoadedFile] use_as_search_filter: bool total_token_count: int # Lightweight metadata for files exposed via FileReaderTool # (populated when files don't fit in context and vector DB is disabled). file_metadata: list[ContextFileMetadata] uncapped_token_count: int | None file_metadata_for_tool: list[FileToolMetadata] = [] class SearchParams(BaseModel): """Resolved search filter IDs and search-tool usage for a chat turn.""" project_id_filter: int | None persona_id_filter: int | None search_usage: SearchToolUsage class LlmStepResult(BaseModel): reasoning: str | None answer: str | None tool_calls: list[ToolCallKickoff] | None # Raw LLM text before any display-oriented filtering/sanitization. # Used for fallback tool-call extraction when providers emit calls as text. raw_answer: str | None = None ================================================ FILE: backend/onyx/chat/process_message.py ================================================ """ IMPORTANT: familiarize yourself with the design concepts prior to contributing to this file. An overview can be found in the README.md file in this directory. """ import contextvars import io import queue import re import threading import traceback from collections.abc import Callable from collections.abc import Generator from concurrent.futures import ThreadPoolExecutor from contextvars import Token from typing import Final from uuid import UUID from sqlalchemy.orm import Session from onyx.cache.factory import get_cache_backend from onyx.chat.chat_processing_checker import set_processing_status from onyx.chat.chat_state import AvailableFiles from onyx.chat.chat_state import ChatStateContainer from onyx.chat.chat_state import ChatTurnSetup from onyx.chat.chat_utils import build_file_context from onyx.chat.chat_utils import convert_chat_history from onyx.chat.chat_utils import create_chat_history_chain from onyx.chat.chat_utils import create_chat_session_from_request from onyx.chat.chat_utils import get_custom_agent_prompt from onyx.chat.chat_utils import is_last_assistant_message_clarification from onyx.chat.chat_utils import load_all_chat_files from onyx.chat.compression import calculate_total_history_tokens from onyx.chat.compression import compress_chat_history from onyx.chat.compression import find_summary_for_branch from onyx.chat.compression import get_compression_params from onyx.chat.emitter import Emitter from onyx.chat.llm_loop import EmptyLLMResponseError from onyx.chat.llm_loop import run_llm_loop from onyx.chat.models import AnswerStream from onyx.chat.models import AnswerStreamPart from onyx.chat.models import ChatBasicResponse from onyx.chat.models import ChatFullResponse from onyx.chat.models import ChatLoadedFile from onyx.chat.models import ChatMessageSimple from onyx.chat.models import ContextFileMetadata from onyx.chat.models import CreateChatSessionID from onyx.chat.models import ExtractedContextFiles from onyx.chat.models import FileToolMetadata from onyx.chat.models import SearchParams from onyx.chat.models import StreamingError from onyx.chat.models import ToolCallResponse from onyx.chat.prompt_utils import calculate_reserved_tokens from onyx.chat.save_chat import save_chat_turn from onyx.chat.stop_signal_checker import is_connected as check_stop_signal from onyx.chat.stop_signal_checker import reset_cancel_status from onyx.configs.app_configs import DISABLE_VECTOR_DB from onyx.configs.app_configs import INTEGRATION_TESTS_MODE from onyx.configs.constants import DEFAULT_PERSONA_ID from onyx.configs.constants import DocumentSource from onyx.configs.constants import MessageType from onyx.configs.constants import MilestoneRecordType from onyx.context.search.models import BaseFilters from onyx.context.search.models import SearchDoc from onyx.db.chat import create_new_chat_message from onyx.db.chat import get_chat_session_by_id from onyx.db.chat import get_or_create_root_message from onyx.db.chat import reserve_message_id from onyx.db.chat import reserve_multi_model_message_ids from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.enums import HookPoint from onyx.db.memory import get_memories from onyx.db.models import ChatMessage from onyx.db.models import Persona from onyx.db.models import User from onyx.db.models import UserFile from onyx.db.projects import get_user_files_from_project from onyx.db.tools import get_tools from onyx.deep_research.dr_loop import run_deep_research_llm_loop from onyx.error_handling.error_codes import OnyxErrorCode from onyx.error_handling.exceptions import log_onyx_error from onyx.error_handling.exceptions import OnyxError from onyx.file_processing.extract_file_text import extract_file_text from onyx.file_store.models import ChatFileType from onyx.file_store.models import InMemoryChatFile from onyx.file_store.utils import load_in_memory_chat_files from onyx.file_store.utils import verify_user_files from onyx.hooks.executor import execute_hook from onyx.hooks.executor import HookSkipped from onyx.hooks.executor import HookSoftFailed from onyx.hooks.points.query_processing import QueryProcessingPayload from onyx.hooks.points.query_processing import QueryProcessingResponse from onyx.llm.factory import get_llm_for_persona from onyx.llm.factory import get_llm_token_counter from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLMUserIdentity from onyx.llm.override_models import LLMOverride from onyx.llm.request_context import reset_llm_mock_response from onyx.llm.request_context import set_llm_mock_response from onyx.llm.utils import litellm_exception_to_error_msg from onyx.onyxbot.slack.models import SlackContext from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE from onyx.server.query_and_chat.models import MessageResponseIDInfo from onyx.server.query_and_chat.models import ModelResponseSlot from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo from onyx.server.query_and_chat.models import SendMessageRequest from onyx.server.query_and_chat.placement import Placement from onyx.server.query_and_chat.streaming_models import AgentResponseDelta from onyx.server.query_and_chat.streaming_models import AgentResponseStart from onyx.server.query_and_chat.streaming_models import CitationInfo from onyx.server.query_and_chat.streaming_models import OverallStop from onyx.server.query_and_chat.streaming_models import Packet from onyx.server.usage_limits import check_llm_cost_limit_for_provider from onyx.tools.constants import FILE_READER_TOOL_ID from onyx.tools.constants import SEARCH_TOOL_ID from onyx.tools.models import ChatFile from onyx.tools.models import SearchToolUsage from onyx.tools.tool_constructor import construct_tools from onyx.tools.tool_constructor import CustomToolConfig from onyx.tools.tool_constructor import FileReaderToolConfig from onyx.tools.tool_constructor import SearchToolConfig from onyx.utils.logger import setup_logger from onyx.utils.telemetry import mt_cloud_telemetry from onyx.utils.timing import log_function_time from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() ERROR_TYPE_CANCELLED = "cancelled" APPROX_CHARS_PER_TOKEN = 4 def _collect_available_file_ids( chat_history: list[ChatMessage], project_id: int | None, user_id: UUID | None, db_session: Session, ) -> AvailableFiles: """Collect all file IDs the FileReaderTool should be allowed to access. Returns *separate* lists for chat-attached files (``file_record`` IDs) and project/user files (``user_file`` IDs) so the tool can pick the right loader without a try/except fallback.""" chat_file_ids: set[UUID] = set() user_file_ids: set[UUID] = set() for msg in chat_history: if not msg.files: continue for fd in msg.files: try: chat_file_ids.add(UUID(fd["id"])) except (ValueError, KeyError): pass if project_id: user_files = get_user_files_from_project( project_id=project_id, user_id=user_id, db_session=db_session, ) for uf in user_files: user_file_ids.add(uf.id) return AvailableFiles( user_file_ids=list(user_file_ids), chat_file_ids=list(chat_file_ids), ) def _should_enable_slack_search( persona: Persona, filters: BaseFilters | None, ) -> bool: """Determine if Slack search should be enabled. Returns True if: - Source type filter exists and includes Slack, OR - Default persona with no source type filter """ source_types = filters.source_type if filters else None return (source_types is not None and DocumentSource.SLACK in source_types) or ( persona.id == DEFAULT_PERSONA_ID and source_types is None ) def _convert_loaded_files_to_chat_files( loaded_files: list[ChatLoadedFile], ) -> list[ChatFile]: """Convert ChatLoadedFile objects to ChatFile for tool usage (e.g., PythonTool). Args: loaded_files: List of ChatLoadedFile objects from the chat history Returns: List of ChatFile objects that can be passed to tools """ chat_files = [] for loaded_file in loaded_files: if len(loaded_file.content) > 0: chat_files.append( ChatFile( filename=loaded_file.filename or f"file_{loaded_file.file_id}", content=loaded_file.content, ) ) return chat_files def resolve_context_user_files( persona: Persona, project_id: int | None, user_id: UUID | None, db_session: Session, ) -> list[UserFile]: """Apply the precedence rule to decide which user files to load. A custom persona fully supersedes the project. When a chat uses a custom persona, the project is purely organisational — its files are never loaded and never made searchable. Custom persona → persona's own user_files (may be empty). Default persona inside a project → project files. Otherwise → empty list. """ if persona.id != DEFAULT_PERSONA_ID: return list(persona.user_files) if persona.user_files else [] if project_id: return get_user_files_from_project( project_id=project_id, user_id=user_id, db_session=db_session, ) return [] def _empty_extracted_context_files() -> ExtractedContextFiles: return ExtractedContextFiles( file_texts=[], image_files=[], use_as_search_filter=False, total_token_count=0, file_metadata=[], uncapped_token_count=None, ) def _extract_text_from_in_memory_file(f: InMemoryChatFile) -> str | None: """Extract text content from an InMemoryChatFile. PLAIN_TEXT: the content is pre-extracted UTF-8 plaintext stored during ingestion — decode directly. DOC / CSV / other text types: the content is the original file bytes — use extract_file_text which handles encoding detection and format parsing. """ try: if f.file_type == ChatFileType.PLAIN_TEXT: return f.content.decode("utf-8", errors="ignore").replace("\x00", "") return extract_file_text( file=io.BytesIO(f.content), file_name=f.filename or "", break_on_unprocessable=False, ) except Exception: logger.warning(f"Failed to extract text from file {f.file_id}", exc_info=True) return None def extract_context_files( user_files: list[UserFile], llm_max_context_window: int, reserved_token_count: int, db_session: Session, # Because the tokenizer is a generic tokenizer, the token count may be incorrect. # to account for this, the maximum context that is allowed for this function is # 60% of the LLM's max context window. The other benefit is that for projects with # more files, this makes it so that we don't throw away the history too quickly every time. max_llm_context_percentage: float = 0.6, ) -> ExtractedContextFiles: """Load user files into context if they fit; otherwise flag for search. The caller is responsible for deciding *which* user files to pass in (project files, persona files, etc.). This function only cares about the all-or-nothing fit check and the actual content loading. Args: project_id: The project ID to load files from user_id: The user ID for authorization llm_max_context_window: Maximum tokens allowed in the LLM context window reserved_token_count: Number of tokens to reserve for other content db_session: Database session max_llm_context_percentage: Maximum percentage of the LLM context window to use. Returns: ExtractedContextFiles containing: - List of text content strings from context files (text files only) - List of image files from context (ChatLoadedFile objects) - Total token count of all extracted files - File metadata for context files - Uncapped token count of all extracted files - File metadata for files that don't fit in context and vector DB is disabled """ # TODO(yuhong): I believe this is not handling all file types correctly. if not user_files: return _empty_extracted_context_files() # Aggregate tokens for the file content that will be added # Skip tokens for those with metadata only aggregate_tokens = sum( uf.token_count or 0 for uf in user_files if not mime_type_to_chat_file_type(uf.file_type).use_metadata_only() ) max_actual_tokens = ( llm_max_context_window - reserved_token_count ) * max_llm_context_percentage if aggregate_tokens >= max_actual_tokens: use_as_search_filter = not DISABLE_VECTOR_DB if DISABLE_VECTOR_DB: overflow_tool_metadata = [_build_tool_metadata(uf) for uf in user_files] else: overflow_tool_metadata = [ _build_tool_metadata(uf) for uf in user_files if mime_type_to_chat_file_type(uf.file_type).use_metadata_only() ] return ExtractedContextFiles( file_texts=[], image_files=[], use_as_search_filter=use_as_search_filter, total_token_count=0, file_metadata=[], uncapped_token_count=aggregate_tokens, file_metadata_for_tool=overflow_tool_metadata, ) # Files fit — load them into context user_file_map = {uf.file_id: uf for uf in user_files} in_memory_files = load_in_memory_chat_files( user_file_ids=[uf.id for uf in user_files], db_session=db_session, ) file_texts: list[str] = [] image_files: list[ChatLoadedFile] = [] file_metadata: list[ContextFileMetadata] = [] tool_metadata: list[FileToolMetadata] = [] total_token_count = 0 for f in in_memory_files: uf = user_file_map.get(str(f.file_id)) filename = f.filename or f"file_{f.file_id}" if f.file_type.use_metadata_only(): # Metadata-only files are not injected as full text. # Only the metadata is provided, with LLM using tools if not uf: logger.error( f"File with id={f.file_id} in metadata-only path with no associated user file" ) continue tool_metadata.append(_build_tool_metadata(uf)) elif f.file_type.is_text_file(): text_content = _extract_text_from_in_memory_file(f) if not text_content: continue if not uf: logger.warning(f"No user file for file_id={f.file_id}") continue file_texts.append(text_content) file_metadata.append( ContextFileMetadata( file_id=str(uf.id), filename=filename, file_content=text_content, ) ) if uf.token_count: total_token_count += uf.token_count elif f.file_type == ChatFileType.IMAGE: token_count = uf.token_count if uf and uf.token_count else 0 total_token_count += token_count image_files.append( ChatLoadedFile( file_id=f.file_id, content=f.content, file_type=f.file_type, filename=f.filename, content_text=None, token_count=token_count, ) ) return ExtractedContextFiles( file_texts=file_texts, image_files=image_files, use_as_search_filter=False, total_token_count=total_token_count, file_metadata=file_metadata, uncapped_token_count=aggregate_tokens, file_metadata_for_tool=tool_metadata, ) def _build_tool_metadata(user_file: UserFile) -> FileToolMetadata: """Build lightweight FileToolMetadata from a UserFile record. Delegates to ``build_file_context`` so that the file ID exposed to the LLM is always consistent with what FileReaderTool expects. """ return build_file_context( tool_file_id=str(user_file.id), filename=user_file.name, file_type=mime_type_to_chat_file_type(user_file.file_type), approx_char_count=(user_file.token_count or 0) * APPROX_CHARS_PER_TOKEN, ).tool_metadata def determine_search_params( persona_id: int, project_id: int | None, extracted_context_files: ExtractedContextFiles, ) -> SearchParams: """Decide which search filter IDs and search-tool usage apply for a chat turn. A custom persona fully supersedes the project — project files are never searchable and the search tool config is entirely controlled by the persona. The project_id filter is only set for the default persona. For the default persona inside a project: - Files overflow → ENABLED (vector DB scopes to these files) - Files fit → DISABLED (content already in prompt) - No files at all → DISABLED (nothing to search) """ is_custom_persona = persona_id != DEFAULT_PERSONA_ID project_id_filter: int | None = None persona_id_filter: int | None = None if extracted_context_files.use_as_search_filter: if is_custom_persona: persona_id_filter = persona_id else: project_id_filter = project_id search_usage = SearchToolUsage.AUTO if not is_custom_persona and project_id: has_context_files = bool(extracted_context_files.uncapped_token_count) files_loaded_in_context = bool(extracted_context_files.file_texts) if extracted_context_files.use_as_search_filter: search_usage = SearchToolUsage.ENABLED elif files_loaded_in_context or not has_context_files: search_usage = SearchToolUsage.DISABLED return SearchParams( project_id_filter=project_id_filter, persona_id_filter=persona_id_filter, search_usage=search_usage, ) def _resolve_query_processing_hook_result( hook_result: QueryProcessingResponse | HookSkipped | HookSoftFailed, message_text: str, ) -> str: """Apply the Query Processing hook result to the message text. Returns the (possibly rewritten) message text, or raises OnyxError with QUERY_REJECTED if the hook signals rejection (query is null or empty). HookSkipped and HookSoftFailed are pass-throughs — the original text is returned unchanged. """ if isinstance(hook_result, (HookSkipped, HookSoftFailed)): return message_text if not (hook_result.query and hook_result.query.strip()): raise OnyxError( OnyxErrorCode.QUERY_REJECTED, hook_result.rejection_message or "The hook extension for query processing did not return a valid query. No rejection reason was provided.", ) return hook_result.query.strip() def build_chat_turn( new_msg_req: SendMessageRequest, user: User, db_session: Session, # None → single-model (persona default LLM); non-empty list → multi-model (one LLM per override) llm_overrides: list[LLMOverride] | None, *, litellm_additional_headers: dict[str, str] | None = None, custom_tool_additional_headers: dict[str, str] | None = None, mcp_headers: dict[str, str] | None = None, bypass_acl: bool = False, # Slack context for federated Slack search slack_context: SlackContext | None = None, # Additional context to include in the chat history, e.g. Slack threads where the # conversation cannot be represented by a chain of User/Assistant messages. # NOTE: not stored in the database, only passed in to the LLM as context additional_context: str | None = None, ) -> Generator[AnswerStreamPart, None, ChatTurnSetup]: """Shared setup generator for both single-model and multi-model chat turns. Yields the packet(s) the frontend needs for request tracking, then returns an immutable ``ChatTurnSetup`` containing everything the execution strategy needs. Callers use:: setup = yield from build_chat_turn(new_msg_req, ..., llm_overrides=...) to forward yielded packets upstream while receiving the return value locally. Args: llm_overrides: ``None`` → single-model (persona default LLM). Non-empty list → multi-model (one LLM per override). """ tenant_id = get_current_tenant_id() is_multi = bool(llm_overrides) user_id = user.id llm_user_identifier = ( "anonymous_user" if user.is_anonymous else (user.email or str(user_id)) ) # ── Session resolution ─────────────────────────────────────────────────── if not new_msg_req.chat_session_id: if not new_msg_req.chat_session_info: raise RuntimeError("Must specify a chat session id or chat session info") chat_session = create_chat_session_from_request( chat_session_request=new_msg_req.chat_session_info, user_id=user_id, db_session=db_session, ) yield CreateChatSessionID(chat_session_id=chat_session.id) chat_session = get_chat_session_by_id( chat_session_id=chat_session.id, user_id=user_id, db_session=db_session, eager_load_persona=True, ) else: chat_session = get_chat_session_by_id( chat_session_id=new_msg_req.chat_session_id, user_id=user_id, db_session=db_session, eager_load_persona=True, ) persona = chat_session.persona message_text = new_msg_req.message user_identity = LLMUserIdentity( user_id=llm_user_identifier, session_id=str(chat_session.id) ) # Milestone tracking, most devs using the API don't need to understand this mt_cloud_telemetry( tenant_id=tenant_id, distinct_id=str(user.id) if not user.is_anonymous else tenant_id, event=MilestoneRecordType.MULTIPLE_ASSISTANTS, ) mt_cloud_telemetry( tenant_id=tenant_id, distinct_id=str(user.id) if not user.is_anonymous else tenant_id, event=MilestoneRecordType.USER_MESSAGE_SENT, properties={ "origin": new_msg_req.origin.value, "has_files": len(new_msg_req.file_descriptors) > 0, "has_project": chat_session.project_id is not None, "has_persona": persona is not None and persona.id != DEFAULT_PERSONA_ID, "deep_research": new_msg_req.deep_research, }, ) # Check LLM cost limits before using the LLM (only for Onyx-managed keys), # then build the LLM instance(s). llms: list[LLM] = [] model_display_names: list[str] = [] selected_overrides: list[LLMOverride | None] = ( list(llm_overrides or []) if is_multi else [new_msg_req.llm_override or chat_session.llm_override] ) for override in selected_overrides: llm = get_llm_for_persona( persona=persona, user=user, llm_override=override, additional_headers=litellm_additional_headers, ) check_llm_cost_limit_for_provider( db_session=db_session, tenant_id=tenant_id, llm_provider_api_key=llm.config.api_key, ) llms.append(llm) model_display_names.append(_build_model_display_name(override)) token_counter = get_llm_token_counter(llms[0]) # not sure why we do this, but to maintain parity with previous code: if not is_multi: model_display_names = [""] # Verify that the user-specified files actually belong to the user verify_user_files( user_files=new_msg_req.file_descriptors, user_id=user_id, db_session=db_session, project_id=chat_session.project_id, ) # Re-create linear history of messages chat_history = create_chat_history_chain( chat_session_id=chat_session.id, db_session=db_session ) # Determine the parent message based on the request: # - AUTO_PLACE_AFTER_LATEST_MESSAGE (-1): auto-place after latest message in chain # - None or root ID: regeneration from root (first message) # - positive int: place after that specific parent message root_message = get_or_create_root_message( chat_session_id=chat_session.id, db_session=db_session ) if new_msg_req.parent_message_id == AUTO_PLACE_AFTER_LATEST_MESSAGE: parent_message = chat_history[-1] if chat_history else root_message elif ( new_msg_req.parent_message_id is None or new_msg_req.parent_message_id == root_message.id ): # Regeneration from root — clear history so we start fresh parent_message = root_message chat_history = [] else: parent_message = None for i in range(len(chat_history) - 1, -1, -1): if chat_history[i].id == new_msg_req.parent_message_id: parent_message = chat_history[i] # Truncate to only messages up to and including the parent chat_history = chat_history[: i + 1] break if parent_message is None: raise ValueError( "The new message sent is not on the latest mainline of messages" ) # ── Query Processing hook + user message ───────────────────────────────── # Skipped on regeneration (parent is USER type): message already exists/was accepted. if parent_message.message_type == MessageType.USER: user_message = parent_message else: # New message — run the Query Processing hook before saving to DB. # Skipped on regeneration: the message already exists and was accepted previously. # Skip for empty/whitespace-only messages — no meaningful query to process, # and SendMessageRequest.message has no min_length guard. if message_text.strip(): hook_result = execute_hook( db_session=db_session, hook_point=HookPoint.QUERY_PROCESSING, payload=QueryProcessingPayload( query=message_text, # Pass None for anonymous users or authenticated users without an email # (e.g. some SSO flows). QueryProcessingPayload.user_email is str | None, # so None is accepted and serialised as null in both cases. user_email=None if user.is_anonymous else user.email, chat_session_id=str(chat_session.id), ).model_dump(), response_type=QueryProcessingResponse, ) message_text = _resolve_query_processing_hook_result( hook_result, message_text ) user_message = create_new_chat_message( chat_session_id=chat_session.id, parent_message=parent_message, message=message_text, token_count=token_counter(message_text), message_type=MessageType.USER, files=new_msg_req.file_descriptors, db_session=db_session, commit=True, ) chat_history.append(user_message) # Collect file IDs for the file reader tool *before* summary truncation so # that files attached to older (summarized-away) messages are still accessible # via the FileReaderTool. available_files = _collect_available_file_ids( chat_history=chat_history, project_id=chat_session.project_id, user_id=user_id, db_session=db_session, ) # Find applicable summary for the current branch summary_message = find_summary_for_branch(db_session, chat_history) # Collect file metadata from messages that will be dropped by summary truncation. # These become "pre-summarized" file metadata so the forgotten-file mechanism can # still tell the LLM about them. summarized_file_metadata: dict[str, FileToolMetadata] = {} if summary_message and summary_message.last_summarized_message_id: cutoff_id = summary_message.last_summarized_message_id for msg in chat_history: if msg.id > cutoff_id or not msg.files: continue for fd in msg.files: file_id = fd.get("id") if not file_id: continue summarized_file_metadata[file_id] = FileToolMetadata( file_id=file_id, filename=fd.get("name") or "unknown", # We don't know the exact size without loading the file, # but 0 signals "unknown" to the LLM. approx_char_count=0, ) # Filter chat_history to only messages after the cutoff chat_history = [m for m in chat_history if m.id > cutoff_id] # Compute skip-clarification flag for deep research path (cheap, always available) skip_clarification = is_last_assistant_message_clarification(chat_history) user_memory_context = get_memories(user, db_session) # This prompt may come from the Agent or Project. Fetched here (before run_llm_loop) # because the inner loop shouldn't need to access the DB-form chat history, but we # need it early for token reservation. custom_agent_prompt = get_custom_agent_prompt(persona, chat_session) # When use_memories is disabled, strip memories from the prompt context but keep # user info/preferences. The full context is still passed to the LLM loop for # memory tool persistence. prompt_memory_context = ( user_memory_context if user.use_memories else user_memory_context.without_memories() ) # ── Token reservation ──────────────────────────────────────────────────── max_reserved_system_prompt_tokens_str = (persona.system_prompt or "") + ( custom_agent_prompt or "" ) reserved_token_count = calculate_reserved_tokens( db_session=db_session, persona_system_prompt=max_reserved_system_prompt_tokens_str, token_counter=token_counter, files=new_msg_req.file_descriptors, user_memory_context=prompt_memory_context, ) # Determine which user files to use. A custom persona fully supersedes the project — # project files are never loaded or searchable when a custom persona is in play. # Only the default persona inside a project uses the project's files. context_user_files = resolve_context_user_files( persona=persona, project_id=chat_session.project_id, user_id=user_id, db_session=db_session, ) # Use the smallest context window across models for safety (harmless for N=1). llm_max_context_window = min(llm.config.max_input_tokens for llm in llms) extracted_context_files = extract_context_files( user_files=context_user_files, llm_max_context_window=llm_max_context_window, reserved_token_count=reserved_token_count, db_session=db_session, ) search_params = determine_search_params( persona_id=persona.id, project_id=chat_session.project_id, extracted_context_files=extracted_context_files, ) # Also grant access to persona-attached user files for FileReaderTool if persona.user_files: existing = set(available_files.user_file_ids) for uf in persona.user_files: if uf.id not in existing: available_files.user_file_ids.append(uf.id) all_tools = get_tools(db_session) tool_id_to_name_map = {tool.id: tool.name for tool in all_tools} search_tool_id = next( (tool.id for tool in all_tools if tool.in_code_tool_id == SEARCH_TOOL_ID), None ) forced_tool_id = new_msg_req.forced_tool_id if ( search_params.search_usage == SearchToolUsage.DISABLED and forced_tool_id is not None and search_tool_id is not None and forced_tool_id == search_tool_id ): forced_tool_id = None # TODO(nmgarza5): Once summarization is done, we don't need to load all files from the beginning. # Load all files needed for this chat chain into memory. files = load_all_chat_files(chat_history, db_session) # Convert loaded files to ChatFile format for tools like PythonTool chat_files_for_tools = _convert_loaded_files_to_chat_files(files) # ── Reserve assistant message ID(s) → yield to frontend ────────────────── if is_multi: assert llm_overrides is not None reserved_messages = reserve_multi_model_message_ids( db_session=db_session, chat_session_id=chat_session.id, parent_message_id=user_message.id, model_display_names=model_display_names, ) yield MultiModelMessageResponseIDInfo( user_message_id=user_message.id, responses=[ ModelResponseSlot(message_id=m.id, model_name=name) for m, name in zip(reserved_messages, model_display_names) ], ) else: assistant_response = reserve_message_id( db_session=db_session, chat_session_id=chat_session.id, parent_message=user_message.id, message_type=MessageType.ASSISTANT, ) reserved_messages = [assistant_response] yield MessageResponseIDInfo( user_message_id=user_message.id, reserved_assistant_message_id=assistant_response.id, ) # Convert the chat history into a simple format that is free of any DB objects # and is easy to parse for the agent loop. has_file_reader_tool = any( tool.in_code_tool_id == FILE_READER_TOOL_ID for tool in persona.tools ) chat_history_result = convert_chat_history( chat_history=chat_history, files=files, context_image_files=extracted_context_files.image_files, additional_context=additional_context or new_msg_req.additional_context, token_counter=token_counter, tool_id_to_name_map=tool_id_to_name_map, ) simple_chat_history = chat_history_result.simple_messages # Metadata for every text file injected into the history. After context-window # truncation drops older messages, the LLM loop compares surviving file_id tags # against this map to discover "forgotten" files and provide their metadata to # FileReaderTool. all_injected_file_metadata: dict[str, FileToolMetadata] = ( chat_history_result.all_injected_file_metadata if has_file_reader_tool else {} ) # Merge in file metadata from messages dropped by summary truncation. These files # are no longer in simple_chat_history so they'd be invisible to the forgotten-file # mechanism — they'll always appear as "forgotten" since no surviving message carries # their file_id tag. if summarized_file_metadata: for fid, meta in summarized_file_metadata.items(): all_injected_file_metadata.setdefault(fid, meta) if all_injected_file_metadata: logger.debug( f"FileReader: file metadata for LLM: {[(fid, m.filename) for fid, m in all_injected_file_metadata.items()]}" ) if summary_message is not None: summary_simple = ChatMessageSimple( message=summary_message.message, token_count=summary_message.token_count, message_type=MessageType.ASSISTANT, ) simple_chat_history.insert(0, summary_simple) # ── Stop signal and processing status ──────────────────────────────────── cache = get_cache_backend() reset_cancel_status(chat_session.id, cache) def check_is_connected() -> bool: return check_stop_signal(chat_session.id, cache) set_processing_status( chat_session_id=chat_session.id, cache=cache, value=True, ) # Release any read transaction before the long-running LLM stream. # If commit fails here, reset the processing status before propagating — # otherwise the chat session appears stuck at "processing" permanently. try: db_session.commit() except Exception: set_processing_status(chat_session_id=chat_session.id, cache=cache, value=False) raise return ChatTurnSetup( new_msg_req=new_msg_req, chat_session=chat_session, persona=persona, user_message=user_message, user_identity=user_identity, llms=llms, model_display_names=model_display_names, simple_chat_history=simple_chat_history, extracted_context_files=extracted_context_files, reserved_messages=reserved_messages, reserved_token_count=reserved_token_count, search_params=search_params, all_injected_file_metadata=all_injected_file_metadata, available_files=available_files, tool_id_to_name_map=tool_id_to_name_map, forced_tool_id=forced_tool_id, files=files, chat_files_for_tools=chat_files_for_tools, custom_agent_prompt=custom_agent_prompt, user_memory_context=user_memory_context, skip_clarification=skip_clarification, check_is_connected=check_is_connected, cache=cache, bypass_acl=bypass_acl, slack_context=slack_context, custom_tool_additional_headers=custom_tool_additional_headers, mcp_headers=mcp_headers, ) # Sentinel placed on the merged queue when a model thread finishes. _MODEL_DONE = object() # How often the drain loop polls for user-initiated cancellation (stop button). _CANCEL_POLL_INTERVAL_S: Final[float] = 0.05 def _run_models( setup: ChatTurnSetup, user: User, db_session: Session, external_state_container: ChatStateContainer | None = None, ) -> AnswerStream: """Stream packets from one or more LLM loops running in parallel worker threads. Each model gets its own worker thread, DB session, and ``Emitter``. Threads write packets to a shared unbounded queue as they are produced; the drain loop yields them in arrival order so the caller receives a single interleaved stream regardless of how many models are running. Single-model (N=1) and multi-model (N>1) use the same execution path. Every packet is tagged with ``model_index`` by the model's Emitter — ``0`` for N=1, ``0``/``1``/``2`` for multi-model. Args: setup: Fully constructed turn context — LLMs, persona, history, tool config. user: Authenticated user making the request. db_session: Caller's DB session (used for setup reads; each worker opens its own session because SQLAlchemy sessions are not thread-safe). external_state_container: Pre-constructed state container for the first model. Used by evals and the non-streaming API path so the caller can inspect accumulated state (tool calls, answer tokens, citations) after the stream is consumed. When ``None`` a fresh container is created automatically. Returns: Generator yielding ``Packet`` objects as they arrive from worker threads — answer tokens, tool output, citations — followed by a terminal ``Packet`` containing ``OverallStop`` once all models complete (or one containing ``OverallStop(stop_reason="user_cancelled")`` if the connection drops). """ n_models = len(setup.llms) merged_queue: queue.Queue[tuple[int, Packet | Exception | object]] = queue.Queue() state_containers: list[ChatStateContainer] = [ ( external_state_container if (external_state_container is not None and i == 0) else ChatStateContainer() ) for i in range(n_models) ] model_succeeded: list[bool] = [False] * n_models # Set to True when a model raises an exception (distinct from "still running"). # Used in the stop-button path to avoid calling completion for errored models. model_errored: list[bool] = [False] * n_models # Set when the drain loop exits early (HTTP disconnect / GeneratorExit). # Signals emitters to skip future puts so workers exit promptly. drain_done = threading.Event() def _run_model(model_idx: int) -> None: """Run one LLM loop inside a worker thread, writing packets to ``merged_queue``.""" model_emitter = Emitter( model_idx=model_idx, merged_queue=merged_queue, drain_done=drain_done, ) sc = state_containers[model_idx] model_llm = setup.llms[model_idx] try: # Each worker opens its own session — SQLAlchemy sessions are not thread-safe. # Do NOT write to the outer db_session (or any shared DB state) from here; # all DB writes in this thread must go through thread_db_session. with get_session_with_current_tenant() as thread_db_session: thread_tool_dict = construct_tools( persona=setup.persona, db_session=thread_db_session, emitter=model_emitter, user=user, llm=model_llm, search_tool_config=SearchToolConfig( user_selected_filters=setup.new_msg_req.internal_search_filters, project_id_filter=setup.search_params.project_id_filter, persona_id_filter=setup.search_params.persona_id_filter, bypass_acl=setup.bypass_acl, slack_context=setup.slack_context, enable_slack_search=_should_enable_slack_search( setup.persona, setup.new_msg_req.internal_search_filters ), ), custom_tool_config=CustomToolConfig( chat_session_id=setup.chat_session.id, message_id=setup.user_message.id, additional_headers=setup.custom_tool_additional_headers, mcp_headers=setup.mcp_headers, ), file_reader_tool_config=FileReaderToolConfig( user_file_ids=setup.available_files.user_file_ids, chat_file_ids=setup.available_files.chat_file_ids, ), allowed_tool_ids=setup.new_msg_req.allowed_tool_ids, search_usage_forcing_setting=setup.search_params.search_usage, ) model_tools = [ tool for tool_list in thread_tool_dict.values() for tool in tool_list ] if setup.forced_tool_id and setup.forced_tool_id not in { tool.id for tool in model_tools }: raise ValueError( f"Forced tool {setup.forced_tool_id} not found in tools" ) # Per-thread copy: run_llm_loop mutates simple_chat_history in-place. if n_models == 1 and setup.new_msg_req.deep_research: if setup.chat_session.project_id: raise RuntimeError( "Deep research is not supported for projects" ) run_deep_research_llm_loop( emitter=model_emitter, state_container=sc, simple_chat_history=list(setup.simple_chat_history), tools=model_tools, custom_agent_prompt=setup.custom_agent_prompt, llm=model_llm, token_counter=get_llm_token_counter(model_llm), db_session=thread_db_session, skip_clarification=setup.skip_clarification, user_identity=setup.user_identity, chat_session_id=str(setup.chat_session.id), all_injected_file_metadata=setup.all_injected_file_metadata, ) else: run_llm_loop( emitter=model_emitter, state_container=sc, simple_chat_history=list(setup.simple_chat_history), tools=model_tools, custom_agent_prompt=setup.custom_agent_prompt, context_files=setup.extracted_context_files, persona=setup.persona, user_memory_context=setup.user_memory_context, llm=model_llm, token_counter=get_llm_token_counter(model_llm), db_session=thread_db_session, forced_tool_id=setup.forced_tool_id, user_identity=setup.user_identity, chat_session_id=str(setup.chat_session.id), chat_files=setup.chat_files_for_tools, include_citations=setup.new_msg_req.include_citations, all_injected_file_metadata=setup.all_injected_file_metadata, inject_memories_in_prompt=user.use_memories, ) model_succeeded[model_idx] = True except Exception as e: model_errored[model_idx] = True merged_queue.put((model_idx, e)) finally: merged_queue.put((model_idx, _MODEL_DONE)) def _delete_orphaned_message(model_idx: int, context: str) -> None: """Delete a reserved ChatMessage that was never populated due to a model error.""" try: orphaned = db_session.get( ChatMessage, setup.reserved_messages[model_idx].id ) if orphaned is not None: db_session.delete(orphaned) db_session.commit() except Exception: logger.exception( "%s orphan cleanup failed for model %d (%s)", context, model_idx, setup.model_display_names[model_idx], ) # Copy contextvars before submitting futures — ThreadPoolExecutor does NOT # auto-propagate contextvars in Python 3.11; threads would inherit a blank context. worker_context = contextvars.copy_context() executor = ThreadPoolExecutor( max_workers=n_models, thread_name_prefix="multi-model" ) completion_persisted: bool = False try: for i in range(n_models): executor.submit(worker_context.run, _run_model, i) # ── Main thread: merge and yield packets ──────────────────────────── models_remaining = n_models while models_remaining > 0: try: model_idx, item = merged_queue.get(timeout=_CANCEL_POLL_INTERVAL_S) except queue.Empty: # Check for user-initiated cancellation every 50 ms. if not setup.check_is_connected(): # Save state for every model before exiting. # - Succeeded models: full answer (is_connected=True). # - Still-in-flight models: partial answer + "stopped by user". # - Errored models: delete the orphaned reserved message; do NOT # save "stopped by user" for a model that actually threw an exception. for i in range(n_models): if model_errored[i]: _delete_orphaned_message(i, "stop-button") continue try: succeeded = model_succeeded[i] llm_loop_completion_handle( state_container=state_containers[i], is_connected=lambda: succeeded, db_session=db_session, assistant_message=setup.reserved_messages[i], llm=setup.llms[i], reserved_tokens=setup.reserved_token_count, ) except Exception: logger.exception( "stop-button completion failed for model %d (%s)", i, setup.model_display_names[i], ) yield Packet( placement=Placement(turn_index=0), obj=OverallStop(type="stop", stop_reason="user_cancelled"), ) completion_persisted = True return continue else: if item is _MODEL_DONE: models_remaining -= 1 elif isinstance(item, Exception): # Yield a tagged error for this model but keep the other models running. # Do NOT decrement models_remaining — _run_model's finally always posts # _MODEL_DONE, which is the sole completion signal. error_msg = str(item) stack_trace = "".join( traceback.format_exception(type(item), item, item.__traceback__) ) model_llm = setup.llms[model_idx] if model_llm.config.api_key and len(model_llm.config.api_key) > 2: error_msg = error_msg.replace( model_llm.config.api_key, "[REDACTED_API_KEY]" ) stack_trace = stack_trace.replace( model_llm.config.api_key, "[REDACTED_API_KEY]" ) yield StreamingError( error=error_msg, stack_trace=stack_trace, error_code="MODEL_ERROR", is_retryable=True, details={ "model": model_llm.config.model_name, "provider": model_llm.config.model_provider, "model_index": model_idx, }, ) elif isinstance(item, Packet): # model_index already embedded by the model's Emitter in _run_model yield item # ── Completion: save each successful model's response ─────────────── # All model loops have completed (run_llm_loop returned) — no more writes # to state_containers. Worker threads may still be closing their own DB # sessions, but the main-thread db_session is unshared and safe to use. for i in range(n_models): if not model_succeeded[i]: # Model errored — delete its orphaned reserved message. _delete_orphaned_message(i, "normal") continue try: llm_loop_completion_handle( state_container=state_containers[i], is_connected=setup.check_is_connected, db_session=db_session, assistant_message=setup.reserved_messages[i], llm=setup.llms[i], reserved_tokens=setup.reserved_token_count, ) except Exception: logger.exception( "normal completion failed for model %d (%s)", i, setup.model_display_names[i], ) completion_persisted = True finally: if completion_persisted: # Normal exit or stop-button exit: completion already persisted. # Threads are done (normal path) or can finish in the background (stop-button). executor.shutdown(wait=False) else: # Early exit (GeneratorExit from raw HTTP disconnect, or unhandled # exception in the drain loop). # 1. Signal emitters to stop — future emit() calls return immediately, # so workers exit their LLM loops promptly. drain_done.set() # 2. Wait for all workers to finish. Once drain_done is set the Emitter # short-circuits, so workers should exit quickly. executor.shutdown(wait=True) # 3. All workers are done — complete from the main thread only. for i in range(n_models): if model_succeeded[i]: try: llm_loop_completion_handle( state_container=state_containers[i], # Model already finished — persist full response. is_connected=lambda: True, db_session=db_session, assistant_message=setup.reserved_messages[i], llm=setup.llms[i], reserved_tokens=setup.reserved_token_count, ) except Exception: logger.exception( "disconnect completion failed for model %d (%s)", i, setup.model_display_names[i], ) elif model_errored[i]: _delete_orphaned_message(i, "disconnect") # 4. Drain buffered packets from memory — no consumer is running. while not merged_queue.empty(): try: merged_queue.get_nowait() except queue.Empty: break def _stream_chat_turn( new_msg_req: SendMessageRequest, user: User, db_session: Session, llm_overrides: list[LLMOverride] | None = None, litellm_additional_headers: dict[str, str] | None = None, custom_tool_additional_headers: dict[str, str] | None = None, mcp_headers: dict[str, str] | None = None, bypass_acl: bool = False, additional_context: str | None = None, slack_context: SlackContext | None = None, external_state_container: ChatStateContainer | None = None, ) -> AnswerStream: """Private implementation for single-model and multi-model chat turn streaming. Builds the turn context via ``build_chat_turn``, then streams packets from ``_run_models`` back to the caller. Handles setup errors, LLM errors, and cancellation uniformly, saving whatever partial state has been accumulated before re-raising or yielding a terminal error packet. Not called directly — use the public wrappers: - ``handle_stream_message_objects`` for single-model (N=1) requests. - ``handle_multi_model_stream`` for side-by-side multi-model comparison (N>1). Args: new_msg_req: The incoming chat request from the user. user: Authenticated user; may be anonymous for public personas. db_session: Database session for this request. llm_overrides: ``None`` → single-model (persona default LLM). Non-empty list → multi-model (one LLM per override, 2–3 items). litellm_additional_headers: Extra headers forwarded to the LLM provider. custom_tool_additional_headers: Extra headers for custom tool HTTP calls. mcp_headers: Extra headers for MCP tool calls. bypass_acl: If ``True``, document ACL checks are skipped (used by Slack bot). additional_context: Extra context prepended to the LLM's chat history, not stored in the DB (used for Slack thread hydration). slack_context: Federated Slack search context passed through to the search tool. external_state_container: Optional pre-constructed state container. When provided, accumulated state (tool calls, citations, answer tokens) is written into it so the caller can inspect the result after streaming. Returns: Generator yielding ``Packet`` objects — answer tokens, tool output, citations — followed by a terminal ``Packet`` containing ``OverallStop``. """ if new_msg_req.mock_llm_response is not None and not INTEGRATION_TESTS_MODE: raise ValueError( "mock_llm_response can only be used when INTEGRATION_TESTS_MODE=true" ) mock_response_token: Token[str | None] | None = None setup: ChatTurnSetup | None = None try: setup = yield from build_chat_turn( new_msg_req=new_msg_req, user=user, db_session=db_session, llm_overrides=llm_overrides, litellm_additional_headers=litellm_additional_headers, custom_tool_additional_headers=custom_tool_additional_headers, mcp_headers=mcp_headers, bypass_acl=bypass_acl, slack_context=slack_context, additional_context=additional_context, ) # Set mock response token right before the LLM stream begins so that # run_in_background threads inherit the correct context. if new_msg_req.mock_llm_response is not None: mock_response_token = set_llm_mock_response(new_msg_req.mock_llm_response) yield from _run_models( setup=setup, user=user, db_session=db_session, external_state_container=external_state_container, ) except OnyxError as e: if e.error_code is not OnyxErrorCode.QUERY_REJECTED: log_onyx_error(e) yield StreamingError( error=e.detail, error_code=e.error_code.code, is_retryable=e.status_code >= 500, ) db_session.rollback() return except ValueError as e: logger.exception("Failed to process chat message.") yield StreamingError( error=str(e), error_code="VALIDATION_ERROR", is_retryable=True, ) db_session.rollback() return except EmptyLLMResponseError as e: stack_trace = traceback.format_exc() logger.warning( f"LLM returned an empty response (provider={e.provider}, model={e.model}, tool_choice={e.tool_choice})" ) yield StreamingError( error=e.client_error_msg, stack_trace=stack_trace, error_code=e.error_code, is_retryable=e.is_retryable, details={ "model": e.model, "provider": e.provider, "tool_choice": e.tool_choice.value, }, ) db_session.rollback() except Exception as e: logger.exception(f"Failed to process chat message due to {e}") stack_trace = traceback.format_exc() llm = setup.llms[0] if setup else None if llm: client_error_msg, error_code, is_retryable = litellm_exception_to_error_msg( e, llm ) if llm.config.api_key and len(llm.config.api_key) > 2: client_error_msg = client_error_msg.replace( llm.config.api_key, "[REDACTED_API_KEY]" ) stack_trace = stack_trace.replace( llm.config.api_key, "[REDACTED_API_KEY]" ) yield StreamingError( error=client_error_msg, stack_trace=stack_trace, error_code=error_code, is_retryable=is_retryable, details={ "model": llm.config.model_name, "provider": llm.config.model_provider, }, ) else: yield StreamingError( error="Failed to initialize the chat. Please check your configuration and try again.", stack_trace=stack_trace, error_code="INIT_FAILED", is_retryable=True, ) db_session.rollback() finally: if mock_response_token is not None: reset_llm_mock_response(mock_response_token) try: if setup is not None: set_processing_status( chat_session_id=setup.chat_session.id, cache=setup.cache, value=False, ) except Exception: logger.exception("Error in setting processing status") def handle_stream_message_objects( new_msg_req: SendMessageRequest, user: User, db_session: Session, litellm_additional_headers: dict[str, str] | None = None, custom_tool_additional_headers: dict[str, str] | None = None, mcp_headers: dict[str, str] | None = None, bypass_acl: bool = False, additional_context: str | None = None, slack_context: SlackContext | None = None, external_state_container: ChatStateContainer | None = None, ) -> AnswerStream: """Single-model streaming entrypoint. For multi-model comparison, use ``handle_multi_model_stream``.""" yield from _stream_chat_turn( new_msg_req=new_msg_req, user=user, db_session=db_session, llm_overrides=None, litellm_additional_headers=litellm_additional_headers, custom_tool_additional_headers=custom_tool_additional_headers, mcp_headers=mcp_headers, bypass_acl=bypass_acl, additional_context=additional_context, slack_context=slack_context, external_state_container=external_state_container, ) def _build_model_display_name(override: LLMOverride | None) -> str: """Build a human-readable display name from an LLM override.""" if override is None: return "unknown" return override.display_name or override.model_version or "unknown" def handle_multi_model_stream( new_msg_req: SendMessageRequest, user: User, db_session: Session, llm_overrides: list[LLMOverride], litellm_additional_headers: dict[str, str] | None = None, custom_tool_additional_headers: dict[str, str] | None = None, mcp_headers: dict[str, str] | None = None, ) -> AnswerStream: """Thin wrapper for side-by-side multi-model comparison (2–3 models). Validates the override list and delegates to ``_stream_chat_turn``, which handles both single-model and multi-model execution via the same path. Args: new_msg_req: The incoming chat request. ``deep_research`` must be ``False``. user: Authenticated user making the request. db_session: Database session for this request. llm_overrides: Exactly 2 or 3 ``LLMOverride`` objects — one per model to run. litellm_additional_headers: Extra headers forwarded to each LLM provider. custom_tool_additional_headers: Extra headers for custom tool HTTP calls. mcp_headers: Extra headers for MCP tool calls. Returns: Generator yielding interleaved ``Packet`` objects from all models, each tagged with ``model_index`` in its placement. """ n_models = len(llm_overrides) if n_models < 2 or n_models > 3: yield StreamingError( error=f"Multi-model requires 2-3 overrides, got {n_models}", error_code="VALIDATION_ERROR", is_retryable=False, ) return if new_msg_req.deep_research: yield StreamingError( error="Multi-model is not supported with deep research", error_code="VALIDATION_ERROR", is_retryable=False, ) return yield from _stream_chat_turn( new_msg_req=new_msg_req, user=user, db_session=db_session, llm_overrides=llm_overrides, litellm_additional_headers=litellm_additional_headers, custom_tool_additional_headers=custom_tool_additional_headers, mcp_headers=mcp_headers, ) def llm_loop_completion_handle( state_container: ChatStateContainer, is_connected: Callable[[], bool], db_session: Session, assistant_message: ChatMessage, llm: LLM, reserved_tokens: int, ) -> None: chat_session_id = assistant_message.chat_session_id # Snapshot all state under the container's lock before any DB write. # Worker threads may still be running (e.g. user-cancellation path), so # direct attribute access is not thread-safe — use the provided getters. answer_tokens = state_container.get_answer_tokens() reasoning_tokens = state_container.get_reasoning_tokens() citation_to_doc = state_container.get_citation_to_doc() tool_calls = state_container.get_tool_calls() is_clarification = state_container.get_is_clarification() all_search_docs = state_container.get_all_search_docs() emitted_citations = state_container.get_emitted_citations() pre_answer_processing_time = state_container.get_pre_answer_processing_time() completed_normally = is_connected() if completed_normally: if answer_tokens is None: raise RuntimeError( "LLM run completed normally but did not return an answer." ) final_answer = answer_tokens else: # Stopped by user - append stop message logger.debug(f"Chat session {chat_session_id} stopped by user") if answer_tokens: final_answer = ( answer_tokens + " ... \n\nGeneration was stopped by the user." ) else: final_answer = "The generation was stopped by the user." save_chat_turn( message_text=final_answer, reasoning_tokens=reasoning_tokens, citation_to_doc=citation_to_doc, tool_calls=tool_calls, all_search_docs=all_search_docs, db_session=db_session, assistant_message=assistant_message, is_clarification=is_clarification, emitted_citations=emitted_citations, pre_answer_processing_time=pre_answer_processing_time, ) # Check if compression is needed after saving the message updated_chat_history = create_chat_history_chain( chat_session_id=chat_session_id, db_session=db_session, ) total_tokens = calculate_total_history_tokens(updated_chat_history) compression_params = get_compression_params( max_input_tokens=llm.config.max_input_tokens, current_history_tokens=total_tokens, reserved_tokens=reserved_tokens, ) if compression_params.should_compress: # Build tool mapping for formatting messages all_tools = get_tools(db_session) tool_id_to_name = {tool.id: tool.name for tool in all_tools} compress_chat_history( db_session=db_session, chat_history=updated_chat_history, llm=llm, compression_params=compression_params, tool_id_to_name=tool_id_to_name, ) _CITATION_LINK_START_PATTERN = re.compile(r"\s*\[\[\d+\]\]\(") def _find_markdown_link_end(text: str, destination_start: int) -> int | None: depth = 0 i = destination_start while i < len(text): curr = text[i] if curr == "\\": i += 2 continue if curr == "(": depth += 1 elif curr == ")": if depth == 0: return i depth -= 1 i += 1 return None def remove_answer_citations(answer: str) -> str: stripped_parts: list[str] = [] cursor = 0 while match := _CITATION_LINK_START_PATTERN.search(answer, cursor): stripped_parts.append(answer[cursor : match.start()]) link_end = _find_markdown_link_end(answer, match.end()) if link_end is None: stripped_parts.append(answer[match.start() :]) return "".join(stripped_parts) cursor = link_end + 1 stripped_parts.append(answer[cursor:]) return "".join(stripped_parts) @log_function_time() def gather_stream( packets: AnswerStream, ) -> ChatBasicResponse: answer: str | None = None citations: list[CitationInfo] = [] error_msg: str | None = None message_id: int | None = None top_documents: list[SearchDoc] = [] for packet in packets: if isinstance(packet, Packet): # Handle the different packet object types if isinstance(packet.obj, AgentResponseStart): # AgentResponseStart contains the final documents if packet.obj.final_documents: top_documents = packet.obj.final_documents elif isinstance(packet.obj, AgentResponseDelta): # AgentResponseDelta contains incremental content updates if answer is None: answer = "" if packet.obj.content: answer += packet.obj.content elif isinstance(packet.obj, CitationInfo): # CitationInfo contains citation information citations.append(packet.obj) elif isinstance(packet, StreamingError): error_msg = packet.error elif isinstance(packet, MessageResponseIDInfo): message_id = packet.reserved_assistant_message_id if message_id is None: raise ValueError("Message ID is required") if answer is None: if error_msg is not None: answer = "" else: # This should never be the case as these non-streamed flows do not have a stop-generation signal raise RuntimeError("Answer was not generated") return ChatBasicResponse( answer=answer, answer_citationless=remove_answer_citations(answer), citation_info=citations, message_id=message_id, error_msg=error_msg, top_documents=top_documents, ) @log_function_time() def gather_stream_full( packets: AnswerStream, state_container: ChatStateContainer, ) -> ChatFullResponse: """ Aggregate streaming packets and state container into a complete ChatFullResponse. This function consumes all packets from the stream and combines them with the accumulated state from the ChatStateContainer to build a complete response including answer, reasoning, citations, and tool calls. Args: packets: The stream of packets from handle_stream_message_objects state_container: The state container that accumulates tool calls, reasoning, etc. Returns: ChatFullResponse with all available data """ answer: str | None = None citations: list[CitationInfo] = [] error_msg: str | None = None message_id: int | None = None top_documents: list[SearchDoc] = [] chat_session_id: UUID | None = None for packet in packets: if isinstance(packet, Packet): if isinstance(packet.obj, AgentResponseStart): if packet.obj.final_documents: top_documents = packet.obj.final_documents elif isinstance(packet.obj, AgentResponseDelta): if answer is None: answer = "" if packet.obj.content: answer += packet.obj.content elif isinstance(packet.obj, CitationInfo): citations.append(packet.obj) elif isinstance(packet, StreamingError): error_msg = packet.error elif isinstance(packet, MessageResponseIDInfo): message_id = packet.reserved_assistant_message_id elif isinstance(packet, CreateChatSessionID): chat_session_id = packet.chat_session_id if message_id is None: raise ValueError("Message ID is required") # Use state_container for complete answer (handles edge cases gracefully) final_answer = state_container.get_answer_tokens() or answer or "" # Get reasoning from state container (None when model doesn't produce reasoning) reasoning = state_container.get_reasoning_tokens() # Convert ToolCallInfo list to ToolCallResponse list tool_call_responses = [ ToolCallResponse( tool_name=tc.tool_name, tool_arguments=tc.tool_call_arguments, tool_result=tc.tool_call_response, search_docs=tc.search_docs, generated_images=tc.generated_images, pre_reasoning=tc.reasoning_tokens, ) for tc in state_container.get_tool_calls() ] return ChatFullResponse( answer=final_answer, answer_citationless=remove_answer_citations(final_answer), pre_answer_reasoning=reasoning, tool_calls=tool_call_responses, top_documents=top_documents, citation_info=citations, message_id=message_id, chat_session_id=chat_session_id, error_msg=error_msg, ) ================================================ FILE: backend/onyx/chat/prompt_utils.py ================================================ from collections.abc import Callable from collections.abc import Sequence from uuid import UUID from sqlalchemy.orm import Session from onyx.db.memory import UserMemoryContext from onyx.db.persona import get_default_behavior_persona from onyx.db.user_file import calculate_user_files_token_count from onyx.file_store.models import FileDescriptor from onyx.prompts.chat_prompts import CITATION_REMINDER from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT from onyx.prompts.chat_prompts import FILE_REMINDER from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE from onyx.prompts.prompt_utils import get_company_context from onyx.prompts.prompt_utils import handle_onyx_date_awareness from onyx.prompts.prompt_utils import replace_citation_guidance_tag from onyx.prompts.prompt_utils import replace_reminder_tag from onyx.prompts.tool_prompts import GENERATE_IMAGE_GUIDANCE from onyx.prompts.tool_prompts import INTERNAL_SEARCH_GUIDANCE from onyx.prompts.tool_prompts import MEMORY_GUIDANCE from onyx.prompts.tool_prompts import OPEN_URLS_GUIDANCE from onyx.prompts.tool_prompts import PYTHON_TOOL_GUIDANCE from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE from onyx.prompts.tool_prompts import TOOL_SECTION_HEADER from onyx.prompts.tool_prompts import WEB_SEARCH_GUIDANCE from onyx.prompts.tool_prompts import WEB_SEARCH_SITE_DISABLED_GUIDANCE from onyx.prompts.user_info import BASIC_INFORMATION_PROMPT from onyx.prompts.user_info import TEAM_INFORMATION_PROMPT from onyx.prompts.user_info import USER_INFORMATION_HEADER from onyx.prompts.user_info import USER_MEMORIES_PROMPT from onyx.prompts.user_info import USER_PREFERENCES_PROMPT from onyx.prompts.user_info import USER_ROLE_PROMPT from onyx.tools.interface import Tool from onyx.tools.tool_implementations.images.image_generation_tool import ( ImageGenerationTool, ) from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool from onyx.tools.tool_implementations.python.python_tool import PythonTool from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool from onyx.utils.timing import log_function_time def get_default_base_system_prompt(db_session: Session) -> str: default_persona = get_default_behavior_persona(db_session) return ( default_persona.system_prompt if default_persona and default_persona.system_prompt is not None else DEFAULT_SYSTEM_PROMPT ) @log_function_time(print_only=True) def calculate_reserved_tokens( db_session: Session, persona_system_prompt: str, token_counter: Callable[[str], int], files: list[FileDescriptor] | None = None, user_memory_context: UserMemoryContext | None = None, ) -> int: """ Calculate reserved token count for system prompt and user files. This is used for token estimation purposes to reserve space for: - The system prompt (base + custom agent prompt + all guidance) - User files attached to the message Args: db_session: Database session persona_system_prompt: Custom agent system prompt (can be empty string) token_counter: Function that counts tokens in text files: List of file descriptors from the chat message (optional) user_memory_context: User memory context (optional) Returns: Total reserved token count """ base_system_prompt = get_default_base_system_prompt(db_session) # This is for token estimation purposes fake_system_prompt = build_system_prompt( base_system_prompt=base_system_prompt, datetime_aware=True, user_memory_context=user_memory_context, tools=None, should_cite_documents=True, include_all_guidance=True, ) custom_agent_prompt = persona_system_prompt if persona_system_prompt else "" reserved_token_count = token_counter( # Annoying that the dict has no attributes now custom_agent_prompt + " " + fake_system_prompt ) # Calculate total token count for files in the last message file_token_count = 0 if files: # Extract user_file_id from each file descriptor user_file_ids: list[UUID] = [] for file in files: uid = file.get("user_file_id") if not uid: continue try: user_file_ids.append(UUID(uid)) except (TypeError, ValueError, AttributeError): # Skip invalid user_file_id values continue if user_file_ids: file_token_count = calculate_user_files_token_count( user_file_ids, db_session ) reserved_token_count += file_token_count return reserved_token_count def build_reminder_message( reminder_text: str | None, include_citation_reminder: bool, include_file_reminder: bool, is_last_cycle: bool, ) -> str | None: reminder = reminder_text.strip() if reminder_text else "" if is_last_cycle: reminder += "\n\n" + LAST_CYCLE_CITATION_REMINDER if include_citation_reminder: reminder += "\n\n" + CITATION_REMINDER if include_file_reminder: reminder += "\n\n" + FILE_REMINDER reminder = reminder.strip() return reminder if reminder else None def _build_user_information_section( user_memory_context: UserMemoryContext | None, company_context: str | None, ) -> str: """Build the complete '# User Information' section with all sub-sections in the correct order: Basic Info → Team Info → Preferences → Memories.""" sections: list[str] = [] if user_memory_context: ctx = user_memory_context has_basic_info = ctx.user_info.name or ctx.user_info.email or ctx.user_info.role if has_basic_info: role_line = ( USER_ROLE_PROMPT.format(user_role=ctx.user_info.role).strip() if ctx.user_info.role else "" ) if role_line: role_line = "\n" + role_line sections.append( BASIC_INFORMATION_PROMPT.format( user_name=ctx.user_info.name or "", user_email=ctx.user_info.email or "", user_role=role_line, ) ) if company_context: sections.append( TEAM_INFORMATION_PROMPT.format(team_information=company_context.strip()) ) if user_memory_context: ctx = user_memory_context if ctx.user_preferences: sections.append( USER_PREFERENCES_PROMPT.format(user_preferences=ctx.user_preferences) ) if ctx.memories: formatted_memories = "\n".join(f"- {memory}" for memory in ctx.memories) sections.append( USER_MEMORIES_PROMPT.format(user_memories=formatted_memories) ) if not sections: return "" return USER_INFORMATION_HEADER + "\n".join(sections) def build_system_prompt( base_system_prompt: str, datetime_aware: bool = False, user_memory_context: UserMemoryContext | None = None, tools: Sequence[Tool] | None = None, should_cite_documents: bool = False, include_all_guidance: bool = False, ) -> str: """Should only be called with the default behavior system prompt. If the user has replaced the default behavior prompt with their custom agent prompt, do not call this function. """ system_prompt = handle_onyx_date_awareness(base_system_prompt, datetime_aware) # Replace citation guidance placeholder if present system_prompt, should_append_citation_guidance = replace_citation_guidance_tag( system_prompt, should_cite_documents=should_cite_documents, include_all_guidance=include_all_guidance, ) # Replace reminder tag placeholder if present system_prompt = replace_reminder_tag(system_prompt) company_context = get_company_context() user_info_section = _build_user_information_section( user_memory_context, company_context ) system_prompt += user_info_section # Append citation guidance after company context if placeholder was not present # This maintains backward compatibility and ensures citations are always enforced when needed if should_append_citation_guidance: system_prompt += REQUIRE_CITATION_GUIDANCE if include_all_guidance: tool_sections = [ TOOL_DESCRIPTION_SEARCH_GUIDANCE, INTERNAL_SEARCH_GUIDANCE, WEB_SEARCH_GUIDANCE.format( site_colon_disabled=WEB_SEARCH_SITE_DISABLED_GUIDANCE ), OPEN_URLS_GUIDANCE, PYTHON_TOOL_GUIDANCE, GENERATE_IMAGE_GUIDANCE, MEMORY_GUIDANCE, ] system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_sections) return system_prompt if tools: has_web_search = any(isinstance(tool, WebSearchTool) for tool in tools) has_internal_search = any(isinstance(tool, SearchTool) for tool in tools) has_open_urls = any(isinstance(tool, OpenURLTool) for tool in tools) has_python = any(isinstance(tool, PythonTool) for tool in tools) has_generate_image = any( isinstance(tool, ImageGenerationTool) for tool in tools ) has_memory = any(isinstance(tool, MemoryTool) for tool in tools) tool_guidance_sections: list[str] = [] if has_web_search or has_internal_search or include_all_guidance: tool_guidance_sections.append(TOOL_DESCRIPTION_SEARCH_GUIDANCE) # These are not included at the Tool level because the ordering may matter. if has_internal_search or include_all_guidance: tool_guidance_sections.append(INTERNAL_SEARCH_GUIDANCE) if has_web_search or include_all_guidance: site_disabled_guidance = "" if has_web_search: web_search_tool = next( (t for t in tools if isinstance(t, WebSearchTool)), None ) if web_search_tool and not web_search_tool.supports_site_filter: site_disabled_guidance = WEB_SEARCH_SITE_DISABLED_GUIDANCE tool_guidance_sections.append( WEB_SEARCH_GUIDANCE.format(site_colon_disabled=site_disabled_guidance) ) if has_open_urls or include_all_guidance: tool_guidance_sections.append(OPEN_URLS_GUIDANCE) if has_python or include_all_guidance: tool_guidance_sections.append(PYTHON_TOOL_GUIDANCE) if has_generate_image or include_all_guidance: tool_guidance_sections.append(GENERATE_IMAGE_GUIDANCE) if has_memory or include_all_guidance: tool_guidance_sections.append(MEMORY_GUIDANCE) if tool_guidance_sections: system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_guidance_sections) return system_prompt ================================================ FILE: backend/onyx/chat/save_chat.py ================================================ import json import mimetypes from sqlalchemy.orm import Session from onyx.chat.chat_state import ChatStateContainer from onyx.chat.chat_state import SearchDocKey from onyx.configs.constants import DocumentSource from onyx.context.search.models import SearchDoc from onyx.db.chat import add_search_docs_to_chat_message from onyx.db.chat import add_search_docs_to_tool_call from onyx.db.chat import create_db_search_doc from onyx.db.models import ChatMessage from onyx.db.models import ToolCall from onyx.db.tools import create_tool_call_no_commit from onyx.file_store.models import FileDescriptor from onyx.natural_language_processing.utils import BaseTokenizer from onyx.natural_language_processing.utils import get_tokenizer from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type from onyx.tools.models import ToolCallInfo from onyx.utils.logger import setup_logger from onyx.utils.postgres_sanitization import sanitize_string logger = setup_logger() def _extract_referenced_file_descriptors( tool_calls: list[ToolCallInfo], message_text: str, ) -> list[FileDescriptor]: """Extract FileDescriptors for code interpreter files referenced in the message text.""" descriptors: list[FileDescriptor] = [] for tool_call_info in tool_calls: if not tool_call_info.generated_files: continue for gen_file in tool_call_info.generated_files: file_id = ( gen_file.file_link.rsplit("/", 1)[-1] if gen_file.file_link else "" ) if file_id and file_id in message_text: mime_type, _ = mimetypes.guess_type(gen_file.filename) descriptors.append( FileDescriptor( id=file_id, type=mime_type_to_chat_file_type(mime_type), name=gen_file.filename, ) ) return descriptors def _create_and_link_tool_calls( tool_calls: list[ToolCallInfo], assistant_message: ChatMessage, db_session: Session, default_tokenizer: BaseTokenizer, tool_call_to_search_doc_ids: dict[str, list[int]], ) -> None: """ Create ToolCall entries and link parent references and SearchDocs. This function handles the logic of: 1. Creating all ToolCall objects (with temporary parent references) 2. Flushing to get DB IDs 3. Building mappings and updating parent references 4. Linking SearchDocs to ToolCalls Args: tool_calls: List of tool call information to create assistant_message: The ChatMessage these tool calls belong to db_session: Database session default_tokenizer: Tokenizer for calculating token counts tool_call_to_search_doc_ids: Mapping from tool_call_id to list of search_doc IDs """ # Create all ToolCall objects first (without parent_tool_call_id set) # We'll update parent references after flushing to get IDs tool_call_objects: list[ToolCall] = [] tool_call_info_map: dict[str, ToolCallInfo] = {} for tool_call_info in tool_calls: tool_call_info_map[tool_call_info.tool_call_id] = tool_call_info # Calculate tool_call_tokens from arguments try: arguments_json_str = json.dumps(tool_call_info.tool_call_arguments) tool_call_tokens = len(default_tokenizer.encode(arguments_json_str)) except Exception as e: logger.warning( f"Failed to tokenize tool call arguments for {tool_call_info.tool_call_id}: {e}. Using length as (over) estimate." ) arguments_json_str = json.dumps(tool_call_info.tool_call_arguments) tool_call_tokens = len(arguments_json_str) parent_message_id = ( assistant_message.id if tool_call_info.parent_tool_call_id is None else None ) # Create ToolCall DB entry (parent_tool_call_id will be set after flush) # This is needed to get the IDs for the parent pointers tool_call = create_tool_call_no_commit( chat_session_id=assistant_message.chat_session_id, parent_chat_message_id=parent_message_id, turn_number=tool_call_info.turn_index, tool_id=tool_call_info.tool_id, tool_call_id=tool_call_info.tool_call_id, tool_call_arguments=tool_call_info.tool_call_arguments, tool_call_response=tool_call_info.tool_call_response, tool_call_tokens=tool_call_tokens, db_session=db_session, parent_tool_call_id=None, # Will be updated after flush reasoning_tokens=tool_call_info.reasoning_tokens, generated_images=( [img.model_dump() for img in tool_call_info.generated_images] if tool_call_info.generated_images else None ), tab_index=tool_call_info.tab_index, add_only=True, ) # Flush to get all of the IDs db_session.flush() tool_call_objects.append(tool_call) # Build mapping of tool calls (tool_call_id string -> DB id int) tool_call_map: dict[str, int] = {} for tool_call_obj in tool_call_objects: tool_call_map[tool_call_obj.tool_call_id] = tool_call_obj.id # Update parent_tool_call_id for all tool calls # Filter out orphaned children (whose parents don't exist) - this can happen # when generation is stopped mid-execution and parent tool calls were cancelled valid_tool_calls: list[ToolCall] = [] for tool_call_obj in tool_call_objects: tool_call_info = tool_call_info_map[tool_call_obj.tool_call_id] if tool_call_info.parent_tool_call_id is not None: parent_id = tool_call_map.get(tool_call_info.parent_tool_call_id) if parent_id is not None: tool_call_obj.parent_tool_call_id = parent_id valid_tool_calls.append(tool_call_obj) else: # Parent doesn't exist (likely cancelled) - skip this orphaned child logger.warning( f"Skipping tool call '{tool_call_obj.tool_call_id}' with missing parent " f"'{tool_call_info.parent_tool_call_id}' (likely cancelled during execution)" ) # Remove from DB session to prevent saving db_session.delete(tool_call_obj) else: # Top-level tool call (no parent) valid_tool_calls.append(tool_call_obj) # Link SearchDocs only to valid ToolCalls for tool_call_obj in valid_tool_calls: search_doc_ids = tool_call_to_search_doc_ids.get(tool_call_obj.tool_call_id, []) if search_doc_ids: add_search_docs_to_tool_call( tool_call_id=tool_call_obj.id, search_doc_ids=search_doc_ids, db_session=db_session, ) def save_chat_turn( message_text: str, reasoning_tokens: str | None, tool_calls: list[ToolCallInfo], citation_to_doc: dict[int, SearchDoc], all_search_docs: dict[SearchDocKey, SearchDoc], db_session: Session, assistant_message: ChatMessage, is_clarification: bool = False, emitted_citations: set[int] | None = None, pre_answer_processing_time: float | None = None, ) -> None: """ Save a chat turn by populating the assistant_message and creating related entities. This function: 1. Updates the ChatMessage with text, reasoning tokens, and token count 2. Creates DB SearchDoc entries from pre-deduplicated all_search_docs 3. Builds tool_call -> search_doc mapping for displayed docs 4. Builds citation mapping from citation_to_doc 5. Links all unique SearchDocs to the ChatMessage 6. Creates ToolCall entries and links SearchDocs to them 7. Builds the citations mapping for the ChatMessage Args: message_text: The message content to save reasoning_tokens: Optional reasoning tokens for the message tool_calls: List of tool call information to create ToolCall entries (may include search_docs) citation_to_doc: Mapping from citation number to SearchDoc for building citations all_search_docs: Pre-deduplicated search docs from ChatStateContainer db_session: Database session for persistence assistant_message: The ChatMessage object to populate (should already exist in DB) is_clarification: Whether this assistant message is a clarification question (deep research flow) emitted_citations: Set of citation numbers that were actually emitted during streaming. If provided, only citations in this set will be saved; others are filtered out. pre_answer_processing_time: Duration of processing before answer starts (in seconds) """ # 1. Update ChatMessage with message content, reasoning tokens, and token count sanitized_message_text = ( sanitize_string(message_text) if message_text else message_text ) assistant_message.message = sanitized_message_text assistant_message.reasoning_tokens = ( sanitize_string(reasoning_tokens) if reasoning_tokens else reasoning_tokens ) assistant_message.is_clarification = is_clarification # Use pre-answer processing time (captured when MESSAGE_START was emitted) if pre_answer_processing_time is not None: assistant_message.processing_duration_seconds = pre_answer_processing_time # Calculate token count using default tokenizer, when storing, this should not use the LLM # specific one so we use a system default tokenizer here. default_tokenizer = get_tokenizer(None, None) if sanitized_message_text: assistant_message.token_count = len( default_tokenizer.encode(sanitized_message_text) ) else: assistant_message.token_count = 0 # 2. Create DB SearchDoc entries from pre-deduplicated all_search_docs search_doc_key_to_id: dict[SearchDocKey, int] = {} for key, search_doc_py in all_search_docs.items(): db_search_doc = create_db_search_doc( server_search_doc=search_doc_py, db_session=db_session, commit=False, ) search_doc_key_to_id[key] = db_search_doc.id # 3. Build tool_call -> search_doc mapping (for displayed docs in each tool call) tool_call_to_search_doc_ids: dict[str, list[int]] = {} for tool_call_info in tool_calls: if tool_call_info.search_docs: search_doc_ids_for_tool: list[int] = [] for search_doc_py in tool_call_info.search_docs: key = ChatStateContainer.create_search_doc_key(search_doc_py) if key in search_doc_key_to_id: search_doc_ids_for_tool.append(search_doc_key_to_id[key]) else: # Displayed doc not in all_search_docs - create it # This can happen if displayed_docs contains docs not in search_docs db_search_doc = create_db_search_doc( server_search_doc=search_doc_py, db_session=db_session, commit=False, ) search_doc_key_to_id[key] = db_search_doc.id search_doc_ids_for_tool.append(db_search_doc.id) tool_call_to_search_doc_ids[tool_call_info.tool_call_id] = list( set(search_doc_ids_for_tool) ) # Collect all search doc IDs for ChatMessage linking all_search_doc_ids_set: set[int] = set(search_doc_key_to_id.values()) # 4. Build a citation mapping from the citation number to the saved DB SearchDoc ID # Only include citations that were actually emitted during streaming citation_number_to_search_doc_id: dict[int, int] = {} for citation_num, search_doc_py in citation_to_doc.items(): # Skip citations that weren't actually emitted (if emitted_citations is provided) if emitted_citations is not None and citation_num not in emitted_citations: continue # Create the unique key for this SearchDoc version search_doc_key = ChatStateContainer.create_search_doc_key(search_doc_py) # Get the search doc ID (should already exist from processing tool_calls) if search_doc_key in search_doc_key_to_id: db_search_doc_id = search_doc_key_to_id[search_doc_key] else: # Citation doc not found in tool call search_docs # Expected case: Project files (source_type=FILE) are cited but don't come from tool calls # Unexpected case: Other citation-only docs (indicates a potential issue upstream) is_project_file = search_doc_py.source_type == DocumentSource.FILE if is_project_file: logger.info( f"Project file citation {search_doc_py.document_id} not in tool calls, creating it" ) else: logger.warning( f"Citation doc {search_doc_py.document_id} not found in tool call search_docs, creating it" ) # Create the SearchDoc in the database # NOTE: It's important that this maps to the saved DB Document ID, because # the match-highlights are specific to this saved version, not any document that has # the same document_id. db_search_doc = create_db_search_doc( server_search_doc=search_doc_py, db_session=db_session, commit=False, ) db_search_doc_id = db_search_doc.id search_doc_key_to_id[search_doc_key] = db_search_doc_id # Link project files to ChatMessage to enable frontend preview if is_project_file: all_search_doc_ids_set.add(db_search_doc_id) # Build mapping from citation number to search doc ID citation_number_to_search_doc_id[citation_num] = db_search_doc_id # 5. Link all unique SearchDocs (from both tool calls and citations) to ChatMessage final_search_doc_ids: list[int] = list(all_search_doc_ids_set) if final_search_doc_ids: add_search_docs_to_chat_message( chat_message_id=assistant_message.id, search_doc_ids=final_search_doc_ids, db_session=db_session, ) # 6. Create ToolCall entries and link SearchDocs to them _create_and_link_tool_calls( tool_calls=tool_calls, assistant_message=assistant_message, db_session=db_session, default_tokenizer=default_tokenizer, tool_call_to_search_doc_ids=tool_call_to_search_doc_ids, ) # 7. Build citations mapping - use the mapping we already built in step 4 assistant_message.citations = ( citation_number_to_search_doc_id if citation_number_to_search_doc_id else None ) # 8. Attach code interpreter generated files that the assistant actually # referenced in its response, so they are available via load_all_chat_files # on subsequent turns. Files not mentioned are intermediate artifacts. if sanitized_message_text: referenced = _extract_referenced_file_descriptors( tool_calls, sanitized_message_text ) if referenced: existing_files = assistant_message.files or [] assistant_message.files = existing_files + referenced # Finally save the messages, tool calls, and docs db_session.commit() ================================================ FILE: backend/onyx/chat/stop_signal_checker.py ================================================ from uuid import UUID from onyx.cache.interface import CacheBackend PREFIX = "chatsessionstop" FENCE_PREFIX = f"{PREFIX}_fence" FENCE_TTL = 10 * 60 # 10 minutes def _get_fence_key(chat_session_id: UUID) -> str: """Generate the cache key for a chat session stop signal fence. Args: chat_session_id: The UUID of the chat session Returns: The fence key string. Tenant isolation is handled automatically by the cache backend (Redis key-prefixing or Postgres schema routing). """ return f"{FENCE_PREFIX}_{chat_session_id}" def set_fence(chat_session_id: UUID, cache: CacheBackend, value: bool) -> None: """Set or clear the stop signal fence for a chat session. Args: chat_session_id: The UUID of the chat session cache: Tenant-aware cache backend value: True to set the fence (stop signal), False to clear it """ fence_key = _get_fence_key(chat_session_id) if not value: cache.delete(fence_key) return cache.set(fence_key, 0, ex=FENCE_TTL) def is_connected(chat_session_id: UUID, cache: CacheBackend) -> bool: """Check if the chat session should continue (not stopped). Args: chat_session_id: The UUID of the chat session to check cache: Tenant-aware cache backend Returns: True if the session should continue, False if it should stop """ return not cache.exists(_get_fence_key(chat_session_id)) def reset_cancel_status(chat_session_id: UUID, cache: CacheBackend) -> None: """Clear the stop signal for a chat session. Args: chat_session_id: The UUID of the chat session cache: Tenant-aware cache backend """ cache.delete(_get_fence_key(chat_session_id)) ================================================ FILE: backend/onyx/chat/tool_call_args_streaming.py ================================================ from collections.abc import Generator from collections.abc import Mapping from typing import Any from typing import Type from onyx.llm.model_response import ChatCompletionDeltaToolCall from onyx.server.query_and_chat.placement import Placement from onyx.server.query_and_chat.streaming_models import Packet from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta from onyx.tools.built_in_tools import TOOL_NAME_TO_CLASS from onyx.tools.interface import Tool from onyx.utils.jsonriver import Parser def _get_tool_class( tool_calls_in_progress: Mapping[int, Mapping[str, Any]], tool_call_delta: ChatCompletionDeltaToolCall, ) -> Type[Tool] | None: """Look up the Tool subclass for a streaming tool call delta.""" tool_name = tool_calls_in_progress.get(tool_call_delta.index, {}).get("name") if not tool_name: return None return TOOL_NAME_TO_CLASS.get(tool_name) def maybe_emit_argument_delta( tool_calls_in_progress: Mapping[int, Mapping[str, Any]], tool_call_delta: ChatCompletionDeltaToolCall, placement: Placement, parsers: dict[int, Parser], ) -> Generator[Packet, None, None]: """Emit decoded tool-call argument deltas to the frontend. Uses a ``jsonriver.Parser`` per tool-call index to incrementally parse the JSON argument string and extract only the newly-appended content for each string-valued argument. NOTE: Non-string arguments (numbers, booleans, null, arrays, objects) are skipped — they are available in the final tool-call kickoff packet. ``parsers`` is a mutable dict keyed by tool-call index. A new ``Parser`` is created automatically for each new index. """ tool_cls = _get_tool_class(tool_calls_in_progress, tool_call_delta) if not tool_cls or not tool_cls.should_emit_argument_deltas(): return fn = tool_call_delta.function delta_fragment = fn.arguments if fn else None if not delta_fragment: return idx = tool_call_delta.index if idx not in parsers: parsers[idx] = Parser() parser = parsers[idx] deltas = parser.feed(delta_fragment) argument_deltas: dict[str, str] = {} for delta in deltas: if isinstance(delta, dict): for key, value in delta.items(): if isinstance(value, str): argument_deltas[key] = argument_deltas.get(key, "") + value if not argument_deltas: return tc_data = tool_calls_in_progress[tool_call_delta.index] yield Packet( placement=placement, obj=ToolCallArgumentDelta( tool_type=tc_data.get("name", ""), argument_deltas=argument_deltas, ), ) ================================================ FILE: backend/onyx/configs/__init__.py ================================================ ================================================ FILE: backend/onyx/configs/agent_configs.py ================================================ import os AGENT_DEFAULT_RETRIEVAL_HITS = 15 AGENT_DEFAULT_RERANKING_HITS = 10 AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS = 8 AGENT_DEFAULT_NUM_DOCS_FOR_INITIAL_DECOMPOSITION = 3 AGENT_DEFAULT_NUM_DOCS_FOR_REFINED_DECOMPOSITION = 5 AGENT_DEFAULT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER = 25 AGENT_DEFAULT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER = 35 AGENT_DEFAULT_EXPLORATORY_SEARCH_RESULTS = 5 AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3 AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10 AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000 INITIAL_SEARCH_DECOMPOSITION_ENABLED = True AGENT_DEFAULT_RETRIEVAL_HITS = 15 AGENT_DEFAULT_RERANKING_HITS = 10 AGENT_DEFAULT_MAX_VERIFIVATION_HITS = 30 AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS = 8 AGENT_DEFAULT_NUM_DOCS_FOR_INITIAL_DECOMPOSITION = 3 AGENT_DEFAULT_NUM_DOCS_FOR_REFINED_DECOMPOSITION = 5 AGENT_DEFAULT_EXPLORATORY_SEARCH_RESULTS = 5 AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3 AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10 AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000 AGENT_ALLOW_REFINEMENT = os.environ.get("AGENT_ALLOW_REFINEMENT", "").lower() == "true" AGENT_ANSWER_GENERATION_BY_FAST_LLM = ( os.environ.get("AGENT_ANSWER_GENERATION_BY_FAST_LLM", "").lower() == "true" ) AGENT_RETRIEVAL_STATS = ( not os.environ.get("AGENT_RETRIEVAL_STATS") == "False" ) or True # default True AGENT_MAX_VERIFICATION_HITS = int( os.environ.get("AGENT_MAX_VERIFICATION_HITS") or AGENT_DEFAULT_MAX_VERIFIVATION_HITS ) # 30 AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int( os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS ) # 15 AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int( os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS ) # 15 # Reranking agent configs # Reranking stats - no influence on flow outside of stats collection AGENT_RERANKING_STATS = ( not os.environ.get("AGENT_RERANKING_STATS") == "True" ) or False # default False AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int( os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS ) # 15 AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS = int( os.environ.get("AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RERANKING_HITS ) # 10 AGENT_NUM_DOCS_FOR_DECOMPOSITION = int( os.environ.get("AGENT_NUM_DOCS_FOR_DECOMPOSITION") or AGENT_DEFAULT_NUM_DOCS_FOR_INITIAL_DECOMPOSITION ) # 3 AGENT_NUM_DOCS_FOR_REFINED_DECOMPOSITION = int( os.environ.get("AGENT_NUM_DOCS_FOR_REFINED_DECOMPOSITION") or AGENT_DEFAULT_NUM_DOCS_FOR_REFINED_DECOMPOSITION ) # 5 AGENT_EXPLORATORY_SEARCH_RESULTS = int( os.environ.get("AGENT_EXPLORATORY_SEARCH_RESULTS") or AGENT_DEFAULT_EXPLORATORY_SEARCH_RESULTS ) # 5 AGENT_MIN_ORIG_QUESTION_DOCS = int( os.environ.get("AGENT_MIN_ORIG_QUESTION_DOCS") or AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS ) # 3 AGENT_MAX_ANSWER_CONTEXT_DOCS = int( os.environ.get("AGENT_MAX_ANSWER_CONTEXT_DOCS") or AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS ) # 8 AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int( os.environ.get("AGENT_MAX_STATIC_HISTORY_WORD_LENGTH") or AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH ) # 2000 AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER = int( os.environ.get("AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER") or AGENT_DEFAULT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER ) # 25 AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER = int( os.environ.get("AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER") or AGENT_DEFAULT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER ) # 35 AGENT_RETRIEVAL_STATS = ( not os.environ.get("AGENT_RETRIEVAL_STATS") == "False" ) or True # default True AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int( os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS ) # 15 AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int( os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS ) # 15 # Reranking agent configs # Reranking stats - no influence on flow outside of stats collection AGENT_RERANKING_STATS = ( not os.environ.get("AGENT_RERANKING_STATS") == "True" ) or False # default False AGENT_MAX_QUERY_RETRIEVAL_RESULTS = int( os.environ.get("AGENT_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RETRIEVAL_HITS ) # 15 AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS = int( os.environ.get("AGENT_RERANKING_MAX_QUERY_RETRIEVAL_RESULTS") or AGENT_DEFAULT_RERANKING_HITS ) # 10 AGENT_NUM_DOCS_FOR_DECOMPOSITION = int( os.environ.get("AGENT_NUM_DOCS_FOR_DECOMPOSITION") or AGENT_DEFAULT_NUM_DOCS_FOR_INITIAL_DECOMPOSITION ) # 3 AGENT_NUM_DOCS_FOR_REFINED_DECOMPOSITION = int( os.environ.get("AGENT_NUM_DOCS_FOR_REFINED_DECOMPOSITION") or AGENT_DEFAULT_NUM_DOCS_FOR_REFINED_DECOMPOSITION ) # 5 AGENT_EXPLORATORY_SEARCH_RESULTS = int( os.environ.get("AGENT_EXPLORATORY_SEARCH_RESULTS") or AGENT_DEFAULT_EXPLORATORY_SEARCH_RESULTS ) # 5 AGENT_MIN_ORIG_QUESTION_DOCS = int( os.environ.get("AGENT_MIN_ORIG_QUESTION_DOCS") or AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS ) # 3 AGENT_MAX_ANSWER_CONTEXT_DOCS = int( os.environ.get("AGENT_MAX_ANSWER_CONTEXT_DOCS") or AGENT_DEFAULT_SUB_QUESTION_MAX_CONTEXT_HITS ) # 8 AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int( os.environ.get("AGENT_MAX_STATIC_HISTORY_WORD_LENGTH") or AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH ) # 2000 AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = 15 # in seconds AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = int( os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION") or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION ) AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = 45 # in seconds AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = int( os.environ.get("AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION") or AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION ) AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = 5 # in seconds AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = int( os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION") or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION ) AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = 8 # in seconds AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = int( os.environ.get("AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION") or AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION ) AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = 8 # in seconds AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = int( os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION") or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION ) AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION = 45 # in seconds AGENT_TIMEOUT_LLM_GENERAL_GENERATION = int( os.environ.get("AGENT_TIMEOUT_LLM_GENERAL_GENERATION") or AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION ) AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = 8 # in seconds AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = int( os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION") or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION ) AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION = 10 # in seconds AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int( os.environ.get("AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION") or AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION ) AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 9 # in seconds AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int( os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION") or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION ) AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 45 # in seconds AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int( os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION") or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION ) AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 15 # in seconds AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = int( os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION") or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION ) AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = 40 # in seconds AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = int( os.environ.get("AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION") or AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION ) AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 20 # in seconds AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = int( os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION") or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION ) AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 60 # in seconds AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int( os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION") or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION ) AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = 6 # in seconds AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = int( os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK") or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK ) AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK = 12 # in seconds AGENT_TIMEOUT_LLM_SUBANSWER_CHECK = int( os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_CHECK") or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK ) AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = 6 # in seconds AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = int( os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION") or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION ) AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = 12 # in seconds AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = int( os.environ.get("AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION") or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION ) AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = 4 # in seconds AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = int( os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION") or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION ) AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = 6 # in seconds AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = int( os.environ.get("AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION") or AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION ) AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = 6 # in seconds AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = int( os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION") or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION ) AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = 8 # in seconds AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = int( os.environ.get("AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION") or AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION ) AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = 6 # in seconds AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = int( os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS") or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS ) AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS = 12 # in seconds AGENT_TIMEOUT_LLM_COMPARE_ANSWERS = int( os.environ.get("AGENT_TIMEOUT_LLM_COMPARE_ANSWERS") or AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS ) AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = 6 # in seconds AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = int( os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION") or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION ) AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = 12 # in seconds AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = int( os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION") or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION ) AGENT_DEFAULT_MAX_TOKENS_VALIDATION = 4 AGENT_MAX_TOKENS_VALIDATION = int( os.environ.get("AGENT_MAX_TOKENS_VALIDATION") or AGENT_DEFAULT_MAX_TOKENS_VALIDATION ) AGENT_DEFAULT_MAX_TOKENS_SUBANSWER_GENERATION = 256 AGENT_MAX_TOKENS_SUBANSWER_GENERATION = int( os.environ.get("AGENT_MAX_TOKENS_SUBANSWER_GENERATION") or AGENT_DEFAULT_MAX_TOKENS_SUBANSWER_GENERATION ) AGENT_DEFAULT_MAX_TOKENS_ANSWER_GENERATION = 1024 AGENT_MAX_TOKENS_ANSWER_GENERATION = int( os.environ.get("AGENT_MAX_TOKENS_ANSWER_GENERATION") or AGENT_DEFAULT_MAX_TOKENS_ANSWER_GENERATION ) AGENT_DEFAULT_MAX_TOKENS_SUBQUESTION_GENERATION = 256 AGENT_MAX_TOKENS_SUBQUESTION_GENERATION = int( os.environ.get("AGENT_MAX_TOKENS_SUBQUESTION_GENERATION") or AGENT_DEFAULT_MAX_TOKENS_SUBQUESTION_GENERATION ) AGENT_DEFAULT_MAX_TOKENS_ENTITY_TERM_EXTRACTION = 1024 AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION = int( os.environ.get("AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION") or AGENT_DEFAULT_MAX_TOKENS_ENTITY_TERM_EXTRACTION ) AGENT_DEFAULT_MAX_TOKENS_SUBQUERY_GENERATION = 64 AGENT_MAX_TOKENS_SUBQUERY_GENERATION = int( os.environ.get("AGENT_MAX_TOKENS_SUBQUERY_GENERATION") or AGENT_DEFAULT_MAX_TOKENS_SUBQUERY_GENERATION ) AGENT_DEFAULT_MAX_TOKENS_HISTORY_SUMMARY = 128 AGENT_MAX_TOKENS_HISTORY_SUMMARY = int( os.environ.get("AGENT_MAX_TOKENS_HISTORY_SUMMARY") or AGENT_DEFAULT_MAX_TOKENS_HISTORY_SUMMARY ) # Parameters for the Thoughtful/Deep Research flows TF_DR_TIMEOUT_LONG = int(os.environ.get("TF_DR_TIMEOUT_LONG") or 120) TF_DR_TIMEOUT_SHORT = int(os.environ.get("TF_DR_TIMEOUT_SHORT") or 60) TF_DR_DEFAULT_FAST = (os.environ.get("TF_DR_DEFAULT_FAST") or "False").lower() == "true" GRAPH_VERSION_NAME: str = "a" ================================================ FILE: backend/onyx/configs/app_configs.py ================================================ import json import os import urllib.parse from datetime import datetime from datetime import timezone from typing import cast from onyx.auth.schemas import AuthBackend from onyx.cache.interface import CacheBackendType from onyx.configs.constants import AuthType from onyx.configs.constants import QueryHistoryType from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT from onyx.utils.logger import setup_logger logger = setup_logger() ##### # App Configs ##### APP_HOST = "0.0.0.0" APP_PORT = 8080 # API_PREFIX is used to prepend a base path for all API routes # generally used if using a reverse proxy which doesn't support stripping the `/api` # prefix from requests directed towards the API server. In these cases, set this to `/api` APP_API_PREFIX = os.environ.get("API_PREFIX", "") # Certain services need to make HTTP requests to the API server, such as the MCP server and Discord bot API_SERVER_PROTOCOL = os.environ.get("API_SERVER_PROTOCOL", "http") API_SERVER_HOST = os.environ.get("API_SERVER_HOST", "127.0.0.1") # This override allows self-hosting the MCP server with Onyx Cloud backend. API_SERVER_URL_OVERRIDE_FOR_HTTP_REQUESTS = os.environ.get( "API_SERVER_URL_OVERRIDE_FOR_HTTP_REQUESTS" ) # Whether to send user metadata (user_id/email and session_id) to the LLM provider. # Disabled by default. SEND_USER_METADATA_TO_LLM_PROVIDER = ( os.environ.get("SEND_USER_METADATA_TO_LLM_PROVIDER", "") ).lower() == "true" ##### # User Facing Features Configs ##### BLURB_SIZE = 128 # Number Encoder Tokens included in the chunk blurb # Hard ceiling for the admin-configurable file upload size (in MB). # Self-hosted customers can raise or lower this via the environment variable. _raw_max_upload_size_mb = int(os.environ.get("MAX_ALLOWED_UPLOAD_SIZE_MB", "250")) if _raw_max_upload_size_mb < 0: logger.warning( "MAX_ALLOWED_UPLOAD_SIZE_MB=%d is negative; falling back to 250", _raw_max_upload_size_mb, ) _raw_max_upload_size_mb = 250 MAX_ALLOWED_UPLOAD_SIZE_MB = _raw_max_upload_size_mb # Default fallback for the per-user file upload size limit (in MB) when no # admin-configured value exists. Clamped to MAX_ALLOWED_UPLOAD_SIZE_MB at # runtime so this never silently exceeds the hard ceiling. _raw_default_upload_size_mb = int( os.environ.get("DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB", "100") ) if _raw_default_upload_size_mb < 0: logger.warning( "DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB=%d is negative; falling back to 100", _raw_default_upload_size_mb, ) _raw_default_upload_size_mb = 100 DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB = _raw_default_upload_size_mb GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int( os.environ.get("GENERATIVE_MODEL_ACCESS_CHECK_FREQ") or 86400 ) # 1 day # Controls whether users can use User Knowledge (personal documents) in assistants DISABLE_USER_KNOWLEDGE = os.environ.get("DISABLE_USER_KNOWLEDGE", "").lower() == "true" # Disables vector DB (Vespa/OpenSearch) entirely. When True, connectors and RAG search # are disabled but core chat, tools, user file uploads, and Projects still work. DISABLE_VECTOR_DB = os.environ.get("DISABLE_VECTOR_DB", "").lower() == "true" # Which backend to use for caching, locks, and ephemeral state. # "redis" (default) or "postgres" (only valid when DISABLE_VECTOR_DB=true). CACHE_BACKEND = CacheBackendType( os.environ.get("CACHE_BACKEND", CacheBackendType.REDIS) ) # If set to true, will show extra/uncommon connectors in the "Other" category SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true" # Controls whether to allow admin query history reports with: # 1. associated user emails # 2. anonymized user emails # 3. no queries ONYX_QUERY_HISTORY_TYPE = QueryHistoryType( (os.environ.get("ONYX_QUERY_HISTORY_TYPE") or QueryHistoryType.NORMAL.value).lower() ) ##### # Web Configs ##### # WEB_DOMAIN is used to set the redirect_uri after login flows # NOTE: if you are having problems accessing the Onyx web UI locally (especially # on Windows, try setting this to `http://127.0.0.1:3000` instead and see if that # fixes it) WEB_DOMAIN = os.environ.get("WEB_DOMAIN") or "http://localhost:3000" ##### # Auth Configs ##### # Silently default to basic - warnings/errors logged in verify_auth_setting() # which only runs on app startup, not during migrations/scripts _auth_type_str = (os.environ.get("AUTH_TYPE") or "").lower() if _auth_type_str in [auth_type.value for auth_type in AuthType]: AUTH_TYPE = AuthType(_auth_type_str) else: AUTH_TYPE = AuthType.BASIC PASSWORD_MIN_LENGTH = int(os.getenv("PASSWORD_MIN_LENGTH", 8)) PASSWORD_MAX_LENGTH = int(os.getenv("PASSWORD_MAX_LENGTH", 64)) PASSWORD_REQUIRE_UPPERCASE = ( os.environ.get("PASSWORD_REQUIRE_UPPERCASE", "false").lower() == "true" ) PASSWORD_REQUIRE_LOWERCASE = ( os.environ.get("PASSWORD_REQUIRE_LOWERCASE", "false").lower() == "true" ) PASSWORD_REQUIRE_DIGIT = ( os.environ.get("PASSWORD_REQUIRE_DIGIT", "false").lower() == "true" ) PASSWORD_REQUIRE_SPECIAL_CHAR = ( os.environ.get("PASSWORD_REQUIRE_SPECIAL_CHAR", "false").lower() == "true" ) # Encryption key secret is used to encrypt connector credentials, api keys, and other sensitive # information. This provides an extra layer of security on top of Postgres access controls # and is available in Onyx EE ENCRYPTION_KEY_SECRET = os.environ.get("ENCRYPTION_KEY_SECRET") or "" # Turn off mask if admin users should see full credentials for data connectors. MASK_CREDENTIAL_PREFIX = ( os.environ.get("MASK_CREDENTIAL_PREFIX", "True").lower() != "false" ) AUTH_BACKEND = AuthBackend(os.environ.get("AUTH_BACKEND") or AuthBackend.REDIS.value) SESSION_EXPIRE_TIME_SECONDS = int( os.environ.get("SESSION_EXPIRE_TIME_SECONDS") or os.environ.get("REDIS_AUTH_EXPIRE_TIME_SECONDS") or 86400 * 7 ) # 7 days # Default request timeout, mostly used by connectors REQUEST_TIMEOUT_SECONDS = int(os.environ.get("REQUEST_TIMEOUT_SECONDS") or 60) # set `VALID_EMAIL_DOMAINS` to a comma seperated list of domains in order to # restrict access to Onyx to only users with emails from those domains. # E.g. `VALID_EMAIL_DOMAINS=example.com,example.org` will restrict Onyx # signups to users with either an @example.com or an @example.org email. # NOTE: maintaining `VALID_EMAIL_DOMAIN` to keep backwards compatibility _VALID_EMAIL_DOMAIN = os.environ.get("VALID_EMAIL_DOMAIN", "") _VALID_EMAIL_DOMAINS_STR = ( os.environ.get("VALID_EMAIL_DOMAINS", "") or _VALID_EMAIL_DOMAIN ) VALID_EMAIL_DOMAINS = ( [ domain.strip().lower() for domain in _VALID_EMAIL_DOMAINS_STR.split(",") if domain.strip() ] if _VALID_EMAIL_DOMAINS_STR else [] ) # Disposable email blocking - blocks temporary/throwaway email addresses # Set to empty string to disable disposable email blocking DISPOSABLE_EMAIL_DOMAINS_URL = os.environ.get( "DISPOSABLE_EMAIL_DOMAINS_URL", "https://disposable.github.io/disposable-email-domains/domains.json", ) # OAuth Login Flow # Used for both Google OAuth2 and OIDC flows OAUTH_CLIENT_ID = ( os.environ.get("OAUTH_CLIENT_ID", os.environ.get("GOOGLE_OAUTH_CLIENT_ID")) or "" ) OAUTH_CLIENT_SECRET = ( os.environ.get("OAUTH_CLIENT_SECRET", os.environ.get("GOOGLE_OAUTH_CLIENT_SECRET")) or "" ) # Whether Google OAuth is enabled (requires both client ID and secret) OAUTH_ENABLED = bool(OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET) # OpenID Connect configuration URL for OIDC integrations OPENID_CONFIG_URL = os.environ.get("OPENID_CONFIG_URL") or "" # Applicable for OIDC Auth, allows you to override the scopes that # are requested from the OIDC provider. Currently used when passing # over access tokens to tool calls and the tool needs more scopes OIDC_SCOPE_OVERRIDE: list[str] | None = None _OIDC_SCOPE_OVERRIDE = os.environ.get("OIDC_SCOPE_OVERRIDE") if _OIDC_SCOPE_OVERRIDE: try: OIDC_SCOPE_OVERRIDE = [ scope.strip() for scope in _OIDC_SCOPE_OVERRIDE.split(",") ] except Exception: pass # Enables PKCE for OIDC login flow. Disabled by default to preserve # backwards compatibility for existing OIDC deployments. OIDC_PKCE_ENABLED = os.environ.get("OIDC_PKCE_ENABLED", "").lower() == "true" # Applicable for SAML Auth SAML_CONF_DIR = os.environ.get("SAML_CONF_DIR") or "/app/onyx/configs/saml_config" # JWT Public Key URL for JWT token verification JWT_PUBLIC_KEY_URL: str | None = os.getenv("JWT_PUBLIC_KEY_URL", None) USER_AUTH_SECRET = os.environ.get("USER_AUTH_SECRET", "") if AUTH_TYPE == AuthType.BASIC and not USER_AUTH_SECRET: logger.warning( "USER_AUTH_SECRET is not set. This is required for secure password reset " "and email verification tokens. Please set USER_AUTH_SECRET in production." ) # Duration (in seconds) for which the FastAPI Users JWT token remains valid in the user's browser. # By default, this is set to match the Redis expiry time for consistency. AUTH_COOKIE_EXPIRE_TIME_SECONDS = int( os.environ.get("AUTH_COOKIE_EXPIRE_TIME_SECONDS") or 86400 * 7 ) # 7 days # for basic auth REQUIRE_EMAIL_VERIFICATION = ( os.environ.get("REQUIRE_EMAIL_VERIFICATION", "").lower() == "true" ) SMTP_SERVER = os.environ.get("SMTP_SERVER") or "" SMTP_PORT = int(os.environ.get("SMTP_PORT") or "587") SMTP_USER = os.environ.get("SMTP_USER") or "" SMTP_PASS = os.environ.get("SMTP_PASS") or "" EMAIL_FROM = os.environ.get("EMAIL_FROM") or SMTP_USER SENDGRID_API_KEY = os.environ.get("SENDGRID_API_KEY") or "" EMAIL_CONFIGURED = all([SMTP_SERVER, SMTP_USER, SMTP_PASS]) or SENDGRID_API_KEY # If set, Onyx will listen to the `expires_at` returned by the identity # provider (e.g. Okta, Google, etc.) and force the user to re-authenticate # after this time has elapsed. Disabled since by default many auth providers # have very short expiry times (e.g. 1 hour) which provide a poor user experience TRACK_EXTERNAL_IDP_EXPIRY = ( os.environ.get("TRACK_EXTERNAL_IDP_EXPIRY", "").lower() == "true" ) ##### # DB Configs ##### DOCUMENT_INDEX_NAME = "danswer_index" # OpenSearch Configs OPENSEARCH_HOST = os.environ.get("OPENSEARCH_HOST") or "localhost" OPENSEARCH_REST_API_PORT = int(os.environ.get("OPENSEARCH_REST_API_PORT") or 9200) # TODO(andrei): 60 seconds is too much, we're just setting a high default # timeout for now to examine why queries are slow. # NOTE: This timeout applies to all requests the client makes, including bulk # indexing. DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S = int( os.environ.get("DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S") or 60 ) # TODO(andrei): 50 seconds is too much, we're just setting a high default # timeout for now to examine why queries are slow. # NOTE: To get useful partial results, this value should be less than the client # timeout above. DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S = int( os.environ.get("DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S") or 50 ) OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin") OPENSEARCH_ADMIN_PASSWORD = os.environ.get( "OPENSEARCH_ADMIN_PASSWORD", "StrongPassword123!" ) USING_AWS_MANAGED_OPENSEARCH = ( os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true" ) # Profiling adds some overhead to OpenSearch operations. This overhead is # unknown right now. Defaults to True. OPENSEARCH_PROFILING_DISABLED = ( os.environ.get("OPENSEARCH_PROFILING_DISABLED", "true").lower() == "true" ) # Whether to disable match highlights for OpenSearch. Defaults to True for now # as we investigate query performance. OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED = ( os.environ.get("OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED", "true").lower() == "true" ) # When enabled, OpenSearch returns detailed score breakdowns for each hit. # Useful for debugging and tuning search relevance. Has ~10-30% performance overhead according to documentation. # Seems for Hybrid Search in practice, the impact is actually more like 1000x slower. OPENSEARCH_EXPLAIN_ENABLED = ( os.environ.get("OPENSEARCH_EXPLAIN_ENABLED", "").lower() == "true" ) # Analyzer used for full-text fields (title, content). Use OpenSearch built-in analyzer # names (e.g. "english", "standard", "german"). Affects stemming and tokenization; # existing indices need reindexing after a change. OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "english" # This is the "base" config for now, the idea is that at least for our dev # environments we always want to be dual indexing into both OpenSearch and Vespa # to stress test the new codepaths. Only enable this if there is some instance # of OpenSearch running for the relevant Onyx instance. # NOTE: Now enabled on by default, unless the env indicates otherwise. ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = ( os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "true").lower() == "true" ) # NOTE: This effectively does nothing anymore, admins can now toggle whether # retrieval is through OpenSearch. This value is only used as a final fallback # in case that doesn't work for whatever reason. # Given that the "base" config above is true, this enables whether we want to # retrieve from OpenSearch or Vespa. We want to be able to quickly toggle this # in the event we see issues with OpenSearch retrieval in our dev environments. ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX = ( ENABLE_OPENSEARCH_INDEXING_FOR_ONYX and os.environ.get("ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX", "").lower() == "true" ) # Whether we should check for and create an index if necessary every time we # instantiate an OpenSearchDocumentIndex on multitenant cloud. Defaults to True. VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT = ( os.environ.get("VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT", "true").lower() == "true" ) OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE = int( os.environ.get("OPENSEARCH_MIGRATION_GET_VESPA_CHUNKS_PAGE_SIZE") or 500 ) # If set, will override the default number of shards and replicas for the index. OPENSEARCH_INDEX_NUM_SHARDS: int | None = ( int(os.environ["OPENSEARCH_INDEX_NUM_SHARDS"]) if os.environ.get("OPENSEARCH_INDEX_NUM_SHARDS", None) is not None else None ) OPENSEARCH_INDEX_NUM_REPLICAS: int | None = ( int(os.environ["OPENSEARCH_INDEX_NUM_REPLICAS"]) if os.environ.get("OPENSEARCH_INDEX_NUM_REPLICAS", None) is not None else None ) ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH = ( os.environ.get("ONYX_SEARCH_UI_USES_OPENSEARCH_KEYWORD_SEARCH", "").lower() == "true" ) VESPA_HOST = os.environ.get("VESPA_HOST") or "localhost" # NOTE: this is used if and only if the vespa config server is accessible via a # different host than the main vespa application VESPA_CONFIG_SERVER_HOST = os.environ.get("VESPA_CONFIG_SERVER_HOST") or VESPA_HOST VESPA_PORT = os.environ.get("VESPA_PORT") or "8081" VESPA_TENANT_PORT = os.environ.get("VESPA_TENANT_PORT") or "19071" # the number of times to try and connect to vespa on startup before giving up VESPA_NUM_ATTEMPTS_ON_STARTUP = int(os.environ.get("NUM_RETRIES_ON_STARTUP") or 10) VESPA_CLOUD_URL = os.environ.get("VESPA_CLOUD_URL", "") VESPA_CLOUD_CERT_PATH = os.environ.get("VESPA_CLOUD_CERT_PATH") VESPA_CLOUD_KEY_PATH = os.environ.get("VESPA_CLOUD_KEY_PATH") # Number of documents in a batch during indexing (further batching done by chunks before passing to bi-encoder) INDEX_BATCH_SIZE = int(os.environ.get("INDEX_BATCH_SIZE") or 16) MAX_DRIVE_WORKERS = int(os.environ.get("MAX_DRIVE_WORKERS", 4)) # Below are intended to match the env variables names used by the official postgres docker image # https://hub.docker.com/_/postgres POSTGRES_USER = os.environ.get("POSTGRES_USER") or "postgres" # URL-encode the password for asyncpg to avoid issues with special characters on some machines. POSTGRES_PASSWORD = urllib.parse.quote_plus( os.environ.get("POSTGRES_PASSWORD") or "password" ) POSTGRES_HOST = os.environ.get("POSTGRES_HOST") or "127.0.0.1" POSTGRES_PORT = os.environ.get("POSTGRES_PORT") or "5432" POSTGRES_DB = os.environ.get("POSTGRES_DB") or "postgres" AWS_REGION_NAME = os.environ.get("AWS_REGION_NAME") or "us-east-2" POSTGRES_API_SERVER_POOL_SIZE = int( os.environ.get("POSTGRES_API_SERVER_POOL_SIZE") or 40 ) POSTGRES_API_SERVER_POOL_OVERFLOW = int( os.environ.get("POSTGRES_API_SERVER_POOL_OVERFLOW") or 10 ) POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE = int( os.environ.get("POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE") or 10 ) POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW = int( os.environ.get("POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW") or 5 ) # defaults to False # generally should only be used for POSTGRES_USE_NULL_POOL = os.environ.get("POSTGRES_USE_NULL_POOL", "").lower() == "true" # defaults to False POSTGRES_POOL_PRE_PING = os.environ.get("POSTGRES_POOL_PRE_PING", "").lower() == "true" # recycle timeout in seconds POSTGRES_POOL_RECYCLE_DEFAULT = 60 * 20 # 20 minutes try: POSTGRES_POOL_RECYCLE = int( os.environ.get("POSTGRES_POOL_RECYCLE", POSTGRES_POOL_RECYCLE_DEFAULT) ) except ValueError: POSTGRES_POOL_RECYCLE = POSTGRES_POOL_RECYCLE_DEFAULT # RDS IAM authentication - enables IAM-based authentication for PostgreSQL USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true" # Redis IAM authentication - enables IAM-based authentication for Redis ElastiCache # Note: This is separate from RDS IAM auth as they use different authentication mechanisms USE_REDIS_IAM_AUTH = os.getenv("USE_REDIS_IAM_AUTH", "False").lower() == "true" REDIS_SSL = os.getenv("REDIS_SSL", "").lower() == "true" REDIS_HOST = os.environ.get("REDIS_HOST") or "localhost" REDIS_PORT = int(os.environ.get("REDIS_PORT", 6379)) REDIS_PASSWORD = os.environ.get("REDIS_PASSWORD") or "" # this assumes that other redis settings remain the same as the primary REDIS_REPLICA_HOST = os.environ.get("REDIS_REPLICA_HOST") or REDIS_HOST REDIS_AUTH_KEY_PREFIX = "fastapi_users_token:" # Rate limiting for auth endpoints RATE_LIMIT_WINDOW_SECONDS: int | None = None _rate_limit_window_seconds_str = os.environ.get("RATE_LIMIT_WINDOW_SECONDS") if _rate_limit_window_seconds_str is not None: try: RATE_LIMIT_WINDOW_SECONDS = int(_rate_limit_window_seconds_str) except ValueError: pass RATE_LIMIT_MAX_REQUESTS: int | None = None _rate_limit_max_requests_str = os.environ.get("RATE_LIMIT_MAX_REQUESTS") if _rate_limit_max_requests_str is not None: try: RATE_LIMIT_MAX_REQUESTS = int(_rate_limit_max_requests_str) except ValueError: pass AUTH_RATE_LIMITING_ENABLED = RATE_LIMIT_MAX_REQUESTS and RATE_LIMIT_WINDOW_SECONDS # Used for general redis things REDIS_DB_NUMBER = int(os.environ.get("REDIS_DB_NUMBER", 0)) # Used by celery as broker and backend REDIS_DB_NUMBER_CELERY_RESULT_BACKEND = int( os.environ.get("REDIS_DB_NUMBER_CELERY_RESULT_BACKEND", 14) ) REDIS_DB_NUMBER_CELERY = int(os.environ.get("REDIS_DB_NUMBER_CELERY", 15)) # broker # will propagate to both our redis client as well as celery's redis client REDIS_HEALTH_CHECK_INTERVAL = int(os.environ.get("REDIS_HEALTH_CHECK_INTERVAL", 60)) # our redis client only, not celery's REDIS_POOL_MAX_CONNECTIONS = int(os.environ.get("REDIS_POOL_MAX_CONNECTIONS", 128)) # https://docs.celeryq.dev/en/stable/userguide/configuration.html#redis-backend-settings # should be one of "required", "optional", or "none" REDIS_SSL_CERT_REQS = os.getenv("REDIS_SSL_CERT_REQS", "none") REDIS_SSL_CA_CERTS = os.getenv("REDIS_SSL_CA_CERTS", None) CELERY_RESULT_EXPIRES = int(os.environ.get("CELERY_RESULT_EXPIRES", 86400)) # seconds # https://docs.celeryq.dev/en/stable/userguide/configuration.html#broker-pool-limit # Setting to None may help when there is a proxy in the way closing idle connections _CELERY_BROKER_POOL_LIMIT_DEFAULT = 10 try: CELERY_BROKER_POOL_LIMIT = int( os.environ.get("CELERY_BROKER_POOL_LIMIT", _CELERY_BROKER_POOL_LIMIT_DEFAULT) ) except ValueError: CELERY_BROKER_POOL_LIMIT = _CELERY_BROKER_POOL_LIMIT_DEFAULT _CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT = 24 try: CELERY_WORKER_LIGHT_CONCURRENCY = int( os.environ.get( "CELERY_WORKER_LIGHT_CONCURRENCY", _CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT, ) ) except ValueError: CELERY_WORKER_LIGHT_CONCURRENCY = _CELERY_WORKER_LIGHT_CONCURRENCY_DEFAULT _CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT = 8 try: CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER = int( os.environ.get( "CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER", _CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT, ) ) except ValueError: CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER = ( _CELERY_WORKER_LIGHT_PREFETCH_MULTIPLIER_DEFAULT ) _CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT = 6 try: env_value = os.environ.get("CELERY_WORKER_DOCPROCESSING_CONCURRENCY") if not env_value: env_value = os.environ.get("NUM_INDEXING_WORKERS") if not env_value: env_value = str(_CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT) CELERY_WORKER_DOCPROCESSING_CONCURRENCY = int(env_value) except ValueError: CELERY_WORKER_DOCPROCESSING_CONCURRENCY = ( _CELERY_WORKER_DOCPROCESSING_CONCURRENCY_DEFAULT ) _CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT = 1 try: env_value = os.environ.get("CELERY_WORKER_DOCFETCHING_CONCURRENCY") if not env_value: env_value = os.environ.get("NUM_DOCFETCHING_WORKERS") if not env_value: env_value = str(_CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT) CELERY_WORKER_DOCFETCHING_CONCURRENCY = int(env_value) except ValueError: CELERY_WORKER_DOCFETCHING_CONCURRENCY = ( _CELERY_WORKER_DOCFETCHING_CONCURRENCY_DEFAULT ) CELERY_WORKER_PRIMARY_CONCURRENCY = int( os.environ.get("CELERY_WORKER_PRIMARY_CONCURRENCY") or 4 ) CELERY_WORKER_PRIMARY_POOL_OVERFLOW = int( os.environ.get("CELERY_WORKER_PRIMARY_POOL_OVERFLOW") or 4 ) # Individual worker concurrency settings CELERY_WORKER_HEAVY_CONCURRENCY = int( os.environ.get("CELERY_WORKER_HEAVY_CONCURRENCY") or 4 ) CELERY_WORKER_MONITORING_CONCURRENCY = int( os.environ.get("CELERY_WORKER_MONITORING_CONCURRENCY") or 1 ) CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY = int( os.environ.get("CELERY_WORKER_USER_FILE_PROCESSING_CONCURRENCY") or 2 ) # The maximum number of tasks that can be queued up to sync to Vespa in a single pass VESPA_SYNC_MAX_TASKS = 8192 DB_YIELD_PER_DEFAULT = 64 ##### # Connector Configs ##### POLL_CONNECTOR_OFFSET = 30 # Minutes overlap between poll windows # View the list here: # https://github.com/onyx-dot-app/onyx/blob/main/backend/onyx/connectors/factory.py # If this is empty, all connectors are enabled, this is an option for security heavy orgs where # only very select connectors are enabled and admins cannot add other connector types ENABLED_CONNECTOR_TYPES = os.environ.get("ENABLED_CONNECTOR_TYPES") or "" # If set to true, curators can only access and edit assistants that they created CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS = ( os.environ.get("CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS", "").lower() == "true" ) # Some calls to get information on expert users are quite costly especially with rate limiting # Since experts are not used in the actual user experience, currently it is turned off # for some connectors ENABLE_EXPENSIVE_EXPERT_CALLS = False # TODO these should be available for frontend configuration, via advanced options expandable WEB_CONNECTOR_IGNORED_CLASSES = os.environ.get( "WEB_CONNECTOR_IGNORED_CLASSES", "sidebar,footer" ).split(",") WEB_CONNECTOR_IGNORED_ELEMENTS = os.environ.get( "WEB_CONNECTOR_IGNORED_ELEMENTS", "nav,footer,meta,script,style,symbol,aside" ).split(",") WEB_CONNECTOR_OAUTH_CLIENT_ID = os.environ.get("WEB_CONNECTOR_OAUTH_CLIENT_ID") WEB_CONNECTOR_OAUTH_CLIENT_SECRET = os.environ.get("WEB_CONNECTOR_OAUTH_CLIENT_SECRET") WEB_CONNECTOR_OAUTH_TOKEN_URL = os.environ.get("WEB_CONNECTOR_OAUTH_TOKEN_URL") WEB_CONNECTOR_VALIDATE_URLS = os.environ.get("WEB_CONNECTOR_VALIDATE_URLS") HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY = os.environ.get( "HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY", HtmlBasedConnectorTransformLinksStrategy.STRIP, ) NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP = ( os.environ.get("NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP", "").lower() == "true" ) ##### # Confluence Connector Configs ##### CONFLUENCE_CONNECTOR_LABELS_TO_SKIP = [ ignored_tag for ignored_tag in os.environ.get("CONFLUENCE_CONNECTOR_LABELS_TO_SKIP", "").split( "," ) if ignored_tag ] # Attachments exceeding this size will not be retrieved (in bytes) CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD = int( os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD", 10 * 1024 * 1024) ) # Attachments with more chars than this will not be indexed. This is to prevent extremely # large files from freezing indexing. 200,000 is ~100 google doc pages. CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD = int( os.environ.get("CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD", 200_000) ) # A JSON-formatted array. Each item in the array should have the following structure: # { # "user_id": "1234567890", # "username": "bob", # "display_name": "Bob Fitzgerald", # "email": "bob@example.com", # "type": "known" # } _RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE = os.environ.get( "CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE", "" ) CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE = cast( list[dict[str, str]] | None, ( json.loads(_RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE) if _RAW_CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE else None ), ) # Due to breakages in the confluence API, the timezone offset must be specified client side # to match the user's specified timezone. # The current state of affairs: # CQL queries are parsed in the user's timezone and cannot be specified in UTC # no API retrieves the user's timezone # All data is returned in UTC, so we can't derive the user's timezone from that # https://community.developer.atlassian.com/t/confluence-cloud-time-zone-get-via-rest-api/35954/16 # https://jira.atlassian.com/browse/CONFCLOUD-69670 def get_current_tz_offset() -> int: # datetime now() gets local time, datetime.now(timezone.utc) gets UTC time. # remove tzinfo to compare non-timezone-aware objects. time_diff = datetime.now() - datetime.now(timezone.utc).replace(tzinfo=None) return round(time_diff.total_seconds() / 3600) # enter as a floating point offset from UTC in hours (-24 < val < 24) # this will be applied globally, so it probably makes sense to transition this to per # connector as some point. # For the default value, we assume that the user's local timezone is more likely to be # correct (i.e. the configured user's timezone or the default server one) than UTC. # https://developer.atlassian.com/cloud/confluence/cql-fields/#created CONFLUENCE_TIMEZONE_OFFSET = float( os.environ.get("CONFLUENCE_TIMEZONE_OFFSET", get_current_tz_offset()) ) CONFLUENCE_USE_ONYX_USERS_FOR_GROUP_SYNC = ( os.environ.get("CONFLUENCE_USE_ONYX_USERS_FOR_GROUP_SYNC", "").lower() == "true" ) GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD = int( os.environ.get("GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD", 10 * 1024 * 1024) ) # Default size threshold for Drupal Wiki attachments (10MB) DRUPAL_WIKI_ATTACHMENT_SIZE_THRESHOLD = int( os.environ.get("DRUPAL_WIKI_ATTACHMENT_SIZE_THRESHOLD", 10 * 1024 * 1024) ) # Default size threshold for SharePoint files (20MB) SHAREPOINT_CONNECTOR_SIZE_THRESHOLD = int( os.environ.get("SHAREPOINT_CONNECTOR_SIZE_THRESHOLD", 20 * 1024 * 1024) ) # When True, group sync enumerates every Azure AD group in the tenant (expensive). # When False (default), only groups found in site role assignments are synced. # Can be overridden per-connector via the "exhaustive_ad_enumeration" key in # connector_specific_config. SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION = ( os.environ.get("SHAREPOINT_EXHAUSTIVE_AD_ENUMERATION", "").lower() == "true" ) BLOB_STORAGE_SIZE_THRESHOLD = int( os.environ.get("BLOB_STORAGE_SIZE_THRESHOLD", 20 * 1024 * 1024) ) JIRA_CONNECTOR_LABELS_TO_SKIP = [ ignored_tag for ignored_tag in os.environ.get("JIRA_CONNECTOR_LABELS_TO_SKIP", "").split(",") if ignored_tag ] # Maximum size for Jira tickets in bytes (default: 100KB) JIRA_CONNECTOR_MAX_TICKET_SIZE = int( os.environ.get("JIRA_CONNECTOR_MAX_TICKET_SIZE", 100 * 1024) ) JIRA_SLIM_PAGE_SIZE = int(os.environ.get("JIRA_SLIM_PAGE_SIZE", 500)) GONG_CONNECTOR_START_TIME = os.environ.get("GONG_CONNECTOR_START_TIME") GITHUB_CONNECTOR_BASE_URL = os.environ.get("GITHUB_CONNECTOR_BASE_URL") or None GITLAB_CONNECTOR_INCLUDE_CODE_FILES = ( os.environ.get("GITLAB_CONNECTOR_INCLUDE_CODE_FILES", "").lower() == "true" ) # Typically set to http://localhost:3000 for OAuth connector development CONNECTOR_LOCALHOST_OVERRIDE = os.getenv("CONNECTOR_LOCALHOST_OVERRIDE") # Egnyte specific configs EGNYTE_CLIENT_ID = os.getenv("EGNYTE_CLIENT_ID") EGNYTE_CLIENT_SECRET = os.getenv("EGNYTE_CLIENT_SECRET") # Linear specific configs LINEAR_CLIENT_ID = os.getenv("LINEAR_CLIENT_ID") LINEAR_CLIENT_SECRET = os.getenv("LINEAR_CLIENT_SECRET") # Slack specific configs SLACK_NUM_THREADS = int(os.getenv("SLACK_NUM_THREADS") or 8) MAX_SLACK_QUERY_EXPANSIONS = int(os.environ.get("MAX_SLACK_QUERY_EXPANSIONS", "5")) # Slack federated search thread context settings # Batch size for fetching thread context (controls concurrent API calls per batch) SLACK_THREAD_CONTEXT_BATCH_SIZE = int( os.environ.get("SLACK_THREAD_CONTEXT_BATCH_SIZE", "5") ) # Maximum messages to fetch thread context for (top N by relevance get full context) MAX_SLACK_THREAD_CONTEXT_MESSAGES = int( os.environ.get("MAX_SLACK_THREAD_CONTEXT_MESSAGES", "5") ) # TestRail specific configs TESTRAIL_BASE_URL = os.environ.get("TESTRAIL_BASE_URL", "") TESTRAIL_USERNAME = os.environ.get("TESTRAIL_USERNAME", "") TESTRAIL_API_KEY = os.environ.get("TESTRAIL_API_KEY", "") LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE = ( os.environ.get("LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE", "").lower() == "true" ) DEFAULT_PRUNING_FREQ = 60 * 60 * 24 # Once a day ALLOW_SIMULTANEOUS_PRUNING = ( os.environ.get("ALLOW_SIMULTANEOUS_PRUNING", "").lower() == "true" ) # This is the maximum rate at which documents are queried for a pruning job. 0 disables the limitation. MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE = int( os.environ.get("MAX_PRUNING_DOCUMENT_RETRIEVAL_PER_MINUTE", 0) ) # comma delimited list of zendesk article labels to skip indexing for ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS = os.environ.get( "ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS", "" ).split(",") ##### # Indexing Configs ##### # NOTE: Currently only supported in the Confluence and Google Drive connectors + # only handles some failures (Confluence = handles API call failures, Google # Drive = handles failures pulling files / parsing them) CONTINUE_ON_CONNECTOR_FAILURE = os.environ.get( "CONTINUE_ON_CONNECTOR_FAILURE", "" ).lower() not in ["false", ""] # When swapping to a new embedding model, a secondary index is created in the background, to conserve # resources, we pause updates on the primary index by default while the secondary index is created DISABLE_INDEX_UPDATE_ON_SWAP = ( os.environ.get("DISABLE_INDEX_UPDATE_ON_SWAP", "").lower() == "true" ) # More accurate results at the expense of indexing speed and index size (stores additional 4 MINI_CHUNK vectors) ENABLE_MULTIPASS_INDEXING = ( os.environ.get("ENABLE_MULTIPASS_INDEXING", "").lower() == "true" ) # Enable contextual retrieval ENABLE_CONTEXTUAL_RAG = os.environ.get("ENABLE_CONTEXTUAL_RAG", "").lower() == "true" DEFAULT_CONTEXTUAL_RAG_LLM_NAME = "gpt-4o-mini" DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER = "DevEnvPresetOpenAI" # Finer grained chunking for more detail retention # Slightly larger since the sentence aware split is a max cutoff so most minichunks will be under MINI_CHUNK_SIZE # tokens. But we need it to be at least as big as 1/4th chunk size to avoid having a tiny mini-chunk at the end MINI_CHUNK_SIZE = 150 # This is the number of regular chunks per large chunk LARGE_CHUNK_RATIO = 4 # The maximum number of chunks that can be held for 1 document processing batch # The purpose of this is to set an upper bound on memory usage MAX_CHUNKS_PER_DOC_BATCH = int(os.environ.get("MAX_CHUNKS_PER_DOC_BATCH") or 1000) # Include the document level metadata in each chunk. If the metadata is too long, then it is thrown out # We don't want the metadata to overwhelm the actual contents of the chunk SKIP_METADATA_IN_CHUNK = os.environ.get("SKIP_METADATA_IN_CHUNK", "").lower() == "true" # The indexer will warn in the logs whenver a document exceeds this threshold (in bytes) INDEXING_SIZE_WARNING_THRESHOLD = int( os.environ.get("INDEXING_SIZE_WARNING_THRESHOLD") or 100 * 1024 * 1024 ) # during indexing, will log verbose memory diff stats every x batches and at the end. # 0 disables this behavior and is the default. INDEXING_TRACER_INTERVAL = int(os.environ.get("INDEXING_TRACER_INTERVAL") or 0) # Enable multi-threaded embedding model calls for parallel processing # Note: only applies for API-based embedding models INDEXING_EMBEDDING_MODEL_NUM_THREADS = int( os.environ.get("INDEXING_EMBEDDING_MODEL_NUM_THREADS") or 8 ) # Maximum file size in a document to be indexed MAX_DOCUMENT_CHARS = int(os.environ.get("MAX_DOCUMENT_CHARS") or 5_000_000) MAX_FILE_SIZE_BYTES = int( os.environ.get("MAX_FILE_SIZE_BYTES") or 2 * 1024 * 1024 * 1024 ) # 2GB in bytes # Use document summary for contextual rag USE_DOCUMENT_SUMMARY = os.environ.get("USE_DOCUMENT_SUMMARY", "true").lower() == "true" # Use chunk summary for contextual rag USE_CHUNK_SUMMARY = os.environ.get("USE_CHUNK_SUMMARY", "true").lower() == "true" # Average summary embeddings for contextual rag (not yet implemented) AVERAGE_SUMMARY_EMBEDDINGS = ( os.environ.get("AVERAGE_SUMMARY_EMBEDDINGS", "false").lower() == "true" ) MAX_TOKENS_FOR_FULL_INCLUSION = 4096 # The intent was to have this be configurable per query, but I don't think any # codepath was actually configuring this, so for the migrated Vespa interface # we'll just use the default value, but also have it be configurable by env var. RECENCY_BIAS_MULTIPLIER = float(os.environ.get("RECENCY_BIAS_MULTIPLIER") or 1.0) # Should match the rerank-count value set in # backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd.jinja. RERANK_COUNT = int(os.environ.get("RERANK_COUNT") or 1000) ##### # Tool Configs ##### # Code Interpreter Service Configuration CODE_INTERPRETER_BASE_URL = os.environ.get( "CODE_INTERPRETER_BASE_URL", "http://localhost:8000" ) CODE_INTERPRETER_DEFAULT_TIMEOUT_MS = int( os.environ.get("CODE_INTERPRETER_DEFAULT_TIMEOUT_MS") or 60_000 ) CODE_INTERPRETER_MAX_OUTPUT_LENGTH = int( os.environ.get("CODE_INTERPRETER_MAX_OUTPUT_LENGTH") or 50_000 ) ##### # Miscellaneous ##### JOB_TIMEOUT = 60 * 60 * 6 # 6 hours default # Logs Onyx only model interactions like prompts, responses, messages etc. LOG_ONYX_MODEL_INTERACTIONS = ( os.environ.get("LOG_ONYX_MODEL_INTERACTIONS", "").lower() == "true" ) PROMPT_CACHE_CHAT_HISTORY = ( os.environ.get("PROMPT_CACHE_CHAT_HISTORY", "").lower() == "true" ) # If set to `true` will enable additional logs about Vespa query performance # (time spent on finding the right docs + time spent fetching summaries from disk) LOG_VESPA_TIMING_INFORMATION = ( os.environ.get("LOG_VESPA_TIMING_INFORMATION", "").lower() == "true" ) LOG_ENDPOINT_LATENCY = os.environ.get("LOG_ENDPOINT_LATENCY", "").lower() == "true" LOG_POSTGRES_LATENCY = os.environ.get("LOG_POSTGRES_LATENCY", "").lower() == "true" LOG_POSTGRES_CONN_COUNTS = ( os.environ.get("LOG_POSTGRES_CONN_COUNTS", "").lower() == "true" ) # Anonymous usage telemetry DISABLE_TELEMETRY = os.environ.get("DISABLE_TELEMETRY", "").lower() == "true" ##### # Braintrust Configuration ##### # Braintrust project name BRAINTRUST_PROJECT = os.environ.get("BRAINTRUST_PROJECT", "Onyx") # Braintrust API key - if provided, Braintrust tracing will be enabled BRAINTRUST_API_KEY = os.environ.get("BRAINTRUST_API_KEY") or "" # Maximum concurrency for Braintrust evaluations # None means unlimited concurrency, otherwise specify a number _braintrust_concurrency = os.environ.get("BRAINTRUST_MAX_CONCURRENCY") BRAINTRUST_MAX_CONCURRENCY = ( int(_braintrust_concurrency) if _braintrust_concurrency else None ) ##### # Scheduled Evals Configuration ##### # Comma-separated list of Braintrust dataset names to run on schedule SCHEDULED_EVAL_DATASET_NAMES = [ name.strip() for name in os.environ.get("SCHEDULED_EVAL_DATASET_NAMES", "").split(",") if name.strip() ] # Email address to use for search permissions during scheduled evals SCHEDULED_EVAL_PERMISSIONS_EMAIL = os.environ.get( "SCHEDULED_EVAL_PERMISSIONS_EMAIL", "roshan@onyx.app" ) # Braintrust project name to use for scheduled evals SCHEDULED_EVAL_PROJECT = os.environ.get("SCHEDULED_EVAL_PROJECT", "st-dev") ##### # Langfuse Configuration ##### # Langfuse API credentials - if provided, Langfuse tracing will be enabled LANGFUSE_SECRET_KEY = os.environ.get("LANGFUSE_SECRET_KEY") or "" LANGFUSE_PUBLIC_KEY = os.environ.get("LANGFUSE_PUBLIC_KEY") or "" LANGFUSE_HOST = os.environ.get("LANGFUSE_HOST") or "" # For self-hosted Langfuse # Defined custom query/answer conditions to validate the query and the LLM answer. # Format: list of strings CUSTOM_ANSWER_VALIDITY_CONDITIONS = json.loads( os.environ.get("CUSTOM_ANSWER_VALIDITY_CONDITIONS", "[]") ) VESPA_REQUEST_TIMEOUT = int(os.environ.get("VESPA_REQUEST_TIMEOUT") or "15") # This is the timeout for the client side of the Vespa migration task. When # exceeded, an exception is raised in our code. This value should be higher than # VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT. VESPA_MIGRATION_REQUEST_TIMEOUT_S = int( os.environ.get("VESPA_MIGRATION_REQUEST_TIMEOUT_S") or "120" ) # This is the timeout Vespa uses on the server side to know when to wrap up its # traversal and try to report partial results. This differs from the client # timeout above which raises an exception in our code when exceeded. This # timeout allows Vespa to return gracefully. This value should be lower than # VESPA_MIGRATION_REQUEST_TIMEOUT_S. Formatted as s. VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT = os.environ.get( "VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT", "110s" ) SYSTEM_RECURSION_LIMIT = int(os.environ.get("SYSTEM_RECURSION_LIMIT") or "1000") PARSE_WITH_TRAFILATURA = os.environ.get("PARSE_WITH_TRAFILATURA", "").lower() == "true" # allow for custom error messages for different errors returned by litellm # for example, can specify: {"Violated content safety policy": "EVIL REQUEST!!!"} # to make it so that if an LLM call returns an error containing "Violated content safety policy" # the end user will see "EVIL REQUEST!!!" instead of the default error message. _LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS = os.environ.get( "LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS", "" ) LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS: dict[str, str] | None = None try: LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS = cast( dict[str, str], json.loads(_LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS) ) except json.JSONDecodeError: pass # Auto LLM Configuration - fetches model configs from GitHub for providers in Auto mode AUTO_LLM_CONFIG_URL = os.environ.get( "AUTO_LLM_CONFIG_URL", "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/onyx/llm/well_known_providers/recommended-models.json", ) # How often to check for auto LLM model updates (in seconds) AUTO_LLM_UPDATE_INTERVAL_SECONDS = int( os.environ.get("AUTO_LLM_UPDATE_INTERVAL_SECONDS", 1800) # 30 minutes ) ##### # Enterprise Edition Configs ##### # NOTE: this should only be enabled if you have purchased an enterprise license. # if you're interested in an enterprise license, please reach out to us at # founders@onyx.app OR message Chris Weaver or Yuhong Sun in the Onyx # Discord community https://discord.gg/4NA5SbzrWb ENTERPRISE_EDITION_ENABLED = ( os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() == "true" ) ##### # Image Generation Configuration (DEPRECATED) # These environment variables will be deprecated soon. # To configure image generation, please visit the Image Generation page in the Admin Panel. ##### # Azure Image Configurations AZURE_IMAGE_API_VERSION = os.environ.get("AZURE_IMAGE_API_VERSION") or os.environ.get( "AZURE_DALLE_API_VERSION" ) AZURE_IMAGE_API_KEY = os.environ.get("AZURE_IMAGE_API_KEY") or os.environ.get( "AZURE_DALLE_API_KEY" ) AZURE_IMAGE_API_BASE = os.environ.get("AZURE_IMAGE_API_BASE") or os.environ.get( "AZURE_DALLE_API_BASE" ) AZURE_IMAGE_DEPLOYMENT_NAME = os.environ.get( "AZURE_IMAGE_DEPLOYMENT_NAME" ) or os.environ.get("AZURE_DALLE_DEPLOYMENT_NAME") # configurable image model IMAGE_MODEL_NAME = os.environ.get("IMAGE_MODEL_NAME", "gpt-image-1") IMAGE_MODEL_PROVIDER = os.environ.get("IMAGE_MODEL_PROVIDER", "openai") # Use managed Vespa (Vespa Cloud). If set, must also set VESPA_CLOUD_URL, VESPA_CLOUD_CERT_PATH and VESPA_CLOUD_KEY_PATH MANAGED_VESPA = os.environ.get("MANAGED_VESPA", "").lower() == "true" ENABLE_EMAIL_INVITES = os.environ.get("ENABLE_EMAIL_INVITES", "").lower() == "true" # Limit on number of users a free trial tenant can invite (cloud only) NUM_FREE_TRIAL_USER_INVITES = int(os.environ.get("NUM_FREE_TRIAL_USER_INVITES", "10")) # Security and authentication DATA_PLANE_SECRET = os.environ.get( "DATA_PLANE_SECRET", "" ) # Used for secure communication between the control and data plane EXPECTED_API_KEY = os.environ.get( "EXPECTED_API_KEY", "" ) # Additional security check for the control plane API # API configuration CONTROL_PLANE_API_BASE_URL = os.environ.get( "CONTROL_PLANE_API_BASE_URL", "http://localhost:8082" ) OAUTH_SLACK_CLIENT_ID = os.environ.get("OAUTH_SLACK_CLIENT_ID", "") OAUTH_SLACK_CLIENT_SECRET = os.environ.get("OAUTH_SLACK_CLIENT_SECRET", "") OAUTH_CONFLUENCE_CLOUD_CLIENT_ID = os.environ.get( "OAUTH_CONFLUENCE_CLOUD_CLIENT_ID", "" ) OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET = os.environ.get( "OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET", "" ) OAUTH_GOOGLE_DRIVE_CLIENT_ID = os.environ.get("OAUTH_GOOGLE_DRIVE_CLIENT_ID", "") OAUTH_GOOGLE_DRIVE_CLIENT_SECRET = os.environ.get( "OAUTH_GOOGLE_DRIVE_CLIENT_SECRET", "" ) # JWT configuration JWT_ALGORITHM = "HS256" ##### # API Key Configs ##### # refers to the rounds described here: https://passlib.readthedocs.io/en/stable/lib/passlib.hash.sha256_crypt.html _API_KEY_HASH_ROUNDS_RAW = os.environ.get("API_KEY_HASH_ROUNDS") API_KEY_HASH_ROUNDS = ( int(_API_KEY_HASH_ROUNDS_RAW) if _API_KEY_HASH_ROUNDS_RAW else None ) ##### # MCP Server Configs ##### MCP_SERVER_ENABLED = os.environ.get("MCP_SERVER_ENABLED", "").lower() == "true" MCP_SERVER_HOST = os.environ.get("MCP_SERVER_HOST", "0.0.0.0") MCP_SERVER_PORT = int(os.environ.get("MCP_SERVER_PORT") or 8090) # CORS origins for MCP clients (comma-separated) # Local dev: "http://localhost:*" # Production: "https://trusted-client.com,https://another-client.com" MCP_SERVER_CORS_ORIGINS = [ origin.strip() for origin in os.environ.get("MCP_SERVER_CORS_ORIGINS", "").split(",") if origin.strip() ] POD_NAME = os.environ.get("POD_NAME") POD_NAMESPACE = os.environ.get("POD_NAMESPACE") DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true" INTEGRATION_TESTS_MODE = os.environ.get("INTEGRATION_TESTS_MODE", "").lower() == "true" ##### # Captcha Configuration (for cloud signup protection) ##### # Enable captcha verification for new user registration CAPTCHA_ENABLED = os.environ.get("CAPTCHA_ENABLED", "").lower() == "true" # Google reCAPTCHA secret key (server-side validation) RECAPTCHA_SECRET_KEY = os.environ.get("RECAPTCHA_SECRET_KEY", "") # Minimum score threshold for reCAPTCHA v3 (0.0-1.0, higher = more likely human) # 0.5 is the recommended default RECAPTCHA_SCORE_THRESHOLD = float(os.environ.get("RECAPTCHA_SCORE_THRESHOLD", "0.5")) MOCK_CONNECTOR_FILE_PATH = os.environ.get("MOCK_CONNECTOR_FILE_PATH") # Set to true to mock LLM responses for testing purposes MOCK_LLM_RESPONSE = ( os.environ.get("MOCK_LLM_RESPONSE") if os.environ.get("MOCK_LLM_RESPONSE") else None ) DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20 # Number of pre-provisioned tenants to maintain TARGET_AVAILABLE_TENANTS = int(os.environ.get("TARGET_AVAILABLE_TENANTS", "5")) # Image summarization configuration IMAGE_SUMMARIZATION_SYSTEM_PROMPT = os.environ.get( "IMAGE_SUMMARIZATION_SYSTEM_PROMPT", DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT, ) # The user prompt for image summarization - the image filename will be automatically prepended IMAGE_SUMMARIZATION_USER_PROMPT = os.environ.get( "IMAGE_SUMMARIZATION_USER_PROMPT", DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT, ) # Knowledge Graph Read Only User Configuration DB_READONLY_USER: str = os.environ.get("DB_READONLY_USER", "db_readonly_user") DB_READONLY_PASSWORD: str = urllib.parse.quote_plus( os.environ.get("DB_READONLY_PASSWORD") or "password" ) # File Store Configuration # Which backend to use for file storage: "s3" (S3/MinIO) or "postgres" (PostgreSQL Large Objects) FILE_STORE_BACKEND = os.environ.get("FILE_STORE_BACKEND", "s3") S3_FILE_STORE_BUCKET_NAME = ( os.environ.get("S3_FILE_STORE_BUCKET_NAME") or "onyx-file-store-bucket" ) S3_FILE_STORE_PREFIX = os.environ.get("S3_FILE_STORE_PREFIX") or "onyx-files" # S3_ENDPOINT_URL is for MinIO and other S3-compatible storage. Leave blank for AWS S3. S3_ENDPOINT_URL = os.environ.get("S3_ENDPOINT_URL") S3_VERIFY_SSL = os.environ.get("S3_VERIFY_SSL", "").lower() == "true" # S3/MinIO Access Keys S3_AWS_ACCESS_KEY_ID = os.environ.get("S3_AWS_ACCESS_KEY_ID") S3_AWS_SECRET_ACCESS_KEY = os.environ.get("S3_AWS_SECRET_ACCESS_KEY") # Should we force S3 local checksumming S3_GENERATE_LOCAL_CHECKSUM = ( os.environ.get("S3_GENERATE_LOCAL_CHECKSUM", "").lower() == "true" ) # Forcing Vespa Language # English: en, German:de, etc. See: https://docs.vespa.ai/en/linguistics.html VESPA_LANGUAGE_OVERRIDE = os.environ.get("VESPA_LANGUAGE_OVERRIDE") ##### # Default LLM API Keys (for cloud deployments) # These are Onyx-managed API keys provided to tenants by default ##### OPENAI_DEFAULT_API_KEY = os.environ.get("OPENAI_DEFAULT_API_KEY") ANTHROPIC_DEFAULT_API_KEY = os.environ.get("ANTHROPIC_DEFAULT_API_KEY") COHERE_DEFAULT_API_KEY = os.environ.get("COHERE_DEFAULT_API_KEY") VERTEXAI_DEFAULT_CREDENTIALS = os.environ.get("VERTEXAI_DEFAULT_CREDENTIALS") VERTEXAI_DEFAULT_LOCATION = os.environ.get("VERTEXAI_DEFAULT_LOCATION", "global") OPENROUTER_DEFAULT_API_KEY = os.environ.get("OPENROUTER_DEFAULT_API_KEY") INSTANCE_TYPE = ( "managed" if os.environ.get("IS_MANAGED_INSTANCE", "").lower() == "true" else "cloud" if AUTH_TYPE == AuthType.CLOUD else "self_hosted" ) ## Discord Bot Configuration DISCORD_BOT_TOKEN = os.environ.get("DISCORD_BOT_TOKEN") DISCORD_BOT_INVOKE_CHAR = os.environ.get("DISCORD_BOT_INVOKE_CHAR", "!") ## Stripe Configuration # URL to fetch the Stripe publishable key from a public S3 bucket. # Publishable keys are safe to expose publicly - they can only initialize # Stripe.js and tokenize payment info, not make charges or access data. STRIPE_PUBLISHABLE_KEY_URL = ( "https://onyx-stripe-public.s3.amazonaws.com/publishable-key.txt" ) # Override for local testing with Stripe test keys (pk_test_*) STRIPE_PUBLISHABLE_KEY_OVERRIDE = os.environ.get("STRIPE_PUBLISHABLE_KEY") ================================================ FILE: backend/onyx/configs/chat_configs.py ================================================ import os PROMPTS_YAML = "./onyx/seeding/prompts.yaml" PERSONAS_YAML = "./onyx/seeding/personas.yaml" NUM_RETURNED_HITS = 50 # May be less depending on model MAX_CHUNKS_FED_TO_CHAT = int(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 25) # 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay # Capped in Vespa at 0.5 DOC_TIME_DECAY = float( os.environ.get("DOC_TIME_DECAY") or 0.5 # Hits limit at 2 years by default ) BASE_RECENCY_DECAY = 0.5 FAVOR_RECENT_DECAY_MULTIPLIER = 2.0 # For the highest matching base size chunk, how many chunks above and below do we pull in by default # Note this is not in any of the deployment configs yet # Currently only applies to search flow not chat CONTEXT_CHUNKS_ABOVE = int(os.environ.get("CONTEXT_CHUNKS_ABOVE") or 1) CONTEXT_CHUNKS_BELOW = int(os.environ.get("CONTEXT_CHUNKS_BELOW") or 1) # Fairly long but this is to account for edge cases where the LLM pauses for much longer than usual # The alternative is to fail the request completely so this is intended to be fairly lenient. LLM_SOCKET_READ_TIMEOUT = int( os.environ.get("LLM_SOCKET_READ_TIMEOUT") or "60" ) # 60 seconds # Weighting factor between vector and keyword Search; 1 for completely vector # search, 0 for keyword. Enforces a valid range of [0, 1]. A supplied value from # the env outside of this range will be clipped to the respective end of the # range. Defaults to 0.5. HYBRID_ALPHA = max(0, min(1, float(os.environ.get("HYBRID_ALPHA") or 0.5))) # Weighting factor between Title and Content of documents during search, 1 for completely # Title based. Default heavily favors Content because Title is also included at the top of # Content. This is to avoid cases where the Content is very relevant but it may not be clear # if the title is separated out. Title is most of a "boost" than a separate field. TITLE_CONTENT_RATIO = max( 0, min(1, float(os.environ.get("TITLE_CONTENT_RATIO") or 0.10)) ) # Stops streaming answers back to the UI if this pattern is seen: STOP_STREAM_PAT = os.environ.get("STOP_STREAM_PAT") or None # Set this to "true" to hard delete chats # This will make chats unviewable by admins after a user deletes them # As opposed to soft deleting them, which just hides them from non-admin users HARD_DELETE_CHATS = os.environ.get("HARD_DELETE_CHATS", "").lower() == "true" # Internet Search NUM_INTERNET_SEARCH_RESULTS = int(os.environ.get("NUM_INTERNET_SEARCH_RESULTS") or 10) NUM_INTERNET_SEARCH_CHUNKS = int(os.environ.get("NUM_INTERNET_SEARCH_CHUNKS") or 50) VESPA_SEARCHER_THREADS = int(os.environ.get("VESPA_SEARCHER_THREADS") or 2) # Whether or not to use the semantic & keyword search expansions for Basic Search USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH = ( os.environ.get("USE_SEMANTIC_KEYWORD_EXPANSIONS_BASIC_SEARCH", "false").lower() == "true" ) # Chat History Compression # Trigger compression when history exceeds this ratio of available context window COMPRESSION_TRIGGER_RATIO = float(os.environ.get("COMPRESSION_TRIGGER_RATIO", "0.75")) SKIP_DEEP_RESEARCH_CLARIFICATION = ( os.environ.get("SKIP_DEEP_RESEARCH_CLARIFICATION", "false").lower() == "true" ) ================================================ FILE: backend/onyx/configs/constants.py ================================================ import platform import re import socket from enum import auto from enum import Enum ONYX_DEFAULT_APPLICATION_NAME = "Onyx" ONYX_DISCORD_URL = "https://discord.gg/4NA5SbzrWb" ONYX_UTM_SOURCE = "onyx_app" SLACK_USER_TOKEN_PREFIX = "xoxp-" SLACK_BOT_TOKEN_PREFIX = "xoxb-" ONYX_EMAILABLE_LOGO_MAX_DIM = 512 SOURCE_TYPE = "source_type" # stored in the `metadata` of a chunk. Used to signify that this chunk should # not be used for QA. For example, Google Drive file types which can't be parsed # are still useful as a search result but not for QA. IGNORE_FOR_QA = "ignore_for_qa" # NOTE: deprecated, only used for porting key from old system GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key" PUBLIC_DOC_PAT = "PUBLIC" ID_SEPARATOR = ":;:" DEFAULT_BOOST = 0 # Tag for endpoints that should be included in the public API documentation PUBLIC_API_TAGS: list[str | Enum] = ["public"] # Cookies FASTAPI_USERS_AUTH_COOKIE_NAME = ( "fastapiusersauth" # Currently a constant, but logic allows for configuration ) TENANT_ID_COOKIE_NAME = "onyx_tid" # tenant id - for workaround cases ANONYMOUS_USER_COOKIE_NAME = "onyx_anonymous_user" # ID used in UserInfo API responses for anonymous users (not a UUID, just a string identifier) ANONYMOUS_USER_INFO_ID = "__anonymous_user__" # Placeholder user for migrating no-auth data to first registered user NO_AUTH_PLACEHOLDER_USER_UUID = "00000000-0000-0000-0000-000000000001" NO_AUTH_PLACEHOLDER_USER_EMAIL = "no-auth-placeholder@onyx.app" # Real anonymous user in DB for anonymous access feature ANONYMOUS_USER_UUID = "00000000-0000-0000-0000-000000000002" ANONYMOUS_USER_EMAIL = "anonymous@onyx.app" # For chunking/processing chunks RETURN_SEPARATOR = "\n\r\n" SECTION_SEPARATOR = "\n\n" # For combining attributes, doesn't have to be unique/perfect to work INDEX_SEPARATOR = "===" # For File Connector Metadata override file ONYX_METADATA_FILENAME = ".onyx_metadata.json" # Messages DISABLED_GEN_AI_MSG = ( "Your System Admin has disabled the Generative AI functionalities of Onyx.\n" "Please contact them if you wish to have this enabled.\n" "You can still use Onyx as a search engine." ) ##### # Version Pattern Configs ##### # Version patterns for Docker image tags STABLE_VERSION_PATTERN = re.compile(r"^v(\d+)\.(\d+)\.(\d+)$") DEV_VERSION_PATTERN = re.compile(r"^v(\d+)\.(\d+)\.(\d+)-beta\.(\d+)$") DEFAULT_PERSONA_ID = 0 DEFAULT_CC_PAIR_ID = 1 CANCEL_CHECK_INTERVAL = 20 DISPATCH_SEP_CHAR = "\n" FORMAT_DOCS_SEPARATOR = "\n\n" NUM_EXPLORATORY_DOCS = 15 # Postgres connection constants for application_name POSTGRES_WEB_APP_NAME = "web" POSTGRES_INDEXER_APP_NAME = "indexer" POSTGRES_CELERY_APP_NAME = "celery" POSTGRES_CELERY_BEAT_APP_NAME = "celery_beat" POSTGRES_CELERY_WORKER_PRIMARY_APP_NAME = "celery_worker_primary" POSTGRES_CELERY_WORKER_LIGHT_APP_NAME = "celery_worker_light" POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME = "celery_worker_docprocessing" POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME = "celery_worker_docfetching" POSTGRES_CELERY_WORKER_INDEXING_CHILD_APP_NAME = "celery_worker_indexing_child" POSTGRES_CELERY_WORKER_HEAVY_APP_NAME = "celery_worker_heavy" POSTGRES_CELERY_WORKER_MONITORING_APP_NAME = "celery_worker_monitoring" POSTGRES_CELERY_WORKER_USER_FILE_PROCESSING_APP_NAME = ( "celery_worker_user_file_processing" ) POSTGRES_PERMISSIONS_APP_NAME = "permissions" POSTGRES_UNKNOWN_APP_NAME = "unknown" SSL_CERT_FILE = "bundle.pem" # API Keys DANSWER_API_KEY_PREFIX = "API_KEY__" DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN = "onyxapikey.ai" UNNAMED_KEY_PLACEHOLDER = "Unnamed" DISCORD_SERVICE_API_KEY_NAME = "discord-bot-service" # Key-Value store keys KV_REINDEX_KEY = "needs_reindexing" KV_UNSTRUCTURED_API_KEY = "unstructured_api_key" KV_USER_STORE_KEY = "INVITED_USERS" KV_PENDING_USERS_KEY = "PENDING_USERS" KV_ANONYMOUS_USER_PREFERENCES_KEY = "anonymous_user_preferences" KV_ANONYMOUS_USER_PERSONALIZATION_KEY = "anonymous_user_personalization" KV_CRED_KEY = "credential_id_{}" KV_GMAIL_CRED_KEY = "gmail_app_credential" KV_GMAIL_SERVICE_ACCOUNT_KEY = "gmail_service_account_key" KV_GOOGLE_DRIVE_CRED_KEY = "google_drive_app_credential" KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY = "google_drive_service_account_key" KV_GEN_AI_KEY_CHECK_TIME = "genai_api_key_last_check_time" KV_SETTINGS_KEY = "onyx_settings" KV_CUSTOMER_UUID_KEY = "customer_uuid" KV_INSTANCE_DOMAIN_KEY = "instance_domain" KV_ENTERPRISE_SETTINGS_KEY = "onyx_enterprise_settings" KV_CUSTOM_ANALYTICS_SCRIPT_KEY = "__custom_analytics_script__" KV_KG_CONFIG_KEY = "kg_config" # NOTE: we use this timeout / 4 in various places to refresh a lock # might be worth separating this timeout into separate timeouts for each situation CELERY_GENERIC_BEAT_LOCK_TIMEOUT = 120 CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT = 120 CELERY_PRIMARY_WORKER_LOCK_TIMEOUT = 120 # hard timeout applied by the watchdog to the indexing connector run # to handle hung connectors CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT = 3 * 60 * 60 # 3 hours (in seconds) # soft timeout for the lock taken by the indexing connector run # allows the lock to eventually expire if the managing code around it dies # if we can get callbacks as object bytes download, we could lower this a lot. # CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 15 minutes # hard termination should always fire first if the connector is hung CELERY_INDEXING_LOCK_TIMEOUT = CELERY_INDEXING_WATCHDOG_CONNECTOR_TIMEOUT + 900 # Heartbeat interval for indexing worker liveness detection INDEXING_WORKER_HEARTBEAT_INTERVAL = 30 # seconds # how long a task should wait for associated fence to be ready CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT = 5 * 60 # 5 min # needs to be long enough to cover the maximum time it takes to download an object # if we can get callbacks as object bytes download, we could lower this a lot. CELERY_PRUNING_LOCK_TIMEOUT = 3600 # 1 hour (in seconds) CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT = 3600 # 1 hour (in seconds) CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds) # How long a queued user-file task is valid before workers discard it. # Should be longer than the beat interval (20 s) but short enough to prevent # indefinite queue growth. Workers drop tasks older than this without touching # the DB, so a shorter value = faster drain of stale duplicates. CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds) # Maximum number of tasks allowed in the user-file-processing queue before the # beat generator stops adding more. Prevents unbounded queue growth when workers # fall behind. USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500 # How long a queued user-file-project-sync task remains valid. # Should be short enough to discard stale queue entries under load while still # allowing workers enough time to pick up new tasks. CELERY_USER_FILE_PROJECT_SYNC_TASK_EXPIRES = 60 # 1 minute (in seconds) # Max queue depth before user-file-project-sync producers stop enqueuing. # This applies backpressure when workers are falling behind. USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH = 500 CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds) # How long a queued user-file-delete task is valid before workers discard it. # Mirrors the processing task expiry to prevent indefinite queue growth when # files are stuck in DELETING status and the beat keeps re-enqueuing them. CELERY_USER_FILE_DELETE_TASK_EXPIRES = 60 # 1 minute (in seconds) # Max queue depth before the delete beat stops enqueuing more delete tasks. USER_FILE_DELETE_MAX_QUEUE_DEPTH = 500 CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds) DANSWER_REDIS_FUNCTION_LOCK_PREFIX = "da_function_lock:" TMP_DRALPHA_PERSONA_NAME = "KG Beta" class DocumentSource(str, Enum): # Special case, document passed in via Onyx APIs without specifying a source type INGESTION_API = "ingestion_api" SLACK = "slack" WEB = "web" GOOGLE_DRIVE = "google_drive" GMAIL = "gmail" REQUESTTRACKER = "requesttracker" GITHUB = "github" GITBOOK = "gitbook" GITLAB = "gitlab" GURU = "guru" BOOKSTACK = "bookstack" OUTLINE = "outline" CONFLUENCE = "confluence" JIRA = "jira" SLAB = "slab" PRODUCTBOARD = "productboard" FILE = "file" CODA = "coda" CANVAS = "canvas" NOTION = "notion" ZULIP = "zulip" LINEAR = "linear" HUBSPOT = "hubspot" DOCUMENT360 = "document360" GONG = "gong" GOOGLE_SITES = "google_sites" ZENDESK = "zendesk" LOOPIO = "loopio" DROPBOX = "dropbox" SHAREPOINT = "sharepoint" TEAMS = "teams" SALESFORCE = "salesforce" DISCOURSE = "discourse" AXERO = "axero" CLICKUP = "clickup" MEDIAWIKI = "mediawiki" WIKIPEDIA = "wikipedia" ASANA = "asana" S3 = "s3" R2 = "r2" GOOGLE_CLOUD_STORAGE = "google_cloud_storage" OCI_STORAGE = "oci_storage" XENFORO = "xenforo" NOT_APPLICABLE = "not_applicable" DISCORD = "discord" FRESHDESK = "freshdesk" FIREFLIES = "fireflies" EGNYTE = "egnyte" AIRTABLE = "airtable" HIGHSPOT = "highspot" DRUPAL_WIKI = "drupal_wiki" IMAP = "imap" BITBUCKET = "bitbucket" TESTRAIL = "testrail" # Special case just for integration tests MOCK_CONNECTOR = "mock_connector" # Special case for user files USER_FILE = "user_file" # Raw files for Craft sandbox access (xlsx, pptx, docx, etc.) # Uses RAW_BINARY processing mode - no text extraction CRAFT_FILE = "craft_file" class FederatedConnectorSource(str, Enum): FEDERATED_SLACK = "federated_slack" def to_non_federated_source(self) -> DocumentSource | None: if self == FederatedConnectorSource.FEDERATED_SLACK: return DocumentSource.SLACK return None DocumentSourceRequiringTenantContext: list[DocumentSource] = [DocumentSource.FILE] class NotificationType(str, Enum): REINDEX = "reindex" PERSONA_SHARED = "persona_shared" TRIAL_ENDS_TWO_DAYS = "two_day_trial_ending" # 2 days left in trial RELEASE_NOTES = "release_notes" ASSISTANT_FILES_READY = "assistant_files_ready" FEATURE_ANNOUNCEMENT = "feature_announcement" class BlobType(str, Enum): R2 = "r2" S3 = "s3" GOOGLE_CLOUD_STORAGE = "google_cloud_storage" OCI_STORAGE = "oci_storage" class DocumentIndexType(str, Enum): COMBINED = "combined" # Vespa SPLIT = "split" # Typesense + Qdrant class AuthType(str, Enum): BASIC = "basic" GOOGLE_OAUTH = "google_oauth" OIDC = "oidc" SAML = "saml" # google auth and basic CLOUD = "cloud" class QueryHistoryType(str, Enum): DISABLED = "disabled" ANONYMIZED = "anonymized" NORMAL = "normal" # Special characters for password validation PASSWORD_SPECIAL_CHARS = "!@#$%^&*()_+-=[]{}|;:,.<>?" class SessionType(str, Enum): CHAT = "Chat" SEARCH = "Search" SLACK = "Slack" class QAFeedbackType(str, Enum): LIKE = "like" # User likes the answer, used for metrics DISLIKE = "dislike" # User dislikes the answer, used for metrics MIXED = "mixed" # User likes some answers and dislikes other, used for chat session metrics class SearchFeedbackType(str, Enum): ENDORSE = "endorse" # boost this document for all future queries REJECT = "reject" # down-boost this document for all future queries HIDE = "hide" # mark this document as untrusted, hide from LLM UNHIDE = "unhide" class MessageType(str, Enum): # Using OpenAI standards, Langchain equivalent shown in comment # System message is always constructed on the fly, not saved SYSTEM = "system" # SystemMessage USER = "user" # HumanMessage ASSISTANT = "assistant" # AIMessage - Can include tool_calls field for parallel tool calling TOOL_CALL_RESPONSE = "tool_call_response" USER_REMINDER = "user_reminder" # Custom Onyx message type which is translated into a USER message when passed to the LLM class ChatMessageSimpleType(str, Enum): USER = "user" ASSISTANT = "assistant" TOOL_CALL = "tool_call" FILE_TEXT = "file_text" class TokenRateLimitScope(str, Enum): USER = "user" USER_GROUP = "user_group" GLOBAL = "global" class FileStoreType(str, Enum): S3 = "s3" POSTGRES = "postgres" class FileOrigin(str, Enum): CHAT_UPLOAD = "chat_upload" CHAT_IMAGE_GEN = "chat_image_gen" CONNECTOR = "connector" CONNECTOR_METADATA = "connector_metadata" GENERATED_REPORT = "generated_report" INDEXING_CHECKPOINT = "indexing_checkpoint" PLAINTEXT_CACHE = "plaintext_cache" OTHER = "other" QUERY_HISTORY_CSV = "query_history_csv" SANDBOX_SNAPSHOT = "sandbox_snapshot" USER_FILE = "user_file" class FileType(str, Enum): CSV = "text/csv" class MilestoneRecordType(str, Enum): TENANT_CREATED = "tenant_created" USER_SIGNED_UP = "user_signed_up" VISITED_ADMIN_PAGE = "visited_admin_page" CREATED_CONNECTOR = "created_connector" CONNECTOR_SUCCEEDED = "connector_succeeded" RAN_QUERY = "ran_query" USER_MESSAGE_SENT = "user_message_sent" MULTIPLE_ASSISTANTS = "multiple_assistants" CREATED_ASSISTANT = "created_assistant" CREATED_ONYX_BOT = "created_onyx_bot" REQUESTED_CONNECTOR = "requested_connector" class PostgresAdvisoryLocks(Enum): KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto() class OnyxCeleryQueues: # "celery" is the default queue defined by celery and also the queue # we are running in the primary worker to run system tasks # Tasks running in this queue should be designed specifically to run quickly PRIMARY = "celery" # Light queue VESPA_METADATA_SYNC = "vespa_metadata_sync" DOC_PERMISSIONS_UPSERT = "doc_permissions_upsert" CONNECTOR_DELETION = "connector_deletion" LLM_MODEL_UPDATE = "llm_model_update" CHECKPOINT_CLEANUP = "checkpoint_cleanup" INDEX_ATTEMPT_CLEANUP = "index_attempt_cleanup" # Heavy queue CONNECTOR_PRUNING = "connector_pruning" CONNECTOR_DOC_PERMISSIONS_SYNC = "connector_doc_permissions_sync" CONNECTOR_EXTERNAL_GROUP_SYNC = "connector_external_group_sync" CONNECTOR_HIERARCHY_FETCHING = "connector_hierarchy_fetching" CSV_GENERATION = "csv_generation" # User file processing queue USER_FILE_PROCESSING = "user_file_processing" USER_FILE_PROJECT_SYNC = "user_file_project_sync" USER_FILE_DELETE = "user_file_delete" # Document processing pipeline queue DOCPROCESSING = "docprocessing" CONNECTOR_DOC_FETCHING = "connector_doc_fetching" # Monitoring queue MONITORING = "monitoring" # Sandbox processing queue SANDBOX = "sandbox" OPENSEARCH_MIGRATION = "opensearch_migration" class OnyxRedisLocks: PRIMARY_WORKER = "da_lock:primary_worker" CHECK_VESPA_SYNC_BEAT_LOCK = "da_lock:check_vespa_sync_beat" CHECK_CONNECTOR_DELETION_BEAT_LOCK = "da_lock:check_connector_deletion_beat" CHECK_PRUNE_BEAT_LOCK = "da_lock:check_prune_beat" CHECK_HIERARCHY_FETCHING_BEAT_LOCK = "da_lock:check_hierarchy_fetching_beat" CHECK_INDEXING_BEAT_LOCK = "da_lock:check_indexing_beat" CHECK_CHECKPOINT_CLEANUP_BEAT_LOCK = "da_lock:check_checkpoint_cleanup_beat" CHECK_INDEX_ATTEMPT_CLEANUP_BEAT_LOCK = "da_lock:check_index_attempt_cleanup_beat" CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK = ( "da_lock:check_connector_doc_permissions_sync_beat" ) CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK = ( "da_lock:check_connector_external_group_sync_beat" ) OPENSEARCH_MIGRATION_BEAT_LOCK = "da_lock:opensearch_migration_beat" MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes" CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants" CLOUD_PRE_PROVISION_TENANT_LOCK = "da_lock:pre_provision_tenant" CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = ( "da_lock:connector_doc_permissions_sync" ) CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX = "da_lock:connector_external_group_sync" PRUNING_LOCK_PREFIX = "da_lock:pruning" INDEXING_METADATA_PREFIX = "da_metadata:indexing" SLACK_BOT_LOCK = "da_lock:slack_bot" SLACK_BOT_HEARTBEAT_PREFIX = "da_heartbeat:slack_bot" ANONYMOUS_USER_ENABLED = "anonymous_user_enabled" CLOUD_BEAT_TASK_GENERATOR_LOCK = "da_lock:cloud_beat_task_generator" CLOUD_CHECK_ALEMBIC_BEAT_LOCK = "da_lock:cloud_check_alembic" # User file processing USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat" USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing" # Short-lived key set when a task is enqueued; cleared when the worker picks it up. # Prevents the beat from re-enqueuing the same file while a task is already queued. USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued" USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat" USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync" USER_FILE_PROJECT_SYNC_QUEUED_PREFIX = "da_lock:user_file_project_sync_queued" USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat" USER_FILE_DELETE_LOCK_PREFIX = "da_lock:user_file_delete" # Short-lived key set when a delete task is enqueued; cleared when the worker picks it up. # Prevents the beat from re-enqueuing the same file while a delete task is already queued. USER_FILE_DELETE_QUEUED_PREFIX = "da_lock:user_file_delete_queued" # Release notes RELEASE_NOTES_FETCH_LOCK = "da_lock:release_notes_fetch" # Sandbox cleanup CLEANUP_IDLE_SANDBOXES_BEAT_LOCK = "da_lock:cleanup_idle_sandboxes_beat" CLEANUP_OLD_SNAPSHOTS_BEAT_LOCK = "da_lock:cleanup_old_snapshots_beat" # Sandbox file sync SANDBOX_FILE_SYNC_LOCK_PREFIX = "da_lock:sandbox_file_sync" class OnyxRedisSignals: BLOCK_VALIDATE_INDEXING_FENCES = "signal:block_validate_indexing_fences" BLOCK_VALIDATE_EXTERNAL_GROUP_SYNC_FENCES = ( "signal:block_validate_external_group_sync_fences" ) BLOCK_VALIDATE_PERMISSION_SYNC_FENCES = ( "signal:block_validate_permission_sync_fences" ) BLOCK_PRUNING = "signal:block_pruning" BLOCK_VALIDATE_PRUNING_FENCES = "signal:block_validate_pruning_fences" BLOCK_BUILD_FENCE_LOOKUP_TABLE = "signal:block_build_fence_lookup_table" BLOCK_VALIDATE_CONNECTOR_DELETION_FENCES = ( "signal:block_validate_connector_deletion_fences" ) class OnyxRedisConstants: ACTIVE_FENCES = "active_fences" class OnyxCeleryPriority(int, Enum): HIGHEST = 0 HIGH = auto() MEDIUM = auto() LOW = auto() LOWEST = auto() # a prefix used to distinguish system wide tasks in the cloud ONYX_CLOUD_CELERY_TASK_PREFIX = "cloud" # the tenant id we use for system level redis operations ONYX_CLOUD_TENANT_ID = "cloud" # the redis namespace for runtime variables ONYX_CLOUD_REDIS_RUNTIME = "runtime" CLOUD_BUILD_FENCE_LOOKUP_TABLE_INTERVAL_DEFAULT = 600 class OnyxCeleryTask: DEFAULT = "celery" CLOUD_BEAT_TASK_GENERATOR = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_generate_beat_tasks" CLOUD_MONITOR_ALEMBIC = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_alembic" CLOUD_MONITOR_CELERY_QUEUES = ( f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_queues" ) CLOUD_CHECK_AVAILABLE_TENANTS = ( f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_available_tenants" ) CLOUD_MONITOR_CELERY_PIDBOX = ( f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_pidbox" ) CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task" CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task" CHECK_FOR_INDEXING = "check_for_indexing" CHECK_FOR_PRUNING = "check_for_pruning" CHECK_FOR_HIERARCHY_FETCHING = "check_for_hierarchy_fetching" CHECK_FOR_DOC_PERMISSIONS_SYNC = "check_for_doc_permissions_sync" CHECK_FOR_EXTERNAL_GROUP_SYNC = "check_for_external_group_sync" CHECK_FOR_AUTO_LLM_UPDATE = "check_for_auto_llm_update" # User file processing CHECK_FOR_USER_FILE_PROCESSING = "check_for_user_file_processing" PROCESS_SINGLE_USER_FILE = "process_single_user_file" CHECK_FOR_USER_FILE_PROJECT_SYNC = "check_for_user_file_project_sync" PROCESS_SINGLE_USER_FILE_PROJECT_SYNC = "process_single_user_file_project_sync" CHECK_FOR_USER_FILE_DELETE = "check_for_user_file_delete" DELETE_SINGLE_USER_FILE = "delete_single_user_file" # Connector checkpoint cleanup CHECK_FOR_CHECKPOINT_CLEANUP = "check_for_checkpoint_cleanup" CLEANUP_CHECKPOINT = "cleanup_checkpoint" # Connector index attempt cleanup CHECK_FOR_INDEX_ATTEMPT_CLEANUP = "check_for_index_attempt_cleanup" CLEANUP_INDEX_ATTEMPT = "cleanup_index_attempt" MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes" MONITOR_CELERY_QUEUES = "monitor_celery_queues" MONITOR_PROCESS_MEMORY = "monitor_process_memory" CELERY_BEAT_HEARTBEAT = "celery_beat_heartbeat" KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task" CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = ( "connector_permission_sync_generator_task" ) UPDATE_EXTERNAL_DOCUMENT_PERMISSIONS_TASK = ( "update_external_document_permissions_task" ) CONNECTOR_EXTERNAL_GROUP_SYNC_GENERATOR_TASK = ( "connector_external_group_sync_generator_task" ) # New split indexing tasks CONNECTOR_DOC_FETCHING_TASK = "connector_doc_fetching_task" DOCPROCESSING_TASK = "docprocessing_task" CONNECTOR_PRUNING_GENERATOR_TASK = "connector_pruning_generator_task" CONNECTOR_HIERARCHY_FETCHING_TASK = "connector_hierarchy_fetching_task" DOCUMENT_BY_CC_PAIR_CLEANUP_TASK = "document_by_cc_pair_cleanup_task" VESPA_METADATA_SYNC_TASK = "vespa_metadata_sync_task" # chat retention CHECK_TTL_MANAGEMENT_TASK = "check_ttl_management_task" PERFORM_TTL_MANAGEMENT_TASK = "perform_ttl_management_task" GENERATE_USAGE_REPORT_TASK = "generate_usage_report_task" EVAL_RUN_TASK = "eval_run_task" SCHEDULED_EVAL_TASK = "scheduled_eval_task" EXPORT_QUERY_HISTORY_TASK = "export_query_history_task" EXPORT_QUERY_HISTORY_CLEANUP_TASK = "export_query_history_cleanup_task" # Hook execution log retention HOOK_EXECUTION_LOG_CLEANUP_TASK = "hook_execution_log_cleanup_task" # Sandbox cleanup CLEANUP_IDLE_SANDBOXES = "cleanup_idle_sandboxes" CLEANUP_OLD_SNAPSHOTS = "cleanup_old_snapshots" # Sandbox file sync SANDBOX_FILE_SYNC = "sandbox_file_sync" CHECK_FOR_DOCUMENTS_FOR_OPENSEARCH_MIGRATION_TASK = ( "check_for_documents_for_opensearch_migration_task" ) MIGRATE_DOCUMENTS_FROM_VESPA_TO_OPENSEARCH_TASK = ( "migrate_documents_from_vespa_to_opensearch_task" ) MIGRATE_CHUNKS_FROM_VESPA_TO_OPENSEARCH_TASK = ( "migrate_chunks_from_vespa_to_opensearch_task" ) # this needs to correspond to the matching entry in supervisord ONYX_CELERY_BEAT_HEARTBEAT_KEY = "onyx:celery:beat:heartbeat" REDIS_SOCKET_KEEPALIVE_OPTIONS = {} REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPINTVL] = 15 REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPCNT] = 3 if platform.system() == "Darwin": REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPALIVE] = 60 # type: ignore[attr-defined,unused-ignore] else: REDIS_SOCKET_KEEPALIVE_OPTIONS[socket.TCP_KEEPIDLE] = 60 # type: ignore[attr-defined,unused-ignore] class OnyxCallTypes(str, Enum): FIREFLIES = "FIREFLIES" GONG = "GONG" NUM_DAYS_TO_KEEP_CHECKPOINTS = 7 # checkpoints are queried based on index attempts, so we need to keep index attempts for one more day NUM_DAYS_TO_KEEP_INDEX_ATTEMPTS = NUM_DAYS_TO_KEEP_CHECKPOINTS + 1 # TODO: this should be stored likely in database DocumentSourceDescription: dict[DocumentSource, str] = { # Special case, document passed in via Onyx APIs without specifying a source type DocumentSource.INGESTION_API: "ingestion_api", DocumentSource.SLACK: "slack channels for discussions and collaboration", DocumentSource.WEB: "indexed web pages", DocumentSource.GOOGLE_DRIVE: "google drive documents (docs, sheets, etc.)", DocumentSource.GMAIL: "email messages", DocumentSource.REQUESTTRACKER: "requesttracker", DocumentSource.GITHUB: "github data (issues, PRs)", DocumentSource.GITBOOK: "gitbook data", DocumentSource.GITLAB: "gitlab data", DocumentSource.BITBUCKET: "bitbucket data", DocumentSource.GURU: "guru data", DocumentSource.BOOKSTACK: "bookstack data", DocumentSource.OUTLINE: "outline data", DocumentSource.CONFLUENCE: "confluence data (pages, spaces, etc.)", DocumentSource.JIRA: "jira data (issues, tickets, projects, etc.)", DocumentSource.SLAB: "slab data", DocumentSource.PRODUCTBOARD: "productboard data (boards, etc.)", DocumentSource.FILE: "files", DocumentSource.CANVAS: "canvas lms - courses, pages, assignments, and announcements", DocumentSource.CODA: "coda - team workspace with docs, tables, and pages", DocumentSource.NOTION: "notion data - a workspace that combines note-taking, \ project management, and collaboration tools into a single, customizable platform", DocumentSource.ZULIP: "zulip data", DocumentSource.LINEAR: "linear data - project management tool, including tickets etc.", DocumentSource.HUBSPOT: "hubspot data - CRM and marketing automation data", DocumentSource.DOCUMENT360: "document360 data", DocumentSource.GONG: "gong - call transcripts", DocumentSource.GOOGLE_SITES: "google_sites - websites", DocumentSource.ZENDESK: "zendesk - customer support data", DocumentSource.LOOPIO: "loopio - rfp data", DocumentSource.DROPBOX: "dropbox - files", DocumentSource.SHAREPOINT: "sharepoint - files", DocumentSource.TEAMS: "teams - chat and collaboration", DocumentSource.SALESFORCE: "salesforce - CRM data", DocumentSource.DISCOURSE: "discourse - discussion forums", DocumentSource.AXERO: "axero - employee engagement data", DocumentSource.CLICKUP: "clickup - project management tool", DocumentSource.MEDIAWIKI: "mediawiki - wiki data", DocumentSource.WIKIPEDIA: "wikipedia - encyclopedia data", DocumentSource.ASANA: "asana", DocumentSource.S3: "s3", DocumentSource.R2: "r2", DocumentSource.GOOGLE_CLOUD_STORAGE: "google_cloud_storage - cloud storage", DocumentSource.OCI_STORAGE: "oci_storage - cloud storage", DocumentSource.XENFORO: "xenforo - forum data", DocumentSource.DISCORD: "discord - chat and collaboration", DocumentSource.FRESHDESK: "freshdesk - customer support data", DocumentSource.FIREFLIES: "fireflies - call transcripts", DocumentSource.EGNYTE: "egnyte - files", DocumentSource.AIRTABLE: "airtable - database", DocumentSource.HIGHSPOT: "highspot - CRM data", DocumentSource.DRUPAL_WIKI: "drupal wiki - knowledge base content (pages, spaces, attachments)", DocumentSource.IMAP: "imap - email data", DocumentSource.TESTRAIL: "testrail - test case management tool for QA processes", } ================================================ FILE: backend/onyx/configs/embedding_configs.py ================================================ from pydantic import BaseModel from onyx.db.enums import EmbeddingPrecision class _BaseEmbeddingModel(BaseModel): """Private model for defining base embedding model configurations.""" name: str dim: int index_name: str class SupportedEmbeddingModel(BaseModel): name: str dim: int index_name: str embedding_precision: EmbeddingPrecision # Base embedding model configurations (without precision) _BASE_EMBEDDING_MODELS = [ # Cloud-based models _BaseEmbeddingModel( name="cohere/embed-english-v3.0", dim=1024, index_name="danswer_chunk_cohere_embed_english_v3_0", ), _BaseEmbeddingModel( name="cohere/embed-english-v3.0", dim=1024, index_name="danswer_chunk_embed_english_v3_0", ), _BaseEmbeddingModel( name="cohere/embed-english-light-v3.0", dim=384, index_name="danswer_chunk_cohere_embed_english_light_v3_0", ), _BaseEmbeddingModel( name="cohere/embed-english-light-v3.0", dim=384, index_name="danswer_chunk_embed_english_light_v3_0", ), _BaseEmbeddingModel( name="openai/text-embedding-3-large", dim=3072, index_name="danswer_chunk_openai_text_embedding_3_large", ), _BaseEmbeddingModel( name="openai/text-embedding-3-large", dim=3072, index_name="danswer_chunk_text_embedding_3_large", ), _BaseEmbeddingModel( name="openai/text-embedding-3-small", dim=1536, index_name="danswer_chunk_openai_text_embedding_3_small", ), _BaseEmbeddingModel( name="openai/text-embedding-3-small", dim=1536, index_name="danswer_chunk_text_embedding_3_small", ), _BaseEmbeddingModel( name="google/gemini-embedding-001", dim=3072, index_name="danswer_chunk_gemini_embedding_001", ), _BaseEmbeddingModel( name="google/text-embedding-005", dim=768, index_name="danswer_chunk_text_embedding_005", ), _BaseEmbeddingModel( name="voyage/voyage-large-2-instruct", dim=1024, index_name="danswer_chunk_voyage_large_2_instruct", ), _BaseEmbeddingModel( name="voyage/voyage-large-2-instruct", dim=1024, index_name="danswer_chunk_large_2_instruct", ), _BaseEmbeddingModel( name="voyage/voyage-light-2-instruct", dim=384, index_name="danswer_chunk_voyage_light_2_instruct", ), _BaseEmbeddingModel( name="voyage/voyage-light-2-instruct", dim=384, index_name="danswer_chunk_light_2_instruct", ), # Self-hosted models _BaseEmbeddingModel( name="nomic-ai/nomic-embed-text-v1", dim=768, index_name="danswer_chunk_nomic_ai_nomic_embed_text_v1", ), _BaseEmbeddingModel( name="nomic-ai/nomic-embed-text-v1", dim=768, index_name="danswer_chunk_nomic_embed_text_v1", ), _BaseEmbeddingModel( name="intfloat/e5-base-v2", dim=768, index_name="danswer_chunk_intfloat_e5_base_v2", ), _BaseEmbeddingModel( name="intfloat/e5-small-v2", dim=384, index_name="danswer_chunk_intfloat_e5_small_v2", ), _BaseEmbeddingModel( name="intfloat/multilingual-e5-base", dim=768, index_name="danswer_chunk_intfloat_multilingual_e5_base", ), _BaseEmbeddingModel( name="intfloat/multilingual-e5-small", dim=384, index_name="danswer_chunk_intfloat_multilingual_e5_small", ), ] # Automatically generate both FLOAT and BFLOAT16 versions of all models SUPPORTED_EMBEDDING_MODELS = [ # BFLOAT16 precision versions *[ SupportedEmbeddingModel( name=model.name, dim=model.dim, index_name=f"{model.index_name}_bfloat16", embedding_precision=EmbeddingPrecision.BFLOAT16, ) for model in _BASE_EMBEDDING_MODELS ], # FLOAT precision versions # NOTE: need to keep this one for backwards compatibility. We now default to # BFLOAT16. *[ SupportedEmbeddingModel( name=model.name, dim=model.dim, index_name=model.index_name, embedding_precision=EmbeddingPrecision.FLOAT, ) for model in _BASE_EMBEDDING_MODELS ], ] ================================================ FILE: backend/onyx/configs/kg_configs.py ================================================ import os KG_RESEARCH_NUM_RETRIEVED_DOCS: int = int( os.environ.get("KG_RESEARCH_NUM_RETRIEVED_DOCS", "25") ) KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES: int = int( os.environ.get("KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES", "10") ) KG_ENTITY_EXTRACTION_TIMEOUT: int = int( os.environ.get("KG_ENTITY_EXTRACTION_TIMEOUT", "15") ) KG_RELATIONSHIP_EXTRACTION_TIMEOUT: int = int( os.environ.get("KG_RELATIONSHIP_EXTRACTION_TIMEOUT", "15") ) KG_STRATEGY_GENERATION_TIMEOUT: int = int( os.environ.get("KG_STRATEGY_GENERATION_TIMEOUT", "20") ) KG_SQL_GENERATION_TIMEOUT: int = int(os.environ.get("KG_SQL_GENERATION_TIMEOUT", "40")) KG_SQL_GENERATION_TIMEOUT_OVERRIDE: int = int( os.environ.get("KG_SQL_GENERATION_TIMEOUT_OVERRIDE", "40") ) KG_SQL_GENERATION_MAX_TOKENS: int = int( os.environ.get("KG_SQL_GENERATION_MAX_TOKENS", "1500") ) KG_TEMP_ALLOWED_DOCS_VIEW_NAME_PREFIX: str = os.environ.get( "KG_TEMP_ALLOWED_DOCS_VIEW_NAME_PREFIX", "allowed_docs" ) KG_TEMP_KG_RELATIONSHIPS_VIEW_NAME_PREFIX: str = os.environ.get( "KG_TEMP_KG_RELATIONSHIPS_VIEW_NAME_PREFIX", "kg_relationships_with_access" ) KG_TEMP_KG_ENTITIES_VIEW_NAME_PREFIX: str = os.environ.get( "KG_TEMP_KG_ENTITIES_VIEW_NAME_PREFIX", "kg_entities_with_access" ) KG_FILTER_CONSTRUCTION_TIMEOUT: int = int( os.environ.get("KG_FILTER_CONSTRUCTION_TIMEOUT", "15") ) KG_NORMALIZATION_RETRIEVE_ENTITIES_LIMIT: int = int( os.environ.get("KG_NORMALIZATION_RETRIEVE_ENTITIES_LIMIT", "100") ) KG_FILTERED_SEARCH_TIMEOUT: int = int( os.environ.get("KG_FILTERED_SEARCH_TIMEOUT", "30") ) KG_OBJECT_SOURCE_RESEARCH_TIMEOUT: int = int( os.environ.get("KG_OBJECT_SOURCE_RESEARCH_TIMEOUT", "30") ) KG_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION: int = int( os.environ.get("KG_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION", "45") ) KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION: int = int( os.environ.get("KG_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION", "15") ) KG_MAX_TOKENS_ANSWER_GENERATION: int = int( os.environ.get("KG_MAX_TOKENS_ANSWER_GENERATION", "1024") ) KG_MAX_DEEP_SEARCH_RESULTS: int = int( os.environ.get("KG_MAX_DEEP_SEARCH_RESULTS", "30") ) KG_METADATA_TRACKING_THRESHOLD: int = int( os.environ.get("KG_METADATA_TRACKING_THRESHOLD", "10") ) KG_DEFAULT_MAX_PARENT_RECURSION_DEPTH: int = int( os.environ.get("KG_DEFAULT_MAX_PARENT_RECURSION_DEPTH", "2") ) _KG_NORMALIZATION_RERANK_UNIGRAM_WEIGHT: float = max( 1e-3, min(1, float(os.environ.get("KG_NORMALIZATION_RERANK_UNIGRAM_WEIGHT", "0.25"))), ) _KG_NORMALIZATION_RERANK_BIGRAM_WEIGHT: float = max( 1e-3, min(1, float(os.environ.get("KG_NORMALIZATION_RERANK_BIGRAM_WEIGHT", "0.25"))), ) _KG_NORMALIZATION_RERANK_TRIGRAM_WEIGHT: float = max( 1e-3, min(1, float(os.environ.get("KG_NORMALIZATION_RERANK_TRIGRAM_WEIGHT", "0.5"))), ) _KG_NORMALIZATION_RERANK_NGRAM_SUMS: float = ( _KG_NORMALIZATION_RERANK_UNIGRAM_WEIGHT + _KG_NORMALIZATION_RERANK_BIGRAM_WEIGHT + _KG_NORMALIZATION_RERANK_TRIGRAM_WEIGHT ) KG_NORMALIZATION_RERANK_NGRAM_WEIGHTS: tuple[float, float, float] = ( _KG_NORMALIZATION_RERANK_UNIGRAM_WEIGHT / _KG_NORMALIZATION_RERANK_NGRAM_SUMS, _KG_NORMALIZATION_RERANK_BIGRAM_WEIGHT / _KG_NORMALIZATION_RERANK_NGRAM_SUMS, _KG_NORMALIZATION_RERANK_TRIGRAM_WEIGHT / _KG_NORMALIZATION_RERANK_NGRAM_SUMS, ) KG_NORMALIZATION_RERANK_LEVENSHTEIN_WEIGHT: float = max( 0, min(1, float(os.environ.get("KG_NORMALIZATION_RERANK_LEVENSHTEIN_WEIGHT", "0.25"))), ) KG_NORMALIZATION_RERANK_THRESHOLD: float = float( os.environ.get("KG_NORMALIZATION_RERANK_THRESHOLD", "0.3") ) KG_CLUSTERING_RETRIEVE_THRESHOLD: float = float( os.environ.get("KG_CLUSTERING_RETRIEVE_THRESHOLD", "0.6") ) KG_CLUSTERING_THRESHOLD: float = float( os.environ.get("KG_CLUSTERING_THRESHOLD", "0.96") ) KG_MAX_SEARCH_DOCUMENTS: int = int(os.environ.get("KG_MAX_SEARCH_DOCUMENTS", "15")) KG_MAX_DECOMPOSITION_SEGMENTS: int = int( os.environ.get("KG_MAX_DECOMPOSITION_SEGMENTS", "10") ) KG_BETA_ASSISTANT_DESCRIPTION = "The KG Beta assistant uses the Onyx Knowledge Graph (beta) structure \ to answer questions" ================================================ FILE: backend/onyx/configs/llm_configs.py ================================================ from onyx.configs.app_configs import DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB from onyx.server.settings.store import load_settings def get_image_extraction_and_analysis_enabled() -> bool: """Get image extraction and analysis enabled setting from workspace settings or fallback to False""" try: settings = load_settings() if settings.image_extraction_and_analysis_enabled is not None: return settings.image_extraction_and_analysis_enabled except Exception: pass return False def get_search_time_image_analysis_enabled() -> bool: """Get search time image analysis enabled setting from workspace settings or fallback to False""" try: settings = load_settings() if settings.search_time_image_analysis_enabled is not None: return settings.search_time_image_analysis_enabled except Exception: pass return False def get_image_analysis_max_size_mb() -> int: """Get image analysis max size MB setting from workspace settings or fallback to environment variable""" try: settings = load_settings() if settings.image_analysis_max_size_mb is not None: return settings.image_analysis_max_size_mb except Exception: pass return DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB ================================================ FILE: backend/onyx/configs/model_configs.py ================================================ import json import os ##### # Embedding/Reranking Model Configs ##### # Important considerations when choosing models # Max tokens count needs to be high considering use case (at least 512) # Models used must be MIT or Apache license # Inference/Indexing speed # https://huggingface.co/DOCUMENT_ENCODER_MODEL # The useable models configured as below must be SentenceTransformer compatible # NOTE: DO NOT CHANGE SET THESE UNLESS YOU KNOW WHAT YOU ARE DOING # IDEALLY, YOU SHOULD CHANGE EMBEDDING MODELS VIA THE UI DEFAULT_DOCUMENT_ENCODER_MODEL = "nomic-ai/nomic-embed-text-v1" DOCUMENT_ENCODER_MODEL = ( os.environ.get("DOCUMENT_ENCODER_MODEL") or DEFAULT_DOCUMENT_ENCODER_MODEL ) # If the below is changed, Vespa deployment must also be changed DOC_EMBEDDING_DIM = int(os.environ.get("DOC_EMBEDDING_DIM") or 768) NORMALIZE_EMBEDDINGS = ( os.environ.get("NORMALIZE_EMBEDDINGS") or "true" ).lower() == "true" # Old default model settings, which are needed for an automatic easy upgrade OLD_DEFAULT_DOCUMENT_ENCODER_MODEL = "thenlper/gte-small" OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM = 384 OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS = False # These are only used if reranking is turned off, to normalize the direct retrieval scores for display # Currently unused SIM_SCORE_RANGE_LOW = float(os.environ.get("SIM_SCORE_RANGE_LOW") or 0.0) SIM_SCORE_RANGE_HIGH = float(os.environ.get("SIM_SCORE_RANGE_HIGH") or 1.0) # Certain models like e5, BGE, etc use a prefix for asymmetric retrievals (query generally shorter than docs) ASYM_QUERY_PREFIX = os.environ.get("ASYM_QUERY_PREFIX", "search_query: ") ASYM_PASSAGE_PREFIX = os.environ.get("ASYM_PASSAGE_PREFIX", "search_document: ") # Purely an optimization, memory limitation consideration # User's set embedding batch size overrides the default encoding batch sizes EMBEDDING_BATCH_SIZE = int(os.environ.get("EMBEDDING_BATCH_SIZE") or 0) or None BATCH_SIZE_ENCODE_CHUNKS = EMBEDDING_BATCH_SIZE or 8 # don't send over too many chunks at once, as sending too many could cause timeouts BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES = EMBEDDING_BATCH_SIZE or 512 # For score display purposes, only way is to know the expected ranges CROSS_ENCODER_RANGE_MAX = 1 CROSS_ENCODER_RANGE_MIN = 0 ##### # Generative AI Model Configs ##### # NOTE: the 2 below should only be used for dev. GEN_AI_API_KEY = os.environ.get("GEN_AI_API_KEY") GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") # Override the auto-detection of LLM max context length GEN_AI_MAX_TOKENS = int(os.environ.get("GEN_AI_MAX_TOKENS") or 0) or None # Set this to be enough for an answer + quotes. Also used for Chat # This is the minimum token context we will leave for the LLM to generate an answer GEN_AI_NUM_RESERVED_OUTPUT_TOKENS = int( os.environ.get("GEN_AI_NUM_RESERVED_OUTPUT_TOKENS") or 1024 ) # Fallback token limit for models where the max context is unknown # Set conservatively at 32K to handle most modern models GEN_AI_MODEL_FALLBACK_MAX_TOKENS = int( os.environ.get("GEN_AI_MODEL_FALLBACK_MAX_TOKENS") or 32000 ) # This is used when computing how much context space is available for documents # ahead of time in order to let the user know if they can "select" more documents # It represents a maximum "expected" number of input tokens from the latest user # message. At query time, we don't actually enforce this - we will only throw an # error if the total # of tokens exceeds the max input tokens. GEN_AI_SINGLE_USER_MESSAGE_EXPECTED_MAX_TOKENS = 512 GEN_AI_TEMPERATURE = float(os.environ.get("GEN_AI_TEMPERATURE") or 0) # should be used if you are using a custom LLM inference provider that doesn't support # streaming format AND you are still using the langchain/litellm LLM class DISABLE_LITELLM_STREAMING = ( os.environ.get("DISABLE_LITELLM_STREAMING") or "false" ).lower() == "true" # extra headers to pass to LiteLLM LITELLM_EXTRA_HEADERS: dict[str, str] | None = None _LITELLM_EXTRA_HEADERS_RAW = os.environ.get("LITELLM_EXTRA_HEADERS") if _LITELLM_EXTRA_HEADERS_RAW: try: LITELLM_EXTRA_HEADERS = json.loads(_LITELLM_EXTRA_HEADERS_RAW) except Exception: # need to import here to avoid circular imports from onyx.utils.logger import setup_logger logger = setup_logger() logger.error( "Failed to parse LITELLM_EXTRA_HEADERS, must be a valid JSON object" ) # if specified, will pass through request headers to the call to the LLM LITELLM_PASS_THROUGH_HEADERS: list[str] | None = None _LITELLM_PASS_THROUGH_HEADERS_RAW = os.environ.get("LITELLM_PASS_THROUGH_HEADERS") if _LITELLM_PASS_THROUGH_HEADERS_RAW: try: LITELLM_PASS_THROUGH_HEADERS = json.loads(_LITELLM_PASS_THROUGH_HEADERS_RAW) except Exception: # need to import here to avoid circular imports from onyx.utils.logger import setup_logger logger = setup_logger() logger.error( "Failed to parse LITELLM_PASS_THROUGH_HEADERS, must be a valid JSON object" ) # if specified, will merge the specified JSON with the existing body of the # request before sending it to the LLM LITELLM_EXTRA_BODY: dict | None = None _LITELLM_EXTRA_BODY_RAW = os.environ.get("LITELLM_EXTRA_BODY") if _LITELLM_EXTRA_BODY_RAW: try: LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW) except Exception: pass ##### # Prompt Caching Configs ##### # Enable prompt caching framework ENABLE_PROMPT_CACHING = ( os.environ.get("ENABLE_PROMPT_CACHING", "true").lower() != "false" ) # Cache TTL multiplier - store caches slightly longer than provider TTL # This allows for some clock skew and ensures we don't lose cache metadata prematurely PROMPT_CACHE_REDIS_TTL_MULTIPLIER = float( os.environ.get("PROMPT_CACHE_REDIS_TTL_MULTIPLIER") or 1.2 ) ================================================ FILE: backend/onyx/configs/onyxbot_configs.py ================================================ import os ##### # Onyx Slack Bot Configs ##### ONYX_BOT_NUM_RETRIES = int(os.environ.get("ONYX_BOT_NUM_RETRIES", "5")) # Number of docs to display in "Reference Documents" ONYX_BOT_NUM_DOCS_TO_DISPLAY = int(os.environ.get("ONYX_BOT_NUM_DOCS_TO_DISPLAY", "5")) # If the LLM fails to answer, Onyx can still show the "Reference Documents" ONYX_BOT_DISABLE_DOCS_ONLY_ANSWER = os.environ.get( "ONYX_BOT_DISABLE_DOCS_ONLY_ANSWER", "" ).lower() not in ["false", ""] # When Onyx is considering a message, what emoji does it react with ONYX_BOT_REACT_EMOJI = os.environ.get("ONYX_BOT_REACT_EMOJI") or "eyes" # When User needs more help, what should the emoji be ONYX_BOT_FOLLOWUP_EMOJI = os.environ.get("ONYX_BOT_FOLLOWUP_EMOJI") or "sos" # What kind of message should be shown when someone gives an AI answer feedback to OnyxBot # Defaults to Private if not provided or invalid # Private: Only visible to user clicking the feedback # Anonymous: Public but anonymous # Public: Visible with the user name who submitted the feedback ONYX_BOT_FEEDBACK_VISIBILITY = ( os.environ.get("ONYX_BOT_FEEDBACK_VISIBILITY") or "private" ) # Should OnyxBot send an apology message if it's not able to find an answer # That way the user isn't confused as to why OnyxBot reacted but then said nothing # Off by default to be less intrusive (don't want to give a notif that just says we couldnt help) NOTIFY_SLACKBOT_NO_ANSWER = ( os.environ.get("NOTIFY_SLACKBOT_NO_ANSWER", "").lower() == "true" ) # Mostly for debugging purposes but it's for explaining what went wrong # if OnyxBot couldn't find an answer ONYX_BOT_DISPLAY_ERROR_MSGS = os.environ.get( "ONYX_BOT_DISPLAY_ERROR_MSGS", "" ).lower() not in [ "false", "", ] # Maximum Questions Per Minute, Default Uncapped ONYX_BOT_MAX_QPM = int(os.environ.get("ONYX_BOT_MAX_QPM") or 0) or None # Maximum time to wait when a question is queued ONYX_BOT_MAX_WAIT_TIME = int(os.environ.get("ONYX_BOT_MAX_WAIT_TIME") or 180) # Time (in minutes) after which a Slack message is sent to the user to remind him to give feedback. # Set to 0 to disable it (default) ONYX_BOT_FEEDBACK_REMINDER = int(os.environ.get("ONYX_BOT_FEEDBACK_REMINDER") or 0) # ONYX_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD is the number of # responses OnyxBot can send in a given time period. # Set to 0 to disable the limit. ONYX_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD = int( os.environ.get("ONYX_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD", "5000") ) # ONYX_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS is the number # of seconds until the response limit is reset. ONYX_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS = int( os.environ.get("ONYX_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS", "86400") ) ================================================ FILE: backend/onyx/configs/research_configs.py ================================================ ================================================ FILE: backend/onyx/configs/saml_config/template.settings.json ================================================ { "strict": true, "debug": false, "idp": { "entityId": "", "singleSignOnService": { "url": " https://trial-1234567.okta.com/home/trial-1234567_onyx/somevalues/somevalues", "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" }, "x509cert": "" }, "sp": { "entityId": "", "assertionConsumerService": { "url": "http://127.0.0.1:3000/auth/saml/callback", "binding": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" }, "x509cert": "" } } ================================================ FILE: backend/onyx/configs/tool_configs.py ================================================ import json import os IMAGE_GENERATION_OUTPUT_FORMAT = os.environ.get( "IMAGE_GENERATION_OUTPUT_FORMAT", "b64_json" ) # if specified, will pass through request headers to the call to API calls made by custom tools CUSTOM_TOOL_PASS_THROUGH_HEADERS: list[str] | None = None _CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW = os.environ.get( "CUSTOM_TOOL_PASS_THROUGH_HEADERS" ) if _CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW: try: CUSTOM_TOOL_PASS_THROUGH_HEADERS = json.loads( _CUSTOM_TOOL_PASS_THROUGH_HEADERS_RAW ) except Exception: # need to import here to avoid circular imports from onyx.utils.logger import setup_logger logger = setup_logger() logger.error( "Failed to parse CUSTOM_TOOL_PASS_THROUGH_HEADERS, must be a valid JSON object" ) ================================================ FILE: backend/onyx/connectors/README.md ================================================ # Writing a new Onyx Connector This README covers how to contribute a new Connector for Onyx. It includes an overview of the design, interfaces, and required changes. Thank you for your contribution! ### Connector Overview Connectors come in 3 different flows: - Load Connector: - Bulk indexes documents to reflect a point in time. This type of connector generally works by either pulling all documents via a connector's API or loads the documents from some sort of a dump file. - Poll Connector: - Incrementally updates documents based on a provided time range. It is used by the background job to pull the latest changes and additions since the last round of polling. This connector helps keep the document index up to date without needing to fetch/embed/index every document which would be too slow to do frequently on large sets of documents. - Slim Connector: - This connector should be a lighter weight method of checking all documents in the source to see if they still exist. - This connector should be identical to the Poll or Load Connector except that it only fetches the IDs of the documents, not the documents themselves. - This is used by our pruning job which removes old documents from the index. - The optional start and end datetimes can be ignored. - Event Based connectors: - Connectors that listen to events and update documents accordingly. - Currently not used by the background job, this exists for future design purposes. ### Connector Implementation Refer to [interfaces.py](https://github.com/onyx-dot-app/onyx/blob/main/backend/onyx/connectors/interfaces.py) and this first contributor created Pull Request for a new connector (Shoutout to Dan Brown): [Reference Pull Request](https://github.com/onyx-dot-app/onyx/pull/139) For implementing a Slim Connector, refer to the comments in this PR: [Slim Connector PR](https://github.com/onyx-dot-app/onyx/pull/3303/files) All new connectors should have tests added to the `backend/tests/daily/connectors` directory. Refer to the above PR for an example of adding tests for a new connector. #### Implementing the new Connector The connector must subclass one or more of LoadConnector, PollConnector, CheckpointedConnector, or CheckpointedConnectorWithPermSync The `__init__` should take arguments for configuring what documents the connector will and where it finds those documents. For example, if you have a wiki site, it may include the configuration for the team, topic, folder, etc. of the documents to fetch. It may also include the base domain of the wiki. Alternatively, if all the access information of the connector is stored in the credential/token, then there may be no required arguments. `load_credentials` should take a dictionary which provides all the access information that the connector might need. For example this could be the user's username and access token. Refer to the existing connectors for `load_from_state` and `poll_source` examples. There is not yet a process to listen for EventConnector events, this will come down the line. #### Development Tip It may be handy to test your new connector separate from the rest of the stack while developing. Follow the below template: ```commandline if __name__ == "__main__": import time test_connector = NewConnector(space="engineering") test_connector.load_credentials({ "user_id": "foobar", "access_token": "fake_token" }) all_docs = test_connector.load_from_state() current = time.time() one_day_ago = current - 24 * 60 * 60 # 1 day latest_docs = test_connector.poll_source(one_day_ago, current) ``` > Note: Be sure to set PYTHONPATH to onyx/backend before running the above main. ### Additional Required Changes: #### Backend Changes - Add a new type to [DocumentSource](https://github.com/onyx-dot-app/onyx/blob/main/backend/onyx/configs/constants.py) - Add a mapping from DocumentSource (and optionally connector type) to the right connector class [here](https://github.com/onyx-dot-app/onyx/blob/main/backend/onyx/connectors/factory.py#L33) #### Frontend Changes - Add the new Connector definition to the `SOURCE_METADATA_MAP` [here](https://github.com/onyx-dot-app/onyx/blob/main/web/src/lib/sources.ts#L59). - Add the definition for the new Form to the `connectorConfigs` object [here](https://github.com/onyx-dot-app/onyx/blob/main/web/src/lib/connectors/connectors.ts#L79). #### Docs Changes Create the new connector page (with guiding images!) with how to get the connector credentials and how to set up the connector in Onyx. Then create a Pull Request in [https://github.com/onyx-dot-app/documentation](https://github.com/onyx-dot-app/documentation). ### Before opening PR 1. Be sure to fully test changes end to end with setting up the connector and updating the index with new docs from the new connector. To make it easier to review, please attach a video showing the successful creation of the connector via the UI (starting from the `Add Connector` page). 2. Add a folder + tests under `backend/tests/daily/connectors` director. For an example, checkout the [test for Confluence](https://github.com/onyx-dot-app/onyx/blob/main/backend/tests/daily/connectors/confluence/test_confluence_basic.py). In the PR description, include a guide on how to setup the new source to pass the test. Before merging, we will re-create the environment and make sure the test(s) pass. 3. Be sure to run the linting/formatting, refer to the formatting and linting section in [CONTRIBUTING.md](https://github.com/onyx-dot-app/onyx/blob/main/CONTRIBUTING.md#formatting-and-linting) ================================================ FILE: backend/onyx/connectors/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/airtable/airtable_connector.py ================================================ import contextvars import re from concurrent.futures import as_completed from concurrent.futures import Future from concurrent.futures import ThreadPoolExecutor from io import BytesIO from typing import Any from typing import cast import requests from pyairtable import Api as AirtableApi from pyairtable.api.types import RecordDict from pyairtable.models.schema import TableSchema from retry import retry from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import ImageSection from onyx.connectors.models import TextSection from onyx.file_processing.extract_file_text import extract_file_text from onyx.file_processing.extract_file_text import get_file_ext from onyx.utils.logger import setup_logger logger = setup_logger() # NOTE: all are made lowercase to avoid case sensitivity issues # These field types are considered metadata by default when # treat_all_non_attachment_fields_as_metadata is False DEFAULT_METADATA_FIELD_TYPES = { "singlecollaborator", "collaborator", "createdby", "singleselect", "multipleselects", "checkbox", "date", "datetime", "email", "phone", "url", "number", "currency", "duration", "percent", "rating", "createdtime", "lastmodifiedtime", "autonumber", "rollup", "lookup", "count", "formula", "date", } class AirtableClientNotSetUpError(PermissionError): def __init__(self) -> None: super().__init__("Airtable Client is not set up, was load_credentials called?") # Matches URLs like https://airtable.com/appXXX/tblYYY/viwZZZ?blocks=hide # Captures: base_id (appXXX), table_id (tblYYY), and optionally view_id (viwZZZ) _AIRTABLE_URL_PATTERN = re.compile( r"https?://airtable\.com/(app[A-Za-z0-9]+)/(tbl[A-Za-z0-9]+)(?:/(viw[A-Za-z0-9]+))?", ) def parse_airtable_url( url: str, ) -> tuple[str, str, str | None]: """Parse an Airtable URL into (base_id, table_id, view_id). Accepts URLs like: https://airtable.com/appXXX/tblYYY https://airtable.com/appXXX/tblYYY/viwZZZ https://airtable.com/appXXX/tblYYY/viwZZZ?blocks=hide Returns: (base_id, table_id, view_id or None) Raises: ValueError if the URL doesn't match the expected format. """ match = _AIRTABLE_URL_PATTERN.search(url.strip()) if not match: raise ValueError( f"Could not parse Airtable URL: '{url}'. Expected format: https://airtable.com/appXXX/tblYYY[/viwZZZ]" ) return match.group(1), match.group(2), match.group(3) class AirtableConnector(LoadConnector): def __init__( self, base_id: str = "", table_name_or_id: str = "", airtable_url: str = "", treat_all_non_attachment_fields_as_metadata: bool = False, view_id: str | None = None, share_id: str | None = None, batch_size: int = INDEX_BATCH_SIZE, ) -> None: """Initialize an AirtableConnector. Args: base_id: The ID of the Airtable base (not required when airtable_url is set) table_name_or_id: The name or ID of the table (not required when airtable_url is set) airtable_url: An Airtable URL to parse base_id, table_id, and view_id from. Overrides base_id, table_name_or_id, and view_id if provided. treat_all_non_attachment_fields_as_metadata: If True, all fields except attachments will be treated as metadata. If False, only fields with types in DEFAULT_METADATA_FIELD_TYPES will be treated as metadata. view_id: Optional ID of a specific view to use share_id: Optional ID of a "share" to use for generating record URLs batch_size: Number of records to process in each batch Mode is auto-detected: if a specific table is identified (via URL or base_id + table_name_or_id), the connector indexes that single table. Otherwise, it discovers and indexes all accessible bases and tables. """ # If a URL is provided, parse it to extract base_id, table_id, and view_id if airtable_url: parsed_base_id, parsed_table_id, parsed_view_id = parse_airtable_url( airtable_url ) base_id = parsed_base_id table_name_or_id = parsed_table_id if parsed_view_id: view_id = parsed_view_id self.base_id = base_id self.table_name_or_id = table_name_or_id self.index_all = not (base_id and table_name_or_id) self.view_id = view_id self.share_id = share_id self.batch_size = batch_size self._airtable_client: AirtableApi | None = None self.treat_all_non_attachment_fields_as_metadata = ( treat_all_non_attachment_fields_as_metadata ) def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self._airtable_client = AirtableApi(credentials["airtable_access_token"]) return None @property def airtable_client(self) -> AirtableApi: if not self._airtable_client: raise AirtableClientNotSetUpError() return self._airtable_client def validate_connector_settings(self) -> None: if self.index_all: try: bases = self.airtable_client.bases() if not bases: raise ConnectorValidationError( "No bases found. Ensure your API token has access to at least one base." ) except ConnectorValidationError: raise except Exception as e: raise ConnectorValidationError(f"Failed to list Airtable bases: {e}") else: if not self.base_id or not self.table_name_or_id: raise ConnectorValidationError( "A valid Airtable URL or base_id and table_name_or_id are required when not using index_all mode." ) try: table = self.airtable_client.table(self.base_id, self.table_name_or_id) table.schema() except Exception as e: raise ConnectorValidationError( f"Failed to access table '{self.table_name_or_id}' in base '{self.base_id}': {e}" ) @classmethod def _get_record_url( cls, base_id: str, table_id: str, record_id: str, share_id: str | None, view_id: str | None, field_id: str | None = None, attachment_id: str | None = None, ) -> str: """Constructs the URL for a record, optionally including field and attachment IDs Full possible structure is: https://airtable.com/BASE_ID/SHARE_ID/TABLE_ID/VIEW_ID/RECORD_ID/FIELD_ID/ATTACHMENT_ID """ # If we have a shared link, use that view for better UX if share_id: base_url = f"https://airtable.com/{base_id}/{share_id}/{table_id}" else: base_url = f"https://airtable.com/{base_id}/{table_id}" if view_id: base_url = f"{base_url}/{view_id}" base_url = f"{base_url}/{record_id}" if field_id and attachment_id: return f"{base_url}/{field_id}/{attachment_id}?blocks=hide" return base_url def _extract_field_values( self, field_id: str, field_name: str, field_info: Any, field_type: str, base_id: str, table_id: str, view_id: str | None, record_id: str, ) -> list[tuple[str, str]]: """ Extract value(s) + links from a field regardless of its type. Attachments are represented as multiple sections, and therefore returned as a list of tuples (value, link). """ if field_info is None: return [] # skip references to other records for now (would need to do another # request to get the actual record name/type) # TODO: support this if field_type == "multipleRecordLinks": return [] # Get the base URL for this record default_link = self._get_record_url( base_id, table_id, record_id, self.share_id, self.view_id or view_id ) if field_type == "multipleAttachments": attachment_texts: list[tuple[str, str]] = [] for attachment in field_info: url = attachment.get("url") filename = attachment.get("filename", "") if not url: continue @retry( tries=5, delay=1, backoff=2, max_delay=10, ) def get_attachment_with_retry(url: str, record_id: str) -> bytes | None: try: attachment_response = requests.get(url) attachment_response.raise_for_status() return attachment_response.content except requests.exceptions.HTTPError as e: if e.response.status_code == 410: logger.info(f"Refreshing attachment for {filename}") # Re-fetch the record to get a fresh URL refreshed_record = self.airtable_client.table( base_id, table_id ).get(record_id) for refreshed_attachment in refreshed_record["fields"][ field_name ]: if refreshed_attachment.get("filename") == filename: new_url = refreshed_attachment.get("url") if new_url: attachment_response = requests.get(new_url) attachment_response.raise_for_status() return attachment_response.content logger.error(f"Failed to refresh attachment for {filename}") raise attachment_content = get_attachment_with_retry(url, record_id) if attachment_content: try: file_ext = get_file_ext(filename) attachment_id = attachment["id"] attachment_text = extract_file_text( BytesIO(attachment_content), filename, break_on_unprocessable=False, extension=file_ext, ) if attachment_text: # Use the helper method to construct attachment URLs attachment_link = self._get_record_url( base_id, table_id, record_id, self.share_id, self.view_id or view_id, field_id, attachment_id, ) attachment_texts.append( (f"{filename}:\n{attachment_text}", attachment_link) ) except Exception as e: logger.warning( f"Failed to process attachment {filename}: {str(e)}" ) return attachment_texts if field_type in ["singleCollaborator", "collaborator", "createdBy"]: combined = [] collab_name = field_info.get("name") collab_email = field_info.get("email") if collab_name: combined.append(collab_name) if collab_email: combined.append(f"({collab_email})") return [(" ".join(combined) if combined else str(field_info), default_link)] if isinstance(field_info, list): return [(str(item), default_link) for item in field_info] return [(str(field_info), default_link)] def _should_be_metadata(self, field_type: str) -> bool: """Determine if a field type should be treated as metadata. When treat_all_non_attachment_fields_as_metadata is True, all fields except attachments are treated as metadata. Otherwise, only fields with types listed in DEFAULT_METADATA_FIELD_TYPES are treated as metadata.""" if self.treat_all_non_attachment_fields_as_metadata: return field_type.lower() != "multipleattachments" return field_type.lower() in DEFAULT_METADATA_FIELD_TYPES def _process_field( self, field_id: str, field_name: str, field_info: Any, field_type: str, base_id: str, table_id: str, view_id: str | None, record_id: str, ) -> tuple[list[TextSection], dict[str, str | list[str]]]: """ Process a single Airtable field and return sections or metadata. Args: field_name: Name of the field field_info: Raw field information from Airtable field_type: Airtable field type Returns: (list of Sections, dict of metadata) """ if field_info is None: return [], {} # Get the value(s) for the field field_value_and_links = self._extract_field_values( field_id=field_id, field_name=field_name, field_info=field_info, field_type=field_type, base_id=base_id, table_id=table_id, view_id=view_id, record_id=record_id, ) if len(field_value_and_links) == 0: return [], {} # Determine if it should be metadata or a section if self._should_be_metadata(field_type): field_values = [value for value, _ in field_value_and_links] if len(field_values) > 1: return [], {field_name: field_values} return [], {field_name: field_values[0]} # Otherwise, create relevant sections sections = [ TextSection( link=link, text=( f"{field_name}:\n------------------------\n{text}\n------------------------" ), ) for text, link in field_value_and_links ] return sections, {} def _process_record( self, record: RecordDict, table_schema: TableSchema, primary_field_name: str | None, base_id: str, base_name: str | None = None, ) -> Document | None: """Process a single Airtable record into a Document. Args: record: The Airtable record to process table_schema: Schema information for the table primary_field_name: Name of the primary field, if any base_id: The ID of the base this record belongs to base_name: The name of the base (used in semantic ID for index_all mode) Returns: Document object representing the record """ table_id = table_schema.id table_name = table_schema.name record_id = record["id"] fields = record["fields"] sections: list[TextSection] = [] metadata: dict[str, str | list[str]] = {} # Get primary field value if it exists primary_field_value = ( fields.get(primary_field_name) if primary_field_name else None ) view_id = table_schema.views[0].id if table_schema.views else None for field_schema in table_schema.fields: field_name = field_schema.name field_val = fields.get(field_name) field_type = field_schema.type logger.debug( f"Processing field '{field_name}' of type '{field_type}' for record '{record_id}'." ) field_sections, field_metadata = self._process_field( field_id=field_schema.id, field_name=field_name, field_info=field_val, field_type=field_type, base_id=base_id, table_id=table_id, view_id=view_id, record_id=record_id, ) sections.extend(field_sections) metadata.update(field_metadata) if not sections: logger.warning(f"No sections found for record {record_id}") return None # Include base name in semantic ID only in index_all mode if self.index_all and base_name: semantic_id = ( f"{base_name} > {table_name}: {primary_field_value}" if primary_field_value else f"{base_name} > {table_name}" ) else: semantic_id = ( f"{table_name}: {primary_field_value}" if primary_field_value else table_name ) # Build hierarchy source_path for Craft file system subdirectory structure. # This creates: airtable/{base_name}/{table_name}/record.json source_path: list[str] = [] if base_name: source_path.append(base_name) source_path.append(table_name) return Document( id=f"airtable__{record_id}", sections=(cast(list[TextSection | ImageSection], sections)), source=DocumentSource.AIRTABLE, semantic_identifier=semantic_id, metadata=metadata, doc_metadata={ "hierarchy": { "source_path": source_path, "base_id": base_id, "table_id": table_id, "table_name": table_name, **({"base_name": base_name} if base_name else {}), } }, ) def _resolve_base_name(self, base_id: str) -> str | None: """Try to resolve a human-readable base name from the API.""" try: for base_info in self.airtable_client.bases(): if base_info.id == base_id: return base_info.name except Exception: logger.debug(f"Could not resolve base name for {base_id}") return None def _index_table( self, base_id: str, table_name_or_id: str, base_name: str | None = None, ) -> GenerateDocumentsOutput: """Index all records from a single table. Yields batches of Documents.""" # Resolve base name for hierarchy if not provided if base_name is None: base_name = self._resolve_base_name(base_id) table = self.airtable_client.table(base_id, table_name_or_id) records = table.all() table_schema = table.schema() primary_field_name = None # Find a primary field from the schema for field in table_schema.fields: if field.id == table_schema.primary_field_id: primary_field_name = field.name break logger.info( f"Processing {len(records)} records from table '{table_schema.name}' in base '{base_name or base_id}'." ) if not records: return # Process records in parallel batches using ThreadPoolExecutor PARALLEL_BATCH_SIZE = 8 max_workers = min(PARALLEL_BATCH_SIZE, len(records)) for i in range(0, len(records), PARALLEL_BATCH_SIZE): batch_records = records[i : i + PARALLEL_BATCH_SIZE] record_documents: list[Document | HierarchyNode] = [] with ThreadPoolExecutor(max_workers=max_workers) as executor: # Submit batch tasks future_to_record: dict[Future[Document | None], RecordDict] = {} for record in batch_records: # Capture the current context so that the thread gets the current tenant ID current_context = contextvars.copy_context() future_to_record[ executor.submit( current_context.run, self._process_record, record=record, table_schema=table_schema, primary_field_name=primary_field_name, base_id=base_id, base_name=base_name, ) ] = record # Wait for all tasks in this batch to complete for future in as_completed(future_to_record): record = future_to_record[future] try: document = future.result() if document: record_documents.append(document) except Exception as e: logger.exception(f"Failed to process record {record['id']}") raise e if record_documents: yield record_documents def load_from_state(self) -> GenerateDocumentsOutput: """ Fetch all records from one or all tables. NOTE: Airtable does not support filtering by time updated, so we have to fetch all records every time. """ if not self.airtable_client: raise AirtableClientNotSetUpError() if self.index_all: yield from self._load_all() else: yield from self._index_table( base_id=self.base_id, table_name_or_id=self.table_name_or_id, ) def _load_all(self) -> GenerateDocumentsOutput: """Discover all bases and tables, then index everything.""" bases = self.airtable_client.bases() logger.info(f"Discovered {len(bases)} Airtable base(s).") for base_info in bases: base_id = base_info.id base_name = base_info.name logger.info(f"Listing tables for base '{base_name}' ({base_id}).") try: base = self.airtable_client.base(base_id) tables = base.tables() except Exception: logger.exception( f"Failed to list tables for base '{base_name}' ({base_id}), skipping." ) continue logger.info(f"Found {len(tables)} table(s) in base '{base_name}'.") for table in tables: try: yield from self._index_table( base_id=base_id, table_name_or_id=table.id, base_name=base_name, ) except Exception: logger.exception( f"Failed to index table '{table.name}' ({table.id}) in base '{base_name}' ({base_id}), skipping." ) continue ================================================ FILE: backend/onyx/connectors/asana/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/asana/asana_api.py ================================================ import time from collections.abc import Iterator from datetime import datetime from typing import Dict import asana # type: ignore from onyx.utils.logger import setup_logger logger = setup_logger() # https://github.com/Asana/python-asana/tree/master?tab=readme-ov-file#documentation-for-api-endpoints class AsanaTask: def __init__( self, id: str, title: str, text: str, link: str, last_modified: datetime, project_gid: str, project_name: str, ) -> None: self.id = id self.title = title self.text = text self.link = link self.last_modified = last_modified self.project_gid = project_gid self.project_name = project_name def __str__(self) -> str: return f"ID: {self.id}\nTitle: {self.title}\nLast modified: {self.last_modified}\nText: {self.text}" class AsanaAPI: def __init__( self, api_token: str, workspace_gid: str, team_gid: str | None ) -> None: self._user = None self.workspace_gid = workspace_gid self.team_gid = team_gid self.configuration = asana.Configuration() self.api_client = asana.ApiClient(self.configuration) self.tasks_api = asana.TasksApi(self.api_client) self.stories_api = asana.StoriesApi(self.api_client) self.users_api = asana.UsersApi(self.api_client) self.project_api = asana.ProjectsApi(self.api_client) self.workspaces_api = asana.WorkspacesApi(self.api_client) self.api_error_count = 0 self.configuration.access_token = api_token self.task_count = 0 def get_tasks( self, project_gids: list[str] | None, start_date: str ) -> Iterator[AsanaTask]: """Get all tasks from the projects with the given gids that were modified since the given date. If project_gids is None, get all tasks from all projects in the workspace.""" logger.info("Starting to fetch Asana projects") projects = self.project_api.get_projects( opts={ "workspace": self.workspace_gid, "opt_fields": "gid,name,archived,modified_at", } ) start_seconds = int(time.mktime(datetime.now().timetuple())) projects_list = [] project_count = 0 for project_info in projects: project_gid = project_info["gid"] if project_gids is None or project_gid in project_gids: projects_list.append(project_gid) else: logger.debug( f"Skipping project: {project_gid} - not in accepted project_gids" ) project_count += 1 if project_count % 100 == 0: logger.info(f"Processed {project_count} projects") logger.info(f"Found {len(projects_list)} projects to process") for project_gid in projects_list: for task in self._get_tasks_for_project( project_gid, start_date, start_seconds ): yield task logger.info(f"Completed fetching {self.task_count} tasks from Asana") if self.api_error_count > 0: logger.warning( f"Encountered {self.api_error_count} API errors during task fetching" ) def _get_tasks_for_project( self, project_gid: str, start_date: str, start_seconds: int ) -> Iterator[AsanaTask]: project = self.project_api.get_project(project_gid, opts={}) project_name = project.get("name", project_gid) team = project.get("team") or {} team_gid = team.get("gid") if project.get("archived"): logger.info(f"Skipping archived project: {project_name} ({project_gid})") return if not team_gid: logger.info( f"Skipping project without a team: {project_name} ({project_gid})" ) return if project.get("privacy_setting") == "private": if self.team_gid and team_gid != self.team_gid: logger.info( f"Skipping private project not in configured team: {project_name} ({project_gid})" ) return logger.info( f"Processing private project in configured team: {project_name} ({project_gid})" ) simple_start_date = start_date.split(".")[0].split("+")[0] logger.info( f"Fetching tasks modified since {simple_start_date} for project: {project_name} ({project_gid})" ) opts = { "opt_fields": "name,memberships,memberships.project,completed_at,completed_by,created_at," "created_by,custom_fields,dependencies,due_at,due_on,external,html_notes,liked,likes," "modified_at,notes,num_hearts,parent,projects,resource_subtype,resource_type,start_on," "workspace,permalink_url", "modified_since": start_date, } tasks_from_api = self.tasks_api.get_tasks_for_project(project_gid, opts) for data in tasks_from_api: self.task_count += 1 if self.task_count % 10 == 0: end_seconds = time.mktime(datetime.now().timetuple()) runtime_seconds = end_seconds - start_seconds if runtime_seconds > 0: logger.info( f"Processed {self.task_count} tasks in {runtime_seconds:.0f} seconds " f"({self.task_count / runtime_seconds:.2f} tasks/second)" ) logger.debug(f"Processing Asana task: {data['name']}") text = self._construct_task_text(data) try: text += self._fetch_and_add_comments(data["gid"]) last_modified_date = self.format_date(data["modified_at"]) text += f"Last modified: {last_modified_date}\n" task = AsanaTask( id=data["gid"], title=data["name"], text=text, link=data["permalink_url"], last_modified=datetime.fromisoformat(data["modified_at"]), project_gid=project_gid, project_name=project_name, ) yield task except Exception: logger.error( f"Error processing task {data['gid']} in project {project_gid}", exc_info=True, ) self.api_error_count += 1 def _construct_task_text(self, data: Dict) -> str: text = f"{data['name']}\n\n" if data["notes"]: text += f"{data['notes']}\n\n" if data["created_by"] and data["created_by"]["gid"]: creator = self.get_user(data["created_by"]["gid"])["name"] created_date = self.format_date(data["created_at"]) text += f"Created by: {creator} on {created_date}\n" if data["due_on"]: due_date = self.format_date(data["due_on"]) text += f"Due date: {due_date}\n" if data["completed_at"]: completed_date = self.format_date(data["completed_at"]) text += f"Completed on: {completed_date}\n" text += "\n" return text def _fetch_and_add_comments(self, task_gid: str) -> str: text = "" stories_opts: Dict[str, str] = {} story_start = time.time() stories = self.stories_api.get_stories_for_task(task_gid, stories_opts) story_count = 0 comment_count = 0 for story in stories: story_count += 1 if story["resource_subtype"] == "comment_added": comment = self.stories_api.get_story( story["gid"], opts={"opt_fields": "text,created_by,created_at"} ) commenter = self.get_user(comment["created_by"]["gid"])["name"] text += f"Comment by {commenter}: {comment['text']}\n\n" comment_count += 1 story_duration = time.time() - story_start logger.debug( f"Processed {story_count} stories (including {comment_count} comments) in {story_duration:.2f} seconds" ) return text def get_user(self, user_gid: str) -> Dict: if self._user is not None: return self._user self._user = self.users_api.get_user(user_gid, {"opt_fields": "name,email"}) if not self._user: logger.warning(f"Unable to fetch user information for user_gid: {user_gid}") return {"name": "Unknown"} return self._user def format_date(self, date_str: str) -> str: date = datetime.fromisoformat(date_str) return time.strftime("%Y-%m-%d", date.timetuple()) def get_time(self) -> str: return time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()) ================================================ FILE: backend/onyx/connectors/asana/connector.py ================================================ import datetime from typing import Any from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.asana import asana_api from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import TextSection from onyx.utils.logger import setup_logger logger = setup_logger() class AsanaConnector(LoadConnector, PollConnector): def __init__( self, asana_workspace_id: str, asana_project_ids: str | None = None, asana_team_id: str | None = None, batch_size: int = INDEX_BATCH_SIZE, continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE, ) -> None: self.workspace_id = asana_workspace_id.strip() if asana_project_ids: project_ids = [ project_id.strip() for project_id in asana_project_ids.split(",") if project_id.strip() ] self.project_ids_to_index = project_ids or None else: self.project_ids_to_index = None self.asana_team_id = (asana_team_id.strip() or None) if asana_team_id else None self.batch_size = batch_size self.continue_on_failure = continue_on_failure logger.info( f"AsanaConnector initialized with workspace_id: {asana_workspace_id}" ) def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self.api_token = credentials["asana_api_token_secret"] self.asana_client = asana_api.AsanaAPI( api_token=self.api_token, workspace_gid=self.workspace_id, team_gid=self.asana_team_id, ) logger.info("Asana credentials loaded and API client initialized") return None def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch | None, # noqa: ARG002 ) -> GenerateDocumentsOutput: start_time = datetime.datetime.fromtimestamp(start).isoformat() logger.info(f"Starting Asana poll from {start_time}") asana = asana_api.AsanaAPI( api_token=self.api_token, workspace_gid=self.workspace_id, team_gid=self.asana_team_id, ) docs_batch: list[Document | HierarchyNode] = [] tasks = asana.get_tasks(self.project_ids_to_index, start_time) for task in tasks: doc = self._message_to_doc(task) docs_batch.append(doc) if len(docs_batch) >= self.batch_size: logger.info(f"Yielding batch of {len(docs_batch)} documents") yield docs_batch docs_batch = [] if docs_batch: logger.info(f"Yielding final batch of {len(docs_batch)} documents") yield docs_batch logger.info("Asana poll completed") def load_from_state(self) -> GenerateDocumentsOutput: logger.notice("Starting full index of all Asana tasks") return self.poll_source(start=0, end=None) def _message_to_doc(self, task: asana_api.AsanaTask) -> Document: logger.debug(f"Converting Asana task {task.id} to Document") return Document( id=task.id, sections=[TextSection(link=task.link, text=task.text)], doc_updated_at=task.last_modified, source=DocumentSource.ASANA, semantic_identifier=task.title, metadata={ "group": task.project_gid, "project": task.project_name, }, ) if __name__ == "__main__": import time import os logger.notice("Starting Asana connector test") connector = AsanaConnector( os.environ["WORKSPACE_ID"], os.environ["PROJECT_IDS"], os.environ["TEAM_ID"], ) connector.load_credentials( { "asana_api_token_secret": os.environ["API_TOKEN"], } ) logger.info("Loading all documents from Asana") all_docs = connector.load_from_state() current = time.time() one_day_ago = current - 24 * 60 * 60 # 1 day logger.info("Polling for documents updated in the last 24 hours") latest_docs = connector.poll_source(one_day_ago, current) for docs in latest_docs: for doc in docs: if isinstance(doc, HierarchyNode): print("hierarchynode:", doc.display_name) else: print(doc.id) logger.notice("Asana connector test completed") ================================================ FILE: backend/onyx/connectors/axero/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/axero/connector.py ================================================ import time from datetime import datetime from datetime import timezone from typing import Any import requests from pydantic import BaseModel from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.miscellaneous_utils import ( process_in_batches, ) from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from onyx.connectors.cross_connector_utils.rate_limit_wrapper import ( rate_limit_builder, ) from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import TextSection from onyx.file_processing.html_utils import parse_html_page_basic from onyx.utils.logger import setup_logger from onyx.utils.retry_wrapper import retry_builder logger = setup_logger() ENTITY_NAME_MAP = {1: "Forum", 3: "Article", 4: "Blog", 9: "Wiki"} def _get_auth_header(api_key: str) -> dict[str, str]: return {"Rest-Api-Key": api_key} @retry_builder() @rate_limit_builder(max_calls=5, period=1) def _rate_limited_request( endpoint: str, headers: dict, params: dict | None = None ) -> Any: # https://my.axerosolutions.com/spaces/5/communifire-documentation/wiki/view/370/rest-api return requests.get(endpoint, headers=headers, params=params) # https://my.axerosolutions.com/spaces/5/communifire-documentation/wiki/view/595/rest-api-get-content-list def _get_entities( entity_type: int, api_key: str, axero_base_url: str, start: datetime, end: datetime, space_id: str | None = None, ) -> list[dict]: endpoint = axero_base_url + "api/content/list" page_num = 1 pages_fetched = 0 pages_to_return = [] break_out = False while True: params = { "EntityType": str(entity_type), "SortColumn": "DateUpdated", "SortOrder": "1", # descending "StartPage": str(page_num), } if space_id is not None: params["SpaceID"] = space_id res = _rate_limited_request( endpoint, headers=_get_auth_header(api_key), params=params ) res.raise_for_status() # Axero limitations: # No next page token, can paginate but things may have changed # for example, a doc that hasn't been read in by Onyx is updated and is now front of the list # due to this limitation and the fact that Axero has no rate limiting but API calls can cause # increased latency for the team, we have to just fetch all the pages quickly to reduce the # chance of missing a document due to an update (it will still get updated next pass) # Assumes the volume of data isn't too big to store in memory (probably fine) data = res.json() total_records = data["TotalRecords"] contents = data["ResponseData"] pages_fetched += len(contents) logger.debug(f"Fetched {pages_fetched} {ENTITY_NAME_MAP[entity_type]}") for page in contents: update_time = time_str_to_utc(page["DateUpdated"]) if update_time > end: continue if update_time < start: break_out = True break pages_to_return.append(page) if pages_fetched >= total_records: break page_num += 1 if break_out: break return pages_to_return def _get_obj_by_id(obj_id: int, api_key: str, axero_base_url: str) -> dict: endpoint = axero_base_url + f"api/content/{obj_id}" res = _rate_limited_request(endpoint, headers=_get_auth_header(api_key)) res.raise_for_status() return res.json() class AxeroForum(BaseModel): doc_id: str title: str link: str initial_content: str responses: list[str] last_update: datetime def _map_post_to_parent( posts: dict, api_key: str, axero_base_url: str, ) -> list[AxeroForum]: """Cannot handle in batches since the posts aren't ordered or structured in any way may need to map any number of them to the initial post""" epoch_str = "1970-01-01T00:00:00.000" post_map: dict[int, AxeroForum] = {} for ind, post in enumerate(posts): if (ind + 1) % 25 == 0: logger.debug(f"Processed {ind + 1} posts or responses") post_time = time_str_to_utc( post.get("DateUpdated") or post.get("DateCreated") or epoch_str ) p_id = post.get("ParentContentID") if p_id in post_map: axero_forum = post_map[p_id] axero_forum.responses.insert(0, post.get("ContentSummary")) axero_forum.last_update = max(axero_forum.last_update, post_time) else: initial_post_d = _get_obj_by_id(p_id, api_key, axero_base_url)[ "ResponseData" ] initial_post_time = time_str_to_utc( initial_post_d.get("DateUpdated") or initial_post_d.get("DateCreated") or epoch_str ) post_map[p_id] = AxeroForum( doc_id="AXERO_" + str(initial_post_d.get("ContentID")), title=initial_post_d.get("ContentTitle"), link=initial_post_d.get("ContentURL"), initial_content=initial_post_d.get("ContentSummary"), responses=[post.get("ContentSummary")], last_update=max(post_time, initial_post_time), ) return list(post_map.values()) def _get_forums( api_key: str, axero_base_url: str, space_id: str | None = None, ) -> list[dict]: endpoint = axero_base_url + "api/content/list" page_num = 1 pages_fetched = 0 pages_to_return = [] break_out = False while True: params = { "EntityType": "54", "SortColumn": "DateUpdated", "SortOrder": "1", # descending "StartPage": str(page_num), } if space_id is not None: params["SpaceID"] = space_id res = _rate_limited_request( endpoint, headers=_get_auth_header(api_key), params=params ) res.raise_for_status() data = res.json() total_records = data["TotalRecords"] contents = data["ResponseData"] pages_fetched += len(contents) logger.debug(f"Fetched {pages_fetched} forums") for page in contents: pages_to_return.append(page) if pages_fetched >= total_records: break page_num += 1 if break_out: break return pages_to_return def _translate_forum_to_doc(af: AxeroForum) -> Document: doc = Document( id=af.doc_id, sections=[TextSection(link=af.link, text=reply) for reply in af.responses], source=DocumentSource.AXERO, semantic_identifier=af.title, doc_updated_at=af.last_update, metadata={}, ) return doc def _translate_content_to_doc(content: dict) -> Document: page_text = "" summary = content.get("ContentSummary") body = content.get("ContentBody") if summary: page_text += f"{summary}\n" if body: content_parsed = parse_html_page_basic(body) page_text += content_parsed doc = Document( id="AXERO_" + str(content["ContentID"]), sections=[TextSection(link=content["ContentURL"], text=page_text)], source=DocumentSource.AXERO, semantic_identifier=content["ContentTitle"], doc_updated_at=time_str_to_utc(content["DateUpdated"]), metadata={"space": content["SpaceName"]}, ) return doc class AxeroConnector(PollConnector): def __init__( self, # Strings of the integer ids of the spaces spaces: list[str] | None = None, include_article: bool = True, include_blog: bool = True, include_wiki: bool = True, include_forum: bool = True, batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.include_article = include_article self.include_blog = include_blog self.include_wiki = include_wiki self.include_forum = include_forum self.batch_size = batch_size self.space_ids = spaces self.axero_key = None self.base_url = None def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self.axero_key = credentials["axero_api_token"] # As the API key specifically applies to a particular deployment, this is # included as part of the credential base_url = credentials["base_url"] if not base_url.endswith("/"): base_url += "/" self.base_url = base_url return None def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: if not self.axero_key or not self.base_url: raise ConnectorMissingCredentialError("Axero") start_datetime = datetime.utcfromtimestamp(start).replace(tzinfo=timezone.utc) end_datetime = datetime.utcfromtimestamp(end).replace(tzinfo=timezone.utc) entity_types = [] if self.include_article: entity_types.append(3) if self.include_blog: entity_types.append(4) if self.include_wiki: entity_types.append(9) iterable_space_ids = self.space_ids if self.space_ids else [None] for space_id in iterable_space_ids: for entity in entity_types: axero_obj = _get_entities( entity_type=entity, api_key=self.axero_key, axero_base_url=self.base_url, start=start_datetime, end=end_datetime, space_id=space_id, ) yield from process_in_batches( objects=axero_obj, process_function=_translate_content_to_doc, batch_size=self.batch_size, ) if self.include_forum: forums_posts = _get_forums( api_key=self.axero_key, axero_base_url=self.base_url, space_id=space_id, ) all_axero_forums = _map_post_to_parent( posts=forums_posts, api_key=self.axero_key, axero_base_url=self.base_url, ) filtered_forums = [ f for f in all_axero_forums if f.last_update >= start_datetime and f.last_update <= end_datetime ] yield from process_in_batches( objects=filtered_forums, process_function=_translate_forum_to_doc, batch_size=self.batch_size, ) if __name__ == "__main__": import os connector = AxeroConnector() connector.load_credentials( { "axero_api_token": os.environ["AXERO_API_TOKEN"], "base_url": os.environ["AXERO_BASE_URL"], } ) current = time.time() one_year_ago = current - 24 * 60 * 60 * 360 latest_docs = connector.poll_source(one_year_ago, current) print(next(latest_docs)) ================================================ FILE: backend/onyx/connectors/bitbucket/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/bitbucket/connector.py ================================================ from __future__ import annotations import copy from collections.abc import Callable from collections.abc import Iterator from datetime import datetime from datetime import timezone from typing import Any from typing import TYPE_CHECKING from typing_extensions import override from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS from onyx.configs.constants import DocumentSource from onyx.connectors.bitbucket.utils import build_auth_client from onyx.connectors.bitbucket.utils import list_repositories from onyx.connectors.bitbucket.utils import map_pr_to_document from onyx.connectors.bitbucket.utils import paginate from onyx.connectors.bitbucket.utils import PR_LIST_RESPONSE_FIELDS from onyx.connectors.bitbucket.utils import SLIM_PR_LIST_RESPONSE_FIELDS from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.exceptions import UnexpectedValidationError from onyx.connectors.interfaces import CheckpointedConnector from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnectorWithPermSync from onyx.connectors.models import ConnectorCheckpoint from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import DocumentFailure from onyx.connectors.models import HierarchyNode from onyx.connectors.models import SlimDocument from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger if TYPE_CHECKING: import httpx logger = setup_logger() class BitbucketConnectorCheckpoint(ConnectorCheckpoint): """Checkpoint state for resumable Bitbucket PR indexing. Fields: repos_queue: Materialized list of repository slugs to process. current_repo_index: Index of the repository currently being processed. next_url: Bitbucket "next" URL for continuing pagination within the current repo. """ repos_queue: list[str] = [] current_repo_index: int = 0 next_url: str | None = None class BitbucketConnector( CheckpointedConnector[BitbucketConnectorCheckpoint], SlimConnectorWithPermSync, ): """Connector for indexing Bitbucket Cloud pull requests. Args: workspace: Bitbucket workspace ID. repositories: Comma-separated list of repository slugs to index. projects: Comma-separated list of project keys to index all repositories within. batch_size: Max number of documents to yield per batch. """ def __init__( self, workspace: str, repositories: str | None = None, projects: str | None = None, batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.workspace = workspace self._repositories = ( [s.strip() for s in repositories.split(",") if s.strip()] if repositories else None ) self._projects: list[str] | None = ( [s.strip() for s in projects.split(",") if s.strip()] if projects else None ) self.batch_size = batch_size self.email: str | None = None self.api_token: str | None = None def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: """Load API token-based credentials. Expects a dict with keys: `bitbucket_email`, `bitbucket_api_token`. """ self.email = credentials.get("bitbucket_email") self.api_token = credentials.get("bitbucket_api_token") if not self.email or not self.api_token: raise ConnectorMissingCredentialError("Bitbucket") return None def _client(self) -> httpx.Client: """Build an authenticated HTTP client or raise if credentials missing.""" if not self.email or not self.api_token: raise ConnectorMissingCredentialError("Bitbucket") return build_auth_client(self.email, self.api_token) def _iter_pull_requests_for_repo( self, client: httpx.Client, repo_slug: str, params: dict[str, Any] | None = None, start_url: str | None = None, on_page: Callable[[str | None], None] | None = None, ) -> Iterator[dict[str, Any]]: base = f"https://api.bitbucket.org/2.0/repositories/{self.workspace}/{repo_slug}/pullrequests" yield from paginate( client, base, params, start_url=start_url, on_page=on_page, ) def _build_params( self, fields: str = PR_LIST_RESPONSE_FIELDS, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> dict[str, Any]: """Build Bitbucket fetch params. Always include OPEN, MERGED, and DECLINED PRs. If both ``start`` and ``end`` are provided, apply a single updated_on time window. """ def _iso(ts: SecondsSinceUnixEpoch) -> str: return datetime.fromtimestamp(ts, tz=timezone.utc).isoformat() def _tc_epoch( lower_epoch: SecondsSinceUnixEpoch | None, upper_epoch: SecondsSinceUnixEpoch | None, ) -> str | None: if lower_epoch is not None and upper_epoch is not None: lower_iso = _iso(lower_epoch) upper_iso = _iso(upper_epoch) return f'(updated_on >= "{lower_iso}" AND updated_on <= "{upper_iso}")' return None params: dict[str, Any] = {"fields": fields, "pagelen": 50} time_clause = _tc_epoch(start, end) q = '(state = "OPEN" OR state = "MERGED" OR state = "DECLINED")' if time_clause: q = f"{q} AND {time_clause}" params["q"] = q return params def _iter_target_repositories(self, client: httpx.Client) -> Iterator[str]: """Yield repository slugs based on configuration. Priority: - repositories list - projects list (list repos by project key) - workspace (all repos) """ if self._repositories: for slug in self._repositories: yield slug return if self._projects: for project_key in self._projects: for repo in list_repositories(client, self.workspace, project_key): slug_val = repo.get("slug") if isinstance(slug_val, str) and slug_val: yield slug_val return for repo in list_repositories(client, self.workspace, None): slug_val = repo.get("slug") if isinstance(slug_val, str) and slug_val: yield slug_val @override def load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: BitbucketConnectorCheckpoint, ) -> CheckpointOutput[BitbucketConnectorCheckpoint]: """Resumable PR ingestion across repos and pages within a time window. Yields Documents (or ConnectorFailure for per-PR mapping failures) and returns an updated checkpoint that records repo position and next page URL. """ new_checkpoint = copy.deepcopy(checkpoint) with self._client() as client: # Materialize target repositories once if not new_checkpoint.repos_queue: # Preserve explicit order; otherwise ensure deterministic ordering repos_list = list(self._iter_target_repositories(client)) new_checkpoint.repos_queue = sorted(set(repos_list)) new_checkpoint.current_repo_index = 0 new_checkpoint.next_url = None repos = new_checkpoint.repos_queue if not repos or new_checkpoint.current_repo_index >= len(repos): new_checkpoint.has_more = False return new_checkpoint repo_slug = repos[new_checkpoint.current_repo_index] first_page_params = self._build_params( fields=PR_LIST_RESPONSE_FIELDS, start=start, end=end, ) def _on_page(next_url: str | None) -> None: new_checkpoint.next_url = next_url for pr in self._iter_pull_requests_for_repo( client, repo_slug, params=first_page_params, start_url=new_checkpoint.next_url, on_page=_on_page, ): try: document = map_pr_to_document(pr, self.workspace, repo_slug) yield document except Exception as e: pr_id = pr.get("id") pr_link = ( f"https://bitbucket.org/{self.workspace}/{repo_slug}/pull-requests/{pr_id}" if pr_id is not None else None ) yield ConnectorFailure( failed_document=DocumentFailure( document_id=( f"{DocumentSource.BITBUCKET.value}:{self.workspace}:{repo_slug}:pr:{pr_id}" if pr_id is not None else f"{DocumentSource.BITBUCKET.value}:{self.workspace}:{repo_slug}:pr:unknown" ), document_link=pr_link, ), failure_message=f"Failed to process Bitbucket PR: {e}", exception=e, ) # Advance to next repository (if any) and set has_more accordingly new_checkpoint.current_repo_index += 1 new_checkpoint.next_url = None new_checkpoint.has_more = new_checkpoint.current_repo_index < len(repos) return new_checkpoint @override def build_dummy_checkpoint(self) -> BitbucketConnectorCheckpoint: """Create an initial checkpoint with work remaining.""" return BitbucketConnectorCheckpoint(has_more=True) @override def validate_checkpoint_json( self, checkpoint_json: str ) -> BitbucketConnectorCheckpoint: """Validate and deserialize a checkpoint instance from JSON.""" return BitbucketConnectorCheckpoint.model_validate_json(checkpoint_json) def retrieve_all_slim_docs_perm_sync( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> Iterator[list[SlimDocument | HierarchyNode]]: """Return only document IDs for all existing pull requests.""" batch: list[SlimDocument | HierarchyNode] = [] params = self._build_params( fields=SLIM_PR_LIST_RESPONSE_FIELDS, start=start, end=end, ) with self._client() as client: for slug in self._iter_target_repositories(client): for pr in self._iter_pull_requests_for_repo( client, slug, params=params ): pr_id = pr["id"] doc_id = f"{DocumentSource.BITBUCKET.value}:{self.workspace}:{slug}:pr:{pr_id}" batch.append(SlimDocument(id=doc_id)) if len(batch) >= self.batch_size: yield batch batch = [] if callback: if callback.should_stop(): # Note: this is not actually used for permission sync yet, just pruning raise RuntimeError( "bitbucket_pr_sync: Stop signal detected" ) callback.progress("bitbucket_pr_sync", len(batch)) if batch: yield batch def validate_connector_settings(self) -> None: """Validate Bitbucket credentials and workspace access by probing a lightweight endpoint. Raises: CredentialExpiredError: on HTTP 401 InsufficientPermissionsError: on HTTP 403 UnexpectedValidationError: on any other failure """ try: with self._client() as client: url = f"https://api.bitbucket.org/2.0/repositories/{self.workspace}" resp = client.get( url, params={"pagelen": 1, "fields": "pagelen"}, timeout=REQUEST_TIMEOUT_SECONDS, ) if resp.status_code == 401: raise CredentialExpiredError( "Invalid or expired Bitbucket credentials (HTTP 401)." ) if resp.status_code == 403: raise InsufficientPermissionsError( "Insufficient permissions to access Bitbucket workspace (HTTP 403)." ) if resp.status_code < 200 or resp.status_code >= 300: raise UnexpectedValidationError( f"Unexpected Bitbucket error (status={resp.status_code})." ) except Exception as e: # Network or other unexpected errors if isinstance( e, ( CredentialExpiredError, InsufficientPermissionsError, UnexpectedValidationError, ConnectorMissingCredentialError, ), ): raise raise UnexpectedValidationError( f"Unexpected error while validating Bitbucket settings: {e}" ) ================================================ FILE: backend/onyx/connectors/bitbucket/utils.py ================================================ from __future__ import annotations import time from collections.abc import Callable from collections.abc import Iterator from datetime import datetime from datetime import timezone from typing import Any import httpx from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.rate_limit_wrapper import ( rate_limit_builder, ) from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import Document from onyx.connectors.models import ImageSection from onyx.connectors.models import TextSection from onyx.utils.logger import setup_logger from onyx.utils.retry_wrapper import retry_builder logger = setup_logger() # Fields requested from Bitbucket PR list endpoint to ensure rich PR data PR_LIST_RESPONSE_FIELDS: str = ",".join( [ "next", "page", "pagelen", "values.author", "values.close_source_branch", "values.closed_by", "values.comment_count", "values.created_on", "values.description", "values.destination", "values.draft", "values.id", "values.links", "values.merge_commit", "values.participants", "values.reason", "values.rendered", "values.reviewers", "values.source", "values.state", "values.summary", "values.task_count", "values.title", "values.type", "values.updated_on", ] ) # Minimal fields for slim retrieval (IDs only) SLIM_PR_LIST_RESPONSE_FIELDS: str = ",".join( [ "next", "page", "pagelen", "values.id", ] ) # Minimal fields for repository list calls REPO_LIST_RESPONSE_FIELDS: str = ",".join( [ "next", "page", "pagelen", "values.slug", "values.full_name", "values.project.key", ] ) class BitbucketRetriableError(Exception): """Raised for retriable Bitbucket conditions (429, 5xx).""" class BitbucketNonRetriableError(Exception): """Raised for non-retriable Bitbucket client errors (4xx except 429).""" @retry_builder( tries=6, delay=1, backoff=2, max_delay=30, exceptions=(BitbucketRetriableError, httpx.RequestError), ) @rate_limit_builder(max_calls=60, period=60) def bitbucket_get( client: httpx.Client, url: str, params: dict[str, Any] | None = None ) -> httpx.Response: """Perform a GET against Bitbucket with retry and rate limiting. Retries on 429 and 5xx responses, and on transport errors. Honors `Retry-After` header for 429 when present by sleeping before retrying. """ try: response = client.get(url, params=params, timeout=REQUEST_TIMEOUT_SECONDS) except httpx.RequestError: # Allow retry_builder to handle retries of transport errors raise try: response.raise_for_status() except httpx.HTTPStatusError as e: status = e.response.status_code if e.response is not None else None if status == 429: retry_after = e.response.headers.get("Retry-After") if e.response else None if retry_after is not None: try: time.sleep(int(retry_after)) except (TypeError, ValueError): pass raise BitbucketRetriableError("Bitbucket rate limit exceeded (429)") from e if status is not None and 500 <= status < 600: raise BitbucketRetriableError(f"Bitbucket server error: {status}") from e if status is not None and 400 <= status < 500: raise BitbucketNonRetriableError(f"Bitbucket client error: {status}") from e # Unknown status, propagate raise return response def build_auth_client(email: str, api_token: str) -> httpx.Client: """Create an authenticated httpx client for Bitbucket Cloud API.""" return httpx.Client(auth=(email, api_token), http2=True) def paginate( client: httpx.Client, url: str, params: dict[str, Any] | None = None, start_url: str | None = None, on_page: Callable[[str | None], None] | None = None, ) -> Iterator[dict[str, Any]]: """Iterate over paginated Bitbucket API responses yielding individual values. Args: client: Authenticated HTTP client. url: Base collection URL (first page when start_url is None). params: Query params for the first page. start_url: If provided, start from this absolute URL (ignores params). on_page: Optional callback invoked after each page with the next page URL. """ next_url = start_url or url # If resuming from a next URL, do not pass params again query = params.copy() if params else None query = None if start_url else query while next_url: resp = bitbucket_get(client, next_url, params=query) data = resp.json() values = data.get("values", []) for item in values: yield item next_url = data.get("next") if on_page is not None: on_page(next_url) # only include params on first call, next_url will contain all necessary params query = None def list_repositories( client: httpx.Client, workspace: str, project_key: str | None = None ) -> Iterator[dict[str, Any]]: """List repositories in a workspace, optionally filtered by project key.""" base_url = f"https://api.bitbucket.org/2.0/repositories/{workspace}" params: dict[str, Any] = { "fields": REPO_LIST_RESPONSE_FIELDS, "pagelen": 100, # Ensure deterministic ordering "sort": "full_name", } if project_key: params["q"] = f'project.key="{project_key}"' yield from paginate(client, base_url, params) def map_pr_to_document(pr: dict[str, Any], workspace: str, repo_slug: str) -> Document: """Map a Bitbucket pull request JSON to Onyx Document.""" pr_id = pr["id"] title = pr.get("title") or f"PR {pr_id}" description = pr.get("description") or "" state = pr.get("state") draft = pr.get("draft", False) author = pr.get("author", {}) reviewers = pr.get("reviewers", []) participants = pr.get("participants", []) link = pr.get("links", {}).get("html", {}).get("href") or ( f"https://bitbucket.org/{workspace}/{repo_slug}/pull-requests/{pr_id}" ) created_on = pr.get("created_on") updated_on = pr.get("updated_on") updated_dt = ( datetime.fromisoformat(updated_on.replace("Z", "+00:00")).astimezone( timezone.utc ) if isinstance(updated_on, str) else None ) source_branch = pr.get("source", {}).get("branch", {}).get("name", "") destination_branch = pr.get("destination", {}).get("branch", {}).get("name", "") approved_by = [ _get_user_name(p.get("user", {})) for p in participants if p.get("approved") ] primary_owner = None if author: primary_owner = BasicExpertInfo( display_name=_get_user_name(author), ) secondary_owners = [ BasicExpertInfo(display_name=_get_user_name(r)) for r in reviewers ] or None reviewer_names = [_get_user_name(r) for r in reviewers] # Create a concise summary of key PR info created_date = created_on.split("T")[0] if created_on else "N/A" updated_date = updated_on.split("T")[0] if updated_on else "N/A" content_text = ( "Pull Request Information:\n" f"- Pull Request ID: {pr_id}\n" f"- Title: {title}\n" f"- State: {state or 'N/A'} {'(Draft)' if draft else ''}\n" ) if state == "DECLINED": content_text += f"- Reason: {pr.get('reason', 'N/A')}\n" content_text += ( f"- Author: {_get_user_name(author) if author else 'N/A'}\n" f"- Reviewers: {', '.join(reviewer_names) if reviewer_names else 'N/A'}\n" f"- Branch: {source_branch} -> {destination_branch}\n" f"- Created: {created_date}\n" f"- Updated: {updated_date}" ) if description: content_text += f"\n\nDescription:\n{description}" sections: list[TextSection | ImageSection] = [ TextSection(link=link, text=content_text) ] metadata: dict[str, str | list[str]] = { "object_type": "PullRequest", "workspace": workspace, "repository": repo_slug, "pr_key": f"{workspace}/{repo_slug}#{pr_id}", "id": str(pr_id), "title": title, "state": state or "", "draft": str(bool(draft)), "link": link, "author": _get_user_name(author) if author else "", "reviewers": reviewer_names, "approved_by": approved_by, "comment_count": str(pr.get("comment_count", "")), "task_count": str(pr.get("task_count", "")), "created_on": created_on or "", "updated_on": updated_on or "", "source_branch": source_branch, "destination_branch": destination_branch, "closed_by": ( _get_user_name(pr.get("closed_by", {})) if pr.get("closed_by") else "" ), "close_source_branch": str(bool(pr.get("close_source_branch", False))), } return Document( id=f"{DocumentSource.BITBUCKET.value}:{workspace}:{repo_slug}:pr:{pr_id}", sections=sections, source=DocumentSource.BITBUCKET, semantic_identifier=f"#{pr_id}: {title}", title=title, doc_updated_at=updated_dt, primary_owners=[primary_owner] if primary_owner else None, secondary_owners=secondary_owners, metadata=metadata, ) def _get_user_name(user: dict[str, Any]) -> str: return user.get("display_name") or user.get("nickname") or "unknown" ================================================ FILE: backend/onyx/connectors/blob/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/blob/connector.py ================================================ import os import time from collections.abc import Mapping from datetime import datetime from datetime import timezone from io import BytesIO from numbers import Integral from typing import Any from typing import Optional from urllib.parse import quote import boto3 from botocore.client import Config from botocore.credentials import RefreshableCredentials from botocore.exceptions import ClientError from botocore.exceptions import NoCredentialsError from botocore.exceptions import PartialCredentialsError from botocore.session import get_session from mypy_boto3_s3 import S3Client from onyx.configs.app_configs import BLOB_STORAGE_SIZE_THRESHOLD from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import BlobType from onyx.configs.constants import DocumentSource from onyx.configs.constants import FileOrigin from onyx.connectors.cross_connector_utils.miscellaneous_utils import ( process_onyx_metadata, ) from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.exceptions import UnexpectedValidationError from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import ImageSection from onyx.connectors.models import TextSection from onyx.file_processing.extract_file_text import extract_text_and_images from onyx.file_processing.extract_file_text import get_file_ext from onyx.file_processing.file_types import OnyxFileExtensions from onyx.file_processing.image_utils import store_image_and_create_section from onyx.utils.logger import setup_logger logger = setup_logger() DOWNLOAD_CHUNK_SIZE = 1024 * 1024 SIZE_THRESHOLD_BUFFER = 64 class BlobStorageConnector(LoadConnector, PollConnector): def __init__( self, bucket_type: str, bucket_name: str, prefix: str = "", batch_size: int = INDEX_BATCH_SIZE, european_residency: bool = False, ) -> None: self.bucket_type: BlobType = BlobType(bucket_type) self.bucket_name = bucket_name.strip() self.prefix = prefix if not prefix or prefix.endswith("/") else prefix + "/" self.batch_size = batch_size self.s3_client: Optional[S3Client] = None self._allow_images: bool | None = None self.size_threshold: int | None = BLOB_STORAGE_SIZE_THRESHOLD self.bucket_region: Optional[str] = None self.european_residency: bool = european_residency def set_allow_images(self, allow_images: bool) -> None: """Set whether to process images in this connector.""" logger.info(f"Setting allow_images to {allow_images}.") self._allow_images = allow_images def _detect_bucket_region(self) -> None: """Detect and cache the actual region of the S3 bucket using head_bucket.""" if self.s3_client is None: logger.warning( "S3 client not initialized. Skipping bucket region detection." ) return try: response = self.s3_client.head_bucket(Bucket=self.bucket_name) # The region is in the response headers as 'x-amz-bucket-region' self.bucket_region = response.get("BucketRegion") or response.get( "ResponseMetadata", {} ).get("HTTPHeaders", {}).get("x-amz-bucket-region") if self.bucket_region: logger.debug(f"Detected bucket region: {self.bucket_region}") else: logger.warning("Bucket region not found in head_bucket response") except Exception as e: logger.warning(f"Failed to detect bucket region via head_bucket: {e}") def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: """Checks for boto3 credentials based on the bucket type. (1) R2: Access Key ID, Secret Access Key, Account ID (2) S3: AWS Access Key ID, AWS Secret Access Key or IAM role or Assume Role (3) GOOGLE_CLOUD_STORAGE: Access Key ID, Secret Access Key, Project ID (4) OCI_STORAGE: Namespace, Region, Access Key ID, Secret Access Key For each bucket type, the method initializes the appropriate S3 client: - R2: Uses Cloudflare R2 endpoint with S3v4 signature - S3: Creates a standard boto3 S3 client - GOOGLE_CLOUD_STORAGE: Uses Google Cloud Storage endpoint - OCI_STORAGE: Uses Oracle Cloud Infrastructure Object Storage endpoint Raises ConnectorMissingCredentialError if required credentials are missing. Raises ValueError for unsupported bucket types. """ logger.debug( f"Loading credentials for {self.bucket_name} or type {self.bucket_type}" ) if self.bucket_type == BlobType.R2: if not all( credentials.get(key) for key in ["r2_access_key_id", "r2_secret_access_key", "account_id"] ): raise ConnectorMissingCredentialError("Cloudflare R2") # Use EU endpoint if european_residency is enabled subdomain = "eu." if self.european_residency else "" endpoint_url = f"https://{credentials['account_id']}.{subdomain}r2.cloudflarestorage.com" self.s3_client = boto3.client( "s3", endpoint_url=endpoint_url, aws_access_key_id=credentials["r2_access_key_id"], aws_secret_access_key=credentials["r2_secret_access_key"], region_name="auto", config=Config(signature_version="s3v4"), ) elif self.bucket_type == BlobType.S3: # For S3, we can use either access keys or IAM roles. authentication_method = credentials.get( "authentication_method", "access_key" ) logger.debug( f"Using authentication method: {authentication_method} for S3 bucket." ) if authentication_method == "access_key": logger.debug("Using access key authentication for S3 bucket.") if not all( credentials.get(key) for key in ["aws_access_key_id", "aws_secret_access_key"] ): raise ConnectorMissingCredentialError("Amazon S3") session = boto3.Session( aws_access_key_id=credentials["aws_access_key_id"], aws_secret_access_key=credentials["aws_secret_access_key"], ) self.s3_client = session.client("s3") elif authentication_method == "iam_role": # If using IAM roles, we assume the role and let boto3 handle the credentials. role_arn = credentials.get("aws_role_arn") # create session name using timestamp if not role_arn: raise ConnectorMissingCredentialError( "Amazon S3 IAM role ARN is required for assuming role." ) def _refresh_credentials() -> dict[str, str]: """Refreshes the credentials for the assumed role.""" sts_client = boto3.client("sts") assumed_role_object = sts_client.assume_role( RoleArn=role_arn, RoleSessionName=f"onyx_blob_storage_{int(time.time())}", ) creds = assumed_role_object["Credentials"] return { "access_key": creds["AccessKeyId"], "secret_key": creds["SecretAccessKey"], "token": creds["SessionToken"], "expiry_time": creds["Expiration"].isoformat(), } refreshable = RefreshableCredentials.create_from_metadata( metadata=_refresh_credentials(), refresh_using=_refresh_credentials, method="sts-assume-role", ) botocore_session = get_session() botocore_session._credentials = refreshable # type: ignore[attr-defined] session = boto3.Session(botocore_session=botocore_session) self.s3_client = session.client("s3") elif authentication_method == "assume_role": # We will assume the instance role to access S3. logger.debug("Using instance role authentication for S3 bucket.") self.s3_client = boto3.client("s3") else: raise ConnectorValidationError("Invalid authentication method for S3. ") # This is important for correct citation links # NOTE: the client region actually doesn't matter for accessing the bucket self._detect_bucket_region() elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE: if not all( credentials.get(key) for key in ["access_key_id", "secret_access_key"] ): raise ConnectorMissingCredentialError("Google Cloud Storage") self.s3_client = boto3.client( "s3", endpoint_url="https://storage.googleapis.com", aws_access_key_id=credentials["access_key_id"], aws_secret_access_key=credentials["secret_access_key"], region_name="auto", ) elif self.bucket_type == BlobType.OCI_STORAGE: if not all( credentials.get(key) for key in ["namespace", "region", "access_key_id", "secret_access_key"] ): raise ConnectorMissingCredentialError("Oracle Cloud Infrastructure") self.s3_client = boto3.client( "s3", endpoint_url=f"https://{credentials['namespace']}.compat.objectstorage.{credentials['region']}.oraclecloud.com", aws_access_key_id=credentials["access_key_id"], aws_secret_access_key=credentials["secret_access_key"], region_name=credentials["region"], ) else: raise ValueError(f"Unsupported bucket type: {self.bucket_type}") return None def _download_object(self, key: str) -> bytes | None: if self.s3_client is None: raise ConnectorMissingCredentialError("Blob storage") response = self.s3_client.get_object(Bucket=self.bucket_name, Key=key) body = response["Body"] try: if self.size_threshold is None: return body.read() return self._read_stream_with_limit(body, key) finally: body.close() def _read_stream_with_limit(self, body: Any, key: str) -> bytes | None: if self.size_threshold is None: return body.read() bytes_read = 0 chunks: list[bytes] = [] chunk_size = min( DOWNLOAD_CHUNK_SIZE, self.size_threshold + SIZE_THRESHOLD_BUFFER ) for chunk in body.iter_chunks(chunk_size=chunk_size): if not chunk: continue chunks.append(chunk) bytes_read += len(chunk) if bytes_read > self.size_threshold + SIZE_THRESHOLD_BUFFER: logger.warning( f"{key} exceeds size threshold of {self.size_threshold}. Skipping." ) return None return b"".join(chunks) # NOTE: Left in as may be useful for one-off access to documents and sharing across orgs. # def _get_presigned_url(self, key: str) -> str: # if self.s3_client is None: # raise ConnectorMissingCredentialError("Blog storage") # url = self.s3_client.generate_presigned_url( # "get_object", # Params={"Bucket": self.bucket_name, "Key": key}, # ExpiresIn=self.presign_length, # ) # return url def _get_blob_link(self, key: str) -> str: # NOTE: We store the object dashboard URL instead of the actual object URL # This is because the actual object URL requires S3 client authentication # Accessing through the browser will always return an unauthorized error if self.s3_client is None: raise ConnectorMissingCredentialError("Blob storage") # URL encode the key to handle special characters, spaces, etc. # safe='/' keeps forward slashes unencoded for proper path structure encoded_key = quote(key, safe="/") if self.bucket_type == BlobType.R2: account_id = self.s3_client.meta.endpoint_url.split("//")[1].split(".")[0] subdomain = "eu/" if self.european_residency else "default/" return f"https://dash.cloudflare.com/{account_id}/r2/{subdomain}buckets/{self.bucket_name}/objects/{encoded_key}/details" elif self.bucket_type == BlobType.S3: region = self.bucket_region or self.s3_client.meta.region_name return f"https://s3.console.aws.amazon.com/s3/object/{self.bucket_name}?region={region}&prefix={encoded_key}" elif self.bucket_type == BlobType.GOOGLE_CLOUD_STORAGE: return f"https://console.cloud.google.com/storage/browser/_details/{self.bucket_name}/{encoded_key}" elif self.bucket_type == BlobType.OCI_STORAGE: namespace = self.s3_client.meta.endpoint_url.split("//")[1].split(".")[0] region = self.s3_client.meta.region_name return f"https://objectstorage.{region}.oraclecloud.com/n/{namespace}/b/{self.bucket_name}/o/{encoded_key}" else: # This should never happen! raise ValueError(f"Unsupported bucket type: {self.bucket_type}") @staticmethod def _extract_size_bytes(obj: Mapping[str, Any]) -> int | None: """Return the first numeric size field found on the object metadata.""" candidate_keys = ( "Size", "size", "ContentLength", "content_length", "Content-Length", "contentLength", "bytes", "Bytes", ) def _normalize(value: Any) -> int | None: if value is None or isinstance(value, bool): return None if isinstance(value, Integral): return int(value) try: numeric = float(value) except (TypeError, ValueError): return None if numeric >= 0 and numeric.is_integer(): return int(numeric) return None for key in candidate_keys: if key in obj: normalized = _normalize(obj.get(key)) if normalized is not None: return normalized for key, value in obj.items(): if not isinstance(key, str): continue lowered_key = key.lower() if "size" in lowered_key or "length" in lowered_key: normalized = _normalize(value) if normalized is not None: return normalized return None def _yield_blob_objects( self, start: datetime, end: datetime, ) -> GenerateDocumentsOutput: if self.s3_client is None: raise ConnectorMissingCredentialError("Blob storage") paginator = self.s3_client.get_paginator("list_objects_v2") pages = paginator.paginate(Bucket=self.bucket_name, Prefix=self.prefix) batch: list[Document | HierarchyNode] = [] for page in pages: if "Contents" not in page: continue for obj in page["Contents"]: if obj["Key"].endswith("/"): continue last_modified = obj["LastModified"].replace(tzinfo=timezone.utc) if not start <= last_modified <= end: continue file_name = os.path.basename(obj["Key"]) file_ext = get_file_ext(file_name) key = obj["Key"] link = self._get_blob_link(key) size_bytes = self._extract_size_bytes(obj) if ( self.size_threshold is not None and isinstance(size_bytes, int) and self.size_threshold is not None and size_bytes > self.size_threshold ): logger.warning( f"{file_name} exceeds size threshold of {self.size_threshold}. Skipping." ) continue # Handle image files if file_ext in OnyxFileExtensions.IMAGE_EXTENSIONS: if not self._allow_images: logger.debug( f"Skipping image file: {key} (image processing not enabled)" ) continue # Process the image file try: downloaded_file = self._download_object(key) if downloaded_file is None: continue # TODO: Refactor to avoid direct DB access in connector # This will require broader refactoring across the codebase image_section, _ = store_image_and_create_section( image_data=downloaded_file, file_id=f"{self.bucket_type}_{self.bucket_name}_{key.replace('/', '_')}", display_name=file_name, link=link, file_origin=FileOrigin.CONNECTOR, ) batch.append( Document( id=f"{self.bucket_type}:{self.bucket_name}:{key}", sections=[image_section], source=DocumentSource(self.bucket_type.value), semantic_identifier=file_name, doc_updated_at=last_modified, metadata={}, ) ) if len(batch) == self.batch_size: yield batch batch = [] except Exception: logger.exception(f"Error processing image {key}") continue # Handle text and document files try: downloaded_file = self._download_object(key) if downloaded_file is None: continue extraction_result = extract_text_and_images( BytesIO(downloaded_file), file_name=file_name ) onyx_metadata, custom_tags = process_onyx_metadata( extraction_result.metadata ) file_display_name = onyx_metadata.file_display_name or file_name time_updated = onyx_metadata.doc_updated_at or last_modified link = onyx_metadata.link or link primary_owners = onyx_metadata.primary_owners secondary_owners = onyx_metadata.secondary_owners source_type = onyx_metadata.source_type or DocumentSource( self.bucket_type.value ) sections: list[TextSection | ImageSection] = [] if extraction_result.text_content.strip(): logger.debug( f"Creating TextSection for {file_name} with link: {link}" ) sections.append( TextSection( link=link, text=extraction_result.text_content.strip(), ) ) batch.append( Document( id=f"{self.bucket_type}:{self.bucket_name}:{key}", sections=( sections if sections else [TextSection(link=link, text="")] ), source=source_type, semantic_identifier=file_display_name, doc_updated_at=time_updated, metadata=custom_tags, primary_owners=primary_owners, secondary_owners=secondary_owners, ) ) if len(batch) == self.batch_size: yield batch batch = [] except Exception: logger.exception(f"Error decoding object {key} as UTF-8") if batch: yield batch def load_from_state(self) -> GenerateDocumentsOutput: logger.debug("Loading blob objects") return self._yield_blob_objects( start=datetime(1970, 1, 1, tzinfo=timezone.utc), end=datetime.now(timezone.utc), ) def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: if self.s3_client is None: raise ConnectorMissingCredentialError("Blob storage") start_datetime = datetime.fromtimestamp(start, tz=timezone.utc) end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) for batch in self._yield_blob_objects(start_datetime, end_datetime): yield batch return None def validate_connector_settings(self) -> None: if self.s3_client is None: raise ConnectorMissingCredentialError( "Blob storage credentials not loaded." ) if not self.bucket_name: raise ConnectorValidationError( "No bucket name was provided in connector settings." ) try: # We only fetch one object/page as a light-weight validation step. # This ensures we trigger typical S3 permission checks (ListObjectsV2, etc.). self.s3_client.list_objects_v2( Bucket=self.bucket_name, Prefix=self.prefix, MaxKeys=1 ) except NoCredentialsError: raise ConnectorMissingCredentialError( "No valid blob storage credentials found or provided to boto3." ) except PartialCredentialsError: raise ConnectorMissingCredentialError( "Partial or incomplete blob storage credentials provided to boto3." ) except ClientError as e: error_code = e.response["Error"].get("Code", "") status_code = e.response["ResponseMetadata"].get("HTTPStatusCode") # Most common S3 error cases if error_code in [ "AccessDenied", "InvalidAccessKeyId", "SignatureDoesNotMatch", ]: if status_code == 403 or error_code == "AccessDenied": raise InsufficientPermissionsError( f"Insufficient permissions to list objects in bucket '{self.bucket_name}'. " "Please check your bucket policy and/or IAM policy." ) if status_code == 401 or error_code == "SignatureDoesNotMatch": raise CredentialExpiredError( "Provided blob storage credentials appear invalid or expired." ) raise CredentialExpiredError( f"Credential issue encountered ({error_code})." ) if error_code == "NoSuchBucket" or status_code == 404: raise ConnectorValidationError( f"Bucket '{self.bucket_name}' does not exist or cannot be found." ) raise ConnectorValidationError( f"Unexpected S3 client error (code={error_code}, status={status_code}): {e}" ) except Exception as e: # Catch-all for anything not captured by the above # Since we are unsure of the error and it may not disable the connector, # raise an unexpected error (does not disable connector) raise UnexpectedValidationError( f"Unexpected error during blob storage settings validation: {e}" ) if __name__ == "__main__": credentials_dict = { "aws_access_key_id": os.environ.get("AWS_ACCESS_KEY_ID"), "aws_secret_access_key": os.environ.get("AWS_SECRET_ACCESS_KEY"), } # Initialize the connector connector = BlobStorageConnector( bucket_type=os.environ.get("BUCKET_TYPE") or "s3", bucket_name=os.environ.get("BUCKET_NAME") or "test", prefix="", ) try: connector.load_credentials(credentials_dict) document_batch_generator = connector.load_from_state() for document_batch in document_batch_generator: print("First batch of documents:") for doc in document_batch: if isinstance(doc, HierarchyNode): print("hierarchynode:", doc.display_name) continue print(f"Document ID: {doc.id}") print(f"Semantic Identifier: {doc.semantic_identifier}") print(f"Source: {doc.source}") print(f"Updated At: {doc.doc_updated_at}") print("Sections:") for section in doc.sections: print(f" - Link: {section.link}") if isinstance(section, TextSection) and section.text is not None: print(f" - Text: {section.text[:100]}...") elif hasattr(section, "image_file_id") and section.image_file_id: print(f" - Image: {section.image_file_id}") else: print("Error: Unknown section type") print("---") break except ConnectorMissingCredentialError as e: print(f"Error: {e}") except Exception as e: print(f"An unexpected error occurred: {e}") ================================================ FILE: backend/onyx/connectors/bookstack/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/bookstack/client.py ================================================ from typing import Any import requests class BookStackClientRequestFailedError(ConnectionError): def __init__(self, status: int, error: str) -> None: self.status_code = status self.error = error super().__init__( "BookStack Client request failed with status {status}: {error}".format( status=status, error=error ) ) class BookStackApiClient: def __init__( self, base_url: str, token_id: str, token_secret: str, ) -> None: self.base_url = base_url self.token_id = token_id self.token_secret = token_secret def get(self, endpoint: str, params: dict[str, str]) -> dict[str, Any]: url: str = self._build_url(endpoint) headers = self._build_headers() response = requests.get(url, headers=headers, params=params) try: json = response.json() except Exception: json = {} if response.status_code >= 300: error = response.reason response_error = json.get("error", {}).get("message", "") if response_error: error = response_error raise BookStackClientRequestFailedError(response.status_code, error) return json def _build_headers(self) -> dict[str, str]: auth = "Token " + self.token_id + ":" + self.token_secret return { "Authorization": auth, "Accept": "application/json", } def _build_url(self, endpoint: str) -> str: return self.base_url.rstrip("/") + "/api/" + endpoint.lstrip("/") def build_app_url(self, endpoint: str) -> str: return self.base_url.rstrip("/") + "/" + endpoint.lstrip("/") ================================================ FILE: backend/onyx/connectors/bookstack/connector.py ================================================ import html import time from collections.abc import Callable from datetime import datetime from typing import Any from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.bookstack.client import BookStackApiClient from onyx.connectors.bookstack.client import BookStackClientRequestFailedError from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import TextSection from onyx.file_processing.html_utils import parse_html_page_basic class BookstackConnector(LoadConnector, PollConnector): def __init__( self, batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.batch_size = batch_size self.bookstack_client: BookStackApiClient | None = None def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self.bookstack_client = BookStackApiClient( base_url=credentials["bookstack_base_url"], token_id=credentials["bookstack_api_token_id"], token_secret=credentials["bookstack_api_token_secret"], ) return None @staticmethod def _get_doc_batch( batch_size: int, bookstack_client: BookStackApiClient, endpoint: str, transformer: Callable[[BookStackApiClient, dict], Document], start_ind: int, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> tuple[list[Document | HierarchyNode], int]: params = { "count": str(batch_size), "offset": str(start_ind), "sort": "+id", } if start: params["filter[updated_at:gte]"] = datetime.utcfromtimestamp( start ).strftime("%Y-%m-%d") if end: params["filter[updated_at:lte]"] = datetime.utcfromtimestamp(end).strftime( "%Y-%m-%d" ) batch = bookstack_client.get(endpoint, params=params).get("data", []) doc_batch: list[Document | HierarchyNode] = [ transformer(bookstack_client, item) for item in batch ] return doc_batch, len(batch) @staticmethod def _book_to_document( bookstack_client: BookStackApiClient, book: dict[str, Any] ) -> Document: url = bookstack_client.build_app_url("/books/" + str(book.get("slug"))) title = str(book.get("name", "")) text = book.get("name", "") + "\n" + book.get("description", "") updated_at_str = ( str(book.get("updated_at")) if book.get("updated_at") is not None else None ) return Document( id="book__" + str(book.get("id")), sections=[TextSection(link=url, text=text)], source=DocumentSource.BOOKSTACK, semantic_identifier="Book: " + title, title=title, doc_updated_at=( time_str_to_utc(updated_at_str) if updated_at_str is not None else None ), metadata={"type": "book"}, ) @staticmethod def _chapter_to_document( bookstack_client: BookStackApiClient, chapter: dict[str, Any] ) -> Document: url = bookstack_client.build_app_url( "/books/" + str(chapter.get("book_slug")) + "/chapter/" + str(chapter.get("slug")) ) title = str(chapter.get("name", "")) text = chapter.get("name", "") + "\n" + chapter.get("description", "") updated_at_str = ( str(chapter.get("updated_at")) if chapter.get("updated_at") is not None else None ) return Document( id="chapter__" + str(chapter.get("id")), sections=[TextSection(link=url, text=text)], source=DocumentSource.BOOKSTACK, semantic_identifier="Chapter: " + title, title=title, doc_updated_at=( time_str_to_utc(updated_at_str) if updated_at_str is not None else None ), metadata={"type": "chapter"}, ) @staticmethod def _shelf_to_document( bookstack_client: BookStackApiClient, shelf: dict[str, Any] ) -> Document: url = bookstack_client.build_app_url("/shelves/" + str(shelf.get("slug"))) title = str(shelf.get("name", "")) text = shelf.get("name", "") + "\n" + shelf.get("description", "") updated_at_str = ( str(shelf.get("updated_at")) if shelf.get("updated_at") is not None else None ) return Document( id="shelf:" + str(shelf.get("id")), sections=[TextSection(link=url, text=text)], source=DocumentSource.BOOKSTACK, semantic_identifier="Shelf: " + title, title=title, doc_updated_at=( time_str_to_utc(updated_at_str) if updated_at_str is not None else None ), metadata={"type": "shelf"}, ) @staticmethod def _page_to_document( bookstack_client: BookStackApiClient, page: dict[str, Any] ) -> Document: page_id = str(page.get("id")) title = str(page.get("name", "")) page_data = bookstack_client.get("/pages/" + page_id, {}) url = bookstack_client.build_app_url( "/books/" + str(page.get("book_slug")) + "/page/" + str(page_data.get("slug")) ) page_html = "

" + html.escape(title) + "

" + str(page_data.get("html")) text = parse_html_page_basic(page_html) updated_at_str = ( str(page_data.get("updated_at")) if page_data.get("updated_at") is not None else None ) time.sleep(0.1) return Document( id="page:" + page_id, sections=[TextSection(link=url, text=text)], source=DocumentSource.BOOKSTACK, semantic_identifier="Page: " + str(title), title=str(title), doc_updated_at=( time_str_to_utc(updated_at_str) if updated_at_str is not None else None ), metadata={"type": "page"}, ) def load_from_state(self) -> GenerateDocumentsOutput: if self.bookstack_client is None: raise ConnectorMissingCredentialError("Bookstack") return self.poll_source(None, None) def poll_source( self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None ) -> GenerateDocumentsOutput: if self.bookstack_client is None: raise ConnectorMissingCredentialError("Bookstack") transform_by_endpoint: dict[ str, Callable[[BookStackApiClient, dict], Document] ] = { "/books": self._book_to_document, "/chapters": self._chapter_to_document, "/shelves": self._shelf_to_document, "/pages": self._page_to_document, } for endpoint, transform in transform_by_endpoint.items(): start_ind = 0 while True: doc_batch, num_results = self._get_doc_batch( batch_size=self.batch_size, bookstack_client=self.bookstack_client, endpoint=endpoint, transformer=transform, start_ind=start_ind, start=start, end=end, ) start_ind += num_results if doc_batch: yield doc_batch if num_results < self.batch_size: break else: time.sleep(0.2) def validate_connector_settings(self) -> None: """ Validate that the BookStack credentials and connector settings are correct. Specifically checks that we can make an authenticated request to BookStack. """ if not self.bookstack_client: raise ConnectorMissingCredentialError( "BookStack credentials have not been loaded." ) try: # Attempt to fetch a small batch of books (arbitrary endpoint) to verify credentials _ = self.bookstack_client.get( "/books", params={"count": "1", "offset": "0"} ) except BookStackClientRequestFailedError as e: # Check for HTTP status codes if e.status_code == 401: raise CredentialExpiredError( "Your BookStack credentials appear to be invalid or expired (HTTP 401)." ) from e elif e.status_code == 403: raise InsufficientPermissionsError( "The configured BookStack token does not have sufficient permissions (HTTP 403)." ) from e else: raise ConnectorValidationError( f"Unexpected BookStack error (status={e.status_code}): {e}" ) from e except Exception as exc: raise ConnectorValidationError( f"Unexpected error while validating BookStack connector settings: {exc}" ) from exc ================================================ FILE: backend/onyx/connectors/canvas/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/canvas/access.py ================================================ """ Permissioning / AccessControl logic for Canvas courses. CE stub — returns None (no permissions). The EE implementation is loaded at runtime via ``fetch_versioned_implementation``. """ from collections.abc import Callable from typing import cast from onyx.access.models import ExternalAccess from onyx.connectors.canvas.client import CanvasApiClient from onyx.utils.variable_functionality import fetch_versioned_implementation from onyx.utils.variable_functionality import global_version def get_course_permissions( canvas_client: CanvasApiClient, course_id: int, ) -> ExternalAccess | None: if not global_version.is_ee_version(): return None ee_get_course_permissions = cast( Callable[[CanvasApiClient, int], ExternalAccess | None], fetch_versioned_implementation( "onyx.external_permissions.canvas.access", "get_course_permissions", ), ) return ee_get_course_permissions(canvas_client, course_id) ================================================ FILE: backend/onyx/connectors/canvas/client.py ================================================ from __future__ import annotations import logging import re from collections.abc import Iterator from typing import Any from urllib.parse import urlparse from onyx.connectors.cross_connector_utils.rate_limit_wrapper import ( rl_requests, ) from onyx.error_handling.error_codes import OnyxErrorCode from onyx.error_handling.exceptions import OnyxError logger = logging.getLogger(__name__) # Requests timeout in seconds. _CANVAS_CALL_TIMEOUT: int = 30 _CANVAS_API_VERSION: str = "/api/v1" # Matches the "next" URL in a Canvas Link header, e.g.: # ; rel="next" # Captures the URL inside the angle brackets. _NEXT_LINK_PATTERN: re.Pattern[str] = re.compile(r'<([^>]+)>;\s*rel="next"') _STATUS_TO_ERROR_CODE: dict[int, OnyxErrorCode] = { 401: OnyxErrorCode.CREDENTIAL_EXPIRED, 403: OnyxErrorCode.INSUFFICIENT_PERMISSIONS, 404: OnyxErrorCode.BAD_GATEWAY, 429: OnyxErrorCode.RATE_LIMITED, } def _error_code_for_status(status_code: int) -> OnyxErrorCode: """Map an HTTP status code to the appropriate OnyxErrorCode. Expects a >= 400 status code. Known codes (401, 403, 404, 429) are mapped to specific error codes; all other codes (unrecognised 4xx and 5xx) map to BAD_GATEWAY as unexpected upstream errors. """ if status_code in _STATUS_TO_ERROR_CODE: return _STATUS_TO_ERROR_CODE[status_code] return OnyxErrorCode.BAD_GATEWAY class CanvasApiClient: def __init__( self, bearer_token: str, canvas_base_url: str, ) -> None: parsed_base = urlparse(canvas_base_url) if not parsed_base.hostname: raise ValueError("canvas_base_url must include a valid host") if parsed_base.scheme != "https": raise ValueError("canvas_base_url must use https") self._bearer_token = bearer_token self.base_url = ( canvas_base_url.rstrip("/").removesuffix(_CANVAS_API_VERSION) + _CANVAS_API_VERSION ) # Hostname is already validated above; reuse parsed_base instead # of re-parsing. Used by _parse_next_link to validate pagination URLs. self._expected_host: str = parsed_base.hostname def get( self, endpoint: str = "", params: dict[str, Any] | None = None, full_url: str | None = None, ) -> tuple[Any, str | None]: """Make a GET request to the Canvas API. Returns a tuple of (json_body, next_url). next_url is parsed from the Link header and is None if there are no more pages. If full_url is provided, it is used directly (for following pagination links). Security note: full_url must only be set to values returned by ``_parse_next_link``, which validates the host against the configured Canvas base URL. Passing an arbitrary URL would leak the bearer token. """ # full_url is used when following pagination (Canvas returns the # next-page URL in the Link header). For the first request we build # the URL from the endpoint name instead. url = full_url if full_url else self._build_url(endpoint) headers = self._build_headers() response = rl_requests.get( url, headers=headers, params=params if not full_url else None, timeout=_CANVAS_CALL_TIMEOUT, ) try: response_json = response.json() except ValueError as e: if response.status_code < 300: raise OnyxError( OnyxErrorCode.BAD_GATEWAY, detail=f"Invalid JSON in Canvas response: {e}", ) logger.warning( "Failed to parse JSON from Canvas error response (status=%d): %s", response.status_code, e, ) response_json = {} if response.status_code >= 400: # Try to extract the most specific error message from the # Canvas response body. Canvas uses three different shapes # depending on the endpoint and error type: default_error: str = response.reason or f"HTTP {response.status_code}" error = default_error if isinstance(response_json, dict): # Shape 1: {"error": {"message": "Not authorized"}} error_field = response_json.get("error") if isinstance(error_field, dict): response_error = error_field.get("message", "") if response_error: error = response_error # Shape 2: {"error": "Invalid access token"} elif isinstance(error_field, str): error = error_field # Shape 3: {"errors": [{"message": "..."}]} # Used for validation errors. Only use as fallback if # we didn't already find a more specific message above. if error == default_error: errors_list = response_json.get("errors") if isinstance(errors_list, list) and errors_list: first_error = errors_list[0] if isinstance(first_error, dict): msg = first_error.get("message", "") if msg: error = msg raise OnyxError( _error_code_for_status(response.status_code), detail=error, status_code_override=response.status_code, ) next_url = self._parse_next_link(response.headers.get("Link", "")) return response_json, next_url def _parse_next_link(self, link_header: str) -> str | None: """Extract the 'next' URL from a Canvas Link header. Only returns URLs whose host matches the configured Canvas base URL to prevent leaking the bearer token to arbitrary hosts. """ expected_host = self._expected_host for match in _NEXT_LINK_PATTERN.finditer(link_header): url = match.group(1) parsed_url = urlparse(url) if parsed_url.hostname != expected_host: raise OnyxError( OnyxErrorCode.BAD_GATEWAY, detail=( "Canvas pagination returned an unexpected host " f"({parsed_url.hostname}); expected {expected_host}" ), ) if parsed_url.scheme != "https": raise OnyxError( OnyxErrorCode.BAD_GATEWAY, detail=( "Canvas pagination link must use https, " f"got {parsed_url.scheme!r}" ), ) return url return None def _build_headers(self) -> dict[str, str]: """Return the Authorization header with the bearer token.""" return {"Authorization": f"Bearer {self._bearer_token}"} def _build_url(self, endpoint: str) -> str: """Build a full Canvas API URL from an endpoint path. Assumes endpoint is non-empty (e.g. ``"courses"``, ``"announcements"``). Only called on a first request, endpoint must be set for first request. Verify endpoint exists in case of future changes where endpoint might be optional. Leading slashes are stripped to avoid double-slash in the result. self.base_url is already normalized with no trailing slash. """ final_url = self.base_url clean_endpoint = endpoint.lstrip("/") if clean_endpoint: final_url += "/" + clean_endpoint return final_url def paginate( self, endpoint: str, params: dict[str, Any] | None = None, ) -> Iterator[list[Any]]: """Yield each page of results, following Link-header pagination. Makes the first request with endpoint + params, then follows next_url from Link headers for subsequent pages. """ response, next_url = self.get(endpoint, params=params) while True: if not response: break yield response if not next_url: break response, next_url = self.get(full_url=next_url) ================================================ FILE: backend/onyx/connectors/canvas/connector.py ================================================ from datetime import datetime from datetime import timezone from typing import Any from typing import cast from typing import Literal from typing import NoReturn from typing import TypeAlias from pydantic import BaseModel from retry import retry from typing_extensions import override from onyx.access.models import ExternalAccess from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.canvas.access import get_course_permissions from onyx.connectors.canvas.client import CanvasApiClient from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.exceptions import UnexpectedValidationError from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnectorWithPermSync from onyx.connectors.models import ConnectorCheckpoint from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import ImageSection from onyx.connectors.models import TextSection from onyx.error_handling.exceptions import OnyxError from onyx.file_processing.html_utils import parse_html_page_basic from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger logger = setup_logger() def _handle_canvas_api_error(e: OnyxError) -> NoReturn: """Map Canvas API errors to connector framework exceptions.""" if e.status_code == 401: raise CredentialExpiredError( "Canvas API token is invalid or expired (HTTP 401)." ) elif e.status_code == 403: raise InsufficientPermissionsError( "Canvas API token does not have sufficient permissions (HTTP 403)." ) elif e.status_code == 429: raise ConnectorValidationError( "Canvas rate-limit exceeded (HTTP 429). Please try again later." ) elif e.status_code >= 500: raise UnexpectedValidationError( f"Unexpected Canvas HTTP error (status={e.status_code}): {e}" ) else: raise ConnectorValidationError( f"Canvas API error (status={e.status_code}): {e}" ) class CanvasCourse(BaseModel): id: int name: str | None = None course_code: str | None = None created_at: str | None = None workflow_state: str | None = None @classmethod def from_api(cls, payload: dict[str, Any]) -> "CanvasCourse": return cls( id=payload["id"], name=payload.get("name"), course_code=payload.get("course_code"), created_at=payload.get("created_at"), workflow_state=payload.get("workflow_state"), ) class CanvasPage(BaseModel): page_id: int url: str title: str body: str | None = None created_at: str | None = None updated_at: str | None = None course_id: int @classmethod def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasPage": return cls( page_id=payload["page_id"], url=payload["url"], title=payload["title"], body=payload.get("body"), created_at=payload.get("created_at"), updated_at=payload.get("updated_at"), course_id=course_id, ) class CanvasAssignment(BaseModel): id: int name: str description: str | None = None html_url: str course_id: int created_at: str | None = None updated_at: str | None = None due_at: str | None = None @classmethod def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAssignment": return cls( id=payload["id"], name=payload["name"], description=payload.get("description"), html_url=payload["html_url"], course_id=course_id, created_at=payload.get("created_at"), updated_at=payload.get("updated_at"), due_at=payload.get("due_at"), ) class CanvasAnnouncement(BaseModel): id: int title: str message: str | None = None html_url: str posted_at: str | None = None course_id: int @classmethod def from_api(cls, payload: dict[str, Any], course_id: int) -> "CanvasAnnouncement": return cls( id=payload["id"], title=payload["title"], message=payload.get("message"), html_url=payload["html_url"], posted_at=payload.get("posted_at"), course_id=course_id, ) CanvasStage: TypeAlias = Literal["pages", "assignments", "announcements"] class CanvasConnectorCheckpoint(ConnectorCheckpoint): """Checkpoint state for resumable Canvas indexing. Fields: course_ids: Materialized list of course IDs to process. current_course_index: Index into course_ids for current course. stage: Which item type we're processing for the current course. next_url: Pagination cursor within the current stage. None means start from the first page; a URL means resume from that page. Invariant: If current_course_index is incremented, stage must be reset to "pages" and next_url must be reset to None. """ course_ids: list[int] = [] current_course_index: int = 0 stage: CanvasStage = "pages" next_url: str | None = None def advance_course(self) -> None: """Move to the next course and reset within-course state.""" self.current_course_index += 1 self.stage = "pages" self.next_url = None class CanvasConnector( CheckpointedConnectorWithPermSync[CanvasConnectorCheckpoint], SlimConnectorWithPermSync, ): def __init__( self, canvas_base_url: str, batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.canvas_base_url = canvas_base_url.rstrip("/").removesuffix("/api/v1") self.batch_size = batch_size self._canvas_client: CanvasApiClient | None = None self._course_permissions_cache: dict[int, ExternalAccess | None] = {} @property def canvas_client(self) -> CanvasApiClient: if self._canvas_client is None: raise ConnectorMissingCredentialError("Canvas") return self._canvas_client def _get_course_permissions(self, course_id: int) -> ExternalAccess | None: """Get course permissions with caching.""" if course_id not in self._course_permissions_cache: self._course_permissions_cache[course_id] = get_course_permissions( canvas_client=self.canvas_client, course_id=course_id, ) return self._course_permissions_cache[course_id] @retry(tries=3, delay=1, backoff=2) def _list_courses(self) -> list[CanvasCourse]: """Fetch all courses accessible to the authenticated user.""" logger.debug("Fetching Canvas courses") courses: list[CanvasCourse] = [] for page in self.canvas_client.paginate( "courses", params={"per_page": "100", "state[]": "available"} ): courses.extend(CanvasCourse.from_api(c) for c in page) return courses @retry(tries=3, delay=1, backoff=2) def _list_pages(self, course_id: int) -> list[CanvasPage]: """Fetch all pages for a given course.""" logger.debug(f"Fetching pages for course {course_id}") pages: list[CanvasPage] = [] for page in self.canvas_client.paginate( f"courses/{course_id}/pages", params={"per_page": "100", "include[]": "body", "published": "true"}, ): pages.extend(CanvasPage.from_api(p, course_id=course_id) for p in page) return pages @retry(tries=3, delay=1, backoff=2) def _list_assignments(self, course_id: int) -> list[CanvasAssignment]: """Fetch all assignments for a given course.""" logger.debug(f"Fetching assignments for course {course_id}") assignments: list[CanvasAssignment] = [] for page in self.canvas_client.paginate( f"courses/{course_id}/assignments", params={"per_page": "100", "published": "true"}, ): assignments.extend( CanvasAssignment.from_api(a, course_id=course_id) for a in page ) return assignments @retry(tries=3, delay=1, backoff=2) def _list_announcements(self, course_id: int) -> list[CanvasAnnouncement]: """Fetch all announcements for a given course.""" logger.debug(f"Fetching announcements for course {course_id}") announcements: list[CanvasAnnouncement] = [] for page in self.canvas_client.paginate( "announcements", params={ "per_page": "100", "context_codes[]": f"course_{course_id}", "active_only": "true", }, ): announcements.extend( CanvasAnnouncement.from_api(a, course_id=course_id) for a in page ) return announcements def _build_document( self, doc_id: str, link: str, text: str, semantic_identifier: str, doc_updated_at: datetime | None, course_id: int, doc_type: str, ) -> Document: """Build a Document with standard Canvas fields.""" return Document( id=doc_id, sections=cast( list[TextSection | ImageSection], [TextSection(link=link, text=text)], ), source=DocumentSource.CANVAS, semantic_identifier=semantic_identifier, doc_updated_at=doc_updated_at, metadata={"course_id": str(course_id), "type": doc_type}, ) def _convert_page_to_document(self, page: CanvasPage) -> Document: """Convert a Canvas page to a Document.""" link = f"{self.canvas_base_url}/courses/{page.course_id}/pages/{page.url}" text_parts = [page.title] body_text = parse_html_page_basic(page.body) if page.body else "" if body_text: text_parts.append(body_text) doc_updated_at = ( datetime.fromisoformat(page.updated_at.replace("Z", "+00:00")).astimezone( timezone.utc ) if page.updated_at else None ) document = self._build_document( doc_id=f"canvas-page-{page.course_id}-{page.page_id}", link=link, text="\n\n".join(text_parts), semantic_identifier=page.title or f"Page {page.page_id}", doc_updated_at=doc_updated_at, course_id=page.course_id, doc_type="page", ) return document def _convert_assignment_to_document(self, assignment: CanvasAssignment) -> Document: """Convert a Canvas assignment to a Document.""" text_parts = [assignment.name] desc_text = ( parse_html_page_basic(assignment.description) if assignment.description else "" ) if desc_text: text_parts.append(desc_text) if assignment.due_at: due_dt = datetime.fromisoformat( assignment.due_at.replace("Z", "+00:00") ).astimezone(timezone.utc) text_parts.append(f"Due: {due_dt.strftime('%B %d, %Y %H:%M UTC')}") doc_updated_at = ( datetime.fromisoformat( assignment.updated_at.replace("Z", "+00:00") ).astimezone(timezone.utc) if assignment.updated_at else None ) document = self._build_document( doc_id=f"canvas-assignment-{assignment.course_id}-{assignment.id}", link=assignment.html_url, text="\n\n".join(text_parts), semantic_identifier=assignment.name or f"Assignment {assignment.id}", doc_updated_at=doc_updated_at, course_id=assignment.course_id, doc_type="assignment", ) return document def _convert_announcement_to_document( self, announcement: CanvasAnnouncement ) -> Document: """Convert a Canvas announcement to a Document.""" text_parts = [announcement.title] msg_text = ( parse_html_page_basic(announcement.message) if announcement.message else "" ) if msg_text: text_parts.append(msg_text) doc_updated_at = ( datetime.fromisoformat( announcement.posted_at.replace("Z", "+00:00") ).astimezone(timezone.utc) if announcement.posted_at else None ) document = self._build_document( doc_id=f"canvas-announcement-{announcement.course_id}-{announcement.id}", link=announcement.html_url, text="\n\n".join(text_parts), semantic_identifier=announcement.title or f"Announcement {announcement.id}", doc_updated_at=doc_updated_at, course_id=announcement.course_id, doc_type="announcement", ) return document @override def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: """Load and validate Canvas credentials.""" access_token = credentials.get("canvas_access_token") if not access_token: raise ConnectorMissingCredentialError("Canvas") try: client = CanvasApiClient( bearer_token=access_token, canvas_base_url=self.canvas_base_url, ) client.get("courses", params={"per_page": "1"}) except ValueError as e: raise ConnectorValidationError(f"Invalid Canvas base URL: {e}") except OnyxError as e: _handle_canvas_api_error(e) self._canvas_client = client return None @override def validate_connector_settings(self) -> None: """Validate Canvas connector settings by testing API access.""" try: self.canvas_client.get("courses", params={"per_page": "1"}) logger.info("Canvas connector settings validated successfully") except OnyxError as e: _handle_canvas_api_error(e) except ConnectorMissingCredentialError: raise except Exception as exc: raise UnexpectedValidationError( f"Unexpected error during Canvas settings validation: {exc}" ) @override def load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: CanvasConnectorCheckpoint, ) -> CheckpointOutput[CanvasConnectorCheckpoint]: # TODO(benwu408): implemented in PR3 (checkpoint) raise NotImplementedError @override def load_from_checkpoint_with_perm_sync( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: CanvasConnectorCheckpoint, ) -> CheckpointOutput[CanvasConnectorCheckpoint]: # TODO(benwu408): implemented in PR3 (checkpoint) raise NotImplementedError @override def build_dummy_checkpoint(self) -> CanvasConnectorCheckpoint: # TODO(benwu408): implemented in PR3 (checkpoint) raise NotImplementedError @override def validate_checkpoint_json( self, checkpoint_json: str ) -> CanvasConnectorCheckpoint: # TODO(benwu408): implemented in PR3 (checkpoint) raise NotImplementedError @override def retrieve_all_slim_docs_perm_sync( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> GenerateSlimDocumentOutput: # TODO(benwu408): implemented in PR4 (perm sync) raise NotImplementedError ================================================ FILE: backend/onyx/connectors/clickup/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/clickup/connector.py ================================================ from datetime import datetime from datetime import timezone from typing import Any from typing import Optional import requests from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.rate_limit_wrapper import ( rate_limit_builder, ) from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import TextSection from onyx.utils.retry_wrapper import retry_builder CLICKUP_API_BASE_URL = "https://api.clickup.com/api/v2" class ClickupConnector(LoadConnector, PollConnector): def __init__( self, batch_size: int = INDEX_BATCH_SIZE, api_token: str | None = None, team_id: str | None = None, connector_type: str | None = None, connector_ids: list[str] | None = None, retrieve_task_comments: bool = True, ) -> None: self.batch_size = batch_size self.api_token = api_token self.team_id = team_id self.connector_type = connector_type if connector_type else "workspace" self.connector_ids = connector_ids self.retrieve_task_comments = retrieve_task_comments def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self.api_token = credentials["clickup_api_token"] self.team_id = credentials["clickup_team_id"] return None @retry_builder() @rate_limit_builder(max_calls=100, period=60) def _make_request(self, endpoint: str, params: Optional[dict] = None) -> Any: if not self.api_token: raise ConnectorMissingCredentialError("Clickup") headers = {"Authorization": self.api_token} response = requests.get( f"{CLICKUP_API_BASE_URL}/{endpoint}", headers=headers, params=params ) response.raise_for_status() return response.json() def _get_task_comments(self, task_id: str) -> list[TextSection]: url_endpoint = f"/task/{task_id}/comment" response = self._make_request(url_endpoint) comments = [ TextSection( link=f"https://app.clickup.com/t/{task_id}?comment={comment_dict['id']}", text=comment_dict["comment_text"], ) for comment_dict in response["comments"] ] return comments def _get_all_tasks_filtered( self, start: int | None = None, end: int | None = None, ) -> GenerateDocumentsOutput: doc_batch: list[Document | HierarchyNode] = [] page: int = 0 params = { "include_markdown_description": "true", "include_closed": "true", "page": page, } if start is not None: params["date_updated_gt"] = start if end is not None: params["date_updated_lt"] = end if self.connector_type == "list": params["list_ids[]"] = self.connector_ids elif self.connector_type == "folder": params["project_ids[]"] = self.connector_ids elif self.connector_type == "space": params["space_ids[]"] = self.connector_ids url_endpoint = f"/team/{self.team_id}/task" while True: response = self._make_request(url_endpoint, params) page += 1 params["page"] = page for task in response["tasks"]: document = Document( id=task["id"], source=DocumentSource.CLICKUP, semantic_identifier=task["name"], doc_updated_at=( datetime.fromtimestamp( round(float(task["date_updated"]) / 1000, 3) ).replace(tzinfo=timezone.utc) ), primary_owners=[ BasicExpertInfo( display_name=task["creator"]["username"], email=task["creator"]["email"], ) ], secondary_owners=[ BasicExpertInfo( display_name=assignee["username"], email=assignee["email"], ) for assignee in task["assignees"] ], title=task["name"], sections=[ TextSection( link=task["url"], text=( task["markdown_description"] if "markdown_description" in task else task["description"] ), ) ], metadata={ "id": task["id"], "status": task["status"]["status"], "list": task["list"]["name"], "project": task["project"]["name"], "folder": task["folder"]["name"], "space_id": task["space"]["id"], "tags": [tag["name"] for tag in task["tags"]], "priority": ( task["priority"]["priority"] if "priority" in task and task["priority"] is not None else "" ), }, ) extra_fields = [ "date_created", "date_updated", "date_closed", "date_done", "due_date", ] for extra_field in extra_fields: if extra_field in task and task[extra_field] is not None: document.metadata[extra_field] = task[extra_field] if self.retrieve_task_comments: document.sections.extend(self._get_task_comments(task["id"])) doc_batch.append(document) if len(doc_batch) >= self.batch_size: yield doc_batch doc_batch = [] if response.get("last_page") is True or len(response["tasks"]) < 100: break if doc_batch: yield doc_batch def load_from_state(self) -> GenerateDocumentsOutput: if self.api_token is None: raise ConnectorMissingCredentialError("Clickup") return self._get_all_tasks_filtered(None, None) def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: if self.api_token is None: raise ConnectorMissingCredentialError("Clickup") return self._get_all_tasks_filtered(int(start * 1000), int(end * 1000)) if __name__ == "__main__": import os clickup_connector = ClickupConnector() clickup_connector.load_credentials( { "clickup_api_token": os.environ["clickup_api_token"], "clickup_team_id": os.environ["clickup_team_id"], } ) latest_docs = clickup_connector.load_from_state() for doc in latest_docs: print(doc) ================================================ FILE: backend/onyx/connectors/coda/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/coda/connector.py ================================================ import os from collections.abc import Generator from datetime import datetime from datetime import timezone from typing import Any from typing import cast from typing import Dict from typing import List from typing import Optional from pydantic import BaseModel from retry import retry from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.rate_limit_wrapper import ( rl_requests, ) from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import UnexpectedValidationError from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import ImageSection from onyx.connectors.models import TextSection from onyx.utils.batching import batch_generator from onyx.utils.logger import setup_logger _CODA_CALL_TIMEOUT = 30 _CODA_BASE_URL = "https://coda.io/apis/v1" logger = setup_logger() class CodaClientRequestFailedError(ConnectionError): def __init__(self, message: str, status_code: int): super().__init__( f"Coda API request failed with status {status_code}: {message}" ) self.status_code = status_code class CodaDoc(BaseModel): id: str browser_link: str name: str created_at: str updated_at: str workspace_id: str workspace_name: str folder_id: str | None folder_name: str | None class CodaPage(BaseModel): id: str browser_link: str name: str content_type: str created_at: str updated_at: str doc_id: str class CodaTable(BaseModel): id: str name: str browser_link: str created_at: str updated_at: str doc_id: str class CodaRow(BaseModel): id: str name: Optional[str] = None index: Optional[int] = None browser_link: str created_at: str updated_at: str values: Dict[str, Any] table_id: str doc_id: str class CodaApiClient: def __init__( self, bearer_token: str, ) -> None: self.bearer_token = bearer_token self.base_url = os.environ.get("CODA_BASE_URL", _CODA_BASE_URL) def get( self, endpoint: str, params: Optional[dict[str, str]] = None ) -> dict[str, Any]: url = self._build_url(endpoint) headers = self._build_headers() response = rl_requests.get( url, headers=headers, params=params, timeout=_CODA_CALL_TIMEOUT ) try: json = response.json() except Exception: json = {} if response.status_code >= 300: error = response.reason response_error = json.get("error", {}).get("message", "") if response_error: error = response_error raise CodaClientRequestFailedError(error, response.status_code) return json def _build_headers(self) -> Dict[str, str]: return {"Authorization": f"Bearer {self.bearer_token}"} def _build_url(self, endpoint: str) -> str: return self.base_url.rstrip("/") + "/" + endpoint.lstrip("/") class CodaConnector(LoadConnector, PollConnector): def __init__( self, batch_size: int = INDEX_BATCH_SIZE, index_page_content: bool = True, workspace_id: str | None = None, ) -> None: self.batch_size = batch_size self.index_page_content = index_page_content self.workspace_id = workspace_id self._coda_client: CodaApiClient | None = None @property def coda_client(self) -> CodaApiClient: if self._coda_client is None: raise ConnectorMissingCredentialError("Coda") return self._coda_client @retry(tries=3, delay=1, backoff=2) def _get_doc(self, doc_id: str) -> CodaDoc: """Fetch a specific Coda document by its ID.""" logger.debug(f"Fetching Coda doc with ID: {doc_id}") try: response = self.coda_client.get(f"docs/{doc_id}") except CodaClientRequestFailedError as e: if e.status_code == 404: raise ConnectorValidationError(f"Failed to fetch doc: {doc_id}") from e else: raise return CodaDoc( id=response["id"], browser_link=response["browserLink"], name=response["name"], created_at=response["createdAt"], updated_at=response["updatedAt"], workspace_id=response["workspace"]["id"], workspace_name=response["workspace"]["name"], folder_id=response["folder"]["id"] if response.get("folder") else None, folder_name=response["folder"]["name"] if response.get("folder") else None, ) @retry(tries=3, delay=1, backoff=2) def _get_page(self, doc_id: str, page_id: str) -> CodaPage: """Fetch a specific page from a Coda document.""" logger.debug(f"Fetching Coda page with ID: {page_id}") try: response = self.coda_client.get(f"docs/{doc_id}/pages/{page_id}") except CodaClientRequestFailedError as e: if e.status_code == 404: raise ConnectorValidationError( f"Failed to fetch page: {page_id} from doc: {doc_id}" ) from e else: raise return CodaPage( id=response["id"], doc_id=doc_id, browser_link=response["browserLink"], name=response["name"], content_type=response["contentType"], created_at=response["createdAt"], updated_at=response["updatedAt"], ) @retry(tries=3, delay=1, backoff=2) def _get_table(self, doc_id: str, table_id: str) -> CodaTable: """Fetch a specific table from a Coda document.""" logger.debug(f"Fetching Coda table with ID: {table_id}") try: response = self.coda_client.get(f"docs/{doc_id}/tables/{table_id}") except CodaClientRequestFailedError as e: if e.status_code == 404: raise ConnectorValidationError( f"Failed to fetch table: {table_id} from doc: {doc_id}" ) from e else: raise return CodaTable( id=response["id"], name=response["name"], browser_link=response["browserLink"], created_at=response["createdAt"], updated_at=response["updatedAt"], doc_id=doc_id, ) @retry(tries=3, delay=1, backoff=2) def _get_row(self, doc_id: str, table_id: str, row_id: str) -> CodaRow: """Fetch a specific row from a Coda table.""" logger.debug(f"Fetching Coda row with ID: {row_id}") try: response = self.coda_client.get( f"docs/{doc_id}/tables/{table_id}/rows/{row_id}" ) except CodaClientRequestFailedError as e: if e.status_code == 404: raise ConnectorValidationError( f"Failed to fetch row: {row_id} from table: {table_id} in doc: {doc_id}" ) from e else: raise values = {} for col_name, col_value in response.get("values", {}).items(): values[col_name] = col_value return CodaRow( id=response["id"], name=response.get("name"), index=response.get("index"), browser_link=response["browserLink"], created_at=response["createdAt"], updated_at=response["updatedAt"], values=values, table_id=table_id, doc_id=doc_id, ) @retry(tries=3, delay=1, backoff=2) def _list_all_docs( self, endpoint: str = "docs", params: Optional[Dict[str, str]] = None ) -> List[CodaDoc]: """List all Coda documents in the workspace.""" logger.debug("Listing documents in Coda") all_docs: List[CodaDoc] = [] next_page_token: str | None = None params = params or {} if self.workspace_id: params["workspaceId"] = self.workspace_id while True: if next_page_token: params["pageToken"] = next_page_token try: response = self.coda_client.get(endpoint, params=params) except CodaClientRequestFailedError as e: if e.status_code == 404: raise ConnectorValidationError("Failed to list docs") from e else: raise items = response.get("items", []) for item in items: doc = CodaDoc( id=item["id"], browser_link=item["browserLink"], name=item["name"], created_at=item["createdAt"], updated_at=item["updatedAt"], workspace_id=item["workspace"]["id"], workspace_name=item["workspace"]["name"], folder_id=item["folder"]["id"] if item.get("folder") else None, folder_name=item["folder"]["name"] if item.get("folder") else None, ) all_docs.append(doc) next_page_token = response.get("nextPageToken") if not next_page_token: break logger.debug(f"Found {len(all_docs)} docs") return all_docs @retry(tries=3, delay=1, backoff=2) def _list_pages_in_doc(self, doc_id: str) -> List[CodaPage]: """List all pages in a Coda document.""" logger.debug(f"Listing pages in Coda doc with ID: {doc_id}") pages: List[CodaPage] = [] endpoint = f"docs/{doc_id}/pages" params: Dict[str, str] = {} next_page_token: str | None = None while True: if next_page_token: params["pageToken"] = next_page_token try: response = self.coda_client.get(endpoint, params=params) except CodaClientRequestFailedError as e: if e.status_code == 404: raise ConnectorValidationError( f"Failed to list pages for doc: {doc_id}" ) from e else: raise items = response.get("items", []) for item in items: # can be removed if we don't care to skip hidden pages if item.get("isHidden", False): continue pages.append( CodaPage( id=item["id"], browser_link=item["browserLink"], name=item["name"], content_type=item["contentType"], created_at=item["createdAt"], updated_at=item["updatedAt"], doc_id=doc_id, ) ) next_page_token = response.get("nextPageToken") if not next_page_token: break logger.debug(f"Found {len(pages)} pages in doc {doc_id}") return pages @retry(tries=3, delay=1, backoff=2) def _fetch_page_content(self, doc_id: str, page_id: str) -> str: """Fetch the content of a Coda page.""" logger.debug(f"Fetching content for page {page_id} in doc {doc_id}") content_parts = [] next_page_token: str | None = None params: Dict[str, str] = {} while True: if next_page_token: params["pageToken"] = next_page_token try: response = self.coda_client.get( f"docs/{doc_id}/pages/{page_id}/content", params=params ) except CodaClientRequestFailedError as e: if e.status_code == 404: logger.debug(f"No content available for page {page_id}") return "" raise items = response.get("items", []) for item in items: item_content = item.get("itemContent", {}) content_text = item_content.get("content", "") if content_text: content_parts.append(content_text) next_page_token = response.get("nextPageToken") if not next_page_token: break return "\n\n".join(content_parts) @retry(tries=3, delay=1, backoff=2) def _list_tables(self, doc_id: str) -> List[CodaTable]: """List all tables in a Coda document.""" logger.debug(f"Listing tables in Coda doc with ID: {doc_id}") tables: List[CodaTable] = [] endpoint = f"docs/{doc_id}/tables" params: Dict[str, str] = {} next_page_token: str | None = None while True: if next_page_token: params["pageToken"] = next_page_token try: response = self.coda_client.get(endpoint, params=params) except CodaClientRequestFailedError as e: if e.status_code == 404: raise ConnectorValidationError( f"Failed to list tables for doc: {doc_id}" ) from e else: raise items = response.get("items", []) for item in items: tables.append( CodaTable( id=item["id"], browser_link=item["browserLink"], name=item["name"], created_at=item["createdAt"], updated_at=item["updatedAt"], doc_id=doc_id, ) ) next_page_token = response.get("nextPageToken") if not next_page_token: break logger.debug(f"Found {len(tables)} tables in doc {doc_id}") return tables @retry(tries=3, delay=1, backoff=2) def _list_rows_and_values(self, doc_id: str, table_id: str) -> List[CodaRow]: """List all rows and their values in a table.""" logger.debug(f"Listing rows in Coda table: {table_id} in Coda doc: {doc_id}") rows: List[CodaRow] = [] endpoint = f"docs/{doc_id}/tables/{table_id}/rows" params: Dict[str, str] = {"valueFormat": "rich"} next_page_token: str | None = None while True: if next_page_token: params["pageToken"] = next_page_token try: response = self.coda_client.get(endpoint, params=params) except CodaClientRequestFailedError as e: if e.status_code == 404: raise ConnectorValidationError( f"Failed to list rows for table: {table_id} in doc: {doc_id}" ) from e else: raise items = response.get("items", []) for item in items: values = {} for col_name, col_value in item.get("values", {}).items(): values[col_name] = col_value rows.append( CodaRow( id=item["id"], name=item["name"], index=item["index"], browser_link=item["browserLink"], created_at=item["createdAt"], updated_at=item["updatedAt"], values=values, table_id=table_id, doc_id=doc_id, ) ) next_page_token = response.get("nextPageToken") if not next_page_token: break logger.debug(f"Found {len(rows)} rows in table {table_id}") return rows def _convert_page_to_document(self, page: CodaPage, content: str = "") -> Document: """Convert a page into a Document.""" page_updated = datetime.fromisoformat(page.updated_at).astimezone(timezone.utc) text_parts = [page.name, page.browser_link] if content: text_parts.append(content) sections = [TextSection(link=page.browser_link, text="\n\n".join(text_parts))] return Document( id=f"coda-page-{page.doc_id}-{page.id}", sections=cast(list[TextSection | ImageSection], sections), source=DocumentSource.CODA, semantic_identifier=page.name or f"Page {page.id}", doc_updated_at=page_updated, metadata={ "browser_link": page.browser_link, "doc_id": page.doc_id, "content_type": page.content_type, }, ) def _convert_table_with_rows_to_document( self, table: CodaTable, rows: List[CodaRow] ) -> Document: """Convert a table and its rows into a single Document with multiple sections (one per row).""" table_updated = datetime.fromisoformat(table.updated_at).astimezone( timezone.utc ) sections: List[TextSection] = [] for row in rows: content_text = " ".join( str(v) if not isinstance(v, list) else " ".join(map(str, v)) for v in row.values.values() ) row_name = row.name or f"Row {row.index or row.id}" text = f"{row_name}: {content_text}" if content_text else row_name sections.append(TextSection(link=row.browser_link, text=text)) # If no rows, create a single section for the table itself if not sections: sections = [ TextSection(link=table.browser_link, text=f"Table: {table.name}") ] return Document( id=f"coda-table-{table.doc_id}-{table.id}", sections=cast(list[TextSection | ImageSection], sections), source=DocumentSource.CODA, semantic_identifier=table.name or f"Table {table.id}", doc_updated_at=table_updated, metadata={ "browser_link": table.browser_link, "doc_id": table.doc_id, "row_count": str(len(rows)), }, ) def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: """Load and validate Coda credentials.""" self._coda_client = CodaApiClient(bearer_token=credentials["coda_bearer_token"]) try: self._coda_client.get("docs", params={"limit": "1"}) except CodaClientRequestFailedError as e: if e.status_code == 401: raise ConnectorMissingCredentialError("Invalid Coda API token") raise return None def load_from_state(self) -> GenerateDocumentsOutput: """Load all documents from Coda workspace.""" def _iter_documents() -> Generator[Document, None, None]: docs = self._list_all_docs() logger.info(f"Found {len(docs)} Coda docs to process") for doc in docs: logger.debug(f"Processing doc: {doc.name} ({doc.id})") try: pages = self._list_pages_in_doc(doc.id) for page in pages: content = "" if self.index_page_content: try: content = self._fetch_page_content(doc.id, page.id) except Exception as e: logger.warning( f"Failed to fetch content for page {page.id}: {e}" ) yield self._convert_page_to_document(page, content) except ConnectorValidationError as e: logger.warning(f"Failed to list pages for doc {doc.id}: {e}") try: tables = self._list_tables(doc.id) for table in tables: try: rows = self._list_rows_and_values(doc.id, table.id) yield self._convert_table_with_rows_to_document(table, rows) except ConnectorValidationError as e: logger.warning( f"Failed to list rows for table {table.id}: {e}" ) yield self._convert_table_with_rows_to_document(table, []) except ConnectorValidationError as e: logger.warning(f"Failed to list tables for doc {doc.id}: {e}") return batch_generator(_iter_documents(), self.batch_size) def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: """ Polls the Coda API for documents updated between start and end timestamps. We refer to page and table update times to determine if they need to be re-indexed. """ def _iter_documents() -> Generator[Document, None, None]: docs = self._list_all_docs() logger.info( f"Polling {len(docs)} Coda docs for updates between {start} and {end}" ) for doc in docs: try: pages = self._list_pages_in_doc(doc.id) for page in pages: page_timestamp = ( datetime.fromisoformat(page.updated_at) .astimezone(timezone.utc) .timestamp() ) if start < page_timestamp <= end: content = "" if self.index_page_content: try: content = self._fetch_page_content(doc.id, page.id) except Exception as e: logger.warning( f"Failed to fetch content for page {page.id}: {e}" ) yield self._convert_page_to_document(page, content) except ConnectorValidationError as e: logger.warning(f"Failed to list pages for doc {doc.id}: {e}") try: tables = self._list_tables(doc.id) for table in tables: table_timestamp = ( datetime.fromisoformat(table.updated_at) .astimezone(timezone.utc) .timestamp() ) try: rows = self._list_rows_and_values(doc.id, table.id) table_or_rows_updated = start < table_timestamp <= end if not table_or_rows_updated: for row in rows: row_timestamp = ( datetime.fromisoformat(row.updated_at) .astimezone(timezone.utc) .timestamp() ) if start < row_timestamp <= end: table_or_rows_updated = True break if table_or_rows_updated: yield self._convert_table_with_rows_to_document( table, rows ) except ConnectorValidationError as e: logger.warning( f"Failed to list rows for table {table.id}: {e}" ) if table_timestamp > start and table_timestamp <= end: yield self._convert_table_with_rows_to_document( table, [] ) except ConnectorValidationError as e: logger.warning(f"Failed to list tables for doc {doc.id}: {e}") return batch_generator(_iter_documents(), self.batch_size) def validate_connector_settings(self) -> None: """Validates the Coda connector settings calling the 'whoami' endpoint.""" try: response = self.coda_client.get("whoami") logger.info( f"Coda connector validated for user: {response.get('name', 'Unknown')}" ) if self.workspace_id: params = {"workspaceId": self.workspace_id, "limit": "1"} self.coda_client.get("docs", params=params) logger.info(f"Validated access to workspace: {self.workspace_id}") except CodaClientRequestFailedError as e: if e.status_code == 401: raise CredentialExpiredError( "Coda credential appears to be invalid or expired (HTTP 401)." ) elif e.status_code == 404: raise ConnectorValidationError( "Coda workspace not found or not accessible (HTTP 404). " "Please verify the workspace_id is correct and shared with the integration." ) elif e.status_code == 429: raise ConnectorValidationError( "Validation failed due to Coda rate-limits being exceeded (HTTP 429). Please try again later." ) else: raise UnexpectedValidationError( f"Unexpected Coda HTTP error (status={e.status_code}): {e}" ) except Exception as exc: raise UnexpectedValidationError( f"Unexpected error during Coda settings validation: {exc}" ) ================================================ FILE: backend/onyx/connectors/confluence/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/confluence/access.py ================================================ from collections.abc import Callable from typing import Any from typing import cast from onyx.access.models import ExternalAccess from onyx.connectors.confluence.onyx_confluence import OnyxConfluence from onyx.utils.variable_functionality import fetch_versioned_implementation from onyx.utils.variable_functionality import global_version def get_page_restrictions( confluence_client: OnyxConfluence, page_id: str, page_restrictions: dict[str, Any], ancestors: list[dict[str, Any]], ) -> ExternalAccess | None: """ Get page access restrictions for a Confluence page. This functionality requires Enterprise Edition. Note: This wrapper is only called from permission sync path. Group IDs are left unprefixed here because upsert_document_external_perms handles prefixing. Args: confluence_client: OnyxConfluence client instance page_id: The ID of the page page_restrictions: Dictionary containing page restriction data ancestors: List of ancestor pages with their restriction data Returns: ExternalAccess object for the page. None if EE is not enabled or no restrictions found. """ # Check if EE is enabled if not global_version.is_ee_version(): return None # Fetch the EE implementation ee_get_all_page_restrictions = cast( Callable[ [OnyxConfluence, str, dict[str, Any], list[dict[str, Any]], bool], ExternalAccess | None, ], fetch_versioned_implementation( "onyx.external_permissions.confluence.page_access", "get_page_restrictions" ), ) # add_prefix=False: permission sync path - upsert_document_external_perms handles prefixing return ee_get_all_page_restrictions( confluence_client, page_id, page_restrictions, ancestors, False ) def get_all_space_permissions( confluence_client: OnyxConfluence, is_cloud: bool, ) -> dict[str, ExternalAccess]: """ Get access permissions for all spaces in Confluence. This functionality requires Enterprise Edition. Note: This wrapper is only called from permission sync path. Group IDs are left unprefixed here because upsert_document_external_perms handles prefixing. Args: confluence_client: OnyxConfluence client instance is_cloud: Whether this is a Confluence Cloud instance Returns: Dictionary mapping space keys to ExternalAccess objects. Empty dict if EE is not enabled. """ # Check if EE is enabled if not global_version.is_ee_version(): return {} # Fetch the EE implementation ee_get_all_space_permissions = cast( Callable[ [OnyxConfluence, bool, bool], dict[str, ExternalAccess], ], fetch_versioned_implementation( "onyx.external_permissions.confluence.space_access", "get_all_space_permissions", ), ) # add_prefix=False: permission sync path - upsert_document_external_perms handles prefixing return ee_get_all_space_permissions(confluence_client, is_cloud, False) ================================================ FILE: backend/onyx/connectors/confluence/connector.py ================================================ import copy from collections.abc import Generator from datetime import datetime from datetime import timedelta from datetime import timezone from typing import Any from urllib.parse import quote from atlassian.errors import ApiError # type: ignore from requests.exceptions import HTTPError from typing_extensions import override from onyx.access.models import ExternalAccess from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_LABELS_TO_SKIP from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.confluence.access import get_all_space_permissions from onyx.connectors.confluence.access import get_page_restrictions from onyx.connectors.confluence.onyx_confluence import extract_text_from_confluence_html from onyx.connectors.confluence.onyx_confluence import OnyxConfluence from onyx.connectors.confluence.utils import build_confluence_document_id from onyx.connectors.confluence.utils import convert_attachment_to_content from onyx.connectors.confluence.utils import datetime_from_string from onyx.connectors.confluence.utils import update_param_in_path from onyx.connectors.confluence.utils import validate_attachment_filetype from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider from onyx.connectors.cross_connector_utils.miscellaneous_utils import ( is_atlassian_date_error, ) from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.exceptions import UnexpectedValidationError from onyx.connectors.interfaces import CheckpointedConnector from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import ConnectorCheckpoint from onyx.connectors.interfaces import ConnectorFailure from onyx.connectors.interfaces import CredentialsConnector from onyx.connectors.interfaces import CredentialsProviderInterface from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnector from onyx.connectors.interfaces import SlimConnectorWithPermSync from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import DocumentFailure from onyx.connectors.models import HierarchyNode from onyx.connectors.models import ImageSection from onyx.connectors.models import SlimDocument from onyx.connectors.models import TextSection from onyx.db.enums import HierarchyNodeType from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger logger = setup_logger() # Potential Improvements # 1. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost _COMMENT_EXPANSION_FIELDS = ["body.storage.value"] _PAGE_EXPANSION_FIELDS = [ "body.storage.value", "version", "space", "metadata.labels", "history.lastUpdated", "ancestors", # For hierarchy node tracking ] _ATTACHMENT_EXPANSION_FIELDS = [ "version", "space", "metadata.labels", ] _RESTRICTIONS_EXPANSION_FIELDS = [ "space", "restrictions.read.restrictions.user", "restrictions.read.restrictions.group", "ancestors.restrictions.read.restrictions.user", "ancestors.restrictions.read.restrictions.group", ] _SLIM_DOC_BATCH_SIZE = 5000 ONE_HOUR = 3600 ONE_DAY = ONE_HOUR * 24 MAX_CACHED_IDS = 100 def _get_page_id(page: dict[str, Any], allow_missing: bool = False) -> str: if allow_missing and "id" not in page: return "unknown" return str(page["id"]) class ConfluenceCheckpoint(ConnectorCheckpoint): next_page_url: str | None class ConfluenceConnector( CheckpointedConnector[ConfluenceCheckpoint], SlimConnector, SlimConnectorWithPermSync, CredentialsConnector, ): def __init__( self, wiki_base: str, is_cloud: bool, space: str = "", page_id: str = "", index_recursively: bool = False, cql_query: str | None = None, batch_size: int = INDEX_BATCH_SIZE, continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE, # if a page has one of the labels specified in this list, we will just # skip it. This is generally used to avoid indexing extra sensitive # pages. labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP, timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET, scoped_token: bool = False, ) -> None: self.wiki_base = wiki_base self.is_cloud = is_cloud self.space = space self.page_id = page_id self.index_recursively = index_recursively self.cql_query = cql_query self.batch_size = batch_size self.labels_to_skip = labels_to_skip self.timezone_offset = timezone_offset self.scoped_token = scoped_token self._confluence_client: OnyxConfluence | None = None self._low_timeout_confluence_client: OnyxConfluence | None = None self._fetched_titles: set[str] = set() self.allow_images = False # Track hierarchy nodes we've already yielded to avoid duplicates self.seen_hierarchy_node_raw_ids: set[str] = set() # Remove trailing slash from wiki_base if present self.wiki_base = wiki_base.rstrip("/") """ If nothing is provided, we default to fetching all pages Only one or none of the following options should be specified so the order shouldn't matter However, we use elif to ensure that only of the following is enforced """ base_cql_page_query = "type=page" if cql_query: base_cql_page_query = cql_query elif page_id: if index_recursively: base_cql_page_query += f" and (ancestor='{page_id}' or id='{page_id}')" else: base_cql_page_query += f" and id='{page_id}'" elif space: uri_safe_space = quote(space) base_cql_page_query += f" and space='{uri_safe_space}'" self.base_cql_page_query = base_cql_page_query self.cql_label_filter = "" if labels_to_skip: labels_to_skip = list(set(labels_to_skip)) comma_separated_labels = ",".join( f"'{quote(label)}'" for label in labels_to_skip ) self.cql_label_filter = f" and label not in ({comma_separated_labels})" self.timezone: timezone = timezone(offset=timedelta(hours=timezone_offset)) self.credentials_provider: CredentialsProviderInterface | None = None self.probe_kwargs = { "max_backoff_retries": 6, "max_backoff_seconds": 10, } self.final_kwargs = { "max_backoff_retries": 10, "max_backoff_seconds": 60, } # deprecated self.continue_on_failure = continue_on_failure def set_allow_images(self, value: bool) -> None: logger.info(f"Setting allow_images to {value}.") self.allow_images = value def _yield_space_hierarchy_nodes( self, ) -> Generator[HierarchyNode, None, None]: """Yield hierarchy nodes for all spaces we're indexing.""" space_keys = [self.space] if self.space else None for space in self.confluence_client.retrieve_confluence_spaces( space_keys=space_keys, limit=50, ): space_key = space.get("key") if not space_key or space_key in self.seen_hierarchy_node_raw_ids: continue self.seen_hierarchy_node_raw_ids.add(space_key) # Build space link space_link = f"{self.wiki_base}/spaces/{space_key}" yield HierarchyNode( raw_node_id=space_key, raw_parent_id=None, # Parent is SOURCE display_name=space.get("name", space_key), link=space_link, node_type=HierarchyNodeType.SPACE, ) def _yield_ancestor_hierarchy_nodes( self, page: dict[str, Any], ) -> Generator[HierarchyNode, None, None]: """Yield hierarchy nodes for all unseen ancestors of this page. Any page that appears as an ancestor of another page IS a hierarchy node (it has at least one child - the page we're currently processing). This ensures parent nodes are always yielded before child documents. Note: raw_node_id for page hierarchy nodes uses the page URL (same as document.id) to enable document<->hierarchy node linking in the indexing pipeline. Space hierarchy nodes use the space key since they don't have documents. """ ancestors = page.get("ancestors", []) space_key = page.get("space", {}).get("key") # Ensure space is yielded first (if not already) if space_key and space_key not in self.seen_hierarchy_node_raw_ids: self.seen_hierarchy_node_raw_ids.add(space_key) space = page.get("space", {}) yield HierarchyNode( raw_node_id=space_key, raw_parent_id=None, # Parent is SOURCE display_name=space.get("name", space_key), link=f"{self.wiki_base}/spaces/{space_key}", node_type=HierarchyNodeType.SPACE, ) # Walk through ancestors (root to immediate parent) # Build a list of (ancestor_url, ancestor_data) pairs first ancestor_urls: list[str | None] = [] for ancestor in ancestors: if "_links" in ancestor and "webui" in ancestor["_links"]: ancestor_urls.append( build_confluence_document_id( self.wiki_base, ancestor["_links"]["webui"], self.is_cloud ) ) else: ancestor_urls.append(None) for i, ancestor in enumerate(ancestors): ancestor_url = ancestor_urls[i] if not ancestor_url: # Can't build URL for this ancestor, skip it continue if ancestor_url in self.seen_hierarchy_node_raw_ids: continue self.seen_hierarchy_node_raw_ids.add(ancestor_url) # Determine parent of this ancestor if i == 0: # First ancestor - parent is the space parent_raw_id = space_key else: # Parent is the previous ancestor (use URL) parent_raw_id = ancestor_urls[i - 1] yield HierarchyNode( raw_node_id=ancestor_url, # Use URL to match document.id raw_parent_id=parent_raw_id, display_name=ancestor.get("title", f"Page {ancestor.get('id')}"), link=ancestor_url, node_type=HierarchyNodeType.PAGE, ) def _get_parent_hierarchy_raw_id(self, page: dict[str, Any]) -> str | None: """Get the raw hierarchy node ID of this page's parent. Returns: - Parent page URL if page has a parent page (last item in ancestors) - Space key if page is at top level of space - None if we can't determine Note: For pages, we return URLs (to match document.id and hierarchy node raw_node_id). For spaces, we return the space key (spaces don't have documents). """ ancestors = page.get("ancestors", []) if ancestors: # Last ancestor is the immediate parent page - use URL parent = ancestors[-1] if "_links" in parent and "webui" in parent["_links"]: return build_confluence_document_id( self.wiki_base, parent["_links"]["webui"], self.is_cloud ) # Fallback to page ID if URL not available (shouldn't happen normally) return str(parent.get("id")) # Top-level page - parent is the space (use space key) return page.get("space", {}).get("key") def _maybe_yield_page_hierarchy_node( self, page: dict[str, Any] ) -> HierarchyNode | None: """Yield a hierarchy node for this page if not already yielded. Used when a page has attachments - attachments are children of the page in the hierarchy, so the page must be a hierarchy node. Note: raw_node_id uses the page URL (same as document.id) to enable document<->hierarchy node linking in the indexing pipeline. """ # Build page URL - we use this as raw_node_id to match document.id if "_links" not in page or "webui" not in page["_links"]: return None # Can't build URL, skip page_url = build_confluence_document_id( self.wiki_base, page["_links"]["webui"], self.is_cloud ) if page_url in self.seen_hierarchy_node_raw_ids: return None self.seen_hierarchy_node_raw_ids.add(page_url) # Get parent hierarchy ID parent_raw_id = self._get_parent_hierarchy_raw_id(page) return HierarchyNode( raw_node_id=page_url, # Use URL to match document.id raw_parent_id=parent_raw_id, display_name=page.get("title", f"Page {_get_page_id(page)}"), link=page_url, node_type=HierarchyNodeType.PAGE, ) @property def confluence_client(self) -> OnyxConfluence: if self._confluence_client is None: raise ConnectorMissingCredentialError("Confluence") return self._confluence_client @property def low_timeout_confluence_client(self) -> OnyxConfluence: if self._low_timeout_confluence_client is None: raise ConnectorMissingCredentialError("Confluence") return self._low_timeout_confluence_client def set_credentials_provider( self, credentials_provider: CredentialsProviderInterface ) -> None: self.credentials_provider = credentials_provider # raises exception if there's a problem confluence_client = OnyxConfluence( is_cloud=self.is_cloud, url=self.wiki_base, credentials_provider=credentials_provider, scoped_token=self.scoped_token, ) confluence_client._probe_connection(**self.probe_kwargs) confluence_client._initialize_connection(**self.final_kwargs) self._confluence_client = confluence_client # create a low timeout confluence client for sync flows low_timeout_confluence_client = OnyxConfluence( is_cloud=self.is_cloud, url=self.wiki_base, credentials_provider=credentials_provider, timeout=3, scoped_token=self.scoped_token, ) low_timeout_confluence_client._probe_connection(**self.probe_kwargs) low_timeout_confluence_client._initialize_connection(**self.final_kwargs) self._low_timeout_confluence_client = low_timeout_confluence_client def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: raise NotImplementedError("Use set_credentials_provider with this connector.") def _construct_page_cql_query( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> str: """ Constructs a CQL query for use in the confluence API. See https://developer.atlassian.com/server/confluence/advanced-searching-using-cql/ for more information. This is JUST the CQL, not the full URL used to hit the API. Use _build_page_retrieval_url to get the full URL. """ page_query = self.base_cql_page_query + self.cql_label_filter # Add time filters if start: formatted_start_time = datetime.fromtimestamp( start, tz=self.timezone ).strftime("%Y-%m-%d %H:%M") page_query += f" and lastmodified >= '{formatted_start_time}'" if end: formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime( "%Y-%m-%d %H:%M" ) page_query += f" and lastmodified <= '{formatted_end_time}'" page_query += " order by lastmodified asc" return page_query def _construct_attachment_query( self, confluence_page_id: str, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> str: attachment_query = f"type=attachment and container='{confluence_page_id}'" attachment_query += self.cql_label_filter # Add time filters to avoid reprocessing unchanged attachments during refresh if start: formatted_start_time = datetime.fromtimestamp( start, tz=self.timezone ).strftime("%Y-%m-%d %H:%M") attachment_query += f" and lastmodified >= '{formatted_start_time}'" if end: formatted_end_time = datetime.fromtimestamp(end, tz=self.timezone).strftime( "%Y-%m-%d %H:%M" ) attachment_query += f" and lastmodified <= '{formatted_end_time}'" attachment_query += " order by lastmodified asc" return attachment_query def _get_comment_string_for_page_id(self, page_id: str) -> str: comment_string = "" comment_cql = f"type=comment and container='{page_id}'" comment_cql += self.cql_label_filter expand = ",".join(_COMMENT_EXPANSION_FIELDS) for comment in self.confluence_client.paginated_cql_retrieval( cql=comment_cql, expand=expand, ): comment_string += "\nComment:\n" comment_string += extract_text_from_confluence_html( confluence_client=self.confluence_client, confluence_object=comment, fetched_titles=set(), ) return comment_string def _convert_page_to_document( self, page: dict[str, Any] ) -> Document | ConnectorFailure: """ Converts a Confluence page to a Document object. Includes the page content, comments, and attachments. """ page_id = page_url = "" try: # Extract basic page information page_id = _get_page_id(page) page_title = page["title"] logger.info(f"Converting page {page_title} to document") page_url = build_confluence_document_id( self.wiki_base, page["_links"]["webui"], self.is_cloud ) # Get the page content page_content = extract_text_from_confluence_html( self.confluence_client, page, self._fetched_titles ) # Create the main section for the page content sections: list[TextSection | ImageSection] = [ TextSection(text=page_content, link=page_url) ] # Process comments if available comment_text = self._get_comment_string_for_page_id(page_id) if comment_text: sections.append( TextSection(text=comment_text, link=f"{page_url}#comments") ) # Note: attachments are no longer merged into the page document. # They are indexed as separate documents downstream. # Extract metadata metadata = {} if "space" in page: metadata["space"] = page["space"].get("name", "") # Extract labels labels = [] if "metadata" in page and "labels" in page["metadata"]: for label in page["metadata"]["labels"].get("results", []): labels.append(label.get("name", "")) if labels: metadata["labels"] = labels # Extract owners primary_owners = [] if "version" in page and "by" in page["version"]: author = page["version"]["by"] display_name = author.get("displayName", "Unknown") email = author.get("email", "unknown@domain.invalid") primary_owners.append( BasicExpertInfo(display_name=display_name, email=email) ) # Determine parent hierarchy node parent_hierarchy_raw_node_id = self._get_parent_hierarchy_raw_id(page) # Create the document return Document( id=page_url, sections=sections, source=DocumentSource.CONFLUENCE, semantic_identifier=page_title, metadata=metadata, doc_updated_at=datetime_from_string(page["version"]["when"]), primary_owners=primary_owners if primary_owners else None, parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id, ) except Exception as e: logger.error(f"Error converting page {page.get('id', 'unknown')}: {e}") if is_atlassian_date_error(e): # propagate error to be caught and retried raise return ConnectorFailure( failed_document=DocumentFailure( document_id=page_id, document_link=page_url, ), failure_message=f"Error converting page {page.get('id', 'unknown')}: {e}", exception=e, ) def _fetch_page_attachments( self, page: dict[str, Any], start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> tuple[list[Document | HierarchyNode], list[ConnectorFailure]]: """ Inline attachments are added directly to the document as text or image sections by this function. The returned documents/connectorfailures are for non-inline attachments and those at the end of the page. If there are valid attachments, the page itself is yielded as a hierarchy node (since attachments are children of the page in the hierarchy). """ attachment_query = self._construct_attachment_query( _get_page_id(page), start, end ) attachment_failures: list[ConnectorFailure] = [] attachment_docs: list[Document | HierarchyNode] = [] page_url = "" page_hierarchy_node_yielded = False try: for attachment in self.confluence_client.paginated_cql_retrieval( cql=attachment_query, expand=",".join(_ATTACHMENT_EXPANSION_FIELDS), ): media_type: str = attachment.get("metadata", {}).get("mediaType", "") # TODO(rkuo): this check is partially redundant with validate_attachment_filetype # and checks in convert_attachment_to_content/process_attachment # but doing the check here avoids an unnecessary download. Due for refactoring. if not self.allow_images: if media_type.startswith("image/"): logger.info( f"Skipping attachment because allow images is False: {attachment['title']}" ) continue if not validate_attachment_filetype( attachment, ): logger.info( f"Skipping attachment because it is not an accepted file type: {attachment['title']}" ) continue logger.info( f"Processing attachment: {attachment['title']} attached to page {page['title']}" ) # Attachment document id: use the download URL for stable identity try: object_url = build_confluence_document_id( self.wiki_base, attachment["_links"]["download"], self.is_cloud ) except Exception as e: logger.warning( f"Invalid attachment url for id {attachment['id']}, skipping" ) logger.debug(f"Error building attachment url: {e}") continue try: response = convert_attachment_to_content( confluence_client=self.confluence_client, attachment=attachment, page_id=_get_page_id(page), allow_images=self.allow_images, ) if response is None: continue content_text, file_storage_name = response sections: list[TextSection | ImageSection] = [] if content_text: sections.append(TextSection(text=content_text, link=object_url)) elif file_storage_name: sections.append( ImageSection( link=object_url, image_file_id=file_storage_name ) ) # Build attachment-specific metadata attachment_metadata: dict[str, str | list[str]] = {} if "space" in attachment: attachment_metadata["space"] = attachment["space"].get( "name", "" ) labels: list[str] = [] if "metadata" in attachment and "labels" in attachment["metadata"]: for label in attachment["metadata"]["labels"].get( "results", [] ): labels.append(label.get("name", "")) if labels: attachment_metadata["labels"] = labels page_url = page_url or build_confluence_document_id( self.wiki_base, page["_links"]["webui"], self.is_cloud ) attachment_metadata["parent_page_id"] = page_url attachment_id = build_confluence_document_id( self.wiki_base, attachment["_links"]["webui"], self.is_cloud ) primary_owners: list[BasicExpertInfo] | None = None if "version" in attachment and "by" in attachment["version"]: author = attachment["version"]["by"] display_name = author.get("displayName", "Unknown") email = author.get("email", "unknown@domain.invalid") primary_owners = [ BasicExpertInfo(display_name=display_name, email=email) ] # Attachments have their parent page as the hierarchy parent # Use page URL to match the hierarchy node's raw_node_id attachment_parent_hierarchy_raw_id = page_url attachment_doc = Document( id=attachment_id, sections=sections, source=DocumentSource.CONFLUENCE, semantic_identifier=attachment.get("title", object_url), metadata=attachment_metadata, doc_updated_at=( datetime_from_string(attachment["version"]["when"]) if attachment.get("version") and attachment["version"].get("when") else None ), primary_owners=primary_owners, parent_hierarchy_raw_node_id=attachment_parent_hierarchy_raw_id, ) # If this is the first valid attachment, yield the page as a # hierarchy node (attachments are children of the page) if not page_hierarchy_node_yielded: page_hierarchy_node = self._maybe_yield_page_hierarchy_node( page ) if page_hierarchy_node: attachment_docs.append(page_hierarchy_node) page_hierarchy_node_yielded = True attachment_docs.append(attachment_doc) except Exception as e: logger.error( f"Failed to extract/summarize attachment {attachment['title']}", exc_info=e, ) if is_atlassian_date_error(e): # propagate error to be caught and retried raise attachment_failures.append( ConnectorFailure( failed_document=DocumentFailure( document_id=object_url, document_link=object_url, ), failure_message=f"Failed to extract/summarize attachment {attachment['title']} for doc {object_url}", exception=e, ) ) except HTTPError as e: # If we get a 403 after all retries, the user likely doesn't have permission # to access attachments on this page. Log and skip rather than failing the whole job. page_id = _get_page_id(page, allow_missing=True) page_title = page.get("title", "unknown") if e.response and e.response.status_code in [401, 403]: failure_message_prefix = ( "Invalid credentials (401)" if e.response.status_code == 401 else "Permission denied (403)" ) failure_message = ( f"{failure_message_prefix} when fetching attachments for page '{page_title}' " f"(ID: {page_id}). The user may not have permission to query attachments on this page. " "Skipping attachments for this page." ) logger.warning(failure_message) # Build the page URL for the failure record try: page_url = build_confluence_document_id( self.wiki_base, page["_links"]["webui"], self.is_cloud ) except Exception: page_url = f"page_id:{page_id}" return [], [ ConnectorFailure( failed_document=DocumentFailure( document_id=page_id, document_link=page_url, ), failure_message=failure_message, exception=e, ) ] else: raise return attachment_docs, attachment_failures def _fetch_document_batches( self, checkpoint: ConfluenceCheckpoint, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> CheckpointOutput[ConfluenceCheckpoint]: """ Yields batches of Documents and HierarchyNodes. For each page: - Yield hierarchy nodes for spaces and ancestor pages (parent-before-child ordering) - Create a Document with 1 Section for the page text/comments - Then fetch attachments. For each attachment: - Attempt to convert it with convert_attachment_to_content(...) - If successful, create a new Section with the extracted text or summary. """ checkpoint = copy.deepcopy(checkpoint) # Yield space hierarchy nodes FIRST (only once per connector run) if not checkpoint.next_page_url: yield from self._yield_space_hierarchy_nodes() # use "start" when last_updated is 0 or for confluence server start_ts = start page_query_url = checkpoint.next_page_url or self._build_page_retrieval_url( start_ts, end, self.batch_size ) logger.debug(f"page_query_url: {page_query_url}") # store the next page start for confluence server, cursor for confluence cloud def store_next_page_url(next_page_url: str) -> None: checkpoint.next_page_url = next_page_url for page in self.confluence_client.paginated_page_retrieval( cql_url=page_query_url, limit=self.batch_size, next_page_callback=store_next_page_url, ): # Yield hierarchy nodes for all ancestors (parent-before-child ordering) yield from self._yield_ancestor_hierarchy_nodes(page) # Build doc from page doc_or_failure = self._convert_page_to_document(page) if isinstance(doc_or_failure, ConnectorFailure): yield doc_or_failure continue # yield completed document (or failure) yield doc_or_failure # Now get attachments for that page: attachment_docs, attachment_failures = self._fetch_page_attachments( page, start, end ) # yield attached docs and failures yield from attachment_docs yield from attachment_failures # Create checkpoint once a full page of results is returned if checkpoint.next_page_url and checkpoint.next_page_url != page_query_url: return checkpoint checkpoint.has_more = False return checkpoint def _build_page_retrieval_url( self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None, limit: int, ) -> str: """ Builds the full URL used to retrieve pages from the confluence API. This can be used as input to the confluence client's _paginate_url or paginated_page_retrieval methods. """ page_query = self._construct_page_cql_query(start, end) cql_url = self.confluence_client.build_cql_url( page_query, expand=",".join(_PAGE_EXPANSION_FIELDS) ) return update_param_in_path(cql_url, "limit", str(limit)) @override def load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: ConfluenceCheckpoint, ) -> CheckpointOutput[ConfluenceCheckpoint]: end += ONE_DAY # handle time zone weirdness try: return self._fetch_document_batches(checkpoint, start, end) except Exception as e: if is_atlassian_date_error(e) and start is not None: logger.warning( "Confluence says we provided an invalid 'updated' field. This may indicate" "a real issue, but can also appear during edge cases like daylight" f"savings time changes. Retrying with a 1 hour offset. Error: {e}" ) return self._fetch_document_batches(checkpoint, start - ONE_HOUR, end) raise @override def build_dummy_checkpoint(self) -> ConfluenceCheckpoint: return ConfluenceCheckpoint(has_more=True, next_page_url=None) @override def validate_checkpoint_json(self, checkpoint_json: str) -> ConfluenceCheckpoint: return ConfluenceCheckpoint.model_validate_json(checkpoint_json) @override def retrieve_all_slim_docs( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> GenerateSlimDocumentOutput: return self._retrieve_all_slim_docs( start=start, end=end, callback=callback, include_permissions=False, ) def retrieve_all_slim_docs_perm_sync( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> GenerateSlimDocumentOutput: """ Return 'slim' docs (IDs + minimal permission data). Does not fetch actual text. Used primarily for incremental permission sync. """ return self._retrieve_all_slim_docs( start=start, end=end, callback=callback, include_permissions=True, ) def _retrieve_all_slim_docs( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, include_permissions: bool = True, ) -> GenerateSlimDocumentOutput: doc_metadata_list: list[SlimDocument | HierarchyNode] = [] restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS) space_level_access_info: dict[str, ExternalAccess] = {} if include_permissions: space_level_access_info = get_all_space_permissions( self.confluence_client, self.is_cloud ) # Yield space hierarchy nodes first for node in self._yield_space_hierarchy_nodes(): doc_metadata_list.append(node) def get_external_access( doc_id: str, restrictions: dict[str, Any], ancestors: list[dict[str, Any]] ) -> ExternalAccess | None: return get_page_restrictions( self.confluence_client, doc_id, restrictions, ancestors ) or space_level_access_info.get(page_space_key) # Query pages (with optional time filtering for indexing_start) page_query = self._construct_page_cql_query(start, end) for page in self.confluence_client.cql_paginate_all_expansions( cql=page_query, expand=restrictions_expand, limit=_SLIM_DOC_BATCH_SIZE, ): # Yield ancestor hierarchy nodes for this page for node in self._yield_ancestor_hierarchy_nodes(page): doc_metadata_list.append(node) page_id = _get_page_id(page) page_restrictions = page.get("restrictions") or {} page_space_key = page.get("space", {}).get("key") page_ancestors = page.get("ancestors", []) page_id = build_confluence_document_id( self.wiki_base, page["_links"]["webui"], self.is_cloud ) doc_metadata_list.append( SlimDocument( id=page_id, external_access=( get_external_access(page_id, page_restrictions, page_ancestors) if include_permissions else None ), parent_hierarchy_raw_node_id=self._get_parent_hierarchy_raw_id( page ), ) ) # Query attachments for each page page_hierarchy_node_yielded = False attachment_query = self._construct_attachment_query( _get_page_id(page), start, end ) for attachment in self.confluence_client.cql_paginate_all_expansions( cql=attachment_query, expand=restrictions_expand, limit=_SLIM_DOC_BATCH_SIZE, ): # If you skip images, you'll skip them in the permission sync attachment["metadata"].get("mediaType", "") if not validate_attachment_filetype( attachment, ): continue # If this page has valid attachments and we haven't yielded it as a # hierarchy node yet, do so now (attachments are children of the page) if not page_hierarchy_node_yielded: page_node = self._maybe_yield_page_hierarchy_node(page) if page_node: doc_metadata_list.append(page_node) page_hierarchy_node_yielded = True attachment_restrictions = attachment.get("restrictions", {}) if not attachment_restrictions: attachment_restrictions = page_restrictions or {} attachment_space_key = attachment.get("space", {}).get("key") if not attachment_space_key: attachment_space_key = page_space_key attachment_id = build_confluence_document_id( self.wiki_base, attachment["_links"]["webui"], self.is_cloud, ) doc_metadata_list.append( SlimDocument( id=attachment_id, external_access=( get_external_access( attachment_id, attachment_restrictions, [] ) if include_permissions else None ), parent_hierarchy_raw_node_id=page_id, ) ) if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE: yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE] doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:] if callback and callback.should_stop(): raise RuntimeError( "retrieve_all_slim_docs_perm_sync: Stop signal detected" ) if callback: callback.progress("retrieve_all_slim_docs_perm_sync", 1) yield doc_metadata_list def validate_connector_settings(self) -> None: try: spaces_iter = self.low_timeout_confluence_client.retrieve_confluence_spaces( limit=1, ) first_space = next(spaces_iter, None) except HTTPError as e: status_code = e.response.status_code if e.response else None if status_code == 401: raise CredentialExpiredError( "Invalid or expired Confluence credentials (HTTP 401)." ) elif status_code == 403: raise InsufficientPermissionsError( "Insufficient permissions to access Confluence resources (HTTP 403)." ) raise UnexpectedValidationError( f"Unexpected Confluence error (status={status_code}): {e}" ) except Exception as e: raise UnexpectedValidationError( f"Unexpected error while validating Confluence settings: {e}" ) if not first_space: raise ConnectorValidationError( "No Confluence spaces found. Either your credentials lack permissions, or " "there truly are no spaces in this Confluence instance." ) if self.space: try: self.low_timeout_confluence_client.get_space(self.space) except ApiError as e: raise ConnectorValidationError( "Invalid Confluence space key provided" ) from e if __name__ == "__main__": import os from onyx.utils.variable_functionality import global_version from tests.daily.connectors.utils import load_all_from_connector # For connector permission testing, set EE to true. global_version.set_ee() # base url wiki_base = os.environ["CONFLUENCE_URL"] # auth stuff username = os.environ["CONFLUENCE_USERNAME"] access_token = os.environ["CONFLUENCE_ACCESS_TOKEN"] is_cloud = os.environ["CONFLUENCE_IS_CLOUD"].lower() == "true" # space + page space = os.environ["CONFLUENCE_SPACE_KEY"] # page_id = os.environ["CONFLUENCE_PAGE_ID"] confluence_connector = ConfluenceConnector( wiki_base=wiki_base, space=space, is_cloud=is_cloud, # page_id=page_id, ) credentials_provider = OnyxStaticCredentialsProvider( None, DocumentSource.CONFLUENCE, { "confluence_username": username, "confluence_access_token": access_token, }, ) confluence_connector.set_credentials_provider(credentials_provider) start = 0.0 end = datetime.now().timestamp() # Fetch all `SlimDocuments`. for slim_doc in confluence_connector.retrieve_all_slim_docs_perm_sync(): print(slim_doc) # Fetch all `Documents`. for doc in load_all_from_connector( connector=confluence_connector, start=start, end=end, ).documents: print(doc) ================================================ FILE: backend/onyx/connectors/confluence/models.py ================================================ from pydantic import BaseModel class ConfluenceUser(BaseModel): user_id: str # accountId in Cloud, userKey in Server username: str | None # Confluence Cloud doesn't give usernames display_name: str # Confluence Data Center doesn't give email back by default, # have to fetch it with a different endpoint email: str | None type: str ================================================ FILE: backend/onyx/connectors/confluence/onyx_confluence.py ================================================ """ # README (notes on Confluence pagination): We've noticed that the `search/users` and `users/memberof` endpoints for Confluence Cloud use offset-based pagination as opposed to cursor-based. We also know that page-retrieval uses cursor-based pagination. Our default pagination strategy right now for cloud is to assume cursor-based. However, if you notice that a cloud API is not being properly paginated (i.e., if the `_links.next` is not appearing in the returned payload), then you can force offset-based pagination. # TODO (@raunakab) We haven't explored all of the cloud APIs' pagination strategies. @raunakab take time to go through this and figure them out. """ import json import time from collections.abc import Callable from collections.abc import Generator from collections.abc import Iterator from datetime import datetime from datetime import timedelta from datetime import timezone from typing import Any from typing import cast from typing import TypeVar from urllib.parse import quote import bs4 from atlassian import Confluence # type:ignore from redis import Redis from requests import HTTPError from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE from onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID from onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET from onyx.connectors.confluence.models import ConfluenceUser from onyx.connectors.confluence.user_profile_override import ( process_confluence_user_profiles_override, ) from onyx.connectors.confluence.utils import _handle_http_error from onyx.connectors.confluence.utils import confluence_refresh_tokens from onyx.connectors.confluence.utils import get_start_param_from_url from onyx.connectors.confluence.utils import update_param_in_path from onyx.connectors.cross_connector_utils.miscellaneous_utils import scoped_url from onyx.connectors.interfaces import CredentialsProviderInterface from onyx.file_processing.html_utils import format_document_soup from onyx.redis.redis_pool import get_redis_client from onyx.utils.logger import setup_logger logger = setup_logger() F = TypeVar("F", bound=Callable[..., Any]) # https://jira.atlassian.com/browse/CONFCLOUD-76433 _PROBLEMATIC_EXPANSIONS = "body.storage.value" _REPLACEMENT_EXPANSIONS = "body.view.value" _USER_NOT_FOUND = "Unknown Confluence User" _USER_ID_TO_DISPLAY_NAME_CACHE: dict[str, str | None] = {} _USER_EMAIL_CACHE: dict[str, str | None] = {} _DEFAULT_PAGINATION_LIMIT = 1000 _CONFLUENCE_SPACES_API_V1 = "rest/api/space" _CONFLUENCE_SPACES_API_V2 = "wiki/api/v2/spaces" class ConfluenceRateLimitError(Exception): pass class OnyxConfluence: """ This is a custom Confluence class that: A. overrides the default Confluence class to add a custom CQL method. B. This is necessary because the default Confluence class does not properly support cql expansions. All methods are automatically wrapped with handle_confluence_rate_limit. """ CREDENTIAL_PREFIX = "connector:confluence:credential" CREDENTIAL_TTL = 300 # 5 min PROBE_TIMEOUT = 5 # 5 seconds def __init__( self, is_cloud: bool, url: str, credentials_provider: CredentialsProviderInterface, timeout: int | None = None, scoped_token: bool = False, # should generally not be passed in, but making it overridable for # easier testing confluence_user_profiles_override: list[dict[str, str]] | None = ( CONFLUENCE_CONNECTOR_USER_PROFILES_OVERRIDE ), ) -> None: self.base_url = url #'/'.join(url.rstrip("/").split("/")[:-1]) url = scoped_url(url, "confluence") if scoped_token else url self._is_cloud = is_cloud self._url = url.rstrip("/") self._credentials_provider = credentials_provider self.scoped_token = scoped_token self.redis_client: Redis | None = None self.static_credentials: dict[str, Any] | None = None if self._credentials_provider.is_dynamic(): self.redis_client = get_redis_client( tenant_id=credentials_provider.get_tenant_id() ) else: self.static_credentials = self._credentials_provider.get_credentials() self._confluence = Confluence(url) self.credential_key: str = ( self.CREDENTIAL_PREFIX + f":credential_{self._credentials_provider.get_provider_key()}" ) self._kwargs: Any = None self.shared_base_kwargs: dict[str, str | int | bool] = { "api_version": "cloud" if is_cloud else "latest", "backoff_and_retry": False, "cloud": is_cloud, } if timeout: self.shared_base_kwargs["timeout"] = timeout self._confluence_user_profiles_override = ( process_confluence_user_profiles_override(confluence_user_profiles_override) if confluence_user_profiles_override else None ) def _renew_credentials(self) -> tuple[dict[str, Any], bool]: """credential_json - the current json credentials Returns a tuple 1. The up to date credentials 2. True if the credentials were updated This method is intended to be used within a distributed lock. Lock, call this, update credentials if the tokens were refreshed, then release """ # static credentials are preloaded, so no locking/redis required if self.static_credentials: return self.static_credentials, False if not self.redis_client: raise RuntimeError("self.redis_client is None") # dynamic credentials need locking # check redis first, then fallback to the DB credential_raw = self.redis_client.get(self.credential_key) if credential_raw is not None: credential_bytes = cast(bytes, credential_raw) credential_str = credential_bytes.decode("utf-8") credential_json: dict[str, Any] = json.loads(credential_str) else: credential_json = self._credentials_provider.get_credentials() if "confluence_refresh_token" not in credential_json: # static credentials ... cache them permanently and return self.static_credentials = credential_json return credential_json, False if not OAUTH_CONFLUENCE_CLOUD_CLIENT_ID: raise RuntimeError("OAUTH_CONFLUENCE_CLOUD_CLIENT_ID must be set!") if not OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET: raise RuntimeError("OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET must be set!") # check if we should refresh tokens. we're deciding to refresh halfway # to expiration now = datetime.now(timezone.utc) created_at = datetime.fromisoformat(credential_json["created_at"]) expires_in: int = credential_json["expires_in"] renew_at = created_at + timedelta(seconds=expires_in // 2) if now <= renew_at: # cached/current credentials are reasonably up to date return credential_json, False # we need to refresh logger.info("Renewing Confluence Cloud credentials...") new_credentials = confluence_refresh_tokens( OAUTH_CONFLUENCE_CLOUD_CLIENT_ID, OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET, credential_json["cloud_id"], credential_json["confluence_refresh_token"], ) # store the new credentials to redis and to the db thru the provider # redis: we use a 5 min TTL because we are given a 10 minute grace period # when keys are rotated. it's easier to expire the cached credentials # reasonably frequently rather than trying to handle strong synchronization # between the db and redis everywhere the credentials might be updated new_credential_str = json.dumps(new_credentials) self.redis_client.set( self.credential_key, new_credential_str, nx=True, ex=self.CREDENTIAL_TTL ) self._credentials_provider.set_credentials(new_credentials) return new_credentials, True @staticmethod def _make_oauth2_dict(credentials: dict[str, Any]) -> dict[str, Any]: oauth2_dict: dict[str, Any] = {} if "confluence_refresh_token" in credentials: oauth2_dict["client_id"] = OAUTH_CONFLUENCE_CLOUD_CLIENT_ID oauth2_dict["token"] = {} oauth2_dict["token"]["access_token"] = credentials[ "confluence_access_token" ] return oauth2_dict def _build_spaces_url( self, is_v2: bool, base_url: str, limit: int, space_keys: list[str] | None, start: int | None = None, ) -> str: """Build URL for Confluence spaces API with query parameters.""" key_param = "keys" if is_v2 else "spaceKey" params = [f"limit={limit}"] if space_keys: params.append(f"{key_param}={','.join(space_keys)}") if start is not None and not is_v2: params.append(f"start={start}") return f"{base_url}?{'&'.join(params)}" def _paginate_spaces_for_endpoint( self, is_v2: bool, base_url: str, limit: int, space_keys: list[str] | None, ) -> Iterator[dict[str, Any]]: """Internal helper to paginate through spaces for a specific API endpoint.""" start = 0 url = self._build_spaces_url( is_v2, base_url, limit, space_keys, start if not is_v2 else None ) while url: response = self.get(url, advanced_mode=True) response.raise_for_status() data = response.json() results = data.get("results", []) if not results: return yield from results if is_v2: url = data.get("_links", {}).get("next", "") else: if len(results) < limit: return start += len(results) url = self._build_spaces_url(is_v2, base_url, limit, space_keys, start) def retrieve_confluence_spaces( self, space_keys: list[str] | None = None, limit: int = 50, ) -> Iterator[dict[str, str]]: """ Retrieve spaces from Confluence using v2 API (Cloud) or v1 API (Server/fallback). Args: space_keys: Optional list of space keys to filter by limit: Results per page (default 50) Yields: Space dictionaries with keys: id, key, name, type, status, etc. Note: For Cloud instances, attempts v2 API first. If v2 returns 404, automatically falls back to v1 API for compatibility with older instances. """ # Determine API version once use_v2 = self._is_cloud and not self.scoped_token base_url = _CONFLUENCE_SPACES_API_V2 if use_v2 else _CONFLUENCE_SPACES_API_V1 try: yield from self._paginate_spaces_for_endpoint( use_v2, base_url, limit, space_keys ) except HTTPError as e: if e.response.status_code == 404 and use_v2: logger.warning( "v2 spaces API returned 404, falling back to v1 API. This may indicate an older Confluence Cloud instance." ) # Fallback to v1 yield from self._paginate_spaces_for_endpoint( False, _CONFLUENCE_SPACES_API_V1, limit, space_keys ) else: raise def _probe_connection( self, **kwargs: Any, ) -> None: merged_kwargs = {**self.shared_base_kwargs, **kwargs} # add special timeout to make sure that we don't hang indefinitely merged_kwargs["timeout"] = self.PROBE_TIMEOUT with self._credentials_provider: credentials, _ = self._renew_credentials() if self.scoped_token: # v2 endpoint doesn't always work with scoped tokens, use v1 token = credentials["confluence_access_token"] probe_url = f"{self.base_url}/{_CONFLUENCE_SPACES_API_V1}?limit=1" import requests try: r = requests.get( probe_url, headers={"Authorization": f"Bearer {token}"}, timeout=10, ) r.raise_for_status() except HTTPError as e: if e.response.status_code == 403: logger.warning( "scoped token authenticated but not valid for probe endpoint (spaces)" ) else: if "WWW-Authenticate" in e.response.headers: logger.warning( f"WWW-Authenticate: {e.response.headers['WWW-Authenticate']}" ) logger.warning(f"Full error: {e.response.text}") raise e return # Initialize connection with probe timeout settings self._confluence = self._initialize_connection_helper( credentials, **merged_kwargs ) # Retrieve first space to validate connection spaces_iter = self.retrieve_confluence_spaces(limit=1) first_space = next(spaces_iter, None) if not first_space: raise RuntimeError( f"No spaces found at {self._url}! Check your credentials and wiki_base and make sure is_cloud is set correctly." ) logger.info("Confluence probe succeeded.") def _initialize_connection( self, **kwargs: Any, ) -> None: """Called externally to init the connection in a thread safe manner.""" merged_kwargs = {**self.shared_base_kwargs, **kwargs} with self._credentials_provider: credentials, _ = self._renew_credentials() self._confluence = self._initialize_connection_helper( credentials, **merged_kwargs ) self._kwargs = merged_kwargs def _initialize_connection_helper( self, credentials: dict[str, Any], **kwargs: Any, ) -> Confluence: """Called internally to init the connection. Distributed locking to prevent multiple threads from modifying the credentials must be handled around this function.""" confluence = None # probe connection with direct client, no retries if "confluence_refresh_token" in credentials: logger.info("Connecting to Confluence Cloud with OAuth Access Token.") oauth2_dict: dict[str, Any] = OnyxConfluence._make_oauth2_dict(credentials) url = f"https://api.atlassian.com/ex/confluence/{credentials['cloud_id']}" confluence = Confluence(url=url, oauth2=oauth2_dict, **kwargs) else: logger.info( f"Connecting to Confluence with Personal Access Token as user: {credentials['confluence_username']}" ) if self._is_cloud: confluence = Confluence( url=self._url, username=credentials["confluence_username"], password=credentials["confluence_access_token"], **kwargs, ) else: confluence = Confluence( url=self._url, token=credentials["confluence_access_token"], **kwargs, ) return confluence # https://developer.atlassian.com/cloud/confluence/rate-limiting/ # This uses the native rate limiting option provided by the # confluence client and otherwise applies a simpler set of error handling. def _make_rate_limited_confluence_method( self, name: str, credential_provider: CredentialsProviderInterface | None ) -> Callable[..., Any]: def wrapped_call(*args: list[Any], **kwargs: Any) -> Any: MAX_RETRIES = 5 TIMEOUT = 600 timeout_at = time.monotonic() + TIMEOUT for attempt in range(MAX_RETRIES): if time.monotonic() > timeout_at: raise TimeoutError( f"Confluence call attempts took longer than {TIMEOUT} seconds." ) # we're relying more on the client to rate limit itself # and applying our own retries in a more specific set of circumstances try: if credential_provider: with credential_provider: credentials, renewed = self._renew_credentials() if renewed: self._confluence = self._initialize_connection_helper( credentials, **self._kwargs ) attr = getattr(self._confluence, name, None) if attr is None: # The underlying Confluence client doesn't have this attribute raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'" ) return attr(*args, **kwargs) else: attr = getattr(self._confluence, name, None) if attr is None: # The underlying Confluence client doesn't have this attribute raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'" ) return attr(*args, **kwargs) except HTTPError as e: delay_until = _handle_http_error(e, attempt, MAX_RETRIES) logger.warning( f"HTTPError in confluence call. Retrying in {delay_until} seconds..." ) while time.monotonic() < delay_until: # in the future, check a signal here to exit time.sleep(1) except AttributeError as e: # Some error within the Confluence library, unclear why it fails. # Users reported it to be intermittent, so just retry if attempt == MAX_RETRIES - 1: raise e logger.exception( "Confluence Client raised an AttributeError. Retrying..." ) time.sleep(5) return wrapped_call def __getattr__(self, name: str) -> Any: """Dynamically intercept attribute/method access.""" attr = getattr(self._confluence, name, None) if attr is None: # The underlying Confluence client doesn't have this attribute raise AttributeError( f"'{type(self).__name__}' object has no attribute '{name}'" ) # If it's not a method, just return it after ensuring token validity if not callable(attr): return attr # skip methods that start with "_" if name.startswith("_"): return attr # wrap the method with our retry handler rate_limited_method: Callable[..., Any] = ( self._make_rate_limited_confluence_method(name, self._credentials_provider) ) return rate_limited_method def _try_one_by_one_for_paginated_url( self, url_suffix: str, initial_start: int, limit: int, ) -> Generator[dict[str, Any], None, str | None]: """ Go through `limit` items, starting at `initial_start` one by one (e.g. using `limit=1` for each call). If we encounter an error, we skip the item and try the next one. We will return the items we were able to retrieve successfully. Returns the expected next url_suffix. Returns None if it thinks we've hit the end. TODO (chris): make this yield failures as well as successes. TODO (chris): make this work for confluence cloud somehow. """ if self._is_cloud: raise RuntimeError("This method is not implemented for Confluence Cloud.") found_empty_page = False temp_url_suffix = url_suffix for ind in range(limit): try: temp_url_suffix = update_param_in_path( url_suffix, "start", str(initial_start + ind) ) temp_url_suffix = update_param_in_path(temp_url_suffix, "limit", "1") logger.info(f"Making recovery confluence call to {temp_url_suffix}") raw_response = self.get(path=temp_url_suffix, advanced_mode=True) raw_response.raise_for_status() latest_results = raw_response.json().get("results", []) yield from latest_results if not latest_results: # no more results, break out of the loop logger.info( f"No results found for call '{temp_url_suffix}'Stopping pagination." ) found_empty_page = True break except Exception: logger.exception( f"Error in confluence call to {temp_url_suffix}. Continuing." ) if found_empty_page: return None # if we got here, we successfully tried `limit` items return update_param_in_path(url_suffix, "start", str(initial_start + limit)) def _paginate_url( self, url_suffix: str, limit: int | None = None, # Called with the next url to use to get the next page next_page_callback: Callable[[str], None] | None = None, force_offset_pagination: bool = False, ) -> Iterator[dict[str, Any]]: """ This will paginate through the top level query. """ if not limit: limit = _DEFAULT_PAGINATION_LIMIT url_suffix = update_param_in_path(url_suffix, "limit", str(limit)) while url_suffix: logger.debug(f"Making confluence call to {url_suffix}") try: # Only pass params if they're not already in the URL to avoid duplicate # params accumulating. Confluence's _links.next already includes these. params = {} if "body-format=" not in url_suffix: params["body-format"] = "atlas_doc_format" if "expand=" not in url_suffix: params["expand"] = "body.atlas_doc_format" raw_response = self.get( path=url_suffix, advanced_mode=True, params=params, ) except Exception as e: logger.exception(f"Error in confluence call to {url_suffix}") raise e try: raw_response.raise_for_status() except Exception as e: logger.warning(f"Error in confluence call to {url_suffix}") # If the problematic expansion is in the url, replace it # with the replacement expansion and try again # If that fails, raise the error if _PROBLEMATIC_EXPANSIONS in url_suffix: logger.warning( f"Replacing {_PROBLEMATIC_EXPANSIONS} with {_REPLACEMENT_EXPANSIONS} and trying again." ) url_suffix = url_suffix.replace( _PROBLEMATIC_EXPANSIONS, _REPLACEMENT_EXPANSIONS, ) continue # If we fail due to a 500, try one by one. # NOTE: this iterative approach only works for server, since cloud uses cursor-based # pagination if raw_response.status_code == 500 and not self._is_cloud: initial_start = get_start_param_from_url(url_suffix) if initial_start is None: # can't handle this if we don't have offset-based pagination raise # this will just yield the successful items from the batch new_url_suffix = yield from self._try_one_by_one_for_paginated_url( url_suffix, initial_start=initial_start, limit=limit, ) # this means we ran into an empty page if new_url_suffix is None: if next_page_callback: next_page_callback("") break url_suffix = new_url_suffix continue else: logger.exception( f"Error in confluence call to {url_suffix} \n" f"Raw Response Text: {raw_response.text} \n" f"Full Response: {raw_response.__dict__} \n" f"Error: {e} \n" ) raise try: next_response = raw_response.json() except Exception as e: logger.exception( f"Failed to parse response as JSON. Response: {raw_response.__dict__}" ) raise e # Yield the results individually. results = cast(list[dict[str, Any]], next_response.get("results", [])) # Note 1: # Make sure we don't update the start by more than the amount # of results we were able to retrieve. The Confluence API has a # weird behavior where if you pass in a limit that is too large for # the configured server, it will artificially limit the amount of # results returned BUT will not apply this to the start parameter. # This will cause us to miss results. # # Note 2: # We specifically perform manual yielding (i.e., `for x in xs: yield x`) as opposed to using a `yield from xs` # because we *have to call the `next_page_callback`* prior to yielding the last element! # # If we did: # # ```py # yield from results # if next_page_callback: # next_page_callback(url_suffix) # ``` # # then the logic would fail since the iterator would finish (and the calling scope would exit out of its driving # loop) prior to the callback being called. old_url_suffix = url_suffix updated_start = get_start_param_from_url(old_url_suffix) url_suffix = cast(str, next_response.get("_links", {}).get("next", "")) for i, result in enumerate(results): updated_start += 1 if url_suffix and next_page_callback and i == len(results) - 1: # update the url if we're on the last result in the page if not self._is_cloud: # If confluence claims there are more results, we update the start param # based on how many results were returned and try again. url_suffix = update_param_in_path( url_suffix, "start", str(updated_start) ) # notify the caller of the new url next_page_callback(url_suffix) elif force_offset_pagination and i == len(results) - 1: url_suffix = update_param_in_path( old_url_suffix, "start", str(updated_start) ) yield result # we've observed that Confluence sometimes returns a next link despite giving # 0 results. This is a bug with Confluence, so we need to check for it and # stop paginating. if url_suffix and not results: logger.info( f"No results found for call '{old_url_suffix}' despite next link being present. Stopping pagination." ) break def build_cql_url(self, cql: str, expand: str | None = None) -> str: expand_string = f"&expand={expand}" if expand else "" return f"rest/api/content/search?cql={cql}{expand_string}" def paginated_cql_retrieval( self, cql: str, expand: str | None = None, limit: int | None = None, ) -> Iterator[dict[str, Any]]: """ The content/search endpoint can be used to fetch pages, attachments, and comments. """ cql_url = self.build_cql_url(cql, expand) yield from self._paginate_url(cql_url, limit) def paginated_page_retrieval( self, cql_url: str, limit: int, # Called with the next url to use to get the next page next_page_callback: Callable[[str], None] | None = None, ) -> Iterator[dict[str, Any]]: """ Error handling (and testing) wrapper for _paginate_url, because the current approach to page retrieval involves handling the next page links manually. """ try: yield from self._paginate_url( cql_url, limit=limit, next_page_callback=next_page_callback ) except Exception as e: logger.exception(f"Error in paginated_page_retrieval: {e}") raise e def cql_paginate_all_expansions( self, cql: str, expand: str | None = None, limit: int | None = None, ) -> Iterator[dict[str, Any]]: """ This function will paginate through the top level query first, then paginate through all of the expansions. """ def _traverse_and_update(data: dict | list) -> None: if isinstance(data, dict): next_url = data.get("_links", {}).get("next") if next_url and "results" in data: data["results"].extend(self._paginate_url(next_url, limit=limit)) for value in data.values(): _traverse_and_update(value) elif isinstance(data, list): for item in data: _traverse_and_update(item) for confluence_object in self.paginated_cql_retrieval(cql, expand, limit): _traverse_and_update(confluence_object) yield confluence_object def paginated_cql_user_retrieval( self, expand: str | None = None, limit: int | None = None, ) -> Iterator[ConfluenceUser]: """ The search/user endpoint can be used to fetch users. It's a separate endpoint from the content/search endpoint used only for users. Otherwise it's very similar to the content/search endpoint. """ # this is needed since there is a live bug with Confluence Server/Data Center # where not all users are returned by the APIs. This is a workaround needed until # that is patched. if self._confluence_user_profiles_override: yield from self._confluence_user_profiles_override elif self._is_cloud: cql = "type=user" url = "rest/api/search/user" expand_string = f"&expand={expand}" if expand else "" url += f"?cql={cql}{expand_string}" for user_result in self._paginate_url( url, limit, force_offset_pagination=True ): # Example response: # { # 'user': { # 'type': 'known', # 'accountId': '712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d', # 'accountType': 'atlassian', # 'email': 'chris@danswer.ai', # 'publicName': 'Chris Weaver', # 'profilePicture': { # 'path': '/wiki/aa-avatar/712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d', # 'width': 48, # 'height': 48, # 'isDefault': False # }, # 'displayName': 'Chris Weaver', # 'isExternalCollaborator': False, # '_expandable': { # 'operations': '', # 'personalSpace': '' # }, # '_links': { # 'self': 'https://danswerai.atlassian.net/wiki/rest/api/user?accountId=712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d' # } # }, # 'title': 'Chris Weaver', # 'excerpt': '', # 'url': '/people/712020:35e60fbb-d0f3-4c91-b8c1-f2dd1d69462d', # 'breadcrumbs': [], # 'entityType': 'user', # 'iconCssClass': 'aui-icon content-type-profile', # 'lastModified': '2025-02-18T04:08:03.579Z', # 'score': 0.0 # } user = user_result["user"] yield ConfluenceUser( user_id=user["accountId"], username=None, display_name=user["displayName"], email=user.get("email"), type=user["accountType"], ) else: for user in self._paginate_url("rest/api/user/list", limit): yield ConfluenceUser( user_id=user["userKey"], username=user["username"], display_name=user["displayName"], email=None, type=user.get("type", "user"), ) def paginated_groups_by_user_retrieval( self, user_id: str, # accountId in Cloud, userKey in Server limit: int | None = None, ) -> Iterator[dict[str, Any]]: """ This is not an SQL like query. It's a confluence specific endpoint that can be used to fetch groups. """ user_field = "accountId" if self._is_cloud else "key" user_value = user_id # Server uses userKey (but calls it key during the API call), Cloud uses accountId user_query = f"{user_field}={quote(user_value)}" url = f"rest/api/user/memberof?{user_query}" yield from self._paginate_url(url, limit, force_offset_pagination=True) def paginated_groups_retrieval( self, limit: int | None = None, ) -> Iterator[dict[str, Any]]: """ This is not an SQL like query. It's a confluence specific endpoint that can be used to fetch groups. """ yield from self._paginate_url("rest/api/group", limit) def paginated_group_members_retrieval( self, group_name: str, limit: int | None = None, ) -> Iterator[dict[str, Any]]: """ This is not an SQL like query. It's a confluence specific endpoint that can be used to fetch the members of a group. THIS DOESN'T WORK FOR SERVER because it breaks when there is a slash in the group name. E.g. neither "test/group" nor "test%2Fgroup" works for confluence. """ group_name = quote(group_name) yield from self._paginate_url(f"rest/api/group/{group_name}/member", limit) def get_all_space_permissions_server( self, space_key: str, ) -> list[dict[str, Any]]: """ This is a confluence server/data center specific method that can be used to fetch the permissions of a space. NOTE: This uses the JSON-RPC API which is the ONLY way to get space permissions on Confluence Server/Data Center. The REST API equivalent (expand=permissions) is Cloud-only and not available on Data Center as of version 8.9.x. If this fails with 401 Unauthorized, the customer needs to enable JSON-RPC: Confluence Admin -> General Configuration -> Further Configuration -> Enable "Remote API (XML-RPC & SOAP)" """ url = "rpc/json-rpc/confluenceservice-v2" data = { "jsonrpc": "2.0", "method": "getSpacePermissionSets", "id": 7, "params": [space_key], } try: response = self.post(url, data=data) except HTTPError as e: if e.response is not None and e.response.status_code == 401: raise HTTPError( "Unauthorized (401) when calling JSON-RPC API for space permissions. " "This is likely because the Remote API is disabled. " "To fix: Confluence Admin -> General Configuration -> Further Configuration " "-> Enable 'Remote API (XML-RPC & SOAP)'", response=e.response, ) from e raise logger.debug(f"jsonrpc response: {response}") if not response.get("result"): logger.warning( f"No jsonrpc response for space permissions for space {space_key}\nResponse: {response}" ) return response.get("result", []) def get_current_user(self, expand: str | None = None) -> Any: """ Implements a method that isn't in the third party client. Get information about the current user :param expand: OPTIONAL expand for get status of user. Possible param is "status". Results are "Active, Deactivated" :return: Returns the user details """ from atlassian.errors import ApiPermissionError # type:ignore url = "rest/api/user/current" params = {} if expand: params["expand"] = expand try: response = self.get(url, params=params) except HTTPError as e: if e.response.status_code == 403: raise ApiPermissionError( "The calling user does not have permission", reason=e ) raise return response def get_user_email_from_username__server( confluence_client: OnyxConfluence, user_name: str ) -> str | None: global _USER_EMAIL_CACHE if _USER_EMAIL_CACHE.get(user_name) is None: try: response = confluence_client.get_mobile_parameters(user_name) email = response.get("email") except HTTPError as e: status_code = e.response.status_code if e.response is not None else "N/A" logger.warning( f"Failed to get confluence email for {user_name}: HTTP {status_code} - {e}" ) # For now, we'll just return None and log a warning. This means # we will keep retrying to get the email every group sync. email = None except Exception as e: logger.warning( f"Failed to get confluence email for {user_name}: {type(e).__name__} - {e}" ) email = None _USER_EMAIL_CACHE[user_name] = email return _USER_EMAIL_CACHE[user_name] def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str: """Get Confluence Display Name based on the account-id or userkey value Args: user_id (str): The user id (i.e: the account-id or userkey) confluence_client (Confluence): The Confluence Client Returns: str: The User Display Name. 'Unknown User' if the user is deactivated or not found """ global _USER_ID_TO_DISPLAY_NAME_CACHE if _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) is None: try: result = confluence_client.get_user_details_by_userkey(user_id) found_display_name = result.get("displayName") except Exception: found_display_name = None if not found_display_name: try: result = confluence_client.get_user_details_by_accountid(user_id) found_display_name = result.get("displayName") except Exception: found_display_name = None _USER_ID_TO_DISPLAY_NAME_CACHE[user_id] = found_display_name return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND def sanitize_attachment_title(title: str) -> str: """ Sanitize the attachment title to be a valid HTML attribute. """ return title.replace("<", "_").replace(">", "_").replace(" ", "_").replace(":", "_") def extract_text_from_confluence_html( confluence_client: OnyxConfluence, confluence_object: dict[str, Any], fetched_titles: set[str], ) -> str: """Parse a Confluence html page and replace the 'user Id' by the real User Display Name Args: confluence_object (dict): The confluence object as a dict confluence_client (Confluence): Confluence client fetched_titles (set[str]): The titles of the pages that have already been fetched Returns: str: loaded and formated Confluence page """ body = confluence_object["body"] object_html = body.get("storage", body.get("view", {})).get("value") soup = bs4.BeautifulSoup(object_html, "html.parser") _remove_macro_stylings(soup=soup) for user in soup.findAll("ri:user"): user_id = ( user.attrs["ri:account-id"] if "ri:account-id" in user.attrs else user.get("ri:userkey") ) if not user_id: logger.warning( f"ri:userkey not found in ri:user element. Found attrs: {user.attrs}" ) continue # Include @ sign for tagging, more clear for LLM user.replaceWith("@" + _get_user(confluence_client, user_id)) for html_page_reference in soup.findAll("ac:structured-macro"): # Here, we only want to process page within page macros if html_page_reference.attrs.get("ac:name") != "include": continue page_data = html_page_reference.find("ri:page") if not page_data: logger.warning( f"Skipping retrieval of {html_page_reference} because because page data is missing" ) continue page_title = page_data.attrs.get("ri:content-title") if not page_title: # only fetch pages that have a title logger.warning( f"Skipping retrieval of {html_page_reference} because it has no title" ) continue if page_title in fetched_titles: # prevent recursive fetching of pages logger.debug(f"Skipping {page_title} because it has already been fetched") continue fetched_titles.add(page_title) # Wrap this in a try-except because there are some pages that might not exist try: page_query = f"type=page and title='{quote(page_title)}'" page_contents: dict[str, Any] | None = None # Confluence enforces title uniqueness, so we should only get one result here for page in confluence_client.paginated_cql_retrieval( cql=page_query, expand="body.storage.value", limit=1, ): page_contents = page break except Exception as e: logger.warning( f"Error getting page contents for object {confluence_object}: {e}" ) continue if not page_contents: continue text_from_page = extract_text_from_confluence_html( confluence_client=confluence_client, confluence_object=page_contents, fetched_titles=fetched_titles, ) html_page_reference.replaceWith(text_from_page) for html_link_body in soup.findAll("ac:link-body"): # This extracts the text from inline links in the page so they can be # represented in the document text as plain text try: text_from_link = html_link_body.text html_link_body.replaceWith(f"(LINK TEXT: {text_from_link})") except Exception as e: logger.warning(f"Error processing ac:link-body: {e}") for html_attachment in soup.findAll("ri:attachment"): # This extracts the text from inline attachments in the page so they can be # represented in the document text as plain text try: html_attachment.replaceWith( f"{sanitize_attachment_title(html_attachment.attrs['ri:filename'])}" ) # to be replaced later except Exception as e: logger.warning(f"Error processing ac:attachment: {e}") return format_document_soup(soup) def _remove_macro_stylings(soup: bs4.BeautifulSoup) -> None: for macro_root in soup.findAll("ac:structured-macro"): if not isinstance(macro_root, bs4.Tag): continue macro_styling = macro_root.find(name="ac:parameter", attrs={"ac:name": "page"}) if not macro_styling or not isinstance(macro_styling, bs4.Tag): continue macro_styling.extract() ================================================ FILE: backend/onyx/connectors/confluence/user_profile_override.py ================================================ from onyx.connectors.confluence.models import ConfluenceUser def process_confluence_user_profiles_override( confluence_user_email_override: list[dict[str, str]], ) -> list[ConfluenceUser]: return [ ConfluenceUser( user_id=override["user_id"], # username is not returned by the Confluence Server API anyways username=override["username"], display_name=override["display_name"], email=override["email"], type=override["type"], ) for override in confluence_user_email_override if override is not None ] ================================================ FILE: backend/onyx/connectors/confluence/utils.py ================================================ import math import time from collections.abc import Callable from datetime import datetime from datetime import timedelta from datetime import timezone from io import BytesIO from pathlib import Path from typing import Any from typing import cast from typing import TYPE_CHECKING from typing import TypeVar from urllib.parse import parse_qs from urllib.parse import quote from urllib.parse import urljoin from urllib.parse import urlparse import requests from pydantic import BaseModel from onyx.configs.app_configs import ( CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD, ) from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD from onyx.configs.constants import FileOrigin from onyx.file_processing.extract_file_text import extract_file_text from onyx.file_processing.extract_file_text import get_file_ext from onyx.file_processing.file_types import OnyxFileExtensions from onyx.file_processing.file_types import OnyxMimeTypes from onyx.file_processing.image_utils import store_image_and_create_section from onyx.utils.logger import setup_logger if TYPE_CHECKING: from onyx.connectors.confluence.onyx_confluence import OnyxConfluence logger = setup_logger() CONFLUENCE_OAUTH_TOKEN_URL = "https://auth.atlassian.com/oauth/token" RATE_LIMIT_MESSAGE_LOWERCASE = "Rate limit exceeded".lower() class TokenResponse(BaseModel): access_token: str expires_in: int token_type: str refresh_token: str scope: str def validate_attachment_filetype( attachment: dict[str, Any], ) -> bool: """ Validates if the attachment is a supported file type. """ media_type = attachment.get("metadata", {}).get("mediaType", "") if media_type.startswith("image/"): return media_type in OnyxMimeTypes.IMAGE_MIME_TYPES # For non-image files, check if we support the extension title = attachment.get("title", "") extension = get_file_ext(title) return extension in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS class AttachmentProcessingResult(BaseModel): """ A container for results after processing a Confluence attachment. 'text' is the textual content of the attachment. 'file_name' is the final file name used in FileStore to store the content. 'error' holds an exception or string if something failed. """ text: str | None file_name: str | None error: str | None = None def _make_attachment_link( confluence_client: "OnyxConfluence", attachment: dict[str, Any], parent_content_id: str | None = None, ) -> str | None: download_link = "" if "api.atlassian.com" in confluence_client.url: # https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get if not parent_content_id: logger.warning( "parent_content_id is required to download attachments from Confluence Cloud!" ) return None download_link = ( confluence_client.url + f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download" ) else: download_link = confluence_client.url + attachment["_links"]["download"] return download_link def process_attachment( confluence_client: "OnyxConfluence", attachment: dict[str, Any], parent_content_id: str | None, allow_images: bool, ) -> AttachmentProcessingResult: """ Processes a Confluence attachment. If it's a document, extracts text, or if it's an image, stores it for later analysis. Returns a structured result. """ try: # Get the media type from the attachment metadata media_type: str = attachment.get("metadata", {}).get("mediaType", "") # Validate the attachment type if not validate_attachment_filetype(attachment): return AttachmentProcessingResult( text=None, file_name=None, error=f"Unsupported file type: {media_type}", ) attachment_link = _make_attachment_link( confluence_client, attachment, parent_content_id ) if not attachment_link: return AttachmentProcessingResult( text=None, file_name=None, error="Failed to make attachment link" ) attachment_size = attachment["extensions"]["fileSize"] if media_type.startswith("image/"): if not allow_images: return AttachmentProcessingResult( text=None, file_name=None, error="Image downloading is not enabled", ) else: if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD: logger.warning( f"Skipping {attachment_link} due to size. " f"size={attachment_size} " f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}" ) return AttachmentProcessingResult( text=None, file_name=None, error=f"Attachment text too long: {attachment_size} chars", ) logger.info( f"Downloading attachment: title={attachment['title']} length={attachment_size} link={attachment_link}" ) # Download the attachment resp: requests.Response = confluence_client._session.get(attachment_link) if resp.status_code != 200: logger.warning( f"Failed to fetch {attachment_link} with status code {resp.status_code}" ) return AttachmentProcessingResult( text=None, file_name=None, error=f"Attachment download status code is {resp.status_code}", ) raw_bytes = resp.content if not raw_bytes: return AttachmentProcessingResult( text=None, file_name=None, error="attachment.content is None" ) # Process image attachments if media_type.startswith("image/"): return _process_image_attachment( confluence_client, attachment, raw_bytes, media_type ) # Process document attachments try: text = extract_file_text( file=BytesIO(raw_bytes), file_name=attachment["title"], ) # Skip if the text is too long if len(text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD: return AttachmentProcessingResult( text=None, file_name=None, error=f"Attachment text too long: {len(text)} chars", ) return AttachmentProcessingResult(text=text, file_name=None, error=None) except Exception as e: return AttachmentProcessingResult( text=None, file_name=None, error=f"Failed to extract text: {e}" ) except Exception as e: return AttachmentProcessingResult( text=None, file_name=None, error=f"Failed to process attachment: {e}" ) def _process_image_attachment( confluence_client: "OnyxConfluence", # noqa: ARG001 attachment: dict[str, Any], raw_bytes: bytes, media_type: str, ) -> AttachmentProcessingResult: """Process an image attachment by saving it without generating a summary.""" try: # Use the standardized image storage and section creation section, file_name = store_image_and_create_section( image_data=raw_bytes, file_id=Path(attachment["id"]).name, display_name=attachment["title"], media_type=media_type, file_origin=FileOrigin.CONNECTOR, ) logger.info(f"Stored image attachment with file name: {file_name}") # Return empty text but include the file_name for later processing return AttachmentProcessingResult(text="", file_name=file_name, error=None) except Exception as e: msg = f"Image storage failed for {attachment['title']}: {e}" logger.error(msg, exc_info=e) return AttachmentProcessingResult(text=None, file_name=None, error=msg) def convert_attachment_to_content( confluence_client: "OnyxConfluence", attachment: dict[str, Any], page_id: str, allow_images: bool, ) -> tuple[str | None, str | None] | None: """ Facade function which: 1. Validates attachment type 2. Extracts content or stores image for later processing 3. Returns (content_text, stored_file_name) or None if we should skip it """ media_type = attachment.get("metadata", {}).get("mediaType", "") # Quick check for unsupported types: if media_type.startswith("video/") or media_type == "application/gliffy+json": logger.warning( f"Skipping unsupported attachment type: '{media_type}' for {attachment['title']}" ) return None result = process_attachment(confluence_client, attachment, page_id, allow_images) if result.error is not None: logger.warning( f"Attachment {attachment['title']} encountered error: {result.error}" ) return None # Return the text and the file name return result.text, result.file_name def build_confluence_document_id( base_url: str, content_url: str, is_cloud: bool ) -> str: """For confluence, the document id is the page url for a page based document or the attachment download url for an attachment based document Args: base_url (str): The base url of the Confluence instance content_url (str): The url of the page or attachment download url Returns: str: The document id """ # NOTE: urljoin is tricky and will drop the last segment of the base if it doesn't # end with "/" because it believes that makes it a file. final_url = base_url.rstrip("/") + "/" if is_cloud and not final_url.endswith("/wiki/"): final_url = urljoin(final_url, "wiki") + "/" final_url = urljoin(final_url, content_url.lstrip("/")) return final_url def datetime_from_string(datetime_string: str) -> datetime: datetime_object = datetime.fromisoformat(datetime_string) if datetime_object.tzinfo is None: # If no timezone info, assume it is UTC datetime_object = datetime_object.replace(tzinfo=timezone.utc) else: # If not in UTC, translate it datetime_object = datetime_object.astimezone(timezone.utc) return datetime_object def confluence_refresh_tokens( client_id: str, client_secret: str, cloud_id: str, refresh_token: str ) -> dict[str, Any]: # rotate the refresh and access token # Note that access tokens are only good for an hour in confluence cloud, # so we're going to have problems if the connector runs for longer # https://developer.atlassian.com/cloud/confluence/oauth-2-3lo-apps/#use-a-refresh-token-to-get-another-access-token-and-refresh-token-pair response = requests.post( CONFLUENCE_OAUTH_TOKEN_URL, headers={"Content-Type": "application/x-www-form-urlencoded"}, data={ "grant_type": "refresh_token", "client_id": client_id, "client_secret": client_secret, "refresh_token": refresh_token, }, ) try: token_response = TokenResponse.model_validate_json(response.text) except Exception: raise RuntimeError("Confluence Cloud token refresh failed.") now = datetime.now(timezone.utc) expires_at = now + timedelta(seconds=token_response.expires_in) new_credentials: dict[str, Any] = {} new_credentials["confluence_access_token"] = token_response.access_token new_credentials["confluence_refresh_token"] = token_response.refresh_token new_credentials["created_at"] = now.isoformat() new_credentials["expires_at"] = expires_at.isoformat() new_credentials["expires_in"] = token_response.expires_in new_credentials["scope"] = token_response.scope new_credentials["cloud_id"] = cloud_id return new_credentials F = TypeVar("F", bound=Callable[..., Any]) # https://developer.atlassian.com/cloud/confluence/rate-limiting/ # this uses the native rate limiting option provided by the # confluence client and otherwise applies a simpler set of error handling def handle_confluence_rate_limit(confluence_call: F) -> F: def wrapped_call(*args: list[Any], **kwargs: Any) -> Any: MAX_RETRIES = 5 TIMEOUT = 600 timeout_at = time.monotonic() + TIMEOUT for attempt in range(MAX_RETRIES): if time.monotonic() > timeout_at: raise TimeoutError( f"Confluence call attempts took longer than {TIMEOUT} seconds." ) try: # we're relying more on the client to rate limit itself # and applying our own retries in a more specific set of circumstances return confluence_call(*args, **kwargs) except requests.HTTPError as e: delay_until = _handle_http_error(e, attempt, MAX_RETRIES) logger.warning( f"HTTPError in confluence call. Retrying in {delay_until} seconds..." ) while time.monotonic() < delay_until: # in the future, check a signal here to exit time.sleep(1) except AttributeError as e: # Some error within the Confluence library, unclear why it fails. # Users reported it to be intermittent, so just retry if attempt == MAX_RETRIES - 1: raise e logger.exception( "Confluence Client raised an AttributeError. Retrying..." ) time.sleep(5) return cast(F, wrapped_call) def _handle_http_error(e: requests.HTTPError, attempt: int, max_retries: int) -> int: MIN_DELAY = 2 MAX_DELAY = 60 STARTING_DELAY = 5 BACKOFF = 2 # Check if the response or headers are None to avoid potential AttributeError if e.response is None or e.response.headers is None: logger.warning("HTTPError with `None` as response or as headers") raise e # Confluence Server returns 403 when rate limited if e.response.status_code == 403: FORBIDDEN_MAX_RETRY_ATTEMPTS = 7 FORBIDDEN_RETRY_DELAY = 10 if attempt < FORBIDDEN_MAX_RETRY_ATTEMPTS: logger.warning( "403 error. This sometimes happens when we hit " f"Confluence rate limits. Retrying in {FORBIDDEN_RETRY_DELAY} seconds..." ) return FORBIDDEN_RETRY_DELAY raise e if e.response.status_code >= 500: if attempt >= max_retries - 1: raise e delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY) logger.warning( f"Server error {e.response.status_code}. " f"Retrying in {delay} seconds (attempt {attempt + 1})..." ) return math.ceil(time.monotonic() + delay) if ( e.response.status_code != 429 and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower() ): raise e retry_after = None retry_after_header = e.response.headers.get("Retry-After") if retry_after_header is not None: try: retry_after = int(retry_after_header) if retry_after > MAX_DELAY: logger.warning( f"Clamping retry_after from {retry_after} to {MAX_DELAY} seconds..." ) retry_after = MAX_DELAY if retry_after < MIN_DELAY: retry_after = MIN_DELAY except ValueError: pass if retry_after is not None: logger.warning( f"Rate limiting with retry header. Retrying after {retry_after} seconds..." ) delay = retry_after else: logger.warning( "Rate limiting without retry header. Retrying with exponential backoff..." ) delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY) delay_until = math.ceil(time.monotonic() + delay) return delay_until def get_single_param_from_url(url: str, param: str) -> str | None: """Get a parameter from a url""" parsed_url = urlparse(url) return parse_qs(parsed_url.query).get(param, [None])[0] def get_start_param_from_url(url: str) -> int: """Get the start parameter from a url""" start_str = get_single_param_from_url(url, "start") return int(start_str) if start_str else 0 def update_param_in_path(path: str, param: str, value: str) -> str: """Update a parameter in a path. Path should look something like: /api/rest/users?start=0&limit=10 """ parsed_url = urlparse(path) query_params = parse_qs(parsed_url.query) query_params[param] = [value] return ( path.split("?")[0] + "?" + "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items()) ) ================================================ FILE: backend/onyx/connectors/connector_runner.py ================================================ import sys import time from collections.abc import Generator from datetime import datetime from typing import Generic from typing import TypeVar from onyx.connectors.interfaces import BaseConnector from onyx.connectors.interfaces import CheckpointedConnector from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.models import ConnectorCheckpoint from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.utils.logger import setup_logger logger = setup_logger() TimeRange = tuple[datetime, datetime] CT = TypeVar("CT", bound=ConnectorCheckpoint) def batched_doc_ids( checkpoint_connector_generator: CheckpointOutput[CT], batch_size: int, ) -> Generator[set[str], None, None]: batch: set[str] = set() for document, hierarchy_node, failure, next_checkpoint in CheckpointOutputWrapper[ CT ]()(checkpoint_connector_generator): if document is not None: batch.add(document.id) elif ( failure and failure.failed_document and failure.failed_document.document_id ): batch.add(failure.failed_document.document_id) # HierarchyNodes don't have IDs that need to be batched for doc processing if len(batch) >= batch_size: yield batch batch = set() if len(batch) > 0: yield batch class CheckpointOutputWrapper(Generic[CT]): """ Wraps a CheckpointOutput generator to give things back in a more digestible format, specifically for Document outputs. The connector format is easier for the connector implementor (e.g. it enforces exactly one new checkpoint is returned AND that the checkpoint is at the end), thus the different formats. """ def __init__(self) -> None: self.next_checkpoint: CT | None = None def __call__( self, checkpoint_connector_generator: CheckpointOutput[CT], ) -> Generator[ tuple[ Document | None, HierarchyNode | None, ConnectorFailure | None, CT | None ], None, None, ]: # grabs the final return value and stores it in the `next_checkpoint` variable def _inner_wrapper( checkpoint_connector_generator: CheckpointOutput[CT], ) -> CheckpointOutput[CT]: self.next_checkpoint = yield from checkpoint_connector_generator return self.next_checkpoint # not used for item in _inner_wrapper(checkpoint_connector_generator): if isinstance(item, Document): yield item, None, None, None elif isinstance(item, HierarchyNode): yield None, item, None, None elif isinstance(item, ConnectorFailure): yield None, None, item, None else: raise ValueError(f"Invalid connector output type: {type(item)}") if self.next_checkpoint is None: raise RuntimeError( "Checkpoint is None. This should never happen - the connector should always return a checkpoint." ) yield None, None, None, self.next_checkpoint class ConnectorRunner(Generic[CT]): """ Handles: - Batching - Additional exception logging - Combining different connector types to a single interface """ def __init__( self, connector: BaseConnector, batch_size: int, # cannot be True for non-checkpointed connectors include_permissions: bool, time_range: TimeRange | None = None, ): if not isinstance(connector, CheckpointedConnector) and include_permissions: raise ValueError( "include_permissions cannot be True for non-checkpointed connectors" ) self.connector = connector self.time_range = time_range self.batch_size = batch_size self.include_permissions = include_permissions self.doc_batch: list[Document] = [] self.hierarchy_node_batch: list[HierarchyNode] = [] def run(self, checkpoint: CT) -> Generator[ tuple[ list[Document] | None, list[HierarchyNode] | None, ConnectorFailure | None, CT | None, ], None, None, ]: """ Yields batches of Documents, HierarchyNodes, failures, and checkpoints. Returns tuples of: - (doc_batch, None, None, None) - batch of documents - (None, hierarchy_batch, None, None) - batch of hierarchy nodes - (None, None, failure, None) - a connector failure - (None, None, None, checkpoint) - new checkpoint """ try: if isinstance(self.connector, CheckpointedConnector): if self.time_range is None: raise ValueError("time_range is required for CheckpointedConnector") start = time.monotonic() if self.include_permissions: if not isinstance( self.connector, CheckpointedConnectorWithPermSync ): raise ValueError( "Connector does not support permission syncing" ) load_from_checkpoint = ( self.connector.load_from_checkpoint_with_perm_sync ) else: load_from_checkpoint = self.connector.load_from_checkpoint checkpoint_connector_generator = load_from_checkpoint( start=self.time_range[0].timestamp(), end=self.time_range[1].timestamp(), checkpoint=checkpoint, ) next_checkpoint: CT | None = None # this is guaranteed to always run at least once with next_checkpoint being non-None for ( document, hierarchy_node, failure, next_checkpoint, ) in CheckpointOutputWrapper[CT]()(checkpoint_connector_generator): if document is not None: self.doc_batch.append(document) if hierarchy_node is not None: self.hierarchy_node_batch.append(hierarchy_node) if failure is not None: yield None, None, failure, None # Yield hierarchy nodes batch if it reaches batch_size # (yield nodes before docs to maintain parent-before-child invariant) if len(self.hierarchy_node_batch) >= self.batch_size: yield None, self.hierarchy_node_batch, None, None self.hierarchy_node_batch = [] # Yield document batch if it reaches batch_size # First flush any pending hierarchy nodes to ensure parents exist if len(self.doc_batch) >= self.batch_size: if len(self.hierarchy_node_batch) > 0: yield None, self.hierarchy_node_batch, None, None self.hierarchy_node_batch = [] yield self.doc_batch, None, None, None self.doc_batch = [] # yield remaining hierarchy nodes first (parents before children) if len(self.hierarchy_node_batch) > 0: yield None, self.hierarchy_node_batch, None, None self.hierarchy_node_batch = [] # yield remaining documents if len(self.doc_batch) > 0: yield self.doc_batch, None, None, None self.doc_batch = [] yield None, None, None, next_checkpoint logger.debug( f"Connector took {time.monotonic() - start} seconds to get to the next checkpoint." ) else: finished_checkpoint = self.connector.build_dummy_checkpoint() finished_checkpoint.has_more = False if isinstance(self.connector, PollConnector): if self.time_range is None: raise ValueError("time_range is required for PollConnector") for batch in self.connector.poll_source( start=self.time_range[0].timestamp(), end=self.time_range[1].timestamp(), ): docs, nodes = self._separate_batch(batch) if nodes: yield None, nodes, None, None if docs: yield docs, None, None, None yield None, None, None, finished_checkpoint elif isinstance(self.connector, LoadConnector): for batch in self.connector.load_from_state(): docs, nodes = self._separate_batch(batch) if nodes: yield None, nodes, None, None if docs: yield docs, None, None, None yield None, None, None, finished_checkpoint else: raise ValueError(f"Invalid connector. type: {type(self.connector)}") except Exception: exc_type, _, exc_traceback = sys.exc_info() # Traverse the traceback to find the last frame where the exception was raised tb = exc_traceback if tb is None: logger.error("No traceback found for exception") raise while tb.tb_next: tb = tb.tb_next # Move to the next frame in the traceback # Get the local variables from the frame where the exception occurred local_vars = tb.tb_frame.f_locals local_vars_str = "\n".join( f"{key}: {value}" for key, value in local_vars.items() ) logger.error( f"Error in connector. type: {exc_type};\nlocal_vars below -> \n{local_vars_str[:1024]}" ) raise def _separate_batch( self, batch: list[Document | HierarchyNode] ) -> tuple[list[Document], list[HierarchyNode]]: """Separate a mixed batch into Documents and HierarchyNodes.""" docs: list[Document] = [] nodes: list[HierarchyNode] = [] for item in batch: if isinstance(item, Document): docs.append(item) elif isinstance(item, HierarchyNode): nodes.append(item) return docs, nodes ================================================ FILE: backend/onyx/connectors/credentials_provider.py ================================================ import uuid from types import TracebackType from typing import Any from redis.lock import Lock as RedisLock from sqlalchemy import select from onyx.connectors.interfaces import CredentialsProviderInterface from onyx.db.engine.sql_engine import get_session_with_tenant from onyx.db.models import Credential from onyx.redis.redis_pool import get_redis_client class OnyxDBCredentialsProvider( CredentialsProviderInterface["OnyxDBCredentialsProvider"] ): """Implementation to allow the connector to callback and update credentials in the db. Required in cases where credentials can rotate while the connector is running. """ LOCK_TTL = 900 # TTL of the lock def __init__(self, tenant_id: str, connector_name: str, credential_id: int): self._tenant_id = tenant_id self._connector_name = connector_name self._credential_id = credential_id self.redis_client = get_redis_client(tenant_id=tenant_id) # lock used to prevent overlapping renewal of credentials self.lock_key = f"da_lock:connector:{connector_name}:credential_{credential_id}" self._lock: RedisLock = self.redis_client.lock(self.lock_key, self.LOCK_TTL) def __enter__(self) -> "OnyxDBCredentialsProvider": acquired = self._lock.acquire(blocking_timeout=self.LOCK_TTL) if not acquired: raise RuntimeError(f"Could not acquire lock for key: {self.lock_key}") return self def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: """Release the lock when exiting the context.""" if self._lock and self._lock.owned(): self._lock.release() def get_tenant_id(self) -> str | None: return self._tenant_id def get_provider_key(self) -> str: return str(self._credential_id) def get_credentials(self) -> dict[str, Any]: with get_session_with_tenant(tenant_id=self._tenant_id) as db_session: credential = db_session.execute( select(Credential).where(Credential.id == self._credential_id) ).scalar_one() if credential is None: raise ValueError( f"No credential found: credential={self._credential_id}" ) if credential.credential_json is None: return {} return credential.credential_json.get_value(apply_mask=False) def set_credentials(self, credential_json: dict[str, Any]) -> None: with get_session_with_tenant(tenant_id=self._tenant_id) as db_session: try: credential = db_session.execute( select(Credential) .where(Credential.id == self._credential_id) .with_for_update() ).scalar_one() if credential is None: raise ValueError( f"No credential found: credential={self._credential_id}" ) credential.credential_json = credential_json # type: ignore[assignment] db_session.commit() except Exception: db_session.rollback() raise def is_dynamic(self) -> bool: return True class OnyxStaticCredentialsProvider( CredentialsProviderInterface["OnyxStaticCredentialsProvider"] ): """Implementation (a very simple one!) to handle static credentials.""" def __init__( self, tenant_id: str | None, connector_name: str, credential_json: dict[str, Any], ): self._tenant_id = tenant_id self._connector_name = connector_name self._credential_json = credential_json self._provider_key = str(uuid.uuid4()) def __enter__(self) -> "OnyxStaticCredentialsProvider": return self def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: pass def get_tenant_id(self) -> str | None: return self._tenant_id def get_provider_key(self) -> str: return self._provider_key def get_credentials(self) -> dict[str, Any]: return self._credential_json def set_credentials(self, credential_json: dict[str, Any]) -> None: self._credential_json = credential_json def is_dynamic(self) -> bool: return False ================================================ FILE: backend/onyx/connectors/cross_connector_utils/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/cross_connector_utils/miscellaneous_utils.py ================================================ import re from collections.abc import Callable from collections.abc import Iterator from datetime import datetime from datetime import timezone from typing import Any from typing import TypeVar from urllib.parse import urljoin from urllib.parse import urlparse import requests from dateutil.parser import parse from dateutil.parser import ParserError from onyx.configs.app_configs import CONNECTOR_LOCALHOST_OVERRIDE from onyx.configs.constants import DocumentSource from onyx.configs.constants import IGNORE_FOR_QA from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import OnyxMetadata from onyx.utils.logger import setup_logger from onyx.utils.text_processing import is_valid_email T = TypeVar("T") U = TypeVar("U") logger = setup_logger() def datetime_to_utc(dt: datetime) -> datetime: if dt.tzinfo is None or dt.tzinfo.utcoffset(dt) is None: dt = dt.replace(tzinfo=timezone.utc) return dt.astimezone(timezone.utc) def time_str_to_utc(datetime_str: str) -> datetime: # Remove all timezone abbreviations in parentheses normalized = re.sub(r"\([A-Z]+\)", "", datetime_str).strip() # Remove any remaining parentheses and their contents normalized = re.sub(r"\(.*?\)", "", normalized).strip() candidates: list[str] = [normalized] # Some sources (e.g. Gmail) may prefix the value with labels like "Date:" label_stripped = re.sub( r"^\s*[A-Za-z][A-Za-z\s_-]*:\s*", "", normalized, count=1 ).strip() if label_stripped and label_stripped != normalized: candidates.append(label_stripped) # Fix common format issues (e.g. "0000" => "+0000") for candidate in list(candidates): if " 0000" in candidate: fixed = candidate.replace(" 0000", " +0000") if fixed not in candidates: candidates.append(fixed) last_exception: Exception | None = None for candidate in candidates: try: dt = parse(candidate) return datetime_to_utc(dt) except (ValueError, ParserError) as exc: last_exception = exc if last_exception is not None: raise last_exception # Fallback in case parsing failed without raising (should not happen) raise ValueError(f"Unable to parse datetime string: {datetime_str}") # TODO: use this function in other connectors def datetime_from_utc_timestamp(timestamp: int) -> datetime: """Convert a Unix timestamp to a datetime object in UTC""" return datetime.fromtimestamp(timestamp, tz=timezone.utc) def basic_expert_info_representation(info: BasicExpertInfo) -> str | None: if info.first_name and info.last_name: return f"{info.first_name} {info.middle_initial} {info.last_name}" if info.display_name: return info.display_name if info.email and is_valid_email(info.email): return info.email if info.first_name: return info.first_name return None def get_experts_stores_representations( experts: list[BasicExpertInfo] | None, ) -> list[str] | None: """Gets string representations of experts supplied. If an expert cannot be represented as a string, it is omitted from the result. """ if not experts: return None reps: list[str | None] = [ basic_expert_info_representation(owner) for owner in experts ] return [owner for owner in reps if owner is not None] def process_in_batches( objects: list[T], process_function: Callable[[T], U], batch_size: int ) -> Iterator[list[U]]: for i in range(0, len(objects), batch_size): yield [process_function(obj) for obj in objects[i : i + batch_size]] def get_metadata_keys_to_ignore() -> list[str]: return [IGNORE_FOR_QA] def _parse_document_source(connector_type: Any) -> DocumentSource | None: if connector_type is None: return None if isinstance(connector_type, DocumentSource): return connector_type if not isinstance(connector_type, str): logger.warning(f"Invalid connector_type type: {type(connector_type).__name__}") return None normalized = re.sub(r"[\s\-]+", "_", connector_type.strip().lower()) try: return DocumentSource(normalized) except ValueError: logger.warning( f"Invalid connector_type value: '{connector_type}' (normalized: '{normalized}')" ) return None def process_onyx_metadata( metadata: dict[str, Any], ) -> tuple[OnyxMetadata, dict[str, Any]]: """ Users may set Onyx metadata and custom tags in text files. https://docs.onyx.app/admins/connectors/official/file Any unrecognized fields are treated as custom tags. """ p_owner_names = metadata.get("primary_owners") p_owners = ( [BasicExpertInfo(display_name=name) for name in p_owner_names] if p_owner_names else None ) s_owner_names = metadata.get("secondary_owners") s_owners = ( [BasicExpertInfo(display_name=name) for name in s_owner_names] if s_owner_names else None ) source_type = _parse_document_source(metadata.get("connector_type")) dt_str = metadata.get("doc_updated_at") doc_updated_at = time_str_to_utc(dt_str) if dt_str else None return ( OnyxMetadata( document_id=metadata.get("id"), source_type=source_type, link=metadata.get("link"), file_display_name=metadata.get("file_display_name"), title=metadata.get("title"), primary_owners=p_owners, secondary_owners=s_owners, doc_updated_at=doc_updated_at, ), { k: v for k, v in metadata.items() if k not in [ "document_id", "time_updated", "doc_updated_at", "link", "primary_owners", "secondary_owners", "filename", "file_display_name", "title", "connector_type", "pdf_password", "mime_type", ] }, ) def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str: if CONNECTOR_LOCALHOST_OVERRIDE: # Used for development base_domain = CONNECTOR_LOCALHOST_OVERRIDE return f"{base_domain.strip('/')}/connector/oauth/callback/{connector_id}" def is_atlassian_date_error(e: Exception) -> bool: return "field 'updated' is invalid" in str(e) def get_cloudId(base_url: str) -> str: tenant_info_url = urljoin(base_url, "/_edge/tenant_info") response = requests.get(tenant_info_url, timeout=10) response.raise_for_status() return response.json()["cloudId"] def scoped_url(url: str, product: str) -> str: parsed = urlparse(url) base_url = parsed.scheme + "://" + parsed.netloc cloud_id = get_cloudId(base_url) return f"https://api.atlassian.com/ex/{product}/{cloud_id}{parsed.path}" ================================================ FILE: backend/onyx/connectors/cross_connector_utils/rate_limit_wrapper.py ================================================ import time from collections.abc import Callable from functools import wraps from typing import Any from typing import cast from typing import TypeVar import requests from onyx.utils.logger import setup_logger logger = setup_logger() F = TypeVar("F", bound=Callable[..., Any]) class RateLimitTriedTooManyTimesError(Exception): pass class _RateLimitDecorator: """Builds a generic wrapper/decorator for calls to external APIs that prevents making more than `max_calls` requests per `period` Implementation inspired by the `ratelimit` library: https://github.com/tomasbasham/ratelimit. NOTE: is not thread safe. """ def __init__( self, max_calls: int, period: float, # in seconds sleep_time: float = 2, # in seconds sleep_backoff: float = 2, # applies exponential backoff max_num_sleep: int = 0, ): self.max_calls = max_calls self.period = period self.sleep_time = sleep_time self.sleep_backoff = sleep_backoff self.max_num_sleep = max_num_sleep self.call_history: list[float] = [] self.curr_calls = 0 def __call__(self, func: F) -> F: @wraps(func) def wrapped_func(*args: list, **kwargs: dict[str, Any]) -> Any: # cleanup calls which are no longer relevant self._cleanup() # check if we've exceeded the rate limit sleep_cnt = 0 while len(self.call_history) == self.max_calls: sleep_time = self.sleep_time * (self.sleep_backoff**sleep_cnt) logger.notice( f"Rate limit exceeded for function {func.__name__}. Waiting {sleep_time} seconds before retrying." ) time.sleep(sleep_time) sleep_cnt += 1 if self.max_num_sleep != 0 and sleep_cnt >= self.max_num_sleep: raise RateLimitTriedTooManyTimesError( f"Exceeded '{self.max_num_sleep}' retries for function '{func.__name__}'" ) self._cleanup() # add the current call to the call history self.call_history.append(time.monotonic()) return func(*args, **kwargs) return cast(F, wrapped_func) def _cleanup(self) -> None: curr_time = time.monotonic() time_to_expire_before = curr_time - self.period self.call_history = [ call_time for call_time in self.call_history if call_time > time_to_expire_before ] rate_limit_builder = _RateLimitDecorator """If you want to allow the external service to tell you when you've hit the rate limit, use the following instead""" R = TypeVar("R", bound=Callable[..., requests.Response]) def wrap_request_to_handle_ratelimiting( request_fn: R, default_wait_time_sec: int = 30, max_waits: int = 30 ) -> R: def wrapped_request(*args: list, **kwargs: dict[str, Any]) -> requests.Response: for _ in range(max_waits): response = request_fn(*args, **kwargs) if response.status_code == 429: try: wait_time = int( response.headers.get("Retry-After", default_wait_time_sec) ) except ValueError: wait_time = default_wait_time_sec time.sleep(wait_time) continue return response raise RateLimitTriedTooManyTimesError(f"Exceeded '{max_waits}' retries") return cast(R, wrapped_request) _rate_limited_get = wrap_request_to_handle_ratelimiting(requests.get) _rate_limited_post = wrap_request_to_handle_ratelimiting(requests.post) class _RateLimitedRequest: get = _rate_limited_get post = _rate_limited_post rl_requests = _RateLimitedRequest ================================================ FILE: backend/onyx/connectors/discord/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/discord/connector.py ================================================ import asyncio from collections.abc import AsyncGenerator from collections.abc import AsyncIterable from collections.abc import Iterable from datetime import datetime from datetime import timezone from typing import Any from typing import cast from discord import Client from discord.channel import TextChannel from discord.channel import Thread from discord.enums import MessageType from discord.errors import LoginFailure from discord.flags import Intents from discord.message import Message as DiscordMessage from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.exceptions import CredentialInvalidError from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import ImageSection from onyx.connectors.models import TextSection from onyx.utils.logger import setup_logger logger = setup_logger() _DISCORD_DOC_ID_PREFIX = "DISCORD_" _SNIPPET_LENGTH = 30 def _convert_message_to_document( message: DiscordMessage, sections: list[TextSection], ) -> Document: """ Convert a discord message to a document Sections are collected before calling this function because it relies on async calls to fetch the thread history if there is one """ metadata: dict[str, str | list[str]] = {} semantic_substring = "" # Only messages from TextChannels will make it here but we have to check for it anyways if isinstance(message.channel, TextChannel) and ( channel_name := message.channel.name ): metadata["Channel"] = channel_name semantic_substring += f" in Channel: #{channel_name}" # Single messages dont have a title title = "" # If there is a thread, add more detail to the metadata, title, and semantic identifier if isinstance(message.channel, Thread): # Threads do have a title title = message.channel.name # If its a thread, update the metadata, title, and semantic_substring metadata["Thread"] = title # Add more detail to the semantic identifier if available semantic_substring += f" in Thread: {title}" snippet: str = ( message.content[:_SNIPPET_LENGTH].rstrip() + "..." if len(message.content) > _SNIPPET_LENGTH else message.content ) semantic_identifier = f"{message.author.name} said{semantic_substring}: {snippet}" return Document( id=f"{_DISCORD_DOC_ID_PREFIX}{message.id}", source=DocumentSource.DISCORD, semantic_identifier=semantic_identifier, doc_updated_at=message.edited_at, title=title, sections=(cast(list[TextSection | ImageSection], sections)), metadata=metadata, ) async def _fetch_filtered_channels( discord_client: Client, server_ids: list[int] | None, channel_names: list[str] | None, ) -> list[TextChannel]: filtered_channels: list[TextChannel] = [] for channel in discord_client.get_all_channels(): if not channel.permissions_for(channel.guild.me).read_message_history: continue if not isinstance(channel, TextChannel): continue if server_ids and len(server_ids) > 0 and channel.guild.id not in server_ids: continue if channel_names and channel.name not in channel_names: continue filtered_channels.append(channel) logger.info(f"Found {len(filtered_channels)} channels for the authenticated user") return filtered_channels async def _fetch_documents_from_channel( channel: TextChannel, start_time: datetime | None, end_time: datetime | None, ) -> AsyncIterable[Document]: # Discord's epoch starts at 2015-01-01 discord_epoch = datetime(2015, 1, 1, tzinfo=timezone.utc) if start_time and start_time < discord_epoch: start_time = discord_epoch # NOTE: limit=None is the correct way to fetch all messages and threads with pagination # The discord package erroneously uses limit for both pagination AND number of results # This causes the history and archived_threads methods to return 100 results even if there are more results within the filters # Pagination is handled automatically (100 results at a time) when limit=None async for channel_message in channel.history( limit=None, after=start_time, before=end_time, ): # Skip messages that are not the default type if channel_message.type != MessageType.default: continue sections: list[TextSection] = [ TextSection( text=channel_message.content, link=channel_message.jump_url, ) ] yield _convert_message_to_document(channel_message, sections) for active_thread in channel.threads: async for thread_message in active_thread.history( limit=None, after=start_time, before=end_time, ): # Skip messages that are not the default type if thread_message.type != MessageType.default: continue sections = [ TextSection( text=thread_message.content, link=thread_message.jump_url, ) ] yield _convert_message_to_document(thread_message, sections) async for archived_thread in channel.archived_threads( limit=None, ): async for thread_message in archived_thread.history( limit=None, after=start_time, before=end_time, ): # Skip messages that are not the default type if thread_message.type != MessageType.default: continue sections = [ TextSection( text=thread_message.content, link=thread_message.jump_url, ) ] yield _convert_message_to_document(thread_message, sections) def _manage_async_retrieval( token: str, requested_start_date_string: str, channel_names: list[str], server_ids: list[int], start: datetime | None = None, end: datetime | None = None, ) -> Iterable[Document]: # parse requested_start_date_string to datetime pull_date: datetime | None = ( datetime.strptime(requested_start_date_string, "%Y-%m-%d").replace( tzinfo=timezone.utc ) if requested_start_date_string else None ) # Set start_time to the later of start and pull_date, or whichever is provided start_time = max(filter(None, [start, pull_date])) if start or pull_date else None end_time: datetime | None = end async def _async_fetch() -> AsyncGenerator[Document, None]: intents = Intents.default() intents.message_content = True async with Client(intents=intents) as discord_client: start_task = asyncio.create_task(discord_client.start(token)) ready_task = asyncio.create_task(discord_client.wait_until_ready()) done, _ = await asyncio.wait( {start_task, ready_task}, return_when=asyncio.FIRST_COMPLETED, ) # start() runs indefinitely once connected, so it only lands # in `done` when login/connection failed — propagate the error. if start_task in done: ready_task.cancel() start_task.result() filtered_channels: list[TextChannel] = await _fetch_filtered_channels( discord_client=discord_client, server_ids=server_ids, channel_names=channel_names, ) for channel in filtered_channels: async for doc in _fetch_documents_from_channel( channel=channel, start_time=start_time, end_time=end_time, ): yield doc def run_and_yield() -> Iterable[Document]: loop = asyncio.new_event_loop() async_gen = _async_fetch() try: while True: try: doc = loop.run_until_complete(anext(async_gen)) yield doc except StopAsyncIteration: break finally: # Must close the async generator before the loop so the Discord # client's `async with` block can await its shutdown coroutine. # The nested try/finally ensures the loop always closes even if # aclose() raises (same pattern as cursor.close() before conn.close()). try: loop.run_until_complete(async_gen.aclose()) finally: loop.close() return run_and_yield() class DiscordConnector(PollConnector, LoadConnector): def __init__( self, server_ids: list[str] = [], channel_names: list[str] = [], # YYYY-MM-DD start_date: str | None = None, batch_size: int = INDEX_BATCH_SIZE, ): self.batch_size = batch_size self.channel_names: list[str] = channel_names if channel_names else [] self.server_ids: list[int] = ( [int(server_id) for server_id in server_ids] if server_ids else [] ) self._discord_bot_token: str | None = None self.requested_start_date_string: str = start_date or "" @property def discord_bot_token(self) -> str: if self._discord_bot_token is None: raise ConnectorMissingCredentialError("Discord") return self._discord_bot_token def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self._discord_bot_token = credentials["discord_bot_token"] return None def validate_connector_settings(self) -> None: loop = asyncio.new_event_loop() try: client = Client(intents=Intents.default()) try: loop.run_until_complete(client.login(self.discord_bot_token)) except LoginFailure as e: raise CredentialInvalidError(f"Invalid Discord bot token: {e}") finally: loop.run_until_complete(client.close()) finally: loop.close() def _manage_doc_batching( self, start: datetime | None = None, end: datetime | None = None, ) -> GenerateDocumentsOutput: doc_batch: list[Document | HierarchyNode] = [] for doc in _manage_async_retrieval( token=self.discord_bot_token, requested_start_date_string=self.requested_start_date_string, channel_names=self.channel_names, server_ids=self.server_ids, start=start, end=end, ): doc_batch.append(doc) if len(doc_batch) >= self.batch_size: yield doc_batch doc_batch = [] if doc_batch: yield doc_batch def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: return self._manage_doc_batching( datetime.fromtimestamp(start, tz=timezone.utc), datetime.fromtimestamp(end, tz=timezone.utc), ) def load_from_state(self) -> GenerateDocumentsOutput: return self._manage_doc_batching(None, None) if __name__ == "__main__": import os import time end = time.time() # 1 day start = end - 24 * 60 * 60 * 1 # "1,2,3" server_ids: str | None = os.environ.get("server_ids", None) # "channel1,channel2" channel_names: str | None = os.environ.get("channel_names", None) connector = DiscordConnector( server_ids=server_ids.split(",") if server_ids else [], channel_names=channel_names.split(",") if channel_names else [], start_date=os.environ.get("start_date", None), ) connector.load_credentials( {"discord_bot_token": os.environ.get("discord_bot_token")} ) for doc_batch in connector.poll_source(start, end): for doc in doc_batch: print(doc) ================================================ FILE: backend/onyx/connectors/discourse/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/discourse/connector.py ================================================ import time import urllib.parse from datetime import datetime from datetime import timezone from typing import Any from typing import cast import requests from pydantic import BaseModel from requests import Response from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from onyx.connectors.cross_connector_utils.rate_limit_wrapper import ( rate_limit_builder, ) from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import ImageSection from onyx.connectors.models import TextSection from onyx.file_processing.html_utils import parse_html_page_basic from onyx.utils.logger import setup_logger from onyx.utils.retry_wrapper import retry_builder logger = setup_logger() class DiscoursePerms(BaseModel): api_key: str api_username: str @retry_builder() def discourse_request( endpoint: str, perms: DiscoursePerms, params: dict | None = None ) -> Response: headers = {"Api-Key": perms.api_key, "Api-Username": perms.api_username} response = requests.get(endpoint, headers=headers, params=params) response.raise_for_status() return response class DiscourseConnector(PollConnector): def __init__( self, base_url: str, categories: list[str] | None = None, batch_size: int = INDEX_BATCH_SIZE, ) -> None: parsed_url = urllib.parse.urlparse(base_url) if not parsed_url.scheme: base_url = "https://" + base_url self.base_url = base_url self.categories = [c.lower() for c in categories] if categories else [] self.category_id_map: dict[int, dict] = {} self.batch_size = batch_size self.permissions: DiscoursePerms | None = None self.active_categories: set | None = None @rate_limit_builder(max_calls=50, period=60) def _make_request(self, endpoint: str, params: dict | None = None) -> Response: if not self.permissions: raise ConnectorMissingCredentialError("Discourse") return discourse_request(endpoint, self.permissions, params) def _get_categories_map( self, ) -> None: assert self.permissions is not None categories_endpoint = urllib.parse.urljoin(self.base_url, "categories.json") response = self._make_request( endpoint=categories_endpoint, params={"include_subcategories": True}, ) categories = response.json()["category_list"]["categories"] self.category_id_map = { cat["id"]: {"name": cat["name"], "slug": cat["slug"]} for cat in categories if not self.categories or cat["name"].lower() in self.categories } self.active_categories = set(self.category_id_map) def _get_doc_from_topic(self, topic_id: int) -> Document: assert self.permissions is not None topic_endpoint = urllib.parse.urljoin(self.base_url, f"t/{topic_id}.json") response = self._make_request(endpoint=topic_endpoint) topic = response.json() topic_url = urllib.parse.urljoin(self.base_url, f"t/{topic['slug']}") sections = [] poster = None responders = [] seen_names = set() for ind, post in enumerate(topic["post_stream"]["posts"]): if ind == 0: poster_name = post.get("name") if poster_name: seen_names.add(poster_name) poster = BasicExpertInfo(display_name=poster_name) else: responder_name = post.get("name") if responder_name and responder_name not in seen_names: seen_names.add(responder_name) responders.append(BasicExpertInfo(display_name=responder_name)) sections.append( TextSection(link=topic_url, text=parse_html_page_basic(post["cooked"])) ) category_name = self.category_id_map.get(topic["category_id"], {}).get("name") metadata: dict[str, str | list[str]] = ( { "category": category_name, } if category_name else {} ) if topic.get("tags"): metadata["tags"] = topic["tags"] doc = Document( id="_".join([DocumentSource.DISCOURSE.value, str(topic["id"])]), sections=cast(list[TextSection | ImageSection], sections), source=DocumentSource.DISCOURSE, semantic_identifier=topic["title"], doc_updated_at=time_str_to_utc(topic["last_posted_at"]), primary_owners=[poster] if poster else None, secondary_owners=responders or None, metadata=metadata, ) return doc def _get_latest_topics( self, start: datetime | None, end: datetime | None, page: int ) -> list[int]: assert self.permissions is not None topic_ids = [] if not self.categories: latest_endpoint = urllib.parse.urljoin( self.base_url, f"latest.json?page={page}" ) response = self._make_request(endpoint=latest_endpoint) topics = response.json()["topic_list"]["topics"] else: topics = [] empty_categories = [] for category_id, category_dict in self.category_id_map.items(): category_endpoint = urllib.parse.urljoin( self.base_url, f"c/{category_dict['slug']}/{category_id}.json?page={page}&sys=latest", ) response = self._make_request(endpoint=category_endpoint) new_topics = response.json()["topic_list"]["topics"] if len(new_topics) == 0: empty_categories.append(category_id) topics.extend(new_topics) for empty_category in empty_categories: self.category_id_map.pop(empty_category) for topic in topics: last_time = topic.get("last_posted_at") if not last_time: continue last_time_dt = time_str_to_utc(last_time) if (start and start > last_time_dt) or (end and end < last_time_dt): continue topic_ids.append(topic["id"]) return topic_ids def _yield_discourse_documents( self, start: datetime, end: datetime, ) -> GenerateDocumentsOutput: page = 0 while topic_ids := self._get_latest_topics(start, end, page): doc_batch: list[Document | HierarchyNode] = [] for topic_id in topic_ids: doc_batch.append(self._get_doc_from_topic(topic_id)) if len(doc_batch) >= self.batch_size: yield doc_batch doc_batch = [] if doc_batch: yield doc_batch page += 1 def load_credentials( self, credentials: dict[str, Any], ) -> dict[str, Any] | None: self.permissions = DiscoursePerms( api_key=credentials["discourse_api_key"], api_username=credentials["discourse_api_username"], ) return None def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: if self.permissions is None: raise ConnectorMissingCredentialError("Discourse") start_datetime = datetime.utcfromtimestamp(start).replace(tzinfo=timezone.utc) end_datetime = datetime.utcfromtimestamp(end).replace(tzinfo=timezone.utc) self._get_categories_map() yield from self._yield_discourse_documents(start_datetime, end_datetime) if __name__ == "__main__": import os connector = DiscourseConnector(base_url=os.environ["DISCOURSE_BASE_URL"]) connector.load_credentials( { "discourse_api_key": os.environ["DISCOURSE_API_KEY"], "discourse_api_username": os.environ["DISCOURSE_API_USERNAME"], } ) current = time.time() one_year_ago = current - 24 * 60 * 60 * 360 latest_docs = connector.poll_source(one_year_ago, current) print(next(latest_docs)) ================================================ FILE: backend/onyx/connectors/document360/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/document360/connector.py ================================================ from datetime import datetime from datetime import timezone from typing import Any from typing import List from typing import Optional import requests from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.rate_limit_wrapper import ( rate_limit_builder, ) from onyx.connectors.document360.utils import flatten_child_categories from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import TextSection from onyx.file_processing.html_utils import parse_html_page_basic from onyx.utils.retry_wrapper import retry_builder # Limitations and Potential Improvements # 1. The "Categories themselves contain potentially relevant information" but they're not pulled in # 2. Only the HTML Articles are supported, Document360 also has a Markdown and "Block" format # 3. The contents are not as cleaned up as other HTML connectors DOCUMENT360_BASE_URL = "https://portal.document360.io" DOCUMENT360_API_BASE_URL = "https://apihub.document360.io/v2" class Document360Connector(LoadConnector, PollConnector): def __init__( self, workspace: str, categories: List[str] | None = None, batch_size: int = INDEX_BATCH_SIZE, portal_id: Optional[str] = None, api_token: Optional[str] = None, ) -> None: self.portal_id = portal_id self.workspace = workspace self.categories = categories self.batch_size = batch_size self.api_token = api_token def load_credentials(self, credentials: dict[str, Any]) -> Optional[dict[str, Any]]: self.api_token = credentials.get("document360_api_token") self.portal_id = credentials.get("portal_id") return None # rate limiting set based on the enterprise plan: https://apidocs.document360.com/apidocs/rate-limiting # NOTE: retry will handle cases where user is not on enterprise plan - we will just hit the rate limit # and then retry after a period @retry_builder() @rate_limit_builder(max_calls=100, period=60) def _make_request(self, endpoint: str, params: Optional[dict] = None) -> Any: if not self.api_token: raise ConnectorMissingCredentialError("Document360") headers = {"accept": "application/json", "api_token": self.api_token} response = requests.get( f"{DOCUMENT360_API_BASE_URL}/{endpoint}", headers=headers, params=params ) response.raise_for_status() return response.json()["data"] def _get_workspace_id_by_name(self) -> str: projects = self._make_request("ProjectVersions") workspace_id = next( ( project["id"] for project in projects if project["version_code_name"] == self.workspace ), None, ) if workspace_id is None: raise ValueError("Not able to find Workspace ID by the user provided name") return workspace_id def _get_articles_with_category(self, workspace_id: str) -> Any: all_categories = self._make_request( f"ProjectVersions/{workspace_id}/categories" ) articles_with_category = [] for category in all_categories: if not self.categories or category["name"] in self.categories: for article in category["articles"]: articles_with_category.append( {"id": article["id"], "category_name": category["name"]} ) for child_category in category["child_categories"]: all_nested_categories = flatten_child_categories(child_category) for nested_category in all_nested_categories: for article in nested_category["articles"]: articles_with_category.append( { "id": article["id"], "category_name": nested_category["name"], } ) return articles_with_category def _process_articles( self, start: datetime | None = None, end: datetime | None = None ) -> GenerateDocumentsOutput: if self.api_token is None: raise ConnectorMissingCredentialError("Document360") workspace_id = self._get_workspace_id_by_name() articles = self._get_articles_with_category(workspace_id) doc_batch: List[Document | HierarchyNode] = [] for article in articles: article_details = self._make_request( f"Articles/{article['id']}", {"langCode": "en"} ) updated_at = datetime.strptime( article_details["modified_at"], "%Y-%m-%dT%H:%M:%S.%fZ" ).replace(tzinfo=timezone.utc) if start is not None and updated_at < start: continue if end is not None and updated_at > end: continue authors = [ BasicExpertInfo( display_name=author.get("name"), email=author["email_id"] ) for author in article_details.get("authors", []) if author["email_id"] ] doc_link = ( article_details["url"] if article_details.get("url") else f"{DOCUMENT360_BASE_URL}/{self.portal_id}/document/v1/view/{article['id']}" ) html_content = article_details["html_content"] article_content = ( parse_html_page_basic(html_content) if html_content is not None else "" ) doc_text = ( f"{article_details.get('description', '')}\n{article_content}".strip() ) document = Document( id=article_details["id"], sections=[TextSection(link=doc_link, text=doc_text)], source=DocumentSource.DOCUMENT360, semantic_identifier=article_details["title"], doc_updated_at=updated_at, primary_owners=authors, metadata={ "workspace": self.workspace, "category": article["category_name"], }, ) doc_batch.append(document) if len(doc_batch) >= self.batch_size: yield doc_batch doc_batch = [] if doc_batch: yield doc_batch def load_from_state(self) -> GenerateDocumentsOutput: return self._process_articles() def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: start_datetime = datetime.fromtimestamp(start, tz=timezone.utc) end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) return self._process_articles(start_datetime, end_datetime) if __name__ == "__main__": import time import os document360_connector = Document360Connector(os.environ["DOCUMENT360_WORKSPACE"]) document360_connector.load_credentials( { "portal_id": os.environ["DOCUMENT360_PORTAL_ID"], "document360_api_token": os.environ["DOCUMENT360_API_TOKEN"], } ) current = time.time() one_year_ago = current - 24 * 60 * 60 * 360 latest_docs = document360_connector.poll_source(one_year_ago, current) for doc in latest_docs: print(doc) ================================================ FILE: backend/onyx/connectors/document360/utils.py ================================================ def flatten_child_categories(category: dict) -> list[dict]: if not category["child_categories"]: return [category] else: flattened_categories = [category] for child_category in category["child_categories"]: flattened_categories.extend(flatten_child_categories(child_category)) return flattened_categories ================================================ FILE: backend/onyx/connectors/dropbox/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/dropbox/connector.py ================================================ from datetime import timezone from io import BytesIO from typing import Any from dropbox import Dropbox # type: ignore[import-untyped] from dropbox.exceptions import ApiError # type: ignore[import-untyped] from dropbox.exceptions import AuthError from dropbox.files import FileMetadata # type: ignore[import-untyped] from dropbox.files import FolderMetadata from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialInvalidError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import TextSection from onyx.file_processing.extract_file_text import extract_file_text from onyx.utils.logger import setup_logger logger = setup_logger() class DropboxConnector(LoadConnector, PollConnector): def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None: self.batch_size = batch_size self.dropbox_client: Dropbox | None = None def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self.dropbox_client = Dropbox(credentials["dropbox_access_token"]) return None def _download_file(self, path: str) -> bytes: """Download a single file from Dropbox.""" if self.dropbox_client is None: raise ConnectorMissingCredentialError("Dropbox") _, resp = self.dropbox_client.files_download(path) return resp.content def _get_shared_link(self, path: str) -> str: """Create a shared link for a file in Dropbox.""" if self.dropbox_client is None: raise ConnectorMissingCredentialError("Dropbox") try: # Check if a shared link already exists shared_links = self.dropbox_client.sharing_list_shared_links(path=path) if shared_links.links: return shared_links.links[0].url link_metadata = ( self.dropbox_client.sharing_create_shared_link_with_settings(path) ) return link_metadata.url except ApiError as err: logger.exception(f"Failed to create a shared link for {path}: {err}") return "" def _yield_files_recursive( self, path: str, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None, ) -> GenerateDocumentsOutput: """Yield files in batches from a specified Dropbox folder, including subfolders.""" if self.dropbox_client is None: raise ConnectorMissingCredentialError("Dropbox") result = self.dropbox_client.files_list_folder( path, limit=self.batch_size, recursive=False, include_non_downloadable_files=False, ) while True: batch: list[Document | HierarchyNode] = [] for entry in result.entries: if isinstance(entry, FileMetadata): modified_time = entry.client_modified if modified_time.tzinfo is None: # If no timezone info, assume it is UTC modified_time = modified_time.replace(tzinfo=timezone.utc) else: # If not in UTC, translate it modified_time = modified_time.astimezone(timezone.utc) time_as_seconds = int(modified_time.timestamp()) if start and time_as_seconds < start: continue if end and time_as_seconds > end: continue downloaded_file = self._download_file(entry.path_display) link = self._get_shared_link(entry.path_display) try: text = extract_file_text( BytesIO(downloaded_file), file_name=entry.name, break_on_unprocessable=False, ) batch.append( Document( id=f"doc:{entry.id}", sections=[TextSection(link=link, text=text)], source=DocumentSource.DROPBOX, semantic_identifier=entry.name, doc_updated_at=modified_time, metadata={"type": "article"}, ) ) except Exception as e: logger.exception( f"Error decoding file {entry.path_display} as utf-8 error occurred: {e}" ) elif isinstance(entry, FolderMetadata): yield from self._yield_files_recursive(entry.path_lower, start, end) if batch: yield batch if not result.has_more: break result = self.dropbox_client.files_list_folder_continue(result.cursor) def load_from_state(self) -> GenerateDocumentsOutput: return self.poll_source(None, None) def poll_source( self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None ) -> GenerateDocumentsOutput: if self.dropbox_client is None: raise ConnectorMissingCredentialError("Dropbox") for batch in self._yield_files_recursive("", start, end): yield batch return None def validate_connector_settings(self) -> None: if self.dropbox_client is None: raise ConnectorMissingCredentialError("Dropbox credentials not loaded.") try: self.dropbox_client.files_list_folder(path="", limit=1) except AuthError as e: logger.exception("Failed to validate Dropbox credentials") raise CredentialInvalidError(f"Dropbox credential is invalid: {e.error}") except ApiError as e: if ( e.error is not None and "insufficient_permissions" in str(e.error).lower() ): raise InsufficientPermissionsError( "Your Dropbox token does not have sufficient permissions." ) raise ConnectorValidationError( f"Unexpected Dropbox error during validation: {e.user_message_text or e}" ) except Exception as e: raise Exception(f"Unexpected error during Dropbox settings validation: {e}") if __name__ == "__main__": import os connector = DropboxConnector() connector.load_credentials( { "dropbox_access_token": os.environ["DROPBOX_ACCESS_TOKEN"], } ) document_batches = connector.load_from_state() print(next(document_batches)) ================================================ FILE: backend/onyx/connectors/drupal_wiki/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/drupal_wiki/connector.py ================================================ import mimetypes from io import BytesIO from typing import Any import requests from typing_extensions import override from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE from onyx.configs.app_configs import DRUPAL_WIKI_ATTACHMENT_SIZE_THRESHOLD from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.configs.constants import FileOrigin from onyx.connectors.cross_connector_utils.miscellaneous_utils import ( datetime_from_utc_timestamp, ) from onyx.connectors.cross_connector_utils.rate_limit_wrapper import rate_limit_builder from onyx.connectors.cross_connector_utils.rate_limit_wrapper import rl_requests from onyx.connectors.drupal_wiki.models import DrupalWikiCheckpoint from onyx.connectors.drupal_wiki.models import DrupalWikiPage from onyx.connectors.drupal_wiki.models import DrupalWikiPageResponse from onyx.connectors.drupal_wiki.models import DrupalWikiSpaceResponse from onyx.connectors.drupal_wiki.utils import build_drupal_wiki_document_id from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.interfaces import CheckpointedConnector from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import ConnectorFailure from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnector from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import DocumentFailure from onyx.connectors.models import HierarchyNode from onyx.connectors.models import ImageSection from onyx.connectors.models import SlimDocument from onyx.connectors.models import TextSection from onyx.file_processing.extract_file_text import extract_text_and_images from onyx.file_processing.extract_file_text import get_file_ext from onyx.file_processing.file_types import OnyxFileExtensions from onyx.file_processing.html_utils import parse_html_page_basic from onyx.file_processing.image_utils import store_image_and_create_section from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.b64 import get_image_type_from_bytes from onyx.utils.logger import setup_logger from onyx.utils.retry_wrapper import retry_builder logger = setup_logger() MAX_API_PAGE_SIZE = 2000 # max allowed by API DRUPAL_WIKI_SPACE_KEY = "space" rate_limited_get = retry_builder()( rate_limit_builder(max_calls=10, period=1)(rl_requests.get) ) class DrupalWikiConnector( CheckpointedConnector[DrupalWikiCheckpoint], SlimConnector, ): # Deprecated parameters that may exist in old connector configurations _DEPRECATED_PARAMS = {"drupal_wiki_scope", "include_all_spaces"} def __init__( self, base_url: str, spaces: list[str] | None = None, pages: list[str] | None = None, batch_size: int = INDEX_BATCH_SIZE, continue_on_failure: bool = CONTINUE_ON_CONNECTOR_FAILURE, include_attachments: bool = False, allow_images: bool = False, **kwargs: Any, ) -> None: """ Initialize the Drupal Wiki connector. Args: base_url: The base URL of the Drupal Wiki instance (e.g., https://help.drupal-wiki.com) spaces: List of space IDs to index. If None and pages is also None, all spaces will be indexed. pages: List of page IDs to index. If provided, these specific pages will be indexed. batch_size: Number of documents to process in a batch. continue_on_failure: If True, continue indexing even if some documents fail. include_attachments: If True, enable processing of page attachments including images and documents. allow_images: If True, enable processing of image attachments. """ ######################################################### # TODO: Remove this after 02/01/2026 and remove **kwargs from the function signature # Check for deprecated parameters from old connector configurations # If attempting to update without deleting the connector: # Remove the deprecated parameters from the custom_connector_config in the relevant connector table rows deprecated_found = set(kwargs.keys()) & self._DEPRECATED_PARAMS if deprecated_found: raise ConnectorValidationError( f"Outdated Drupal Wiki connector configuration detected " f"(found deprecated parameters: {', '.join(deprecated_found)}). " f"Please delete and recreate this connector, or contact Onyx support " f"for assistance with updating the configuration without deleting the connector." ) # Reject any other unexpected parameters if kwargs: raise ConnectorValidationError( f"Unexpected parameters for Drupal Wiki connector: {', '.join(kwargs.keys())}" ) ######################################################### self.base_url = base_url.rstrip("/") self.spaces = spaces or [] self.pages = pages or [] # If no specific spaces or pages are provided, index all spaces self.include_all_spaces = not self.spaces and not self.pages self.batch_size = batch_size self.continue_on_failure = continue_on_failure # Attachment processing configuration self.include_attachments = include_attachments self.allow_images = allow_images self.headers: dict[str, str] = {"Accept": "application/json"} self._api_token: str | None = None # set by load_credentials def set_allow_images(self, value: bool) -> None: logger.info(f"Setting allow_images to {value}.") self.allow_images = value def _get_page_attachments(self, page_id: int) -> list[dict[str, Any]]: """ Get all attachments for a specific page. Args: page_id: ID of the page. Returns: List of attachment dictionaries. """ url = f"{self.base_url}/api/rest/scope/api/attachment" params = {"pageId": str(page_id)} logger.debug(f"Fetching attachments for page {page_id} from {url}") try: response = rate_limited_get(url, headers=self.headers, params=params) response.raise_for_status() attachments = response.json() logger.info(f"Found {len(attachments)} attachments for page {page_id}") return attachments except Exception as e: logger.warning(f"Failed to fetch attachments for page {page_id}: {e}") return [] def _download_attachment(self, attachment_id: int) -> bytes: """ Download attachment content. Args: attachment_id: ID of the attachment to download. Returns: Raw bytes of the attachment. """ url = f"{self.base_url}/api/rest/scope/api/attachment/{attachment_id}/download" logger.info(f"Downloading attachment {attachment_id} from {url}") # Use headers without Accept for binary downloads download_headers = {"Authorization": f"Bearer {self._api_token}"} response = rate_limited_get(url, headers=download_headers) response.raise_for_status() return response.content def _validate_attachment_filetype(self, attachment: dict[str, Any]) -> bool: """ Validate if the attachment file type is supported. Args: attachment: Attachment dictionary from Drupal Wiki API. Returns: True if the file type is supported, False otherwise. """ file_name = attachment.get("fileName", "") if not file_name: return False # Get file extension file_extension = get_file_ext(file_name) if file_extension in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS: return True logger.warning(f"Unsupported file type: {file_extension} for {file_name}") return False def _get_media_type_from_filename(self, filename: str) -> str: """ Get media type from filename using the standard mimetypes library. Args: filename: The filename. Returns: Media type string. """ mime_type, _encoding = mimetypes.guess_type(filename) return mime_type or "application/octet-stream" def _process_attachment( self, attachment: dict[str, Any], page_id: int, download_url: str, ) -> tuple[list[TextSection | ImageSection], str | None]: """ Process a single attachment and return generated sections. Args: attachment: Attachment dictionary from Drupal Wiki API. page_id: ID of the parent page. download_url: Direct download URL for the attachment. Returns: Tuple of (sections, error_message). If error_message is not None, the sections list should be treated as invalid. """ sections: list[TextSection | ImageSection] = [] try: if not self._validate_attachment_filetype(attachment): return ( [], f"Unsupported file type: {attachment.get('fileName', 'unknown')}", ) attachment_id = attachment["id"] file_name = attachment.get("fileName", f"attachment_{attachment_id}") file_size = attachment.get("fileSize", 0) media_type = self._get_media_type_from_filename(file_name) if file_size > DRUPAL_WIKI_ATTACHMENT_SIZE_THRESHOLD: return [], f"Attachment too large: {file_size} bytes" try: raw_bytes = self._download_attachment(attachment_id) except Exception as e: return [], f"Failed to download attachment: {e}" if media_type.startswith("image/"): if not self.allow_images: logger.info( f"Skipping image attachment {file_name} because allow_images is False", ) return [], None try: image_section, _ = store_image_and_create_section( image_data=raw_bytes, file_id=str(attachment_id), display_name=attachment.get( "name", attachment.get("fileName", "Unknown") ), link=download_url, media_type=media_type, file_origin=FileOrigin.CONNECTOR, ) sections.append(image_section) logger.debug(f"Stored image attachment with file name: {file_name}") except Exception as e: return [], f"Image storage failed: {e}" return sections, None image_counter = 0 def _store_embedded_image(image_data: bytes, image_name: str) -> None: nonlocal image_counter if not self.allow_images: return media_for_image = self._get_media_type_from_filename(image_name) if media_for_image == "application/octet-stream": try: media_for_image = get_image_type_from_bytes(image_data) except ValueError: logger.warning( f"Unable to determine media type for embedded image {image_name} on attachment {file_name}" ) image_counter += 1 display_name = ( image_name or f"{attachment.get('name', file_name)} - embedded image {image_counter}" ) try: image_section, _ = store_image_and_create_section( image_data=image_data, file_id=f"{attachment_id}_embedded_{image_counter}", display_name=display_name, link=download_url, media_type=media_for_image, file_origin=FileOrigin.CONNECTOR, ) sections.append(image_section) except Exception as err: logger.warning( f"Failed to store embedded image {image_name or image_counter} for attachment {file_name}: {err}" ) extraction_result = extract_text_and_images( file=BytesIO(raw_bytes), file_name=file_name, content_type=media_type, image_callback=_store_embedded_image if self.allow_images else None, ) text_content = extraction_result.text_content.strip() if text_content: sections.insert(0, TextSection(text=text_content, link=download_url)) logger.info( f"Extracted {len(text_content)} characters from {file_name}" ) elif not sections: return [], f"No text extracted for {file_name}" return sections, None except Exception as e: logger.error( f"Failed to process attachment {attachment.get('name', 'unknown')} on page {page_id}: {e}" ) return [], f"Failed to process attachment: {e}" def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: """ Load credentials for the Drupal Wiki connector. Args: credentials: Dictionary containing the API token. Returns: None """ api_token = credentials.get("drupal_wiki_api_token", "").strip() if not api_token: raise ConnectorValidationError( "API token is required for Drupal Wiki connector" ) self._api_token = api_token self.headers.update( { "Authorization": f"Bearer {api_token}", } ) return None def _get_space_ids(self) -> list[int]: """ Get all space IDs from the Drupal Wiki instance. Returns: List of space IDs (deduplicated). The list is sorted to be deterministic. """ url = f"{self.base_url}/api/rest/scope/api/space" size = MAX_API_PAGE_SIZE page = 0 all_space_ids: set[int] = set() has_more = True last_num_ids = -1 while has_more and len(all_space_ids) > last_num_ids: last_num_ids = len(all_space_ids) params = {"size": size, "page": page} logger.debug(f"Fetching spaces from {url} (page={page}, size={size})") response = rate_limited_get(url, headers=self.headers, params=params) response.raise_for_status() resp_json = response.json() space_response = DrupalWikiSpaceResponse.model_validate(resp_json) logger.info(f"Fetched {len(space_response.content)} spaces from {page}") # Collect ids into the set to deduplicate for space in space_response.content: all_space_ids.add(space.id) # Continue if we got a full page, indicating there might be more has_more = len(space_response.content) >= size page += 1 # Return a deterministic, sorted list of ids space_id_list = list(sorted(all_space_ids)) logger.debug(f"Total spaces fetched: {len(space_id_list)}") return space_id_list def _get_pages_for_space( self, space_id: int, modified_after: SecondsSinceUnixEpoch | None = None ) -> list[DrupalWikiPage]: """ Get all pages for a specific space, optionally filtered by modification time. Args: space_id: ID of the space. modified_after: Only return pages modified after this timestamp (seconds since Unix epoch). Returns: List of DrupalWikiPage objects. """ url = f"{self.base_url}/api/rest/scope/api/page" size = MAX_API_PAGE_SIZE page = 0 all_pages = [] has_more = True while has_more: params: dict[str, str | int] = { DRUPAL_WIKI_SPACE_KEY: str(space_id), "size": size, "page": page, } # Add modifiedAfter parameter if provided if modified_after is not None: params["modifiedAfter"] = int(modified_after) logger.debug( f"Fetching pages for space {space_id} from {url} ({page=}, {size=}, {modified_after=})" ) response = rate_limited_get(url, headers=self.headers, params=params) response.raise_for_status() resp_json = response.json() try: page_response = DrupalWikiPageResponse.model_validate(resp_json) except Exception as e: logger.error(f"Failed to validate Drupal Wiki page response: {e}") raise ConnectorValidationError(f"Invalid API response format: {e}") logger.info( f"Fetched {len(page_response.content)} pages in space {space_id} (page={page})" ) # Pydantic should automatically parse content items as DrupalWikiPage objects # If validation fails, it will raise an exception which we should catch all_pages.extend(page_response.content) # Continue if we got a full page, indicating there might be more has_more = len(page_response.content) >= size page += 1 logger.debug(f"Total pages fetched for space {space_id}: {len(all_pages)}") return all_pages def _get_page_content(self, page_id: int) -> DrupalWikiPage: """ Get the content of a specific page. Args: page_id: ID of the page. Returns: DrupalWikiPage object. """ url = f"{self.base_url}/api/rest/scope/api/page/{page_id}" response = rate_limited_get(url, headers=self.headers) response.raise_for_status() return DrupalWikiPage.model_validate(response.json()) def _process_page(self, page: DrupalWikiPage) -> Document | ConnectorFailure: """ Process a page and convert it to a Document. Args: page: DrupalWikiPage object. Returns: Document object or ConnectorFailure. """ try: # Extract text from HTML, handle None body text_content = parse_html_page_basic(page.body or "") # Ensure text_content is a string, not None if text_content is None: text_content = "" # Create document URL page_url = build_drupal_wiki_document_id(self.base_url, page.id) # Create sections with just the page content sections: list[TextSection | ImageSection] = [ TextSection(text=text_content, link=page_url) ] # Only process attachments if self.include_attachments is True if self.include_attachments: attachments = self._get_page_attachments(page.id) for attachment in attachments: logger.info( f"Processing attachment: {attachment.get('name', 'Unknown')} (ID: {attachment['id']})" ) # Use downloadUrl from API; fallback to page URL raw_download = attachment.get("downloadUrl") if raw_download: download_url = ( raw_download if raw_download.startswith("http") else f"{self.base_url.rstrip('/')}" + raw_download ) else: download_url = page_url # Process the attachment attachment_sections, error = self._process_attachment( attachment, page.id, download_url ) if error: logger.warning( f"Error processing attachment {attachment.get('name', 'Unknown')}: {error}" ) continue if attachment_sections: sections.extend(attachment_sections) logger.debug( f"Added {len(attachment_sections)} section(s) for attachment {attachment.get('name', 'Unknown')}" ) # Create metadata metadata: dict[str, str | list[str]] = { "space_id": str(page.homeSpace), "page_id": str(page.id), "type": page.type, } # Create document return Document( id=page_url, sections=sections, source=DocumentSource.DRUPAL_WIKI, semantic_identifier=page.title, metadata=metadata, doc_updated_at=datetime_from_utc_timestamp(page.lastModified), ) except Exception as e: logger.error(f"Error processing page {page.id}: {e}") return ConnectorFailure( failed_document=DocumentFailure( document_id=str(page.id), document_link=build_drupal_wiki_document_id(self.base_url, page.id), ), failure_message=f"Error processing page {page.id}: {e}", exception=e, ) @override def load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: DrupalWikiCheckpoint, ) -> CheckpointOutput[DrupalWikiCheckpoint]: """ Load documents from a checkpoint. Args: start: Start time as seconds since Unix epoch. end: End time as seconds since Unix epoch. checkpoint: Checkpoint to resume from. Returns: Generator yielding documents and the updated checkpoint. """ # Ensure page_ids is not None if checkpoint.page_ids is None: checkpoint.page_ids = [] # Initialize page_ids from self.pages if not already set if not checkpoint.page_ids and self.pages: logger.info(f"Initializing page_ids from self.pages: {self.pages}") checkpoint.page_ids = [int(page_id.strip()) for page_id in self.pages] # Ensure spaces is not None if checkpoint.spaces is None: checkpoint.spaces = [] while checkpoint.current_page_id_index < len(checkpoint.page_ids): page_id = checkpoint.page_ids[checkpoint.current_page_id_index] logger.debug(f"Processing page ID: {page_id}") try: # Get the page content directly page = self._get_page_content(page_id) # Skip pages outside the time range if not self._is_page_in_time_range(page.lastModified, start, end): logger.info(f"Skipping page {page_id} - outside time range") checkpoint.current_page_id_index += 1 continue # Process the page doc_or_failure = self._process_page(page) yield doc_or_failure except Exception as e: logger.error(f"Error processing page ID {page_id}: {e}") yield ConnectorFailure( failed_document=DocumentFailure( document_id=str(page_id), document_link=build_drupal_wiki_document_id( self.base_url, page_id ), ), failure_message=f"Error processing page ID {page_id}: {e}", exception=e, ) # Move to the next page ID checkpoint.current_page_id_index += 1 # TODO: The main benefit of CheckpointedConnectors is that they can "save their work" # by storing a checkpoint so transient errors are easy to recover from: simply resume # from the last checkpoint. The way to get checkpoints saved is to return them somewhere # in the middle of this function. The guarantee our checkpointing system gives to you, # the connector implementer, is that when you return a checkpoint, this connector will # at a later time (generally within a few seconds) call the load_from_checkpoint function # again with the checkpoint you last returned as long as has_more=True. # Process spaces if include_all_spaces is True or spaces are provided if self.include_all_spaces or self.spaces: # If include_all_spaces is True, always fetch all spaces if self.include_all_spaces: logger.info("Fetching all spaces") # Fetch all spaces all_space_ids = self._get_space_ids() # checkpoint.spaces expects a list of ints; assign returned list checkpoint.spaces = all_space_ids logger.info(f"Found {len(checkpoint.spaces)} spaces to process") # Otherwise, use provided spaces if checkpoint is empty elif not checkpoint.spaces: logger.info(f"Using provided spaces: {self.spaces}") # Use provided spaces checkpoint.spaces = [int(space_id.strip()) for space_id in self.spaces] # Process spaces from the checkpoint while checkpoint.current_space_index < len(checkpoint.spaces): space_id = checkpoint.spaces[checkpoint.current_space_index] logger.debug(f"Processing space ID: {space_id}") # Get pages for the current space, filtered by start time if provided pages = self._get_pages_for_space(space_id, modified_after=start) # Process pages from the checkpoint while checkpoint.current_page_index < len(pages): page = pages[checkpoint.current_page_index] logger.debug(f"Processing page: {page.title} (ID: {page.id})") # For space-based pages, we already filtered by modifiedAfter in the API call # Only need to check the end time boundary if end and page.lastModified >= end: logger.info( f"Skipping page {page.id} - outside time range (after end)" ) checkpoint.current_page_index += 1 continue # Process the page doc_or_failure = self._process_page(page) yield doc_or_failure # Move to the next page checkpoint.current_page_index += 1 # Move to the next space checkpoint.current_space_index += 1 checkpoint.current_page_index = 0 # All spaces and pages processed logger.info("Finished processing all spaces and pages") checkpoint.has_more = False return checkpoint @override def build_dummy_checkpoint(self) -> DrupalWikiCheckpoint: """ Build a dummy checkpoint. Returns: DrupalWikiCheckpoint with default values. """ return DrupalWikiCheckpoint( has_more=True, current_space_index=0, current_page_index=0, current_page_id_index=0, spaces=[], page_ids=[], is_processing_specific_pages=False, ) @override def validate_checkpoint_json(self, checkpoint_json: str) -> DrupalWikiCheckpoint: """ Validate a checkpoint JSON string. Args: checkpoint_json: JSON string representing a checkpoint. Returns: Validated DrupalWikiCheckpoint. """ return DrupalWikiCheckpoint.model_validate_json(checkpoint_json) # TODO: unify approach with load_from_checkpoint. # Ideally slim retrieval shares a lot of the same code with non-slim # and we pass in a param is_slim to the main helper function # that does the retrieval. @override def retrieve_all_slim_docs( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> GenerateSlimDocumentOutput: """ Retrieve all slim documents. Args: start: Start time as seconds since Unix epoch. end: End time as seconds since Unix epoch. callback: Callback for indexing heartbeat. Returns: Generator yielding batches of SlimDocument objects. """ slim_docs: list[SlimDocument | HierarchyNode] = [] logger.info( f"Starting retrieve_all_slim_docs with include_all_spaces={self.include_all_spaces}, spaces={self.spaces}" ) # Process specific page IDs if provided if self.pages: logger.info(f"Processing specific pages: {self.pages}") for page_id in self.pages: try: # Get the page content directly page_content = self._get_page_content(int(page_id.strip())) # Skip pages outside the time range if not self._is_page_in_time_range( page_content.lastModified, start, end ): logger.info(f"Skipping page {page_id} - outside time range") continue # Create slim document for the page page_url = build_drupal_wiki_document_id( self.base_url, page_content.id ) slim_docs.append( SlimDocument( id=page_url, ) ) logger.debug(f"Added slim document for page {page_content.id}") # Process attachments for this page attachments = self._get_page_attachments(page_content.id) for attachment in attachments: if self._validate_attachment_filetype(attachment): attachment_url = f"{page_url}#attachment-{attachment['id']}" slim_docs.append( SlimDocument( id=attachment_url, ) ) logger.debug( f"Added slim document for attachment {attachment['id']}" ) # Yield batch if it reaches the batch size if len(slim_docs) >= self.batch_size: logger.debug( f"Yielding batch of {len(slim_docs)} slim documents" ) yield slim_docs slim_docs = [] if callback and callback.should_stop(): return if callback: callback.progress("retrieve_all_slim_docs", 1) except Exception as e: logger.error( f"Error processing page ID {page_id} for slim documents: {e}" ) # Process spaces if include_all_spaces is True or spaces are provided if self.include_all_spaces or self.spaces: logger.info("Processing spaces for slim documents") # Get spaces to process spaces_to_process = [] if self.include_all_spaces: logger.info("Fetching all spaces for slim documents") # Fetch all spaces all_space_ids = self._get_space_ids() spaces_to_process = all_space_ids logger.info(f"Found {len(spaces_to_process)} spaces to process") else: logger.info(f"Using provided spaces: {self.spaces}") # Use provided spaces spaces_to_process = [int(space_id.strip()) for space_id in self.spaces] # Process each space for space_id in spaces_to_process: logger.info(f"Processing space ID: {space_id}") # Get pages for the current space, filtered by start time if provided pages = self._get_pages_for_space(space_id, modified_after=start) # Process each page for page in pages: logger.debug(f"Processing page: {page.title} (ID: {page.id})") # Skip pages outside the time range if end and page.lastModified >= end: logger.info( f"Skipping page {page.id} - outside time range (after end)" ) continue # Create slim document for the page page_url = build_drupal_wiki_document_id(self.base_url, page.id) slim_docs.append( SlimDocument( id=page_url, ) ) logger.info(f"Added slim document for page {page.id}") # Process attachments for this page attachments = self._get_page_attachments(page.id) for attachment in attachments: if self._validate_attachment_filetype(attachment): attachment_url = f"{page_url}#attachment-{attachment['id']}" slim_docs.append( SlimDocument( id=attachment_url, ) ) logger.info( f"Added slim document for attachment {attachment['id']}" ) # Yield batch if it reaches the batch size if len(slim_docs) >= self.batch_size: logger.info( f"Yielding batch of {len(slim_docs)} slim documents" ) yield slim_docs slim_docs = [] if callback and callback.should_stop(): return if callback: callback.progress("retrieve_all_slim_docs", 1) # Yield remaining documents if slim_docs: logger.debug(f"Yielding final batch of {len(slim_docs)} slim documents") yield slim_docs def validate_connector_settings(self) -> None: """ Validate the connector settings. Raises: ConnectorValidationError: If the settings are invalid. """ if not self.headers: raise ConnectorMissingCredentialError("Drupal Wiki") try: # Try to fetch spaces to validate the connection # Call the new helper which returns the list of space ids self._get_space_ids() except requests.exceptions.RequestException as e: raise ConnectorValidationError(f"Failed to connect to Drupal Wiki: {e}") def _is_page_in_time_range( self, last_modified: int, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None, ) -> bool: """ Check if a page's last modified timestamp falls within the specified time range. Args: last_modified: The page's last modified timestamp. start: Start time as seconds since Unix epoch (inclusive). end: End time as seconds since Unix epoch (exclusive). Returns: True if the page is within the time range, False otherwise. """ return (not start or last_modified >= start) and ( not end or last_modified < end ) ================================================ FILE: backend/onyx/connectors/drupal_wiki/models.py ================================================ from enum import Enum from typing import Generic from typing import List from typing import Optional from typing import TypeVar from pydantic import BaseModel from onyx.connectors.interfaces import ConnectorCheckpoint class SpaceAccessStatus(str, Enum): """Enum for Drupal Wiki space access status""" PRIVATE = "PRIVATE" ANONYMOUS = "ANONYMOUS" AUTHENTICATED = "AUTHENTICATED" class DrupalWikiSpace(BaseModel): """Model for a Drupal Wiki space""" id: int name: str type: str description: Optional[str] = None accessStatus: Optional[SpaceAccessStatus] = None color: Optional[str] = None class DrupalWikiPage(BaseModel): """Model for a Drupal Wiki page""" id: int title: str homeSpace: int lastModified: int type: str body: Optional[str] = None T = TypeVar("T") class DrupalWikiBaseResponse(BaseModel, Generic[T]): """Base model for Drupal Wiki API responses""" totalPages: int totalElements: int size: int content: List[T] number: int first: bool last: bool numberOfElements: int empty: bool class DrupalWikiSpaceResponse(DrupalWikiBaseResponse[DrupalWikiSpace]): """Model for the response from the Drupal Wiki spaces API""" class DrupalWikiPageResponse(DrupalWikiBaseResponse[DrupalWikiPage]): """Model for the response from the Drupal Wiki pages API""" class DrupalWikiCheckpoint(ConnectorCheckpoint): """Checkpoint for the Drupal Wiki connector""" current_space_index: int = 0 current_page_index: int = 0 current_page_id_index: int = 0 spaces: List[int] = [] page_ids: List[int] = [] is_processing_specific_pages: bool = False ================================================ FILE: backend/onyx/connectors/drupal_wiki/utils.py ================================================ from onyx.utils.logger import setup_logger logger = setup_logger() def build_drupal_wiki_document_id(base_url: str, page_id: int) -> str: """Build a document ID for a Drupal Wiki page using the real URL format""" # Ensure base_url ends with a slash base_url = base_url.rstrip("/") + "/" return f"{base_url}node/{page_id}" ================================================ FILE: backend/onyx/connectors/egnyte/connector.py ================================================ import io import os from collections.abc import Generator from datetime import datetime from datetime import timezone from typing import Any from typing import IO from urllib.parse import quote from pydantic import Field from onyx.configs.app_configs import EGNYTE_CLIENT_ID from onyx.configs.app_configs import EGNYTE_CLIENT_SECRET from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.miscellaneous_utils import ( get_oauth_callback_uri, ) from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import OAuthConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import TextSection from onyx.file_processing.extract_file_text import detect_encoding from onyx.file_processing.extract_file_text import extract_file_text from onyx.file_processing.extract_file_text import get_file_ext from onyx.file_processing.extract_file_text import read_text_file from onyx.file_processing.file_types import OnyxFileExtensions from onyx.utils.logger import setup_logger from onyx.utils.retry_wrapper import request_with_retries logger = setup_logger() _EGNYTE_API_BASE = "https://{domain}.egnyte.com/pubapi/v1" _EGNYTE_APP_BASE = "https://{domain}.egnyte.com" def _parse_last_modified(last_modified: str) -> datetime: return datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z").replace( tzinfo=timezone.utc ) def _process_egnyte_file( file_metadata: dict[str, Any], file_content: IO, base_url: str, folder_path: str | None = None, ) -> Document | None: """Process an Egnyte file into a Document object Args: file_data: The file data from Egnyte API file_content: The raw content of the file in bytes base_url: The base URL for the Egnyte instance folder_path: Optional folder path to filter results """ # Skip if file path doesn't match folder path filter if folder_path and not file_metadata["path"].startswith(folder_path): raise ValueError( f"File path {file_metadata['path']} does not match folder path {folder_path}" ) file_name = file_metadata["name"] extension = get_file_ext(file_name) # Explicitly excluding image extensions here. TODO: consider allowing images if extension not in OnyxFileExtensions.TEXT_AND_DOCUMENT_EXTENSIONS: logger.warning(f"Skipping file '{file_name}' with extension '{extension}'") return None # Extract text content based on file type # TODO @wenxi-onyx: convert to extract_text_and_images if extension in OnyxFileExtensions.PLAIN_TEXT_EXTENSIONS: encoding = detect_encoding(file_content) file_content_raw, file_metadata = read_text_file( file_content, encoding=encoding, ignore_onyx_metadata=False ) else: file_content_raw = extract_file_text( file=file_content, file_name=file_name, break_on_unprocessable=True, ) # Build the web URL for the file web_url = f"{base_url}/navigate/file/{file_metadata['group_id']}" # Create document metadata metadata: dict[str, str | list[str]] = { "file_path": file_metadata["path"], "last_modified": file_metadata.get("last_modified", ""), } # Add lock info if present if lock_info := file_metadata.get("lock_info"): metadata["lock_owner"] = ( f"{lock_info.get('first_name', '')} {lock_info.get('last_name', '')}" ) # Create the document owners primary_owner = None if uploaded_by := file_metadata.get("uploaded_by"): primary_owner = BasicExpertInfo( email=uploaded_by, # Using username as email since that's what we have ) # Create the document return Document( id=f"egnyte-{file_metadata['entry_id']}", sections=[TextSection(text=file_content_raw.strip(), link=web_url)], source=DocumentSource.EGNYTE, semantic_identifier=file_name, metadata=metadata, doc_updated_at=( _parse_last_modified(file_metadata["last_modified"]) if "last_modified" in file_metadata else None ), primary_owners=[primary_owner] if primary_owner else None, ) class EgnyteConnector(LoadConnector, PollConnector, OAuthConnector): class AdditionalOauthKwargs(OAuthConnector.AdditionalOauthKwargs): egnyte_domain: str = Field( title="Egnyte Domain", description=( "The domain for the Egnyte instance (e.g. 'company' for company.egnyte.com)" ), ) def __init__( self, folder_path: str | None = None, batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.domain = "" # will always be set in `load_credentials` self.folder_path = folder_path or "" # Root folder if not specified self.batch_size = batch_size self.access_token: str | None = None @classmethod def oauth_id(cls) -> DocumentSource: return DocumentSource.EGNYTE @classmethod def oauth_authorization_url( cls, base_domain: str, state: str, additional_kwargs: dict[str, str], ) -> str: if not EGNYTE_CLIENT_ID: raise ValueError("EGNYTE_CLIENT_ID environment variable must be set") oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs) callback_uri = get_oauth_callback_uri(base_domain, "egnyte") return ( f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token" f"?client_id={EGNYTE_CLIENT_ID}" f"&redirect_uri={callback_uri}" f"&scope=Egnyte.filesystem" f"&state={state}" f"&response_type=code" ) @classmethod def oauth_code_to_token( cls, base_domain: str, code: str, additional_kwargs: dict[str, str], ) -> dict[str, Any]: if not EGNYTE_CLIENT_ID: raise ValueError("EGNYTE_CLIENT_ID environment variable must be set") if not EGNYTE_CLIENT_SECRET: raise ValueError("EGNYTE_CLIENT_SECRET environment variable must be set") oauth_kwargs = cls.AdditionalOauthKwargs(**additional_kwargs) # Exchange code for token url = f"https://{oauth_kwargs.egnyte_domain}.egnyte.com/puboauth/token" redirect_uri = get_oauth_callback_uri(base_domain, "egnyte") data = { "client_id": EGNYTE_CLIENT_ID, "client_secret": EGNYTE_CLIENT_SECRET, "code": code, "grant_type": "authorization_code", "redirect_uri": redirect_uri, "scope": "Egnyte.filesystem", } headers = {"Content-Type": "application/x-www-form-urlencoded"} response = request_with_retries( method="POST", url=url, data=data, headers=headers, # try a lot faster since this is a realtime flow backoff=0, delay=0.1, ) if not response.ok: raise RuntimeError(f"Failed to exchange code for token: {response.text}") token_data = response.json() return { "domain": oauth_kwargs.egnyte_domain, "access_token": token_data["access_token"], } def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self.domain = credentials["domain"] self.access_token = credentials["access_token"] return None def _get_files_list( self, path: str, ) -> Generator[dict[str, Any], None, None]: if not self.access_token or not self.domain: raise ConnectorMissingCredentialError("Egnyte") headers = { "Authorization": f"Bearer {self.access_token}", } params: dict[str, Any] = { "list_content": True, } url_encoded_path = quote(path or "") url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs/{url_encoded_path}" response = request_with_retries( method="GET", url=url, headers=headers, params=params ) if not response.ok: raise RuntimeError(f"Failed to fetch files from Egnyte: {response.text}") data = response.json() # Yield files from current directory for file in data.get("files", []): yield file # Recursively traverse folders for folder in data.get("folders", []): yield from self._get_files_list(folder["path"]) def _should_index_file( self, file: dict[str, Any], start_time: datetime | None = None, end_time: datetime | None = None, ) -> bool: """Return True if file should be included based on filters.""" if file["is_folder"]: return False file_modified = _parse_last_modified(file["last_modified"]) if start_time and file_modified < start_time: return False if end_time and file_modified > end_time: return False return True def _process_files( self, start_time: datetime | None = None, end_time: datetime | None = None, ) -> Generator[list[Document | HierarchyNode], None, None]: current_batch: list[Document | HierarchyNode] = [] # Iterate through yielded files and filter them for file in self._get_files_list(self.folder_path): if not self._should_index_file(file, start_time, end_time): logger.debug(f"Skipping file '{file['path']}'.") continue try: # Set up request with streaming enabled headers = { "Authorization": f"Bearer {self.access_token}", } url_encoded_path = quote(file["path"]) url = f"{_EGNYTE_API_BASE.format(domain=self.domain)}/fs-content/{url_encoded_path}" response = request_with_retries( method="GET", url=url, headers=headers, stream=True, ) if not response.ok: logger.error( f"Failed to fetch file content: {file['path']} (status code: {response.status_code})" ) continue # Stream the response content into a BytesIO buffer buffer = io.BytesIO() for chunk in response.iter_content(chunk_size=8192): if chunk: buffer.write(chunk) # Reset buffer's position to the start buffer.seek(0) # Process the streamed file content doc = _process_egnyte_file( file_metadata=file, file_content=buffer, base_url=_EGNYTE_APP_BASE.format(domain=self.domain), folder_path=self.folder_path, ) if doc is not None: current_batch.append(doc) if len(current_batch) >= self.batch_size: yield current_batch current_batch = [] except Exception: logger.exception(f"Failed to process file {file['path']}") continue if current_batch: yield current_batch def load_from_state(self) -> GenerateDocumentsOutput: yield from self._process_files() def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: start_time = datetime.fromtimestamp(start, tz=timezone.utc) end_time = datetime.fromtimestamp(end, tz=timezone.utc) yield from self._process_files(start_time=start_time, end_time=end_time) if __name__ == "__main__": connector = EgnyteConnector() connector.load_credentials( { "domain": os.environ["EGNYTE_DOMAIN"], "access_token": os.environ["EGNYTE_ACCESS_TOKEN"], } ) document_batches = connector.load_from_state() print(next(document_batches)) ================================================ FILE: backend/onyx/connectors/exceptions.py ================================================ class ValidationError(Exception): """General exception for validation errors.""" def __init__(self, message: str): self.message = message super().__init__(self.message) class ConnectorValidationError(ValidationError): """General exception for connector validation errors.""" def __init__(self, message: str): self.message = message super().__init__(self.message) class UnexpectedValidationError(ValidationError): """Raised when an unexpected error occurs during connector validation. Unexpected errors don't necessarily mean the credential is invalid, but rather that there was an error during the validation process or we encountered a currently unhandled error case. Currently, unexpected validation errors are defined as transient and should not be used to disable the connector. """ def __init__(self, message: str = "Unexpected error during connector validation"): super().__init__(message) class CredentialInvalidError(ConnectorValidationError): """Raised when a connector's credential is invalid.""" def __init__(self, message: str = "Credential is invalid"): super().__init__(message) class CredentialExpiredError(ConnectorValidationError): """Raised when a connector's credential is expired.""" def __init__(self, message: str = "Credential has expired"): super().__init__(message) class InsufficientPermissionsError(ConnectorValidationError): """Raised when the credential does not have sufficient API permissions.""" def __init__( self, message: str = "Insufficient permissions for the requested operation" ): super().__init__(message) ================================================ FILE: backend/onyx/connectors/factory.py ================================================ import importlib from typing import Any from typing import Type from sqlalchemy.orm import Session from onyx.configs.app_configs import INTEGRATION_TESTS_MODE from onyx.configs.constants import DocumentSource from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.interfaces import BaseConnector from onyx.connectors.interfaces import CheckpointedConnector from onyx.connectors.interfaces import CredentialsConnector from onyx.connectors.interfaces import EventConnector from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.models import InputType from onyx.connectors.registry import CONNECTOR_CLASS_MAP from onyx.db.connector import fetch_connector_by_id from onyx.db.credentials import backend_update_credential_json from onyx.db.credentials import fetch_credential_by_id from onyx.db.enums import AccessType from onyx.db.models import Credential from shared_configs.contextvars import get_current_tenant_id class ConnectorMissingException(Exception): pass # Cache for already imported connector classes _connector_cache: dict[DocumentSource, Type[BaseConnector]] = {} def _load_connector_class(source: DocumentSource) -> Type[BaseConnector]: """Dynamically load and cache a connector class.""" if source in _connector_cache: return _connector_cache[source] if source not in CONNECTOR_CLASS_MAP: raise ConnectorMissingException(f"Connector not found for source={source}") mapping = CONNECTOR_CLASS_MAP[source] try: module = importlib.import_module(mapping.module_path) connector_class = getattr(module, mapping.class_name) _connector_cache[source] = connector_class return connector_class except (ImportError, AttributeError) as e: raise ConnectorMissingException( f"Failed to import {mapping.class_name} from {mapping.module_path}: {e}" ) def _validate_connector_supports_input_type( connector: Type[BaseConnector], input_type: InputType | None, source: DocumentSource, ) -> None: """Validate that a connector supports the requested input type.""" if input_type is None: return # Check each input type requirement separately for clarity load_state_unsupported = input_type == InputType.LOAD_STATE and not issubclass( connector, LoadConnector ) poll_unsupported = ( input_type == InputType.POLL # Either poll or checkpoint works for this, in the future # all connectors should be checkpoint connectors and ( not issubclass(connector, PollConnector) and not issubclass(connector, CheckpointedConnector) ) ) event_unsupported = input_type == InputType.EVENT and not issubclass( connector, EventConnector ) if any([load_state_unsupported, poll_unsupported, event_unsupported]): raise ConnectorMissingException( f"Connector for source={source} does not accept input_type={input_type}" ) def identify_connector_class( source: DocumentSource, input_type: InputType | None = None, ) -> Type[BaseConnector]: # Load the connector class using lazy loading connector = _load_connector_class(source) # Validate connector supports the requested input_type _validate_connector_supports_input_type(connector, input_type, source) return connector def instantiate_connector( db_session: Session, source: DocumentSource, input_type: InputType, connector_specific_config: dict[str, Any], credential: Credential, ) -> BaseConnector: connector_class = identify_connector_class(source, input_type) connector = connector_class(**connector_specific_config) if isinstance(connector, CredentialsConnector): provider = OnyxDBCredentialsProvider( get_current_tenant_id(), str(source), credential.id ) connector.set_credentials_provider(provider) else: credential_json = ( credential.credential_json.get_value(apply_mask=False) if credential.credential_json else {} ) new_credentials = connector.load_credentials(credential_json) if new_credentials is not None: backend_update_credential_json(credential, new_credentials, db_session) connector.set_allow_images(get_image_extraction_and_analysis_enabled()) return connector def validate_ccpair_for_user( connector_id: int, credential_id: int, access_type: AccessType, db_session: Session, enforce_creation: bool = True, ) -> bool: if INTEGRATION_TESTS_MODE: return True # Validate the connector settings connector = fetch_connector_by_id(connector_id, db_session) credential = fetch_credential_by_id( credential_id, db_session, ) if not connector: raise ValueError("Connector not found") if ( connector.source == DocumentSource.INGESTION_API or connector.source == DocumentSource.MOCK_CONNECTOR ): return True if not credential: raise ValueError("Credential not found") try: runnable_connector = instantiate_connector( db_session=db_session, source=connector.source, input_type=connector.input_type, connector_specific_config=connector.connector_specific_config, credential=credential, ) except ConnectorValidationError as e: raise e except Exception as e: if enforce_creation: raise ConnectorValidationError(str(e)) else: return False runnable_connector.validate_connector_settings() if access_type == AccessType.SYNC: runnable_connector.validate_perm_sync() return True ================================================ FILE: backend/onyx/connectors/file/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/file/connector.py ================================================ import json import os from datetime import datetime from datetime import timezone from pathlib import Path from typing import Any from typing import IO from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.configs.constants import FileOrigin from onyx.connectors.cross_connector_utils.miscellaneous_utils import ( process_onyx_metadata, ) from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import ImageSection from onyx.connectors.models import TextSection from onyx.file_processing.extract_file_text import extract_text_and_images from onyx.file_processing.extract_file_text import get_file_ext from onyx.file_processing.file_types import OnyxFileExtensions from onyx.file_processing.image_utils import store_image_and_create_section from onyx.file_store.file_store import get_default_file_store from onyx.utils.logger import setup_logger logger = setup_logger() def _create_image_section( image_data: bytes, parent_file_name: str, display_name: str, media_type: str | None = None, link: str | None = None, idx: int = 0, ) -> tuple[ImageSection, str | None]: """ Creates an ImageSection for an image file or embedded image. Stores the image in FileStore but does not generate a summary. Args: image_data: Raw image bytes db_session: Database session parent_file_name: Name of the parent file (for embedded images) display_name: Display name for the image idx: Index for embedded images Returns: Tuple of (ImageSection, stored_file_name or None) """ # Create a unique identifier for the image file_id = f"{parent_file_name}_embedded_{idx}" if idx > 0 else parent_file_name # Store the image and create a section try: section, stored_file_name = store_image_and_create_section( image_data=image_data, file_id=file_id, display_name=display_name, media_type=( media_type if media_type is not None else "application/octet-stream" ), link=link, file_origin=FileOrigin.CONNECTOR, ) return section, stored_file_name except Exception as e: logger.error(f"Failed to store image {display_name}: {e}") raise e def _process_file( file_id: str, file_name: str, file: IO[Any], metadata: dict[str, Any] | None, pdf_pass: str | None, file_type: str | None, ) -> list[Document]: """ Process a file and return a list of Documents. For images, creates ImageSection objects without summarization. For documents with embedded images, extracts and stores the images. """ if metadata is None: metadata = {} # Get file extension and determine file type extension = get_file_ext(file_name) if extension not in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS: logger.warning( f"Skipping file '{file_name}' with unrecognized extension '{extension}'" ) return [] # If a zip is uploaded with a metadata file, we can process it here onyx_metadata, custom_tags = process_onyx_metadata(metadata) file_display_name = onyx_metadata.file_display_name or os.path.basename(file_name) time_updated = onyx_metadata.doc_updated_at or datetime.now(timezone.utc) primary_owners = onyx_metadata.primary_owners secondary_owners = onyx_metadata.secondary_owners link = onyx_metadata.link # These metadata items are not settable by the user source_type = onyx_metadata.source_type or DocumentSource.FILE doc_id = onyx_metadata.document_id or f"FILE_CONNECTOR__{file_id}" title = metadata.get("title") or file_display_name # 1) If the file itself is an image, handle that scenario quickly if extension in OnyxFileExtensions.IMAGE_EXTENSIONS: # Read the image data image_data = file.read() if not image_data: logger.warning(f"Empty image file: {file_name}") return [] # Create an ImageSection for the image try: section, _ = _create_image_section( image_data=image_data, parent_file_name=file_id, display_name=title, media_type=file_type, ) return [ Document( id=doc_id, sections=[section], source=source_type, semantic_identifier=file_display_name, title=title, doc_updated_at=time_updated, primary_owners=primary_owners, secondary_owners=secondary_owners, metadata=custom_tags, ) ] except Exception as e: logger.error(f"Failed to process image file {file_name}: {e}") return [] # 2) Otherwise: text-based approach. Possibly with embedded images. file.seek(0) # Extract text and images from the file extraction_result = extract_text_and_images( file=file, file_name=file_name, pdf_pass=pdf_pass, content_type=file_type, ) # Each file may have file-specific ONYX_METADATA https://docs.onyx.app/admins/connectors/official/file # If so, we should add it to any metadata processed so far if extraction_result.metadata: logger.debug( f"Found file-specific metadata for {file_name}: {extraction_result.metadata}" ) onyx_metadata, more_custom_tags = process_onyx_metadata( extraction_result.metadata ) # Add file-specific tags custom_tags.update(more_custom_tags) # File-specific metadata overrides metadata processed so far source_type = onyx_metadata.source_type or source_type primary_owners = onyx_metadata.primary_owners or primary_owners secondary_owners = onyx_metadata.secondary_owners or secondary_owners time_updated = onyx_metadata.doc_updated_at or time_updated file_display_name = onyx_metadata.file_display_name or file_display_name title = onyx_metadata.title or onyx_metadata.file_display_name or title link = onyx_metadata.link or link # Build sections: first the text as a single Section sections: list[TextSection | ImageSection] = [] if extraction_result.text_content.strip(): logger.debug(f"Creating TextSection for {file_name} with link: {link}") sections.append( TextSection(link=link, text=extraction_result.text_content.strip()) ) # Then any extracted images from docx, PDFs, etc. for idx, (img_data, img_name) in enumerate( extraction_result.embedded_images, start=1 ): # Store each embedded image as a separate file in FileStore # and create a section with the image reference try: image_section, stored_file_name = _create_image_section( image_data=img_data, parent_file_name=file_id, display_name=f"{title} - image {idx}", media_type="application/octet-stream", # Default media type for embedded images idx=idx, ) sections.append(image_section) logger.debug( f"Created ImageSection for embedded image {idx} in {file_name}, stored as: {stored_file_name}" ) except Exception as e: logger.warning( f"Failed to process embedded image {idx} in {file_name}: {e}" ) return [ Document( id=doc_id, sections=sections, source=source_type, semantic_identifier=file_display_name, title=title, doc_updated_at=time_updated, primary_owners=primary_owners, secondary_owners=secondary_owners, metadata=custom_tags, ) ] class LocalFileConnector(LoadConnector): """ Connector that reads files from Postgres and yields Documents, including embedded image extraction without summarization. file_locations are S3/Filestore UUIDs file_names are the names of the files """ # Note: file_names is a required parameter, but should not break backwards compatibility. # If add_file_names migration is not run, old file connector configs will not have file_names. # file_names is only used for display purposes in the UI and file_locations is used as a fallback. def __init__( self, file_locations: list[Path | str], file_names: list[str] | None = None, # noqa: ARG002 zip_metadata_file_id: str | None = None, zip_metadata: dict[str, Any] | None = None, # Deprecated, for backwards compat batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.file_locations = [str(loc) for loc in file_locations] self.batch_size = batch_size self.pdf_pass: str | None = None self._zip_metadata_file_id = zip_metadata_file_id self._zip_metadata_deprecated = zip_metadata def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self.pdf_pass = credentials.get("pdf_password") return None def load_from_state(self) -> GenerateDocumentsOutput: """ Iterates over each file path, fetches from Postgres, tries to parse text or images, and yields Document batches. """ # Load metadata dict at start (from file store or deprecated inline format) zip_metadata: dict[str, Any] = {} if self._zip_metadata_file_id: try: file_store = get_default_file_store() metadata_io = file_store.read_file( file_id=self._zip_metadata_file_id, mode="b" ) metadata_bytes = metadata_io.read() loaded_metadata = json.loads(metadata_bytes) if isinstance(loaded_metadata, list): zip_metadata = {d["filename"]: d for d in loaded_metadata} else: zip_metadata = loaded_metadata except Exception as e: logger.warning(f"Failed to load metadata from file store: {e}") elif self._zip_metadata_deprecated: logger.warning( "Using deprecated inline zip_metadata dict. Re-upload files to use the new file store format." ) zip_metadata = self._zip_metadata_deprecated documents: list[Document | HierarchyNode] = [] for file_id in self.file_locations: file_store = get_default_file_store() file_record = file_store.read_file_record(file_id=file_id) if not file_record: # typically an unsupported extension logger.warning(f"No file record found for '{file_id}' in PG; skipping.") continue metadata = zip_metadata.get( file_record.display_name, {} ) or zip_metadata.get(os.path.basename(file_record.display_name), {}) file_io = file_store.read_file(file_id=file_id, mode="b") new_docs = _process_file( file_id=file_id, file_name=file_record.display_name, file=file_io, metadata=metadata, pdf_pass=self.pdf_pass, file_type=file_record.file_type, ) documents.extend(new_docs) if len(documents) >= self.batch_size: yield documents documents = [] if documents: yield documents if __name__ == "__main__": connector = LocalFileConnector( file_locations=[os.environ["TEST_FILE"]], file_names=[os.environ["TEST_FILE"]], ) connector.load_credentials({"pdf_password": os.environ.get("PDF_PASSWORD")}) doc_batches = connector.load_from_state() for batch in doc_batches: print("BATCH:", batch) ================================================ FILE: backend/onyx/connectors/fireflies/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/fireflies/connector.py ================================================ from collections.abc import Iterator from datetime import datetime from datetime import timezone from typing import cast from typing import List import requests from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import ImageSection from onyx.connectors.models import TextSection from onyx.utils.logger import setup_logger logger = setup_logger() _FIREFLIES_ID_PREFIX = "FIREFLIES_" _FIREFLIES_API_URL = "https://api.fireflies.ai/graphql" _FIREFLIES_TRANSCRIPT_QUERY_SIZE = 50 # Max page size is 50 _FIREFLIES_API_QUERY = """ query Transcripts($fromDate: DateTime, $toDate: DateTime, $limit: Int!, $skip: Int!) { transcripts(fromDate: $fromDate, toDate: $toDate, limit: $limit, skip: $skip) { id title organizer_email participants date duration transcript_url sentences { text speaker_name start_time } } } """ ONE_MINUTE = 60 def _create_doc_from_transcript(transcript: dict) -> Document | None: sections: List[TextSection] = [] current_speaker_name = None current_link = "" current_text = "" if transcript["sentences"] is None: return None for sentence in transcript["sentences"]: if sentence["speaker_name"] != current_speaker_name: if current_speaker_name is not None: sections.append( TextSection( link=current_link, text=current_text.strip(), ) ) current_speaker_name = sentence.get("speaker_name") or "Unknown Speaker" current_link = f"{transcript['transcript_url']}?t={sentence['start_time']}" current_text = f"{current_speaker_name}: " cleaned_text = sentence["text"].replace("\xa0", " ") current_text += f"{cleaned_text} " # Sometimes these links (links with a timestamp) do not work, it is a bug with Fireflies. sections.append( TextSection( link=current_link, text=current_text.strip(), ) ) fireflies_id = _FIREFLIES_ID_PREFIX + transcript["id"] meeting_title = transcript["title"] or "No Title" meeting_date_unix = transcript["date"] meeting_date = datetime.fromtimestamp(meeting_date_unix / 1000, tz=timezone.utc) # Build hierarchy based on meeting date (year-month) year_month = meeting_date.strftime("%Y-%m") meeting_organizer_email = transcript["organizer_email"] organizer_email_user_info = [BasicExpertInfo(email=meeting_organizer_email)] meeting_participants_email_list = [] for participant in transcript.get("participants", []): if participant != meeting_organizer_email and participant: meeting_participants_email_list.append(BasicExpertInfo(email=participant)) return Document( id=fireflies_id, sections=cast(list[TextSection | ImageSection], sections), source=DocumentSource.FIREFLIES, semantic_identifier=meeting_title, doc_metadata={ "hierarchy": { "source_path": [year_month], "year_month": year_month, "meeting_title": meeting_title, "organizer_email": meeting_organizer_email, } }, metadata={ k: str(v) for k, v in { "meeting_date": meeting_date, "duration_min": transcript.get("duration"), }.items() if v is not None }, doc_updated_at=meeting_date, primary_owners=organizer_email_user_info, secondary_owners=meeting_participants_email_list, ) # If not all transcripts are being indexed, try using a more-recently-generated # API key. class FirefliesConnector(PollConnector, LoadConnector): def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None: self.batch_size = batch_size def load_credentials(self, credentials: dict[str, str]) -> None: api_key = credentials.get("fireflies_api_key") if not isinstance(api_key, str): raise ConnectorMissingCredentialError( "The Fireflies API key must be a string" ) self.api_key = api_key return None def _fetch_transcripts( self, start_datetime: str | None = None, end_datetime: str | None = None ) -> Iterator[List[dict]]: if self.api_key is None: raise ConnectorMissingCredentialError("Missing API key") headers = { "Content-Type": "application/json", "Authorization": "Bearer " + self.api_key, } skip = 0 variables: dict[str, int | str] = { "limit": _FIREFLIES_TRANSCRIPT_QUERY_SIZE, } if start_datetime: variables["fromDate"] = start_datetime if end_datetime: variables["toDate"] = end_datetime while True: variables["skip"] = skip response = requests.post( _FIREFLIES_API_URL, headers=headers, json={"query": _FIREFLIES_API_QUERY, "variables": variables}, ) response.raise_for_status() if response.status_code == 204: break received_transcripts = response.json() parsed_transcripts = received_transcripts.get("data", {}).get( "transcripts", [] ) yield parsed_transcripts if len(parsed_transcripts) < _FIREFLIES_TRANSCRIPT_QUERY_SIZE: break skip += _FIREFLIES_TRANSCRIPT_QUERY_SIZE def _process_transcripts( self, start: str | None = None, end: str | None = None ) -> GenerateDocumentsOutput: doc_batch: List[Document | HierarchyNode] = [] for transcript_batch in self._fetch_transcripts(start, end): for transcript in transcript_batch: if doc := _create_doc_from_transcript(transcript): doc_batch.append(doc) if len(doc_batch) >= self.batch_size: yield doc_batch doc_batch = [] if doc_batch: yield doc_batch def load_from_state(self) -> GenerateDocumentsOutput: return self._process_transcripts() def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: # add some leeway to account for any timezone funkiness and/or bad handling # of start time on the Fireflies side start = max(0, start - ONE_MINUTE) start_datetime = datetime.fromtimestamp(start, tz=timezone.utc).strftime( "%Y-%m-%dT%H:%M:%S.000Z" ) end_datetime = datetime.fromtimestamp(end, tz=timezone.utc).strftime( "%Y-%m-%dT%H:%M:%S.000Z" ) yield from self._process_transcripts(start_datetime, end_datetime) ================================================ FILE: backend/onyx/connectors/freshdesk/__init__,py ================================================ ================================================ FILE: backend/onyx/connectors/freshdesk/connector.py ================================================ import json from collections.abc import Iterator from datetime import datetime from datetime import timezone from typing import List import requests from retry import retry from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.rate_limit_wrapper import ( rl_requests, ) from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import TextSection from onyx.file_processing.html_utils import parse_html_page_basic from onyx.utils.logger import setup_logger logger = setup_logger() _FRESHDESK_ID_PREFIX = "FRESHDESK_" _TICKET_FIELDS_TO_INCLUDE = { "fr_escalated", "spam", "priority", "source", "status", "type", "is_escalated", "tags", "nr_due_by", "nr_escalated", "cc_emails", "fwd_emails", "reply_cc_emails", "ticket_cc_emails", "support_email", "to_emails", } _SOURCE_NUMBER_TYPE_MAP: dict[int, str] = { 1: "Email", 2: "Portal", 3: "Phone", 7: "Chat", 9: "Feedback Widget", 10: "Outbound Email", } _PRIORITY_NUMBER_TYPE_MAP: dict[int, str] = { 1: "low", 2: "medium", 3: "high", 4: "urgent", } _STATUS_NUMBER_TYPE_MAP: dict[int, str] = { 2: "open", 3: "pending", 4: "resolved", 5: "closed", } # TODO: unify this with other generic rate limited requests with retries (e.g. Axero, Notion?) @retry(tries=3, delay=1, backoff=2) def _rate_limited_freshdesk_get( url: str, auth: tuple, params: dict ) -> requests.Response: return rl_requests.get(url, auth=auth, params=params) def _create_metadata_from_ticket(ticket: dict) -> dict: metadata: dict[str, str | list[str]] = {} # Combine all emails into a list so there are no repeated emails email_data: set[str] = set() for key, value in ticket.items(): # Skip fields that aren't useful for embedding if key not in _TICKET_FIELDS_TO_INCLUDE: continue # Skip empty fields if not value or value == "[]": continue # Convert strings or lists to strings stringified_value: str | list[str] if isinstance(value, list): stringified_value = [str(item) for item in value] else: stringified_value = str(value) if "email" in key: if isinstance(stringified_value, list): email_data.update(stringified_value) else: email_data.add(stringified_value) else: metadata[key] = stringified_value if email_data: metadata["emails"] = list(email_data) # Convert source numbers to human-parsable string if source_number := ticket.get("source"): metadata["source"] = _SOURCE_NUMBER_TYPE_MAP.get( source_number, "Unknown Source Type" ) # Convert priority numbers to human-parsable string if priority_number := ticket.get("priority"): metadata["priority"] = _PRIORITY_NUMBER_TYPE_MAP.get( priority_number, "Unknown Priority" ) # Convert status to human-parsable string if status_number := ticket.get("status"): metadata["status"] = _STATUS_NUMBER_TYPE_MAP.get( status_number, "Unknown Status" ) due_by = datetime.fromisoformat(ticket["due_by"].replace("Z", "+00:00")) metadata["overdue"] = str(datetime.now(timezone.utc) > due_by) return metadata def _create_doc_from_ticket(ticket: dict, domain: str) -> Document: # Use the ticket description as the text text = f"Ticket description: {parse_html_page_basic(ticket.get('description_text', ''))}" metadata = _create_metadata_from_ticket(ticket) # This is also used in the ID because it is more unique than the just the ticket ID link = f"https://{domain}.freshdesk.com/helpdesk/tickets/{ticket['id']}" return Document( id=_FRESHDESK_ID_PREFIX + link, sections=[ TextSection( link=link, text=text, ) ], source=DocumentSource.FRESHDESK, semantic_identifier=ticket["subject"], metadata=metadata, doc_updated_at=datetime.fromisoformat( ticket["updated_at"].replace("Z", "+00:00") ), ) class FreshdeskConnector(PollConnector, LoadConnector): def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None: self.batch_size = batch_size def load_credentials(self, credentials: dict[str, str | int]) -> None: api_key = credentials.get("freshdesk_api_key") domain = credentials.get("freshdesk_domain") if not all(isinstance(cred, str) for cred in [domain, api_key]): raise ConnectorMissingCredentialError( "All Freshdesk credentials must be strings" ) # TODO: Move the domain to the connector-specific configuration instead of part of the credential # Then apply normalization and validation against the config # Clean and normalize the domain URL domain = str(domain).strip().lower() # Remove any trailing slashes domain = domain.rstrip("/") # Remove protocol if present if domain.startswith(("http://", "https://")): domain = domain.replace("http://", "").replace("https://", "") # Remove .freshdesk.com suffix and any API paths if present if ".freshdesk.com" in domain: domain = domain.split(".freshdesk.com")[0] if not domain: raise ConnectorMissingCredentialError("Freshdesk domain cannot be empty") self.api_key = str(api_key) self.domain = domain def _fetch_tickets( self, start: datetime | None = None, end: datetime | None = None, # noqa: ARG002 ) -> Iterator[List[dict]]: """ 'end' is not currently used, so we may double fetch tickets created after the indexing starts but before the actual call is made. To use 'end' would require us to use the search endpoint but it has limitations, namely having to fetch all IDs and then individually fetch each ticket because there is no 'include' field available for this endpoint: https://developers.freshdesk.com/api/#filter_tickets """ if self.api_key is None or self.domain is None: raise ConnectorMissingCredentialError("freshdesk") base_url = f"https://{self.domain}.freshdesk.com/api/v2/tickets" params: dict[str, int | str] = { "include": "description", "per_page": 50, "page": 1, } if start: params["updated_since"] = start.isoformat() while True: # Freshdesk API uses API key as the username and any value as the password. response = _rate_limited_freshdesk_get( base_url, auth=(self.api_key, "CanYouBelieveFreshdeskDoesThis"), params=params, ) response.raise_for_status() if response.status_code == 204: break tickets = json.loads(response.content) logger.info( f"Fetched {len(tickets)} tickets from Freshdesk API (Page {params['page']})" ) yield tickets if len(tickets) < int(params["per_page"]): break params["page"] = int(params["page"]) + 1 def _process_tickets( self, start: datetime | None = None, end: datetime | None = None ) -> GenerateDocumentsOutput: doc_batch: List[Document | HierarchyNode] = [] for ticket_batch in self._fetch_tickets(start, end): for ticket in ticket_batch: doc_batch.append(_create_doc_from_ticket(ticket, self.domain)) if len(doc_batch) >= self.batch_size: yield doc_batch doc_batch = [] if doc_batch: yield doc_batch def load_from_state(self) -> GenerateDocumentsOutput: return self._process_tickets() def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: start_datetime = datetime.fromtimestamp(start, tz=timezone.utc) end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) yield from self._process_tickets(start_datetime, end_datetime) ================================================ FILE: backend/onyx/connectors/gitbook/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/gitbook/connector.py ================================================ from datetime import datetime from datetime import timezone from typing import Any from urllib.parse import urljoin import requests from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import TextSection from onyx.utils.logger import setup_logger logger = setup_logger() GITBOOK_API_BASE = "https://api.gitbook.com/v1/" class GitbookApiClient: def __init__(self, access_token: str) -> None: self.access_token = access_token def get(self, endpoint: str, params: dict[str, Any] | None = None) -> Any: headers = { "Authorization": f"Bearer {self.access_token}", "Content-Type": "application/json", } url = urljoin(GITBOOK_API_BASE, endpoint.lstrip("/")) response = requests.get(url, headers=headers, params=params) response.raise_for_status() return response.json() def get_page_content(self, space_id: str, page_id: str) -> dict[str, Any]: return self.get(f"/spaces/{space_id}/content/page/{page_id}") def _extract_text_from_document(document: dict[str, Any]) -> str: """Extract text content from GitBook document structure by parsing the document nodes into markdown format.""" def parse_leaf(leaf: dict[str, Any]) -> str: text = leaf.get("text", "") leaf.get("marks", []) return text def parse_text_node(node: dict[str, Any]) -> str: text = "" for leaf in node.get("leaves", []): text += parse_leaf(leaf) return text def parse_block_node(node: dict[str, Any]) -> str: block_type = node.get("type", "") result = "" if block_type == "heading-1": text = "".join(parse_text_node(n) for n in node.get("nodes", [])) result = f"# {text}\n\n" elif block_type == "heading-2": text = "".join(parse_text_node(n) for n in node.get("nodes", [])) result = f"## {text}\n\n" elif block_type == "heading-3": text = "".join(parse_text_node(n) for n in node.get("nodes", [])) result = f"### {text}\n\n" elif block_type == "heading-4": text = "".join(parse_text_node(n) for n in node.get("nodes", [])) result = f"#### {text}\n\n" elif block_type == "heading-5": text = "".join(parse_text_node(n) for n in node.get("nodes", [])) result = f"##### {text}\n\n" elif block_type == "heading-6": text = "".join(parse_text_node(n) for n in node.get("nodes", [])) result = f"###### {text}\n\n" elif block_type == "list-unordered": for list_item in node.get("nodes", []): paragraph = list_item.get("nodes", [])[0] text = "".join(parse_text_node(n) for n in paragraph.get("nodes", [])) result += f"* {text}\n" result += "\n" elif block_type == "paragraph": text = "".join(parse_text_node(n) for n in node.get("nodes", [])) result = f"{text}\n\n" elif block_type == "list-tasks": for task_item in node.get("nodes", []): checked = task_item.get("data", {}).get("checked", False) paragraph = task_item.get("nodes", [])[0] text = "".join(parse_text_node(n) for n in paragraph.get("nodes", [])) checkbox = "[x]" if checked else "[ ]" result += f"- {checkbox} {text}\n" result += "\n" elif block_type == "code": for code_line in node.get("nodes", []): if code_line.get("type") == "code-line": text = "".join( parse_text_node(n) for n in code_line.get("nodes", []) ) result += f"{text}\n" result += "\n" elif block_type == "blockquote": for quote_node in node.get("nodes", []): if quote_node.get("type") == "paragraph": text = "".join( parse_text_node(n) for n in quote_node.get("nodes", []) ) result += f"> {text}\n" result += "\n" elif block_type == "table": records = node.get("data", {}).get("records", {}) definition = node.get("data", {}).get("definition", {}) view = node.get("data", {}).get("view", {}) columns = view.get("columns", []) header_cells = [] for col_id in columns: col_def = definition.get(col_id, {}) header_cells.append(col_def.get("title", "")) result = "| " + " | ".join(header_cells) + " |\n" result += "|" + "---|" * len(header_cells) + "\n" sorted_records = sorted( records.items(), key=lambda x: x[1].get("orderIndex", "") ) for record_id, record_data in sorted_records: values = record_data.get("values", {}) row_cells = [] for col_id in columns: fragment_id = values.get(col_id, "") fragment_text = "" for fragment in node.get("fragments", []): if fragment.get("fragment") == fragment_id: for frag_node in fragment.get("nodes", []): if frag_node.get("type") == "paragraph": fragment_text = "".join( parse_text_node(n) for n in frag_node.get("nodes", []) ) break row_cells.append(fragment_text) result += "| " + " | ".join(row_cells) + " |\n" result += "\n" return result if not document or "document" not in document: return "" markdown = "" nodes = document["document"].get("nodes", []) for node in nodes: markdown += parse_block_node(node) return markdown def _convert_page_to_document( client: GitbookApiClient, space_id: str, page: dict[str, Any] ) -> Document: page_id = page["id"] page_content = client.get_page_content(space_id, page_id) return Document( id=f"gitbook-{space_id}-{page_id}", sections=[ TextSection( link=page.get("urls", {}).get("app", ""), text=_extract_text_from_document(page_content), ) ], source=DocumentSource.GITBOOK, semantic_identifier=page.get("title", ""), doc_updated_at=datetime.fromisoformat(page["updatedAt"]).replace( tzinfo=timezone.utc ), metadata={ "path": page.get("path", ""), "type": page.get("type", ""), "kind": page.get("kind", ""), }, ) class GitbookConnector(LoadConnector, PollConnector): def __init__( self, space_id: str, batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.space_id = space_id self.batch_size = batch_size self.access_token: str | None = None self.client: GitbookApiClient | None = None def load_credentials(self, credentials: dict[str, Any]) -> None: access_token = credentials.get("gitbook_api_key") if not access_token: raise ConnectorMissingCredentialError("GitBook access token") self.access_token = access_token self.client = GitbookApiClient(access_token) def _fetch_all_pages( self, start: datetime | None = None, end: datetime | None = None, ) -> GenerateDocumentsOutput: if not self.client: raise ConnectorMissingCredentialError("GitBook") try: content = self.client.get(f"/spaces/{self.space_id}/content/pages") pages: list[dict[str, Any]] = content.get("pages", []) current_batch: list[Document | HierarchyNode] = [] logger.info(f"Found {len(pages)} root pages.") logger.info( f"First 20 Page Ids: {[page.get('id', 'Unknown') for page in pages[:20]]}" ) while pages: page = pages.pop(0) updated_at_raw = page.get("updatedAt") if updated_at_raw is None: # if updatedAt is not present, that means the page has never been edited continue updated_at = datetime.fromisoformat(updated_at_raw) if start and updated_at < start: continue if end and updated_at > end: continue current_batch.append( _convert_page_to_document(self.client, self.space_id, page) ) if len(current_batch) >= self.batch_size: yield current_batch current_batch = [] pages.extend(page.get("pages", [])) if current_batch: yield current_batch except requests.RequestException as e: logger.error(f"Error fetching GitBook content: {str(e)}") raise def load_from_state(self) -> GenerateDocumentsOutput: return self._fetch_all_pages() def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: start_datetime = datetime.fromtimestamp(start, tz=timezone.utc) end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) return self._fetch_all_pages(start_datetime, end_datetime) if __name__ == "__main__": import os connector = GitbookConnector( space_id=os.environ["GITBOOK_SPACE_ID"], ) connector.load_credentials({"gitbook_api_key": os.environ["GITBOOK_API_KEY"]}) document_batches = connector.load_from_state() print(next(document_batches)) ================================================ FILE: backend/onyx/connectors/github/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/github/connector.py ================================================ import copy from collections.abc import Callable from collections.abc import Generator from datetime import datetime from datetime import timedelta from datetime import timezone from enum import Enum from typing import Any from typing import cast from github import Github from github import RateLimitExceededException from github import Repository from github.GithubException import GithubException from github.Issue import Issue from github.NamedUser import NamedUser from github.PaginatedList import PaginatedList from github.PullRequest import PullRequest from pydantic import BaseModel from typing_extensions import override from onyx.access.models import ExternalAccess from onyx.configs.app_configs import GITHUB_CONNECTOR_BASE_URL from onyx.configs.constants import DocumentSource from onyx.connectors.connector_runner import ConnectorRunner from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.exceptions import UnexpectedValidationError from onyx.connectors.github.models import SerializedRepository from onyx.connectors.github.rate_limit_utils import sleep_after_rate_limit_exception from onyx.connectors.github.utils import deserialize_repository from onyx.connectors.github.utils import get_external_access_permission from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import ConnectorCheckpoint from onyx.connectors.interfaces import ConnectorFailure from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import DocumentFailure from onyx.connectors.models import TextSection from onyx.utils.logger import setup_logger logger = setup_logger() ITEMS_PER_PAGE = 100 CURSOR_LOG_FREQUENCY = 50 _MAX_NUM_RATE_LIMIT_RETRIES = 5 ONE_DAY = timedelta(days=1) SLIM_BATCH_SIZE = 100 # Cases # X (from start) standard run, no fallback to cursor-based pagination # X (from start) standard run errors, fallback to cursor-based pagination # X error in the middle of a page # X no errors: run to completion # X (from checkpoint) standard run, no fallback to cursor-based pagination # X (from checkpoint) continue from cursor-based pagination # - retrying # - no retrying # things to check: # checkpoint state on return # checkpoint progress (no infinite loop) class DocMetadata(BaseModel): repo: str def get_nextUrl_key(pag_list: PaginatedList[PullRequest | Issue]) -> str: if "_PaginatedList__nextUrl" in pag_list.__dict__: return "_PaginatedList__nextUrl" for key in pag_list.__dict__: if "__nextUrl" in key: return key for key in pag_list.__dict__: if "nextUrl" in key: return key return "" def get_nextUrl( pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str ) -> str | None: return getattr(pag_list, nextUrl_key) if nextUrl_key else None def set_nextUrl( pag_list: PaginatedList[PullRequest | Issue], nextUrl_key: str, nextUrl: str ) -> None: if nextUrl_key: setattr(pag_list, nextUrl_key, nextUrl) elif nextUrl: raise ValueError("Next URL key not found: " + str(pag_list.__dict__)) def _paginate_until_error( git_objs: Callable[[], PaginatedList[PullRequest | Issue]], cursor_url: str | None, prev_num_objs: int, cursor_url_callback: Callable[[str | None, int], None], retrying: bool = False, ) -> Generator[PullRequest | Issue, None, None]: num_objs = prev_num_objs pag_list = git_objs() nextUrl_key = get_nextUrl_key(pag_list) if cursor_url: set_nextUrl(pag_list, nextUrl_key, cursor_url) elif retrying: # if we are retrying, we want to skip the objects retrieved # over previous calls. Unfortunately, this WILL retrieve all # pages before the one we are resuming from, so we really # don't want this case to be hit often logger.warning( "Retrying from a previous cursor-based pagination call. " "This will retrieve all pages before the one we are resuming from, " "which may take a while and consume many API calls." ) pag_list = cast(PaginatedList[PullRequest | Issue], pag_list[prev_num_objs:]) num_objs = 0 try: # this for loop handles cursor-based pagination for issue_or_pr in pag_list: num_objs += 1 yield issue_or_pr # used to store the current cursor url in the checkpoint. This value # is updated during iteration over pag_list. cursor_url_callback(get_nextUrl(pag_list, nextUrl_key), num_objs) if num_objs % CURSOR_LOG_FREQUENCY == 0: logger.info( f"Retrieved {num_objs} objects with current cursor url: {get_nextUrl(pag_list, nextUrl_key)}" ) except Exception as e: logger.exception(f"Error during cursor-based pagination: {e}") if num_objs - prev_num_objs > 0: raise if get_nextUrl(pag_list, nextUrl_key) is not None and not retrying: logger.info( "Assuming that this error is due to cursor " "expiration because no objects were retrieved. " "Retrying from the first page." ) yield from _paginate_until_error( git_objs, None, prev_num_objs, cursor_url_callback, retrying=True ) return # for no cursor url or if we reach this point after a retry, raise the error raise def _get_batch_rate_limited( # We pass in a callable because we want git_objs to produce a fresh # PaginatedList each time it's called to avoid using the same object for cursor-based pagination # from a partial offset-based pagination call. git_objs: Callable[[], PaginatedList], page_num: int, cursor_url: str | None, prev_num_objs: int, cursor_url_callback: Callable[[str | None, int], None], github_client: Github, attempt_num: int = 0, ) -> Generator[PullRequest | Issue, None, None]: if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES: raise RuntimeError( "Re-tried fetching batch too many times. Something is going wrong with fetching objects from Github" ) try: if cursor_url: # when this is set, we are resuming from an earlier # cursor-based pagination call. yield from _paginate_until_error( git_objs, cursor_url, prev_num_objs, cursor_url_callback ) return objs = list(git_objs().get_page(page_num)) # fetch all data here to disable lazy loading later # this is needed to capture the rate limit exception here (if one occurs) for obj in objs: if hasattr(obj, "raw_data"): getattr(obj, "raw_data") yield from objs except RateLimitExceededException: sleep_after_rate_limit_exception(github_client) yield from _get_batch_rate_limited( git_objs, page_num, cursor_url, prev_num_objs, cursor_url_callback, github_client, attempt_num + 1, ) except GithubException as e: if not ( e.status == 422 and ( "cursor" in (e.message or "") or "cursor" in (e.data or {}).get("message", "") ) ): raise # Fallback to a cursor-based pagination strategy # This can happen for "large datasets," but there's no documentation # On the error on the web as far as we can tell. # Error message: # "Pagination with the page parameter is not supported for large datasets, # please use cursor based pagination (after/before)" yield from _paginate_until_error( git_objs, cursor_url, prev_num_objs, cursor_url_callback ) def _get_userinfo(user: NamedUser) -> dict[str, str]: def _safe_get(attr_name: str) -> str | None: try: return cast(str | None, getattr(user, attr_name)) except GithubException: logger.debug(f"Error getting {attr_name} for user") return None return { k: v for k, v in { "login": _safe_get("login"), "name": _safe_get("name"), "email": _safe_get("email"), }.items() if v is not None } def _convert_pr_to_document( pull_request: PullRequest, repo_external_access: ExternalAccess | None ) -> Document: repo_full_name = pull_request.base.repo.full_name if pull_request.base else "" # Split full_name (e.g., "owner/repo") into owner and repo parts = repo_full_name.split("/", 1) owner_name = parts[0] if parts else "" repo_name = parts[1] if len(parts) > 1 else repo_full_name doc_metadata = { "repo": repo_full_name, "hierarchy": { "source_path": [owner_name, repo_name, "pull_requests"], "owner": owner_name, "repo": repo_name, "object_type": "pull_request", }, } return Document( id=pull_request.html_url, sections=[ TextSection(link=pull_request.html_url, text=pull_request.body or "") ], external_access=repo_external_access, source=DocumentSource.GITHUB, semantic_identifier=f"{pull_request.number}: {pull_request.title}", # updated_at is UTC time but is timezone unaware, explicitly add UTC # as there is logic in indexing to prevent wrong timestamped docs # due to local time discrepancies with UTC doc_updated_at=( pull_request.updated_at.replace(tzinfo=timezone.utc) if pull_request.updated_at else None ), # this metadata is used in perm sync doc_metadata=doc_metadata, metadata={ k: [str(vi) for vi in v] if isinstance(v, list) else str(v) for k, v in { "object_type": "PullRequest", "id": pull_request.number, "merged": pull_request.merged, "state": pull_request.state, "user": _get_userinfo(pull_request.user) if pull_request.user else None, "assignees": [ _get_userinfo(assignee) for assignee in pull_request.assignees ], "repo": ( pull_request.base.repo.full_name if pull_request.base else None ), "num_commits": str(pull_request.commits), "num_files_changed": str(pull_request.changed_files), "labels": [label.name for label in pull_request.labels], "created_at": ( pull_request.created_at.replace(tzinfo=timezone.utc) if pull_request.created_at else None ), "updated_at": ( pull_request.updated_at.replace(tzinfo=timezone.utc) if pull_request.updated_at else None ), "closed_at": ( pull_request.closed_at.replace(tzinfo=timezone.utc) if pull_request.closed_at else None ), "merged_at": ( pull_request.merged_at.replace(tzinfo=timezone.utc) if pull_request.merged_at else None ), "merged_by": ( _get_userinfo(pull_request.merged_by) if pull_request.merged_by else None ), }.items() if v is not None }, ) def _fetch_issue_comments(issue: Issue) -> str: comments = issue.get_comments() return "\nComment: ".join(comment.body for comment in comments) def _convert_issue_to_document( issue: Issue, repo_external_access: ExternalAccess | None ) -> Document: repo_full_name = issue.repository.full_name if issue.repository else "" # Split full_name (e.g., "owner/repo") into owner and repo parts = repo_full_name.split("/", 1) owner_name = parts[0] if parts else "" repo_name = parts[1] if len(parts) > 1 else repo_full_name doc_metadata = { "repo": repo_full_name, "hierarchy": { "source_path": [owner_name, repo_name, "issues"], "owner": owner_name, "repo": repo_name, "object_type": "issue", }, } return Document( id=issue.html_url, sections=[TextSection(link=issue.html_url, text=issue.body or "")], source=DocumentSource.GITHUB, external_access=repo_external_access, semantic_identifier=f"{issue.number}: {issue.title}", # updated_at is UTC time but is timezone unaware doc_updated_at=issue.updated_at.replace(tzinfo=timezone.utc), # this metadata is used in perm sync doc_metadata=doc_metadata, metadata={ k: [str(vi) for vi in v] if isinstance(v, list) else str(v) for k, v in { "object_type": "Issue", "id": issue.number, "state": issue.state, "user": _get_userinfo(issue.user) if issue.user else None, "assignees": [_get_userinfo(assignee) for assignee in issue.assignees], "repo": issue.repository.full_name if issue.repository else None, "labels": [label.name for label in issue.labels], "created_at": ( issue.created_at.replace(tzinfo=timezone.utc) if issue.created_at else None ), "updated_at": ( issue.updated_at.replace(tzinfo=timezone.utc) if issue.updated_at else None ), "closed_at": ( issue.closed_at.replace(tzinfo=timezone.utc) if issue.closed_at else None ), "closed_by": ( _get_userinfo(issue.closed_by) if issue.closed_by else None ), }.items() if v is not None }, ) class GithubConnectorStage(Enum): START = "start" PRS = "prs" ISSUES = "issues" class GithubConnectorCheckpoint(ConnectorCheckpoint): stage: GithubConnectorStage curr_page: int cached_repo_ids: list[int] | None = None cached_repo: SerializedRepository | None = None # Used for the fallback cursor-based pagination strategy num_retrieved: int cursor_url: str | None = None def reset(self) -> None: """ Resets curr_page, num_retrieved, and cursor_url to their initial values (0, 0, None) """ self.curr_page = 0 self.num_retrieved = 0 self.cursor_url = None def make_cursor_url_callback( checkpoint: GithubConnectorCheckpoint, ) -> Callable[[str | None, int], None]: def cursor_url_callback(cursor_url: str | None, num_objs: int) -> None: # we want to maintain the old cursor url so code after retrieval # can determine that we are using the fallback cursor-based pagination strategy if cursor_url: checkpoint.cursor_url = cursor_url checkpoint.num_retrieved = num_objs return cursor_url_callback class GithubConnector(CheckpointedConnectorWithPermSync[GithubConnectorCheckpoint]): def __init__( self, repo_owner: str, repositories: str | None = None, state_filter: str = "all", include_prs: bool = True, include_issues: bool = False, ) -> None: self.repo_owner = repo_owner self.repositories = repositories self.state_filter = state_filter self.include_prs = include_prs self.include_issues = include_issues self.github_client: Github | None = None def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: # defaults to 30 items per page, can be set to as high as 100 self.github_client = ( Github( credentials["github_access_token"], base_url=GITHUB_CONNECTOR_BASE_URL, per_page=ITEMS_PER_PAGE, ) if GITHUB_CONNECTOR_BASE_URL else Github(credentials["github_access_token"], per_page=ITEMS_PER_PAGE) ) return None def get_github_repo( self, github_client: Github, attempt_num: int = 0 ) -> Repository.Repository: if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES: raise RuntimeError( "Re-tried fetching repo too many times. Something is going wrong with fetching objects from Github" ) try: return github_client.get_repo(f"{self.repo_owner}/{self.repositories}") except RateLimitExceededException: sleep_after_rate_limit_exception(github_client) return self.get_github_repo(github_client, attempt_num + 1) def get_github_repos( self, github_client: Github, attempt_num: int = 0 ) -> list[Repository.Repository]: """Get specific repositories based on comma-separated repo_name string.""" if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES: raise RuntimeError( "Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github" ) try: repos = [] # Split repo_name by comma and strip whitespace repo_names = [ name.strip() for name in (cast(str, self.repositories)).split(",") ] for repo_name in repo_names: if repo_name: # Skip empty strings try: repo = github_client.get_repo(f"{self.repo_owner}/{repo_name}") repos.append(repo) except GithubException as e: logger.warning( f"Could not fetch repo {self.repo_owner}/{repo_name}: {e}" ) return repos except RateLimitExceededException: sleep_after_rate_limit_exception(github_client) return self.get_github_repos(github_client, attempt_num + 1) def get_all_repos( self, github_client: Github, attempt_num: int = 0 ) -> list[Repository.Repository]: if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES: raise RuntimeError( "Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github" ) try: # Try to get organization first try: org = github_client.get_organization(self.repo_owner) return list(org.get_repos()) except GithubException: # If not an org, try as a user user = github_client.get_user(self.repo_owner) return list(user.get_repos()) except RateLimitExceededException: sleep_after_rate_limit_exception(github_client) return self.get_all_repos(github_client, attempt_num + 1) def fetch_configured_repos(self) -> list[Repository.Repository]: """ Fetch the configured repositories based on the connector settings. Returns: list[Repository.Repository]: The configured repositories. """ assert self.github_client is not None # mypy if self.repositories: if "," in self.repositories: return self.get_github_repos(self.github_client) else: return [self.get_github_repo(self.github_client)] else: return self.get_all_repos(self.github_client) def _pull_requests_func( self, repo: Repository.Repository ) -> Callable[[], PaginatedList[PullRequest]]: return lambda: repo.get_pulls( state=self.state_filter, sort="updated", direction="desc" ) def _issues_func( self, repo: Repository.Repository ) -> Callable[[], PaginatedList[Issue]]: return lambda: repo.get_issues( state=self.state_filter, sort="updated", direction="desc" ) def _fetch_from_github( self, checkpoint: GithubConnectorCheckpoint, start: datetime | None = None, end: datetime | None = None, include_permissions: bool = False, ) -> Generator[Document | ConnectorFailure, None, GithubConnectorCheckpoint]: if self.github_client is None: raise ConnectorMissingCredentialError("GitHub") checkpoint = copy.deepcopy(checkpoint) # First run of the connector, fetch all repos and store in checkpoint if checkpoint.cached_repo_ids is None: repos = self.fetch_configured_repos() if not repos: checkpoint.has_more = False return checkpoint curr_repo = repos.pop() checkpoint.cached_repo_ids = [repo.id for repo in repos] checkpoint.cached_repo = SerializedRepository( id=curr_repo.id, headers=curr_repo.raw_headers, raw_data=curr_repo.raw_data, ) checkpoint.stage = GithubConnectorStage.PRS checkpoint.curr_page = 0 # save checkpoint with repo ids retrieved return checkpoint if checkpoint.cached_repo is None: raise ValueError("No repo saved in checkpoint") # Deserialize the repository from the checkpoint repo = deserialize_repository(checkpoint.cached_repo, self.github_client) cursor_url_callback = make_cursor_url_callback(checkpoint) repo_external_access: ExternalAccess | None = None if include_permissions: repo_external_access = get_external_access_permission( repo, self.github_client ) if self.include_prs and checkpoint.stage == GithubConnectorStage.PRS: logger.info(f"Fetching PRs for repo: {repo.name}") pr_batch = _get_batch_rate_limited( self._pull_requests_func(repo), checkpoint.curr_page, checkpoint.cursor_url, checkpoint.num_retrieved, cursor_url_callback, self.github_client, ) checkpoint.curr_page += 1 # NOTE: not used for cursor-based fallback done_with_prs = False num_prs = 0 pr = None for pr in pr_batch: num_prs += 1 # we iterate backwards in time, so at this point we stop processing prs if ( start is not None and pr.updated_at and pr.updated_at.replace(tzinfo=timezone.utc) < start ): done_with_prs = True break # Skip PRs updated after the end date if ( end is not None and pr.updated_at and pr.updated_at.replace(tzinfo=timezone.utc) > end ): continue try: yield _convert_pr_to_document( cast(PullRequest, pr), repo_external_access ) except Exception as e: error_msg = f"Error converting PR to document: {e}" logger.exception(error_msg) yield ConnectorFailure( failed_document=DocumentFailure( document_id=str(pr.id), document_link=pr.html_url ), failure_message=error_msg, exception=e, ) continue # If we reach this point with a cursor url in the checkpoint, we were using # the fallback cursor-based pagination strategy. That strategy tries to get all # PRs, so having curosr_url set means we are done with prs. However, we need to # return AFTER the checkpoint reset to avoid infinite loops. # if we found any PRs on the page and there are more PRs to get, return the checkpoint. # In offset mode, while indexing without time constraints, the pr batch # will be empty when we're done. used_cursor = checkpoint.cursor_url is not None logger.info(f"Fetched {num_prs} PRs for repo: {repo.name}") if num_prs > 0 and not done_with_prs and not used_cursor: return checkpoint # if we went past the start date during the loop or there are no more # prs to get, we move on to issues checkpoint.stage = GithubConnectorStage.ISSUES checkpoint.reset() if used_cursor: # save the checkpoint after changing stage; next run will continue from issues return checkpoint checkpoint.stage = GithubConnectorStage.ISSUES if self.include_issues and checkpoint.stage == GithubConnectorStage.ISSUES: logger.info(f"Fetching issues for repo: {repo.name}") issue_batch = list( _get_batch_rate_limited( self._issues_func(repo), checkpoint.curr_page, checkpoint.cursor_url, checkpoint.num_retrieved, cursor_url_callback, self.github_client, ) ) logger.info(f"Fetched {len(issue_batch)} issues for repo: {repo.name}") checkpoint.curr_page += 1 done_with_issues = False num_issues = 0 for issue in issue_batch: num_issues += 1 issue = cast(Issue, issue) # we iterate backwards in time, so at this point we stop processing prs if ( start is not None and issue.updated_at.replace(tzinfo=timezone.utc) < start ): done_with_issues = True break # Skip PRs updated after the end date if ( end is not None and issue.updated_at.replace(tzinfo=timezone.utc) > end ): continue if issue.pull_request is not None: # PRs are handled separately continue try: yield _convert_issue_to_document(issue, repo_external_access) except Exception as e: error_msg = f"Error converting issue to document: {e}" logger.exception(error_msg) yield ConnectorFailure( failed_document=DocumentFailure( document_id=str(issue.id), document_link=issue.html_url, ), failure_message=error_msg, exception=e, ) continue logger.info(f"Fetched {num_issues} issues for repo: {repo.name}") # if we found any issues on the page, and we're not done, return the checkpoint. # don't return if we're using cursor-based pagination to avoid infinite loops if num_issues > 0 and not done_with_issues and not checkpoint.cursor_url: return checkpoint # if we went past the start date during the loop or there are no more # issues to get, we move on to the next repo checkpoint.stage = GithubConnectorStage.PRS checkpoint.reset() checkpoint.has_more = len(checkpoint.cached_repo_ids) > 0 if checkpoint.cached_repo_ids: next_id = checkpoint.cached_repo_ids.pop() next_repo = self.github_client.get_repo(next_id) checkpoint.cached_repo = SerializedRepository( id=next_id, headers=next_repo.raw_headers, raw_data=next_repo.raw_data, ) checkpoint.stage = GithubConnectorStage.PRS checkpoint.reset() if checkpoint.cached_repo_ids: logger.info( f"{len(checkpoint.cached_repo_ids)} repos remaining (IDs: {checkpoint.cached_repo_ids})" ) else: logger.info("No more repos remaining") return checkpoint def _load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: GithubConnectorCheckpoint, include_permissions: bool = False, ) -> CheckpointOutput[GithubConnectorCheckpoint]: start_datetime = datetime.fromtimestamp(start, tz=timezone.utc) # add a day for timezone safety end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) + ONE_DAY # Move start time back by 3 hours, since some Issues/PRs are getting dropped # Could be due to delayed processing on GitHub side # The non-updated issues since last poll will be shortcut-ed and not embedded adjusted_start_datetime = start_datetime - timedelta(hours=3) epoch = datetime.fromtimestamp(0, tz=timezone.utc) if adjusted_start_datetime < epoch: adjusted_start_datetime = epoch return self._fetch_from_github( checkpoint, start=adjusted_start_datetime, end=end_datetime, include_permissions=include_permissions, ) @override def load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: GithubConnectorCheckpoint, ) -> CheckpointOutput[GithubConnectorCheckpoint]: return self._load_from_checkpoint( start, end, checkpoint, include_permissions=False ) @override def load_from_checkpoint_with_perm_sync( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: GithubConnectorCheckpoint, ) -> CheckpointOutput[GithubConnectorCheckpoint]: return self._load_from_checkpoint( start, end, checkpoint, include_permissions=True ) def validate_connector_settings(self) -> None: if self.github_client is None: raise ConnectorMissingCredentialError("GitHub credentials not loaded.") if not self.repo_owner: raise ConnectorValidationError( "Invalid connector settings: 'repo_owner' must be provided." ) try: if self.repositories: if "," in self.repositories: # Multiple repositories specified repo_names = [name.strip() for name in self.repositories.split(",")] if not repo_names: raise ConnectorValidationError( "Invalid connector settings: No valid repository names provided." ) # Validate at least one repository exists and is accessible valid_repos = False validation_errors = [] for repo_name in repo_names: if not repo_name: continue try: test_repo = self.github_client.get_repo( f"{self.repo_owner}/{repo_name}" ) logger.info( f"Successfully accessed repository: {self.repo_owner}/{repo_name}" ) test_repo.get_contents("") valid_repos = True # If at least one repo is valid, we can proceed break except GithubException as e: validation_errors.append( f"Repository '{repo_name}': {e.data.get('message', str(e))}" ) if not valid_repos: error_msg = ( "None of the specified repositories could be accessed: " ) error_msg += ", ".join(validation_errors) raise ConnectorValidationError(error_msg) else: # Single repository (backward compatibility) test_repo = self.github_client.get_repo( f"{self.repo_owner}/{self.repositories}" ) test_repo.get_contents("") else: # Try to get organization first try: org = self.github_client.get_organization(self.repo_owner) total_count = org.get_repos().totalCount if total_count == 0: raise ConnectorValidationError( f"Found no repos for organization: {self.repo_owner}. Does the credential have the right scopes?" ) except GithubException as e: # Check for missing SSO MISSING_SSO_ERROR_MESSAGE = "You must grant your Personal Access token access to this organization".lower() if MISSING_SSO_ERROR_MESSAGE in str(e).lower(): SSO_GUIDE_LINK = ( "https://docs.github.com/en/enterprise-cloud@latest/authentication/" "authenticating-with-saml-single-sign-on/" "authorizing-a-personal-access-token-for-use-with-saml-single-sign-on" ) raise ConnectorValidationError( f"Your GitHub token is missing authorization to access the " f"`{self.repo_owner}` organization. Please follow the guide to " f"authorize your token: {SSO_GUIDE_LINK}" ) # If not an org, try as a user user = self.github_client.get_user(self.repo_owner) # Check if we can access any repos total_count = user.get_repos().totalCount if total_count == 0: raise ConnectorValidationError( f"Found no repos for user: {self.repo_owner}. Does the credential have the right scopes?" ) except RateLimitExceededException: raise UnexpectedValidationError( "Validation failed due to GitHub rate-limits being exceeded. Please try again later." ) except GithubException as e: if e.status == 401: raise CredentialExpiredError( "GitHub credential appears to be invalid or expired (HTTP 401)." ) elif e.status == 403: raise InsufficientPermissionsError( "Your GitHub token does not have sufficient permissions for this repository (HTTP 403)." ) elif e.status == 404: if self.repositories: if "," in self.repositories: raise ConnectorValidationError( f"None of the specified GitHub repositories could be found for owner: {self.repo_owner}" ) else: raise ConnectorValidationError( f"GitHub repository not found with name: {self.repo_owner}/{self.repositories}" ) else: raise ConnectorValidationError( f"GitHub user or organization not found: {self.repo_owner}" ) else: raise ConnectorValidationError( f"Unexpected GitHub error (status={e.status}): {e.data}" ) except Exception as exc: raise Exception( f"Unexpected error during GitHub settings validation: {exc}" ) def validate_checkpoint_json( self, checkpoint_json: str ) -> GithubConnectorCheckpoint: return GithubConnectorCheckpoint.model_validate_json(checkpoint_json) def build_dummy_checkpoint(self) -> GithubConnectorCheckpoint: return GithubConnectorCheckpoint( stage=GithubConnectorStage.PRS, curr_page=0, has_more=True, num_retrieved=0 ) if __name__ == "__main__": import os # Initialize the connector connector = GithubConnector( repo_owner=os.environ["REPO_OWNER"], repositories=os.environ.get("REPOSITORIES"), ) connector.load_credentials( {"github_access_token": os.environ["ACCESS_TOKEN_GITHUB"]} ) if connector.github_client: get_external_access_permission( connector.get_github_repos(connector.github_client).pop(), connector.github_client, ) # Create a time range from epoch to now end_time = datetime.now(timezone.utc) start_time = datetime.fromtimestamp(0, tz=timezone.utc) time_range = (start_time, end_time) # Initialize the runner with a batch size of 10 runner: ConnectorRunner[GithubConnectorCheckpoint] = ConnectorRunner( connector, batch_size=10, include_permissions=False, time_range=time_range ) # Get initial checkpoint checkpoint = connector.build_dummy_checkpoint() # Run the connector while checkpoint.has_more: for doc_batch, hierarchy_node_batch, failure, next_checkpoint in runner.run( checkpoint ): if doc_batch: print(f"Retrieved batch of {len(doc_batch)} documents") for doc in doc_batch: print(f"Document: {doc.semantic_identifier}") if failure: print(f"Failure: {failure.failure_message}") if next_checkpoint: checkpoint = next_checkpoint ================================================ FILE: backend/onyx/connectors/github/models.py ================================================ from typing import Any from github import Repository from github.Requester import Requester from pydantic import BaseModel class SerializedRepository(BaseModel): # id is part of the raw_data as well, just pulled out for convenience id: int headers: dict[str, str | int] raw_data: dict[str, Any] def to_Repository(self, requester: Requester) -> Repository.Repository: return Repository.Repository( requester, self.headers, self.raw_data, completed=True ) ================================================ FILE: backend/onyx/connectors/github/rate_limit_utils.py ================================================ import time from datetime import datetime from datetime import timedelta from datetime import timezone from github import Github from onyx.utils.logger import setup_logger logger = setup_logger() def sleep_after_rate_limit_exception(github_client: Github) -> None: """ Sleep until the GitHub rate limit resets. Args: github_client: The GitHub client that hit the rate limit """ sleep_time = github_client.get_rate_limit().core.reset.replace( tzinfo=timezone.utc ) - datetime.now(tz=timezone.utc) sleep_time += timedelta(minutes=1) # add an extra minute just to be safe logger.notice(f"Ran into Github rate-limit. Sleeping {sleep_time.seconds} seconds.") time.sleep(sleep_time.total_seconds()) ================================================ FILE: backend/onyx/connectors/github/utils.py ================================================ from collections.abc import Callable from typing import cast from github import Github from github.Repository import Repository from onyx.access.models import ExternalAccess from onyx.connectors.github.models import SerializedRepository from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_versioned_implementation from onyx.utils.variable_functionality import global_version logger = setup_logger() def get_external_access_permission( repo: Repository, github_client: Github ) -> ExternalAccess: """ Get the external access permission for a repository. This functionality requires Enterprise Edition. """ # Check if EE is enabled if not global_version.is_ee_version(): # For the MIT version, return an empty ExternalAccess (private document) return ExternalAccess.empty() # Fetch the EE implementation ee_get_external_access_permission = cast( Callable[[Repository, Github, bool], ExternalAccess], fetch_versioned_implementation( "onyx.external_permissions.github.utils", "get_external_access_permission", ), ) return ee_get_external_access_permission(repo, github_client, True) def deserialize_repository( cached_repo: SerializedRepository, github_client: Github ) -> Repository: """ Deserialize a SerializedRepository back into a Repository object. """ # Try to access the requester - different PyGithub versions may use different attribute names try: # Try to get the requester using getattr to avoid linter errors requester = getattr(github_client, "_requester", None) if requester is None: requester = getattr(github_client, "_Github__requester", None) if requester is None: # If we can't find the requester attribute, we need to fall back to recreating the repo raise AttributeError("Could not find requester attribute") return cached_repo.to_Repository(requester) except Exception as e: # If all else fails, re-fetch the repo directly logger.warning( f"Failed to deserialize repository: {e}. Attempting to re-fetch." ) repo_id = cached_repo.id return github_client.get_repo(repo_id) ================================================ FILE: backend/onyx/connectors/gitlab/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/gitlab/connector.py ================================================ import fnmatch import itertools from collections import deque from collections.abc import Iterable from collections.abc import Iterator from datetime import datetime from datetime import timezone from typing import Any from typing import TypeVar import gitlab import pytz from gitlab.v4.objects import Project from onyx.configs.app_configs import GITLAB_CONNECTOR_INCLUDE_CODE_FILES from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import TextSection from onyx.utils.logger import setup_logger T = TypeVar("T") logger = setup_logger() # List of directories/Files to exclude exclude_patterns = [ "logs", ".github/", ".gitlab/", ".pre-commit-config.yaml", ] def _batch_gitlab_objects(git_objs: Iterable[T], batch_size: int) -> Iterator[list[T]]: it = iter(git_objs) while True: batch = list(itertools.islice(it, batch_size)) if not batch: break yield batch def get_author(author: Any) -> BasicExpertInfo: return BasicExpertInfo( display_name=author.get("name"), ) def _convert_merge_request_to_document(mr: Any) -> Document: doc = Document( id=mr.web_url, sections=[TextSection(link=mr.web_url, text=mr.description or "")], source=DocumentSource.GITLAB, semantic_identifier=mr.title, # updated_at is UTC time but is timezone unaware, explicitly add UTC # as there is logic in indexing to prevent wrong timestamped docs # due to local time discrepancies with UTC doc_updated_at=mr.updated_at.replace(tzinfo=timezone.utc), primary_owners=[get_author(mr.author)], metadata={"state": mr.state, "type": "MergeRequest"}, ) return doc def _convert_issue_to_document(issue: Any) -> Document: doc = Document( id=issue.web_url, sections=[TextSection(link=issue.web_url, text=issue.description or "")], source=DocumentSource.GITLAB, semantic_identifier=issue.title, # updated_at is UTC time but is timezone unaware, explicitly add UTC # as there is logic in indexing to prevent wrong timestamped docs # due to local time discrepancies with UTC doc_updated_at=issue.updated_at.replace(tzinfo=timezone.utc), primary_owners=[get_author(issue.author)], metadata={"state": issue.state, "type": issue.type if issue.type else "Issue"}, ) return doc def _convert_code_to_document( project: Project, file: Any, url: str, projectName: str, projectOwner: str ) -> Document: # Dynamically get the default branch from the project object default_branch = project.default_branch # Fetch the file content using the correct branch file_content_obj = project.files.get( file_path=file["path"], ref=default_branch, # Use the default branch ) try: file_content = file_content_obj.decode().decode("utf-8") except UnicodeDecodeError: file_content = file_content_obj.decode().decode("latin-1") # Construct the file URL dynamically using the default branch file_url = ( f"{url}/{projectOwner}/{projectName}/-/blob/{default_branch}/{file['path']}" ) # Create and return a Document object doc = Document( id=file["id"], sections=[TextSection(link=file_url, text=file_content)], source=DocumentSource.GITLAB, semantic_identifier=file["name"], doc_updated_at=datetime.now().replace(tzinfo=timezone.utc), primary_owners=[], # Add owners if needed metadata={"type": "CodeFile"}, ) return doc def _should_exclude(path: str) -> bool: """Check if a path matches any of the exclude patterns.""" return any(fnmatch.fnmatch(path, pattern) for pattern in exclude_patterns) class GitlabConnector(LoadConnector, PollConnector): def __init__( self, project_owner: str, project_name: str, batch_size: int = INDEX_BATCH_SIZE, state_filter: str = "all", include_mrs: bool = True, include_issues: bool = True, include_code_files: bool = GITLAB_CONNECTOR_INCLUDE_CODE_FILES, ) -> None: self.project_owner = project_owner self.project_name = project_name self.batch_size = batch_size self.state_filter = state_filter self.include_mrs = include_mrs self.include_issues = include_issues self.include_code_files = include_code_files self.gitlab_client: gitlab.Gitlab | None = None def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self.gitlab_client = gitlab.Gitlab( credentials["gitlab_url"], private_token=credentials["gitlab_access_token"] ) return None def _fetch_from_gitlab( self, start: datetime | None = None, end: datetime | None = None ) -> GenerateDocumentsOutput: if self.gitlab_client is None: raise ConnectorMissingCredentialError("Gitlab") project: Project = self.gitlab_client.projects.get( f"{self.project_owner}/{self.project_name}" ) # Fetch code files if self.include_code_files: # Fetching using BFS as project.report_tree with recursion causing slow load queue = deque([""]) # Start with the root directory while queue: current_path = queue.popleft() files = project.repository_tree(path=current_path, all=True) for file_batch in _batch_gitlab_objects(files, self.batch_size): code_doc_batch: list[Document | HierarchyNode] = [] for file in file_batch: if _should_exclude(file["path"]): continue if file["type"] == "blob": code_doc_batch.append( _convert_code_to_document( project, file, self.gitlab_client.url, self.project_name, self.project_owner, ) ) elif file["type"] == "tree": queue.append(file["path"]) if code_doc_batch: yield code_doc_batch if self.include_mrs: merge_requests = project.mergerequests.list( state=self.state_filter, order_by="updated_at", sort="desc", iterator=True, ) for mr_batch in _batch_gitlab_objects(merge_requests, self.batch_size): mr_doc_batch: list[Document | HierarchyNode] = [] for mr in mr_batch: mr.updated_at = datetime.strptime( mr.updated_at, "%Y-%m-%dT%H:%M:%S.%f%z" ) if start is not None and mr.updated_at < start.replace( tzinfo=pytz.UTC ): yield mr_doc_batch return if end is not None and mr.updated_at > end.replace(tzinfo=pytz.UTC): continue mr_doc_batch.append(_convert_merge_request_to_document(mr)) yield mr_doc_batch if self.include_issues: issues = project.issues.list(state=self.state_filter, iterator=True) for issue_batch in _batch_gitlab_objects(issues, self.batch_size): issue_doc_batch: list[Document | HierarchyNode] = [] for issue in issue_batch: issue.updated_at = datetime.strptime( issue.updated_at, "%Y-%m-%dT%H:%M:%S.%f%z" ) if start is not None: start = start.replace(tzinfo=pytz.UTC) if issue.updated_at < start: yield issue_doc_batch return if end is not None: end = end.replace(tzinfo=pytz.UTC) if issue.updated_at > end: continue issue_doc_batch.append(_convert_issue_to_document(issue)) yield issue_doc_batch def load_from_state(self) -> GenerateDocumentsOutput: return self._fetch_from_gitlab() def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: start_datetime = datetime.fromtimestamp(start, tz=timezone.utc) end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) return self._fetch_from_gitlab(start_datetime, end_datetime) if __name__ == "__main__": import os connector = GitlabConnector( # gitlab_url="https://gitlab.com/api/v4", project_owner=os.environ["PROJECT_OWNER"], project_name=os.environ["PROJECT_NAME"], batch_size=10, state_filter="all", include_mrs=True, include_issues=True, include_code_files=GITLAB_CONNECTOR_INCLUDE_CODE_FILES, ) connector.load_credentials( { "gitlab_access_token": os.environ["GITLAB_ACCESS_TOKEN"], "gitlab_url": os.environ["GITLAB_URL"], } ) document_batches = connector.load_from_state() print(next(document_batches)) ================================================ FILE: backend/onyx/connectors/gmail/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/gmail/connector.py ================================================ from base64 import urlsafe_b64decode from collections.abc import Callable from collections.abc import Iterator from typing import Any from typing import cast from typing import Dict from google.oauth2.credentials import Credentials as OAuthCredentials from google.oauth2.service_account import Credentials as ServiceAccountCredentials from googleapiclient.errors import HttpError # type: ignore from onyx.access.models import ExternalAccess from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from onyx.connectors.google_utils.google_auth import get_google_creds from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval from onyx.connectors.google_utils.google_utils import ( execute_paginated_retrieval_with_max_pages, ) from onyx.connectors.google_utils.google_utils import execute_single_retrieval from onyx.connectors.google_utils.google_utils import PAGE_TOKEN_KEY from onyx.connectors.google_utils.resources import get_admin_service from onyx.connectors.google_utils.resources import get_gmail_service from onyx.connectors.google_utils.resources import GmailService from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_PRIMARY_ADMIN_KEY, ) from onyx.connectors.google_utils.shared_constants import MISSING_SCOPES_ERROR_STR from onyx.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTIONS from onyx.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE from onyx.connectors.google_utils.shared_constants import USER_FIELDS from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import ConnectorFailure from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnectorWithPermSync from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import ConnectorCheckpoint from onyx.connectors.models import Document from onyx.connectors.models import DocumentFailure from onyx.connectors.models import HierarchyNode from onyx.connectors.models import ImageSection from onyx.connectors.models import SlimDocument from onyx.connectors.models import TextSection from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger from onyx.utils.retry_wrapper import retry_builder logger = setup_logger() # This is for the initial list call to get the thread ids THREAD_LIST_FIELDS = "nextPageToken, threads(id)" # These are the fields to retrieve using the ID from the initial list call PARTS_FIELDS = "parts(body(data), mimeType)" PAYLOAD_FIELDS = f"payload(headers, {PARTS_FIELDS})" MESSAGES_FIELDS = f"messages(id, {PAYLOAD_FIELDS})" THREADS_FIELDS = f"threads(id, {MESSAGES_FIELDS})" THREAD_FIELDS = f"id, {MESSAGES_FIELDS}" EMAIL_FIELDS = [ "cc", "bcc", "from", "to", ] MAX_MESSAGE_BODY_BYTES = 10 * 1024 * 1024 # 10MB cap to keep large threads safe PAGES_PER_CHECKPOINT = 1 add_retries = retry_builder(tries=50, max_delay=30) def _is_mail_service_disabled_error(error: HttpError) -> bool: """Detect if the Gmail API is telling us the mailbox is not provisioned.""" if error.resp.status != 400: return False error_message = str(error) return ( "Mail service not enabled" in error_message or "failedPrecondition" in error_message ) def _build_time_range_query( time_range_start: SecondsSinceUnixEpoch | None = None, time_range_end: SecondsSinceUnixEpoch | None = None, ) -> str | None: query = "" if time_range_start is not None and time_range_start != 0: query += f"after:{int(time_range_start)}" if time_range_end is not None and time_range_end != 0: query += f" before:{int(time_range_end)}" query = query.strip() if len(query) == 0: return None return query def _clean_email_and_extract_name(email: str) -> tuple[str, str | None]: email = email.strip() if "<" in email and ">" in email: # Handle format: "Display Name " display_name = email[: email.find("<")].strip() email_address = email[email.find("<") + 1 : email.find(">")].strip() return email_address, display_name if display_name else None else: # Handle plain email address return email.strip(), None def _get_owners_from_emails(emails: dict[str, str | None]) -> list[BasicExpertInfo]: owners = [] for email, names in emails.items(): if names: name_parts = names.split(" ") first_name = " ".join(name_parts[:-1]) last_name = name_parts[-1] else: first_name = None last_name = None owners.append( BasicExpertInfo(email=email, first_name=first_name, last_name=last_name) ) return owners def _get_message_body(payload: dict[str, Any]) -> str: """ Gmail threads can contain large inline parts (including attachments transmitted as base64). Only decode text/plain parts and skip anything that breaches the safety threshold to protect against OOMs. """ message_body_chunks: list[str] = [] stack = [payload] while stack: part = stack.pop() if not part: continue children = part.get("parts", []) stack.extend(reversed(children)) mime_type = part.get("mimeType") if mime_type != "text/plain": continue body = part.get("body", {}) data = body.get("data", "") if not data: continue # base64 inflates storage by ~4/3; work with decoded size estimate approx_decoded_size = (len(data) * 3) // 4 if approx_decoded_size > MAX_MESSAGE_BODY_BYTES: logger.warning( "Skipping oversized Gmail message part (%s bytes > %s limit)", approx_decoded_size, MAX_MESSAGE_BODY_BYTES, ) continue try: text = urlsafe_b64decode(data).decode() except (ValueError, UnicodeDecodeError) as error: logger.warning("Failed to decode Gmail message part: %s", error) continue message_body_chunks.append(text) return "".join(message_body_chunks) def _build_document_link(thread_id: str) -> str: return f"https://mail.google.com/mail/u/0/#inbox/{thread_id}" def message_to_section(message: Dict[str, Any]) -> tuple[TextSection, dict[str, str]]: link = _build_document_link(message["id"]) payload = message.get("payload", {}) headers = payload.get("headers", []) metadata: dict[str, Any] = {} for header in headers: name = header.get("name").lower() value = header.get("value") if name in EMAIL_FIELDS: metadata[name] = value if name == "subject": metadata["subject"] = value if name == "date": metadata["updated_at"] = value if labels := message.get("labelIds"): metadata["labels"] = labels message_data = "" for name, value in metadata.items(): # updated at isnt super useful for the llm if name != "updated_at": message_data += f"{name}: {value}\n" message_body_text: str = _get_message_body(payload) return TextSection(link=link, text=message_body_text + message_data), metadata def thread_to_document( full_thread: Dict[str, Any], email_used_to_fetch_thread: str ) -> Document | None: all_messages = full_thread.get("messages", []) if not all_messages: return None sections = [] semantic_identifier = "" updated_at = None from_emails: dict[str, str | None] = {} other_emails: dict[str, str | None] = {} for message in all_messages: section, message_metadata = message_to_section(message) sections.append(section) for name, value in message_metadata.items(): if name in EMAIL_FIELDS: email, display_name = _clean_email_and_extract_name(value) if name == "from": from_emails[email] = ( display_name if not from_emails.get(email) else None ) else: other_emails[email] = ( display_name if not other_emails.get(email) else None ) # If we haven't set the semantic identifier yet, set it to the subject of the first message if not semantic_identifier: semantic_identifier = message_metadata.get("subject", "") if message_metadata.get("updated_at"): updated_at = message_metadata.get("updated_at") updated_at_datetime = None if updated_at: updated_at_datetime = time_str_to_utc(updated_at) id = full_thread.get("id") if not id: raise ValueError("Thread ID is required") primary_owners = _get_owners_from_emails(from_emails) secondary_owners = _get_owners_from_emails(other_emails) # If emails have no subject, match Gmail's default "no subject" # Search will break without a semantic identifier if not semantic_identifier: semantic_identifier = "(no subject)" # NOTE: we're choosing to unconditionally include perm sync info # (external_access) as it doesn't cost much space return Document( id=id, semantic_identifier=semantic_identifier, sections=cast(list[TextSection | ImageSection], sections), source=DocumentSource.GMAIL, # This is used to perform permission sync primary_owners=primary_owners, secondary_owners=secondary_owners, doc_updated_at=updated_at_datetime, # Not adding emails to metadata because it's already in the sections metadata={}, external_access=ExternalAccess( external_user_emails={email_used_to_fetch_thread}, external_user_group_ids=set(), is_public=False, ), ) def _full_thread_from_id( thread_id: str, user_email: str, gmail_service: GmailService, ) -> Document | ConnectorFailure | None: try: thread = next( execute_single_retrieval( retrieval_function=gmail_service.users().threads().get, list_key=None, userId=user_email, fields=THREAD_FIELDS, id=thread_id, continue_on_404_or_403=True, ), None, ) if thread is None: raise ValueError(f"Thread {thread_id} not found") return thread_to_document(thread, user_email) except Exception as e: return ConnectorFailure( failed_document=DocumentFailure( document_id=thread_id, document_link=_build_document_link(thread_id) ), failure_message=f"Failed to retrieve thread {thread_id}", exception=e, ) def _slim_thread_from_id( thread_id: str, user_email: str, gmail_service: GmailService, # noqa: ARG001 ) -> SlimDocument: return SlimDocument( id=thread_id, external_access=ExternalAccess( external_user_emails={user_email}, external_user_group_ids=set(), is_public=False, ), ) class GmailCheckpoint(ConnectorCheckpoint): user_emails: list[str] = [] # stack of user emails to process page_token: str | None = None class GmailConnector( SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GmailCheckpoint] ): def __init__(self, batch_size: int = INDEX_BATCH_SIZE) -> None: self.batch_size = batch_size self._creds: OAuthCredentials | ServiceAccountCredentials | None = None self._primary_admin_email: str | None = None @property def primary_admin_email(self) -> str: if self._primary_admin_email is None: raise RuntimeError( "Primary admin email missing, should not call this property before calling load_credentials" ) return self._primary_admin_email @property def google_domain(self) -> str: if self._primary_admin_email is None: raise RuntimeError( "Primary admin email missing, should not call this property before calling load_credentials" ) return self._primary_admin_email.split("@")[-1] @property def creds(self) -> OAuthCredentials | ServiceAccountCredentials: if self._creds is None: raise RuntimeError( "Creds missing, should not call this property before calling load_credentials" ) return self._creds def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None: primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] self._primary_admin_email = primary_admin_email self._creds, new_creds_dict = get_google_creds( credentials=credentials, source=DocumentSource.GMAIL, ) return new_creds_dict def _get_all_user_emails(self) -> list[str]: """ List all user emails if we are on a Google Workspace domain. If the domain is gmail.com, or if we attempt to call the Admin SDK and get a 404 or 403, fall back to using the single user. A 404 indicates a personal Gmail account with no Workspace domain. A 403 indicates insufficient permissions (e.g., OAuth user without admin privileges). """ try: admin_service = get_admin_service(self.creds, self.primary_admin_email) emails = [] for user in execute_paginated_retrieval( retrieval_function=admin_service.users().list, list_key="users", fields=USER_FIELDS, domain=self.google_domain, ): if email := user.get("primaryEmail"): emails.append(email) return emails except HttpError as e: if e.resp.status == 404: logger.warning( "Received 404 from Admin SDK; this may indicate a personal Gmail account " "with no Workspace domain. Falling back to single user." ) return [self.primary_admin_email] elif e.resp.status == 403: logger.warning( "Received 403 from Admin SDK; this may indicate insufficient permissions " "(e.g., OAuth user without admin privileges or service account without " "domain-wide delegation). Falling back to single user." ) return [self.primary_admin_email] raise def _fetch_threads_impl( self, user_email: str, time_range_start: SecondsSinceUnixEpoch | None = None, time_range_end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, page_token: str | None = None, set_page_token: Callable[[str | None], None] = lambda x: None, # noqa: ARG005 is_slim: bool = False, ) -> Iterator[Document | ConnectorFailure] | GenerateSlimDocumentOutput: query = _build_time_range_query(time_range_start, time_range_end) slim_doc_batch: list[SlimDocument | HierarchyNode] = [] logger.info( f"Fetching {'slim' if is_slim else 'full'} threads for user: {user_email}" ) gmail_service = get_gmail_service(self.creds, user_email) try: for thread in execute_paginated_retrieval_with_max_pages( max_num_pages=PAGES_PER_CHECKPOINT, retrieval_function=gmail_service.users().threads().list, list_key="threads", userId=user_email, fields=THREAD_LIST_FIELDS, q=query, continue_on_404_or_403=True, **({PAGE_TOKEN_KEY: page_token} if page_token else {}), ): # if a page token is returned, set it and leave the function if isinstance(thread, str): set_page_token(thread) return if is_slim: slim_doc_batch.append( SlimDocument( id=thread["id"], external_access=ExternalAccess( external_user_emails={user_email}, external_user_group_ids=set(), is_public=False, ), ) ) if len(slim_doc_batch) >= SLIM_BATCH_SIZE: yield slim_doc_batch slim_doc_batch = [] else: result = _full_thread_from_id( thread["id"], user_email, gmail_service ) if result is not None: yield result if callback: tag = ( "retrieve_all_slim_docs_perm_sync" if is_slim else "gmail_retrieve_all_docs" ) if callback.should_stop(): raise RuntimeError(f"{tag}: Stop signal detected") callback.progress(tag, 1) if slim_doc_batch: yield slim_doc_batch # done with user set_page_token(None) except HttpError as e: if _is_mail_service_disabled_error(e): logger.warning( "Skipping Gmail sync for %s because the mailbox is disabled.", user_email, ) return raise def _fetch_threads( self, user_email: str, page_token: str | None = None, set_page_token: Callable[[str | None], None] = lambda x: None, # noqa: ARG005 time_range_start: SecondsSinceUnixEpoch | None = None, time_range_end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> Iterator[Document | ConnectorFailure]: yield from cast( Iterator[Document | ConnectorFailure], self._fetch_threads_impl( user_email, time_range_start, time_range_end, callback, page_token, set_page_token, False, ), ) def _fetch_slim_threads( self, user_email: str, page_token: str | None = None, set_page_token: Callable[[str | None], None] = lambda x: None, # noqa: ARG005 time_range_start: SecondsSinceUnixEpoch | None = None, time_range_end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> GenerateSlimDocumentOutput: yield from cast( GenerateSlimDocumentOutput, self._fetch_threads_impl( user_email, time_range_start, time_range_end, callback, page_token, set_page_token, True, ), ) def _load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: GmailCheckpoint, ) -> CheckpointOutput[GmailCheckpoint]: if not checkpoint.user_emails: checkpoint.user_emails = self._get_all_user_emails() try: def set_page_token(page_token: str | None) -> None: checkpoint.page_token = page_token yield from self._fetch_threads( checkpoint.user_emails[-1], checkpoint.page_token, set_page_token, start, end, callback=None, ) if checkpoint.page_token is None: # we're done with this user checkpoint.user_emails.pop() if len(checkpoint.user_emails) == 0: checkpoint.has_more = False return checkpoint except Exception as e: if MISSING_SCOPES_ERROR_STR in str(e): raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e raise e def load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: GmailCheckpoint, ) -> CheckpointOutput[GmailCheckpoint]: return self._load_from_checkpoint( start=start, end=end, checkpoint=checkpoint, ) def load_from_checkpoint_with_perm_sync( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: GmailCheckpoint, ) -> CheckpointOutput[GmailCheckpoint]: # NOTE: we're choosing to unconditionally include perm sync info # (external_access) as it doesn't cost much space return self._load_from_checkpoint( start=start, end=end, checkpoint=checkpoint, ) def retrieve_all_slim_docs_perm_sync( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> GenerateSlimDocumentOutput: try: pt_dict: dict[str, str | None] = {PAGE_TOKEN_KEY: None} def set_page_token(page_token: str | None) -> None: pt_dict[PAGE_TOKEN_KEY] = page_token for user_email in self._get_all_user_emails(): yield from self._fetch_slim_threads( user_email, pt_dict[PAGE_TOKEN_KEY], set_page_token, start, end, callback=callback, ) except Exception as e: if MISSING_SCOPES_ERROR_STR in str(e): raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e raise e def build_dummy_checkpoint(self) -> GmailCheckpoint: return GmailCheckpoint(has_more=True) def validate_checkpoint_json(self, checkpoint_json: str) -> GmailCheckpoint: return GmailCheckpoint.model_validate_json(checkpoint_json) if __name__ == "__main__": pass ================================================ FILE: backend/onyx/connectors/gong/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/gong/connector.py ================================================ import base64 import time from collections.abc import Generator from datetime import datetime from datetime import timedelta from datetime import timezone from typing import Any from typing import cast import requests from requests.adapters import HTTPAdapter from urllib3.util import Retry from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE from onyx.configs.app_configs import GONG_CONNECTOR_START_TIME from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import TextSection from onyx.utils.logger import setup_logger logger = setup_logger() class GongConnector(LoadConnector, PollConnector): BASE_URL = "https://api.gong.io" MAX_CALL_DETAILS_ATTEMPTS = 6 CALL_DETAILS_DELAY = 30 # in seconds # Gong API limit is 3 calls/sec — stay safely under it MIN_REQUEST_INTERVAL = 0.5 # seconds between requests def __init__( self, workspaces: list[str] | None = None, batch_size: int = INDEX_BATCH_SIZE, continue_on_fail: bool = CONTINUE_ON_CONNECTOR_FAILURE, hide_user_info: bool = False, ) -> None: self.workspaces = workspaces self.batch_size: int = batch_size self.continue_on_fail = continue_on_fail self.auth_token_basic: str | None = None self.hide_user_info = hide_user_info self._last_request_time: float = 0.0 # urllib3 Retry already respects the Retry-After header by default # (respect_retry_after_header=True), so on 429 it will sleep for the # duration Gong specifies before retrying. retry_strategy = Retry( total=10, backoff_factor=2, status_forcelist=[429, 500, 502, 503, 504], ) session = requests.Session() session.mount(GongConnector.BASE_URL, HTTPAdapter(max_retries=retry_strategy)) self._session = session @staticmethod def make_url(endpoint: str) -> str: url = f"{GongConnector.BASE_URL}{endpoint}" return url def _throttled_request( self, method: str, url: str, **kwargs: Any ) -> requests.Response: """Rate-limited request wrapper. Enforces MIN_REQUEST_INTERVAL between calls to stay under Gong's 3 calls/sec limit and avoid triggering 429s.""" now = time.monotonic() elapsed = now - self._last_request_time if elapsed < self.MIN_REQUEST_INTERVAL: time.sleep(self.MIN_REQUEST_INTERVAL - elapsed) response = self._session.request(method, url, **kwargs) self._last_request_time = time.monotonic() return response def _get_workspace_id_map(self) -> dict[str, str]: response = self._throttled_request( "GET", GongConnector.make_url("/v2/workspaces") ) response.raise_for_status() workspaces_details = response.json().get("workspaces") name_id_map = { workspace["name"]: workspace["id"] for workspace in workspaces_details } id_id_map = { workspace["id"]: workspace["id"] for workspace in workspaces_details } # In very rare case, if a workspace is given a name which is the id of another workspace, # Then the user input is treated as the name return {**id_id_map, **name_id_map} def _get_transcript_batches( self, start_datetime: str | None = None, end_datetime: str | None = None ) -> Generator[list[dict[str, Any]], None, None]: body: dict[str, dict] = {"filter": {}} if start_datetime: body["filter"]["fromDateTime"] = start_datetime if end_datetime: body["filter"]["toDateTime"] = end_datetime # The batch_ids in the previous method appears to be batches of call_ids to process # In this method, we will retrieve transcripts for them in batches. transcripts: list[dict[str, Any]] = [] workspace_list = self.workspaces or [None] # type: ignore workspace_map = self._get_workspace_id_map() if self.workspaces else {} for workspace in workspace_list: if workspace: logger.info(f"Updating Gong workspace: {workspace}") workspace_id = workspace_map.get(workspace) if not workspace_id: logger.error(f"Invalid Gong workspace: {workspace}") if not self.continue_on_fail: raise ValueError(f"Invalid workspace: {workspace}") continue body["filter"]["workspaceId"] = workspace_id else: if "workspaceId" in body["filter"]: del body["filter"]["workspaceId"] while True: response = self._throttled_request( "POST", GongConnector.make_url("/v2/calls/transcript"), json=body ) # If no calls in the range, just break out if response.status_code == 404: break try: response.raise_for_status() except Exception: logger.error(f"Error fetching transcripts: {response.text}") raise data = response.json() call_transcripts = data.get("callTranscripts", []) transcripts.extend(call_transcripts) while len(transcripts) >= self.batch_size: yield transcripts[: self.batch_size] transcripts = transcripts[self.batch_size :] cursor = data.get("records", {}).get("cursor") if cursor: body["cursor"] = cursor else: break if transcripts: yield transcripts def _get_call_details_by_ids(self, call_ids: list[str]) -> dict: body = { "filter": {"callIds": call_ids}, "contentSelector": {"exposedFields": {"parties": True}}, } response = self._throttled_request( "POST", GongConnector.make_url("/v2/calls/extensive"), json=body ) response.raise_for_status() calls = response.json().get("calls") call_to_metadata = {} for call in calls: call_to_metadata[call["metaData"]["id"]] = call return call_to_metadata @staticmethod def _parse_parties(parties: list[dict]) -> dict[str, str]: id_mapping = {} for party in parties: name = party.get("name") email = party.get("emailAddress") if name and email: full_identifier = f"{name} ({email})" elif name: full_identifier = name elif email: full_identifier = email else: full_identifier = "Unknown" id_mapping[party["speakerId"]] = full_identifier return id_mapping def _fetch_calls( self, start_datetime: str | None = None, end_datetime: str | None = None ) -> GenerateDocumentsOutput: num_calls = 0 for transcript_batch in self._get_transcript_batches( start_datetime, end_datetime ): doc_batch: list[Document | HierarchyNode] = [] transcript_call_ids = cast( list[str], [t.get("callId") for t in transcript_batch if t.get("callId")], ) call_details_map: dict[str, Any] = {} # There's a likely race condition in the API where a transcript will have a # call id but the call to v2/calls/extensive will not return all of the id's # retry with exponential backoff has been observed to mitigate this # in ~2 minutes. After max attempts, proceed with whatever we have — # the per-call loop below will skip missing IDs gracefully. current_attempt = 0 while True: current_attempt += 1 call_details_map = self._get_call_details_by_ids(transcript_call_ids) if set(transcript_call_ids) == set(call_details_map.keys()): # we got all the id's we were expecting ... break and continue break # we are missing some id's. Log and retry with exponential backoff missing_call_ids = set(transcript_call_ids) - set( call_details_map.keys() ) logger.warning( f"_get_call_details_by_ids is missing call id's: " f"current_attempt={current_attempt} " f"missing_call_ids={missing_call_ids}" ) if current_attempt >= self.MAX_CALL_DETAILS_ATTEMPTS: logger.error( f"Giving up on missing call id's after " f"{self.MAX_CALL_DETAILS_ATTEMPTS} attempts: " f"missing_call_ids={missing_call_ids} — " f"proceeding with {len(call_details_map)} of " f"{len(transcript_call_ids)} calls" ) break wait_seconds = self.CALL_DETAILS_DELAY * pow(2, current_attempt - 1) logger.warning( f"_get_call_details_by_ids waiting to retry: " f"wait={wait_seconds}s " f"current_attempt={current_attempt} " f"next_attempt={current_attempt + 1} " f"max_attempts={self.MAX_CALL_DETAILS_ATTEMPTS}" ) time.sleep(wait_seconds) # now we can iterate per call/transcript for transcript in transcript_batch: call_id = transcript.get("callId") if not call_id or call_id not in call_details_map: # NOTE(rkuo): seeing odd behavior where call_ids from the transcript # don't have call details. adding error debugging logs to trace. logger.error( f"Couldn't get call information for Call ID: {call_id}" ) if call_id: logger.error( f"Call debug info: call_id={call_id} " f"call_ids={transcript_call_ids} " f"call_details_map={call_details_map.keys()}" ) if not self.continue_on_fail: raise RuntimeError( f"Couldn't get call information for Call ID: {call_id}" ) continue call_details = call_details_map[call_id] call_metadata = call_details["metaData"] call_time_str = call_metadata["started"] call_title = call_metadata["title"] logger.info( f"{num_calls + 1}: Indexing Gong call id {call_id} from {call_time_str.split('T', 1)[0]}: {call_title}" ) call_parties = cast(list[dict] | None, call_details.get("parties")) if call_parties is None: logger.error(f"Couldn't get parties for Call ID: {call_id}") call_parties = [] id_to_name_map = self._parse_parties(call_parties) # Keeping a separate dict here in case the parties info is incomplete speaker_to_name: dict[str, str] = {} transcript_text = "" call_purpose = call_metadata["purpose"] if call_purpose: transcript_text += f"Call Description: {call_purpose}\n\n" contents = transcript["transcript"] for segment in contents: speaker_id = segment.get("speakerId", "") if speaker_id not in speaker_to_name: if self.hide_user_info: speaker_to_name[speaker_id] = ( f"User {len(speaker_to_name) + 1}" ) else: speaker_to_name[speaker_id] = id_to_name_map.get( speaker_id, "Unknown" ) speaker_name = speaker_to_name[speaker_id] sentences = segment.get("sentences", {}) monolog = " ".join( [sentence.get("text", "") for sentence in sentences] ) transcript_text += f"{speaker_name}: {monolog}\n\n" metadata = {} if call_metadata.get("system"): metadata["client"] = call_metadata.get("system") # TODO calls have a clientUniqueId field, can pull that in later doc_batch.append( Document( id=call_id, sections=[ TextSection(link=call_metadata["url"], text=transcript_text) ], source=DocumentSource.GONG, # Should not ever be Untitled as a call cannot be made without a Title semantic_identifier=call_title or "Untitled", doc_updated_at=datetime.fromisoformat(call_time_str).astimezone( timezone.utc ), metadata={"client": call_metadata.get("system")}, ) ) num_calls += 1 yield doc_batch logger.info(f"_fetch_calls finished: num_calls={num_calls}") def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: combined = ( f"{credentials['gong_access_key']}:{credentials['gong_access_key_secret']}" ) self.auth_token_basic = base64.b64encode(combined.encode("utf-8")).decode( "utf-8" ) if self.auth_token_basic is None: raise ConnectorMissingCredentialError("Gong") self._session.headers.update( {"Authorization": f"Basic {self.auth_token_basic}"} ) return None def load_from_state(self) -> GenerateDocumentsOutput: return self._fetch_calls() def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) # if this env variable is set, don't start from a timestamp before the specified # start time # TODO: remove this once this is globally available if GONG_CONNECTOR_START_TIME: special_start_datetime = datetime.fromisoformat(GONG_CONNECTOR_START_TIME) special_start_datetime = special_start_datetime.replace(tzinfo=timezone.utc) else: special_start_datetime = datetime.fromtimestamp(0, tz=timezone.utc) # don't let the special start dt be past the end time, this causes issues when # the Gong API (`filter.fromDateTime: must be before toDateTime`) special_start_datetime = min(special_start_datetime, end_datetime) start_datetime = max( datetime.fromtimestamp(start, tz=timezone.utc), special_start_datetime ) # Because these are meeting start times, the meeting needs to end and be processed # so adding a 1 day buffer and fetching by default till current time start_one_day_offset = start_datetime - timedelta(days=1) start_time = start_one_day_offset.isoformat() end_time = datetime.fromtimestamp(end, tz=timezone.utc).isoformat() logger.info(f"Fetching Gong calls between {start_time} and {end_time}") return self._fetch_calls(start_time, end_time) if __name__ == "__main__": import os connector = GongConnector() connector.load_credentials( { "gong_access_key": os.environ["GONG_ACCESS_KEY"], "gong_access_key_secret": os.environ["GONG_ACCESS_KEY_SECRET"], } ) latest_docs = connector.load_from_state() print(next(latest_docs)) ================================================ FILE: backend/onyx/connectors/google_drive/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/google_drive/connector.py ================================================ import copy import json import os import sys import threading from collections.abc import Callable from collections.abc import Generator from collections.abc import Iterator from datetime import datetime from enum import Enum from typing import Any from typing import cast from typing import Protocol from urllib.parse import parse_qs from urllib.parse import urlparse from urllib.parse import urlunparse from google.auth.exceptions import RefreshError from google.oauth2.credentials import Credentials as OAuthCredentials from google.oauth2.service_account import Credentials as ServiceAccountCredentials from googleapiclient.errors import HttpError # type: ignore from typing_extensions import override from onyx.access.models import ExternalAccess from onyx.configs.app_configs import GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.app_configs import MAX_DRIVE_WORKERS from onyx.configs.constants import DocumentSource from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.google_drive.doc_conversion import build_slim_document from onyx.connectors.google_drive.doc_conversion import ( convert_drive_item_to_document, ) from onyx.connectors.google_drive.doc_conversion import onyx_document_id_from_drive_file from onyx.connectors.google_drive.doc_conversion import PermissionSyncContext from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files from onyx.connectors.google_drive.file_retrieval import DriveFileFieldType from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth from onyx.connectors.google_drive.file_retrieval import ( get_all_files_in_my_drive_and_shared, ) from onyx.connectors.google_drive.file_retrieval import get_external_access_for_folder from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive from onyx.connectors.google_drive.file_retrieval import get_folder_metadata from onyx.connectors.google_drive.file_retrieval import get_root_folder_id from onyx.connectors.google_drive.file_retrieval import get_shared_drive_name from onyx.connectors.google_drive.file_retrieval import has_link_only_permission from onyx.connectors.google_drive.models import DriveRetrievalStage from onyx.connectors.google_drive.models import GoogleDriveCheckpoint from onyx.connectors.google_drive.models import GoogleDriveFileType from onyx.connectors.google_drive.models import RetrievedDriveFile from onyx.connectors.google_drive.models import StageCompletion from onyx.connectors.google_utils.google_auth import get_google_creds from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval from onyx.connectors.google_utils.google_utils import get_file_owners from onyx.connectors.google_utils.google_utils import GoogleFields from onyx.connectors.google_utils.resources import get_admin_service from onyx.connectors.google_utils.resources import get_drive_service from onyx.connectors.google_utils.resources import GoogleDriveService from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_PRIMARY_ADMIN_KEY, ) from onyx.connectors.google_utils.shared_constants import MISSING_SCOPES_ERROR_STR from onyx.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTIONS from onyx.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE from onyx.connectors.google_utils.shared_constants import USER_FIELDS from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import NormalizationResult from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnectorWithPermSync from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import EntityFailure from onyx.connectors.models import HierarchyNode from onyx.connectors.models import SlimDocument from onyx.db.enums import HierarchyNodeType from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger from onyx.utils.retry_wrapper import retry_builder from onyx.utils.threadpool_concurrency import parallel_yield from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel from onyx.utils.threadpool_concurrency import ThreadSafeDict from onyx.utils.threadpool_concurrency import ThreadSafeSet logger = setup_logger() # TODO: Improve this by using the batch utility: https://googleapis.github.io/google-api-python-client/docs/batch.html # All file retrievals could be batched and made at once BATCHES_PER_CHECKPOINT = 1 DRIVE_BATCH_SIZE = 80 SHARED_DRIVE_PAGES_PER_CHECKPOINT = 2 MY_DRIVE_PAGES_PER_CHECKPOINT = 2 OAUTH_PAGES_PER_CHECKPOINT = 2 FOLDERS_PER_CHECKPOINT = 1 def _extract_str_list_from_comma_str(string: str | None) -> list[str]: if not string: return [] return [s.strip() for s in string.split(",") if s.strip()] def _extract_ids_from_urls(urls: list[str]) -> list[str]: return [urlparse(url).path.strip("/").split("/")[-1] for url in urls] def _clean_requested_drive_ids( requested_drive_ids: set[str], requested_folder_ids: set[str], all_drive_ids_available: set[str], ) -> tuple[list[str], list[str]]: invalid_requested_drive_ids = requested_drive_ids - all_drive_ids_available filtered_folder_ids = requested_folder_ids - all_drive_ids_available if invalid_requested_drive_ids: logger.warning( f"Some shared drive IDs were not found. IDs: {invalid_requested_drive_ids}" ) logger.warning("Checking for folder access instead...") filtered_folder_ids.update(invalid_requested_drive_ids) valid_requested_drive_ids = requested_drive_ids - invalid_requested_drive_ids return sorted(valid_requested_drive_ids), sorted(filtered_folder_ids) def _get_parent_id_from_file(drive_file: GoogleDriveFileType) -> str | None: """Extract the first parent ID from a drive file.""" parents = drive_file.get("parents") if parents and len(parents) > 0: return parents[0] # files have a unique parent return None def _is_shared_drive_root(folder: GoogleDriveFileType) -> bool: """ Check if a folder is a verified shared drive root. For shared drives, we can verify using driveId: - If driveId is set and folder_id == driveId AND no parents, it's the shared drive root - If driveId is set but folder_id != driveId with empty parents, it's a permission issue Returns True only for verified shared drive roots. """ folder_id = folder.get("id") drive_id = folder.get("driveId") parents = folder.get("parents", []) # Must have no parents to be a root if parents: return False # For shared drive content, the root has id == driveId return bool(drive_id and folder_id == drive_id) def _public_access() -> ExternalAccess: return ExternalAccess( external_user_emails=set(), external_user_group_ids=set(), is_public=True, ) class CredentialedRetrievalMethod(Protocol): def __call__( self, field_type: DriveFileFieldType, checkpoint: GoogleDriveCheckpoint, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[RetrievedDriveFile]: ... def add_retrieval_info( drive_files: Iterator[GoogleDriveFileType | str], user_email: str, completion_stage: DriveRetrievalStage, parent_id: str | None = None, ) -> Iterator[RetrievedDriveFile | str]: for file in drive_files: if isinstance(file, str): yield file continue yield RetrievedDriveFile( drive_file=file, user_email=user_email, parent_id=parent_id, completion_stage=completion_stage, ) class DriveIdStatus(Enum): AVAILABLE = "available" IN_PROGRESS = "in_progress" FINISHED = "finished" class GoogleDriveConnector( SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[GoogleDriveCheckpoint] ): def __init__( self, include_shared_drives: bool = False, include_my_drives: bool = False, include_files_shared_with_me: bool = False, shared_drive_urls: str | None = None, my_drive_emails: str | None = None, shared_folder_urls: str | None = None, specific_user_emails: str | None = None, exclude_domain_link_only: bool = False, batch_size: int = INDEX_BATCH_SIZE, # noqa: ARG002 # OLD PARAMETERS folder_paths: list[str] | None = None, include_shared: bool | None = None, follow_shortcuts: bool | None = None, only_org_public: bool | None = None, continue_on_failure: bool | None = None, ) -> None: # Check for old input parameters if folder_paths is not None: logger.warning( "The 'folder_paths' parameter is deprecated. Use 'shared_folder_urls' instead." ) if include_shared is not None: logger.warning( "The 'include_shared' parameter is deprecated. Use 'include_files_shared_with_me' instead." ) if follow_shortcuts is not None: logger.warning("The 'follow_shortcuts' parameter is deprecated.") if only_org_public is not None: logger.warning("The 'only_org_public' parameter is deprecated.") if continue_on_failure is not None: logger.warning("The 'continue_on_failure' parameter is deprecated.") if not any( ( include_shared_drives, include_my_drives, include_files_shared_with_me, shared_folder_urls, my_drive_emails, shared_drive_urls, ) ): raise ConnectorValidationError( "Nothing to index. Please specify at least one of the following: " "include_shared_drives, include_my_drives, include_files_shared_with_me, " "shared_folder_urls, or my_drive_emails" ) specific_requests_made = False if bool(shared_drive_urls) or bool(my_drive_emails) or bool(shared_folder_urls): specific_requests_made = True self.specific_requests_made = specific_requests_made # NOTE: potentially modified in load_credentials if using service account self.include_files_shared_with_me = ( False if specific_requests_made else include_files_shared_with_me ) self.include_my_drives = False if specific_requests_made else include_my_drives self.include_shared_drives = ( False if specific_requests_made else include_shared_drives ) shared_drive_url_list = _extract_str_list_from_comma_str(shared_drive_urls) self._requested_shared_drive_ids = set( _extract_ids_from_urls(shared_drive_url_list) ) self._requested_my_drive_emails = set( _extract_str_list_from_comma_str(my_drive_emails) ) shared_folder_url_list = _extract_str_list_from_comma_str(shared_folder_urls) self._requested_folder_ids = set(_extract_ids_from_urls(shared_folder_url_list)) self._specific_user_emails = _extract_str_list_from_comma_str( specific_user_emails ) self.exclude_domain_link_only = exclude_domain_link_only self._primary_admin_email: str | None = None self._creds: OAuthCredentials | ServiceAccountCredentials | None = None self._creds_dict: dict[str, Any] | None = None # ids of folders and shared drives that have been traversed self._retrieved_folder_and_drive_ids: set[str] = set() # Cache of known My Drive root IDs (user_email -> root_id) # Used to verify if a folder with no parents is actually a My Drive root # Thread-safe because multiple impersonation threads access this concurrently self._my_drive_root_id_cache: ThreadSafeDict[str, str] = ThreadSafeDict() self.allow_images = False self.size_threshold = GOOGLE_DRIVE_CONNECTOR_SIZE_THRESHOLD def set_allow_images(self, value: bool) -> None: self.allow_images = value @property def primary_admin_email(self) -> str: if self._primary_admin_email is None: raise RuntimeError( "Primary admin email missing, should not call this property before calling load_credentials" ) return self._primary_admin_email @property def google_domain(self) -> str: if self._primary_admin_email is None: raise RuntimeError( "Primary admin email missing, should not call this property before calling load_credentials" ) return self._primary_admin_email.split("@")[-1] @property def creds(self) -> OAuthCredentials | ServiceAccountCredentials: if self._creds is None: raise RuntimeError( "Creds missing, should not call this property before calling load_credentials" ) return self._creds @classmethod @override def normalize_url(cls, url: str) -> NormalizationResult: """Normalize a Google Drive URL to match the canonical Document.id format. Reuses the connector's existing document ID creation logic from onyx_document_id_from_drive_file. """ parsed = urlparse(url) netloc = parsed.netloc.lower() if not ( netloc.startswith("docs.google.com") or netloc.startswith("drive.google.com") ): return NormalizationResult(normalized_url=None, use_default=False) # Handle ?id= query parameter case query_params = parse_qs(parsed.query) doc_id = query_params.get("id", [None])[0] if doc_id: scheme = parsed.scheme or "https" netloc = "drive.google.com" path = f"/file/d/{doc_id}" params = "" query = "" fragment = "" normalized = urlunparse( (scheme, netloc, path, params, query, fragment) ).rstrip("/") return NormalizationResult(normalized_url=normalized, use_default=False) # Extract file ID and use connector's function path_parts = parsed.path.split("/") file_id = None for i, part in enumerate(path_parts): if part == "d" and i + 1 < len(path_parts): file_id = path_parts[i + 1] break if not file_id: return NormalizationResult(normalized_url=None, use_default=False) # Create minimal file object for connector function file_obj = {"webViewLink": url, "id": file_id} normalized = onyx_document_id_from_drive_file(file_obj).rstrip("/") return NormalizationResult(normalized_url=normalized, use_default=False) # TODO: ensure returned new_creds_dict is actually persisted when this is called? def load_credentials(self, credentials: dict[str, Any]) -> dict[str, str] | None: try: self._primary_admin_email = credentials[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] except KeyError: raise ValueError("Credentials json missing primary admin key") self._creds, new_creds_dict = get_google_creds( credentials=credentials, source=DocumentSource.GOOGLE_DRIVE, ) # Service account connectors don't have a specific setting determining whether # to include "shared with me" for each user, so we default to true unless the connector # is in specific folders/drives mode. Note that shared files are only picked up during # the My Drive stage, so this does nothing if the connector is set to only index shared drives. if ( isinstance(self._creds, ServiceAccountCredentials) and not self.specific_requests_made ): self.include_files_shared_with_me = True self._creds_dict = new_creds_dict return new_creds_dict def _update_traversed_parent_ids(self, folder_id: str) -> None: self._retrieved_folder_and_drive_ids.add(folder_id) def _get_all_user_emails(self) -> list[str]: if self._specific_user_emails: return self._specific_user_emails # Start with primary admin email user_emails = [self.primary_admin_email] # Only fetch additional users if using service account if isinstance(self.creds, OAuthCredentials): return user_emails admin_service = get_admin_service( creds=self.creds, user_email=self.primary_admin_email, ) # Get admins first since they're more likely to have access to most files for is_admin in [True, False]: query = "isAdmin=true" if is_admin else "isAdmin=false" for user in execute_paginated_retrieval( retrieval_function=admin_service.users().list, list_key="users", fields=USER_FIELDS, domain=self.google_domain, query=query, ): if email := user.get("primaryEmail"): if email not in user_emails: user_emails.append(email) return user_emails def _get_my_drive_root_id(self, user_email: str) -> str | None: """ Get the My Drive root folder ID for a user. Uses a cache to avoid repeated API calls. Returns None if the user doesn't have access to Drive APIs or the call fails. """ if user_email in self._my_drive_root_id_cache: return self._my_drive_root_id_cache[user_email] try: drive_service = get_drive_service(self.creds, user_email) root_id = get_root_folder_id(drive_service) self._my_drive_root_id_cache[user_email] = root_id return root_id except Exception: # User might not have access to Drive APIs return None def _is_my_drive_root( self, folder: GoogleDriveFileType, retriever_email: str ) -> bool: """ Check if a folder is a My Drive root. For My Drive folders (no driveId), we verify by comparing the folder ID to the actual My Drive root ID obtained via files().get(fileId='root'). """ folder_id = folder.get("id") drive_id = folder.get("driveId") parents = folder.get("parents", []) # If there are parents, this is not a root if parents: return False # If driveId is set, this is shared drive content, not My Drive if drive_id: return False # Get the My Drive root ID for this user and compare root_id = self._get_my_drive_root_id(retriever_email) if root_id and folder_id == root_id: return True # Also check with admin in case the retriever doesn't have access admin_root_id = self._get_my_drive_root_id(self.primary_admin_email) if admin_root_id and folder_id == admin_root_id: return True return False def _get_new_ancestors_for_files( self, files: list[RetrievedDriveFile], seen_hierarchy_node_raw_ids: ThreadSafeSet[str], fully_walked_hierarchy_node_raw_ids: ThreadSafeSet[str], permission_sync_context: PermissionSyncContext | None = None, add_prefix: bool = False, ) -> list[HierarchyNode]: """ Get all NEW ancestor hierarchy nodes for a batch of files. For each file, walks up the parent chain until reaching a root/drive (terminal node with no parent). Returns HierarchyNode objects for all new ancestors. The function tracks two separate sets: - seen_hierarchy_node_raw_ids: Nodes we've already yielded (to avoid duplicates) - fully_walked_hierarchy_node_raw_ids: Nodes where we've successfully walked to a terminal root. Only skip walking from a node if it's in this set. This separation ensures that if User A can access folder C but not its parent B, a later User B who has access to both can still complete the walk to the root. Args: files: List of retrieved drive files to get ancestors for seen_hierarchy_node_raw_ids: Set of already-yielded node IDs (modified in place) fully_walked_hierarchy_node_raw_ids: Set of node IDs where the walk to root succeeded (modified in place) permission_sync_context: If provided, permissions will be fetched for hierarchy nodes. Contains google_domain and primary_admin_email needed for permission syncing. add_prefix: When True, prefix group IDs with source type (for indexing path). When False (default), leave unprefixed (for permission sync path). Returns: List of HierarchyNode objects for new ancestors (ordered parent-first) """ service = get_drive_service(self.creds, self.primary_admin_email) field_type = ( DriveFileFieldType.WITH_PERMISSIONS if permission_sync_context else DriveFileFieldType.STANDARD ) new_nodes: list[HierarchyNode] = [] for file in files: parent_id = _get_parent_id_from_file(file.drive_file) if not parent_id: continue # Only skip if we've already successfully walked from this node to a root. # Don't skip just because it's "seen" - a previous user may have failed # to walk to the root, and this user might have better access. if parent_id in fully_walked_hierarchy_node_raw_ids: continue # Walk up the parent chain ancestors_to_add: list[HierarchyNode] = [] node_ids_in_walk: list[str] = [] current_id: str | None = parent_id reached_terminal = False while current_id: node_ids_in_walk.append(current_id) # If we hit a node that's already been fully walked, we know # the path from here to root is complete if current_id in fully_walked_hierarchy_node_raw_ids: reached_terminal = True break # Fetch folder metadata folder = self._get_folder_metadata( current_id, file.user_email, field_type ) if not folder: # Can't access this folder - stop climbing # Don't mark as fully walked since we didn't reach root break folder_parent_id = _get_parent_id_from_file(folder) # Create the node BEFORE marking as seen to avoid a race condition where: # 1. Thread A marks node as "seen" # 2. Thread A fails to create node (e.g., API error in get_external_access) # 3. Thread B sees node as "already seen" and skips it # 4. Result: node is never yielded # # By creating first and then atomically checking/marking, we ensure that # if creation fails, another thread can still try. If both succeed, # only one will add to ancestors_to_add (the one that wins check_and_add). if permission_sync_context: external_access = get_external_access_for_folder( folder, permission_sync_context.google_domain, service, add_prefix, ) else: external_access = _public_access() node = HierarchyNode( raw_node_id=current_id, raw_parent_id=folder_parent_id, display_name=folder.get("name", "Unknown Folder"), link=folder.get("webViewLink"), node_type=HierarchyNodeType.FOLDER, external_access=external_access, ) # Now atomically check and add - only append if we're the first thread # to successfully create this node already_seen = seen_hierarchy_node_raw_ids.check_and_add(current_id) if not already_seen: ancestors_to_add.append(node) # Check if this is a verified terminal node (actual root, not just # empty parents due to permission limitations) # Check shared drive root first (simple ID comparison) if _is_shared_drive_root(folder): # files().get() returns 'Drive' for shared drive roots; # fetch the real name via drives().get(). # Try both the retriever and admin since the admin may # not have access to private shared drives. drive_name = self._get_shared_drive_name( current_id, file.user_email ) if drive_name: node.display_name = drive_name node.node_type = HierarchyNodeType.SHARED_DRIVE reached_terminal = True break # Check if this is a My Drive root (requires API call, but cached) if self._is_my_drive_root(folder, file.user_email): reached_terminal = True break # If parents is empty but we couldn't verify it's a true root, # stop walking but don't mark as fully walked (another user # with better access might be able to continue) if folder_parent_id is None: break # Move to parent current_id = folder_parent_id # If we successfully reached a terminal node (or a fully-walked node), # mark all nodes in this walk as fully walked if reached_terminal: fully_walked_hierarchy_node_raw_ids.update(set(node_ids_in_walk)) new_nodes += ancestors_to_add return new_nodes def _get_folder_metadata( self, folder_id: str, retriever_email: str, field_type: DriveFileFieldType ) -> GoogleDriveFileType | None: """ Fetch metadata for a folder by ID. Important: When a user has access to a shared folder but NOT its parent, the Google Drive API returns the folder metadata WITHOUT the parent info. To handle this, if the retriever gets a folder without parents, we also try with admin who may have better access and can see the parent chain. """ best_folder: GoogleDriveFileType | None = None # Use a set to deduplicate if retriever_email == primary_admin_email for email in {retriever_email, self.primary_admin_email}: service = get_drive_service(self.creds, email) folder = get_folder_metadata(service, folder_id, field_type) if not folder: logger.debug(f"Failed to fetch folder {folder_id} using {email}") continue logger.debug(f"Successfully fetched folder {folder_id} using {email}") # If this folder has parents, use it if folder.get("parents"): return folder # Folder has no parents - could be a root OR user lacks access to parent # Keep this as a fallback but try admin to see if they can see parents if best_folder is None: best_folder = folder logger.debug( f"Folder {folder_id} has no parents when fetched by {email}, will try admin to check for parent access" ) if best_folder: logger.debug( f"Successfully fetched folder {folder_id} but no parents found" ) return best_folder logger.debug( f"All attempts failed to fetch folder {folder_id} (tried {retriever_email} and {self.primary_admin_email})" ) return None def _get_shared_drive_name(self, drive_id: str, retriever_email: str) -> str | None: """Fetch the name of a shared drive, trying both the retriever and admin.""" for email in {retriever_email, self.primary_admin_email}: svc = get_drive_service(self.creds, email) name = get_shared_drive_name(svc, drive_id) if name: return name return None def get_all_drive_ids(self) -> set[str]: return self._get_all_drives_for_user(self.primary_admin_email) def _get_all_drives_for_user(self, user_email: str) -> set[str]: drive_service = get_drive_service(self.creds, user_email) is_service_account = isinstance(self.creds, ServiceAccountCredentials) logger.info( f"Getting all drives for user {user_email} with service account: {is_service_account}" ) all_drive_ids: set[str] = set() for drive in execute_paginated_retrieval( retrieval_function=drive_service.drives().list, list_key="drives", useDomainAdminAccess=is_service_account, fields="drives(id),nextPageToken", ): all_drive_ids.add(drive["id"]) if not all_drive_ids: logger.warning( "No drives found even though indexing shared drives was requested." ) return all_drive_ids def make_drive_id_getter( self, drive_ids: list[str], checkpoint: GoogleDriveCheckpoint ) -> Callable[[str], str | None]: status_lock = threading.Lock() in_progress_drive_ids = { completion.current_folder_or_drive_id: user_email for user_email, completion in checkpoint.completion_map.items() if completion.stage == DriveRetrievalStage.SHARED_DRIVE_FILES and completion.current_folder_or_drive_id is not None } drive_id_status: dict[str, DriveIdStatus] = {} for drive_id in drive_ids: if drive_id in self._retrieved_folder_and_drive_ids: drive_id_status[drive_id] = DriveIdStatus.FINISHED elif drive_id in in_progress_drive_ids: drive_id_status[drive_id] = DriveIdStatus.IN_PROGRESS else: drive_id_status[drive_id] = DriveIdStatus.AVAILABLE def get_available_drive_id(thread_id: str) -> str | None: completion = checkpoint.completion_map[thread_id] with status_lock: future_work = None for drive_id, status in drive_id_status.items(): if drive_id in self._retrieved_folder_and_drive_ids: drive_id_status[drive_id] = DriveIdStatus.FINISHED continue if drive_id in completion.processed_drive_ids: continue if status == DriveIdStatus.AVAILABLE: # add to processed drive ids so if this user fails to retrieve once # they won't try again on the next checkpoint run completion.processed_drive_ids.add(drive_id) return drive_id elif status == DriveIdStatus.IN_PROGRESS: logger.debug(f"Drive id in progress: {drive_id}") future_work = drive_id if future_work: # in this case, all drive ids are either finished or in progress. # This thread will pick up one of the in progress ones in case it fails. # This is a much simpler approach than waiting for a failure picking it up, # at the cost of some repeated work until all shared drives are retrieved. # we avoid apocalyptic cases like all threads focusing on one huge drive # because the drive id is added to _retrieved_folder_and_drive_ids after any thread # manages to retrieve any file from it (unfortunately, this is also the reason we currently # sometimes fail to retrieve restricted access folders/files) completion.processed_drive_ids.add(future_work) return future_work return None # no work available, return None return get_available_drive_id def _impersonate_user_for_retrieval( self, user_email: str, field_type: DriveFileFieldType, checkpoint: GoogleDriveCheckpoint, get_new_drive_id: Callable[[str], str | None], sorted_filtered_folder_ids: list[str], start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[RetrievedDriveFile]: logger.info(f"Impersonating user {user_email}") curr_stage = checkpoint.completion_map[user_email] resuming = True if curr_stage.stage == DriveRetrievalStage.START: logger.info(f"Setting stage to {DriveRetrievalStage.MY_DRIVE_FILES.value}") curr_stage.stage = DriveRetrievalStage.MY_DRIVE_FILES resuming = False drive_service = get_drive_service(self.creds, user_email) # validate that the user has access to the drive APIs by performing a simple # request and checking for a 401 try: logger.debug(f"Getting root folder id for user {user_email}") # default is ~17mins of retries, don't do that here for cases so we don't # waste 17mins everytime we run into a user without access to drive APIs retry_builder(tries=3, delay=1)(get_root_folder_id)(drive_service) except HttpError as e: if e.status_code == 401: # fail gracefully, let the other impersonations continue # one user without access shouldn't block the entire connector logger.warning( f"User '{user_email}' does not have access to the drive APIs." ) # mark this user as done so we don't try to retrieve anything for them # again curr_stage.stage = DriveRetrievalStage.DONE return raise except RefreshError as e: logger.warning( f"User '{user_email}' could not refresh their token. Error: {e}" ) # mark this user as done so we don't try to retrieve anything for them # again yield RetrievedDriveFile( completion_stage=DriveRetrievalStage.DONE, drive_file={}, user_email=user_email, error=e, ) curr_stage.stage = DriveRetrievalStage.DONE return # if we are including my drives, try to get the current user's my # drive if any of the following are true: # - include_my_drives is true # - the current user's email is in the requested emails if curr_stage.stage == DriveRetrievalStage.MY_DRIVE_FILES: if self.include_my_drives or user_email in self._requested_my_drive_emails: logger.info( f"Getting all files in my drive as '{user_email}. Resuming: {resuming}. " f"Stage completed until: {curr_stage.completed_until}. " f"Next page token: {curr_stage.next_page_token}" ) for file_or_token in add_retrieval_info( get_all_files_in_my_drive_and_shared( service=drive_service, update_traversed_ids_func=self._update_traversed_parent_ids, field_type=field_type, include_shared_with_me=self.include_files_shared_with_me, max_num_pages=MY_DRIVE_PAGES_PER_CHECKPOINT, start=curr_stage.completed_until if resuming else start, end=end, cache_folders=not bool(curr_stage.completed_until), page_token=curr_stage.next_page_token, ), user_email, DriveRetrievalStage.MY_DRIVE_FILES, ): if isinstance(file_or_token, str): logger.debug(f"Done with max num pages for user {user_email}") checkpoint.completion_map[user_email].next_page_token = ( file_or_token ) return # done with the max num pages, return checkpoint yield file_or_token checkpoint.completion_map[user_email].next_page_token = None curr_stage.stage = DriveRetrievalStage.SHARED_DRIVE_FILES curr_stage.current_folder_or_drive_id = None return # resume from next stage on the next run if curr_stage.stage == DriveRetrievalStage.SHARED_DRIVE_FILES: def _yield_from_drive( drive_id: str, drive_start: SecondsSinceUnixEpoch | None ) -> Iterator[RetrievedDriveFile | str]: yield from add_retrieval_info( get_files_in_shared_drive( service=drive_service, drive_id=drive_id, field_type=field_type, max_num_pages=SHARED_DRIVE_PAGES_PER_CHECKPOINT, update_traversed_ids_func=self._update_traversed_parent_ids, cache_folders=not bool( drive_start ), # only cache folders for 0 or None start=drive_start, end=end, page_token=curr_stage.next_page_token, ), user_email, DriveRetrievalStage.SHARED_DRIVE_FILES, parent_id=drive_id, ) # resume from a checkpoint if resuming and (drive_id := curr_stage.current_folder_or_drive_id): resume_start = curr_stage.completed_until for file_or_token in _yield_from_drive(drive_id, resume_start): if isinstance(file_or_token, str): checkpoint.completion_map[user_email].next_page_token = ( file_or_token ) return # done with the max num pages, return checkpoint yield file_or_token drive_id = get_new_drive_id(user_email) if drive_id: logger.info( f"Getting files in shared drive '{drive_id}' as '{user_email}. Resuming: {resuming}" ) curr_stage.completed_until = 0 curr_stage.current_folder_or_drive_id = drive_id for file_or_token in _yield_from_drive(drive_id, start): if isinstance(file_or_token, str): checkpoint.completion_map[user_email].next_page_token = ( file_or_token ) return # done with the max num pages, return checkpoint yield file_or_token curr_stage.current_folder_or_drive_id = None return # get a new drive id on the next run checkpoint.completion_map[user_email].next_page_token = None curr_stage.stage = DriveRetrievalStage.FOLDER_FILES curr_stage.current_folder_or_drive_id = None return # resume from next stage on the next run # In the folder files section of service account retrieval we take extra care # to not retrieve duplicate docs. In particular, we only add a folder to # retrieved_folder_and_drive_ids when all users are finished retrieving files # from that folder, and maintain a set of all file ids that have been retrieved # for each folder. This might get rather large; in practice we assume that the # specific folders users choose to index don't have too many files. if curr_stage.stage == DriveRetrievalStage.FOLDER_FILES: def _yield_from_folder_crawl( folder_id: str, folder_start: SecondsSinceUnixEpoch | None ) -> Iterator[RetrievedDriveFile]: for retrieved_file in crawl_folders_for_files( service=drive_service, parent_id=folder_id, field_type=field_type, user_email=user_email, traversed_parent_ids=self._retrieved_folder_and_drive_ids, update_traversed_ids_func=self._update_traversed_parent_ids, start=folder_start, end=end, ): yield retrieved_file # resume from a checkpoint last_processed_folder = None if resuming: folder_id = curr_stage.current_folder_or_drive_id if folder_id is None: logger.warning( f"folder id not set in checkpoint for user {user_email}. " "This happens occasionally when the connector is interrupted " "and resumed." ) else: resume_start = curr_stage.completed_until yield from _yield_from_folder_crawl(folder_id, resume_start) last_processed_folder = folder_id skipping_seen_folders = last_processed_folder is not None # NOTE: this assumes a small number of folders to crawl. If someone # really wants to specify a large number of folders, we should use # binary search to find the first unseen folder. num_completed_folders = 0 for folder_id in sorted_filtered_folder_ids: if skipping_seen_folders: skipping_seen_folders = folder_id != last_processed_folder continue if folder_id in self._retrieved_folder_and_drive_ids: continue curr_stage.completed_until = 0 curr_stage.current_folder_or_drive_id = folder_id if num_completed_folders >= FOLDERS_PER_CHECKPOINT: return # resume from this folder on the next run logger.info(f"Getting files in folder '{folder_id}' as '{user_email}'") yield from _yield_from_folder_crawl(folder_id, start) num_completed_folders += 1 curr_stage.stage = DriveRetrievalStage.DONE def _manage_service_account_retrieval( self, field_type: DriveFileFieldType, checkpoint: GoogleDriveCheckpoint, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[RetrievedDriveFile]: """ The current implementation of the service account retrieval does some initial setup work using the primary admin email, then runs MAX_DRIVE_WORKERS concurrent threads, each of which impersonates a different user and retrieves files for that user. Technically, the actual work each thread does is "yield the next file retrieved by the user", at which point it returns to the thread pool; see parallel_yield for more details. """ if checkpoint.completion_stage == DriveRetrievalStage.START: checkpoint.completion_stage = DriveRetrievalStage.USER_EMAILS if checkpoint.completion_stage == DriveRetrievalStage.USER_EMAILS: all_org_emails: list[str] = self._get_all_user_emails() checkpoint.user_emails = all_org_emails checkpoint.completion_stage = DriveRetrievalStage.DRIVE_IDS else: if checkpoint.user_emails is None: raise ValueError("user emails not set") all_org_emails = checkpoint.user_emails sorted_drive_ids, sorted_folder_ids = self._determine_retrieval_ids( checkpoint, DriveRetrievalStage.MY_DRIVE_FILES ) # Setup initial completion map on first connector run for email in all_org_emails: # don't overwrite existing completion map on resuming runs if email in checkpoint.completion_map: continue checkpoint.completion_map[email] = StageCompletion( stage=DriveRetrievalStage.START, completed_until=0, processed_drive_ids=set(), ) # we've found all users and drives, now time to actually start # fetching stuff logger.info(f"Found {len(all_org_emails)} users to impersonate") logger.debug(f"Users: {all_org_emails}") logger.info(f"Found {len(sorted_drive_ids)} drives to retrieve") logger.debug(f"Drives: {sorted_drive_ids}") logger.info(f"Found {len(sorted_folder_ids)} folders to retrieve") logger.debug(f"Folders: {sorted_folder_ids}") drive_id_getter = self.make_drive_id_getter(sorted_drive_ids, checkpoint) # only process emails that we haven't already completed retrieval for non_completed_org_emails = [ user_email for user_email, stage_completion in checkpoint.completion_map.items() if stage_completion.stage != DriveRetrievalStage.DONE ] logger.debug(f"Non-completed users remaining: {len(non_completed_org_emails)}") # don't process too many emails before returning a checkpoint. This is # to resolve the case where there are a ton of emails that don't have access # to the drive APIs. Without this, we could loop through these emails for # more than 3 hours, causing a timeout and stalling progress. email_batch_takes_us_to_completion = True MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING = MAX_DRIVE_WORKERS if len(non_completed_org_emails) > MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING: non_completed_org_emails = non_completed_org_emails[ :MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING ] email_batch_takes_us_to_completion = False user_retrieval_gens = [ self._impersonate_user_for_retrieval( email, field_type, checkpoint, drive_id_getter, sorted_folder_ids, start, end, ) for email in non_completed_org_emails ] yield from parallel_yield(user_retrieval_gens, max_workers=MAX_DRIVE_WORKERS) # if there are more emails to process, don't mark as complete if not email_batch_takes_us_to_completion: return remaining_folders = ( set(sorted_drive_ids) | set(sorted_folder_ids) ) - self._retrieved_folder_and_drive_ids if remaining_folders: logger.warning( f"Some folders/drives were not retrieved. IDs: {remaining_folders}" ) if any( checkpoint.completion_map[user_email].stage != DriveRetrievalStage.DONE for user_email in all_org_emails ): logger.info( "some users did not complete retrieval, returning checkpoint for another run" ) return checkpoint.completion_stage = DriveRetrievalStage.DONE def _determine_retrieval_ids( self, checkpoint: GoogleDriveCheckpoint, next_stage: DriveRetrievalStage, ) -> tuple[list[str], list[str]]: all_drive_ids = self.get_all_drive_ids() sorted_drive_ids: list[str] = [] sorted_folder_ids: list[str] = [] if checkpoint.completion_stage == DriveRetrievalStage.DRIVE_IDS: if self._requested_shared_drive_ids or self._requested_folder_ids: ( sorted_drive_ids, sorted_folder_ids, ) = _clean_requested_drive_ids( requested_drive_ids=self._requested_shared_drive_ids, requested_folder_ids=self._requested_folder_ids, all_drive_ids_available=all_drive_ids, ) elif self.include_shared_drives: sorted_drive_ids = sorted(all_drive_ids) checkpoint.drive_ids_to_retrieve = sorted_drive_ids checkpoint.folder_ids_to_retrieve = sorted_folder_ids checkpoint.completion_stage = next_stage else: if checkpoint.drive_ids_to_retrieve is None: raise ValueError("drive ids to retrieve not set in checkpoint") if checkpoint.folder_ids_to_retrieve is None: raise ValueError("folder ids to retrieve not set in checkpoint") # When loading from a checkpoint, load the previously cached drive and folder ids sorted_drive_ids = checkpoint.drive_ids_to_retrieve sorted_folder_ids = checkpoint.folder_ids_to_retrieve return sorted_drive_ids, sorted_folder_ids def _oauth_retrieval_all_files( self, field_type: DriveFileFieldType, drive_service: GoogleDriveService, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, page_token: str | None = None, ) -> Iterator[RetrievedDriveFile | str]: if not self.include_files_shared_with_me and not self.include_my_drives: return logger.info( f"Getting shared files/my drive files for OAuth " f"with include_files_shared_with_me={self.include_files_shared_with_me}, " f"include_my_drives={self.include_my_drives}, " f"include_shared_drives={self.include_shared_drives}." f"Using '{self.primary_admin_email}' as the account." ) yield from add_retrieval_info( get_all_files_for_oauth( service=drive_service, include_files_shared_with_me=self.include_files_shared_with_me, include_my_drives=self.include_my_drives, include_shared_drives=self.include_shared_drives, field_type=field_type, max_num_pages=OAUTH_PAGES_PER_CHECKPOINT, start=start, end=end, page_token=page_token, ), self.primary_admin_email, DriveRetrievalStage.OAUTH_FILES, ) def _oauth_retrieval_drives( self, field_type: DriveFileFieldType, drive_service: GoogleDriveService, drive_ids_to_retrieve: list[str], checkpoint: GoogleDriveCheckpoint, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[RetrievedDriveFile | str]: def _yield_from_drive( drive_id: str, drive_start: SecondsSinceUnixEpoch | None ) -> Iterator[RetrievedDriveFile | str]: yield from add_retrieval_info( get_files_in_shared_drive( service=drive_service, drive_id=drive_id, field_type=field_type, max_num_pages=SHARED_DRIVE_PAGES_PER_CHECKPOINT, cache_folders=not bool( drive_start ), # only cache folders for 0 or None update_traversed_ids_func=self._update_traversed_parent_ids, start=drive_start, end=end, page_token=checkpoint.completion_map[ self.primary_admin_email ].next_page_token, ), self.primary_admin_email, DriveRetrievalStage.SHARED_DRIVE_FILES, parent_id=drive_id, ) # If we are resuming from a checkpoint, we need to finish retrieving the files from the last drive we retrieved if ( checkpoint.completion_map[self.primary_admin_email].stage == DriveRetrievalStage.SHARED_DRIVE_FILES ): drive_id = checkpoint.completion_map[ self.primary_admin_email ].current_folder_or_drive_id if drive_id is None: raise ValueError("drive id not set in checkpoint") resume_start = checkpoint.completion_map[ self.primary_admin_email ].completed_until for file_or_token in _yield_from_drive(drive_id, resume_start): if isinstance(file_or_token, str): checkpoint.completion_map[ self.primary_admin_email ].next_page_token = file_or_token return # done with the max num pages, return checkpoint yield file_or_token checkpoint.completion_map[self.primary_admin_email].next_page_token = None for drive_id in drive_ids_to_retrieve: if drive_id in self._retrieved_folder_and_drive_ids: logger.info( f"Skipping drive '{drive_id}' as it has already been retrieved" ) continue logger.info( f"Getting files in shared drive '{drive_id}' as '{self.primary_admin_email}'" ) for file_or_token in _yield_from_drive(drive_id, start): if isinstance(file_or_token, str): checkpoint.completion_map[ self.primary_admin_email ].next_page_token = file_or_token return # done with the max num pages, return checkpoint yield file_or_token checkpoint.completion_map[self.primary_admin_email].next_page_token = None def _oauth_retrieval_folders( self, field_type: DriveFileFieldType, drive_service: GoogleDriveService, drive_ids_to_retrieve: set[str], folder_ids_to_retrieve: set[str], checkpoint: GoogleDriveCheckpoint, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[RetrievedDriveFile]: """ If there are any remaining folder ids to retrieve found earlier in the retrieval process, we recursively descend the file tree and retrieve all files in the folder(s). """ # Even if no folders were requested, we still check if any drives were requested # that could be folders. remaining_folders = ( folder_ids_to_retrieve - self._retrieved_folder_and_drive_ids ) def _yield_from_folder_crawl( folder_id: str, folder_start: SecondsSinceUnixEpoch | None ) -> Iterator[RetrievedDriveFile]: yield from crawl_folders_for_files( service=drive_service, parent_id=folder_id, field_type=field_type, user_email=self.primary_admin_email, traversed_parent_ids=self._retrieved_folder_and_drive_ids, update_traversed_ids_func=self._update_traversed_parent_ids, start=folder_start, end=end, ) # resume from a checkpoint # TODO: actually checkpoint folder retrieval. Since we moved towards returning from # generator functions to indicate when a checkpoint should be returned, this code # shouldn't be used currently. Unfortunately folder crawling is quite difficult to checkpoint # effectively (likely need separate folder crawling and file retrieval stages), # so we'll revisit this later. if checkpoint.completion_map[ self.primary_admin_email ].stage == DriveRetrievalStage.FOLDER_FILES and ( folder_id := checkpoint.completion_map[ self.primary_admin_email ].current_folder_or_drive_id ): resume_start = checkpoint.completion_map[ self.primary_admin_email ].completed_until yield from _yield_from_folder_crawl(folder_id, resume_start) # the times stored in the completion_map aren't used due to the crawling behavior # instead, the traversed_parent_ids are used to determine what we have left to retrieve for folder_id in remaining_folders: logger.info( f"Getting files in folder '{folder_id}' as '{self.primary_admin_email}'" ) yield from _yield_from_folder_crawl(folder_id, start) remaining_folders = ( drive_ids_to_retrieve | folder_ids_to_retrieve ) - self._retrieved_folder_and_drive_ids if remaining_folders: logger.warning( f"Some folders/drives were not retrieved. IDs: {remaining_folders}" ) def _checkpointed_retrieval( self, retrieval_method: CredentialedRetrievalMethod, field_type: DriveFileFieldType, checkpoint: GoogleDriveCheckpoint, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[RetrievedDriveFile]: drive_files = retrieval_method( field_type=field_type, checkpoint=checkpoint, start=start, end=end, ) for file in drive_files: drive_file = file.drive_file or {} completion = checkpoint.completion_map[file.user_email] completed_until = completion.completed_until modified_time = drive_file.get(GoogleFields.MODIFIED_TIME.value) if isinstance(modified_time, str): try: completed_until = datetime.fromisoformat(modified_time).timestamp() except ValueError: logger.warning( "Invalid modifiedTime for file '%s' (stage=%s, user=%s).", drive_file.get("id"), file.completion_stage, file.user_email, ) completion.update( stage=file.completion_stage, completed_until=completed_until, current_folder_or_drive_id=file.parent_id, ) if file.error is not None or not drive_file: yield file continue try: document_id = onyx_document_id_from_drive_file(drive_file) except KeyError as exc: logger.warning( "Drive file missing id/webViewLink (stage=%s user=%s). Skipping.", file.completion_stage, file.user_email, ) if file.error is None: file.error = exc yield file continue logger.debug( f"Updating checkpoint for file: {drive_file.get('name')}. " f"Seen: {document_id in checkpoint.all_retrieved_file_ids}" ) if document_id in checkpoint.all_retrieved_file_ids: continue checkpoint.all_retrieved_file_ids.add(document_id) yield file def _manage_oauth_retrieval( self, field_type: DriveFileFieldType, checkpoint: GoogleDriveCheckpoint, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[RetrievedDriveFile]: if checkpoint.completion_stage == DriveRetrievalStage.START: checkpoint.completion_stage = DriveRetrievalStage.OAUTH_FILES checkpoint.completion_map[self.primary_admin_email] = StageCompletion( stage=DriveRetrievalStage.START, completed_until=0, current_folder_or_drive_id=None, ) drive_service = get_drive_service(self.creds, self.primary_admin_email) if checkpoint.completion_stage == DriveRetrievalStage.OAUTH_FILES: completion = checkpoint.completion_map[self.primary_admin_email] all_files_start = start # if resuming from a checkpoint if completion.stage == DriveRetrievalStage.OAUTH_FILES: all_files_start = completion.completed_until for file_or_token in self._oauth_retrieval_all_files( field_type=field_type, drive_service=drive_service, start=all_files_start, end=end, page_token=checkpoint.completion_map[ self.primary_admin_email ].next_page_token, ): if isinstance(file_or_token, str): checkpoint.completion_map[ self.primary_admin_email ].next_page_token = file_or_token return # done with the max num pages, return checkpoint yield file_or_token checkpoint.completion_stage = DriveRetrievalStage.DRIVE_IDS checkpoint.completion_map[self.primary_admin_email].next_page_token = None return # create a new checkpoint all_requested = ( self.include_files_shared_with_me and self.include_my_drives and self.include_shared_drives ) if all_requested: # If all 3 are true, we already yielded from get_all_files_for_oauth checkpoint.completion_stage = DriveRetrievalStage.DONE return sorted_drive_ids, sorted_folder_ids = self._determine_retrieval_ids( checkpoint, DriveRetrievalStage.SHARED_DRIVE_FILES ) if checkpoint.completion_stage == DriveRetrievalStage.SHARED_DRIVE_FILES: for file_or_token in self._oauth_retrieval_drives( field_type=field_type, drive_service=drive_service, drive_ids_to_retrieve=sorted_drive_ids, checkpoint=checkpoint, start=start, end=end, ): if isinstance(file_or_token, str): checkpoint.completion_map[ self.primary_admin_email ].next_page_token = file_or_token return # done with the max num pages, return checkpoint yield file_or_token checkpoint.completion_stage = DriveRetrievalStage.FOLDER_FILES checkpoint.completion_map[self.primary_admin_email].next_page_token = None return # create a new checkpoint if checkpoint.completion_stage == DriveRetrievalStage.FOLDER_FILES: yield from self._oauth_retrieval_folders( field_type=field_type, drive_service=drive_service, drive_ids_to_retrieve=set(sorted_drive_ids), folder_ids_to_retrieve=set(sorted_folder_ids), checkpoint=checkpoint, start=start, end=end, ) checkpoint.completion_stage = DriveRetrievalStage.DONE def _fetch_drive_items( self, field_type: DriveFileFieldType, checkpoint: GoogleDriveCheckpoint, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[RetrievedDriveFile]: retrieval_method = ( self._manage_service_account_retrieval if isinstance(self.creds, ServiceAccountCredentials) else self._manage_oauth_retrieval ) return self._checkpointed_retrieval( retrieval_method=retrieval_method, field_type=field_type, checkpoint=checkpoint, start=start, end=end, ) def _convert_retrieved_files_to_documents( self, drive_files_iter: Iterator[RetrievedDriveFile], checkpoint: GoogleDriveCheckpoint, include_permissions: bool, ) -> Iterator[Document | ConnectorFailure | HierarchyNode]: """ Converts retrieved files to documents, yielding HierarchyNode objects for ancestor folders before the converted documents. """ permission_sync_context = ( PermissionSyncContext( primary_admin_email=self.primary_admin_email, google_domain=self.google_domain, ) if include_permissions else None ) files_batch: list[RetrievedDriveFile] = [] for retrieved_file in drive_files_iter: if self.exclude_domain_link_only and has_link_only_permission( retrieved_file.drive_file ): continue if retrieved_file.error is None: files_batch.append(retrieved_file) continue failure_stage = retrieved_file.completion_stage.value failure_message = f"retrieval failure during stage: {failure_stage}," failure_message += f"user: {retrieved_file.user_email}," failure_message += f"parent drive/folder: {retrieved_file.parent_id}," failure_message += f"error: {retrieved_file.error}" logger.error(failure_message) yield ConnectorFailure( failed_entity=EntityFailure( entity_id=retrieved_file.drive_file.get("id", failure_stage), ), failure_message=failure_message, exception=retrieved_file.error, ) new_ancestors = self._get_new_ancestors_for_files( files=files_batch, seen_hierarchy_node_raw_ids=checkpoint.seen_hierarchy_node_raw_ids, fully_walked_hierarchy_node_raw_ids=checkpoint.fully_walked_hierarchy_node_raw_ids, permission_sync_context=permission_sync_context, add_prefix=True, ) if new_ancestors: logger.debug(f"Yielding {len(new_ancestors)} new hierarchy nodes") yield from new_ancestors func_with_args = [ ( self._convert_retrieved_file_to_document, (retrieved_file, permission_sync_context), ) for retrieved_file in files_batch ] raw_results = cast( list[Document | ConnectorFailure | None], run_functions_tuples_in_parallel(func_with_args, max_workers=8), ) results: list[Document | ConnectorFailure] = [ r for r in raw_results if r is not None ] logger.debug(f"batch has {len(results)} docs or failures") yield from results checkpoint.retrieved_folder_and_drive_ids = self._retrieved_folder_and_drive_ids def _convert_retrieved_file_to_document( self, retrieved_file: RetrievedDriveFile, permission_sync_context: PermissionSyncContext | None, ) -> Document | ConnectorFailure | None: """ Converts a single retrieved file to a document. """ try: return convert_drive_item_to_document( self.creds, self.allow_images, self.size_threshold, permission_sync_context, [retrieved_file.user_email, self.primary_admin_email] + get_file_owners(retrieved_file.drive_file, self.primary_admin_email), retrieved_file.drive_file, ) except Exception as e: logger.exception( f"Error extracting document: " f"{retrieved_file.drive_file.get('name')} from Google Drive" ) return ConnectorFailure( failed_entity=EntityFailure( entity_id=retrieved_file.drive_file.get("id", "unknown"), ), failure_message=( f"Error extracting document: " f"{retrieved_file.drive_file.get('name')}" ), exception=e, ) def _load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: GoogleDriveCheckpoint, include_permissions: bool, ) -> CheckpointOutput[GoogleDriveCheckpoint]: """ Entrypoint for the connector; first run is with an empty checkpoint. """ if self._creds is None or self._primary_admin_email is None: raise RuntimeError( "Credentials missing, should not call this method before calling load_credentials" ) logger.info( f"Loading from checkpoint with completion stage: {checkpoint.completion_stage}," f"num retrieved ids: {len(checkpoint.all_retrieved_file_ids)}" ) checkpoint = copy.deepcopy(checkpoint) self._retrieved_folder_and_drive_ids = checkpoint.retrieved_folder_and_drive_ids try: field_type = ( DriveFileFieldType.WITH_PERMISSIONS if include_permissions or self.exclude_domain_link_only else DriveFileFieldType.STANDARD ) drive_files_iter = self._fetch_drive_items( field_type=field_type, checkpoint=checkpoint, start=start, end=end, ) yield from self._convert_retrieved_files_to_documents( drive_files_iter, checkpoint, include_permissions ) except Exception as e: if MISSING_SCOPES_ERROR_STR in str(e): raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e raise e checkpoint.retrieved_folder_and_drive_ids = self._retrieved_folder_and_drive_ids logger.info( f"num drive files retrieved: {len(checkpoint.all_retrieved_file_ids)}" ) if checkpoint.completion_stage == DriveRetrievalStage.DONE: checkpoint.has_more = False return checkpoint @override def load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: GoogleDriveCheckpoint, ) -> CheckpointOutput[GoogleDriveCheckpoint]: return self._load_from_checkpoint( start, end, checkpoint, include_permissions=False ) @override def load_from_checkpoint_with_perm_sync( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: GoogleDriveCheckpoint, ) -> CheckpointOutput[GoogleDriveCheckpoint]: return self._load_from_checkpoint( start, end, checkpoint, include_permissions=True ) def _extract_slim_docs_from_google_drive( self, checkpoint: GoogleDriveCheckpoint, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> GenerateSlimDocumentOutput: files_batch: list[RetrievedDriveFile] = [] slim_batch: list[SlimDocument | HierarchyNode] = [] def _yield_slim_batch() -> list[SlimDocument | HierarchyNode]: """Process files batch and return items to yield (hierarchy nodes + slim docs).""" nonlocal files_batch, slim_batch # Get new ancestor hierarchy nodes first permission_sync_context = PermissionSyncContext( primary_admin_email=self.primary_admin_email, google_domain=self.google_domain, ) new_ancestors = self._get_new_ancestors_for_files( files=files_batch, seen_hierarchy_node_raw_ids=checkpoint.seen_hierarchy_node_raw_ids, fully_walked_hierarchy_node_raw_ids=checkpoint.fully_walked_hierarchy_node_raw_ids, permission_sync_context=permission_sync_context, ) # Build slim documents for file in files_batch: if doc := build_slim_document( self.creds, file.drive_file, PermissionSyncContext( primary_admin_email=self.primary_admin_email, google_domain=self.google_domain, ), retriever_email=file.user_email, ): slim_batch.append(doc) # Combine: hierarchy nodes first, then slim docs result: list[SlimDocument | HierarchyNode] = [] result.extend(new_ancestors) result.extend(slim_batch) files_batch = [] slim_batch = [] return result for file in self._fetch_drive_items( field_type=DriveFileFieldType.SLIM, checkpoint=checkpoint, start=start, end=end, ): if file.error is not None: raise file.error if self.exclude_domain_link_only and has_link_only_permission( file.drive_file ): continue files_batch.append(file) if len(files_batch) >= SLIM_BATCH_SIZE: yield _yield_slim_batch() if callback: if callback.should_stop(): raise RuntimeError( "_extract_slim_docs_from_google_drive: Stop signal detected" ) callback.progress("_extract_slim_docs_from_google_drive", 1) # Yield remaining files if files_batch: yield _yield_slim_batch() def retrieve_all_slim_docs_perm_sync( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> GenerateSlimDocumentOutput: try: checkpoint = self.build_dummy_checkpoint() while checkpoint.completion_stage != DriveRetrievalStage.DONE: yield from self._extract_slim_docs_from_google_drive( checkpoint=checkpoint, start=start, end=end, callback=callback, ) logger.info("Drive perm sync: Slim doc retrieval complete") except Exception as e: if MISSING_SCOPES_ERROR_STR in str(e): raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e raise e def validate_connector_settings(self) -> None: if self._creds is None: raise ConnectorMissingCredentialError( "Google Drive credentials not loaded." ) if self._primary_admin_email is None: raise ConnectorValidationError( "Primary admin email not found in credentials. Ensure DB_CREDENTIALS_PRIMARY_ADMIN_KEY is set." ) try: drive_service = get_drive_service(self._creds, self._primary_admin_email) drive_service.files().list(pageSize=1, fields="files(id)").execute() if isinstance(self._creds, ServiceAccountCredentials): # default is ~17mins of retries, don't do that here since this is called from # the UI retry_builder(tries=3, delay=0.1)(get_root_folder_id)(drive_service) except HttpError as e: status_code = e.resp.status if e.resp else None if status_code == 401: raise CredentialExpiredError( "Invalid or expired Google Drive credentials (401)." ) elif status_code == 403: raise InsufficientPermissionsError( "Google Drive app lacks required permissions (403). " "Please ensure the necessary scopes are granted and Drive " "apps are enabled." ) else: raise ConnectorValidationError( f"Unexpected Google Drive error (status={status_code}): {e}" ) except Exception as e: # Check for scope-related hints from the error message if MISSING_SCOPES_ERROR_STR in str(e): raise InsufficientPermissionsError( f"Google Drive credentials are missing required scopes. {ONYX_SCOPE_INSTRUCTIONS}" ) raise ConnectorValidationError( f"Unexpected error during Google Drive validation: {e}" ) @override def build_dummy_checkpoint(self) -> GoogleDriveCheckpoint: return GoogleDriveCheckpoint( retrieved_folder_and_drive_ids=set(), completion_stage=DriveRetrievalStage.START, completion_map=ThreadSafeDict(), all_retrieved_file_ids=set(), has_more=True, ) @override def validate_checkpoint_json(self, checkpoint_json: str) -> GoogleDriveCheckpoint: return GoogleDriveCheckpoint.model_validate_json(checkpoint_json) def get_credentials_from_env(email: str, oauth: bool) -> dict: if oauth: raw_credential_string = os.environ["GOOGLE_DRIVE_OAUTH_CREDENTIALS_JSON_STR"] else: raw_credential_string = os.environ["GOOGLE_DRIVE_SERVICE_ACCOUNT_JSON_STR"] refried_credential_string = json.dumps(json.loads(raw_credential_string)) # This is the Oauth token DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens" # This is the service account key DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key" # The email saved for both auth types DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin" DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method" cred_key = ( DB_CREDENTIALS_DICT_TOKEN_KEY if oauth else DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY ) return { cred_key: refried_credential_string, DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email, DB_CREDENTIALS_AUTHENTICATION_METHOD: "uploaded", } class CheckpointOutputWrapper: """ Wraps a CheckpointOutput generator to give things back in a more digestible format. The connector format is easier for the connector implementor (e.g. it enforces exactly one new checkpoint is returned AND that the checkpoint is at the end), thus the different formats. """ def __init__(self) -> None: self.next_checkpoint: GoogleDriveCheckpoint | None = None def __call__( self, checkpoint_connector_generator: CheckpointOutput[GoogleDriveCheckpoint], ) -> Generator[ tuple[Document | None, ConnectorFailure | None, GoogleDriveCheckpoint | None], None, None, ]: # grabs the final return value and stores it in the `next_checkpoint` variable def _inner_wrapper( checkpoint_connector_generator: CheckpointOutput[GoogleDriveCheckpoint], ) -> CheckpointOutput[GoogleDriveCheckpoint]: self.next_checkpoint = yield from checkpoint_connector_generator return self.next_checkpoint # not used for document_or_failure in _inner_wrapper(checkpoint_connector_generator): if isinstance(document_or_failure, Document): yield document_or_failure, None, None elif isinstance(document_or_failure, ConnectorFailure): yield None, document_or_failure, None else: raise ValueError( f"Invalid document_or_failure type: {type(document_or_failure)}" ) if self.next_checkpoint is None: raise RuntimeError( "Checkpoint is None. This should never happen - the connector should always return a checkpoint." ) yield None, None, self.next_checkpoint def yield_all_docs_from_checkpoint_connector( connector: GoogleDriveConnector, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, ) -> Iterator[Document | ConnectorFailure]: num_iterations = 0 checkpoint = connector.build_dummy_checkpoint() while checkpoint.has_more: doc_batch_generator = CheckpointOutputWrapper()( connector.load_from_checkpoint(start, end, checkpoint) ) for document, failure, next_checkpoint in doc_batch_generator: if failure is not None: yield failure if document is not None: yield document if next_checkpoint is not None: checkpoint = next_checkpoint num_iterations += 1 if num_iterations > 100_000: raise RuntimeError("Too many iterations. Infinite loop?") if __name__ == "__main__": import time creds = get_credentials_from_env( os.environ["GOOGLE_DRIVE_PRIMARY_ADMIN_EMAIL"], False ) connector = GoogleDriveConnector( include_shared_drives=True, shared_drive_urls=None, include_my_drives=True, my_drive_emails=None, shared_folder_urls=None, include_files_shared_with_me=True, specific_user_emails=None, ) connector.load_credentials(creds) max_fsize = 0 biggest_fsize = 0 num_errors = 0 start_time = time.time() with open("stats.txt", "w") as f: for num, doc_or_failure in enumerate( yield_all_docs_from_checkpoint_connector(connector, 0, time.time()) ): if num % 200 == 0: f.write(f"Processed {num} files\n") f.write(f"Max file size: {max_fsize / 1000_000:.2f} MB\n") f.write(f"Time so far: {time.time() - start_time:.2f} seconds\n") f.write( f"Docs per minute: {num / (time.time() - start_time) * 60:.2f}\n" ) biggest_fsize = max(biggest_fsize, max_fsize) max_fsize = 0 if isinstance(doc_or_failure, Document): max_fsize = max(max_fsize, sys.getsizeof(doc_or_failure)) elif isinstance(doc_or_failure, ConnectorFailure): num_errors += 1 print(f"Num errors: {num_errors}") print(f"Biggest file size: {biggest_fsize / 1000_000:.2f} MB") print(f"Time taken: {time.time() - start_time:.2f} seconds") ================================================ FILE: backend/onyx/connectors/google_drive/constants.py ================================================ UNSUPPORTED_FILE_TYPE_CONTENT = "" # keep empty for now DRIVE_FOLDER_TYPE = "application/vnd.google-apps.folder" DRIVE_SHORTCUT_TYPE = "application/vnd.google-apps.shortcut" DRIVE_FILE_TYPE = "application/vnd.google-apps.file" ================================================ FILE: backend/onyx/connectors/google_drive/doc_conversion.py ================================================ import io from collections.abc import Callable from datetime import datetime from typing import Any from typing import cast from urllib.parse import urlparse from urllib.parse import urlunparse from googleapiclient.errors import HttpError # type: ignore from googleapiclient.http import MediaIoBaseDownload # type: ignore from pydantic import BaseModel from onyx.access.models import ExternalAccess from onyx.configs.constants import DocumentSource from onyx.configs.constants import FileOrigin from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE from onyx.connectors.google_drive.models import GDriveMimeType from onyx.connectors.google_drive.models import GoogleDriveFileType from onyx.connectors.google_drive.section_extraction import get_document_sections from onyx.connectors.google_drive.section_extraction import HEADING_DELIMITER from onyx.connectors.google_utils.resources import get_drive_service from onyx.connectors.google_utils.resources import get_google_docs_service from onyx.connectors.google_utils.resources import GoogleDocsService from onyx.connectors.google_utils.resources import GoogleDriveService from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import Document from onyx.connectors.models import DocumentFailure from onyx.connectors.models import ImageSection from onyx.connectors.models import SlimDocument from onyx.connectors.models import TextSection from onyx.file_processing.extract_file_text import extract_file_text from onyx.file_processing.extract_file_text import get_file_ext from onyx.file_processing.extract_file_text import pptx_to_text from onyx.file_processing.extract_file_text import read_docx_file from onyx.file_processing.extract_file_text import read_pdf_file from onyx.file_processing.extract_file_text import xlsx_to_text from onyx.file_processing.file_types import OnyxFileExtensions from onyx.file_processing.file_types import OnyxMimeTypes from onyx.file_processing.image_utils import store_image_and_create_section from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import ( fetch_versioned_implementation_with_fallback, ) from onyx.utils.variable_functionality import noop_fallback logger = setup_logger() # Cache for folder path lookups to avoid redundant API calls # Maps folder_id -> (folder_name, parent_id) _folder_cache: dict[str, tuple[str, str | None]] = {} def _get_folder_info( service: GoogleDriveService, folder_id: str ) -> tuple[str, str | None]: """Fetch folder name and parent ID, with caching.""" if folder_id in _folder_cache: return _folder_cache[folder_id] try: folder = ( service.files() .get( fileId=folder_id, fields="name, parents", supportsAllDrives=True, ) .execute() ) folder_name = folder.get("name", "Unknown") parents = folder.get("parents", []) parent_id = parents[0] if parents else None _folder_cache[folder_id] = (folder_name, parent_id) return folder_name, parent_id except HttpError as e: logger.warning(f"Failed to get folder info for {folder_id}: {e}") _folder_cache[folder_id] = ("Unknown", None) return "Unknown", None def _get_drive_name(service: GoogleDriveService, drive_id: str) -> str: """Fetch shared drive name.""" cache_key = f"drive_{drive_id}" if cache_key in _folder_cache: return _folder_cache[cache_key][0] try: drive = service.drives().get(driveId=drive_id).execute() drive_name = drive.get("name", f"Shared Drive {drive_id}") _folder_cache[cache_key] = (drive_name, None) return drive_name except HttpError as e: logger.warning(f"Failed to get drive name for {drive_id}: {e}") _folder_cache[cache_key] = (f"Shared Drive {drive_id}", None) return f"Shared Drive {drive_id}" def build_folder_path( file: GoogleDriveFileType, service: GoogleDriveService, drive_id: str | None = None, user_email: str | None = None, ) -> list[str]: """ Build the full folder path for a file by walking up the parent chain. Returns a list of folder names from root to immediate parent. Args: file: The Google Drive file object service: Google Drive service instance drive_id: Optional drive ID (will be extracted from file if not provided) user_email: Optional user email to check ownership for "My Drive" vs "Shared with me" """ path_parts: list[str] = [] # Get drive_id from file if not provided if drive_id is None: drive_id = file.get("driveId") # Check if file is owned by the user (for distinguishing "My Drive" vs "Shared with me") is_owned_by_user = False if user_email: owners = file.get("owners", []) is_owned_by_user = any( owner.get("emailAddress", "").lower() == user_email.lower() for owner in owners ) # Get the file's parent folder ID parents = file.get("parents", []) if not parents: # File is at root level if drive_id: return [_get_drive_name(service, drive_id)] # If not in a shared drive, check if it's owned by the user if is_owned_by_user: return ["My Drive"] else: return ["Shared with me"] parent_id: str | None = parents[0] # Walk up the folder hierarchy (limit to 50 levels to prevent infinite loops) visited: set[str] = set() for _ in range(50): if not parent_id or parent_id in visited: break visited.add(parent_id) folder_name, next_parent = _get_folder_info(service, parent_id) # Check if we've reached the root (parent is the drive itself or no parent) if next_parent is None: # This folder's name is either the drive root, My Drive, or Shared with me if drive_id: path_parts.insert(0, _get_drive_name(service, drive_id)) else: # Not in a shared drive - determine if it's "My Drive" or "Shared with me" if is_owned_by_user: path_parts.insert(0, "My Drive") else: path_parts.insert(0, "Shared with me") break else: path_parts.insert(0, folder_name) parent_id = next_parent # If we didn't find a root, determine the root based on ownership and drive if not path_parts: if drive_id: return [_get_drive_name(service, drive_id)] elif is_owned_by_user: return ["My Drive"] else: return ["Shared with me"] return path_parts # This is not a standard valid unicode char, it is used by the docs advanced API to # represent smart chips (elements like dates and doc links). SMART_CHIP_CHAR = "\ue907" WEB_VIEW_LINK_KEY = "webViewLink" # Fallback templates for generating web links when Drive omits webViewLink. _FALLBACK_WEB_VIEW_LINK_TEMPLATES = { GDriveMimeType.DOC.value: "https://docs.google.com/document/d/{}/view", GDriveMimeType.SPREADSHEET.value: "https://docs.google.com/spreadsheets/d/{}/view", GDriveMimeType.PPT.value: "https://docs.google.com/presentation/d/{}/view", } MAX_RETRIEVER_EMAILS = 20 CHUNK_SIZE_BUFFER = 64 # extra bytes past the limit to read # Mapping of Google Drive mime types to export formats GOOGLE_MIME_TYPES_TO_EXPORT = { GDriveMimeType.DOC.value: "text/plain", GDriveMimeType.SPREADSHEET.value: "text/csv", GDriveMimeType.PPT.value: "text/plain", } # Define Google MIME types mapping GOOGLE_MIME_TYPES = { GDriveMimeType.DOC.value: "text/plain", GDriveMimeType.SPREADSHEET.value: "text/csv", GDriveMimeType.PPT.value: "text/plain", } class PermissionSyncContext(BaseModel): """ This is the information that is needed to sync permissions for a document. """ primary_admin_email: str google_domain: str def onyx_document_id_from_drive_file(file: GoogleDriveFileType) -> str: link = file.get(WEB_VIEW_LINK_KEY) if not link: file_id = file.get("id") if not file_id: raise KeyError( f"Google Drive file missing both '{WEB_VIEW_LINK_KEY}' and 'id' fields." ) mime_type = file.get("mimeType", "") template = _FALLBACK_WEB_VIEW_LINK_TEMPLATES.get(mime_type) if template is None: link = f"https://drive.google.com/file/d/{file_id}/view" else: link = template.format(file_id) logger.debug( "Missing webViewLink for Google Drive file with id %s. Falling back to constructed link %s", file_id, link, ) parsed_url = urlparse(link) parsed_url = parsed_url._replace(query="") # remove query parameters spl_path = parsed_url.path.split("/") if spl_path and (spl_path[-1] in ["edit", "view", "preview"]): spl_path.pop() parsed_url = parsed_url._replace(path="/".join(spl_path)) # Remove query parameters and reconstruct URL return urlunparse(parsed_url) def download_request( service: GoogleDriveService, file_id: str, size_threshold: int ) -> bytes: """ Download the file from Google Drive. """ # For other file types, download the file # Use the correct API call for downloading files request = service.files().get_media(fileId=file_id) return _download_request(request, file_id, size_threshold) _DOWNLOAD_NUM_RETRIES = 3 def _download_request(request: Any, file_id: str, size_threshold: int) -> bytes: response_bytes = io.BytesIO() downloader = MediaIoBaseDownload( response_bytes, request, chunksize=size_threshold + CHUNK_SIZE_BUFFER ) done = False while not done: # num_retries enables automatic retry with exponential backoff for transient errors download_progress, done = downloader.next_chunk( num_retries=_DOWNLOAD_NUM_RETRIES ) if download_progress.resumable_progress > size_threshold: logger.warning( f"File {file_id} exceeds size threshold of {size_threshold}. Skipping2." ) return bytes() response = response_bytes.getvalue() if not response: logger.warning(f"Failed to download {file_id}") return bytes() return response def _download_and_extract_sections_basic( file: dict[str, str], service: GoogleDriveService, allow_images: bool, size_threshold: int, ) -> list[TextSection | ImageSection]: """Extract text and images from a Google Drive file.""" file_id = file["id"] file_name = file["name"] mime_type = file["mimeType"] link = file.get(WEB_VIEW_LINK_KEY, "") # For non-Google files, download the file # Use the correct API call for downloading files # lazy evaluation to only download the file if necessary def response_call() -> bytes: return download_request(service, file_id, size_threshold) if mime_type in OnyxMimeTypes.IMAGE_MIME_TYPES: # Skip images if not explicitly enabled if not allow_images: return [] # Store images for later processing sections: list[TextSection | ImageSection] = [] try: section, embedded_id = store_image_and_create_section( image_data=response_call(), file_id=file_id, display_name=file_name, media_type=mime_type, file_origin=FileOrigin.CONNECTOR, link=link, ) sections.append(section) except Exception as e: logger.error(f"Failed to process image {file_name}: {e}") return sections # For Google Docs, Sheets, and Slides, export as plain text if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT: export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type] # Use the correct API call for exporting files request = service.files().export_media( fileId=file_id, mimeType=export_mime_type ) response = _download_request(request, file_id, size_threshold) if not response: logger.warning(f"Failed to export {file_name} as {export_mime_type}") return [] text = response.decode("utf-8") return [TextSection(link=link, text=text)] # Process based on mime type if mime_type == "text/plain": try: text = response_call().decode("utf-8") return [TextSection(link=link, text=text)] except UnicodeDecodeError as e: logger.warning(f"Failed to extract text from {file_name}: {e}") return [] elif ( mime_type == "application/vnd.openxmlformats-officedocument.wordprocessingml.document" ): text, _ = read_docx_file(io.BytesIO(response_call())) return [TextSection(link=link, text=text)] elif ( mime_type == "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" ): text = xlsx_to_text(io.BytesIO(response_call()), file_name=file_name) return [TextSection(link=link, text=text)] if text else [] elif ( mime_type == "application/vnd.openxmlformats-officedocument.presentationml.presentation" ): text = pptx_to_text(io.BytesIO(response_call()), file_name=file_name) return [TextSection(link=link, text=text)] if text else [] elif mime_type == "application/pdf": text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_call())) pdf_sections: list[TextSection | ImageSection] = [ TextSection(link=link, text=text) ] # Process embedded images in the PDF try: for idx, (img_data, img_name) in enumerate(images): section, embedded_id = store_image_and_create_section( image_data=img_data, file_id=f"{file_id}_img_{idx}", display_name=img_name or f"{file_name} - image {idx}", file_origin=FileOrigin.CONNECTOR, ) pdf_sections.append(section) except Exception as e: logger.error(f"Failed to process PDF images in {file_name}: {e}") return pdf_sections # Final attempt at extracting text file_ext = get_file_ext(file.get("name", "")) if file_ext not in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS: logger.warning(f"Skipping file {file.get('name')} due to extension.") return [] try: text = extract_file_text(io.BytesIO(response_call()), file_name) return [TextSection(link=link, text=text)] except Exception as e: logger.warning(f"Failed to extract text from {file_name}: {e}") return [] def _find_nth(haystack: str, needle: str, n: int, start: int = 0) -> int: start = haystack.find(needle, start) while start >= 0 and n > 1: start = haystack.find(needle, start + len(needle)) n -= 1 return start def align_basic_advanced( basic_sections: list[TextSection | ImageSection], adv_sections: list[TextSection] ) -> list[TextSection | ImageSection]: """Align the basic sections with the advanced sections. In particular, the basic sections contain all content of the file, including smart chips like dates and doc links. The advanced sections are separated by section headers and contain header-based links that improve user experience when they click on the source in the UI. There are edge cases in text matching (i.e. the heading is a smart chip or there is a smart chip in the doc with text containing the actual heading text) that make the matching imperfect; this is hence done on a best-effort basis. """ if len(adv_sections) <= 1: return basic_sections # no benefit from aligning basic_full_text = "".join( [section.text for section in basic_sections if isinstance(section, TextSection)] ) new_sections: list[TextSection | ImageSection] = [] heading_start = 0 for adv_ind in range(1, len(adv_sections)): heading = adv_sections[adv_ind].text.split(HEADING_DELIMITER)[0] # retrieve the longest part of the heading that is not a smart chip heading_key = max(heading.split(SMART_CHIP_CHAR), key=len).strip() if heading_key == "": logger.warning( f"Cannot match heading: {heading}, its link will come from the following section" ) continue heading_offset = heading.find(heading_key) # count occurrences of heading str in previous section heading_count = adv_sections[adv_ind - 1].text.count(heading_key) prev_start = heading_start heading_start = ( _find_nth(basic_full_text, heading_key, heading_count, start=prev_start) - heading_offset ) if heading_start < 0: logger.warning( f"Heading key {heading_key} from heading {heading} not found in basic text" ) heading_start = prev_start continue new_sections.append( TextSection( link=adv_sections[adv_ind - 1].link, text=basic_full_text[prev_start:heading_start], ) ) # handle last section new_sections.append( TextSection(link=adv_sections[-1].link, text=basic_full_text[heading_start:]) ) return new_sections def _get_external_access_for_raw_gdrive_file( file: GoogleDriveFileType, company_domain: str, retriever_drive_service: GoogleDriveService | None, admin_drive_service: GoogleDriveService, fallback_user_email: str, add_prefix: bool = False, ) -> ExternalAccess: """ Get the external access for a raw Google Drive file. add_prefix: When True, prefix group IDs with source type (for indexing path). When False (default), leave unprefixed (for permission sync path where upsert_document_external_perms handles prefixing). fallback_user_email: When permission info can't be retrieved (e.g. externally-owned files), fall back to granting access to this user. """ external_access_fn = cast( Callable[ [ GoogleDriveFileType, str, GoogleDriveService | None, GoogleDriveService, str, bool, ], ExternalAccess, ], fetch_versioned_implementation_with_fallback( "onyx.external_permissions.google_drive.doc_sync", "get_external_access_for_raw_gdrive_file", fallback=noop_fallback, ), ) return external_access_fn( file, company_domain, retriever_drive_service, admin_drive_service, fallback_user_email, add_prefix, ) def convert_drive_item_to_document( creds: Any, allow_images: bool, size_threshold: int, # if not specified, we will not sync permissions # will also be a no-op if EE is not enabled permission_sync_context: PermissionSyncContext | None, retriever_emails: list[str], file: GoogleDriveFileType, ) -> Document | ConnectorFailure | None: """ Attempt to convert a drive item to a document with each retriever email in order. returns upon a successful retrieval or a non-403 error. We used to always get the user email from the file owners when available, but this was causing issues with shared folders where the owner was not included in the service account now we use the email of the account that successfully listed the file. There are cases where a user that can list a file cannot download it, so we retry with file owners and admin email. """ first_error = None doc_or_failure = None retriever_emails = retriever_emails[:MAX_RETRIEVER_EMAILS] # use seen instead of list(set()) to avoid re-ordering the retriever emails seen = set() for retriever_email in retriever_emails: if retriever_email in seen: continue seen.add(retriever_email) doc_or_failure = _convert_drive_item_to_document( creds, allow_images, size_threshold, retriever_email, file, permission_sync_context, ) # There are a variety of permissions-based errors that occasionally occur # when retrieving files. Often when these occur, there is another user # that can successfully retrieve the file, so we try the next user. if ( doc_or_failure is None or isinstance(doc_or_failure, Document) or not ( isinstance(doc_or_failure.exception, HttpError) and doc_or_failure.exception.status_code in [401, 403, 404] ) ): return doc_or_failure if first_error is None: first_error = doc_or_failure else: first_error.failure_message += f"\n\n{doc_or_failure.failure_message}" if ( first_error and isinstance(first_error.exception, HttpError) and first_error.exception.status_code == 403 ): # This SHOULD happen very rarely, and we don't want to break the indexing process when # a high volume of 403s occurs early. We leave a verbose log to help investigate. logger.error( f"Skipping file id: {file.get('id')} name: {file.get('name')} due to 403 error." f"Attempted to retrieve with {retriever_emails}," f"got the following errors: {first_error.failure_message}" ) return None return first_error def _convert_drive_item_to_document( creds: Any, allow_images: bool, size_threshold: int, retriever_email: str, file: GoogleDriveFileType, # if not specified, we will not sync permissions # will also be a no-op if EE is not enabled permission_sync_context: PermissionSyncContext | None, ) -> Document | ConnectorFailure | None: """ Main entry point for converting a Google Drive file => Document object. """ sections: list[TextSection | ImageSection] = [] # Only construct these services when needed def _get_drive_service() -> GoogleDriveService: return get_drive_service(creds, user_email=retriever_email) def _get_docs_service() -> GoogleDocsService: return get_google_docs_service(creds, user_email=retriever_email) doc_id = "unknown" try: # skip shortcuts or folders if file.get("mimeType") in [DRIVE_SHORTCUT_TYPE, DRIVE_FOLDER_TYPE]: logger.info("Skipping shortcut/folder.") return None size_str = file.get("size") if size_str: try: size_int = int(size_str) except ValueError: logger.warning(f"Parsing string to int failed: size_str={size_str}") else: if size_int > size_threshold: logger.warning( f"{file.get('name')} exceeds size threshold of {size_threshold}. Skipping." ) return None # If it's a Google Doc, we might do advanced parsing if file.get("mimeType") == GDriveMimeType.DOC.value: try: logger.debug(f"starting advanced parsing for {file.get('name')}") # get_document_sections is the advanced approach for Google Docs doc_sections = get_document_sections( docs_service=_get_docs_service(), doc_id=file.get("id", ""), ) if doc_sections: sections = cast(list[TextSection | ImageSection], doc_sections) if any(SMART_CHIP_CHAR in section.text for section in doc_sections): logger.debug( f"found smart chips in {file.get('name')}, aligning with basic sections" ) basic_sections = _download_and_extract_sections_basic( file, _get_drive_service(), allow_images, size_threshold ) sections = align_basic_advanced(basic_sections, doc_sections) except Exception as e: logger.warning( f"Error in advanced parsing: {e}. Falling back to basic extraction." ) # Not Google Doc, attempt basic extraction else: sections = _download_and_extract_sections_basic( file, _get_drive_service(), allow_images, size_threshold ) # If we still don't have any sections, skip this file if not sections: logger.warning(f"No content extracted from {file.get('name')}. Skipping.") return None doc_id = onyx_document_id_from_drive_file(file) external_access = ( _get_external_access_for_raw_gdrive_file( file=file, company_domain=permission_sync_context.google_domain, # try both retriever_email and primary_admin_email if necessary retriever_drive_service=_get_drive_service(), admin_drive_service=get_drive_service( creds, user_email=permission_sync_context.primary_admin_email ), add_prefix=True, # Indexing path - prefix here fallback_user_email=retriever_email, ) if permission_sync_context else None ) # Build doc_metadata with hierarchy information file_name = file.get("name", "") mime_type = file.get("mimeType", "") drive_id = file.get("driveId") # Build full folder path by walking up the parent chain # Pass retriever_email to determine if file is in "My Drive" vs "Shared with me" source_path = build_folder_path( file, _get_drive_service(), drive_id, retriever_email ) doc_metadata = { "hierarchy": { "source_path": source_path, "drive_id": drive_id, "file_name": file_name, "mime_type": mime_type, } } # Create the document return Document( id=doc_id, sections=sections, source=DocumentSource.GOOGLE_DRIVE, semantic_identifier=file_name, doc_metadata=doc_metadata, metadata={ "owner_names": ", ".join( owner.get("displayName", "") for owner in file.get("owners", []) ), }, doc_updated_at=datetime.fromisoformat( file.get("modifiedTime", "").replace("Z", "+00:00") ), external_access=external_access, parent_hierarchy_raw_node_id=(file.get("parents") or [None])[0], ) except Exception as e: doc_id = "unknown" try: doc_id = onyx_document_id_from_drive_file(file) except Exception as e2: logger.warning(f"Error getting document id from file: {e2}") file_name = file.get("name") error_str = ( f"Error converting file '{file_name}' to Document as {retriever_email}: {e}" ) if isinstance(e, HttpError) and e.status_code == 403: logger.warning( f"Uncommon permissions error while downloading file. User " f"{retriever_email} was able to see file {file_name} " "but cannot download it." ) logger.warning(error_str) return ConnectorFailure( failed_document=DocumentFailure( document_id=doc_id, document_link=( sections[0].link if sections else None ), # TODO: see if this is the best way to get a link ), failed_entity=None, failure_message=error_str, exception=e, ) def build_slim_document( creds: Any, file: GoogleDriveFileType, # if not specified, we will not sync permissions # will also be a no-op if EE is not enabled permission_sync_context: PermissionSyncContext | None, retriever_email: str, ) -> SlimDocument | None: if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]: return None owner_email = cast(str | None, file.get("owners", [{}])[0].get("emailAddress")) external_access = ( _get_external_access_for_raw_gdrive_file( file=file, company_domain=permission_sync_context.google_domain, retriever_drive_service=( get_drive_service( creds, user_email=owner_email, ) if owner_email else None ), admin_drive_service=get_drive_service( creds, user_email=permission_sync_context.primary_admin_email, ), fallback_user_email=retriever_email, ) if permission_sync_context else None ) return SlimDocument( id=onyx_document_id_from_drive_file(file), external_access=external_access, parent_hierarchy_raw_node_id=(file.get("parents") or [None])[0], ) ================================================ FILE: backend/onyx/connectors/google_drive/file_retrieval.py ================================================ from collections.abc import Callable from collections.abc import Iterator from datetime import datetime from datetime import timezone from enum import Enum from typing import cast from urllib.parse import parse_qs from urllib.parse import urlparse from googleapiclient.discovery import Resource # type: ignore from googleapiclient.errors import HttpError # type: ignore from onyx.access.models import ExternalAccess from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE from onyx.connectors.google_drive.models import DriveRetrievalStage from onyx.connectors.google_drive.models import GoogleDriveFileType from onyx.connectors.google_drive.models import RetrievedDriveFile from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval from onyx.connectors.google_utils.google_utils import ( execute_paginated_retrieval_with_max_pages, ) from onyx.connectors.google_utils.google_utils import GoogleFields from onyx.connectors.google_utils.google_utils import ORDER_BY_KEY from onyx.connectors.google_utils.google_utils import PAGE_TOKEN_KEY from onyx.connectors.google_utils.resources import GoogleDriveService from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import ( fetch_versioned_implementation_with_fallback, ) from onyx.utils.variable_functionality import noop_fallback logger = setup_logger() class DriveFileFieldType(Enum): """Enum to specify which fields to retrieve from Google Drive files""" SLIM = "slim" # Minimal fields for basic file info STANDARD = "standard" # Standard fields including content metadata WITH_PERMISSIONS = "with_permissions" # Full fields including permissions PERMISSION_FULL_DESCRIPTION = ( "permissions(id, emailAddress, type, domain, allowFileDiscovery, permissionDetails)" ) FILE_FIELDS = ( "nextPageToken, files(mimeType, id, name, driveId, parents, " "modifiedTime, webViewLink, shortcutDetails, owners(emailAddress), size)" ) FILE_FIELDS_WITH_PERMISSIONS = ( f"nextPageToken, files(mimeType, id, name, driveId, parents, {PERMISSION_FULL_DESCRIPTION}, permissionIds, " "modifiedTime, webViewLink, shortcutDetails, owners(emailAddress), size)" ) SLIM_FILE_FIELDS = ( f"nextPageToken, files(mimeType, driveId, id, name, parents, {PERMISSION_FULL_DESCRIPTION}, " "permissionIds, webViewLink, owners(emailAddress), modifiedTime)" ) FOLDER_FIELDS = "nextPageToken, files(id, name, permissions, modifiedTime, webViewLink, shortcutDetails)" HIERARCHY_FIELDS = "id, name, parents, webViewLink, mimeType, driveId" HIERARCHY_FIELDS_WITH_PERMISSIONS = ( "id, name, parents, webViewLink, mimeType, permissionIds, driveId" ) def generate_time_range_filter( start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> str: time_range_filter = "" if start is not None: time_start = datetime.fromtimestamp(start, tz=timezone.utc).isoformat() time_range_filter += ( f" and {GoogleFields.MODIFIED_TIME.value} >= '{time_start}'" ) if end is not None: time_stop = datetime.fromtimestamp(end, tz=timezone.utc).isoformat() time_range_filter += f" and {GoogleFields.MODIFIED_TIME.value} <= '{time_stop}'" return time_range_filter LINK_ONLY_PERMISSION_TYPES = {"domain", "anyone"} def has_link_only_permission(file: GoogleDriveFileType) -> bool: """ Return True if any permission requires a direct link to access (allowFileDiscovery is explicitly false for supported types). """ permissions = file.get("permissions") or [] for permission in permissions: if permission.get("type") not in LINK_ONLY_PERMISSION_TYPES: continue if permission.get("allowFileDiscovery") is False: return True return False def _get_folders_in_parent( service: Resource, parent_id: str | None = None, ) -> Iterator[GoogleDriveFileType]: # Follow shortcuts to folders query = f"(mimeType = '{DRIVE_FOLDER_TYPE}' or mimeType = '{DRIVE_SHORTCUT_TYPE}')" query += " and trashed = false" if parent_id: query += f" and '{parent_id}' in parents" for file in execute_paginated_retrieval( retrieval_function=service.files().list, list_key="files", continue_on_404_or_403=True, corpora="allDrives", supportsAllDrives=True, includeItemsFromAllDrives=True, fields=FOLDER_FIELDS, q=query, ): yield file def get_folder_metadata( service: Resource, folder_id: str, field_type: DriveFileFieldType, ) -> GoogleDriveFileType | None: """Fetch metadata for a folder by ID.""" fields = _get_hierarchy_fields_for_file_type(field_type) try: return ( service.files() .get( fileId=folder_id, fields=fields, supportsAllDrives=True, ) .execute() ) except HttpError as e: if e.resp.status in (403, 404): logger.debug(f"Cannot access folder {folder_id}: {e}") else: raise e return None def _get_hierarchy_fields_for_file_type(field_type: DriveFileFieldType) -> str: if field_type == DriveFileFieldType.WITH_PERMISSIONS: return HIERARCHY_FIELDS_WITH_PERMISSIONS else: return HIERARCHY_FIELDS def get_shared_drive_name( service: Resource, drive_id: str, ) -> str | None: """Fetch the actual name of a shared drive via the drives().get() API. The files().get() API returns 'Drive' as the name for shared drive root folders. Only drives().get() returns the real user-assigned name. """ try: drive = service.drives().get(driveId=drive_id, fields="name").execute() return drive.get("name") except HttpError as e: if e.resp.status in (403, 404): logger.debug(f"Cannot access drive {drive_id}: {e}") else: raise return None def get_external_access_for_folder( folder: GoogleDriveFileType, google_domain: str, drive_service: GoogleDriveService, add_prefix: bool = False, ) -> ExternalAccess: """ Extract ExternalAccess from a folder's permissions. This fetches permissions using the Drive API (via permissionIds) and extracts user emails, group emails, and public access status. Uses the EE implementation if available, otherwise returns public access (fallback for non-EE deployments). Args: folder: The folder metadata from Google Drive API (must include permissionIds field) google_domain: The company's Google Workspace domain (e.g., "company.com") drive_service: Google Drive service for fetching permission details add_prefix: When True, prefix group IDs with source type (for indexing path). When False (default), leave unprefixed (for permission sync path where upsert_document_external_perms handles prefixing). Returns: ExternalAccess with extracted permission info """ # Try to get the EE implementation get_folder_access_fn = cast( Callable[[GoogleDriveFileType, str, GoogleDriveService, bool], ExternalAccess], fetch_versioned_implementation_with_fallback( "onyx.external_permissions.google_drive.doc_sync", "get_external_access_for_folder", noop_fallback, ), ) return get_folder_access_fn(folder, google_domain, drive_service, add_prefix) def _get_fields_for_file_type(field_type: DriveFileFieldType) -> str: """Get the appropriate fields string based on the field type enum""" if field_type == DriveFileFieldType.SLIM: return SLIM_FILE_FIELDS elif field_type == DriveFileFieldType.WITH_PERMISSIONS: return FILE_FIELDS_WITH_PERMISSIONS else: # DriveFileFieldType.STANDARD return FILE_FIELDS def _get_files_in_parent( service: Resource, parent_id: str, field_type: DriveFileFieldType, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[GoogleDriveFileType]: query = f"mimeType != '{DRIVE_FOLDER_TYPE}' and '{parent_id}' in parents" query += " and trashed = false" query += generate_time_range_filter(start, end) kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value} for file in execute_paginated_retrieval( retrieval_function=service.files().list, list_key="files", continue_on_404_or_403=True, corpora="allDrives", supportsAllDrives=True, includeItemsFromAllDrives=True, fields=_get_fields_for_file_type(field_type), q=query, **kwargs, ): yield file def crawl_folders_for_files( service: Resource, parent_id: str, field_type: DriveFileFieldType, user_email: str, traversed_parent_ids: set[str], update_traversed_ids_func: Callable[[str], None], start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> Iterator[RetrievedDriveFile]: """ This function starts crawling from any folder. It is slower though. """ logger.info("Entered crawl_folders_for_files with parent_id: " + parent_id) if parent_id not in traversed_parent_ids: logger.info("Parent id not in traversed parent ids, getting files") found_files = False file = {} try: for file in _get_files_in_parent( service=service, parent_id=parent_id, field_type=field_type, start=start, end=end, ): logger.info(f"Found file: {file['name']}, user email: {user_email}") found_files = True yield RetrievedDriveFile( drive_file=file, user_email=user_email, parent_id=parent_id, completion_stage=DriveRetrievalStage.FOLDER_FILES, ) # Only mark a folder as done if it was fully traversed without errors # This usually indicates that the owner of the folder was impersonated. # In cases where this never happens, most likely the folder owner is # not part of the google workspace in question (or for oauth, the authenticated # user doesn't own the folder) if found_files: update_traversed_ids_func(parent_id) except Exception as e: if isinstance(e, HttpError) and e.status_code == 403: # don't yield an error here because this is expected behavior # when a user doesn't have access to a folder logger.debug(f"Error getting files in parent {parent_id}: {e}") else: logger.error(f"Error getting files in parent {parent_id}: {e}") yield RetrievedDriveFile( drive_file=file, user_email=user_email, parent_id=parent_id, completion_stage=DriveRetrievalStage.FOLDER_FILES, error=e, ) else: logger.info(f"Skipping subfolder files since already traversed: {parent_id}") for subfolder in _get_folders_in_parent( service=service, parent_id=parent_id, ): logger.info("Fetching all files in subfolder: " + subfolder["name"]) yield from crawl_folders_for_files( service=service, parent_id=subfolder["id"], field_type=field_type, user_email=user_email, traversed_parent_ids=traversed_parent_ids, update_traversed_ids_func=update_traversed_ids_func, start=start, end=end, ) def get_files_in_shared_drive( service: Resource, drive_id: str, field_type: DriveFileFieldType, max_num_pages: int, update_traversed_ids_func: Callable[[str], None] = lambda _: None, cache_folders: bool = True, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, page_token: str | None = None, ) -> Iterator[GoogleDriveFileType | str]: kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value} if page_token: logger.info(f"Using page token: {page_token}") kwargs[PAGE_TOKEN_KEY] = page_token if cache_folders: # If we know we are going to folder crawl later, we can cache the folders here # Get all folders being queried and add them to the traversed set folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'" folder_query += " and trashed = false" for folder in execute_paginated_retrieval( retrieval_function=service.files().list, list_key="files", continue_on_404_or_403=True, corpora="drive", driveId=drive_id, supportsAllDrives=True, includeItemsFromAllDrives=True, fields="nextPageToken, files(id)", q=folder_query, ): update_traversed_ids_func(folder["id"]) # Get all files in the shared drive file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'" file_query += " and trashed = false" file_query += generate_time_range_filter(start, end) for file in execute_paginated_retrieval_with_max_pages( retrieval_function=service.files().list, max_num_pages=max_num_pages, list_key="files", continue_on_404_or_403=True, corpora="drive", driveId=drive_id, supportsAllDrives=True, includeItemsFromAllDrives=True, fields=_get_fields_for_file_type(field_type), q=file_query, **kwargs, ): # If we found any files, mark this drive as traversed. When a user has access to a drive, # they have access to all the files in the drive. Also not a huge deal if we re-traverse # empty drives. # NOTE: ^^ the above is not actually true due to folder restrictions: # https://support.google.com/a/users/answer/12380484?hl=en # So we may have to change this logic for people who use folder restrictions. update_traversed_ids_func(drive_id) yield file def get_all_files_in_my_drive_and_shared( service: GoogleDriveService, update_traversed_ids_func: Callable, field_type: DriveFileFieldType, include_shared_with_me: bool, max_num_pages: int, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, cache_folders: bool = True, page_token: str | None = None, ) -> Iterator[GoogleDriveFileType | str]: kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value} if page_token: logger.info(f"Using page token: {page_token}") kwargs[PAGE_TOKEN_KEY] = page_token if cache_folders: # If we know we are going to folder crawl later, we can cache the folders here # Get all folders being queried and add them to the traversed set folder_query = f"mimeType = '{DRIVE_FOLDER_TYPE}'" folder_query += " and trashed = false" if not include_shared_with_me: folder_query += " and 'me' in owners" found_folders = False for folder in execute_paginated_retrieval( retrieval_function=service.files().list, list_key="files", corpora="user", fields=_get_fields_for_file_type(field_type), q=folder_query, ): update_traversed_ids_func(folder[GoogleFields.ID]) found_folders = True if found_folders: update_traversed_ids_func(get_root_folder_id(service)) # Then get the files file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'" file_query += " and trashed = false" if not include_shared_with_me: file_query += " and 'me' in owners" file_query += generate_time_range_filter(start, end) yield from execute_paginated_retrieval_with_max_pages( retrieval_function=service.files().list, max_num_pages=max_num_pages, list_key="files", continue_on_404_or_403=False, corpora="user", fields=_get_fields_for_file_type(field_type), q=file_query, **kwargs, ) def get_all_files_for_oauth( service: GoogleDriveService, include_files_shared_with_me: bool, include_my_drives: bool, # One of the above 2 should be true include_shared_drives: bool, field_type: DriveFileFieldType, max_num_pages: int, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, page_token: str | None = None, ) -> Iterator[GoogleDriveFileType | str]: kwargs = {ORDER_BY_KEY: GoogleFields.MODIFIED_TIME.value} if page_token: logger.info(f"Using page token: {page_token}") kwargs[PAGE_TOKEN_KEY] = page_token should_get_all = ( include_shared_drives and include_my_drives and include_files_shared_with_me ) corpora = "allDrives" if should_get_all else "user" file_query = f"mimeType != '{DRIVE_FOLDER_TYPE}'" file_query += " and trashed = false" file_query += generate_time_range_filter(start, end) if not should_get_all: if include_files_shared_with_me and not include_my_drives: file_query += " and not 'me' in owners" if not include_files_shared_with_me and include_my_drives: file_query += " and 'me' in owners" yield from execute_paginated_retrieval_with_max_pages( max_num_pages=max_num_pages, retrieval_function=service.files().list, list_key="files", continue_on_404_or_403=False, corpora=corpora, includeItemsFromAllDrives=should_get_all, supportsAllDrives=should_get_all, fields=_get_fields_for_file_type(field_type), q=file_query, **kwargs, ) # Just in case we need to get the root folder id def get_root_folder_id(service: Resource) -> str: # we dont paginate here because there is only one root folder per user # https://developers.google.com/drive/api/guides/v2-to-v3-reference return ( service.files() .get(fileId="root", fields=GoogleFields.ID.value) .execute()[GoogleFields.ID.value] ) def _extract_file_id_from_web_view_link(web_view_link: str) -> str: parsed = urlparse(web_view_link) path_parts = [part for part in parsed.path.split("/") if part] if "d" in path_parts: idx = path_parts.index("d") if idx + 1 < len(path_parts): return path_parts[idx + 1] query_params = parse_qs(parsed.query) for key in ("id", "fileId"): value = query_params.get(key) if value and value[0]: return value[0] raise ValueError( f"Unable to extract Drive file id from webViewLink: {web_view_link}" ) def get_file_by_web_view_link( service: GoogleDriveService, web_view_link: str, fields: str, ) -> GoogleDriveFileType: """Retrieve a Google Drive file using its webViewLink.""" file_id = _extract_file_id_from_web_view_link(web_view_link) return ( service.files() .get( fileId=file_id, supportsAllDrives=True, fields=fields, ) .execute() ) ================================================ FILE: backend/onyx/connectors/google_drive/models.py ================================================ from enum import Enum from typing import Any from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field from pydantic import field_serializer from pydantic import field_validator from onyx.connectors.interfaces import ConnectorCheckpoint from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.utils.threadpool_concurrency import ThreadSafeDict from onyx.utils.threadpool_concurrency import ThreadSafeSet class GDriveMimeType(str, Enum): DOC = "application/vnd.google-apps.document" SPREADSHEET = "application/vnd.google-apps.spreadsheet" SPREADSHEET_OPEN_FORMAT = ( "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" ) SPREADSHEET_MS_EXCEL = "application/vnd.ms-excel" PDF = "application/pdf" WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document" PPT = "application/vnd.google-apps.presentation" POWERPOINT = ( "application/vnd.openxmlformats-officedocument.presentationml.presentation" ) PLAIN_TEXT = "text/plain" MARKDOWN = "text/markdown" GoogleDriveFileType = dict[str, Any] TOKEN_EXPIRATION_TIME = 3600 # 1 hour # These correspond to The major stages of retrieval for google drive. # The stages for the oauth flow are: # get_all_files_for_oauth(), # get_all_drive_ids(), # get_files_in_shared_drive(), # crawl_folders_for_files() # # The stages for the service account flow are roughly: # get_all_user_emails(), # get_all_drive_ids(), # get_files_in_shared_drive(), # Then for each user: # get_files_in_my_drive() # get_files_in_shared_drive() # crawl_folders_for_files() class DriveRetrievalStage(str, Enum): START = "start" DONE = "done" # OAuth specific stages OAUTH_FILES = "oauth_files" # Service account specific stages USER_EMAILS = "user_emails" MY_DRIVE_FILES = "my_drive_files" # Used for both oauth and service account flows DRIVE_IDS = "drive_ids" SHARED_DRIVE_FILES = "shared_drive_files" FOLDER_FILES = "folder_files" class StageCompletion(BaseModel): """ Describes the point in the retrieval+indexing process that the connector is at. completed_until is the timestamp of the latest file that has been retrieved or error that has been yielded. Optional fields are used for retrieval stages that need more information for resuming than just the timestamp of the latest file. """ stage: DriveRetrievalStage completed_until: SecondsSinceUnixEpoch current_folder_or_drive_id: str | None = None next_page_token: str | None = None # only used for shared drives processed_drive_ids: set[str] = set() def update( self, stage: DriveRetrievalStage, completed_until: SecondsSinceUnixEpoch, current_folder_or_drive_id: str | None = None, ) -> None: self.stage = stage self.completed_until = completed_until self.current_folder_or_drive_id = current_folder_or_drive_id class RetrievedDriveFile(BaseModel): """ Describes a file that has been retrieved from google drive. user_email is the email of the user that the file was retrieved by impersonating. If an error worthy of being reported is encountered, error should be set and later propagated as a ConnectorFailure. """ # The stage at which this file was retrieved completion_stage: DriveRetrievalStage # The file that was retrieved drive_file: GoogleDriveFileType # The email of the user that the file was retrieved by impersonating user_email: str # The id of the parent folder or drive of the file parent_id: str | None = None # Any unexpected error that occurred while retrieving the file. # In particular, this is not used for 403/404 errors, which are expected # in the context of impersonating all the users to try to retrieve all # files from all their Drives and Folders. error: Exception | None = None model_config = ConfigDict(arbitrary_types_allowed=True) class GoogleDriveCheckpoint(ConnectorCheckpoint): # Checkpoint version of _retrieved_ids retrieved_folder_and_drive_ids: set[str] # Describes the point in the retrieval+indexing process that the # checkpoint is at. when this is set to a given stage, the connector # has finished yielding all values from the previous stage. completion_stage: DriveRetrievalStage # The latest timestamp of a file that has been retrieved per user email. # StageCompletion is used to track the completion of each stage, but the # timestamp part is not used for folder crawling. completion_map: ThreadSafeDict[str, StageCompletion] # all file ids that have been retrieved all_retrieved_file_ids: set[str] = set() # cached version of the drive and folder ids to retrieve drive_ids_to_retrieve: list[str] | None = None folder_ids_to_retrieve: list[str] | None = None # cached user emails user_emails: list[str] | None = None # Hierarchy node raw IDs that have already been yielded. # Used to avoid yielding duplicate hierarchy nodes across checkpoints. # Thread-safe because multiple impersonation threads access this concurrently. # Uses default_factory to ensure each checkpoint instance gets a fresh set. seen_hierarchy_node_raw_ids: ThreadSafeSet[str] = Field( default_factory=ThreadSafeSet ) # Hierarchy node raw IDs where we have successfully walked up to a terminal # node (a drive root with no parent). This is separate from seen_hierarchy_node_raw_ids # because a node might be yielded before we've walked its full ancestry chain. # We only skip walking from a node if it's in this set, ensuring that if one user # fails to walk to the root, another user with better access can still complete the walk. # Thread-safe because multiple impersonation threads access this concurrently. # Uses default_factory to ensure each checkpoint instance gets a fresh set. fully_walked_hierarchy_node_raw_ids: ThreadSafeSet[str] = Field( default_factory=ThreadSafeSet ) @field_serializer("completion_map") def serialize_completion_map( self, completion_map: ThreadSafeDict[str, StageCompletion], _info: Any ) -> dict[str, StageCompletion]: return completion_map._dict @field_serializer("seen_hierarchy_node_raw_ids") def serialize_seen_hierarchy( self, seen_hierarchy_node_raw_ids: ThreadSafeSet[str], _info: Any ) -> set[str]: return seen_hierarchy_node_raw_ids.copy() @field_serializer("fully_walked_hierarchy_node_raw_ids") def serialize_fully_walked_hierarchy( self, fully_walked_hierarchy_node_raw_ids: ThreadSafeSet[str], _info: Any ) -> set[str]: return fully_walked_hierarchy_node_raw_ids.copy() @field_validator("completion_map", mode="before") def validate_completion_map(cls, v: Any) -> ThreadSafeDict[str, StageCompletion]: assert isinstance(v, dict) or isinstance(v, ThreadSafeDict) return ThreadSafeDict( {k: StageCompletion.model_validate(val) for k, val in v.items()} ) @field_validator("seen_hierarchy_node_raw_ids", mode="before") def validate_seen_hierarchy(cls, v: Any) -> ThreadSafeSet[str]: if isinstance(v, ThreadSafeSet): return v if isinstance(v, set): return ThreadSafeSet(v) if isinstance(v, list): return ThreadSafeSet(set(v)) return ThreadSafeSet() @field_validator("fully_walked_hierarchy_node_raw_ids", mode="before") def validate_fully_walked_hierarchy(cls, v: Any) -> ThreadSafeSet[str]: if isinstance(v, ThreadSafeSet): return v if isinstance(v, set): return ThreadSafeSet(v) if isinstance(v, list): return ThreadSafeSet(set(v)) return ThreadSafeSet() ================================================ FILE: backend/onyx/connectors/google_drive/section_extraction.py ================================================ from typing import Any from pydantic import BaseModel from onyx.connectors.google_utils.resources import GoogleDocsService from onyx.connectors.models import TextSection HEADING_DELIMITER = "\n" class CurrentHeading(BaseModel): id: str | None text: str def _build_gdoc_section_link(doc_id: str, tab_id: str, heading_id: str | None) -> str: """Builds a Google Doc link that jumps to a specific heading""" # NOTE: doesn't support docs with multiple tabs atm, if we need that ask # @Chris heading_str = f"#heading={heading_id}" if heading_id else "" return f"https://docs.google.com/document/d/{doc_id}/edit?tab={tab_id}{heading_str}" def _extract_id_from_heading(paragraph: dict[str, Any]) -> str: """Extracts the id from a heading paragraph element""" return paragraph["paragraphStyle"]["headingId"] def _extract_text_from_paragraph(paragraph: dict[str, Any]) -> str: """Extracts the text content from a paragraph element""" text_elements = [] for element in paragraph.get("elements", []): if "textRun" in element: text_elements.append(element["textRun"].get("content", "")) # Handle links if "textStyle" in element and "link" in element["textStyle"]: text_elements.append(f"({element['textStyle']['link'].get('url', '')})") if "person" in element: name = element["person"].get("personProperties", {}).get("name", "") email = element["person"].get("personProperties", {}).get("email", "") person_str = " str: """ Extracts the text content from a table element. """ row_strs = [] for row in table.get("tableRows", []): cells = row.get("tableCells", []) cell_strs = [] for cell in cells: child_elements = cell.get("content", {}) cell_str = [] for child_elem in child_elements: if "paragraph" not in child_elem: continue cell_str.append(_extract_text_from_paragraph(child_elem["paragraph"])) cell_strs.append("".join(cell_str)) row_strs.append(", ".join(cell_strs)) return "\n".join(row_strs) def get_document_sections( docs_service: GoogleDocsService, doc_id: str, ) -> list[TextSection]: """Extracts sections from a Google Doc, including their headings and content""" # Fetch the document structure http_request = docs_service.documents().get(documentId=doc_id) # Google has poor support for tabs in the docs api, see # https://cloud.google.com/python/docs/reference/cloudtasks/ # latest/google.cloud.tasks_v2.types.HttpRequest # https://developers.google.com/workspace/docs/api/how-tos/tabs # https://developers.google.com/workspace/docs/api/reference/rest/v1/documents/get # this is a hack to use the param mentioned in the rest api docs # TODO: check if it can be specified i.e. in documents() http_request.uri += "&includeTabsContent=true" doc = http_request.execute() # Get the content tabs = doc.get("tabs", {}) sections: list[TextSection] = [] for tab in tabs: sections.extend(get_tab_sections(tab, doc_id)) return sections def _is_heading(paragraph: dict[str, Any]) -> bool: """Checks if a paragraph (a block of text in a drive document) is a heading""" if not ( "paragraphStyle" in paragraph and "namedStyleType" in paragraph["paragraphStyle"] ): return False style = paragraph["paragraphStyle"]["namedStyleType"] is_heading = style.startswith("HEADING_") is_title = style.startswith("TITLE") return is_heading or is_title def _add_finished_section( sections: list[TextSection], doc_id: str, tab_id: str, current_heading: CurrentHeading, current_section: list[str], ) -> None: """Adds a finished section to the list of sections if the section has content. Returns the list of sections to use going forward, which may be the old list if a new section was not added. """ if not (current_section or current_heading.text): return # If we were building a previous section, add it to sections list # this is unlikely to ever matter, but helps if the doc contains weird headings header_text = current_heading.text.replace(HEADING_DELIMITER, "") section_text = f"{header_text}{HEADING_DELIMITER}" + "\n".join(current_section) sections.append( TextSection( text=section_text.strip(), link=_build_gdoc_section_link(doc_id, tab_id, current_heading.id), ) ) def get_tab_sections(tab: dict[str, Any], doc_id: str) -> list[TextSection]: tab_id = tab["tabProperties"]["tabId"] content = tab.get("documentTab", {}).get("body", {}).get("content", []) sections: list[TextSection] = [] current_section: list[str] = [] current_heading = CurrentHeading(id=None, text="") for element in content: if "paragraph" in element: paragraph = element["paragraph"] # If this is not a heading, add content to current section if not _is_heading(paragraph): text = _extract_text_from_paragraph(paragraph) if text.strip(): current_section.append(text) continue _add_finished_section( sections, doc_id, tab_id, current_heading, current_section ) current_section = [] # Start new heading heading_id = _extract_id_from_heading(paragraph) heading_text = _extract_text_from_paragraph(paragraph) current_heading = CurrentHeading( id=heading_id, text=heading_text, ) elif "table" in element: text = _extract_text_from_table(element["table"]) if text.strip(): current_section.append(text) # Don't forget to add the last section _add_finished_section(sections, doc_id, tab_id, current_heading, current_section) return sections ================================================ FILE: backend/onyx/connectors/google_site/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/google_site/connector.py ================================================ import os import re from typing import Any from typing import cast from bs4 import BeautifulSoup from bs4 import Tag from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import TextSection from onyx.file_processing.extract_file_text import load_files_from_zip from onyx.file_processing.extract_file_text import read_text_file from onyx.file_processing.html_utils import web_html_cleanup from onyx.file_store.file_store import get_default_file_store from onyx.utils.logger import setup_logger logger = setup_logger() def a_tag_text_to_path(atag: Tag) -> str: page_path = atag.text.strip().lower() page_path = re.sub(r"[^a-zA-Z0-9\s]", "", page_path) page_path = "-".join(page_path.split()) return page_path def find_google_sites_page_path_from_navbar( element: BeautifulSoup | Tag, path: str, depth: int ) -> str | None: lis = cast( list[Tag], element.find_all("li", attrs={"data-nav-level": f"{depth}"}), ) for li in lis: a = cast(Tag, li.find("a")) if a.get("aria-selected") == "true": return f"{path}/{a_tag_text_to_path(a)}" elif a.get("aria-expanded") == "true": sub_path = find_google_sites_page_path_from_navbar( element, f"{path}/{a_tag_text_to_path(a)}", depth + 1 ) if sub_path: return sub_path return None class GoogleSitesConnector(LoadConnector): def __init__( self, zip_path: str, base_url: str, batch_size: int = INDEX_BATCH_SIZE, ): self.zip_path = zip_path self.base_url = base_url self.batch_size = batch_size def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: pass def load_from_state(self) -> GenerateDocumentsOutput: documents: list[Document | HierarchyNode] = [] file_content_io = get_default_file_store().read_file(self.zip_path, mode="b") # load the HTML files files = load_files_from_zip(file_content_io) count = 0 for file_info, file_io in files: # skip non-published files if "/PUBLISHED/" not in file_info.filename: continue file_path, extension = os.path.splitext(file_info.filename) if extension != ".html": continue file_content, _ = read_text_file(file_io) soup = BeautifulSoup(file_content, "html.parser") # get the link out of the navbar header = cast(Tag, soup.find("header")) nav = cast(Tag, header.find("nav")) path = find_google_sites_page_path_from_navbar(nav, "", 1) if not path: count += 1 logger.error( f"Could not find path for '{file_info.filename}'. " + "This page will not have a working link.\n\n" + f"# of broken links so far - {count}" ) logger.info(f"Path to page: {path}") # cleanup the hidden `Skip to main content` and `Skip to navigation` that # appears at the top of every page for div in soup.find_all("div", attrs={"data-is-touch-wrapper": "true"}): div.extract() # get the body of the page parsed_html = web_html_cleanup( soup, additional_element_types_to_discard=["header", "nav"] ) title = parsed_html.title or file_path.split("/")[-1] documents.append( Document( id=f"{DocumentSource.GOOGLE_SITES.value}:{path}", source=DocumentSource.GOOGLE_SITES, semantic_identifier=title, sections=[ TextSection( link=( (self.base_url.rstrip("/") + "/" + path.lstrip("/")) if path else "" ), text=parsed_html.cleaned_text, ) ], metadata={}, ) ) if len(documents) >= self.batch_size: yield documents documents = [] if documents: yield documents if __name__ == "__main__": connector = GoogleSitesConnector( os.environ["GOOGLE_SITES_ZIP_PATH"], os.environ.get("GOOGLE_SITES_BASE_URL", ""), ) for doc_batch in connector.load_from_state(): for doc in doc_batch: print(doc) ================================================ FILE: backend/onyx/connectors/google_utils/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/google_utils/google_auth.py ================================================ import json from typing import Any from google.auth.transport.requests import Request from google.oauth2.credentials import Credentials as OAuthCredentials from google.oauth2.service_account import Credentials as ServiceAccountCredentials from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_ID from onyx.configs.app_configs import OAUTH_GOOGLE_DRIVE_CLIENT_SECRET from onyx.configs.constants import DocumentSource from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_AUTHENTICATION_METHOD, ) from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, ) from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_DICT_TOKEN_KEY, ) from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_PRIMARY_ADMIN_KEY, ) from onyx.connectors.google_utils.shared_constants import ( GOOGLE_SCOPES, ) from onyx.connectors.google_utils.shared_constants import ( GoogleOAuthAuthenticationMethod, ) from onyx.utils.logger import setup_logger logger = setup_logger() def sanitize_oauth_credentials(oauth_creds: OAuthCredentials) -> str: """we really don't want to be persisting the client id and secret anywhere but the environment. Returns a string of serialized json. """ # strip the client id and secret oauth_creds_json_str = oauth_creds.to_json() oauth_creds_sanitized_json: dict[str, Any] = json.loads(oauth_creds_json_str) oauth_creds_sanitized_json.pop("client_id", None) oauth_creds_sanitized_json.pop("client_secret", None) oauth_creds_sanitized_json_str = json.dumps(oauth_creds_sanitized_json) return oauth_creds_sanitized_json_str def get_google_oauth_creds( token_json_str: str, source: DocumentSource ) -> OAuthCredentials | None: """creds_json only needs to contain client_id, client_secret and refresh_token to refresh the creds. expiry and token are optional ... however, if passing in expiry, token should also be passed in or else we may not return any creds. (probably a sign we should refactor the function) """ creds_json = json.loads(token_json_str) creds = OAuthCredentials.from_authorized_user_info( info=creds_json, scopes=GOOGLE_SCOPES[source], ) if creds.valid: return creds if creds.expired and creds.refresh_token: try: creds.refresh(Request()) if creds.valid: logger.notice("Refreshed Google Drive tokens.") return creds except Exception: logger.exception("Failed to refresh google drive access token") return None return None def get_google_creds( credentials: dict[str, str], source: DocumentSource, ) -> tuple[ServiceAccountCredentials | OAuthCredentials, dict[str, str] | None]: """Checks for two different types of credentials. (1) A credential which holds a token acquired via a user going through the Google OAuth flow. (2) A credential which holds a service account key JSON file, which can then be used to impersonate any user in the workspace. Return a tuple where: The first element is the requested credentials The second element is a new credentials dict that the caller should write back to the db. This happens if token rotation occurs while loading credentials. """ oauth_creds = None service_creds = None new_creds_dict = None if DB_CREDENTIALS_DICT_TOKEN_KEY in credentials: # OAUTH authentication_method: str = credentials.get( DB_CREDENTIALS_AUTHENTICATION_METHOD, GoogleOAuthAuthenticationMethod.UPLOADED.value, ) credentials_dict_str = credentials[DB_CREDENTIALS_DICT_TOKEN_KEY] credentials_dict = json.loads(credentials_dict_str) # only send what get_google_oauth_creds needs authorized_user_info = {} # oauth_interactive is sanitized and needs credentials from the environment if ( authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value ): authorized_user_info["client_id"] = OAUTH_GOOGLE_DRIVE_CLIENT_ID authorized_user_info["client_secret"] = OAUTH_GOOGLE_DRIVE_CLIENT_SECRET else: authorized_user_info["client_id"] = credentials_dict["client_id"] authorized_user_info["client_secret"] = credentials_dict["client_secret"] authorized_user_info["refresh_token"] = credentials_dict["refresh_token"] authorized_user_info["token"] = credentials_dict["token"] authorized_user_info["expiry"] = credentials_dict["expiry"] token_json_str = json.dumps(authorized_user_info) oauth_creds = get_google_oauth_creds( token_json_str=token_json_str, source=source ) # tell caller to update token stored in DB if the refresh token changed if oauth_creds: if oauth_creds.refresh_token != authorized_user_info["refresh_token"]: # if oauth_interactive, sanitize the credentials so they don't get stored in the db if ( authentication_method == GoogleOAuthAuthenticationMethod.OAUTH_INTERACTIVE.value ): oauth_creds_json_str = sanitize_oauth_credentials(oauth_creds) else: oauth_creds_json_str = oauth_creds.to_json() new_creds_dict = { DB_CREDENTIALS_DICT_TOKEN_KEY: oauth_creds_json_str, DB_CREDENTIALS_PRIMARY_ADMIN_KEY: credentials[ DB_CREDENTIALS_PRIMARY_ADMIN_KEY ], DB_CREDENTIALS_AUTHENTICATION_METHOD: authentication_method, } elif DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY in credentials: # SERVICE ACCOUNT service_account_key_json_str = credentials[ DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY ] service_account_key = json.loads(service_account_key_json_str) service_creds = ServiceAccountCredentials.from_service_account_info( service_account_key, scopes=GOOGLE_SCOPES[source] ) if not service_creds.valid or not service_creds.expired: service_creds.refresh(Request()) if not service_creds.valid: raise PermissionError( f"Unable to access {source} - service account credentials are invalid." ) creds: ServiceAccountCredentials | OAuthCredentials | None = ( oauth_creds or service_creds ) if creds is None: raise PermissionError( f"Unable to access {source} - unknown credential structure." ) return creds, new_creds_dict ================================================ FILE: backend/onyx/connectors/google_utils/google_kv.py ================================================ import json from typing import cast from urllib.parse import parse_qs from urllib.parse import ParseResult from urllib.parse import urlparse from google.oauth2.credentials import Credentials as OAuthCredentials from google_auth_oauthlib.flow import InstalledAppFlow # type: ignore from sqlalchemy.orm import Session from onyx.configs.app_configs import WEB_DOMAIN from onyx.configs.constants import DocumentSource from onyx.configs.constants import KV_CRED_KEY from onyx.configs.constants import KV_GMAIL_CRED_KEY from onyx.configs.constants import KV_GMAIL_SERVICE_ACCOUNT_KEY from onyx.configs.constants import KV_GOOGLE_DRIVE_CRED_KEY from onyx.configs.constants import KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY from onyx.connectors.google_utils.resources import get_drive_service from onyx.connectors.google_utils.resources import get_gmail_service from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_AUTHENTICATION_METHOD, ) from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, ) from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_DICT_TOKEN_KEY, ) from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_PRIMARY_ADMIN_KEY, ) from onyx.connectors.google_utils.shared_constants import ( GOOGLE_SCOPES, ) from onyx.connectors.google_utils.shared_constants import ( GoogleOAuthAuthenticationMethod, ) from onyx.connectors.google_utils.shared_constants import ( MISSING_SCOPES_ERROR_STR, ) from onyx.connectors.google_utils.shared_constants import ( ONYX_SCOPE_INSTRUCTIONS, ) from onyx.db.credentials import update_credential_json from onyx.db.models import User from onyx.key_value_store.factory import get_kv_store from onyx.key_value_store.interface import unwrap_str from onyx.server.documents.models import CredentialBase from onyx.server.documents.models import GoogleAppCredentials from onyx.server.documents.models import GoogleServiceAccountKey from onyx.utils.logger import setup_logger logger = setup_logger() def _build_frontend_google_drive_redirect(source: DocumentSource) -> str: if source == DocumentSource.GOOGLE_DRIVE: return f"{WEB_DOMAIN}/admin/connectors/google-drive/auth/callback" elif source == DocumentSource.GMAIL: return f"{WEB_DOMAIN}/admin/connectors/gmail/auth/callback" else: raise ValueError(f"Unsupported source: {source}") def _get_current_oauth_user(creds: OAuthCredentials, source: DocumentSource) -> str: if source == DocumentSource.GOOGLE_DRIVE: drive_service = get_drive_service(creds) user_info = ( drive_service.about() .get( fields="user(emailAddress)", ) .execute() ) email = user_info.get("user", {}).get("emailAddress") elif source == DocumentSource.GMAIL: gmail_service = get_gmail_service(creds) user_info = ( gmail_service.users() .getProfile( userId="me", fields="emailAddress", ) .execute() ) email = user_info.get("emailAddress") else: raise ValueError(f"Unsupported source: {source}") return email def verify_csrf(credential_id: int, state: str) -> None: csrf = unwrap_str(get_kv_store().load(KV_CRED_KEY.format(str(credential_id)))) if csrf != state: raise PermissionError( "State from Google Drive Connector callback does not match expected" ) def update_credential_access_tokens( auth_code: str, credential_id: int, user: User, db_session: Session, source: DocumentSource, auth_method: GoogleOAuthAuthenticationMethod, ) -> OAuthCredentials | None: app_credentials = get_google_app_cred(source) flow = InstalledAppFlow.from_client_config( app_credentials.model_dump(), scopes=GOOGLE_SCOPES[source], redirect_uri=_build_frontend_google_drive_redirect(source), ) flow.fetch_token(code=auth_code) creds = flow.credentials token_json_str = creds.to_json() # Get user email from Google API so we know who # the primary admin is for this connector try: email = _get_current_oauth_user(creds, source) except Exception as e: if MISSING_SCOPES_ERROR_STR in str(e): raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e raise e new_creds_dict = { DB_CREDENTIALS_DICT_TOKEN_KEY: token_json_str, DB_CREDENTIALS_PRIMARY_ADMIN_KEY: email, DB_CREDENTIALS_AUTHENTICATION_METHOD: auth_method.value, } if not update_credential_json(credential_id, new_creds_dict, user, db_session): return None return creds def build_service_account_creds( source: DocumentSource, primary_admin_email: str | None = None, name: str | None = None, ) -> CredentialBase: service_account_key = get_service_account_key(source=source) credential_dict = { DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY: service_account_key.json(), } if primary_admin_email: credential_dict[DB_CREDENTIALS_PRIMARY_ADMIN_KEY] = primary_admin_email credential_dict[DB_CREDENTIALS_AUTHENTICATION_METHOD] = ( GoogleOAuthAuthenticationMethod.UPLOADED.value ) return CredentialBase( credential_json=credential_dict, admin_public=True, source=source, name=name, ) def get_auth_url(credential_id: int, source: DocumentSource) -> str: if source == DocumentSource.GOOGLE_DRIVE: creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY)) elif source == DocumentSource.GMAIL: creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY)) else: raise ValueError(f"Unsupported source: {source}") credential_json = json.loads(creds_str) flow = InstalledAppFlow.from_client_config( credential_json, scopes=GOOGLE_SCOPES[source], redirect_uri=_build_frontend_google_drive_redirect(source), ) auth_url, _ = flow.authorization_url(prompt="consent") parsed_url = cast(ParseResult, urlparse(auth_url)) params = parse_qs(parsed_url.query) get_kv_store().store( KV_CRED_KEY.format(credential_id), {"value": params.get("state", [None])[0]}, encrypt=True, ) return str(auth_url) def get_google_app_cred(source: DocumentSource) -> GoogleAppCredentials: if source == DocumentSource.GOOGLE_DRIVE: creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_CRED_KEY)) elif source == DocumentSource.GMAIL: creds_str = str(get_kv_store().load(KV_GMAIL_CRED_KEY)) else: raise ValueError(f"Unsupported source: {source}") return GoogleAppCredentials(**json.loads(creds_str)) def upsert_google_app_cred( app_credentials: GoogleAppCredentials, source: DocumentSource ) -> None: if source == DocumentSource.GOOGLE_DRIVE: get_kv_store().store( KV_GOOGLE_DRIVE_CRED_KEY, app_credentials.json(), encrypt=True ) elif source == DocumentSource.GMAIL: get_kv_store().store(KV_GMAIL_CRED_KEY, app_credentials.json(), encrypt=True) else: raise ValueError(f"Unsupported source: {source}") def delete_google_app_cred(source: DocumentSource) -> None: if source == DocumentSource.GOOGLE_DRIVE: get_kv_store().delete(KV_GOOGLE_DRIVE_CRED_KEY) elif source == DocumentSource.GMAIL: get_kv_store().delete(KV_GMAIL_CRED_KEY) else: raise ValueError(f"Unsupported source: {source}") def get_service_account_key(source: DocumentSource) -> GoogleServiceAccountKey: if source == DocumentSource.GOOGLE_DRIVE: creds_str = str(get_kv_store().load(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY)) elif source == DocumentSource.GMAIL: creds_str = str(get_kv_store().load(KV_GMAIL_SERVICE_ACCOUNT_KEY)) else: raise ValueError(f"Unsupported source: {source}") return GoogleServiceAccountKey(**json.loads(creds_str)) def upsert_service_account_key( service_account_key: GoogleServiceAccountKey, source: DocumentSource ) -> None: if source == DocumentSource.GOOGLE_DRIVE: get_kv_store().store( KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True, ) elif source == DocumentSource.GMAIL: get_kv_store().store( KV_GMAIL_SERVICE_ACCOUNT_KEY, service_account_key.json(), encrypt=True ) else: raise ValueError(f"Unsupported source: {source}") def delete_service_account_key(source: DocumentSource) -> None: if source == DocumentSource.GOOGLE_DRIVE: get_kv_store().delete(KV_GOOGLE_DRIVE_SERVICE_ACCOUNT_KEY) elif source == DocumentSource.GMAIL: get_kv_store().delete(KV_GMAIL_SERVICE_ACCOUNT_KEY) else: raise ValueError(f"Unsupported source: {source}") ================================================ FILE: backend/onyx/connectors/google_utils/google_utils.py ================================================ import re import socket import time from collections.abc import Callable from collections.abc import Iterator from datetime import datetime from datetime import timezone from enum import Enum from typing import Any from googleapiclient.errors import HttpError # type: ignore from onyx.connectors.google_drive.models import GoogleDriveFileType from onyx.utils.logger import setup_logger from onyx.utils.retry_wrapper import retry_builder logger = setup_logger() _RATE_LIMIT_REASONS = {"userRateLimitExceeded", "rateLimitExceeded"} def _is_rate_limit_error(error: HttpError) -> bool: """Google sometimes returns rate-limit errors as 403 with reason 'userRateLimitExceeded' instead of 429. This helper detects both.""" if error.resp.status == 429: return True if error.resp.status != 403: return False error_details = getattr(error, "error_details", None) or [] for detail in error_details: if isinstance(detail, dict) and detail.get("reason") in _RATE_LIMIT_REASONS: return True return "userRateLimitExceeded" in str(error) or "rateLimitExceeded" in str(error) # Google Drive APIs are quite flakey and may 500 for an # extended period of time. This is now addressed by checkpointing. # # NOTE: We previously tried to combat this here by adding a very # long retry period (~20 minutes of trying, one request a minute.) # This is no longer necessary due to checkpointing. add_retries = retry_builder(tries=5, max_delay=10) NEXT_PAGE_TOKEN_KEY = "nextPageToken" PAGE_TOKEN_KEY = "pageToken" ORDER_BY_KEY = "orderBy" # See https://developers.google.com/drive/api/reference/rest/v3/files/list for more class GoogleFields(str, Enum): ID = "id" CREATED_TIME = "createdTime" MODIFIED_TIME = "modifiedTime" NAME = "name" SIZE = "size" PARENTS = "parents" def _execute_with_retry(request: Any) -> Any: max_attempts = 6 attempt = 1 while attempt < max_attempts: # Note for reasons unknown, the Google API will sometimes return a 429 # and even after waiting the retry period, it will return another 429. # It could be due to a few possibilities: # 1. Other things are also requesting from the Drive/Gmail API with the same key # 2. It's a rolling rate limit so the moment we get some amount of requests cleared, we hit it again very quickly # 3. The retry-after has a maximum and we've already hit the limit for the day # or it's something else... try: return request.execute() except HttpError as error: attempt += 1 if _is_rate_limit_error(error): # Attempt to get 'Retry-After' from headers retry_after = error.resp.get("Retry-After") if retry_after: sleep_time = int(retry_after) else: # Extract 'Retry after' timestamp from error message match = re.search( r"Retry after (\d{4}-\d{2}-\d{2}T\d{2}:\d{2}:\d{2}\.\d+Z)", str(error), ) if match: retry_after_timestamp = match.group(1) retry_after_dt = datetime.strptime( retry_after_timestamp, "%Y-%m-%dT%H:%M:%S.%fZ" ).replace(tzinfo=timezone.utc) current_time = datetime.now(timezone.utc) sleep_time = max( int((retry_after_dt - current_time).total_seconds()), 0, ) else: logger.error( f"No Retry-After header or timestamp found in error message: {error}" ) sleep_time = 60 sleep_time += 3 # Add a buffer to be safe logger.info( f"Rate limit exceeded. Attempt {attempt}/{max_attempts}. Sleeping for {sleep_time} seconds." ) time.sleep(sleep_time) else: raise # If we've exhausted all attempts raise Exception(f"Failed to execute request after {max_attempts} attempts") def get_file_owners(file: GoogleDriveFileType, primary_admin_email: str) -> list[str]: """ Get the owners of a file if the attribute is present. """ return [ email for owner in file.get("owners", []) if (email := owner.get("emailAddress")) and email.split("@")[-1] == primary_admin_email.split("@")[-1] ] def _execute_single_retrieval( retrieval_function: Callable, continue_on_404_or_403: bool = False, **request_kwargs: Any, ) -> GoogleDriveFileType: """Execute a single retrieval from Google Drive API""" try: results = retrieval_function(**request_kwargs).execute() except HttpError as e: if e.resp.status >= 500: results = add_retries( lambda: retrieval_function(**request_kwargs).execute() )() elif e.resp.status == 400: if ( "pageToken" in request_kwargs and "Invalid Value" in str(e) and "pageToken" in str(e) ): logger.warning( f"Invalid page token: {request_kwargs['pageToken']}, retrying from start of request" ) request_kwargs.pop("pageToken") return _execute_single_retrieval( retrieval_function, continue_on_404_or_403, **request_kwargs, ) logger.error(f"Error executing request: {e}") raise e elif _is_rate_limit_error(e): results = _execute_with_retry(retrieval_function(**request_kwargs)) elif e.resp.status == 404 or e.resp.status == 403: if continue_on_404_or_403: logger.debug(f"Error executing request: {e}") results = {} else: raise e else: logger.exception("Error executing request:") raise e except (TimeoutError, socket.timeout) as error: logger.warning( "Timed out executing Google API request; retrying with backoff. Details: %s", error, ) results = add_retries(lambda: retrieval_function(**request_kwargs).execute())() return results def execute_single_retrieval( retrieval_function: Callable, list_key: str | None = None, continue_on_404_or_403: bool = False, **request_kwargs: Any, ) -> Iterator[GoogleDriveFileType]: results = _execute_single_retrieval( retrieval_function, continue_on_404_or_403, **request_kwargs, ) if list_key: for item in results.get(list_key, []): yield item else: yield results # included for type purposes; caller should not need to address # Nones unless max_num_pages is specified. Use # execute_paginated_retrieval_with_max_pages instead if you want # the early stop + yield None after max_num_pages behavior. def execute_paginated_retrieval( retrieval_function: Callable, list_key: str | None = None, continue_on_404_or_403: bool = False, **kwargs: Any, ) -> Iterator[GoogleDriveFileType]: for item in _execute_paginated_retrieval( retrieval_function, list_key, continue_on_404_or_403, **kwargs, ): if not isinstance(item, str): yield item def execute_paginated_retrieval_with_max_pages( retrieval_function: Callable, max_num_pages: int, list_key: str | None = None, continue_on_404_or_403: bool = False, **kwargs: Any, ) -> Iterator[GoogleDriveFileType | str]: yield from _execute_paginated_retrieval( retrieval_function, list_key, continue_on_404_or_403, max_num_pages=max_num_pages, **kwargs, ) def _execute_paginated_retrieval( retrieval_function: Callable, list_key: str | None = None, continue_on_404_or_403: bool = False, max_num_pages: int | None = None, **kwargs: Any, ) -> Iterator[GoogleDriveFileType | str]: """Execute a paginated retrieval from Google Drive API Args: retrieval_function: The specific list function to call (e.g., service.files().list) list_key: If specified, each object returned by the retrieval function will be accessed at the specified key and yielded from. continue_on_404_or_403: If True, the retrieval will continue even if the request returns a 404 or 403 error. max_num_pages: If specified, the retrieval will stop after the specified number of pages and yield None. **kwargs: Arguments to pass to the list function """ if "fields" not in kwargs or "nextPageToken" not in kwargs["fields"]: raise ValueError( "fields must contain nextPageToken for execute_paginated_retrieval" ) next_page_token = kwargs.get(PAGE_TOKEN_KEY, "") num_pages = 0 while next_page_token is not None: if max_num_pages is not None and num_pages >= max_num_pages: yield next_page_token return num_pages += 1 request_kwargs = kwargs.copy() if next_page_token: request_kwargs[PAGE_TOKEN_KEY] = next_page_token results = _execute_single_retrieval( retrieval_function, continue_on_404_or_403, **request_kwargs, ) next_page_token = results.get(NEXT_PAGE_TOKEN_KEY) if list_key: for item in results.get(list_key, []): yield item else: yield results ================================================ FILE: backend/onyx/connectors/google_utils/resources.py ================================================ from collections.abc import Callable from typing import Any from google.auth.exceptions import RefreshError from google.oauth2.credentials import Credentials as OAuthCredentials from google.oauth2.service_account import Credentials as ServiceAccountCredentials from googleapiclient.discovery import build # type: ignore[import-untyped] from googleapiclient.discovery import Resource from onyx.utils.logger import setup_logger logger = setup_logger() class GoogleDriveService(Resource): pass class GoogleDocsService(Resource): pass class AdminService(Resource): pass class GmailService(Resource): pass class RefreshableDriveObject: """ Running Google drive service retrieval functions involves accessing methods of the service object (ie. files().list()) which can raise a RefreshError if the access token is expired. This class is a wrapper that propagates the ability to refresh the access token and retry the final retrieval function until execute() is called. """ def __init__( self, call_stack: Callable[[ServiceAccountCredentials | OAuthCredentials], Any], creds: ServiceAccountCredentials | OAuthCredentials, creds_getter: Callable[..., ServiceAccountCredentials | OAuthCredentials], ): self.call_stack = call_stack self.creds = creds self.creds_getter = creds_getter def __getattr__(self, name: str) -> Any: if name == "execute": return self.make_refreshable_execute() return RefreshableDriveObject( lambda creds: getattr(self.call_stack(creds), name), self.creds, self.creds_getter, ) def __call__(self, *args: Any, **kwargs: Any) -> Any: return RefreshableDriveObject( lambda creds: self.call_stack(creds)(*args, **kwargs), self.creds, self.creds_getter, ) def make_refreshable_execute(self) -> Callable: def execute(*args: Any, **kwargs: Any) -> Any: try: return self.call_stack(self.creds).execute(*args, **kwargs) except RefreshError as e: logger.warning( f"RefreshError, going to attempt a creds refresh and retry: {e}" ) # Refresh the access token self.creds = self.creds_getter() return self.call_stack(self.creds).execute(*args, **kwargs) return execute def _get_google_service( service_name: str, service_version: str, creds: ServiceAccountCredentials | OAuthCredentials, user_email: str | None = None, ) -> GoogleDriveService | GoogleDocsService | AdminService | GmailService: service: Resource if isinstance(creds, ServiceAccountCredentials): # NOTE: https://developers.google.com/identity/protocols/oauth2/service-account#error-codes creds = creds.with_subject(user_email) service = build(service_name, service_version, credentials=creds) elif isinstance(creds, OAuthCredentials): service = build(service_name, service_version, credentials=creds) return service def get_google_docs_service( creds: ServiceAccountCredentials | OAuthCredentials, user_email: str | None = None, ) -> GoogleDocsService: return _get_google_service("docs", "v1", creds, user_email) def get_drive_service( creds: ServiceAccountCredentials | OAuthCredentials, user_email: str | None = None, ) -> GoogleDriveService: return _get_google_service("drive", "v3", creds, user_email) def get_admin_service( creds: ServiceAccountCredentials | OAuthCredentials, user_email: str | None = None, ) -> AdminService: return _get_google_service("admin", "directory_v1", creds, user_email) def get_gmail_service( creds: ServiceAccountCredentials | OAuthCredentials, user_email: str | None = None, ) -> GmailService: return _get_google_service("gmail", "v1", creds, user_email) ================================================ FILE: backend/onyx/connectors/google_utils/shared_constants.py ================================================ from enum import Enum as PyEnum from onyx.configs.constants import DocumentSource # NOTE: do not need https://www.googleapis.com/auth/documents.readonly # this is counted under `/auth/drive.readonly` GOOGLE_SCOPES = { DocumentSource.GOOGLE_DRIVE: [ "https://www.googleapis.com/auth/drive.readonly", "https://www.googleapis.com/auth/drive.metadata.readonly", "https://www.googleapis.com/auth/admin.directory.group.readonly", "https://www.googleapis.com/auth/admin.directory.user.readonly", ], DocumentSource.GMAIL: [ "https://www.googleapis.com/auth/gmail.readonly", "https://www.googleapis.com/auth/admin.directory.user.readonly", "https://www.googleapis.com/auth/admin.directory.group.readonly", ], } # This is the Oauth token DB_CREDENTIALS_DICT_TOKEN_KEY = "google_tokens" # This is the service account key DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY = "google_service_account_key" # The email saved for both auth types DB_CREDENTIALS_PRIMARY_ADMIN_KEY = "google_primary_admin" # https://developers.google.com/workspace/guides/create-credentials # Internally defined authentication method type. # The value must be one of "oauth_interactive" or "uploaded" # Used to disambiguate whether credentials have already been created via # certain methods and what actions we allow users to take DB_CREDENTIALS_AUTHENTICATION_METHOD = "authentication_method" class GoogleOAuthAuthenticationMethod(str, PyEnum): OAUTH_INTERACTIVE = "oauth_interactive" UPLOADED = "uploaded" USER_FIELDS = "nextPageToken, users(primaryEmail)" # Error message substrings MISSING_SCOPES_ERROR_STR = "client not authorized for any of the scopes requested" # Documentation and error messages SCOPE_DOC_URL = "https://docs.onyx.app/admins/connectors/official/google_drive/overview" ONYX_SCOPE_INSTRUCTIONS = ( "You have upgraded Onyx without updating the Google Auth scopes. " f"Please refer to the documentation to learn how to update the scopes: {SCOPE_DOC_URL}" ) # This is the maximum number of threads that can be retrieved at once SLIM_BATCH_SIZE = 500 ================================================ FILE: backend/onyx/connectors/guru/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/guru/connector.py ================================================ import json from datetime import datetime from datetime import timezone from typing import Any import requests from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import TextSection from onyx.file_processing.html_utils import parse_html_page_basic from onyx.utils.logger import setup_logger logger = setup_logger() # Potential Improvements # 1. Support fetching per collection via collection token (configured at connector creation) GURU_API_BASE = "https://api.getguru.com/api/v1/" GURU_QUERY_ENDPOINT = GURU_API_BASE + "search/query" GURU_CARDS_URL = "https://app.getguru.com/card/" def unixtime_to_guru_time_str(unix_time: SecondsSinceUnixEpoch) -> str: date_obj = datetime.fromtimestamp(unix_time, tz=timezone.utc) date_str = date_obj.strftime("%Y-%m-%dT%H:%M:%S.%f")[:-3] tz_str = date_obj.strftime("%z") return date_str + tz_str class GuruConnector(LoadConnector, PollConnector): def __init__( self, batch_size: int = INDEX_BATCH_SIZE, guru_user: str | None = None, guru_user_token: str | None = None, ) -> None: self.batch_size = batch_size self.guru_user = guru_user self.guru_user_token = guru_user_token def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self.guru_user = credentials["guru_user"] self.guru_user_token = credentials["guru_user_token"] return None def _process_cards( self, start_str: str | None = None, end_str: str | None = None ) -> GenerateDocumentsOutput: if self.guru_user is None or self.guru_user_token is None: raise ConnectorMissingCredentialError("Guru") doc_batch: list[Document | HierarchyNode] = [] session = requests.Session() session.auth = (self.guru_user, self.guru_user_token) params: dict[str, str | int] = {"maxResults": self.batch_size} if start_str is not None and end_str is not None: params["q"] = f"lastModified >= {start_str} AND lastModified < {end_str}" current_url = GURU_QUERY_ENDPOINT # This is how they handle pagination, a different url will be provided while True: response = session.get(current_url, params=params) response.raise_for_status() if response.status_code == 204: break cards = json.loads(response.text) for card in cards: title = card["preferredPhrase"] link = GURU_CARDS_URL + card["slug"] content_text = parse_html_page_basic(card["content"]) last_updated = time_str_to_utc(card["lastModified"]) last_verified = ( time_str_to_utc(card.get("lastVerified")) if card.get("lastVerified") else None ) # For Onyx, we decay document score overtime, either last_updated or # last_verified is a good enough signal for the document's recency latest_time = ( max(last_verified, last_updated) if last_verified else last_updated ) metadata_dict: dict[str, str | list[str]] = {} tags = [tag.get("value") for tag in card.get("tags", [])] if tags: metadata_dict["tags"] = tags boards = [board.get("title") for board in card.get("boards", [])] if boards: # In UI it's called Folders metadata_dict["folders"] = boards collection = card.get("collection", {}) if collection: metadata_dict["collection_name"] = collection.get("name", "") owner = card.get("owner", {}) author = None if owner: author = BasicExpertInfo( email=owner.get("email"), first_name=owner.get("firstName"), last_name=owner.get("lastName"), ) doc_batch.append( Document( id=card["id"], sections=[TextSection(link=link, text=content_text)], source=DocumentSource.GURU, semantic_identifier=title, doc_updated_at=latest_time, primary_owners=[author] if author is not None else None, # Can add verifies and commenters later metadata=metadata_dict, ) ) if len(doc_batch) >= self.batch_size: yield doc_batch doc_batch = [] if not hasattr(response, "links") or not response.links: break current_url = response.links["next-page"]["url"] if doc_batch: yield doc_batch def load_from_state(self) -> GenerateDocumentsOutput: return self._process_cards() def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: start_time = unixtime_to_guru_time_str(start) end_time = unixtime_to_guru_time_str(end) return self._process_cards(start_time, end_time) if __name__ == "__main__": import os connector = GuruConnector() connector.load_credentials( { "guru_user": os.environ["GURU_USER"], "guru_user_token": os.environ["GURU_USER_TOKEN"], } ) latest_docs = connector.load_from_state() print(next(latest_docs)) ================================================ FILE: backend/onyx/connectors/highspot/__init__.py ================================================ """ Highspot connector package for Onyx. Enables integration with Highspot's knowledge base. """ ================================================ FILE: backend/onyx/connectors/highspot/client.py ================================================ import base64 from typing import Any from typing import Dict from typing import List from typing import Optional from urllib.parse import urljoin import requests from requests.adapters import HTTPAdapter from requests.exceptions import HTTPError from requests.exceptions import RequestException from requests.exceptions import Timeout from urllib3.util.retry import Retry from onyx.utils.logger import setup_logger logger = setup_logger() PAGE_SIZE = 100 class HighspotClientError(Exception): """Base exception for Highspot API client errors.""" def __init__(self, message: str, status_code: Optional[int] = None): self.message = message self.status_code = status_code super().__init__(self.message) class HighspotAuthenticationError(HighspotClientError): """Exception raised for authentication errors.""" class HighspotRateLimitError(HighspotClientError): """Exception raised when rate limit is exceeded.""" def __init__(self, message: str, retry_after: Optional[str] = None): self.retry_after = retry_after super().__init__(message) class HighspotClient: """ Client for interacting with the Highspot API. Uses basic authentication with provided key (username) and secret (password). Implements retry logic, error handling, and connection pooling. """ BASE_URL = "https://api-su2.highspot.com/v1.0/" def __init__( self, key: str, secret: str, base_url: str = BASE_URL, timeout: int = 30, max_retries: int = 3, backoff_factor: float = 0.5, status_forcelist: Optional[List[int]] = None, ): """ Initialize the Highspot API client. Args: key: API key (used as username) secret: API secret (used as password) base_url: Base URL for the Highspot API timeout: Request timeout in seconds max_retries: Maximum number of retries for failed requests backoff_factor: Backoff factor for retries status_forcelist: HTTP status codes to retry on """ if not key or not secret: raise ValueError("API key and secret are required") self.key = key self.secret = secret self.base_url = base_url.rstrip("/") + "/" self.timeout = timeout # Set up session with retry logic self.session = requests.Session() retry_strategy = Retry( total=max_retries, backoff_factor=backoff_factor, status_forcelist=status_forcelist or [429, 500, 502, 503, 504], allowed_methods=["GET", "POST", "PUT", "DELETE"], ) adapter = HTTPAdapter(max_retries=retry_strategy) self.session.mount("http://", adapter) self.session.mount("https://", adapter) # Set up authentication self._setup_auth() def _setup_auth(self) -> None: """Set up basic authentication for the session.""" auth = f"{self.key}:{self.secret}" encoded_auth = base64.b64encode(auth.encode()).decode() self.session.headers.update( { "Authorization": f"Basic {encoded_auth}", "Content-Type": "application/json", "Accept": "application/json", } ) def _make_request( self, method: str, endpoint: str, params: Optional[Dict[str, Any]] = None, data: Optional[Dict[str, Any]] = None, json_data: Optional[Dict[str, Any]] = None, headers: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: """ Make a request to the Highspot API. Args: method: HTTP method (GET, POST, etc.) endpoint: API endpoint params: URL parameters data: Form data json_data: JSON data headers: Additional headers Returns: API response as a dictionary Raises: HighspotClientError: On API errors HighspotAuthenticationError: On authentication errors HighspotRateLimitError: On rate limiting requests.exceptions.RequestException: On request failures """ url = urljoin(self.base_url, endpoint) request_headers = {} if headers: request_headers.update(headers) try: logger.debug(f"Making {method} request to {url}") response = self.session.request( method=method, url=url, params=params, data=data, json=json_data, headers=request_headers, timeout=self.timeout, ) response.raise_for_status() if response.content and response.content.strip(): return response.json() return {} except HTTPError as e: status_code = e.response.status_code error_msg = str(e) try: error_data = e.response.json() if isinstance(error_data, dict): error_msg = error_data.get("message", str(e)) except (ValueError, KeyError): pass if status_code == 401: raise HighspotAuthenticationError(f"Authentication failed: {error_msg}") elif status_code == 429: retry_after = e.response.headers.get("Retry-After") raise HighspotRateLimitError( f"Rate limit exceeded: {error_msg}", retry_after=retry_after ) else: raise HighspotClientError( f"API error {status_code}: {error_msg}", status_code=status_code ) except Timeout: raise HighspotClientError("Request timed out") except RequestException as e: raise HighspotClientError(f"Request failed: {str(e)}") def get_spots(self) -> List[Dict[str, Any]]: """ Get all available spots, paginated. Returns: List of spots with their names and IDs """ all_spots = [] has_more = True current_offset = 0 while has_more: params = {"right": "view", "start": current_offset, "limit": PAGE_SIZE} response = self._make_request("GET", "spots", params=params) found_spots = response.get("collection", []) logger.info(f"Received {len(found_spots)} spots at offset {current_offset}") all_spots.extend(found_spots) if len(found_spots) < PAGE_SIZE: has_more = False else: current_offset += PAGE_SIZE logger.info(f"Total spots retrieved: {len(all_spots)}") return all_spots def get_spot(self, spot_id: str) -> Dict[str, Any]: """ Get details for a specific spot. Args: spot_id: ID of the spot Returns: Spot details """ if not spot_id: raise ValueError("spot_id is required") return self._make_request("GET", f"spots/{spot_id}") def get_spot_items( self, spot_id: str, offset: int = 0, page_size: int = PAGE_SIZE ) -> Dict[str, Any]: """ Get items in a specific spot. Args: spot_id: ID of the spot offset: offset number page_size: Number of items per page Returns: Items in the spot """ if not spot_id: raise ValueError("spot_id is required") params = {"spot": spot_id, "start": offset, "limit": page_size} return self._make_request("GET", "items", params=params) def get_item(self, item_id: str) -> Dict[str, Any]: """ Get details for a specific item. Args: item_id: ID of the item Returns: Item details """ if not item_id: raise ValueError("item_id is required") return self._make_request("GET", f"items/{item_id}") def get_item_content(self, item_id: str) -> bytes: """ Get the raw content of an item. Args: item_id: ID of the item Returns: Raw content bytes """ if not item_id: raise ValueError("item_id is required") url = urljoin(self.base_url, f"items/{item_id}/content") response = self.session.get(url, timeout=self.timeout) response.raise_for_status() return response.content def health_check(self) -> bool: """ Check if the API is accessible and credentials are valid. Returns: True if API is accessible, False otherwise """ try: self._make_request("GET", "spots", params={"limit": 1}) return True except (HighspotClientError, HighspotAuthenticationError): return False ================================================ FILE: backend/onyx/connectors/highspot/connector.py ================================================ import os from datetime import datetime from io import BytesIO from typing import Any from typing import Dict from typing import List from typing import Optional from pydantic import BaseModel from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.highspot.client import HighspotClient from onyx.connectors.highspot.client import HighspotClientError from onyx.connectors.highspot.utils import scrape_url_content from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnectorWithPermSync from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import SlimDocument from onyx.connectors.models import TextSection from onyx.file_processing.extract_file_text import extract_file_text from onyx.file_processing.file_types import OnyxFileExtensions from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger logger = setup_logger() _SLIM_BATCH_SIZE = 1000 class HighspotSpot(BaseModel): id: str name: str class HighspotConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """ Connector for loading data from Highspot. Retrieves content from specified spots using the Highspot API. If no spots are specified, retrieves content from all available spots. """ def __init__( self, spot_names: list[str] | None = None, batch_size: int = INDEX_BATCH_SIZE, ): """ Initialize the Highspot connector. Args: spot_names: List of spot names to retrieve content from (if empty, gets all spots) batch_size: Number of items to retrieve in each batch """ self.spot_names = spot_names or [] self.batch_size = batch_size self._client: Optional[HighspotClient] = None self.highspot_url: Optional[str] = None self.key: Optional[str] = None self.secret: Optional[str] = None @property def client(self) -> HighspotClient: if self._client is None: if not self.key or not self.secret: raise ConnectorMissingCredentialError("Highspot") # Ensure highspot_url is a string, use default if None base_url = ( self.highspot_url if self.highspot_url is not None else HighspotClient.BASE_URL ) self._client = HighspotClient(self.key, self.secret, base_url=base_url) return self._client def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: logger.info("Loading Highspot credentials") self.highspot_url = credentials.get("highspot_url") self.key = credentials.get("highspot_key") self.secret = credentials.get("highspot_secret") return None def _fetch_spots(self) -> list[HighspotSpot]: """ Populate the spot ID map with all available spots. Keys are stored as lowercase for case-insensitive lookups. """ return [ HighspotSpot(id=spot["id"], name=spot["title"]) for spot in self.client.get_spots() ] def _fetch_spots_to_process(self) -> list[HighspotSpot]: """ Fetch spots to process based on the configured spot names. """ spots = self._fetch_spots() if not spots: raise ValueError("No spots found in Highspot.") if self.spot_names: lower_spot_names = [name.lower() for name in self.spot_names] spots_to_process = [ spot for spot in spots if spot.name.lower() in lower_spot_names ] if not spots_to_process: raise ValueError( f"No valid spots found in Highspot. Found {spots} but {self.spot_names} were requested." ) return spots_to_process return spots def load_from_state(self) -> GenerateDocumentsOutput: """ Load content from configured spots in Highspot. If no spots are configured, loads from all spots. Yields: Batches of Document objects """ return self.poll_source(None, None) def poll_source( self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None ) -> GenerateDocumentsOutput: """ Poll Highspot for content updated since the start time. Args: start: Start time as seconds since Unix epoch end: End time as seconds since Unix epoch Yields: Batches of Document objects """ spots_to_process = self._fetch_spots_to_process() doc_batch: list[Document | HierarchyNode] = [] try: for spot in spots_to_process: try: offset = 0 has_more = True while has_more: logger.info( f"Retrieving items from spot {spot.name}, offset {offset}" ) response = self.client.get_spot_items( spot_id=spot.id, offset=offset, page_size=self.batch_size ) items = response.get("collection", []) logger.info( f"Received {len(items)} items from spot {spot.name}" ) if not items: has_more = False continue for item in items: try: item_id = item.get("id") if not item_id: logger.warning("Item without ID found, skipping") continue item_details = self.client.get_item(item_id) if not item_details: logger.warning( f"Item {item_id} details not found, skipping" ) continue # Apply time filter if specified if start or end: updated_at = item_details.get("date_updated") if updated_at: # Convert to datetime for comparison try: updated_time = datetime.fromisoformat( updated_at.replace("Z", "+00:00") ) if ( start and updated_time.timestamp() < start ) or ( end and updated_time.timestamp() > end ): continue except (ValueError, TypeError): # Skip if date cannot be parsed logger.warning( f"Invalid date format for item {item_id}: {updated_at}" ) continue content = self._get_item_content(item_details) title = item_details.get("title", "") doc_batch.append( Document( id=f"HIGHSPOT_{item_id}", sections=[ TextSection( link=item_details.get( "url", f"https://www.highspot.com/items/{item_id}", ), text=content, ) ], source=DocumentSource.HIGHSPOT, semantic_identifier=title, metadata={ "spot_name": spot.name, "type": item_details.get( "content_type", "" ), "created_at": item_details.get( "date_added", "" ), "author": item_details.get("author", ""), "language": item_details.get( "language", "" ), "can_download": str( item_details.get("can_download", False) ), }, doc_updated_at=item_details.get("date_updated"), ) ) if len(doc_batch) >= self.batch_size: yield doc_batch doc_batch = [] except HighspotClientError as e: item_id = "ID" if not item_id else item_id logger.error( f"Error retrieving item {item_id}: {str(e)}" ) except Exception as e: item_id = "ID" if not item_id else item_id logger.error( f"Unexpected error for item {item_id}: {str(e)}" ) has_more = len(items) >= self.batch_size offset += self.batch_size except (HighspotClientError, ValueError) as e: logger.error(f"Error processing spot {spot.name}: {str(e)}") raise except Exception as e: logger.error( f"Unexpected error processing spot {spot.name}: {str(e)}" ) raise except Exception as e: logger.error(f"Error in Highspot connector: {str(e)}") raise if doc_batch: yield doc_batch def _get_item_content(self, item_details: Dict[str, Any]) -> str: """ Get the text content of an item. Args: item_details: Item details from the API Returns: Text content of the item """ item_id = item_details.get("id", "") content_name = item_details.get("content_name", "") is_valid_format = content_name and "." in content_name file_extension = content_name.split(".")[-1].lower() if is_valid_format else "" file_extension = "." + file_extension if file_extension else "" can_download = item_details.get("can_download", False) content_type = item_details.get("content_type", "") # Extract title and description once at the beginning title, description = self._extract_title_and_description(item_details) default_content = f"{title}\n{description}" logger.info( f"Processing item {item_id} with extension {file_extension} and file name {content_name}" ) try: if content_type == "WebLink": url = item_details.get("url") if not url: return default_content content = scrape_url_content(url, True) return content if content else default_content elif ( is_valid_format and file_extension in OnyxFileExtensions.TEXT_AND_DOCUMENT_EXTENSIONS and can_download ): content_response = self.client.get_item_content(item_id) # Process and extract text from binary content based on type if content_response: text_content = extract_file_text( BytesIO(content_response), content_name, False ) return text_content if text_content else default_content return default_content else: logger.warning( f"Item {item_id} has unsupported format: {file_extension}" ) return default_content except HighspotClientError as e: error_context = f"item {item_id}" if item_id else "(item id not found)" logger.warning(f"Could not retrieve content for {error_context}: {str(e)}") return default_content except ValueError as e: error_context = f"item {item_id}" if item_id else "(item id not found)" logger.error(f"Value error for {error_context}: {str(e)}") return default_content except Exception as e: error_context = f"item {item_id}" if item_id else "(item id not found)" logger.error( f"Unexpected error retrieving content for {error_context}: {str(e)}" ) return default_content def _extract_title_and_description( self, item_details: Dict[str, Any] ) -> tuple[str, str]: """ Extract the title and description from item details. Args: item_details: Item details from the API Returns: Tuple of title and description """ title = item_details.get("title", "") description = item_details.get("description", "") return title, description def retrieve_all_slim_docs_perm_sync( self, start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002 end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002 callback: IndexingHeartbeatInterface | None = None, # noqa: ARG002 ) -> GenerateSlimDocumentOutput: """ Retrieve all document IDs from the configured spots. If no spots are configured, retrieves from all spots. Args: start: Optional start time filter end: Optional end time filter callback: Optional indexing heartbeat callback Yields: Batches of SlimDocument objects """ spots_to_process = self._fetch_spots_to_process() slim_doc_batch: list[SlimDocument | HierarchyNode] = [] try: for spot in spots_to_process: try: offset = 0 has_more = True while has_more: logger.info( f"Retrieving slim documents from spot {spot.name}, offset {offset}" ) response = self.client.get_spot_items( spot_id=spot.id, offset=offset, page_size=self.batch_size ) items = response.get("collection", []) if not items: has_more = False continue for item in items: item_id = item.get("id") if not item_id: logger.warning("Item without ID found, skipping") continue slim_doc_batch.append( SlimDocument(id=f"HIGHSPOT_{item_id}") ) if len(slim_doc_batch) >= _SLIM_BATCH_SIZE: yield slim_doc_batch slim_doc_batch = [] has_more = len(items) >= self.batch_size offset += self.batch_size except (HighspotClientError, ValueError): logger.exception( f"Error retrieving slim documents from spot {spot.name}" ) raise if slim_doc_batch: yield slim_doc_batch except Exception: logger.exception("Error in Highspot Slim Connector") raise def validate_credentials(self) -> bool: """ Validate that the provided credentials can access the Highspot API. Returns: True if credentials are valid, False otherwise """ try: return self.client.health_check() except Exception as e: logger.error(f"Failed to validate credentials: {str(e)}") return False if __name__ == "__main__": spot_names: List[str] = [] connector = HighspotConnector(spot_names) credentials = { "highspot_key": os.environ.get("HIGHSPOT_KEY"), "highspot_secret": os.environ.get("HIGHSPOT_SECRET"), } connector.load_credentials(credentials=credentials) for doc in connector.load_from_state(): print(doc) ================================================ FILE: backend/onyx/connectors/highspot/utils.py ================================================ from typing import Optional from urllib.parse import urlparse from bs4 import BeautifulSoup from playwright.sync_api import sync_playwright from onyx.file_processing.html_utils import web_html_cleanup from onyx.utils.logger import setup_logger logger = setup_logger() # Constants WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20 JAVASCRIPT_DISABLED_MESSAGE = "You have JavaScript disabled in your browser" DEFAULT_TIMEOUT = 60000 # 60 seconds def scrape_url_content( url: str, scroll_before_scraping: bool = False, timeout_ms: int = DEFAULT_TIMEOUT ) -> Optional[str]: """ Scrapes content from a given URL and returns the cleaned text. Args: url: The URL to scrape scroll_before_scraping: Whether to scroll through the page to load lazy content timeout_ms: Timeout in milliseconds for page navigation and loading Returns: The cleaned text content of the page or None if scraping fails """ playwright = None browser = None try: validate_url(url) playwright = sync_playwright().start() browser = playwright.chromium.launch(headless=True) context = browser.new_context() page = context.new_page() logger.info(f"Navigating to URL: {url}") try: page.goto(url, timeout=timeout_ms) except Exception as e: logger.error(f"Failed to navigate to {url}: {str(e)}") return None if scroll_before_scraping: logger.debug("Scrolling page to load lazy content") scroll_attempts = 0 previous_height = page.evaluate("document.body.scrollHeight") while scroll_attempts < WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS: page.evaluate("window.scrollTo(0, document.body.scrollHeight)") try: page.wait_for_load_state("networkidle", timeout=timeout_ms) except Exception as e: logger.warning(f"Network idle wait timed out: {str(e)}") break new_height = page.evaluate("document.body.scrollHeight") if new_height == previous_height: break previous_height = new_height scroll_attempts += 1 content = page.content() soup = BeautifulSoup(content, "html.parser") parsed_html = web_html_cleanup(soup) if JAVASCRIPT_DISABLED_MESSAGE in parsed_html.cleaned_text: logger.debug("JavaScript disabled message detected, checking iframes") try: iframe_count = page.frame_locator("iframe").locator("html").count() if iframe_count > 0: iframe_texts = ( page.frame_locator("iframe").locator("html").all_inner_texts() ) iframe_content = "\n".join(iframe_texts) if len(parsed_html.cleaned_text) < 700: parsed_html.cleaned_text = iframe_content else: parsed_html.cleaned_text += "\n" + iframe_content except Exception as e: logger.warning(f"Error processing iframes: {str(e)}") return parsed_html.cleaned_text except Exception as e: logger.error(f"Error scraping URL {url}: {str(e)}") return None finally: if browser: try: browser.close() except Exception as e: logger.debug(f"Error closing browser: {str(e)}") if playwright: try: playwright.stop() except Exception as e: logger.debug(f"Error stopping playwright: {str(e)}") def validate_url(url: str) -> None: """ Validates that a URL is properly formatted. Args: url: The URL to validate Raises: ValueError: If URL is not valid """ parse = urlparse(url) if parse.scheme != "http" and parse.scheme != "https": raise ValueError("URL must be of scheme https?://") if not parse.hostname: raise ValueError("URL must include a hostname") ================================================ FILE: backend/onyx/connectors/hubspot/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/hubspot/connector.py ================================================ import re from collections.abc import Callable from collections.abc import Generator from datetime import datetime from datetime import timezone from typing import Any from typing import cast from typing import TypeVar import requests from hubspot import HubSpot # type: ignore from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.hubspot.rate_limit import HubSpotRateLimiter from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import ImageSection from onyx.connectors.models import TextSection from onyx.utils.logger import setup_logger HUBSPOT_BASE_URL = "https://app.hubspot.com" HUBSPOT_API_URL = "https://api.hubapi.com/integrations/v1/me" AVAILABLE_OBJECT_TYPES = {"tickets", "companies", "deals", "contacts"} HUBSPOT_PAGE_SIZE = 100 T = TypeVar("T") logger = setup_logger() class HubSpotConnector(LoadConnector, PollConnector): def __init__( self, batch_size: int = INDEX_BATCH_SIZE, access_token: str | None = None, object_types: list[str] | None = None, ) -> None: self.batch_size = batch_size self._access_token = access_token self._portal_id: str | None = None self._rate_limiter = HubSpotRateLimiter() # Set object types to fetch, default to all available types if object_types is None: self.object_types = AVAILABLE_OBJECT_TYPES.copy() else: object_types_set = set(object_types) # Validate provided object types invalid_types = object_types_set - AVAILABLE_OBJECT_TYPES if invalid_types: raise ValueError( f"Invalid object types: {invalid_types}. Available types: {AVAILABLE_OBJECT_TYPES}" ) self.object_types = object_types_set.copy() @property def access_token(self) -> str: """Get the access token, raising an exception if not set.""" if self._access_token is None: raise ConnectorMissingCredentialError("HubSpot access token not set") return self._access_token @access_token.setter def access_token(self, value: str | None) -> None: """Set the access token.""" self._access_token = value @property def portal_id(self) -> str: """Get the portal ID, raising an exception if not set.""" if self._portal_id is None: raise ConnectorMissingCredentialError("HubSpot portal ID not set") return self._portal_id @portal_id.setter def portal_id(self, value: str | None) -> None: """Set the portal ID.""" self._portal_id = value def _call_hubspot(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T: return self._rate_limiter.call(func, *args, **kwargs) def _paginated_results( self, fetch_page: Callable[..., Any], **kwargs: Any, ) -> Generator[Any, None, None]: base_kwargs = dict(kwargs) base_kwargs.setdefault("limit", HUBSPOT_PAGE_SIZE) after: str | None = None while True: page_kwargs = base_kwargs.copy() if after is not None: page_kwargs["after"] = after page = self._call_hubspot(fetch_page, **page_kwargs) results = getattr(page, "results", []) for result in results: yield result paging = getattr(page, "paging", None) next_page = getattr(paging, "next", None) if paging else None if next_page is None: break after = getattr(next_page, "after", None) if after is None: break def _clean_html_content(self, html_content: str) -> str: """Clean HTML content and extract raw text""" if not html_content: return "" # Remove HTML tags using regex clean_text = re.sub(r"<[^>]+>", "", html_content) # Decode common HTML entities clean_text = clean_text.replace(" ", " ") clean_text = clean_text.replace("&", "&") clean_text = clean_text.replace("<", "<") clean_text = clean_text.replace(">", ">") clean_text = clean_text.replace(""", '"') clean_text = clean_text.replace("'", "'") # Clean up whitespace clean_text = " ".join(clean_text.split()) return clean_text.strip() def get_portal_id(self) -> str: headers = { "Authorization": f"Bearer {self.access_token}", "Content-Type": "application/json", } response = requests.get(HUBSPOT_API_URL, headers=headers) if response.status_code != 200: raise Exception("Error fetching portal ID") data = response.json() return str(data["portalId"]) def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self.access_token = cast(str, credentials["hubspot_access_token"]) self.portal_id = self.get_portal_id() return None def _get_object_url(self, object_type: str, object_id: str) -> str: """Generate HubSpot URL for different object types""" if object_type == "tickets": return ( f"{HUBSPOT_BASE_URL}/contacts/{self.portal_id}/record/0-5/{object_id}" ) elif object_type == "companies": return ( f"{HUBSPOT_BASE_URL}/contacts/{self.portal_id}/record/0-2/{object_id}" ) elif object_type == "deals": return ( f"{HUBSPOT_BASE_URL}/contacts/{self.portal_id}/record/0-3/{object_id}" ) elif object_type == "contacts": return ( f"{HUBSPOT_BASE_URL}/contacts/{self.portal_id}/record/0-1/{object_id}" ) elif object_type == "notes": return ( f"{HUBSPOT_BASE_URL}/contacts/{self.portal_id}/objects/0-4/{object_id}" ) else: return f"{HUBSPOT_BASE_URL}/contacts/{self.portal_id}/{object_type}/{object_id}" def _get_associated_objects( self, api_client: HubSpot, object_id: str, from_object_type: str, to_object_type: str, ) -> list[dict[str, Any]]: """Get associated objects for a given object""" try: associations_iter = self._paginated_results( api_client.crm.associations.v4.basic_api.get_page, object_type=from_object_type, object_id=object_id, to_object_type=to_object_type, ) object_ids = [assoc.to_object_id for assoc in associations_iter] associated_objects: list[dict[str, Any]] = [] if to_object_type == "contacts": for obj_id in object_ids: try: obj = self._call_hubspot( api_client.crm.contacts.basic_api.get_by_id, contact_id=obj_id, properties=[ "firstname", "lastname", "email", "company", "jobtitle", ], ) associated_objects.append(obj.to_dict()) except Exception as e: logger.warning(f"Failed to fetch contact {obj_id}: {e}") elif to_object_type == "companies": for obj_id in object_ids: try: obj = self._call_hubspot( api_client.crm.companies.basic_api.get_by_id, company_id=obj_id, properties=[ "name", "domain", "industry", "city", "state", ], ) associated_objects.append(obj.to_dict()) except Exception as e: logger.warning(f"Failed to fetch company {obj_id}: {e}") elif to_object_type == "deals": for obj_id in object_ids: try: obj = self._call_hubspot( api_client.crm.deals.basic_api.get_by_id, deal_id=obj_id, properties=[ "dealname", "amount", "dealstage", "closedate", "pipeline", ], ) associated_objects.append(obj.to_dict()) except Exception as e: logger.warning(f"Failed to fetch deal {obj_id}: {e}") elif to_object_type == "tickets": for obj_id in object_ids: try: obj = self._call_hubspot( api_client.crm.tickets.basic_api.get_by_id, ticket_id=obj_id, properties=["subject", "content", "hs_ticket_priority"], ) associated_objects.append(obj.to_dict()) except Exception as e: logger.warning(f"Failed to fetch ticket {obj_id}: {e}") return associated_objects except Exception as e: logger.warning( f"Failed to get associations from {from_object_type} to {to_object_type}: {e}" ) return [] def _get_associated_notes( self, api_client: HubSpot, object_id: str, object_type: str, ) -> list[dict[str, Any]]: """Get notes associated with a given object""" try: associations_iter = self._paginated_results( api_client.crm.associations.v4.basic_api.get_page, object_type=object_type, object_id=object_id, to_object_type="notes", ) note_ids = [assoc.to_object_id for assoc in associations_iter] associated_notes = [] for note_id in note_ids: try: # Notes are engagements in HubSpot, use the engagements API note = self._call_hubspot( api_client.crm.objects.notes.basic_api.get_by_id, note_id=note_id, properties=[ "hs_note_body", "hs_timestamp", "hs_created_by", "hubspot_owner_id", ], ) associated_notes.append(note.to_dict()) except Exception as e: logger.warning(f"Failed to fetch note {note_id}: {e}") return associated_notes except Exception as e: logger.warning(f"Failed to get notes for {object_type} {object_id}: {e}") return [] def _create_object_section( self, obj: dict[str, Any], object_type: str ) -> TextSection: """Create a TextSection for an associated object""" obj_id = obj.get("id", "") properties = obj.get("properties", {}) if object_type == "contacts": name_parts = [] if properties.get("firstname"): name_parts.append(properties["firstname"]) if properties.get("lastname"): name_parts.append(properties["lastname"]) if name_parts: name = " ".join(name_parts) elif properties.get("email"): # Use email as fallback if no first/last name name = properties["email"] else: name = "Unknown Contact" content_parts = [f"Contact: {name}"] if properties.get("email"): content_parts.append(f"Email: {properties['email']}") if properties.get("company"): content_parts.append(f"Company: {properties['company']}") if properties.get("jobtitle"): content_parts.append(f"Job Title: {properties['jobtitle']}") elif object_type == "companies": name = properties.get("name", "Unknown Company") content_parts = [f"Company: {name}"] if properties.get("domain"): content_parts.append(f"Domain: {properties['domain']}") if properties.get("industry"): content_parts.append(f"Industry: {properties['industry']}") if properties.get("city") and properties.get("state"): content_parts.append( f"Location: {properties['city']}, {properties['state']}" ) elif object_type == "deals": name = properties.get("dealname", "Unknown Deal") content_parts = [f"Deal: {name}"] if properties.get("amount"): content_parts.append(f"Amount: ${properties['amount']}") if properties.get("dealstage"): content_parts.append(f"Stage: {properties['dealstage']}") if properties.get("closedate"): content_parts.append(f"Close Date: {properties['closedate']}") if properties.get("pipeline"): content_parts.append(f"Pipeline: {properties['pipeline']}") elif object_type == "tickets": name = properties.get("subject", "Unknown Ticket") content_parts = [f"Ticket: {name}"] if properties.get("content"): content_parts.append(f"Content: {properties['content']}") if properties.get("hs_ticket_priority"): content_parts.append(f"Priority: {properties['hs_ticket_priority']}") elif object_type == "notes": # Notes have a body property that contains the note content body = properties.get("hs_note_body", "") timestamp = properties.get("hs_timestamp", "") # Clean HTML content to get raw text clean_body = self._clean_html_content(body) # Use full content, not truncated content_parts = [f"Note: {clean_body}"] if timestamp: content_parts.append(f"Created: {timestamp}") else: content_parts = [f"{object_type.capitalize()}: {obj_id}"] content = "\n".join(content_parts) link = self._get_object_url(object_type, obj_id) return TextSection(link=link, text=content) def _process_tickets( self, start: datetime | None = None, end: datetime | None = None ) -> GenerateDocumentsOutput: api_client = HubSpot(access_token=self.access_token) tickets_iter = self._paginated_results( api_client.crm.tickets.basic_api.get_page, properties=[ "subject", "content", "hs_ticket_priority", "createdate", "hs_lastmodifieddate", ], associations=["contacts", "companies", "deals"], ) doc_batch: list[Document | HierarchyNode] = [] for ticket in tickets_iter: updated_at = ticket.updated_at.replace(tzinfo=None) if start is not None and updated_at < start.replace(tzinfo=None): continue if end is not None and updated_at > end.replace(tzinfo=None): continue title = ticket.properties.get("subject") or f"Ticket {ticket.id}" link = self._get_object_url("tickets", ticket.id) content_text = ticket.properties.get("content") or "" # Main ticket section sections = [TextSection(link=link, text=content_text)] # Metadata with parent object IDs metadata: dict[str, str | list[str]] = { "object_type": "ticket", } if ticket.properties.get("hs_ticket_priority"): metadata["priority"] = ticket.properties["hs_ticket_priority"] # Add associated objects as sections associated_contact_ids = [] associated_company_ids = [] associated_deal_ids = [] # Get associated contacts associated_contacts = self._get_associated_objects( api_client, ticket.id, "tickets", "contacts" ) for contact in associated_contacts: sections.append(self._create_object_section(contact, "contacts")) associated_contact_ids.append(contact["id"]) # Get associated companies associated_companies = self._get_associated_objects( api_client, ticket.id, "tickets", "companies" ) for company in associated_companies: sections.append(self._create_object_section(company, "companies")) associated_company_ids.append(company["id"]) # Get associated deals associated_deals = self._get_associated_objects( api_client, ticket.id, "tickets", "deals" ) for deal in associated_deals: sections.append(self._create_object_section(deal, "deals")) associated_deal_ids.append(deal["id"]) # Get associated notes associated_notes = self._get_associated_notes( api_client, ticket.id, "tickets" ) for note in associated_notes: sections.append(self._create_object_section(note, "notes")) # Add association IDs to metadata if associated_contact_ids: metadata["associated_contact_ids"] = associated_contact_ids if associated_company_ids: metadata["associated_company_ids"] = associated_company_ids if associated_deal_ids: metadata["associated_deal_ids"] = associated_deal_ids doc_batch.append( Document( id=f"hubspot_ticket_{ticket.id}", sections=cast(list[TextSection | ImageSection], sections), source=DocumentSource.HUBSPOT, semantic_identifier=title, doc_updated_at=ticket.updated_at.replace(tzinfo=timezone.utc), metadata=metadata, doc_metadata={ "hierarchy": { "source_path": ["Tickets"], "object_type": "ticket", "object_id": ticket.id, } }, ) ) if len(doc_batch) >= self.batch_size: yield doc_batch doc_batch = [] if doc_batch: yield doc_batch def _process_companies( self, start: datetime | None = None, end: datetime | None = None ) -> GenerateDocumentsOutput: api_client = HubSpot(access_token=self.access_token) companies_iter = self._paginated_results( api_client.crm.companies.basic_api.get_page, properties=[ "name", "domain", "industry", "city", "state", "description", "createdate", "hs_lastmodifieddate", ], associations=["contacts", "deals", "tickets"], ) doc_batch: list[Document | HierarchyNode] = [] for company in companies_iter: updated_at = company.updated_at.replace(tzinfo=None) if start is not None and updated_at < start.replace(tzinfo=None): continue if end is not None and updated_at > end.replace(tzinfo=None): continue title = company.properties.get("name") or f"Company {company.id}" link = self._get_object_url("companies", company.id) # Build main content content_parts = [f"Company: {title}"] if company.properties.get("domain"): content_parts.append(f"Domain: {company.properties['domain']}") if company.properties.get("industry"): content_parts.append(f"Industry: {company.properties['industry']}") if company.properties.get("city") and company.properties.get("state"): content_parts.append( f"Location: {company.properties['city']}, {company.properties['state']}" ) if company.properties.get("description"): content_parts.append( f"Description: {company.properties['description']}" ) content_text = "\n".join(content_parts) # Main company section sections = [TextSection(link=link, text=content_text)] # Metadata with parent object IDs metadata: dict[str, str | list[str]] = { "company_id": company.id, "object_type": "company", } if company.properties.get("industry"): metadata["industry"] = company.properties["industry"] if company.properties.get("domain"): metadata["domain"] = company.properties["domain"] # Add associated objects as sections associated_contact_ids = [] associated_deal_ids = [] associated_ticket_ids = [] # Get associated contacts associated_contacts = self._get_associated_objects( api_client, company.id, "companies", "contacts" ) for contact in associated_contacts: sections.append(self._create_object_section(contact, "contacts")) associated_contact_ids.append(contact["id"]) # Get associated deals associated_deals = self._get_associated_objects( api_client, company.id, "companies", "deals" ) for deal in associated_deals: sections.append(self._create_object_section(deal, "deals")) associated_deal_ids.append(deal["id"]) # Get associated tickets associated_tickets = self._get_associated_objects( api_client, company.id, "companies", "tickets" ) for ticket in associated_tickets: sections.append(self._create_object_section(ticket, "tickets")) associated_ticket_ids.append(ticket["id"]) # Get associated notes associated_notes = self._get_associated_notes( api_client, company.id, "companies" ) for note in associated_notes: sections.append(self._create_object_section(note, "notes")) # Add association IDs to metadata if associated_contact_ids: metadata["associated_contact_ids"] = associated_contact_ids if associated_deal_ids: metadata["associated_deal_ids"] = associated_deal_ids if associated_ticket_ids: metadata["associated_ticket_ids"] = associated_ticket_ids doc_batch.append( Document( id=f"hubspot_company_{company.id}", sections=cast(list[TextSection | ImageSection], sections), source=DocumentSource.HUBSPOT, semantic_identifier=title, doc_updated_at=company.updated_at.replace(tzinfo=timezone.utc), metadata=metadata, doc_metadata={ "hierarchy": { "source_path": ["Companies"], "object_type": "company", "object_id": company.id, } }, ) ) if len(doc_batch) >= self.batch_size: yield doc_batch doc_batch = [] if doc_batch: yield doc_batch def _process_deals( self, start: datetime | None = None, end: datetime | None = None ) -> GenerateDocumentsOutput: api_client = HubSpot(access_token=self.access_token) deals_iter = self._paginated_results( api_client.crm.deals.basic_api.get_page, properties=[ "dealname", "amount", "dealstage", "closedate", "pipeline", "description", "createdate", "hs_lastmodifieddate", ], associations=["contacts", "companies", "tickets"], ) doc_batch: list[Document | HierarchyNode] = [] for deal in deals_iter: updated_at = deal.updated_at.replace(tzinfo=None) if start is not None and updated_at < start.replace(tzinfo=None): continue if end is not None and updated_at > end.replace(tzinfo=None): continue title = deal.properties.get("dealname") or f"Deal {deal.id}" link = self._get_object_url("deals", deal.id) # Build main content content_parts = [f"Deal: {title}"] if deal.properties.get("amount"): content_parts.append(f"Amount: ${deal.properties['amount']}") if deal.properties.get("dealstage"): content_parts.append(f"Stage: {deal.properties['dealstage']}") if deal.properties.get("closedate"): content_parts.append(f"Close Date: {deal.properties['closedate']}") if deal.properties.get("pipeline"): content_parts.append(f"Pipeline: {deal.properties['pipeline']}") if deal.properties.get("description"): content_parts.append(f"Description: {deal.properties['description']}") content_text = "\n".join(content_parts) # Main deal section sections = [TextSection(link=link, text=content_text)] # Metadata with parent object IDs metadata: dict[str, str | list[str]] = { "deal_id": deal.id, "object_type": "deal", } if deal.properties.get("dealstage"): metadata["deal_stage"] = deal.properties["dealstage"] if deal.properties.get("pipeline"): metadata["pipeline"] = deal.properties["pipeline"] if deal.properties.get("amount"): metadata["amount"] = deal.properties["amount"] # Add associated objects as sections associated_contact_ids = [] associated_company_ids = [] associated_ticket_ids = [] # Get associated contacts associated_contacts = self._get_associated_objects( api_client, deal.id, "deals", "contacts" ) for contact in associated_contacts: sections.append(self._create_object_section(contact, "contacts")) associated_contact_ids.append(contact["id"]) # Get associated companies associated_companies = self._get_associated_objects( api_client, deal.id, "deals", "companies" ) for company in associated_companies: sections.append(self._create_object_section(company, "companies")) associated_company_ids.append(company["id"]) # Get associated tickets associated_tickets = self._get_associated_objects( api_client, deal.id, "deals", "tickets" ) for ticket in associated_tickets: sections.append(self._create_object_section(ticket, "tickets")) associated_ticket_ids.append(ticket["id"]) # Get associated notes associated_notes = self._get_associated_notes(api_client, deal.id, "deals") for note in associated_notes: sections.append(self._create_object_section(note, "notes")) # Add association IDs to metadata if associated_contact_ids: metadata["associated_contact_ids"] = associated_contact_ids if associated_company_ids: metadata["associated_company_ids"] = associated_company_ids if associated_ticket_ids: metadata["associated_ticket_ids"] = associated_ticket_ids doc_batch.append( Document( id=f"hubspot_deal_{deal.id}", sections=cast(list[TextSection | ImageSection], sections), source=DocumentSource.HUBSPOT, semantic_identifier=title, doc_updated_at=deal.updated_at.replace(tzinfo=timezone.utc), metadata=metadata, doc_metadata={ "hierarchy": { "source_path": ["Deals"], "object_type": "deal", "object_id": deal.id, } }, ) ) if len(doc_batch) >= self.batch_size: yield doc_batch doc_batch = [] if doc_batch: yield doc_batch def _process_contacts( self, start: datetime | None = None, end: datetime | None = None ) -> GenerateDocumentsOutput: api_client = HubSpot(access_token=self.access_token) contacts_iter = self._paginated_results( api_client.crm.contacts.basic_api.get_page, properties=[ "firstname", "lastname", "email", "company", "jobtitle", "phone", "city", "state", "createdate", "lastmodifieddate", ], associations=["companies", "deals", "tickets"], ) doc_batch: list[Document | HierarchyNode] = [] for contact in contacts_iter: updated_at = contact.updated_at.replace(tzinfo=None) if start is not None and updated_at < start.replace(tzinfo=None): continue if end is not None and updated_at > end.replace(tzinfo=None): continue # Build contact name name_parts = [] if contact.properties.get("firstname"): name_parts.append(contact.properties["firstname"]) if contact.properties.get("lastname"): name_parts.append(contact.properties["lastname"]) if name_parts: title = " ".join(name_parts) elif contact.properties.get("email"): # Use email as fallback if no first/last name title = contact.properties["email"] else: title = f"Contact {contact.id}" link = self._get_object_url("contacts", contact.id) # Build main content content_parts = [f"Contact: {title}"] if contact.properties.get("email"): content_parts.append(f"Email: {contact.properties['email']}") if contact.properties.get("company"): content_parts.append(f"Company: {contact.properties['company']}") if contact.properties.get("jobtitle"): content_parts.append(f"Job Title: {contact.properties['jobtitle']}") if contact.properties.get("phone"): content_parts.append(f"Phone: {contact.properties['phone']}") if contact.properties.get("city") and contact.properties.get("state"): content_parts.append( f"Location: {contact.properties['city']}, {contact.properties['state']}" ) content_text = "\n".join(content_parts) # Main contact section sections = [TextSection(link=link, text=content_text)] # Metadata with parent object IDs metadata: dict[str, str | list[str]] = { "contact_id": contact.id, "object_type": "contact", } if contact.properties.get("email"): metadata["email"] = contact.properties["email"] if contact.properties.get("company"): metadata["company"] = contact.properties["company"] if contact.properties.get("jobtitle"): metadata["job_title"] = contact.properties["jobtitle"] # Add associated objects as sections associated_company_ids = [] associated_deal_ids = [] associated_ticket_ids = [] # Get associated companies associated_companies = self._get_associated_objects( api_client, contact.id, "contacts", "companies" ) for company in associated_companies: sections.append(self._create_object_section(company, "companies")) associated_company_ids.append(company["id"]) # Get associated deals associated_deals = self._get_associated_objects( api_client, contact.id, "contacts", "deals" ) for deal in associated_deals: sections.append(self._create_object_section(deal, "deals")) associated_deal_ids.append(deal["id"]) # Get associated tickets associated_tickets = self._get_associated_objects( api_client, contact.id, "contacts", "tickets" ) for ticket in associated_tickets: sections.append(self._create_object_section(ticket, "tickets")) associated_ticket_ids.append(ticket["id"]) # Get associated notes associated_notes = self._get_associated_notes( api_client, contact.id, "contacts" ) for note in associated_notes: sections.append(self._create_object_section(note, "notes")) # Add association IDs to metadata if associated_company_ids: metadata["associated_company_ids"] = associated_company_ids if associated_deal_ids: metadata["associated_deal_ids"] = associated_deal_ids if associated_ticket_ids: metadata["associated_ticket_ids"] = associated_ticket_ids doc_batch.append( Document( id=f"hubspot_contact_{contact.id}", sections=cast(list[TextSection | ImageSection], sections), source=DocumentSource.HUBSPOT, semantic_identifier=title, doc_updated_at=contact.updated_at.replace(tzinfo=timezone.utc), metadata=metadata, doc_metadata={ "hierarchy": { "source_path": ["Contacts"], "object_type": "contact", "object_id": contact.id, } }, ) ) if len(doc_batch) >= self.batch_size: yield doc_batch doc_batch = [] if doc_batch: yield doc_batch def load_from_state(self) -> GenerateDocumentsOutput: """Load all HubSpot objects (tickets, companies, deals, contacts)""" # Process each object type based on configuration if "tickets" in self.object_types: yield from self._process_tickets() if "companies" in self.object_types: yield from self._process_companies() if "deals" in self.object_types: yield from self._process_deals() if "contacts" in self.object_types: yield from self._process_contacts() def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: start_datetime = datetime.fromtimestamp(start, tz=timezone.utc) end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) # Process each object type with time filtering based on configuration if "tickets" in self.object_types: yield from self._process_tickets(start_datetime, end_datetime) if "companies" in self.object_types: yield from self._process_companies(start_datetime, end_datetime) if "deals" in self.object_types: yield from self._process_deals(start_datetime, end_datetime) if "contacts" in self.object_types: yield from self._process_contacts(start_datetime, end_datetime) if __name__ == "__main__": import os connector = HubSpotConnector() connector.load_credentials( {"hubspot_access_token": os.environ["HUBSPOT_ACCESS_TOKEN"]} ) # Run the first example document_batches = connector.load_from_state() first_batch = next(document_batches) for doc in first_batch: print(doc.model_dump_json(indent=2)) ================================================ FILE: backend/onyx/connectors/hubspot/rate_limit.py ================================================ from __future__ import annotations import time from collections.abc import Callable from typing import Any from typing import TypeVar from onyx.connectors.cross_connector_utils.rate_limit_wrapper import ( rate_limit_builder, ) from onyx.connectors.cross_connector_utils.rate_limit_wrapper import ( RateLimitTriedTooManyTimesError, ) from onyx.utils.logger import setup_logger logger = setup_logger() T = TypeVar("T") # HubSpot exposes a ten second rolling window (x-hubspot-ratelimit-interval-milliseconds) # with a maximum of 190 requests, and a per-second limit of 19 requests. _HUBSPOT_TEN_SECOND_LIMIT = 190 _HUBSPOT_TEN_SECOND_PERIOD = 10 # seconds _HUBSPOT_SECONDLY_LIMIT = 19 _HUBSPOT_SECONDLY_PERIOD = 1 # second _DEFAULT_SLEEP_SECONDS = 10 _SLEEP_PADDING_SECONDS = 1.0 _MAX_RATE_LIMIT_RETRIES = 5 def _extract_header(headers: Any, key: str) -> str | None: if headers is None: return None getter = getattr(headers, "get", None) if callable(getter): value = getter(key) if value is not None: return value if isinstance(headers, dict): value = headers.get(key) if value is not None: return value return None def is_rate_limit_error(exception: Exception) -> bool: status = getattr(exception, "status", None) if status == 429: return True headers = getattr(exception, "headers", None) if headers is not None: remaining = _extract_header(headers, "x-hubspot-ratelimit-remaining") if remaining == "0": return True secondly_remaining = _extract_header( headers, "x-hubspot-ratelimit-secondly-remaining" ) if secondly_remaining == "0": return True message = str(exception) return "RATE_LIMIT" in message or "Too Many Requests" in message def get_rate_limit_retry_delay_seconds(exception: Exception) -> float: headers = getattr(exception, "headers", None) retry_after = _extract_header(headers, "Retry-After") if retry_after: try: return float(retry_after) + _SLEEP_PADDING_SECONDS except ValueError: logger.debug( "Failed to parse Retry-After header '%s' as float", retry_after ) interval_ms = _extract_header(headers, "x-hubspot-ratelimit-interval-milliseconds") if interval_ms: try: return float(interval_ms) / 1000.0 + _SLEEP_PADDING_SECONDS except ValueError: logger.debug( "Failed to parse x-hubspot-ratelimit-interval-milliseconds '%s' as float", interval_ms, ) secondly_limit = _extract_header(headers, "x-hubspot-ratelimit-secondly") if secondly_limit: try: per_second = max(float(secondly_limit), 1.0) return (1.0 / per_second) + _SLEEP_PADDING_SECONDS except ValueError: logger.debug( "Failed to parse x-hubspot-ratelimit-secondly '%s' as float", secondly_limit, ) return _DEFAULT_SLEEP_SECONDS + _SLEEP_PADDING_SECONDS class HubSpotRateLimiter: def __init__( self, *, ten_second_limit: int = _HUBSPOT_TEN_SECOND_LIMIT, ten_second_period: int = _HUBSPOT_TEN_SECOND_PERIOD, secondly_limit: int = _HUBSPOT_SECONDLY_LIMIT, secondly_period: int = _HUBSPOT_SECONDLY_PERIOD, max_retries: int = _MAX_RATE_LIMIT_RETRIES, ) -> None: self._max_retries = max_retries @rate_limit_builder(max_calls=secondly_limit, period=secondly_period) @rate_limit_builder(max_calls=ten_second_limit, period=ten_second_period) def _execute(callable_: Callable[[], T]) -> T: return callable_() self._execute = _execute def call(self, func: Callable[..., T], *args: Any, **kwargs: Any) -> T: attempts = 0 while True: try: return self._execute(lambda: func(*args, **kwargs)) except Exception as exc: # pylint: disable=broad-except if not is_rate_limit_error(exc): raise attempts += 1 if attempts > self._max_retries: raise RateLimitTriedTooManyTimesError( "Exceeded configured HubSpot rate limit retries" ) from exc wait_time = get_rate_limit_retry_delay_seconds(exc) logger.notice( "HubSpot rate limit reached. Sleeping %.2f seconds before retrying.", wait_time, ) time.sleep(wait_time) ================================================ FILE: backend/onyx/connectors/imap/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/imap/connector.py ================================================ import copy import email import imaplib import os import re from datetime import datetime from datetime import timezone from email.message import Message from email.utils import parseaddr from enum import Enum from typing import Any from typing import cast import bs4 from pydantic import BaseModel from onyx.access.models import ExternalAccess from onyx.configs.constants import DocumentSource from onyx.connectors.imap.models import EmailHeaders from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import CredentialsConnector from onyx.connectors.interfaces import CredentialsProviderInterface from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import ConnectorCheckpoint from onyx.connectors.models import Document from onyx.connectors.models import TextSection from onyx.utils.logger import setup_logger logger = setup_logger() _DEFAULT_IMAP_PORT_NUMBER = int(os.environ.get("IMAP_PORT", 993)) _IMAP_OKAY_STATUS = "OK" _PAGE_SIZE = 100 _USERNAME_KEY = "imap_username" _PASSWORD_KEY = "imap_password" class CurrentMailbox(BaseModel): mailbox: str todo_email_ids: list[str] # An email has a list of mailboxes. # Each mailbox has a list of email-ids inside of it. # # Usage: # To use this checkpointer, first fetch all the mailboxes. # Then, pop a mailbox and fetch all of its email-ids. # Then, pop each email-id and fetch its content (and parse it, etc..). # When you have popped all email-ids for this mailbox, pop the next mailbox and repeat the above process until you're done. # # For initial checkpointing, set both fields to `None`. class ImapCheckpoint(ConnectorCheckpoint): todo_mailboxes: list[str] | None = None current_mailbox: CurrentMailbox | None = None class LoginState(str, Enum): LoggedIn = "logged_in" LoggedOut = "logged_out" class ImapConnector( CredentialsConnector, CheckpointedConnectorWithPermSync[ImapCheckpoint], ): def __init__( self, host: str, port: int = _DEFAULT_IMAP_PORT_NUMBER, mailboxes: list[str] | None = None, ) -> None: self._host = host self._port = port self._mailboxes = mailboxes self._credentials: dict[str, Any] | None = None @property def credentials(self) -> dict[str, Any]: if not self._credentials: raise RuntimeError( "Credentials have not been initialized; call `set_credentials_provider` first" ) return self._credentials def _get_mail_client(self) -> imaplib.IMAP4_SSL: """ Returns a new `imaplib.IMAP4_SSL` instance. The `imaplib.IMAP4_SSL` object is supposed to be an "ephemeral" object; it's not something that you can login, logout, then log back into again. I.e., the following will fail: ```py mail_client.login(..) mail_client.logout(); mail_client.login(..) ``` Therefore, you need a fresh, new instance in order to operate with IMAP. This function gives one to you. # Notes This function will throw an error if the credentials have not yet been set. """ def get_or_raise(name: str) -> str: value = self.credentials.get(name) if not value: raise RuntimeError(f"Credential item {name=} was not found") if not isinstance(value, str): raise RuntimeError( f"Credential item {name=} must be of type str, instead received {type(name)=}" ) return value username = get_or_raise(_USERNAME_KEY) password = get_or_raise(_PASSWORD_KEY) mail_client = imaplib.IMAP4_SSL(host=self._host, port=self._port) status, _data = mail_client.login(user=username, password=password) if status != _IMAP_OKAY_STATUS: raise RuntimeError(f"Failed to log into imap server; {status=}") return mail_client def _load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: ImapCheckpoint, include_perm_sync: bool, ) -> CheckpointOutput[ImapCheckpoint]: checkpoint = cast(ImapCheckpoint, copy.deepcopy(checkpoint)) checkpoint.has_more = True mail_client = self._get_mail_client() if checkpoint.todo_mailboxes is None: # This is the dummy checkpoint. # Fill it with mailboxes first. if self._mailboxes: checkpoint.todo_mailboxes = _sanitize_mailbox_names(self._mailboxes) else: fetched_mailboxes = _fetch_all_mailboxes_for_email_account( mail_client=mail_client ) if not fetched_mailboxes: raise RuntimeError( "Failed to find any mailboxes for this email account" ) checkpoint.todo_mailboxes = _sanitize_mailbox_names(fetched_mailboxes) return checkpoint if ( not checkpoint.current_mailbox or not checkpoint.current_mailbox.todo_email_ids ): if not checkpoint.todo_mailboxes: checkpoint.has_more = False return checkpoint mailbox = checkpoint.todo_mailboxes.pop() email_ids = _fetch_email_ids_in_mailbox( mail_client=mail_client, mailbox=mailbox, start=start, end=end, ) checkpoint.current_mailbox = CurrentMailbox( mailbox=mailbox, todo_email_ids=email_ids, ) _select_mailbox( mail_client=mail_client, mailbox=checkpoint.current_mailbox.mailbox ) current_todos = cast( list, copy.deepcopy(checkpoint.current_mailbox.todo_email_ids[:_PAGE_SIZE]) ) checkpoint.current_mailbox.todo_email_ids = ( checkpoint.current_mailbox.todo_email_ids[_PAGE_SIZE:] ) for email_id in current_todos: email_msg = _fetch_email(mail_client=mail_client, email_id=email_id) if not email_msg: logger.warn(f"Failed to fetch message {email_id=}; skipping") continue email_headers = EmailHeaders.from_email_msg(email_msg=email_msg) yield _convert_email_headers_and_body_into_document( email_msg=email_msg, email_headers=email_headers, include_perm_sync=include_perm_sync, ) return checkpoint # impls for BaseConnector def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: raise NotImplementedError("Use `set_credentials_provider` instead") def validate_connector_settings(self) -> None: self._get_mail_client() # impls for CredentialsConnector def set_credentials_provider( self, credentials_provider: CredentialsProviderInterface ) -> None: self._credentials = credentials_provider.get_credentials() # impls for CheckpointedConnector def load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: ImapCheckpoint, ) -> CheckpointOutput[ImapCheckpoint]: return self._load_from_checkpoint( start=start, end=end, checkpoint=checkpoint, include_perm_sync=False ) def build_dummy_checkpoint(self) -> ImapCheckpoint: return ImapCheckpoint(has_more=True) def validate_checkpoint_json(self, checkpoint_json: str) -> ImapCheckpoint: return ImapCheckpoint.model_validate_json(json_data=checkpoint_json) # impls for CheckpointedConnectorWithPermSync def load_from_checkpoint_with_perm_sync( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: ImapCheckpoint, ) -> CheckpointOutput[ImapCheckpoint]: return self._load_from_checkpoint( start=start, end=end, checkpoint=checkpoint, include_perm_sync=True ) def _fetch_all_mailboxes_for_email_account(mail_client: imaplib.IMAP4_SSL) -> list[str]: status, mailboxes_data = mail_client.list(directory="*", pattern="*") if status != _IMAP_OKAY_STATUS: raise RuntimeError(f"Failed to fetch mailboxes; {status=}") mailboxes = [] for mailboxes_raw in mailboxes_data: if isinstance(mailboxes_raw, bytes): mailboxes_str = mailboxes_raw.decode() elif isinstance(mailboxes_raw, str): mailboxes_str = mailboxes_raw else: logger.warn( f"Expected the mailbox data to be of type str, instead got {type(mailboxes_raw)=} {mailboxes_raw}; skipping" ) continue # The mailbox LIST response output can be found here: # https://www.rfc-editor.org/rfc/rfc3501.html#section-7.2.2 # # The general format is: # `() ` # # The below regex matches on that pattern; from there, we select the 3rd match (index 2), which is the mailbox-name. match = re.match(r'\([^)]*\)\s+"([^"]+)"\s+"?(.+?)"?$', mailboxes_str) if not match: logger.warn( f"Invalid mailbox-data formatting structure: {mailboxes_str=}; skipping" ) continue mailbox = match.group(2) mailboxes.append(mailbox) return mailboxes def _select_mailbox(mail_client: imaplib.IMAP4_SSL, mailbox: str) -> None: status, _ids = mail_client.select(mailbox=mailbox, readonly=True) if status != _IMAP_OKAY_STATUS: raise RuntimeError(f"Failed to select {mailbox=}") def _fetch_email_ids_in_mailbox( mail_client: imaplib.IMAP4_SSL, mailbox: str, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, ) -> list[str]: _select_mailbox(mail_client=mail_client, mailbox=mailbox) start_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime("%d-%b-%Y") end_str = datetime.fromtimestamp(end, tz=timezone.utc).strftime("%d-%b-%Y") search_criteria = f'(SINCE "{start_str}" BEFORE "{end_str}")' status, email_ids_byte_array = mail_client.search(None, search_criteria) if status != _IMAP_OKAY_STATUS or not email_ids_byte_array: raise RuntimeError(f"Failed to fetch email ids; {status=}") email_ids: bytes = email_ids_byte_array[0] return [email_id.decode() for email_id in email_ids.split()] def _fetch_email(mail_client: imaplib.IMAP4_SSL, email_id: str) -> Message | None: status, msg_data = mail_client.fetch(message_set=email_id, message_parts="(RFC822)") if status != _IMAP_OKAY_STATUS or not msg_data: return None data = msg_data[0] if not isinstance(data, tuple): raise RuntimeError( f"Message data should be a tuple; instead got a {type(data)=} {data=}" ) _metadata, raw_email = data return email.message_from_bytes(raw_email) def _convert_email_headers_and_body_into_document( email_msg: Message, email_headers: EmailHeaders, include_perm_sync: bool, ) -> Document: sender_name, sender_addr = _parse_singular_addr(raw_header=email_headers.sender) parsed_recipients = ( _parse_addrs(raw_header=email_headers.recipients) if email_headers.recipients else [] ) expert_info_map = { recipient_addr: BasicExpertInfo( display_name=recipient_name, email=recipient_addr ) for recipient_name, recipient_addr in parsed_recipients } if sender_addr not in expert_info_map: expert_info_map[sender_addr] = BasicExpertInfo( display_name=sender_name, email=sender_addr ) email_body = _parse_email_body(email_msg=email_msg, email_headers=email_headers) primary_owners = list(expert_info_map.values()) external_access = ( ExternalAccess( external_user_emails=set(expert_info_map.keys()), external_user_group_ids=set(), is_public=False, ) if include_perm_sync else None ) return Document( id=email_headers.id, title=email_headers.subject, semantic_identifier=email_headers.subject, metadata={}, source=DocumentSource.IMAP, sections=[TextSection(text=email_body)], primary_owners=primary_owners, external_access=external_access, ) def _parse_email_body( email_msg: Message, email_headers: EmailHeaders, ) -> str: body = None for part in email_msg.walk(): if part.is_multipart(): # Multipart parts are *containers* for other parts, not the actual content itself. # Therefore, we skip until we find the individual parts instead. continue charset = part.get_content_charset() or "utf-8" try: raw_payload = part.get_payload(decode=True) if not isinstance(raw_payload, bytes): logger.warn( "Payload section from email was expected to be an array of bytes, instead got " f"{type(raw_payload)=}, {raw_payload=}" ) continue body = raw_payload.decode(charset) break except (UnicodeDecodeError, LookupError) as e: print(f"Warning: Could not decode part with charset {charset}. Error: {e}") continue if not body: logger.warn( f"Email with {email_headers.id=} has an empty body; returning an empty string" ) return "" soup = bs4.BeautifulSoup(markup=body, features="html.parser") return " ".join(str_section for str_section in soup.stripped_strings) def _sanitize_mailbox_names(mailboxes: list[str]) -> list[str]: """ Mailboxes with special characters in them must be enclosed by double-quotes, as per the IMAP protocol. Just to be safe, we wrap *all* mailboxes with double-quotes. """ return [f'"{mailbox}"' for mailbox in mailboxes if mailbox] def _parse_addrs(raw_header: str) -> list[tuple[str, str]]: addrs = raw_header.split(",") name_addr_pairs = [parseaddr(addr=addr) for addr in addrs if addr] return [(name, addr) for name, addr in name_addr_pairs if addr] def _parse_singular_addr(raw_header: str) -> tuple[str, str]: addrs = _parse_addrs(raw_header=raw_header) if not addrs: raise RuntimeError( f"Parsing email header resulted in no addresses being found; {raw_header=}" ) elif len(addrs) >= 2: raise RuntimeError( f"Expected a singular address, but instead got multiple; {raw_header=} {addrs=}" ) return addrs[0] if __name__ == "__main__": import time from tests.daily.connectors.utils import load_all_from_connector from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider host = os.environ.get("IMAP_HOST") mailboxes_str = os.environ.get("IMAP_MAILBOXES") username = os.environ.get("IMAP_USERNAME") password = os.environ.get("IMAP_PASSWORD") mailboxes = ( [mailbox.strip() for mailbox in mailboxes_str.split(",")] if mailboxes_str else [] ) if not host: raise RuntimeError("`IMAP_HOST` must be set") imap_connector = ImapConnector( host=host, mailboxes=mailboxes, ) imap_connector.set_credentials_provider( OnyxStaticCredentialsProvider( tenant_id=None, connector_name=DocumentSource.IMAP, credential_json={ _USERNAME_KEY: username, _PASSWORD_KEY: password, }, ) ) for doc in load_all_from_connector( connector=imap_connector, start=0, end=time.time(), ).documents: print(doc) ================================================ FILE: backend/onyx/connectors/imap/models.py ================================================ import email from datetime import datetime from email.message import Message from enum import Enum from pydantic import BaseModel class Header(str, Enum): SUBJECT_HEADER = "subject" FROM_HEADER = "from" TO_HEADER = "to" DELIVERED_TO_HEADER = ( "Delivered-To" # Used in mailing lists instead of the "to" header. ) DATE_HEADER = "date" MESSAGE_ID_HEADER = "Message-ID" class EmailHeaders(BaseModel): """ Model for email headers extracted from IMAP messages. """ id: str subject: str sender: str recipients: str | None date: datetime @classmethod def from_email_msg(cls, email_msg: Message) -> "EmailHeaders": def _decode(header: str, default: str | None = None) -> str | None: value = email_msg.get(header, default) if not value: return None decoded_value, encoding = email.header.decode_header(value)[0] if isinstance(decoded_value, bytes): encoding = encoding or "utf-8" return decoded_value.decode(encoding, errors="replace") elif isinstance(decoded_value, str): return decoded_value else: return None def _parse_date(date_str: str | None) -> datetime | None: if not date_str: return None try: return email.utils.parsedate_to_datetime(date_str) except (TypeError, ValueError): return None message_id = _decode(header=Header.MESSAGE_ID_HEADER) # It's possible for the subject line to not exist or be an empty string. subject = _decode(header=Header.SUBJECT_HEADER) or "Unknown Subject" from_ = _decode(header=Header.FROM_HEADER) to = _decode(header=Header.TO_HEADER) if not to: to = _decode(header=Header.DELIVERED_TO_HEADER) date_str = _decode(header=Header.DATE_HEADER) date = _parse_date(date_str=date_str) # If any of the above are `None`, model validation will fail. # Therefore, no guards (i.e.: `if
is None: raise RuntimeError(..)`) were written. return cls.model_validate( { "id": message_id, "subject": subject, "sender": from_, "recipients": to, "date": date, } ) ================================================ FILE: backend/onyx/connectors/interfaces.py ================================================ import abc from collections.abc import Generator from collections.abc import Iterator from types import TracebackType from typing import Any from typing import Generic from typing import TypeAlias from typing import TypeVar from pydantic import BaseModel from onyx.configs.constants import DocumentSource from onyx.connectors.models import ConnectorCheckpoint from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import SlimDocument from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop SecondsSinceUnixEpoch = float # Output types that can include HierarchyNode alongside Documents/SlimDocuments GenerateDocumentsOutput = Iterator[list[Document | HierarchyNode]] GenerateSlimDocumentOutput = Iterator[list[SlimDocument | HierarchyNode]] CT = TypeVar("CT", bound=ConnectorCheckpoint) class NormalizationResult(BaseModel): """Result of URL normalization attempt. Attributes: normalized_url: The normalized URL string, or None if normalization failed use_default: If True, fall back to default normalizer. If False, return None. """ normalized_url: str | None use_default: bool = False class BaseConnector(abc.ABC, Generic[CT]): REDIS_KEY_PREFIX = "da_connector_data:" @abc.abstractmethod def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: raise NotImplementedError @staticmethod def parse_metadata(metadata: dict[str, Any]) -> list[str]: """Parse the metadata for a document/chunk into a string to pass to Generative AI as additional context""" custom_parser_req_msg = ( "Specific metadata parsing required, connector has not implemented it." ) metadata_lines = [] for metadata_key, metadata_value in metadata.items(): if isinstance(metadata_value, str): metadata_lines.append(f"{metadata_key}: {metadata_value}") elif isinstance(metadata_value, list): if not all([isinstance(val, str) for val in metadata_value]): raise RuntimeError(custom_parser_req_msg) metadata_lines.append(f"{metadata_key}: {', '.join(metadata_value)}") else: raise RuntimeError(custom_parser_req_msg) return metadata_lines def validate_connector_settings(self) -> None: """ Override this if your connector needs to validate credentials or settings. Raise an exception if invalid, otherwise do nothing. Default is a no-op (always successful). """ def validate_perm_sync(self) -> None: """ Don't override this; add a function to perm_sync_valid.py in the ee package to do permission sync validation """ validate_connector_settings_fn = fetch_ee_implementation_or_noop( "onyx.connectors.perm_sync_valid", "validate_perm_sync", noop_return_value=None, ) validate_connector_settings_fn(self) def set_allow_images(self, value: bool) -> None: """Implement if the underlying connector wants to skip/allow image downloading based on the application level image analysis setting.""" @classmethod def normalize_url(cls, url: str) -> "NormalizationResult": # noqa: ARG003 """Normalize a URL to match the canonical Document.id format used during ingestion. Connectors that use URLs as document IDs should override this method. Returns NormalizationResult with use_default=True if not implemented. """ return NormalizationResult(normalized_url=None, use_default=True) def build_dummy_checkpoint(self) -> CT: # TODO: find a way to make this work without type: ignore return ConnectorCheckpoint(has_more=True) # type: ignore # Large set update or reindex, generally pulling a complete state or from a savestate file class LoadConnector(BaseConnector): @abc.abstractmethod def load_from_state(self) -> GenerateDocumentsOutput: raise NotImplementedError # Small set updates by time class PollConnector(BaseConnector): @abc.abstractmethod def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: raise NotImplementedError # Slim connectors retrieve just the ids of documents class SlimConnector(BaseConnector): @abc.abstractmethod def retrieve_all_slim_docs( self, ) -> GenerateSlimDocumentOutput: raise NotImplementedError # Slim connectors retrieve both the ids AND # permission syncing information for connected documents class SlimConnectorWithPermSync(BaseConnector): @abc.abstractmethod def retrieve_all_slim_docs_perm_sync( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> GenerateSlimDocumentOutput: raise NotImplementedError class OAuthConnector(BaseConnector): class AdditionalOauthKwargs(BaseModel): # if overridden, all fields should be str type pass @classmethod @abc.abstractmethod def oauth_id(cls) -> DocumentSource: raise NotImplementedError @classmethod @abc.abstractmethod def oauth_authorization_url( cls, base_domain: str, state: str, additional_kwargs: dict[str, str], ) -> str: raise NotImplementedError @classmethod @abc.abstractmethod def oauth_code_to_token( cls, base_domain: str, code: str, additional_kwargs: dict[str, str], ) -> dict[str, Any]: raise NotImplementedError T = TypeVar("T", bound="CredentialsProviderInterface") class CredentialsProviderInterface(abc.ABC, Generic[T]): @abc.abstractmethod def __enter__(self) -> T: raise NotImplementedError @abc.abstractmethod def __exit__( self, exc_type: type[BaseException] | None, exc_value: BaseException | None, traceback: TracebackType | None, ) -> None: raise NotImplementedError @abc.abstractmethod def get_tenant_id(self) -> str | None: raise NotImplementedError @abc.abstractmethod def get_provider_key(self) -> str: """a unique key that the connector can use to lock around a credential that might be used simultaneously. Will typically be the credential id, but can also just be something random in cases when there is nothing to lock (aka static credentials) """ raise NotImplementedError @abc.abstractmethod def get_credentials(self) -> dict[str, Any]: raise NotImplementedError @abc.abstractmethod def set_credentials(self, credential_json: dict[str, Any]) -> None: raise NotImplementedError @abc.abstractmethod def is_dynamic(self) -> bool: """If dynamic, the credentials may change during usage ... meaning the client needs to use the locking features of the credentials provider to operate correctly. If static, the client can simply reference the credentials once and use them through the entire indexing run. """ raise NotImplementedError class CredentialsConnector(BaseConnector): """Implement this if the connector needs to be able to read and write credentials on the fly. Typically used with shared credentials/tokens that might be renewed at any time.""" @abc.abstractmethod def set_credentials_provider( self, credentials_provider: CredentialsProviderInterface ) -> None: raise NotImplementedError # Event driven class EventConnector(BaseConnector): @abc.abstractmethod def handle_event(self, event: Any) -> GenerateDocumentsOutput: raise NotImplementedError CheckpointOutput: TypeAlias = Generator[ Document | HierarchyNode | ConnectorFailure, None, CT ] HierarchyOutput: TypeAlias = Generator[HierarchyNode, None, None] class CheckpointedConnector(BaseConnector[CT]): @abc.abstractmethod def load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: CT, ) -> CheckpointOutput[CT]: """Yields back documents or failures. Final return is the new checkpoint. Final return can be access via either: ``` try: for document_or_failure in connector.load_from_checkpoint(start, end, checkpoint): print(document_or_failure) except StopIteration as e: checkpoint = e.value # Extracting the return value print(checkpoint) ``` OR ``` checkpoint = yield from connector.load_from_checkpoint(start, end, checkpoint) ``` """ raise NotImplementedError @abc.abstractmethod def build_dummy_checkpoint(self) -> CT: raise NotImplementedError @abc.abstractmethod def validate_checkpoint_json(self, checkpoint_json: str) -> CT: """Validate the checkpoint json and return the checkpoint object""" raise NotImplementedError class CheckpointedConnectorWithPermSync(CheckpointedConnector[CT]): @abc.abstractmethod def load_from_checkpoint_with_perm_sync( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: CT, ) -> CheckpointOutput[CT]: raise NotImplementedError class HierarchyConnector(BaseConnector): @abc.abstractmethod def load_hierarchy( self, start: SecondsSinceUnixEpoch, # may be unused if the connector must load the full hierarchy each time end: SecondsSinceUnixEpoch, ) -> HierarchyOutput: raise NotImplementedError ================================================ FILE: backend/onyx/connectors/jira/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/jira/access.py ================================================ """ Permissioning / AccessControl logic for JIRA Projects + Issues. """ from collections.abc import Callable from typing import cast from jira import JIRA from onyx.access.models import ExternalAccess from onyx.utils.variable_functionality import fetch_versioned_implementation from onyx.utils.variable_functionality import global_version def get_project_permissions( jira_client: JIRA, jira_project: str, add_prefix: bool = False, ) -> ExternalAccess | None: """ Fetch the project + issue level permissions / access-control. This functionality requires Enterprise Edition. Args: jira_client: The JIRA client instance. jira_project: The JIRA project string. add_prefix: When True, prefix group IDs with source type (for indexing path). When False (default), leave unprefixed (for permission sync path where upsert_document_external_perms handles prefixing). Returns: ExternalAccess object for the page. None if EE is not enabled or no restrictions found. """ # Check if EE is enabled if not global_version.is_ee_version(): return None ee_get_project_permissions = cast( Callable[ [JIRA, str, bool], ExternalAccess | None, ], fetch_versioned_implementation( "onyx.external_permissions.jira.page_access", "get_project_permissions" ), ) return ee_get_project_permissions( jira_client, jira_project, add_prefix, ) ================================================ FILE: backend/onyx/connectors/jira/connector.py ================================================ import copy import json import os from collections.abc import Callable from collections.abc import Generator from collections.abc import Iterable from collections.abc import Iterator from datetime import datetime from datetime import timedelta from datetime import timezone from typing import Any import requests from jira import JIRA from jira.exceptions import JIRAError from jira.resources import Issue from more_itertools import chunked from typing_extensions import override from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP from onyx.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE from onyx.configs.app_configs import JIRA_SLIM_PAGE_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.miscellaneous_utils import ( is_atlassian_date_error, ) from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.exceptions import UnexpectedValidationError from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnectorWithPermSync from onyx.connectors.jira.access import get_project_permissions from onyx.connectors.jira.utils import best_effort_basic_expert_info from onyx.connectors.jira.utils import best_effort_get_field_from_issue from onyx.connectors.jira.utils import build_jira_client from onyx.connectors.jira.utils import build_jira_url from onyx.connectors.jira.utils import extract_text_from_adf from onyx.connectors.jira.utils import get_comment_strs from onyx.connectors.jira.utils import JIRA_CLOUD_API_VERSION from onyx.connectors.models import ConnectorCheckpoint from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import DocumentFailure from onyx.connectors.models import HierarchyNode from onyx.connectors.models import SlimDocument from onyx.connectors.models import TextSection from onyx.db.enums import HierarchyNodeType from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger logger = setup_logger() ONE_HOUR = 3600 _MAX_RESULTS_FETCH_IDS = 5000 # 5000 _JIRA_FULL_PAGE_SIZE = 50 # Constants for Jira field names _FIELD_REPORTER = "reporter" _FIELD_ASSIGNEE = "assignee" _FIELD_PRIORITY = "priority" _FIELD_STATUS = "status" _FIELD_RESOLUTION = "resolution" _FIELD_LABELS = "labels" _FIELD_KEY = "key" _FIELD_CREATED = "created" _FIELD_DUEDATE = "duedate" _FIELD_ISSUETYPE = "issuetype" _FIELD_PARENT = "parent" _FIELD_ASSIGNEE_EMAIL = "assignee_email" _FIELD_REPORTER_EMAIL = "reporter_email" _FIELD_PROJECT = "project" _FIELD_PROJECT_NAME = "project_name" _FIELD_UPDATED = "updated" _FIELD_RESOLUTION_DATE = "resolutiondate" _FIELD_RESOLUTION_DATE_KEY = "resolution_date" def _is_cloud_client(jira_client: JIRA) -> bool: return jira_client._options["rest_api_version"] == JIRA_CLOUD_API_VERSION def _perform_jql_search( jira_client: JIRA, jql: str, start: int, max_results: int, fields: str | None = None, all_issue_ids: list[list[str]] | None = None, checkpoint_callback: ( Callable[[Iterator[list[str]], str | None], None] | None ) = None, nextPageToken: str | None = None, ids_done: bool = False, ) -> Iterable[Issue]: """ The caller should expect a) this function returns an iterable of issues of length 0 < len(issues) <= max_results. - caveat; if all_issue_ids is provided, the iterable will be the size of some sub-list. - this will only not match the above bound if a recent deployment changed max_results. IF the v3 API is used (i.e. the jira instance is a cloud instance), then the caller should expect: b) this function will call checkpoint_callback ONCE after at least one of the following has happened: - a new batch of ids has been fetched via enhanced search - a batch of issues has been bulk-fetched c) checkpoint_callback is called with the new all_issue_ids and the pageToken of the enhanced search request. We pass in a pageToken of None once we've fetched all the issue ids. Note: nextPageToken is valid for 7 days according to a post from a year ago, so for now we won't add any handling for restarting (just re-index, since there's no easy way to recover from this). """ # it would be preferable to use one approach for both versions, but # v2 doesnt have the bulk fetch api and v3 has fully deprecated the search # api that v2 uses if _is_cloud_client(jira_client): if all_issue_ids is None: raise ValueError("all_issue_ids is required for v3") return _perform_jql_search_v3( jira_client, jql, max_results, all_issue_ids, fields=fields, checkpoint_callback=checkpoint_callback, nextPageToken=nextPageToken, ids_done=ids_done, ) else: return _perform_jql_search_v2(jira_client, jql, start, max_results, fields) def _handle_jira_search_error(e: Exception, jql: str) -> None: """Handle common Jira search errors and raise appropriate exceptions. Args: e: The exception raised by the Jira API jql: The JQL query that caused the error Raises: ConnectorValidationError: For HTTP 400 errors (invalid JQL or project) CredentialExpiredError: For HTTP 401 errors InsufficientPermissionsError: For HTTP 403 errors Exception: Re-raises the original exception for other error types """ # Extract error information from the exception error_text = "" status_code = None def _format_error_text(error_payload: Any) -> str: error_messages = ( error_payload.get("errorMessages", []) if isinstance(error_payload, dict) else [] ) if error_messages: return ( "; ".join(error_messages) if isinstance(error_messages, list) else str(error_messages) ) return str(error_payload) # Try to get status code and error text from JIRAError or requests response if hasattr(e, "status_code"): status_code = e.status_code raw_text = getattr(e, "text", "") if isinstance(raw_text, str): try: error_text = _format_error_text(json.loads(raw_text)) except Exception: error_text = raw_text else: error_text = str(raw_text) elif hasattr(e, "response") and e.response is not None: status_code = e.response.status_code # Try JSON first, fall back to text try: error_json = e.response.json() error_text = _format_error_text(error_json) except Exception: error_text = e.response.text # Handle specific status codes if status_code == 400: if "does not exist for the field 'project'" in error_text: raise ConnectorValidationError( f"The specified Jira project does not exist or you don't have access to it. JQL query: {jql}. Error: {error_text}" ) raise ConnectorValidationError( f"Invalid JQL query. JQL: {jql}. Error: {error_text}" ) elif status_code == 401: raise CredentialExpiredError( "Jira credentials are expired or invalid (HTTP 401)." ) elif status_code == 403: raise InsufficientPermissionsError( f"Insufficient permissions to execute JQL query. JQL: {jql}" ) # Re-raise for other error types raise e def enhanced_search_ids( jira_client: JIRA, jql: str, nextPageToken: str | None = None ) -> tuple[list[str], str | None]: # https://community.atlassian.com/forums/Jira-articles/ # Avoiding-Pitfalls-A-Guide-to-Smooth-Migration-to-Enhanced-JQL/ba-p/2985433 # For cloud, it's recommended that we fetch all ids first then use the bulk fetch API. # The enhanced search isn't currently supported by our python library, so we have to # do this janky thing where we use the session directly. enhanced_search_path = jira_client._get_url("search/jql") params: dict[str, str | int | None] = { "jql": jql, "maxResults": _MAX_RESULTS_FETCH_IDS, "nextPageToken": nextPageToken, "fields": "id", } try: response = jira_client._session.get(enhanced_search_path, params=params) response.raise_for_status() response_json = response.json() except Exception as e: _handle_jira_search_error(e, jql) raise # Explicitly re-raise for type checker, should never reach here return [str(issue["id"]) for issue in response_json["issues"]], response_json.get( "nextPageToken" ) def _bulk_fetch_request( jira_client: JIRA, issue_ids: list[str], fields: str | None ) -> list[dict[str, Any]]: """Raw POST to the bulkfetch endpoint. Returns the list of raw issue dicts.""" bulk_fetch_path = jira_client._get_url("issue/bulkfetch") # Prepare the payload according to Jira API v3 specification payload: dict[str, Any] = {"issueIdsOrKeys": issue_ids} # Only restrict fields if specified, might want to explicitly do this in the future # to avoid reading unnecessary data payload["fields"] = fields.split(",") if fields else ["*all"] resp = jira_client._session.post(bulk_fetch_path, json=payload) return resp.json()["issues"] def bulk_fetch_issues( jira_client: JIRA, issue_ids: list[str], fields: str | None = None ) -> list[Issue]: # TODO(evan): move away from this jira library if they continue to not support # the endpoints we need. Using private fields is not ideal, but # is likely fine for now since we pin the library version try: raw_issues = _bulk_fetch_request(jira_client, issue_ids, fields) except requests.exceptions.JSONDecodeError: if len(issue_ids) <= 1: logger.exception( f"Jira bulk-fetch response for issue(s) {issue_ids} could not " f"be decoded as JSON (response too large or truncated)." ) raise mid = len(issue_ids) // 2 logger.warning( f"Jira bulk-fetch JSON decode failed for batch of {len(issue_ids)} issues. " f"Splitting into sub-batches of {mid} and {len(issue_ids) - mid}." ) left = bulk_fetch_issues(jira_client, issue_ids[:mid], fields) right = bulk_fetch_issues(jira_client, issue_ids[mid:], fields) return left + right except Exception as e: logger.error(f"Error fetching issues: {e}") raise return [ Issue(jira_client._options, jira_client._session, raw=issue) for issue in raw_issues ] def _perform_jql_search_v3( jira_client: JIRA, jql: str, max_results: int, all_issue_ids: list[list[str]], fields: str | None = None, checkpoint_callback: ( Callable[[Iterator[list[str]], str | None], None] | None ) = None, nextPageToken: str | None = None, ids_done: bool = False, ) -> Iterable[Issue]: """ The way this works is we get all the issue ids and bulk fetch them in batches. However, for really large deployments we can't do these operations sequentially, as it might take several hours to fetch all the issue ids. So, each run of this function does at least one of: - fetch a batch of issue ids - bulk fetch a batch of issues If all_issue_ids is not None, we use it to bulk fetch issues. """ # with some careful synchronization these steps can be done in parallel, # leaving that out for now to avoid rate limit issues if not ids_done: new_ids, pageToken = enhanced_search_ids(jira_client, jql, nextPageToken) if checkpoint_callback is not None: checkpoint_callback(chunked(new_ids, max_results), pageToken) # bulk fetch issues from ids. Note that the above callback MAY mutate all_issue_ids, # but this fetch always just takes the last id batch. if all_issue_ids: yield from bulk_fetch_issues(jira_client, all_issue_ids.pop(), fields) def _perform_jql_search_v2( jira_client: JIRA, jql: str, start: int, max_results: int, fields: str | None = None, ) -> Iterable[Issue]: """ Unfortunately, jira server/data center will forever use the v2 APIs that are now deprecated. """ logger.debug( f"Fetching Jira issues with JQL: {jql}, starting at {start}, max results: {max_results}" ) try: issues = jira_client.search_issues( jql_str=jql, startAt=start, maxResults=max_results, fields=fields, ) except JIRAError as e: _handle_jira_search_error(e, jql) raise # Explicitly re-raise for type checker, should never reach here for issue in issues: if isinstance(issue, Issue): yield issue else: raise RuntimeError(f"Found Jira object not of type Issue: {issue}") def process_jira_issue( jira_base_url: str, issue: Issue, comment_email_blacklist: tuple[str, ...] = (), labels_to_skip: set[str] | None = None, parent_hierarchy_raw_node_id: str | None = None, ) -> Document | None: if labels_to_skip: if any(label in issue.fields.labels for label in labels_to_skip): logger.info( f"Skipping {issue.key} because it has a label to skip. Found " f"labels: {issue.fields.labels}. Labels to skip: {labels_to_skip}." ) return None if isinstance(issue.fields.description, str): description = issue.fields.description else: description = extract_text_from_adf(issue.raw["fields"]["description"]) comments = get_comment_strs( issue=issue, comment_email_blacklist=comment_email_blacklist, ) ticket_content = f"{description}\n" + "\n".join( [f"Comment: {comment}" for comment in comments if comment] ) # Check ticket size if len(ticket_content.encode("utf-8")) > JIRA_CONNECTOR_MAX_TICKET_SIZE: logger.info( f"Skipping {issue.key} because it exceeds the maximum size of {JIRA_CONNECTOR_MAX_TICKET_SIZE} bytes." ) return None page_url = build_jira_url(jira_base_url, issue.key) metadata_dict: dict[str, str | list[str]] = {} people = set() creator = best_effort_get_field_from_issue(issue, _FIELD_REPORTER) if creator is not None and ( basic_expert_info := best_effort_basic_expert_info(creator) ): people.add(basic_expert_info) metadata_dict[_FIELD_REPORTER] = basic_expert_info.get_semantic_name() if email := basic_expert_info.get_email(): metadata_dict[_FIELD_REPORTER_EMAIL] = email assignee = best_effort_get_field_from_issue(issue, _FIELD_ASSIGNEE) if assignee is not None and ( basic_expert_info := best_effort_basic_expert_info(assignee) ): people.add(basic_expert_info) metadata_dict[_FIELD_ASSIGNEE] = basic_expert_info.get_semantic_name() if email := basic_expert_info.get_email(): metadata_dict[_FIELD_ASSIGNEE_EMAIL] = email metadata_dict[_FIELD_KEY] = issue.key if priority := best_effort_get_field_from_issue(issue, _FIELD_PRIORITY): metadata_dict[_FIELD_PRIORITY] = priority.name if status := best_effort_get_field_from_issue(issue, _FIELD_STATUS): metadata_dict[_FIELD_STATUS] = status.name if resolution := best_effort_get_field_from_issue(issue, _FIELD_RESOLUTION): metadata_dict[_FIELD_RESOLUTION] = resolution.name if labels := best_effort_get_field_from_issue(issue, _FIELD_LABELS): metadata_dict[_FIELD_LABELS] = labels if created := best_effort_get_field_from_issue(issue, _FIELD_CREATED): metadata_dict[_FIELD_CREATED] = created if updated := best_effort_get_field_from_issue(issue, _FIELD_UPDATED): metadata_dict[_FIELD_UPDATED] = updated if duedate := best_effort_get_field_from_issue(issue, _FIELD_DUEDATE): metadata_dict[_FIELD_DUEDATE] = duedate if issuetype := best_effort_get_field_from_issue(issue, _FIELD_ISSUETYPE): metadata_dict[_FIELD_ISSUETYPE] = issuetype.name if resolutiondate := best_effort_get_field_from_issue( issue, _FIELD_RESOLUTION_DATE ): metadata_dict[_FIELD_RESOLUTION_DATE_KEY] = resolutiondate parent = best_effort_get_field_from_issue(issue, _FIELD_PARENT) if parent is not None: metadata_dict[_FIELD_PARENT] = parent.key project = best_effort_get_field_from_issue(issue, _FIELD_PROJECT) if project is not None: metadata_dict[_FIELD_PROJECT_NAME] = project.name metadata_dict[_FIELD_PROJECT] = project.key else: logger.error(f"Project should exist but does not for {issue.key}") return Document( id=page_url, sections=[TextSection(link=page_url, text=ticket_content)], source=DocumentSource.JIRA, semantic_identifier=f"{issue.key}: {issue.fields.summary}", title=f"{issue.key} {issue.fields.summary}", doc_updated_at=time_str_to_utc(issue.fields.updated), primary_owners=list(people) or None, metadata=metadata_dict, parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id, ) class JiraConnectorCheckpoint(ConnectorCheckpoint): # used for v3 (cloud) endpoint all_issue_ids: list[list[str]] = [] ids_done: bool = False cursor: str | None = None # deprecated # Used for v2 endpoint (server/data center) offset: int | None = None # Track hierarchy nodes we've already yielded to avoid duplicates across restarts seen_hierarchy_node_ids: list[str] = [] class JiraConnector( CheckpointedConnectorWithPermSync[JiraConnectorCheckpoint], SlimConnectorWithPermSync, ): def __init__( self, jira_base_url: str, project_key: str | None = None, comment_email_blacklist: list[str] | None = None, batch_size: int = INDEX_BATCH_SIZE, # if a ticket has one of the labels specified in this list, we will just # skip it. This is generally used to avoid indexing extra sensitive # tickets. labels_to_skip: list[str] = JIRA_CONNECTOR_LABELS_TO_SKIP, # Custom JQL query to filter Jira issues jql_query: str | None = None, scoped_token: bool = False, ) -> None: self.batch_size = batch_size # dealing with scoped tokens is a bit tricky becasue we need to hit api.atlassian.net # when making jira requests but still want correct links to issues in the UI. # So, the user's base url is stored here, but converted to a scoped url when passed # to the jira client. self.jira_base = jira_base_url.rstrip("/") # Remove trailing slash if present self.jira_project = project_key self._comment_email_blacklist = comment_email_blacklist or [] self.labels_to_skip = set(labels_to_skip) self.jql_query = jql_query self.scoped_token = scoped_token self._jira_client: JIRA | None = None # Cache project permissions to avoid fetching them repeatedly across runs self._project_permissions_cache: dict[str, Any] = {} @property def comment_email_blacklist(self) -> tuple: return tuple(email.strip() for email in self._comment_email_blacklist) @property def jira_client(self) -> JIRA: if self._jira_client is None: raise ConnectorMissingCredentialError("Jira") return self._jira_client @property def quoted_jira_project(self) -> str: # Quote the project name to handle reserved words if not self.jira_project: return "" return f'"{self.jira_project}"' def _get_project_permissions( self, project_key: str, add_prefix: bool = False ) -> Any: """Get project permissions with caching. Args: project_key: The Jira project key add_prefix: When True, prefix group IDs with source type (for indexing path). When False (default), leave unprefixed (for permission sync path). Returns: The external access permissions for the project """ # Use different cache keys for prefixed vs unprefixed to avoid mixing cache_key = f"{project_key}:{'prefixed' if add_prefix else 'unprefixed'}" if cache_key not in self._project_permissions_cache: self._project_permissions_cache[cache_key] = get_project_permissions( jira_client=self.jira_client, jira_project=project_key, add_prefix=add_prefix, ) return self._project_permissions_cache[cache_key] def _is_epic(self, issue: Issue) -> bool: """Check if issue is an Epic.""" issuetype = best_effort_get_field_from_issue(issue, _FIELD_ISSUETYPE) if issuetype is None: return False return issuetype.name.lower() == "epic" def _is_parent_epic(self, parent: Any) -> bool: """Check if a parent reference is an Epic. The parent object from issue.fields.parent has a different structure than a full Issue, so we handle it separately. """ parent_issuetype = ( getattr(parent.fields, "issuetype", None) if hasattr(parent, "fields") else None ) if parent_issuetype is None: return False return parent_issuetype.name.lower() == "epic" def _yield_project_hierarchy_node( self, project_key: str, project_name: str | None, seen_hierarchy_node_ids: set[str], ) -> Generator[HierarchyNode, None, None]: """Yield a hierarchy node for a project if not already yielded.""" if project_key in seen_hierarchy_node_ids: return seen_hierarchy_node_ids.add(project_key) yield HierarchyNode( raw_node_id=project_key, raw_parent_id=None, # Parent is SOURCE display_name=project_name or project_key, link=f"{self.jira_base}/projects/{project_key}", node_type=HierarchyNodeType.PROJECT, ) def _yield_epic_hierarchy_node( self, issue: Issue, project_key: str, seen_hierarchy_node_ids: set[str], ) -> Generator[HierarchyNode, None, None]: """Yield a hierarchy node for an Epic issue.""" issue_key = issue.key if issue_key in seen_hierarchy_node_ids: return seen_hierarchy_node_ids.add(issue_key) yield HierarchyNode( raw_node_id=issue_key, raw_parent_id=project_key, display_name=f"{issue_key}: {issue.fields.summary}", link=build_jira_url(self.jira_base, issue_key), node_type=HierarchyNodeType.FOLDER, # don't have a separate epic node type ) def _yield_parent_hierarchy_node_if_epic( self, parent: Any, project_key: str, seen_hierarchy_node_ids: set[str], ) -> Generator[HierarchyNode, None, None]: """Yield hierarchy node for parent issue if it's an Epic we haven't seen.""" parent_key = parent.key if parent_key in seen_hierarchy_node_ids: return if not self._is_parent_epic(parent): # Not an epic, don't create hierarchy node for it return seen_hierarchy_node_ids.add(parent_key) # Get summary if available parent_summary = ( getattr(parent.fields, "summary", None) if hasattr(parent, "fields") else None ) display_name = ( f"{parent_key}: {parent_summary}" if parent_summary else parent_key ) yield HierarchyNode( raw_node_id=parent_key, raw_parent_id=project_key, display_name=display_name, link=build_jira_url(self.jira_base, parent_key), node_type=HierarchyNodeType.FOLDER, # don't have a separate epic node type ) def _get_parent_hierarchy_raw_node_id(self, issue: Issue, project_key: str) -> str: """Determine the parent hierarchy node ID for an issue. Returns: - Epic key if issue's parent is an Epic - Project key otherwise (for top-level issues or non-epic parents) """ parent = best_effort_get_field_from_issue(issue, _FIELD_PARENT) if parent is None: # No parent, directly under project return project_key if self._is_parent_epic(parent): return parent.key # For non-epic parents (e.g., story with subtasks), # the document belongs directly under the project in the hierarchy return project_key def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self._jira_client = build_jira_client( credentials=credentials, jira_base=self.jira_base, scoped_token=self.scoped_token, ) return None def _get_jql_query( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> str: """Get the JQL query based on configuration and time range If a custom JQL query is provided, it will be used and combined with time constraints. Otherwise, the query will be constructed based on project key (if provided). """ start_date_str = datetime.fromtimestamp(start, tz=timezone.utc).strftime( "%Y-%m-%d %H:%M" ) end_date_str = datetime.fromtimestamp(end, tz=timezone.utc).strftime( "%Y-%m-%d %H:%M" ) time_jql = f"updated >= '{start_date_str}' AND updated <= '{end_date_str}'" # If custom JQL query is provided, use it and combine with time constraints if self.jql_query: return f"({self.jql_query}) AND {time_jql}" # Otherwise, use project key if provided if self.jira_project: base_jql = f"project = {self.quoted_jira_project}" return f"{base_jql} AND {time_jql}" return time_jql def load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: JiraConnectorCheckpoint, ) -> CheckpointOutput[JiraConnectorCheckpoint]: jql = self._get_jql_query(start, end) try: return self._load_from_checkpoint( jql, checkpoint, include_permissions=False ) except Exception as e: if is_atlassian_date_error(e): jql = self._get_jql_query(start - ONE_HOUR, end) return self._load_from_checkpoint( jql, checkpoint, include_permissions=False ) raise e def load_from_checkpoint_with_perm_sync( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: JiraConnectorCheckpoint, ) -> CheckpointOutput[JiraConnectorCheckpoint]: """Load documents from checkpoint with permission information included.""" jql = self._get_jql_query(start, end) try: return self._load_from_checkpoint(jql, checkpoint, include_permissions=True) except Exception as e: if is_atlassian_date_error(e): jql = self._get_jql_query(start - ONE_HOUR, end) return self._load_from_checkpoint( jql, checkpoint, include_permissions=True ) raise e def _load_from_checkpoint( self, jql: str, checkpoint: JiraConnectorCheckpoint, include_permissions: bool ) -> CheckpointOutput[JiraConnectorCheckpoint]: # Get the current offset from checkpoint or start at 0 starting_offset = checkpoint.offset or 0 current_offset = starting_offset new_checkpoint = copy.deepcopy(checkpoint) # Convert checkpoint list to set for efficient lookups seen_hierarchy_node_ids = set(new_checkpoint.seen_hierarchy_node_ids) checkpoint_callback = make_checkpoint_callback(new_checkpoint) for issue in _perform_jql_search( jira_client=self.jira_client, jql=jql, start=current_offset, max_results=_JIRA_FULL_PAGE_SIZE, all_issue_ids=new_checkpoint.all_issue_ids, checkpoint_callback=checkpoint_callback, nextPageToken=new_checkpoint.cursor, ids_done=new_checkpoint.ids_done, ): issue_key = issue.key try: # Get project info for hierarchy project = best_effort_get_field_from_issue(issue, _FIELD_PROJECT) project_key = project.key if project else None project_name = project.name if project else None # Yield hierarchy nodes BEFORE the document (parent-before-child) if project_key: # 1. Yield project hierarchy node (if not already yielded) yield from self._yield_project_hierarchy_node( project_key, project_name, seen_hierarchy_node_ids ) # 2. If parent is an Epic, yield hierarchy node for it parent = best_effort_get_field_from_issue(issue, _FIELD_PARENT) if parent: yield from self._yield_parent_hierarchy_node_if_epic( parent, project_key, seen_hierarchy_node_ids ) # 3. If this issue IS an Epic, yield it as hierarchy node if self._is_epic(issue): yield from self._yield_epic_hierarchy_node( issue, project_key, seen_hierarchy_node_ids ) # Determine parent hierarchy node ID for the document parent_hierarchy_raw_node_id = ( self._get_parent_hierarchy_raw_node_id(issue, project_key) if project_key else None ) if document := process_jira_issue( jira_base_url=self.jira_base, issue=issue, comment_email_blacklist=self.comment_email_blacklist, labels_to_skip=self.labels_to_skip, parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id, ): # Add permission information to the document if requested if include_permissions: document.external_access = self._get_project_permissions( project_key, add_prefix=True, # Indexing path - prefix here ) yield document except Exception as e: yield ConnectorFailure( failed_document=DocumentFailure( document_id=issue_key, document_link=build_jira_url(self.jira_base, issue_key), ), failure_message=f"Failed to process Jira issue: {str(e)}", exception=e, ) current_offset += 1 # Update checkpoint with seen hierarchy nodes new_checkpoint.seen_hierarchy_node_ids = list(seen_hierarchy_node_ids) # Update checkpoint self.update_checkpoint_for_next_run( new_checkpoint, current_offset, starting_offset, _JIRA_FULL_PAGE_SIZE ) return new_checkpoint def update_checkpoint_for_next_run( self, checkpoint: JiraConnectorCheckpoint, current_offset: int, starting_offset: int, page_size: int, ) -> None: if _is_cloud_client(self.jira_client): # other updates done in the checkpoint callback checkpoint.has_more = ( len(checkpoint.all_issue_ids) > 0 or not checkpoint.ids_done ) else: checkpoint.offset = current_offset # if we didn't retrieve a full batch, we're done checkpoint.has_more = current_offset - starting_offset == page_size def retrieve_all_slim_docs_perm_sync( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, # noqa: ARG002 ) -> GenerateSlimDocumentOutput: one_day = timedelta(hours=24).total_seconds() start = start or 0 end = ( end or datetime.now().timestamp() + one_day ) # we add one day to account for any potential timezone issues jql = self._get_jql_query(start, end) checkpoint = self.build_dummy_checkpoint() checkpoint_callback = make_checkpoint_callback(checkpoint) prev_offset = 0 current_offset = 0 slim_doc_batch: list[SlimDocument | HierarchyNode] = [] # Track seen hierarchy nodes within this sync run seen_hierarchy_node_ids: set[str] = set() while checkpoint.has_more: for issue in _perform_jql_search( jira_client=self.jira_client, jql=jql, start=current_offset, max_results=JIRA_SLIM_PAGE_SIZE, all_issue_ids=checkpoint.all_issue_ids, checkpoint_callback=checkpoint_callback, nextPageToken=checkpoint.cursor, ids_done=checkpoint.ids_done, ): # Get project info project = best_effort_get_field_from_issue(issue, _FIELD_PROJECT) project_key = project.key if project else None project_name = project.name if project else None if not project_key: continue # Yield hierarchy nodes BEFORE the slim document (parent-before-child) # 1. Yield project hierarchy node (if not already yielded) for node in self._yield_project_hierarchy_node( project_key, project_name, seen_hierarchy_node_ids ): slim_doc_batch.append(node) # 2. If parent is an Epic, yield hierarchy node for it parent = best_effort_get_field_from_issue(issue, _FIELD_PARENT) if parent: for node in self._yield_parent_hierarchy_node_if_epic( parent, project_key, seen_hierarchy_node_ids ): slim_doc_batch.append(node) # 3. If this issue IS an Epic, yield it as hierarchy node if self._is_epic(issue): for node in self._yield_epic_hierarchy_node( issue, project_key, seen_hierarchy_node_ids ): slim_doc_batch.append(node) # Now add the slim document issue_key = best_effort_get_field_from_issue(issue, _FIELD_KEY) doc_id = build_jira_url(self.jira_base, issue_key) slim_doc_batch.append( SlimDocument( id=doc_id, # Permission sync path - don't prefix, upsert_document_external_perms handles it external_access=self._get_project_permissions( project_key, add_prefix=False ), parent_hierarchy_raw_node_id=( self._get_parent_hierarchy_raw_node_id(issue, project_key) if project_key else None ), ) ) current_offset += 1 if len(slim_doc_batch) >= JIRA_SLIM_PAGE_SIZE: yield slim_doc_batch slim_doc_batch = [] self.update_checkpoint_for_next_run( checkpoint, current_offset, prev_offset, JIRA_SLIM_PAGE_SIZE ) prev_offset = current_offset if slim_doc_batch: yield slim_doc_batch def validate_connector_settings(self) -> None: if self._jira_client is None: raise ConnectorMissingCredentialError("Jira") # If a custom JQL query is set, validate it's valid if self.jql_query: try: # Try to execute the JQL query with a small limit to validate its syntax # Use next(iter(...), None) to get just the first result without # forcing evaluation of all results next( iter( _perform_jql_search( jira_client=self.jira_client, jql=self.jql_query, start=0, max_results=1, all_issue_ids=[], ) ), None, ) except Exception as e: self._handle_jira_connector_settings_error(e) # If a specific project is set, validate it exists elif self.jira_project: try: self.jira_client.project(self.jira_project) except Exception as e: self._handle_jira_connector_settings_error(e) else: # If neither JQL nor project specified, validate we can access the Jira API try: # Try to list projects to validate access self.jira_client.projects() except Exception as e: self._handle_jira_connector_settings_error(e) def _handle_jira_connector_settings_error(self, e: Exception) -> None: """Helper method to handle Jira API errors consistently. Extracts error messages from the Jira API response for all status codes when possible, providing more user-friendly error messages. Args: e: The exception raised by the Jira API Raises: CredentialExpiredError: If the status code is 401 InsufficientPermissionsError: If the status code is 403 ConnectorValidationError: For other HTTP errors with extracted error messages """ status_code = getattr(e, "status_code", None) logger.error(f"Jira API error during validation: {e}") # Handle specific status codes with appropriate exceptions if status_code == 401: raise CredentialExpiredError( "Jira credential appears to be expired or invalid (HTTP 401)." ) elif status_code == 403: raise InsufficientPermissionsError( "Your Jira token does not have sufficient permissions for this configuration (HTTP 403)." ) elif status_code == 429: raise ConnectorValidationError( "Validation failed due to Jira rate-limits being exceeded. Please try again later." ) # Try to extract original error message from the response error_message = getattr(e, "text", None) if error_message is None: raise UnexpectedValidationError( f"Unexpected Jira error during validation: {e}" ) raise ConnectorValidationError( f"Validation failed due to Jira error: {error_message}" ) @override def validate_checkpoint_json(self, checkpoint_json: str) -> JiraConnectorCheckpoint: return JiraConnectorCheckpoint.model_validate_json(checkpoint_json) @override def build_dummy_checkpoint(self) -> JiraConnectorCheckpoint: return JiraConnectorCheckpoint( has_more=True, ) def make_checkpoint_callback( checkpoint: JiraConnectorCheckpoint, ) -> Callable[[Iterator[list[str]], str | None], None]: def checkpoint_callback( issue_ids: Iterator[list[str]], pageToken: str | None ) -> None: for id_batch in issue_ids: checkpoint.all_issue_ids.append(id_batch) checkpoint.cursor = pageToken # pageToken starts out as None and is only None once we've fetched all the issue ids checkpoint.ids_done = pageToken is None return checkpoint_callback if __name__ == "__main__": import os from onyx.utils.variable_functionality import global_version from tests.daily.connectors.utils import load_all_from_connector # For connector permission testing, set EE to true. global_version.set_ee() connector = JiraConnector( jira_base_url=os.environ["JIRA_BASE_URL"], project_key=os.environ.get("JIRA_PROJECT_KEY"), comment_email_blacklist=[], ) connector.load_credentials( { "jira_user_email": os.environ["JIRA_USER_EMAIL"], "jira_api_token": os.environ["JIRA_API_TOKEN"], } ) start = 0 end = datetime.now().timestamp() for slim_doc in connector.retrieve_all_slim_docs_perm_sync( start=start, end=end, ): print(slim_doc) for doc in load_all_from_connector( connector=connector, start=start, end=end, ).documents: print(doc) ================================================ FILE: backend/onyx/connectors/jira/utils.py ================================================ """Module with custom fields processing functions""" import os from typing import Any from typing import List from urllib.parse import urlparse from jira import JIRA from jira.resources import CustomFieldOption from jira.resources import Issue from jira.resources import User from onyx.connectors.cross_connector_utils.miscellaneous_utils import scoped_url from onyx.connectors.models import BasicExpertInfo from onyx.utils.logger import setup_logger logger = setup_logger() PROJECT_URL_PAT = "projects" JIRA_SERVER_API_VERSION = os.environ.get("JIRA_SERVER_API_VERSION") or "2" JIRA_CLOUD_API_VERSION = os.environ.get("JIRA_CLOUD_API_VERSION") or "3" def best_effort_basic_expert_info(obj: Any) -> BasicExpertInfo | None: display_name = None email = None try: if hasattr(obj, "displayName"): display_name = obj.displayName else: display_name = obj.get("displayName") if hasattr(obj, "emailAddress"): email = obj.emailAddress else: email = obj.get("emailAddress") except Exception: return None if not email and not display_name: return None return BasicExpertInfo(display_name=display_name, email=email) def best_effort_get_field_from_issue(jira_issue: Issue, field: str) -> Any: if hasattr(jira_issue, field): return getattr(jira_issue, field) if hasattr(jira_issue, "fields") and hasattr(jira_issue.fields, field): return getattr(jira_issue.fields, field) try: return jira_issue.raw["fields"][field] except Exception: return None def extract_text_from_adf(adf: dict | None) -> str: """Extracts plain text from Atlassian Document Format: https://developer.atlassian.com/cloud/jira/platform/apis/document/structure/ WARNING: This function is incomplete and will e.g. skip lists! """ # TODO: complete this function texts = [] if adf is not None and "content" in adf: for block in adf["content"]: if "content" in block: for item in block["content"]: if item["type"] == "text": texts.append(item["text"]) return " ".join(texts) def build_jira_url(jira_base_url: str, issue_key: str) -> str: """ Get the url used to access an issue in the UI. """ return f"{jira_base_url}/browse/{issue_key}" def build_jira_client( credentials: dict[str, Any], jira_base: str, scoped_token: bool = False ) -> JIRA: jira_base = scoped_url(jira_base, "jira") if scoped_token else jira_base api_token = credentials["jira_api_token"] # if user provide an email we assume it's cloud if "jira_user_email" in credentials: email = credentials["jira_user_email"] return JIRA( basic_auth=(email, api_token), server=jira_base, options={"rest_api_version": JIRA_CLOUD_API_VERSION}, ) else: return JIRA( token_auth=api_token, server=jira_base, options={"rest_api_version": JIRA_SERVER_API_VERSION}, ) def extract_jira_project(url: str) -> tuple[str, str]: parsed_url = urlparse(url) jira_base = parsed_url.scheme + "://" + parsed_url.netloc # Split the path by '/' and find the position of 'projects' to get the project name split_path = parsed_url.path.split("/") if PROJECT_URL_PAT in split_path: project_pos = split_path.index(PROJECT_URL_PAT) if len(split_path) > project_pos + 1: jira_project = split_path[project_pos + 1] else: raise ValueError("No project name found in the URL") else: raise ValueError("'projects' not found in the URL") return jira_base, jira_project def get_comment_strs( issue: Issue, comment_email_blacklist: tuple[str, ...] = () ) -> list[str]: comment_strs = [] for comment in issue.fields.comment.comments: try: if isinstance(comment.body, str): body_text = comment.body else: body_text = extract_text_from_adf(comment.raw["body"]) if ( hasattr(comment, "author") and hasattr(comment.author, "emailAddress") and comment.author.emailAddress in comment_email_blacklist ): continue # Skip adding comment if author's email is in blacklist comment_strs.append(body_text) except Exception as e: logger.error(f"Failed to process comment due to an error: {e}") continue return comment_strs def get_jira_project_key_from_issue(issue: Issue) -> str | None: if not hasattr(issue, "fields"): return None if not hasattr(issue.fields, "project"): return None if not hasattr(issue.fields.project, "key"): return None return issue.fields.project.key class CustomFieldExtractor: @staticmethod def _process_custom_field_value(value: Any) -> str: """ Process a custom field value to a string """ try: if isinstance(value, str): return value elif isinstance(value, CustomFieldOption): return value.value elif isinstance(value, User): return value.displayName elif isinstance(value, List): return " ".join( [CustomFieldExtractor._process_custom_field_value(v) for v in value] ) else: return str(value) except Exception as e: logger.error(f"Error processing custom field value {value}: {e}") return "" @staticmethod def get_issue_custom_fields( jira: Issue, custom_fields: dict, max_value_length: int = 250 ) -> dict: """ Process all custom fields of an issue to a dictionary of strings :param jira: jira_issue, bug or similar :param custom_fields: custom fields dictionary :param max_value_length: maximum length of the value to be processed, if exceeded, it will be truncated """ issue_custom_fields = { custom_fields[key]: value for key, value in jira.fields.__dict__.items() if value and key in custom_fields.keys() } processed_fields = {} if issue_custom_fields: for key, value in issue_custom_fields.items(): processed = CustomFieldExtractor._process_custom_field_value(value) # We need max length parameter, because there are some plugins that often has very long description # and there is just a technical information so we just avoid long values if len(processed) < max_value_length: processed_fields[key] = processed return processed_fields @staticmethod def get_all_custom_fields(jira_client: JIRA) -> dict: """Get all custom fields from Jira""" fields = jira_client.fields() fields_dct = { field["id"]: field["name"] for field in fields if field["custom"] is True } return fields_dct class CommonFieldExtractor: @staticmethod def get_issue_common_fields(jira: Issue) -> dict: return { "Priority": jira.fields.priority.name if jira.fields.priority else None, "Reporter": ( jira.fields.reporter.displayName if jira.fields.reporter else None ), "Assignee": ( jira.fields.assignee.displayName if jira.fields.assignee else None ), "Status": jira.fields.status.name if jira.fields.status else None, "Resolution": ( jira.fields.resolution.name if jira.fields.resolution else None ), } ================================================ FILE: backend/onyx/connectors/linear/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/linear/connector.py ================================================ import os import re from datetime import datetime from datetime import timezone from typing import Any from typing import cast from urllib.parse import urlparse import requests from typing_extensions import override from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.app_configs import LINEAR_CLIENT_ID from onyx.configs.app_configs import LINEAR_CLIENT_SECRET from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.miscellaneous_utils import ( get_oauth_callback_uri, ) from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import NormalizationResult from onyx.connectors.interfaces import OAuthConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import ImageSection from onyx.connectors.models import TextSection from onyx.utils.logger import setup_logger from onyx.utils.retry_wrapper import request_with_retries logger = setup_logger() _NUM_RETRIES = 5 _TIMEOUT = 60 _LINEAR_GRAPHQL_URL = "https://api.linear.app/graphql" def _make_query(request_body: dict[str, Any], api_key: str) -> requests.Response: headers = { "Authorization": api_key, "Content-Type": "application/json", } for i in range(_NUM_RETRIES): try: response = requests.post( _LINEAR_GRAPHQL_URL, headers=headers, json=request_body, timeout=_TIMEOUT, ) if not response.ok: raise RuntimeError( f"Error fetching issues from Linear: {response.text}" ) return response except Exception as e: if i == _NUM_RETRIES - 1: raise e logger.warning(f"A Linear GraphQL error occurred: {e}. Retrying...") raise RuntimeError( "Unexpected execution when querying Linear. This should never happen." ) class LinearConnector(LoadConnector, PollConnector, OAuthConnector): def __init__( self, batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.batch_size = batch_size self.linear_api_key: str | None = None @classmethod def oauth_id(cls) -> DocumentSource: return DocumentSource.LINEAR @classmethod def oauth_authorization_url( cls, base_domain: str, state: str, additional_kwargs: dict[str, str], # noqa: ARG003 ) -> str: if not LINEAR_CLIENT_ID: raise ValueError("LINEAR_CLIENT_ID environment variable must be set") callback_uri = get_oauth_callback_uri(base_domain, DocumentSource.LINEAR.value) return ( f"https://linear.app/oauth/authorize" f"?client_id={LINEAR_CLIENT_ID}" f"&redirect_uri={callback_uri}" f"&response_type=code" f"&scope=read" f"&state={state}" f"&prompt=consent" # prompts user for access; allows choosing workspace ) @classmethod def oauth_code_to_token( cls, base_domain: str, code: str, additional_kwargs: dict[str, str], # noqa: ARG003 ) -> dict[str, Any]: data = { "code": code, "redirect_uri": get_oauth_callback_uri( base_domain, DocumentSource.LINEAR.value ), "client_id": LINEAR_CLIENT_ID, "client_secret": LINEAR_CLIENT_SECRET, "grant_type": "authorization_code", } headers = {"Content-Type": "application/x-www-form-urlencoded"} response = request_with_retries( method="POST", url="https://api.linear.app/oauth/token", data=data, headers=headers, backoff=0, delay=0.1, ) if not response.ok: raise RuntimeError(f"Failed to exchange code for token: {response.text}") token_data = response.json() return { "access_token": token_data["access_token"], } def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: if "linear_api_key" in credentials: self.linear_api_key = cast(str, credentials["linear_api_key"]) elif "access_token" in credentials: self.linear_api_key = "Bearer " + cast(str, credentials["access_token"]) else: # May need to handle case in the future if the OAuth flow expires raise ConnectorMissingCredentialError("Linear") return None def _process_issues( self, start_str: datetime | None = None, end_str: datetime | None = None ) -> GenerateDocumentsOutput: if self.linear_api_key is None: raise ConnectorMissingCredentialError("Linear") lte_filter = f'lte: "{end_str}"' if end_str else "" gte_filter = f'gte: "{start_str}"' if start_str else "" updatedAtFilter = f""" {lte_filter} {gte_filter} """ query = ( """ query IterateIssueBatches($first: Int, $after: String) { issues( orderBy: updatedAt, first: $first, after: $after, filter: { updatedAt: { """ + updatedAtFilter + """ }, } ) { edges { node { id createdAt updatedAt archivedAt number title priority estimate sortOrder startedAt completedAt startedTriageAt triagedAt canceledAt autoClosedAt autoArchivedAt dueDate slaStartedAt slaBreachesAt trashed snoozedUntilAt team { name } creator { name email } assignee { name email } previousIdentifiers subIssueSortOrder priorityLabel identifier url branchName state { id name } customerTicketCount description comments { nodes { url body } } } } pageInfo { hasNextPage endCursor } } } """ ) has_more = True endCursor = None while has_more: graphql_query = { "query": query, "variables": { "first": self.batch_size, "after": endCursor, }, } logger.debug(f"Requesting issues from Linear with query: {graphql_query}") response = _make_query(graphql_query, self.linear_api_key) response_json = response.json() logger.debug(f"Raw response from Linear: {response_json}") edges = response_json["data"]["issues"]["edges"] documents: list[Document | HierarchyNode] = [] for edge in edges: node = edge["node"] # Create sections for description and comments sections = [ TextSection( link=node["url"], text=node["description"] or "", ) ] # Add comment sections for comment in node["comments"]["nodes"]: sections.append( TextSection( link=node["url"], text=comment["body"] or "", ) ) # Cast the sections list to the expected type typed_sections = cast(list[TextSection | ImageSection], sections) # Extract team name for hierarchy team_name = (node.get("team") or {}).get("name") or "Unknown Team" identifier = node.get("identifier", node["id"]) documents.append( Document( id=node["id"], sections=typed_sections, source=DocumentSource.LINEAR, semantic_identifier=f"[{node['identifier']}] {node['title']}", title=node["title"], doc_updated_at=time_str_to_utc(node["updatedAt"]), doc_metadata={ "hierarchy": { "source_path": [team_name], "team_name": team_name, "identifier": identifier, } }, metadata={ k: str(v) for k, v in { "team": (node.get("team") or {}).get("name"), "creator": node.get("creator"), "assignee": node.get("assignee"), "state": (node.get("state") or {}).get("name"), "priority": node.get("priority"), "estimate": node.get("estimate"), "started_at": node.get("startedAt"), "completed_at": node.get("completedAt"), "created_at": node.get("createdAt"), "due_date": node.get("dueDate"), }.items() if v is not None }, ) ) yield documents endCursor = response_json["data"]["issues"]["pageInfo"]["endCursor"] has_more = response_json["data"]["issues"]["pageInfo"]["hasNextPage"] def load_from_state(self) -> GenerateDocumentsOutput: yield from self._process_issues() def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: start_time = datetime.fromtimestamp(start, tz=timezone.utc) end_time = datetime.fromtimestamp(end, tz=timezone.utc) yield from self._process_issues(start_str=start_time, end_str=end_time) @classmethod @override def normalize_url(cls, url: str) -> NormalizationResult: """Extract Linear issue identifier from URL. Linear URLs are like: https://linear.app/team/issue/IDENTIFIER/... Returns the identifier (e.g., "DAN-2327") which can be used to match Document.link. """ parsed = urlparse(url) netloc = parsed.netloc.lower() if "linear.app" not in netloc: return NormalizationResult(normalized_url=None, use_default=False) # Extract identifier from path: /team/issue/IDENTIFIER/... # Pattern: /{team}/issue/{identifier}/... path_parts = [p for p in parsed.path.split("/") if p] if len(path_parts) >= 3 and path_parts[1] == "issue": identifier = path_parts[2] # Validate identifier format (e.g., "DAN-2327") if re.match(r"^[A-Z]+-\d+$", identifier): return NormalizationResult(normalized_url=identifier, use_default=False) return NormalizationResult(normalized_url=None, use_default=False) if __name__ == "__main__": connector = LinearConnector() connector.load_credentials({"linear_api_key": os.environ["LINEAR_API_KEY"]}) document_batches = connector.load_from_state() print(next(document_batches)) ================================================ FILE: backend/onyx/connectors/loopio/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/loopio/connector.py ================================================ import json from collections.abc import Generator from datetime import datetime from datetime import timezone from typing import Any from oauthlib.oauth2 import BackendApplicationClient from requests_oauthlib import OAuth2Session # type: ignore from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import TextSection from onyx.file_processing.html_utils import parse_html_page_basic from onyx.file_processing.html_utils import strip_excessive_newlines_and_spaces from onyx.utils.logger import setup_logger LOOPIO_API_BASE = "https://api.loopio.com/" LOOPIO_AUTH_URL = LOOPIO_API_BASE + "oauth2/access_token" LOOPIO_DATA_URL = LOOPIO_API_BASE + "data/" logger = setup_logger() class LoopioConnector(LoadConnector, PollConnector): def __init__( self, loopio_stack_name: str | None = None, batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.batch_size = batch_size self.loopio_client_id: str | None = None self.loopio_client_token: str | None = None self.loopio_stack_name = loopio_stack_name def _fetch_data( self, resource: str, params: dict[str, str | int] ) -> Generator[dict[str, Any], None, None]: client = BackendApplicationClient( client_id=self.loopio_client_id, scope=["library:read"] ) session = OAuth2Session(client=client) session.fetch_token( token_url=LOOPIO_AUTH_URL, client_id=self.loopio_client_id, client_secret=self.loopio_client_token, ) page = 0 stop_at_page = 1 while (page := page + 1) <= stop_at_page: params["page"] = page response = session.request( "GET", LOOPIO_DATA_URL + resource, headers={"Accept": "application/json"}, params=params, ) if response.status_code == 400: logger.error( f"Loopio API returned 400 for {resource} with params {params}", ) logger.error(response.text) response.raise_for_status() response_data = json.loads(response.text) stop_at_page = response_data.get("totalPages", 1) yield response_data def _build_search_filter( self, stack_name: str | None, start: str | None, end: str | None ) -> dict[str, Any]: filter: dict[str, Any] = {} if start is not None and end is not None: filter["lastUpdatedDate"] = {"gte": start, "lt": end} if stack_name is not None: # Right now this is fetching the stacks every time, which is not ideal. # We should update this later to store the ID when we create the Connector for stack in self._fetch_data(resource="v2/stacks", params={}): for item in stack["items"]: if item["name"] == stack_name: filter["locations"] = [{"stackID": item["id"]}] break if "locations" not in filter: raise ValueError(f"Stack {stack_name} not found in Loopio") return filter def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self.loopio_subdomain = credentials["loopio_subdomain"] self.loopio_client_id = credentials["loopio_client_id"] self.loopio_client_token = credentials["loopio_client_token"] return None def _process_entries( self, start: str | None = None, end: str | None = None ) -> GenerateDocumentsOutput: if self.loopio_client_id is None or self.loopio_client_token is None: raise ConnectorMissingCredentialError("Loopio") filter = self._build_search_filter( stack_name=self.loopio_stack_name, start=start, end=end ) params: dict[str, str | int] = {"pageSize": self.batch_size} params["filter"] = json.dumps(filter) doc_batch: list[Document | HierarchyNode] = [] for library_entries in self._fetch_data( resource="v2/libraryEntries", params=params ): for entry in library_entries.get("items", []): link = f"https://{self.loopio_subdomain}.loopio.com/library?entry={entry['id']}" topic = "/".join( part["name"] for part in entry["location"].values() if part ) answer_text = entry.get("answer", {}).get("text", "") if not answer_text: logger.warning( f"The Library entry {entry['id']} has no answer text. Skipping." ) continue try: answer = parse_html_page_basic(answer_text) except Exception as e: logger.error(f"Error parsing HTML for entry {entry['id']}: {e}") continue questions = [ question.get("text").replace("\xa0", " ") for question in entry["questions"] if question.get("text") ] questions_string = strip_excessive_newlines_and_spaces( "\n".join(questions) ) content_text = f"{answer}\n\nRelated Questions: {questions_string}" content_text = strip_excessive_newlines_and_spaces( content_text.replace("\xa0", " ") ) last_updated = time_str_to_utc(entry["lastUpdatedDate"]) last_reviewed = ( time_str_to_utc(entry["lastReviewedDate"]) if entry.get("lastReviewedDate") else None ) # For Onyx, we decay document score overtime, either last_updated or # last_reviewed is a good enough signal for the document's recency latest_time = ( max(last_reviewed, last_updated) if last_reviewed else last_updated ) creator = entry.get("creator") last_updated_by = entry.get("lastUpdatedBy") last_reviewed_by = entry.get("lastReviewedBy") primary_owners: list[BasicExpertInfo] = [ BasicExpertInfo(display_name=owner.get("name")) for owner in [creator, last_updated_by] if owner is not None ] secondary_owners: list[BasicExpertInfo] = [ BasicExpertInfo(display_name=owner.get("name")) for owner in [last_reviewed_by] if owner is not None ] doc_batch.append( Document( id=str(entry["id"]), sections=[TextSection(link=link, text=content_text)], source=DocumentSource.LOOPIO, semantic_identifier=questions[0], doc_updated_at=latest_time, primary_owners=primary_owners, secondary_owners=secondary_owners, metadata={ "topic": topic, "questions": "\n".join(questions), "creator": creator.get("name") if creator else "", }, ) ) if len(doc_batch) >= self.batch_size: yield doc_batch doc_batch = [] if len(doc_batch) > 0: yield doc_batch def load_from_state(self) -> GenerateDocumentsOutput: return self._process_entries() def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: start_time = datetime.fromtimestamp(start, tz=timezone.utc).isoformat( timespec="seconds" ) end_time = datetime.fromtimestamp(end, tz=timezone.utc).isoformat( timespec="seconds" ) return self._process_entries(start_time, end_time) if __name__ == "__main__": import os connector = LoopioConnector( loopio_stack_name=os.environ.get("LOOPIO_STACK_NAME", None) ) connector.load_credentials( { "loopio_client_id": os.environ["LOOPIO_CLIENT_ID"], "loopio_client_token": os.environ["LOOPIO_CLIENT_TOKEN"], "loopio_subdomain": os.environ["LOOPIO_SUBDOMAIN"], } ) latest_docs = connector.load_from_state() print(next(latest_docs)) ================================================ FILE: backend/onyx/connectors/mediawiki/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/mediawiki/family.py ================================================ from __future__ import annotations import builtins import functools import itertools import tempfile from typing import Any from unittest import mock from urllib.parse import urlparse from urllib.parse import urlunparse from pywikibot import family # type: ignore[import-untyped] from pywikibot import pagegenerators from pywikibot.scripts import generate_family_file # type: ignore[import-untyped] from pywikibot.scripts.generate_user_files import pywikibot # type: ignore[import-untyped] from onyx.utils.logger import setup_logger logger = setup_logger() pywikibot.config.base_dir = tempfile.TemporaryDirectory().name @mock.patch.object( builtins, "print", lambda *args: logger.info("\t".join(map(str, args))) ) class FamilyFileGeneratorInMemory(generate_family_file.FamilyFileGenerator): """A subclass of FamilyFileGenerator that writes the family file to memory instead of to disk.""" def __init__( self, url: str, name: str, dointerwiki: str | bool = True, verify: str | bool = True, ): """Initialize the FamilyFileGeneratorInMemory.""" url_parse = urlparse(url, "https") if not url_parse.netloc and url_parse.path: url = urlunparse( (url_parse.scheme, url_parse.path, url_parse.netloc, *url_parse[3:]) ) else: url = urlunparse(url_parse) assert isinstance(url, str) if any(x not in generate_family_file.NAME_CHARACTERS for x in name): raise ValueError( f'ERROR: Name of family "{name}" must be ASCII letters and digits [a-zA-Z0-9]', ) if isinstance(dointerwiki, bool): dointerwiki = "Y" if dointerwiki else "N" assert isinstance(dointerwiki, str) if isinstance(verify, bool): verify = "Y" if verify else "N" assert isinstance(verify, str) super().__init__(url, name, dointerwiki, verify) self.family_definition: type[family.Family] | None = None def get_params(self) -> bool: """Get the parameters for the family class definition. This override prevents the method from prompting the user for input (which would be impossible in this context). We do all the input validation in the constructor. """ return True def writefile(self, verify: Any) -> None: # noqa: ARG002 """Write the family file. This overrides the method in the parent class to write the family definition to memory instead of to disk. Args: verify: unused argument necessary to match the signature of the method in the parent class. """ code_hostname_pairs = { f"{k}": f"{urlparse(w.server).netloc}" for k, w in self.wikis.items() } code_path_pairs = {f"{k}": f"{w.scriptpath}" for k, w in self.wikis.items()} code_protocol_pairs = { f"{k}": f"{urlparse(w.server).scheme}" for k, w in self.wikis.items() } class Family(family.Family): # noqa: D101 """The family definition for the wiki.""" name = "%(name)s" langs = code_hostname_pairs def scriptpath(self, code: str) -> str: return code_path_pairs[code] def protocol(self, code: str) -> str: return code_protocol_pairs[code] self.family_definition = Family @functools.lru_cache(maxsize=None) def generate_family_class(url: str, name: str) -> type[family.Family]: """Generate a family file for a given URL and name. Args: url: The URL of the wiki. name: The short name of the wiki (customizable by the user). Returns: The family definition. Raises: ValueError: If the family definition was not generated. """ generator = FamilyFileGeneratorInMemory(url, name, "Y", "Y") generator.run() if generator.family_definition is None: raise ValueError("Family definition was not generated.") return generator.family_definition def family_class_dispatch(url: str, name: str) -> type[family.Family]: """Find or generate a family class for a given URL and name. Args: url: The URL of the wiki. name: The short name of the wiki (customizable by the user). """ if "wikipedia" in url: import pywikibot.families.wikipedia_family # type: ignore[import-untyped] return pywikibot.families.wikipedia_family.Family # TODO: Support additional families pre-defined in `pywikibot.families.*_family.py` files return generate_family_class(url, name) if __name__ == "__main__": url = "fallout.fandom.com/wiki/Fallout_Wiki" name = "falloutfandom" categories: list[str] = [] pages = ["Fallout: New Vegas"] recursion_depth = 1 family_type = generate_family_class(url, name) site = pywikibot.Site(fam=family_type(), code="en") categories = [ pywikibot.Category(site, f"Category:{category.replace(' ', '_')}") for category in categories ] pages = [pywikibot.Page(site, page) for page in pages] all_pages = itertools.chain( pages, *[ pagegenerators.CategorizedPageGenerator(category, recurse=recursion_depth) for category in categories ], ) for page in all_pages: print(page.title()) print(page.text[:1000]) ================================================ FILE: backend/onyx/connectors/mediawiki/wiki.py ================================================ from __future__ import annotations import datetime import itertools import tempfile from collections.abc import Iterator from typing import Any from typing import cast from typing import ClassVar import pywikibot.time # type: ignore[import-untyped] from pywikibot import pagegenerators from pywikibot import textlib from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.mediawiki.family import family_class_dispatch from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import ImageSection from onyx.connectors.models import TextSection from onyx.utils.logger import setup_logger logger = setup_logger() pywikibot.config.base_dir = tempfile.TemporaryDirectory().name def pywikibot_timestamp_to_utc_datetime( timestamp: pywikibot.time.Timestamp, ) -> datetime.datetime: """Convert a pywikibot timestamp to a datetime object in UTC. Args: timestamp: The pywikibot timestamp to convert. Returns: A datetime object in UTC. """ return datetime.datetime.astimezone(timestamp, tz=datetime.timezone.utc) def get_doc_from_page( page: pywikibot.Page, site: pywikibot.Site | None, source_type: DocumentSource ) -> Document: """Generate Onyx Document from a MediaWiki page object. Args: page: Page from a MediaWiki site. site: MediaWiki site (used to parse the sections of the page using the site template, if available). source_type: Source of the document. Returns: Generated document. """ page_text = page.text sections_extracted: textlib.Content = textlib.extract_sections(page_text, site) sections = [ TextSection( link=f"{page.full_url()}#" + section.heading.replace(" ", "_"), text=section.title + section.content, ) for section in sections_extracted.sections ] sections.append( TextSection( link=page.full_url(), text=sections_extracted.header, ) ) return Document( source=source_type, title=page.title(), doc_updated_at=pywikibot_timestamp_to_utc_datetime( page.latest_revision.timestamp ), sections=cast(list[TextSection | ImageSection], sections), semantic_identifier=page.title(), metadata={"categories": [category.title() for category in page.categories()]}, id=f"MEDIAWIKI_{page.pageid}_{page.full_url()}", ) class MediaWikiConnector(LoadConnector, PollConnector): """A connector for MediaWiki wikis. Args: hostname: The hostname of the wiki. categories: The categories to include in the index. pages: The pages to include in the index. recurse_depth: The depth to recurse into categories. -1 means unbounded recursion. language_code: The language code of the wiki. batch_size: The batch size for loading documents. Raises: ValueError: If `recurse_depth` is not an integer greater than or equal to -1. """ document_source_type: ClassVar[DocumentSource] = DocumentSource.MEDIAWIKI """DocumentSource type for all documents generated by instances of this class. Can be overridden for connectors tailored for specific sites.""" def __init__( self, hostname: str, categories: list[str], pages: list[str], recurse_depth: int, language_code: str = "en", batch_size: int = INDEX_BATCH_SIZE, ) -> None: if recurse_depth < -1: raise ValueError( f"recurse_depth must be an integer greater than or equal to -1. Got {recurse_depth} instead." ) # -1 means infinite recursion, which `pywikibot` will only do with `True` self.recurse_depth: bool | int = True if recurse_depth == -1 else recurse_depth self.batch_size = batch_size # short names can only have ascii letters and digits self.family = family_class_dispatch(hostname, "WikipediaConnector")() self.site = pywikibot.Site(fam=self.family, code=language_code) self.categories = [ pywikibot.Category( self.site, ( f"{category.replace(' ', '_')}" if category.startswith("Category:") else f"Category:{category.replace(' ', '_')}" ), ) for category in categories ] self.pages = [] for page in pages: if not page: continue self.pages.append(pywikibot.Page(self.site, page)) def load_credentials( self, credentials: dict[str, Any], # noqa: ARG002 ) -> dict[str, Any] | None: """Load credentials for a MediaWiki site. Note: For most read-only operations, MediaWiki API credentials are not necessary. This method can be overridden in the event that a particular MediaWiki site requires credentials. """ return None def _get_doc_batch( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> GenerateDocumentsOutput: """Request batches of pages from a MediaWiki site. Args: start: The beginning of the time period of pages to request. end: The end of the time period of pages to request. Yields: Lists of Documents containing each parsed page in a batch. """ doc_batch: list[Document | HierarchyNode] = [] # Pywikibot can handle batching for us, including only loading page contents when we finally request them. category_pages = [ pagegenerators.PreloadingGenerator( pagegenerators.EdittimeFilterPageGenerator( pagegenerators.CategorizedPageGenerator( category, recurse=self.recurse_depth ), last_edit_start=( datetime.datetime.fromtimestamp(start) if start else None ), last_edit_end=datetime.datetime.fromtimestamp(end) if end else None, ), groupsize=self.batch_size, ) for category in self.categories ] # Since we can specify both individual pages and categories, we need to iterate over all of them. all_pages: Iterator[pywikibot.Page] = itertools.chain( self.pages, *category_pages ) for page in all_pages: logger.info( f"MediaWikiConnector: title='{page.title()}' url={page.full_url()}" ) doc_batch.append( get_doc_from_page(page, self.site, self.document_source_type) ) if len(doc_batch) >= self.batch_size: yield doc_batch doc_batch = [] if doc_batch: yield doc_batch def load_from_state(self) -> GenerateDocumentsOutput: """Load all documents from the source. Returns: A generator of documents. """ return self.poll_source(None, None) def poll_source( self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None ) -> GenerateDocumentsOutput: """Poll the source for new documents. Args: start: The start of the time range to poll. end: The end of the time range to poll. Returns: A generator of documents. """ return self._get_doc_batch(start, end) if __name__ == "__main__": HOSTNAME = "fallout.fandom.com" test_connector = MediaWikiConnector( hostname=HOSTNAME, categories=["Fallout:_New_Vegas_factions"], pages=["Fallout: New Vegas"], recurse_depth=1, ) all_docs = list(test_connector.load_from_state()) print("All docs", all_docs) current = datetime.datetime.now().timestamp() one_day_ago = current - 30 * 24 * 60 * 60 # 30 days latest_docs = list(test_connector.poll_source(one_day_ago, current)) print("Latest docs", latest_docs) ================================================ FILE: backend/onyx/connectors/microsoft_graph_env.py ================================================ """Inverse mapping from user-facing Microsoft host URLs to the SDK's AzureEnvironment. The office365 library's GraphClient requires an ``AzureEnvironment`` string (e.g. ``"Global"``, ``"GCC High"``) to route requests to the correct national cloud. Our connectors instead expose free-text ``authority_host`` and ``graph_api_host`` fields so the frontend doesn't need to know about SDK internals. This module bridges the gap: given the two host URLs the user configured, it resolves the matching ``AzureEnvironment`` value (and the implied SharePoint domain suffix) so callers can pass ``environment=…`` to ``GraphClient``. """ from office365.graph_client import AzureEnvironment # type: ignore[import-untyped] from pydantic import BaseModel from onyx.connectors.exceptions import ConnectorValidationError class MicrosoftGraphEnvironment(BaseModel): """One row of the inverse mapping.""" environment: str graph_host: str authority_host: str sharepoint_domain_suffix: str _ENVIRONMENTS: list[MicrosoftGraphEnvironment] = [ MicrosoftGraphEnvironment( environment=AzureEnvironment.Global, graph_host="https://graph.microsoft.com", authority_host="https://login.microsoftonline.com", sharepoint_domain_suffix="sharepoint.com", ), MicrosoftGraphEnvironment( environment=AzureEnvironment.USGovernmentHigh, graph_host="https://graph.microsoft.us", authority_host="https://login.microsoftonline.us", sharepoint_domain_suffix="sharepoint.us", ), MicrosoftGraphEnvironment( environment=AzureEnvironment.USGovernmentDoD, graph_host="https://dod-graph.microsoft.us", authority_host="https://login.microsoftonline.us", sharepoint_domain_suffix="sharepoint.us", ), MicrosoftGraphEnvironment( environment=AzureEnvironment.China, graph_host="https://microsoftgraph.chinacloudapi.cn", authority_host="https://login.chinacloudapi.cn", sharepoint_domain_suffix="sharepoint.cn", ), MicrosoftGraphEnvironment( environment=AzureEnvironment.Germany, graph_host="https://graph.microsoft.de", authority_host="https://login.microsoftonline.de", sharepoint_domain_suffix="sharepoint.de", ), ] _GRAPH_HOST_INDEX: dict[str, MicrosoftGraphEnvironment] = { env.graph_host: env for env in _ENVIRONMENTS } def resolve_microsoft_environment( graph_api_host: str, authority_host: str, ) -> MicrosoftGraphEnvironment: """Return the ``MicrosoftGraphEnvironment`` that matches the supplied hosts. Raises ``ConnectorValidationError`` when the combination is unknown or internally inconsistent (e.g. a GCC-High graph host paired with a commercial authority host). """ graph_api_host = graph_api_host.rstrip("/") authority_host = authority_host.rstrip("/") env = _GRAPH_HOST_INDEX.get(graph_api_host) if env is None: known = ", ".join(sorted(_GRAPH_HOST_INDEX)) raise ConnectorValidationError( f"Unsupported Microsoft Graph API host '{graph_api_host}'. Recognised hosts: {known}" ) if env.authority_host != authority_host: raise ConnectorValidationError( f"Authority host '{authority_host}' is inconsistent with " f"graph API host '{graph_api_host}'. " f"Expected authority host '{env.authority_host}' " f"for the {env.environment} environment." ) return env ================================================ FILE: backend/onyx/connectors/mock_connector/connector.py ================================================ from typing import Any import httpx from pydantic import BaseModel from typing_extensions import override from onyx.access.models import ExternalAccess from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorCheckpoint from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import Document from onyx.utils.logger import setup_logger logger = setup_logger() EXTERNAL_USER_EMAILS = {"test@example.com", "admin@example.com"} EXTERNAL_USER_GROUP_IDS = {"mock-group-1", "mock-group-2"} class MockConnectorCheckpoint(ConnectorCheckpoint): last_document_id: str | None = None class SingleConnectorYield(BaseModel): documents: list[Document] checkpoint: MockConnectorCheckpoint failures: list[ConnectorFailure] unhandled_exception: str | None = None class MockConnector(CheckpointedConnectorWithPermSync[MockConnectorCheckpoint]): def __init__( self, mock_server_host: str, mock_server_port: int, ) -> None: self.mock_server_host = mock_server_host self.mock_server_port = mock_server_port self.client = httpx.Client(timeout=30.0) self.connector_yields: list[SingleConnectorYield] | None = None self.current_yield_index: int = 0 def load_credentials( self, credentials: dict[str, Any], # noqa: ARG002 ) -> dict[str, Any] | None: response = self.client.get(self._get_mock_server_url("get-documents")) response.raise_for_status() data = response.json() self.connector_yields = [ SingleConnectorYield(**yield_data) for yield_data in data ] return None def _get_mock_server_url(self, endpoint: str) -> str: return f"http://{self.mock_server_host}:{self.mock_server_port}/{endpoint}" def _save_checkpoint(self, checkpoint: MockConnectorCheckpoint) -> None: response = self.client.post( self._get_mock_server_url("add-checkpoint"), json=checkpoint.model_dump(mode="json"), ) response.raise_for_status() def _load_from_checkpoint_common( self, start: SecondsSinceUnixEpoch, # noqa: ARG002 end: SecondsSinceUnixEpoch, # noqa: ARG002 checkpoint: MockConnectorCheckpoint, include_permissions: bool = False, ) -> CheckpointOutput[MockConnectorCheckpoint]: if self.connector_yields is None: raise ValueError("No connector yields configured") # Save the checkpoint to the mock server self._save_checkpoint(checkpoint) yield_index = self.current_yield_index self.current_yield_index += 1 current_yield = self.connector_yields[yield_index] # If the current yield has an unhandled exception, raise it # This is used to simulate an unhandled failure in the connector. if current_yield.unhandled_exception: raise RuntimeError(current_yield.unhandled_exception) # yield all documents for document in current_yield.documents: # If permissions are requested and not already set, add mock permissions if include_permissions and document.external_access is None: # Add mock permissions - make documents accessible to specific users/groups document.external_access = ExternalAccess( external_user_emails=EXTERNAL_USER_EMAILS, external_user_group_ids=EXTERNAL_USER_GROUP_IDS, is_public=False, ) yield document for failure in current_yield.failures: yield failure return current_yield.checkpoint def load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: MockConnectorCheckpoint, ) -> CheckpointOutput[MockConnectorCheckpoint]: return self._load_from_checkpoint_common( start, end, checkpoint, include_permissions=False ) @override def load_from_checkpoint_with_perm_sync( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: MockConnectorCheckpoint, ) -> CheckpointOutput[MockConnectorCheckpoint]: return self._load_from_checkpoint_common( start, end, checkpoint, include_permissions=True ) @override def build_dummy_checkpoint(self) -> MockConnectorCheckpoint: return MockConnectorCheckpoint( has_more=True, last_document_id=None, ) def validate_checkpoint_json(self, checkpoint_json: str) -> MockConnectorCheckpoint: return MockConnectorCheckpoint.model_validate_json(checkpoint_json) ================================================ FILE: backend/onyx/connectors/models.py ================================================ import sys from datetime import datetime from enum import Enum from typing import Any from typing import cast from pydantic import BaseModel from pydantic import Field from pydantic import field_validator from pydantic import model_validator from onyx.access.models import ExternalAccess from onyx.configs.constants import DocumentSource from onyx.configs.constants import INDEX_SEPARATOR from onyx.configs.constants import RETURN_SEPARATOR from onyx.db.enums import HierarchyNodeType from onyx.db.enums import IndexModelStatus from onyx.utils.text_processing import make_url_compatible class InputType(str, Enum): LOAD_STATE = "load_state" # e.g. loading a current full state or a save state, such as from a file POLL = "poll" # e.g. calling an API to get all documents in the last hour EVENT = "event" # e.g. registered an endpoint as a listener, and processing connector events SLIM_RETRIEVAL = "slim_retrieval" class ConnectorMissingCredentialError(PermissionError): def __init__(self, connector_name: str) -> None: connector_name = connector_name or "Unknown" super().__init__( f"{connector_name} connector missing credentials, was load_credentials called?" ) class Section(BaseModel): """Base section class with common attributes""" link: str | None = None text: str | None = None image_file_id: str | None = None class TextSection(Section): """Section containing text content""" text: str def __sizeof__(self) -> int: return sys.getsizeof(self.text) + sys.getsizeof(self.link) class ImageSection(Section): """Section containing an image reference""" image_file_id: str def __sizeof__(self) -> int: return sys.getsizeof(self.image_file_id) + sys.getsizeof(self.link) class BasicExpertInfo(BaseModel): """Basic Information for the owner of a document, any of the fields can be left as None Display fallback goes as follows: - first_name + (optional middle_initial) + last_name - display_name - email - first_name """ display_name: str | None = None first_name: str | None = None middle_initial: str | None = None last_name: str | None = None email: str | None = None def get_semantic_name(self) -> str: if self.first_name and self.last_name: name_parts = [self.first_name] if self.middle_initial: name_parts.append(self.middle_initial + ".") name_parts.append(self.last_name) return " ".join([name_part.capitalize() for name_part in name_parts]) if self.display_name: return self.display_name if self.email: return self.email if self.first_name: return self.first_name.capitalize() return "Unknown" def get_email(self) -> str | None: return self.email or None def __eq__(self, other: Any) -> bool: if not isinstance(other, BasicExpertInfo): return False return ( self.display_name, self.first_name, self.middle_initial, self.last_name, self.email, ) == ( other.display_name, other.first_name, other.middle_initial, other.last_name, other.email, ) def __hash__(self) -> int: return hash( ( self.display_name, self.first_name, self.middle_initial, self.last_name, self.email, ) ) def __sizeof__(self) -> int: size = sys.getsizeof(self.display_name) size += sys.getsizeof(self.first_name) size += sys.getsizeof(self.middle_initial) size += sys.getsizeof(self.last_name) size += sys.getsizeof(self.email) return size @classmethod def from_dict(cls, model_dict: dict[str, Any]) -> "BasicExpertInfo": first_name = cast(str, model_dict.get("FirstName")) last_name = cast(str, model_dict.get("LastName")) email = cast(str, model_dict.get("Email")) display_name = cast(str, model_dict.get("Name")) # Check if all fields are None if ( first_name is None and last_name is None and email is None and display_name is None ): raise ValueError("No identifying information found for user") return cls( first_name=first_name, last_name=last_name, email=email, display_name=display_name, ) class DocumentBase(BaseModel): """Used for Onyx ingestion api, the ID is inferred before use if not provided""" id: str | None = None sections: list[TextSection | ImageSection] source: DocumentSource | None = None semantic_identifier: str # displayed in the UI as the main identifier for the doc # TODO(andrei): Ideally we could improve this to where each value is just a # list of strings. metadata: dict[str, str | list[str]] @field_validator("metadata", mode="before") @classmethod def _coerce_metadata_values(cls, v: dict[str, Any]) -> dict[str, str | list[str]]: return { key: [str(item) for item in val] if isinstance(val, list) else str(val) for key, val in v.items() } # UTC time doc_updated_at: datetime | None = None chunk_count: int | None = None # Owner, creator, etc. primary_owners: list[BasicExpertInfo] | None = None # Assignee, space owner, etc. secondary_owners: list[BasicExpertInfo] | None = None # title is used for search whereas semantic_identifier is used for displaying in the UI # different because Slack message may display as #general but general should not be part # of the search, at least not in the same way as a document title should be for like Confluence # The default title is semantic_identifier though unless otherwise specified title: str | None = None from_ingestion_api: bool = False # Anything else that may be useful that is specific to this particular connector type that other # parts of the code may need. If you're unsure, this can be left as None additional_info: Any = None # only filled in EE for connectors w/ permission sync enabled external_access: ExternalAccess | None = None doc_metadata: dict[str, Any] | None = None # Parent hierarchy node raw ID - the folder/space/page containing this document # If None, document's hierarchy position is unknown or connector doesn't support hierarchy parent_hierarchy_raw_node_id: str | None = None # Resolved database ID of the parent hierarchy node # Set during docfetching after hierarchy nodes are cached parent_hierarchy_node_id: int | None = None def get_title_for_document_index( self, ) -> str | None: # If title is explicitly empty, return a None here for embedding purposes if self.title == "": return None replace_chars = set(RETURN_SEPARATOR) title = self.semantic_identifier if self.title is None else self.title for char in replace_chars: title = title.replace(char, " ") title = title.strip() return title def get_metadata_str_attributes(self) -> list[str] | None: if not self.metadata: return None # Combined string for the key/value for easy filtering return convert_metadata_dict_to_list_of_strings(self.metadata) def __sizeof__(self) -> int: size = sys.getsizeof(self.id) for section in self.sections: size += sys.getsizeof(section) size += sys.getsizeof(self.source) size += sys.getsizeof(self.semantic_identifier) size += sys.getsizeof(self.doc_updated_at) size += sys.getsizeof(self.chunk_count) if self.primary_owners is not None: for primary_owner in self.primary_owners: size += sys.getsizeof(primary_owner) else: size += sys.getsizeof(self.primary_owners) if self.secondary_owners is not None: for secondary_owner in self.secondary_owners: size += sys.getsizeof(secondary_owner) else: size += sys.getsizeof(self.secondary_owners) size += sys.getsizeof(self.title) size += sys.getsizeof(self.from_ingestion_api) size += sys.getsizeof(self.additional_info) return size def get_text_content(self) -> str: return " ".join([section.text for section in self.sections if section.text]) def convert_metadata_dict_to_list_of_strings( metadata: dict[str, str | list[str]], ) -> list[str]: """Converts a metadata dict to a list of strings. Each string is a key-value pair separated by the INDEX_SEPARATOR. If a key points to a list of values, each value generates a unique pair. NOTE: Whatever formatting strategy is used here to generate a key-value string must be replicated when constructing query filters. Args: metadata: The metadata dict to convert where values can be either a string or a list of strings. Returns: A list of strings where each string is a key-value pair separated by the INDEX_SEPARATOR. """ attributes: list[str] = [] for k, v in metadata.items(): if isinstance(v, list): attributes.extend([k + INDEX_SEPARATOR + vi for vi in v]) else: attributes.append(k + INDEX_SEPARATOR + v) return attributes def convert_metadata_list_of_strings_to_dict( metadata_list: list[str], ) -> dict[str, str | list[str]]: """ Converts a list of strings to a metadata dict. The inverse of convert_metadata_dict_to_list_of_strings. Assumes the input strings are formatted as in the output of convert_metadata_dict_to_list_of_strings. The schema of the output metadata dict is suboptimal yet bound to legacy code. Ideally each key would just point to a list of strings, where each list might contain just one element. Args: metadata_list: The list of strings to convert to a metadata dict. Returns: A metadata dict where values can be either a string or a list of strings. """ metadata: dict[str, str | list[str]] = {} for item in metadata_list: key, value = item.split(INDEX_SEPARATOR, 1) if key in metadata: # We have already seen this key therefore it must point to a list. if isinstance(metadata[key], list): cast(list[str], metadata[key]).append(value) else: metadata[key] = [cast(str, metadata[key]), value] else: metadata[key] = value return metadata class Document(DocumentBase): """Used for Onyx ingestion api, the ID is required""" id: str source: DocumentSource def to_short_descriptor(self) -> str: """Used when logging the identity of a document""" return f"ID: '{self.id}'; Semantic ID: '{self.semantic_identifier}'" @classmethod def from_base(cls, base: DocumentBase) -> "Document": return cls( id=( make_url_compatible(base.id) if base.id else "ingestion_api_" + make_url_compatible(base.semantic_identifier) ), sections=base.sections, source=base.source or DocumentSource.INGESTION_API, semantic_identifier=base.semantic_identifier, metadata=base.metadata, doc_updated_at=base.doc_updated_at, primary_owners=base.primary_owners, secondary_owners=base.secondary_owners, title=base.title, from_ingestion_api=base.from_ingestion_api, ) def __sizeof__(self) -> int: size = super().__sizeof__() size += sys.getsizeof(self.id) size += sys.getsizeof(self.source) return size class IndexingDocument(Document): """Document with processed sections for indexing""" processed_sections: list[Section] = [] def get_total_char_length(self) -> int: """Get the total character length of the document including processed sections""" title_len = len(self.title or self.semantic_identifier) # Use processed_sections if available, otherwise fall back to original sections if self.processed_sections: section_len = sum( len(section.text) if section.text is not None else 0 for section in self.processed_sections ) else: section_len = sum( ( len(section.text) if isinstance(section, TextSection) and section.text is not None else 0 ) for section in self.sections ) return title_len + section_len class SlimDocument(BaseModel): id: str external_access: ExternalAccess | None = None parent_hierarchy_raw_node_id: str | None = None class HierarchyNode(BaseModel): """ Hierarchy node yielded by connectors. This is the Pydantic model used by connectors, distinct from the SQLAlchemy HierarchyNode model in db/models.py. The connector runner layer converts this to the DB model when persisting to Postgres. """ # Raw identifier from the source system # e.g., "1h7uWUR2BYZjtMfEXFt43tauj-Gp36DTPtwnsNuA665I" for Google Drive raw_node_id: str # Raw ID of parent node, or None for SOURCE-level children (direct children of the source root) raw_parent_id: str | None = None # Human-readable name for display display_name: str # Link to view this node in the source system link: str | None = None # What kind of structural node this is (folder, space, page, etc.) node_type: HierarchyNodeType # If this hierarchy node represents a document (e.g., Confluence page), # The db model stores that doc's document_id. This gets set during docprocessing # after the document row is created. Matching is done by raw_node_id matching document.id. # so, we don't allow connectors to specify this as it would be unused # document_id: str | None = None # External access information for the node external_access: ExternalAccess | None = None class IndexAttemptMetadata(BaseModel): connector_id: int credential_id: int batch_num: int | None = None attempt_id: int | None = None request_id: str | None = None # Work in progress: will likely contain metadata about cc pair / index attempt structured_id: str | None = None class ConnectorCheckpoint(BaseModel): # TODO: maybe move this to something disk-based to handle extremely large checkpoints? has_more: bool def __str__(self) -> str: """String representation of the checkpoint, with truncation for large checkpoint content.""" MAX_CHECKPOINT_CONTENT_CHARS = 1000 content_str = self.model_dump_json() if len(content_str) > MAX_CHECKPOINT_CONTENT_CHARS: content_str = content_str[: MAX_CHECKPOINT_CONTENT_CHARS - 3] + "..." return content_str class DocumentFailure(BaseModel): document_id: str document_link: str | None = None class EntityFailure(BaseModel): entity_id: str missed_time_range: tuple[datetime, datetime] | None = None class ConnectorFailure(BaseModel): failed_document: DocumentFailure | None = None failed_entity: EntityFailure | None = None failure_message: str exception: Exception | None = Field(default=None, exclude=True) model_config = {"arbitrary_types_allowed": True} @model_validator(mode="before") def check_failed_fields(cls, values: dict) -> dict: failed_document = values.get("failed_document") failed_entity = values.get("failed_entity") if (failed_document is None and failed_entity is None) or ( failed_document is not None and failed_entity is not None ): raise ValueError( "Exactly one of 'failed_document' or 'failed_entity' must be specified." ) return values class ConnectorStopSignal(Exception): """A custom exception used to signal a stop in processing.""" class OnyxMetadata(BaseModel): # Careful overriding the document_id, may cause visual issues in the UI. # Kept here for API based use cases mostly document_id: str | None = None source_type: DocumentSource | None = None link: str | None = None file_display_name: str | None = None primary_owners: list[BasicExpertInfo] | None = None secondary_owners: list[BasicExpertInfo] | None = None doc_updated_at: datetime | None = None title: str | None = None class DocExtractionContext(BaseModel): index_name: str cc_pair_id: int connector_id: int credential_id: int source: DocumentSource earliest_index_time: float from_beginning: bool is_primary: bool should_fetch_permissions_during_indexing: bool search_settings_status: IndexModelStatus doc_extraction_complete_batch_num: int | None class DocIndexingContext(BaseModel): batches_done: int total_failures: int net_doc_change: int total_chunks: int ================================================ FILE: backend/onyx/connectors/notion/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/notion/connector.py ================================================ import re from collections.abc import Generator from datetime import datetime from datetime import timezone from typing import Any from typing import cast from typing import Optional from urllib.parse import parse_qs from urllib.parse import urlparse import requests from pydantic import BaseModel from retry import retry from typing_extensions import override from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.app_configs import NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.rate_limit_wrapper import ( rl_requests, ) from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.exceptions import UnexpectedValidationError from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import NormalizationResult from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import ImageSection from onyx.connectors.models import TextSection from onyx.db.enums import HierarchyNodeType from onyx.utils.batching import batch_generator from onyx.utils.logger import setup_logger logger = setup_logger() _NOTION_PAGE_SIZE = 100 _NOTION_CALL_TIMEOUT = 30 # 30 seconds _MAX_PAGES = 1000 # TODO: Tables need to be ingested, Pages need to have their metadata ingested class NotionPage(BaseModel): """Represents a Notion Page object""" id: str created_time: str last_edited_time: str in_trash: bool properties: dict[str, Any] url: str database_name: str | None = None # Only applicable to the database type page (wiki) parent: dict[str, Any] | None = ( None # Raw parent object from API for hierarchy tracking ) class NotionDataSource(BaseModel): """Represents a Notion Data Source within a database.""" id: str name: str = "" class NotionBlock(BaseModel): """Represents a Notion Block object""" id: str # Used for the URL text: str # In a plaintext representation of the page, how this block should be joined # with the existing text up to this point, separated out from text for clarity prefix: str class NotionSearchResponse(BaseModel): """Represents the response from the Notion Search API""" results: list[dict[str, Any]] next_cursor: Optional[str] has_more: bool = False class BlockReadOutput(BaseModel): """Output from reading blocks of a page.""" blocks: list[NotionBlock] child_page_ids: list[str] hierarchy_nodes: list[HierarchyNode] class NotionConnector(LoadConnector, PollConnector): """Notion Page connector that reads all Notion pages this integration has been granted access to. Arguments: batch_size (int): Number of objects to index in a batch """ def __init__( self, batch_size: int = INDEX_BATCH_SIZE, recursive_index_enabled: bool = not NOTION_CONNECTOR_DISABLE_RECURSIVE_PAGE_LOOKUP, root_page_id: str | None = None, ) -> None: """Initialize with parameters.""" self.batch_size = batch_size self.headers = { "Content-Type": "application/json", "Notion-Version": "2026-03-11", } self.indexed_pages: set[str] = set() self.root_page_id = root_page_id # if enabled, will recursively index child pages as they are found rather # relying entirely on the `search` API. We have received reports that the # `search` API misses many pages - in those cases, this might need to be # turned on. It's not currently known why/when this is required. # NOTE: this also removes all benefits polling, since we need to traverse # all pages regardless of if they are updated. If the notion workspace is # very large, this may not be practical. self.recursive_index_enabled = recursive_index_enabled or self.root_page_id # Hierarchy tracking state self.seen_hierarchy_node_raw_ids: set[str] = set() self.workspace_id: str | None = None self.workspace_name: str | None = None # Maps child page IDs to their containing page ID (discovered in _read_blocks). # Used to resolve block_id parent types to the actual containing page. self._child_page_parent_map: dict[str, str] = {} # Maps data_source_id -> database_id (populated in _read_pages_from_database). # Used to resolve data_source_id parent types back to the database. self._data_source_to_database_map: dict[str, str] = {} @classmethod @override def normalize_url(cls, url: str) -> NormalizationResult: """Normalize a Notion URL to extract the page ID (UUID format).""" parsed = urlparse(url) netloc = parsed.netloc.lower() if not ("notion.so" in netloc or "notion.site" in netloc): return NormalizationResult(normalized_url=None, use_default=False) # Extract page ID from path (format: "Title-PageID") path_last = parsed.path.split("/")[-1] candidate = path_last.split("-")[-1] if "-" in path_last else path_last # Clean and format as UUID candidate = re.sub(r"[^0-9a-fA-F-]", "", candidate) cleaned = candidate.replace("-", "") if len(cleaned) == 32 and re.fullmatch(r"[0-9a-fA-F]{32}", cleaned): normalized_uuid = ( f"{cleaned[0:8]}-{cleaned[8:12]}-{cleaned[12:16]}-{cleaned[16:20]}-{cleaned[20:]}" ).lower() return NormalizationResult( normalized_url=normalized_uuid, use_default=False ) # Try query params params = parse_qs(parsed.query) for key in ("p", "page_id"): if key in params and params[key]: candidate = params[key][0].replace("-", "") if len(candidate) == 32 and re.fullmatch(r"[0-9a-fA-F]{32}", candidate): normalized_uuid = ( f"{candidate[0:8]}-{candidate[8:12]}-{candidate[12:16]}-{candidate[16:20]}-{candidate[20:]}" ).lower() return NormalizationResult( normalized_url=normalized_uuid, use_default=False ) return NormalizationResult(normalized_url=None, use_default=False) @retry(tries=3, delay=1, backoff=2) def _fetch_child_blocks( self, block_id: str, cursor: str | None = None ) -> dict[str, Any] | None: """Fetch all child blocks via the Notion API.""" logger.debug(f"Fetching children of block with ID '{block_id}'") block_url = f"https://api.notion.com/v1/blocks/{block_id}/children" query_params = None if not cursor else {"start_cursor": cursor} res = rl_requests.get( block_url, headers=self.headers, params=query_params, timeout=_NOTION_CALL_TIMEOUT, ) try: res.raise_for_status() except Exception as e: if res.status_code == 404: # this happens when a page is not shared with the integration # in this case, we should just ignore the page logger.error( f"Unable to access block with ID '{block_id}'. " f"This is likely due to the block not being shared " f"with the Onyx integration. Exact exception:\n\n{e}" ) else: logger.exception( f"Error fetching blocks with status code {res.status_code}: {res.json()}" ) # This can occasionally happen, the reason is unknown and cannot be reproduced on our internal Notion # Assuming this will not be a critical loss of data return None return res.json() @retry(tries=3, delay=1, backoff=2) def _fetch_page(self, page_id: str) -> NotionPage: """Fetch a page from its ID via the Notion API, retry with database if page fetch fails.""" logger.debug(f"Fetching page for ID '{page_id}'") page_url = f"https://api.notion.com/v1/pages/{page_id}" res = rl_requests.get( page_url, headers=self.headers, timeout=_NOTION_CALL_TIMEOUT, ) try: res.raise_for_status() except Exception as e: logger.warning( f"Failed to fetch page, trying database for ID '{page_id}'. Exception: {e}" ) # Try fetching as a database if page fetch fails, this happens if the page is set to a wiki # it becomes a database from the notion perspective return self._fetch_database_as_page(page_id) return NotionPage(**res.json()) @retry(tries=3, delay=1, backoff=2) def _fetch_database_as_page(self, database_id: str) -> NotionPage: """Attempt to fetch a database as a page. Note: As of API 2025-09-03, database objects no longer include `properties` (schema moved to individual data sources). """ logger.debug(f"Fetching database for ID '{database_id}' as a page") database_url = f"https://api.notion.com/v1/databases/{database_id}" res = rl_requests.get( database_url, headers=self.headers, timeout=_NOTION_CALL_TIMEOUT, ) try: res.raise_for_status() except Exception as e: logger.exception(f"Error fetching database as page - {res.json()}") raise e db_data = res.json() database_name = db_data.get("title") database_name = ( database_name[0].get("text", {}).get("content") if database_name else None ) db_data.setdefault("properties", {}) return NotionPage(**db_data, database_name=database_name) @retry(tries=3, delay=1, backoff=2) def _fetch_data_sources_for_database( self, database_id: str ) -> list[NotionDataSource]: """Fetch the list of data sources for a database.""" logger.debug(f"Fetching data sources for database '{database_id}'") res = rl_requests.get( f"https://api.notion.com/v1/databases/{database_id}", headers=self.headers, timeout=_NOTION_CALL_TIMEOUT, ) try: res.raise_for_status() except Exception as e: if res.status_code in (403, 404): logger.error( f"Unable to access database with ID '{database_id}'. " f"This is likely due to the database not being shared " f"with the Onyx integration. Exact exception:\n{e}" ) return [] logger.exception(f"Error fetching database - {res.json()}") raise e db_data = res.json() data_sources = db_data.get("data_sources", []) return [ NotionDataSource(id=ds["id"], name=ds.get("name", "")) for ds in data_sources if ds.get("id") ] @retry(tries=3, delay=1, backoff=2) def _fetch_data_source( self, data_source_id: str, cursor: str | None = None ) -> dict[str, Any]: """Query a data source via POST /v1/data_sources/{id}/query.""" logger.debug(f"Querying data source '{data_source_id}'") url = f"https://api.notion.com/v1/data_sources/{data_source_id}/query" body = None if not cursor else {"start_cursor": cursor} res = rl_requests.post( url, headers=self.headers, json=body, timeout=_NOTION_CALL_TIMEOUT, ) try: res.raise_for_status() except Exception as e: if res.status_code in (403, 404): logger.error( f"Unable to access data source with ID '{data_source_id}'. " f"This is likely due to it not being shared " f"with the Onyx integration. Exact exception:\n{e}" ) return {"results": [], "next_cursor": None} logger.exception(f"Error querying data source - {res.json()}") raise e return res.json() @retry(tries=3, delay=1, backoff=2) def _fetch_workspace_info(self) -> tuple[str, str]: """Fetch workspace ID and name from the bot user endpoint.""" res = rl_requests.get( "https://api.notion.com/v1/users/me", headers=self.headers, timeout=_NOTION_CALL_TIMEOUT, ) res.raise_for_status() data = res.json() bot = data.get("bot", {}) # workspace_id may be in bot object, fallback to user id workspace_id = bot.get("workspace_id", data.get("id")) workspace_name = bot.get("workspace_name", "Notion Workspace") return workspace_id, workspace_name def _get_workspace_hierarchy_node(self) -> HierarchyNode | None: """Get the workspace hierarchy node, fetching workspace info if needed. Returns None if the workspace node has already been yielded. """ if self.workspace_id is None: self.workspace_id, self.workspace_name = self._fetch_workspace_info() if self.workspace_id in self.seen_hierarchy_node_raw_ids: return None self.seen_hierarchy_node_raw_ids.add(self.workspace_id) return HierarchyNode( raw_node_id=self.workspace_id, raw_parent_id=None, # Parent is SOURCE (auto-created by system) display_name=self.workspace_name or "Notion Workspace", link=f"https://notion.so/{self.workspace_id.replace('-', '')}", node_type=HierarchyNodeType.WORKSPACE, ) def _get_parent_raw_id( self, parent: dict[str, Any] | None, page_id: str | None = None ) -> str | None: """Get the parent raw ID for hierarchy tracking. Returns workspace_id for top-level pages, or the direct parent ID for nested pages. Args: parent: The parent object from the Notion API page_id: The page's own ID, used to look up block_id parents in our cache """ if not parent: return self.workspace_id # Default to workspace if no parent info parent_type = parent.get("type") if parent_type == "workspace": return self.workspace_id elif parent_type == "block_id": # Inline page in a block - resolve to the containing page if we discovered it if page_id and page_id in self._child_page_parent_map: return self._child_page_parent_map[page_id] # Fallback to workspace if we don't know the parent return self.workspace_id elif parent_type == "data_source_id": ds_id = parent.get("data_source_id") if ds_id: return self._data_source_to_database_map.get(ds_id, self.workspace_id) elif parent_type in ["page_id", "database_id"]: return parent.get(parent_type) return self.workspace_id def _maybe_yield_hierarchy_node( self, raw_node_id: str, raw_parent_id: str | None, display_name: str, link: str | None, node_type: HierarchyNodeType, ) -> HierarchyNode | None: """Create and return a hierarchy node if not already yielded. Args: raw_node_id: The raw ID of the node raw_parent_id: The raw ID of the parent node display_name: Human-readable name link: URL to the node in Notion node_type: Type of hierarchy node Returns: HierarchyNode if new, None if already yielded """ if raw_node_id in self.seen_hierarchy_node_raw_ids: return None self.seen_hierarchy_node_raw_ids.add(raw_node_id) return HierarchyNode( raw_node_id=raw_node_id, raw_parent_id=raw_parent_id, display_name=display_name, link=link, node_type=node_type, ) @staticmethod def _properties_to_str(properties: dict[str, Any]) -> str: """Converts Notion properties to a string""" def _recurse_list_properties(inner_list: list[Any]) -> str | None: list_properties: list[str | None] = [] for item in inner_list: if item and isinstance(item, dict): list_properties.append(_recurse_properties(item)) elif item and isinstance(item, list): list_properties.append(_recurse_list_properties(item)) else: list_properties.append(str(item)) return ( ", ".join( [ list_property for list_property in list_properties if list_property ] ) or None ) def _recurse_properties(inner_dict: dict[str, Any]) -> str | None: sub_inner_dict: dict[str, Any] | list[Any] | str = inner_dict while isinstance(sub_inner_dict, dict) and "type" in sub_inner_dict: type_name = sub_inner_dict["type"] sub_inner_dict = sub_inner_dict[type_name] # If the innermost layer is None, the value is not set if not sub_inner_dict: return None # TODO there may be more types to handle here if isinstance(sub_inner_dict, list): return _recurse_list_properties(sub_inner_dict) elif isinstance(sub_inner_dict, str): # For some objects the innermost value could just be a string, not sure what causes this return sub_inner_dict elif isinstance(sub_inner_dict, dict): if "name" in sub_inner_dict: return sub_inner_dict["name"] if "content" in sub_inner_dict: return sub_inner_dict["content"] start = sub_inner_dict.get("start") end = sub_inner_dict.get("end") if start is not None: if end is not None: return f"{start} - {end}" return start elif end is not None: return f"Until {end}" if "id" in sub_inner_dict: # This is not useful to index, it's a reference to another Notion object # and this ID value in plaintext is useless outside of the Notion context logger.debug("Skipping Notion object id field property") return None logger.debug(f"Unreadable property from innermost prop: {sub_inner_dict}") return None result = "" for prop_name, prop in properties.items(): if not prop or not isinstance(prop, dict): continue try: inner_value = _recurse_properties(prop) except Exception as e: # This is not a critical failure, these properties are not the actual contents of the page # more similar to metadata logger.warning(f"Error recursing properties for {prop_name}: {e}") continue # Not a perfect way to format Notion database tables but there's no perfect representation # since this must be represented as plaintext if inner_value: result += f"{prop_name}: {inner_value}\t" return result def _read_pages_from_database( self, database_id: str, database_parent_raw_id: str | None = None, database_name: str | None = None, ) -> BlockReadOutput: """Returns blocks, page IDs, and hierarchy nodes from a database. Args: database_id: The ID of the database database_parent_raw_id: The raw ID of the database's parent (containing page or workspace) database_name: The name of the database (from child_database block title) """ result_blocks: list[NotionBlock] = [] result_pages: list[str] = [] hierarchy_nodes: list[HierarchyNode] = [] # Create hierarchy node for this database if not already yielded. # Notion URLs omit dashes from UUIDs: https://notion.so/17ab3186873d418fb899c3f6a43f68de db_node = self._maybe_yield_hierarchy_node( raw_node_id=database_id, raw_parent_id=database_parent_raw_id or self.workspace_id, display_name=database_name or f"Database {database_id}", link=f"https://notion.so/{database_id.replace('-', '')}", node_type=HierarchyNodeType.DATABASE, ) if db_node: hierarchy_nodes.append(db_node) # Discover all data sources under this database, then query each one. # Even legacy single-source databases have one entry in the array. data_sources = self._fetch_data_sources_for_database(database_id) if not data_sources: logger.warning( f"Database '{database_id}' returned zero data sources — " f"no pages will be indexed from this database." ) for ds in data_sources: self._data_source_to_database_map[ds.id] = database_id cursor = None while True: data = self._fetch_data_source(ds.id, cursor) for result in data["results"]: obj_id = result["id"] obj_type = result["object"] text = self._properties_to_str(result.get("properties", {})) if text: result_blocks.append( NotionBlock(id=obj_id, text=text, prefix="\n") ) if not self.recursive_index_enabled: continue if obj_type == "page": logger.debug( f"Found page with ID '{obj_id}' in database '{database_id}'" ) result_pages.append(result["id"]) elif obj_type == "database": logger.debug( f"Found database with ID '{obj_id}' in database '{database_id}'" ) nested_db_title = result.get("title", []) nested_db_name = None if nested_db_title and len(nested_db_title) > 0: nested_db_name = ( nested_db_title[0].get("text", {}).get("content") ) nested_output = self._read_pages_from_database( obj_id, database_parent_raw_id=database_id, database_name=nested_db_name, ) result_pages.extend(nested_output.child_page_ids) hierarchy_nodes.extend(nested_output.hierarchy_nodes) if data["next_cursor"] is None: break cursor = data["next_cursor"] return BlockReadOutput( blocks=result_blocks, child_page_ids=result_pages, hierarchy_nodes=hierarchy_nodes, ) def _read_blocks( self, base_block_id: str, containing_page_id: str | None = None ) -> BlockReadOutput: """Reads all child blocks for the specified block. Args: base_block_id: The block ID to read children from containing_page_id: The ID of the page that contains this block tree. Used to correctly map child pages/databases to their parent page rather than intermediate block IDs. """ # If no containing_page_id provided, assume base_block_id is the page itself page_id = containing_page_id or base_block_id result_blocks: list[NotionBlock] = [] child_pages: list[str] = [] hierarchy_nodes: list[HierarchyNode] = [] cursor = None while True: data = self._fetch_child_blocks(base_block_id, cursor) # this happens when a block is not shared with the integration if data is None: return BlockReadOutput( blocks=result_blocks, child_page_ids=child_pages, hierarchy_nodes=hierarchy_nodes, ) for result in data["results"]: logger.debug( f"Found child block for block with ID '{base_block_id}': {result}" ) result_block_id = result["id"] result_type = result["type"] result_obj = result[result_type] if result_type == "ai_block": logger.warning( f"Skipping 'ai_block' ('{result_block_id}') for base block '{base_block_id}': " f"Notion API does not currently support reading AI blocks (as of 24/02/09) " f"(discussion: https://github.com/onyx-dot-app/onyx/issues/1053)" ) continue if result_type == "unsupported": logger.warning( f"Skipping unsupported block type '{result_type}' " f"('{result_block_id}') for base block '{base_block_id}': " f"(discussion: https://github.com/onyx-dot-app/onyx/issues/1230)" ) continue if result_type == "external_object_instance_page": logger.warning( f"Skipping 'external_object_instance_page' ('{result_block_id}') for base block '{base_block_id}': " f"Notion API does not currently support reading external blocks (as of 24/07/03) " f"(discussion: https://github.com/onyx-dot-app/onyx/issues/1761)" ) continue cur_result_text_arr = [] if "rich_text" in result_obj: for rich_text in result_obj["rich_text"]: # skip if doesn't have text object if "text" in rich_text: text = rich_text["text"]["content"] cur_result_text_arr.append(text) if result["has_children"]: if result_type == "child_page": # Child pages will not be included at this top level, it will be a separate document. # Track parent page so we can resolve block_id parents later. # Use page_id (not base_block_id) to ensure we map to the containing page, # not an intermediate block like a toggle or callout. child_pages.append(result_block_id) self._child_page_parent_map[result_block_id] = page_id else: logger.debug(f"Entering sub-block: {result_block_id}") sub_output = self._read_blocks(result_block_id, page_id) logger.debug(f"Finished sub-block: {result_block_id}") result_blocks.extend(sub_output.blocks) child_pages.extend(sub_output.child_page_ids) hierarchy_nodes.extend(sub_output.hierarchy_nodes) if result_type == "child_database": # Extract database name from the child_database block db_title = result_obj.get("title", "") db_output = self._read_pages_from_database( result_block_id, database_parent_raw_id=page_id, # Parent is the containing page database_name=db_title or None, ) # A database on a page often looks like a table, we need to include it for the contents # of the page but the children (cells) should be processed as other Documents result_blocks.extend(db_output.blocks) hierarchy_nodes.extend(db_output.hierarchy_nodes) if self.recursive_index_enabled: child_pages.extend(db_output.child_page_ids) if cur_result_text_arr: new_block = NotionBlock( id=result_block_id, text="\n".join(cur_result_text_arr), prefix="\n", ) result_blocks.append(new_block) if data["next_cursor"] is None: break cursor = data["next_cursor"] return BlockReadOutput( blocks=result_blocks, child_page_ids=child_pages, hierarchy_nodes=hierarchy_nodes, ) def _read_page_title(self, page: NotionPage) -> str | None: """Extracts the title from a Notion page""" page_title = None if hasattr(page, "database_name") and page.database_name: return page.database_name for _, prop in page.properties.items(): if prop["type"] == "title" and len(prop["title"]) > 0: page_title = " ".join([t["plain_text"] for t in prop["title"]]).strip() break return page_title def _read_pages( self, pages: list[NotionPage], ) -> Generator[Document | HierarchyNode, None, None]: """Reads pages for rich text content and generates Documents and HierarchyNodes Note that a page which is turned into a "wiki" becomes a database but both top level pages and top level databases do not seem to have any properties associated with them. Pages that are part of a database can have properties which are like the values of the row in the "database" table in which they exist This is not clearly outlined in the Notion API docs but it is observable empirically. https://developers.notion.com/docs/working-with-page-content """ all_child_page_ids: list[str] = [] for page in pages: if page.id in self.indexed_pages: logger.debug(f"Already indexed page with ID '{page.id}'. Skipping.") continue logger.info(f"Reading page with ID '{page.id}', with url {page.url}") block_output = self._read_blocks(page.id) all_child_page_ids.extend(block_output.child_page_ids) # okay to mark here since there's no way for this to not succeed # without a critical failure self.indexed_pages.add(page.id) raw_page_title = self._read_page_title(page) page_title = raw_page_title or f"Untitled Page with ID {page.id}" parent_raw_id = self._get_parent_raw_id(page.parent, page_id=page.id) # If this page has children (pages or databases), yield it as a hierarchy node FIRST # This ensures parent nodes are created before child documents reference them if block_output.child_page_ids or block_output.hierarchy_nodes: hierarchy_node = self._maybe_yield_hierarchy_node( raw_node_id=page.id, raw_parent_id=parent_raw_id, display_name=page_title, link=page.url, node_type=HierarchyNodeType.PAGE, ) if hierarchy_node: yield hierarchy_node # Yield database hierarchy nodes discovered in this page's blocks for db_node in block_output.hierarchy_nodes: yield db_node if not block_output.blocks: if not raw_page_title: logger.warning( f"No blocks OR title found for page with ID '{page.id}'. Skipping." ) continue logger.debug(f"No blocks found for page with ID '{page.id}'") """ Something like: TITLE PROP1: PROP1_VALUE PROP2: PROP2_VALUE """ text = page_title if page.properties: text += "\n\n" + "\n".join( [f"{key}: {value}" for key, value in page.properties.items()] ) sections = [ TextSection( link=f"{page.url}", text=text, ) ] else: sections = [ TextSection( link=f"{page.url}#{block.id.replace('-', '')}", text=block.prefix + block.text, ) for block in block_output.blocks ] yield ( Document( id=page.id, sections=cast(list[TextSection | ImageSection], sections), source=DocumentSource.NOTION, semantic_identifier=page_title, doc_updated_at=datetime.fromisoformat( page.last_edited_time ).astimezone(timezone.utc), metadata={}, parent_hierarchy_raw_node_id=parent_raw_id, ) ) self.indexed_pages.add(page.id) if self.recursive_index_enabled and all_child_page_ids: # NOTE: checking if page_id is in self.indexed_pages to prevent extra # calls to `_fetch_page` for pages we've already indexed for child_page_batch_ids in batch_generator( all_child_page_ids, batch_size=INDEX_BATCH_SIZE ): child_page_batch = [ self._fetch_page(page_id) for page_id in child_page_batch_ids if page_id not in self.indexed_pages ] yield from self._read_pages(child_page_batch) @retry(tries=3, delay=1, backoff=2) def _search_notion(self, query_dict: dict[str, Any]) -> NotionSearchResponse: """Search for pages from a Notion database. Includes some small number of retries to handle misc, flakey failures.""" logger.debug(f"Searching for pages in Notion with query_dict: {query_dict}") res = rl_requests.post( "https://api.notion.com/v1/search", headers=self.headers, json=query_dict, timeout=_NOTION_CALL_TIMEOUT, ) res.raise_for_status() return NotionSearchResponse(**res.json()) # The | Document is needed for mypy type checking def _yield_database_hierarchy_nodes( self, ) -> Generator[HierarchyNode | Document, None, None]: """Search for all data sources and yield hierarchy nodes for their parent databases. This must be called BEFORE page indexing so that database hierarchy nodes exist when pages inside databases reference them as parents. With the new API, search returns data source objects instead of databases. Multiple data sources can share the same parent database, so we use database_id as the hierarchy node key and deduplicate via _maybe_yield_hierarchy_node. """ query_dict: dict[str, Any] = { "filter": {"property": "object", "value": "data_source"}, "page_size": _NOTION_PAGE_SIZE, } pages_seen = 0 while pages_seen < _MAX_PAGES: db_res = self._search_notion(query_dict) for ds in db_res.results: # Extract the parent database_id from the data source's parent ds_parent = ds.get("parent", {}) db_id = ds_parent.get("database_id") if not db_id: continue # Populate the mapping so _get_parent_raw_id can resolve later ds_id = ds.get("id") if not ds_id: continue self._data_source_to_database_map[ds_id] = db_id # Fetch the database to get its actual name and parent try: db_page = self._fetch_database_as_page(db_id) db_name = db_page.database_name or f"Database {db_id}" parent_raw_id = self._get_parent_raw_id(db_page.parent) db_url = ( db_page.url or f"https://notion.so/{db_id.replace('-', '')}" ) except requests.exceptions.RequestException as e: logger.warning( f"Could not fetch database '{db_id}', " f"defaulting to workspace root. Error: {e}" ) db_name = f"Database {db_id}" parent_raw_id = self.workspace_id db_url = f"https://notion.so/{db_id.replace('-', '')}" # _maybe_yield_hierarchy_node deduplicates by raw_node_id, # so multiple data sources under one database produce one node. node = self._maybe_yield_hierarchy_node( raw_node_id=db_id, raw_parent_id=parent_raw_id or self.workspace_id, display_name=db_name, link=db_url, node_type=HierarchyNodeType.DATABASE, ) if node: yield node if not db_res.has_more: break query_dict["start_cursor"] = db_res.next_cursor pages_seen += 1 def _filter_pages_by_time( self, pages: list[dict[str, Any]], start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, filter_field: str = "last_edited_time", ) -> list[NotionPage]: """A helper function to filter out pages outside of a time range. This functionality doesn't yet exist in the Notion Search API, but when it does, this approach can be deprecated. Arguments: pages (list[dict]) - Pages to filter start (float) - start epoch time to filter from end (float) - end epoch time to filter to filter_field (str) - the attribute on the page to apply the filter """ filtered_pages: list[NotionPage] = [] for page in pages: # Parse ISO 8601 timestamp and convert to UTC epoch time timestamp = page[filter_field].replace(".000Z", "+00:00") compare_time = datetime.fromisoformat(timestamp).timestamp() if compare_time > start and compare_time <= end: filtered_pages += [NotionPage(**page)] return filtered_pages def _recursive_load(self) -> GenerateDocumentsOutput: if self.root_page_id is None or not self.recursive_index_enabled: raise RuntimeError( "Recursive page lookup is not enabled, but we are trying to recursively load pages. This should never happen." ) # Yield workspace hierarchy node FIRST before any pages workspace_node = self._get_workspace_hierarchy_node() if workspace_node: yield [workspace_node] logger.info( f"Recursively loading pages from Notion based on root page with ID: {self.root_page_id}" ) pages = [self._fetch_page(page_id=self.root_page_id)] yield from batch_generator(self._read_pages(pages), self.batch_size) def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: """Applies integration token to headers""" self.headers["Authorization"] = ( f"Bearer {credentials['notion_integration_token']}" ) return None def load_from_state(self) -> GenerateDocumentsOutput: """Loads all page data from a Notion workspace. Returns: list[Document]: list of documents. """ # TODO: remove once Notion search issue is discovered if self.recursive_index_enabled and self.root_page_id: yield from self._recursive_load() return # Yield workspace hierarchy node FIRST before any pages workspace_node = self._get_workspace_hierarchy_node() if workspace_node: yield [workspace_node] # Yield database hierarchy nodes BEFORE pages so parent references resolve yield from batch_generator( self._yield_database_hierarchy_nodes(), self.batch_size ) query_dict: dict[str, Any] = { "filter": {"property": "object", "value": "page"}, "page_size": _NOTION_PAGE_SIZE, } while True: db_res = self._search_notion(query_dict) pages = [NotionPage(**page) for page in db_res.results] yield from batch_generator(self._read_pages(pages), self.batch_size) if db_res.has_more: query_dict["start_cursor"] = db_res.next_cursor else: break def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: """Uses the Notion search API to fetch updated pages within a time period. Unfortunately the search API doesn't yet support filtering by times, so until they add that, we're just going to page through results until, we reach ones that are older than our search criteria. """ # TODO: remove once Notion search issue is discovered if self.recursive_index_enabled and self.root_page_id: yield from self._recursive_load() return # Yield workspace hierarchy node FIRST before any pages workspace_node = self._get_workspace_hierarchy_node() if workspace_node: yield [workspace_node] # Yield database hierarchy nodes BEFORE pages so parent references resolve. # We yield all databases without time filtering because a page's parent # database might not have been edited even if the page was. yield from batch_generator( self._yield_database_hierarchy_nodes(), self.batch_size ) query_dict: dict[str, Any] = { "page_size": _NOTION_PAGE_SIZE, "sort": {"timestamp": "last_edited_time", "direction": "descending"}, "filter": {"property": "object", "value": "page"}, } while True: db_res = self._search_notion(query_dict) pages = self._filter_pages_by_time( db_res.results, start, end, filter_field="last_edited_time" ) if len(pages) > 0: yield from batch_generator(self._read_pages(pages), self.batch_size) if db_res.has_more: query_dict["start_cursor"] = db_res.next_cursor else: break else: break def validate_connector_settings(self) -> None: if not self.headers.get("Authorization"): raise ConnectorMissingCredentialError("Notion credentials not loaded.") try: # We'll do a minimal search call (page_size=1) to confirm accessibility if self.root_page_id: # If root_page_id is set, fetch the specific page res = rl_requests.get( f"https://api.notion.com/v1/pages/{self.root_page_id}", headers=self.headers, timeout=_NOTION_CALL_TIMEOUT, ) else: # If root_page_id is not set, perform a minimal search test_query = { "filter": {"property": "object", "value": "page"}, "page_size": 1, } res = rl_requests.post( "https://api.notion.com/v1/search", headers=self.headers, json=test_query, timeout=_NOTION_CALL_TIMEOUT, ) res.raise_for_status() except requests.exceptions.HTTPError as http_err: status_code = http_err.response.status_code if http_err.response else None if status_code == 401: raise CredentialExpiredError( "Notion credential appears to be invalid or expired (HTTP 401)." ) elif status_code == 403: raise InsufficientPermissionsError( "Your Notion token does not have sufficient permissions (HTTP 403)." ) elif status_code == 404: # Typically means resource not found or not shared. Could be root_page_id is invalid. raise ConnectorValidationError( "Notion resource not found or not shared with the integration (HTTP 404)." ) elif status_code == 429: raise ConnectorValidationError( "Validation failed due to Notion rate-limits being exceeded (HTTP 429). Please try again later." ) else: raise UnexpectedValidationError( f"Unexpected Notion HTTP error (status={status_code}): {http_err}" ) from http_err except Exception as exc: raise UnexpectedValidationError( f"Unexpected error during Notion settings validation: {exc}" ) if __name__ == "__main__": import os root_page_id = os.environ.get("NOTION_ROOT_PAGE_ID") connector = NotionConnector(root_page_id=root_page_id) connector.load_credentials( {"notion_integration_token": os.environ.get("NOTION_INTEGRATION_TOKEN")} ) document_batches = connector.load_from_state() for doc_batch in document_batches: for doc in doc_batch: print(doc) ================================================ FILE: backend/onyx/connectors/outline/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/outline/client.py ================================================ from typing import Any import requests from requests.exceptions import ConnectionError as RequestsConnectionError from requests.exceptions import RequestException from requests.exceptions import Timeout from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS class OutlineClientRequestFailedError(ConnectionError): """Custom error class for handling failed requests to the Outline API with status code and error message""" def __init__(self, status: int, error: str) -> None: self.status_code = status self.error = error super().__init__(f"Outline Client request failed with status {status}: {error}") class OutlineApiClient: """Client for interacting with the Outline API. Handles authentication and making HTTP requests.""" def __init__( self, api_token: str, base_url: str, ) -> None: self.base_url = base_url.rstrip("/") self.api_token = api_token def post(self, endpoint: str, data: dict[str, Any] | None = None) -> dict[str, Any]: if data is None: data = {} url: str = self._build_url(endpoint) headers = self._build_headers() try: response = requests.post( url, headers=headers, json=data, timeout=REQUEST_TIMEOUT_SECONDS ) except Timeout: raise OutlineClientRequestFailedError( 408, f"Request timed out - server did not respond within {REQUEST_TIMEOUT_SECONDS} seconds", ) except RequestsConnectionError as e: raise OutlineClientRequestFailedError( -1, f"Connection error - unable to reach Outline server: {e}" ) except RequestException as e: raise OutlineClientRequestFailedError(-1, f"Network error occurred: {e}") if response.status_code >= 300: error = response.reason try: response_json = response.json() if isinstance(response_json, dict): response_error = response_json.get("error", {}).get("message", "") if response_error: error = response_error except Exception: # If JSON parsing fails, fall back to response.text for better debugging if response.text.strip(): error = f"{response.reason}: {response.text.strip()}" raise OutlineClientRequestFailedError(response.status_code, error) try: return response.json() except Exception: raise OutlineClientRequestFailedError( response.status_code, f"Response was successful but contained invalid JSON: {response.text}", ) def _build_headers(self) -> dict[str, str]: return { "Authorization": f"Bearer {self.api_token}", "Accept": "application/json", "Content-Type": "application/json", } def _build_url(self, endpoint: str) -> str: return self.base_url.rstrip("/") + "/api/" + endpoint.lstrip("/") def build_app_url(self, endpoint: str) -> str: return self.base_url.rstrip("/") + "/" + endpoint.lstrip("/") ================================================ FILE: backend/onyx/connectors/outline/connector.py ================================================ import html import time from collections.abc import Callable from typing import Any from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import TextSection from onyx.connectors.outline.client import OutlineApiClient from onyx.connectors.outline.client import OutlineClientRequestFailedError class OutlineConnector(LoadConnector, PollConnector): """Connector for Outline knowledge base. Handles authentication, document loading and polling. Implements both LoadConnector for initial state loading and PollConnector for incremental updates. """ def __init__( self, batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.batch_size = batch_size self.outline_client: OutlineApiClient | None = None def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: required_keys = ["outline_api_token", "outline_base_url"] for key in required_keys: if key not in credentials: raise ConnectorMissingCredentialError("Outline") self.outline_client = OutlineApiClient( api_token=credentials["outline_api_token"], base_url=credentials["outline_base_url"], ) return None @staticmethod def _get_doc_batch( batch_size: int, outline_client: OutlineApiClient, endpoint: str, transformer: Callable[[OutlineApiClient, dict], Document], start_ind: int, ) -> tuple[list[Document], int]: data = { "limit": batch_size, "offset": start_ind, } batch = outline_client.post(endpoint, data=data).get("data", []) doc_batch = [transformer(outline_client, item) for item in batch] return doc_batch, len(batch) @staticmethod def _collection_to_document( outline_client: OutlineApiClient, collection: dict[str, Any] ) -> Document: url = outline_client.build_app_url(f"/collection/{collection.get('id')}") title = str(collection.get("name", "")) name = collection.get("name") or "" description = collection.get("description") or "" text = name + "\n" + description updated_at_str = ( str(collection.get("updatedAt")) if collection.get("updatedAt") is not None else None ) return Document( id="outline_collection__" + str(collection.get("id")), sections=[TextSection(link=url, text=html.unescape(text))], source=DocumentSource.OUTLINE, semantic_identifier="Collection: " + title, title=title, doc_updated_at=( time_str_to_utc(updated_at_str) if updated_at_str is not None else None ), metadata={"type": "collection"}, ) @staticmethod def _document_to_document( outline_client: OutlineApiClient, document: dict[str, Any] ) -> Document: url = outline_client.build_app_url(f"/doc/{document.get('id')}") title = str(document.get("title", "")) doc_title = document.get("title") or "" doc_text = document.get("text") or "" text = doc_title + "\n" + doc_text updated_at_str = ( str(document.get("updatedAt")) if document.get("updatedAt") is not None else None ) return Document( id="outline_document__" + str(document.get("id")), sections=[TextSection(link=url, text=html.unescape(text))], source=DocumentSource.OUTLINE, semantic_identifier="Document: " + title, title=title, doc_updated_at=( time_str_to_utc(updated_at_str) if updated_at_str is not None else None ), metadata={"type": "document"}, ) def load_from_state(self) -> GenerateDocumentsOutput: if self.outline_client is None: raise ConnectorMissingCredentialError("Outline") return self._fetch_documents() def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: if self.outline_client is None: raise ConnectorMissingCredentialError("Outline") # Outline API does not support date-based filtering natively, # so we implement client-side filtering after fetching documents def time_filter(doc: Document) -> bool: if doc.doc_updated_at is None: return False doc_timestamp = doc.doc_updated_at.timestamp() if doc_timestamp < start: return False if doc_timestamp > end: return False return True return self._fetch_documents(time_filter) def _fetch_documents( self, time_filter: Callable[[Document], bool] | None = None ) -> GenerateDocumentsOutput: if self.outline_client is None: raise ConnectorMissingCredentialError("Outline") transform_by_endpoint: dict[ str, Callable[[OutlineApiClient, dict], Document] ] = { "documents.list": self._document_to_document, "collections.list": self._collection_to_document, } for endpoint, transform in transform_by_endpoint.items(): start_ind = 0 while True: doc_batch, num_results = self._get_doc_batch( batch_size=self.batch_size, outline_client=self.outline_client, endpoint=endpoint, transformer=transform, start_ind=start_ind, ) # Apply time filtering if specified filtered_batch: list[Document | HierarchyNode] = [] for doc in doc_batch: if time_filter is None or time_filter(doc): filtered_batch.append(doc) start_ind += num_results if filtered_batch: yield filtered_batch if num_results < self.batch_size: break else: time.sleep(0.2) def validate_connector_settings(self) -> None: """ Validate that the Outline credentials and connector settings are correct. Specifically checks that we can make an authenticated request to Outline. """ if not self.outline_client: raise ConnectorMissingCredentialError("Outline") try: # Use auth.info endpoint for validation _ = self.outline_client.post("auth.info", data={}) except OutlineClientRequestFailedError as e: # Check for HTTP status codes if e.status_code == 401: raise CredentialExpiredError( "Your Outline credentials appear to be invalid or expired (HTTP 401)." ) from e elif e.status_code == 403: raise InsufficientPermissionsError( "The configured Outline token does not have sufficient permissions (HTTP 403)." ) from e else: raise ConnectorValidationError( f"Unexpected Outline error (status={e.status_code}): {e}" ) from e except Exception as exc: raise ConnectorValidationError( f"Unexpected error while validating Outline connector settings: {exc}" ) from exc ================================================ FILE: backend/onyx/connectors/productboard/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/productboard/connector.py ================================================ from collections.abc import Generator from itertools import chain from typing import Any from typing import cast import requests from bs4 import BeautifulSoup from dateutil import parser from retry import retry from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import TextSection from onyx.utils.logger import setup_logger logger = setup_logger() _PRODUCT_BOARD_BASE_URL = "https://api.productboard.com" class ProductboardApiError(Exception): pass class ProductboardConnector(PollConnector): def __init__( self, batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.batch_size = batch_size self.access_token: str | None = None def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self.access_token = credentials["productboard_access_token"] return None def _build_headers(self) -> dict[str, str]: return { "Authorization": f"Bearer {self.access_token}", "X-Version": "1", } @staticmethod def _parse_description_html(description_html: str) -> str: soup = BeautifulSoup(description_html, "html.parser") return soup.get_text() @staticmethod def _get_owner_email(productboard_obj: dict[str, Any]) -> str | None: owner_dict = cast(dict[str, str] | None, productboard_obj.get("owner")) if not owner_dict: return None return owner_dict.get("email") def _fetch_documents( self, initial_link: str, ) -> Generator[dict[str, Any], None, None]: headers = self._build_headers() @retry(tries=3, delay=1, backoff=2) def fetch(link: str) -> dict[str, Any]: response = requests.get(link, headers=headers) if not response.ok: # rate-limiting is at 50 requests per second. # The delay in this retry should handle this while this is # not parallelized. raise ProductboardApiError( f"Failed to fetch from productboard - status code: {response.status_code} - response: {response.text}" ) return response.json() curr_link = initial_link while True: response_json = fetch(curr_link) for entity in response_json["data"]: yield entity curr_link = response_json.get("links", {}).get("next") if not curr_link: break def _get_features(self) -> Generator[Document, None, None]: """A Feature is like a ticket in Jira""" for feature in self._fetch_documents( initial_link=f"{_PRODUCT_BOARD_BASE_URL}/features" ): owner = self._get_owner_email(feature) experts = [BasicExpertInfo(email=owner)] if owner else None metadata: dict[str, str | list[str]] = {} entity_type = feature.get("type", "feature") if entity_type: metadata["entity_type"] = str(entity_type) status = feature.get("status", {}).get("name") if status: metadata["status"] = str(status) yield Document( id=feature["id"], sections=[ TextSection( link=feature["links"]["html"], text=self._parse_description_html(feature["description"]), ) ], semantic_identifier=feature["name"], source=DocumentSource.PRODUCTBOARD, doc_updated_at=time_str_to_utc(feature["updatedAt"]), primary_owners=experts, metadata=metadata, ) def _get_components(self) -> Generator[Document, None, None]: """A Component is like an epic in Jira. It contains Features""" for component in self._fetch_documents( initial_link=f"{_PRODUCT_BOARD_BASE_URL}/components" ): owner = self._get_owner_email(component) experts = [BasicExpertInfo(email=owner)] if owner else None yield Document( id=component["id"], sections=[ TextSection( link=component["links"]["html"], text=self._parse_description_html(component["description"]), ) ], semantic_identifier=component["name"], source=DocumentSource.PRODUCTBOARD, doc_updated_at=time_str_to_utc(component["updatedAt"]), primary_owners=experts, metadata={ "entity_type": "component", }, ) def _get_products(self) -> Generator[Document, None, None]: """A Product is the highest level of organization. A Product contains components, which contains features.""" for product in self._fetch_documents( initial_link=f"{_PRODUCT_BOARD_BASE_URL}/products" ): owner = self._get_owner_email(product) experts = [BasicExpertInfo(email=owner)] if owner else None yield Document( id=product["id"], sections=[ TextSection( link=product["links"]["html"], text=self._parse_description_html(product["description"]), ) ], semantic_identifier=product["name"], source=DocumentSource.PRODUCTBOARD, doc_updated_at=time_str_to_utc(product["updatedAt"]), primary_owners=experts, metadata={ "entity_type": "product", }, ) def _get_objectives(self) -> Generator[Document, None, None]: for objective in self._fetch_documents( initial_link=f"{_PRODUCT_BOARD_BASE_URL}/objectives" ): owner = self._get_owner_email(objective) experts = [BasicExpertInfo(email=owner)] if owner else None metadata: dict[str, str | list[str]] = { "entity_type": "objective", } if objective.get("state"): metadata["state"] = str(objective["state"]) yield Document( id=objective["id"], sections=[ TextSection( link=objective["links"]["html"], text=self._parse_description_html(objective["description"]), ) ], semantic_identifier=objective["name"], source=DocumentSource.PRODUCTBOARD, doc_updated_at=time_str_to_utc(objective["updatedAt"]), primary_owners=experts, metadata=metadata, ) def _is_updated_at_out_of_time_range( self, document: Document, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, ) -> bool: updated_at = cast(str, document.metadata.get("updated_at", "")) if updated_at: updated_at_datetime = parser.parse(updated_at) if ( updated_at_datetime.timestamp() < start or updated_at_datetime.timestamp() > end ): return True else: logger.debug(f"Unable to find updated_at for document '{document.id}'") return False def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: if self.access_token is None: raise PermissionError( "Access token is not set up, was load_credentials called?" ) document_batch: list[Document | HierarchyNode] = [] # NOTE: there is a concept of a "Note" in productboard, however # there is no read API for it atm. Additionally, comments are not # included with features. Finally, "Releases" are not fetched atm, # since they do not provide an updatedAt. feature_documents = self._get_features() component_documents = self._get_components() product_documents = self._get_products() objective_documents = self._get_objectives() for document in chain( feature_documents, component_documents, product_documents, objective_documents, ): # skip documents that are not in the time range if self._is_updated_at_out_of_time_range(document, start, end): continue document_batch.append(document) if len(document_batch) >= self.batch_size: yield document_batch document_batch = [] if document_batch: yield document_batch if __name__ == "__main__": import os import time connector = ProductboardConnector() connector.load_credentials( { "productboard_access_token": os.environ["PRODUCTBOARD_ACCESS_TOKEN"], } ) current = time.time() one_year_ago = current - 24 * 60 * 60 * 360 latest_docs = connector.poll_source(one_year_ago, current) print(next(latest_docs)) ================================================ FILE: backend/onyx/connectors/registry.py ================================================ """Registry mapping for connector classes.""" from pydantic import BaseModel from onyx.configs.constants import DocumentSource class ConnectorMapping(BaseModel): module_path: str class_name: str # Mapping of DocumentSource to connector details for lazy loading CONNECTOR_CLASS_MAP = { DocumentSource.WEB: ConnectorMapping( module_path="onyx.connectors.web.connector", class_name="WebConnector", ), DocumentSource.FILE: ConnectorMapping( module_path="onyx.connectors.file.connector", class_name="LocalFileConnector", ), DocumentSource.SLACK: ConnectorMapping( module_path="onyx.connectors.slack.connector", class_name="SlackConnector", ), DocumentSource.GITHUB: ConnectorMapping( module_path="onyx.connectors.github.connector", class_name="GithubConnector", ), DocumentSource.GMAIL: ConnectorMapping( module_path="onyx.connectors.gmail.connector", class_name="GmailConnector", ), DocumentSource.GITLAB: ConnectorMapping( module_path="onyx.connectors.gitlab.connector", class_name="GitlabConnector", ), DocumentSource.GITBOOK: ConnectorMapping( module_path="onyx.connectors.gitbook.connector", class_name="GitbookConnector", ), DocumentSource.GOOGLE_DRIVE: ConnectorMapping( module_path="onyx.connectors.google_drive.connector", class_name="GoogleDriveConnector", ), DocumentSource.BOOKSTACK: ConnectorMapping( module_path="onyx.connectors.bookstack.connector", class_name="BookstackConnector", ), DocumentSource.OUTLINE: ConnectorMapping( module_path="onyx.connectors.outline.connector", class_name="OutlineConnector", ), DocumentSource.CONFLUENCE: ConnectorMapping( module_path="onyx.connectors.confluence.connector", class_name="ConfluenceConnector", ), DocumentSource.JIRA: ConnectorMapping( module_path="onyx.connectors.jira.connector", class_name="JiraConnector", ), DocumentSource.PRODUCTBOARD: ConnectorMapping( module_path="onyx.connectors.productboard.connector", class_name="ProductboardConnector", ), DocumentSource.SLAB: ConnectorMapping( module_path="onyx.connectors.slab.connector", class_name="SlabConnector", ), DocumentSource.CODA: ConnectorMapping( module_path="onyx.connectors.coda.connector", class_name="CodaConnector", ), DocumentSource.CANVAS: ConnectorMapping( module_path="onyx.connectors.canvas.connector", class_name="CanvasConnector", ), DocumentSource.NOTION: ConnectorMapping( module_path="onyx.connectors.notion.connector", class_name="NotionConnector", ), DocumentSource.ZULIP: ConnectorMapping( module_path="onyx.connectors.zulip.connector", class_name="ZulipConnector", ), DocumentSource.GURU: ConnectorMapping( module_path="onyx.connectors.guru.connector", class_name="GuruConnector", ), DocumentSource.LINEAR: ConnectorMapping( module_path="onyx.connectors.linear.connector", class_name="LinearConnector", ), DocumentSource.HUBSPOT: ConnectorMapping( module_path="onyx.connectors.hubspot.connector", class_name="HubSpotConnector", ), DocumentSource.DOCUMENT360: ConnectorMapping( module_path="onyx.connectors.document360.connector", class_name="Document360Connector", ), DocumentSource.GONG: ConnectorMapping( module_path="onyx.connectors.gong.connector", class_name="GongConnector", ), DocumentSource.GOOGLE_SITES: ConnectorMapping( module_path="onyx.connectors.google_site.connector", class_name="GoogleSitesConnector", ), DocumentSource.ZENDESK: ConnectorMapping( module_path="onyx.connectors.zendesk.connector", class_name="ZendeskConnector", ), DocumentSource.LOOPIO: ConnectorMapping( module_path="onyx.connectors.loopio.connector", class_name="LoopioConnector", ), DocumentSource.DROPBOX: ConnectorMapping( module_path="onyx.connectors.dropbox.connector", class_name="DropboxConnector", ), DocumentSource.SHAREPOINT: ConnectorMapping( module_path="onyx.connectors.sharepoint.connector", class_name="SharepointConnector", ), DocumentSource.TEAMS: ConnectorMapping( module_path="onyx.connectors.teams.connector", class_name="TeamsConnector", ), DocumentSource.SALESFORCE: ConnectorMapping( module_path="onyx.connectors.salesforce.connector", class_name="SalesforceConnector", ), DocumentSource.DISCOURSE: ConnectorMapping( module_path="onyx.connectors.discourse.connector", class_name="DiscourseConnector", ), DocumentSource.AXERO: ConnectorMapping( module_path="onyx.connectors.axero.connector", class_name="AxeroConnector", ), DocumentSource.CLICKUP: ConnectorMapping( module_path="onyx.connectors.clickup.connector", class_name="ClickupConnector", ), DocumentSource.MEDIAWIKI: ConnectorMapping( module_path="onyx.connectors.mediawiki.wiki", class_name="MediaWikiConnector", ), DocumentSource.WIKIPEDIA: ConnectorMapping( module_path="onyx.connectors.wikipedia.connector", class_name="WikipediaConnector", ), DocumentSource.ASANA: ConnectorMapping( module_path="onyx.connectors.asana.connector", class_name="AsanaConnector", ), DocumentSource.S3: ConnectorMapping( module_path="onyx.connectors.blob.connector", class_name="BlobStorageConnector", ), DocumentSource.R2: ConnectorMapping( module_path="onyx.connectors.blob.connector", class_name="BlobStorageConnector", ), DocumentSource.GOOGLE_CLOUD_STORAGE: ConnectorMapping( module_path="onyx.connectors.blob.connector", class_name="BlobStorageConnector", ), DocumentSource.OCI_STORAGE: ConnectorMapping( module_path="onyx.connectors.blob.connector", class_name="BlobStorageConnector", ), DocumentSource.XENFORO: ConnectorMapping( module_path="onyx.connectors.xenforo.connector", class_name="XenforoConnector", ), DocumentSource.DISCORD: ConnectorMapping( module_path="onyx.connectors.discord.connector", class_name="DiscordConnector", ), DocumentSource.FRESHDESK: ConnectorMapping( module_path="onyx.connectors.freshdesk.connector", class_name="FreshdeskConnector", ), DocumentSource.FIREFLIES: ConnectorMapping( module_path="onyx.connectors.fireflies.connector", class_name="FirefliesConnector", ), DocumentSource.EGNYTE: ConnectorMapping( module_path="onyx.connectors.egnyte.connector", class_name="EgnyteConnector", ), DocumentSource.AIRTABLE: ConnectorMapping( module_path="onyx.connectors.airtable.airtable_connector", class_name="AirtableConnector", ), DocumentSource.HIGHSPOT: ConnectorMapping( module_path="onyx.connectors.highspot.connector", class_name="HighspotConnector", ), DocumentSource.DRUPAL_WIKI: ConnectorMapping( module_path="onyx.connectors.drupal_wiki.connector", class_name="DrupalWikiConnector", ), DocumentSource.IMAP: ConnectorMapping( module_path="onyx.connectors.imap.connector", class_name="ImapConnector", ), DocumentSource.BITBUCKET: ConnectorMapping( module_path="onyx.connectors.bitbucket.connector", class_name="BitbucketConnector", ), DocumentSource.TESTRAIL: ConnectorMapping( module_path="onyx.connectors.testrail.connector", class_name="TestRailConnector", ), # just for integration tests DocumentSource.MOCK_CONNECTOR: ConnectorMapping( module_path="onyx.connectors.mock_connector.connector", class_name="MockConnector", ), } ================================================ FILE: backend/onyx/connectors/requesttracker/.gitignore ================================================ .env ================================================ FILE: backend/onyx/connectors/requesttracker/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/requesttracker/connector.py ================================================ # from datetime import datetime # from datetime import timezone # from logging import DEBUG as LOG_LVL_DEBUG # from typing import Any # from typing import List # from typing import Optional # from rt.rest1 import ALL_QUEUES # from rt.rest1 import Rt # from onyx.configs.app_configs import INDEX_BATCH_SIZE # from onyx.configs.constants import DocumentSource # from onyx.connectors.interfaces import GenerateDocumentsOutput # from onyx.connectors.interfaces import PollConnector # from onyx.connectors.interfaces import SecondsSinceUnixEpoch # from onyx.connectors.models import ConnectorMissingCredentialError # from onyx.connectors.models import Document # from onyx.connectors.models import Section # from onyx.utils.logger import setup_logger # logger = setup_logger() # class RequestTrackerError(Exception): # pass # class RequestTrackerConnector(PollConnector): # def __init__( # self, # batch_size: int = INDEX_BATCH_SIZE, # ) -> None: # self.batch_size = batch_size # def txn_link(self, tid: int, txn: int) -> str: # return f"{self.rt_base_url}/Ticket/Display.html?id={tid}&txn={txn}" # def build_doc_sections_from_txn( # self, connection: Rt, ticket_id: int # ) -> List[Section]: # Sections: List[Section] = [] # get_history_resp = connection.get_history(ticket_id) # if get_history_resp is None: # raise RequestTrackerError(f"Ticket {ticket_id} cannot be found") # for tx in get_history_resp: # Sections.append( # Section( # link=self.txn_link(ticket_id, int(tx["id"])), # text="\n".join( # [ # f"{k}:\n{v}\n" if k != "Attachments" else "" # for (k, v) in tx.items() # ] # ), # ) # ) # return Sections # def load_credentials(self, credentials: dict[str, Any]) -> Optional[dict[str, Any]]: # self.rt_username = credentials.get("requesttracker_username") # self.rt_password = credentials.get("requesttracker_password") # self.rt_base_url = credentials.get("requesttracker_base_url") # return None # # This does not include RT file attachments yet. # def _process_tickets( # self, start: datetime, end: datetime # ) -> GenerateDocumentsOutput: # if any([self.rt_username, self.rt_password, self.rt_base_url]) is None: # raise ConnectorMissingCredentialError("requesttracker") # Rt0 = Rt( # f"{self.rt_base_url}/REST/1.0/", # self.rt_username, # self.rt_password, # ) # Rt0.login() # d0 = start.strftime("%Y-%m-%d %H:%M:%S") # d1 = end.strftime("%Y-%m-%d %H:%M:%S") # tickets = Rt0.search( # Queue=ALL_QUEUES, # raw_query=f"Updated > '{d0}' AND Updated < '{d1}'", # ) # doc_batch: List[Document] = [] # for ticket in tickets: # ticket_keys_to_omit = ["id", "Subject"] # tid: int = int(ticket["numerical_id"]) # ticketLink: str = f"{self.rt_base_url}/Ticket/Display.html?id={tid}" # logger.info(f"Processing ticket {tid}") # doc = Document( # id=ticket["id"], # # Will add title to the first section later in processing # sections=[Section(link=ticketLink, text="")] # + self.build_doc_sections_from_txn(Rt0, tid), # source=DocumentSource.REQUESTTRACKER, # semantic_identifier=ticket["Subject"], # metadata={ # key: value # for key, value in ticket.items() # if key not in ticket_keys_to_omit # }, # ) # doc_batch.append(doc) # if len(doc_batch) >= self.batch_size: # yield doc_batch # doc_batch = [] # if doc_batch: # yield doc_batch # def poll_source( # self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch # ) -> GenerateDocumentsOutput: # # Keep query short, only look behind 1 day at maximum # one_day_ago: float = end - (24 * 60 * 60) # _start: float = start if start > one_day_ago else one_day_ago # start_datetime = datetime.fromtimestamp(_start, tz=timezone.utc) # end_datetime = datetime.fromtimestamp(end, tz=timezone.utc) # yield from self._process_tickets(start_datetime, end_datetime) # if __name__ == "__main__": # import time # import os # from dotenv import load_dotenv # load_dotenv() # logger.setLevel(LOG_LVL_DEBUG) # rt_connector = RequestTrackerConnector() # rt_connector.load_credentials( # { # "requesttracker_username": os.getenv("RT_USERNAME"), # "requesttracker_password": os.getenv("RT_PASSWORD"), # "requesttracker_base_url": os.getenv("RT_BASE_URL"), # } # ) # current = time.time() # one_day_ago = current - (24 * 60 * 60) # 1 days # latest_docs = rt_connector.poll_source(one_day_ago, current) # for doc in latest_docs: # print(doc) ================================================ FILE: backend/onyx/connectors/salesforce/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/salesforce/blacklist.py ================================================ # NOTE(rkuo): I can't find an actual API that allows us to distinguish # broken/incompatible objects from regular ones. # taking hints from # https://docs.resco.net/wiki/Salesforce_object_blacklist SALESFORCE_BLACKLISTED_PREFIXES: set[str] = set( [ "process", "aura", "app", "auth", "duplicate", "secure", "data", "listemail", "fsl__optimization", "fsl_scheduling", "feed", "chatter", ] ) SALESFORCE_BLACKLISTED_SUFFIXES: set[str] = set( [ "history", "share", "__tag", "__hd", "feed", "changeevent", "__ka", "__votestat", "__viewstat", "__kav", "__datacategoryselection", "subscription", "definition", "eventstream", "__mdt", ] ) SALESFORCE_BLACKLISTED_OBJECTS: set[str] = set( [ "acceptedeventrelation", "accountchangeevent", "accountcontactrole", "accountcontactrolechangeevent", "accounthistory", "accountshare", "actionlinkgrouptemplate", "actionlinktemplate", "activityhistory", "adminsetupevent", "aggregateresult", "announcement", "apexclass", "apexcomponent", "apexemailnotification", "apexlog", "apexpage", "apexpageinfo", "apextestqueueitem", "apextestresult", "apextestresultlimits", "apextestrunresult", "apextestsuite", "apextrigger", "apievent", "apptabmember", "assetchangeevent", "assethistory", "assetrelationshiphistory", "assettokenevent", "assignmentrule", "asyncapexjob", "backgroundoperation", "backgroundoperationresult", "batchapexerrorevent", "brandingset", "brandingsetproperty", "brandtemplate", "businessprocess", "campaignchangeevent", "campaignhistory", "campaignshare", "casechangeevent", "caseexternaldocument", "casehistory", "caseshare", "clientbrowser", "collaborationgroup", "collaborationgroupmember", "collaborationgroupmemberrequest", "collaborationinvitation", "connectedapplication", "contactchangeevent", "contacthistory", "contactrequest", "contactrequestshare", "contactshare", "contentasset", "contentbody", "contentdocumenthistory", "contenthubrepository", "contenttagsubscription", "contentusersubscription", "contentversionhistory", "contracthistory", "corswhitelistentry", "cronjobdetail", "crontrigger", "csptrustedsite", "custombrand", "custombrandasset", "customhelpmenuitem", "customhelpmenusection", "customhttpheader", "customobjectuserlicensemetrics", "custompermission", "custompermissiondependency", "dandbcompany", "dashboard", "dashboardcomponent", "digitalsignature", "documentattachmentmap", "domain", "domainsite", "emailcapture", "emaildomainfilter", "emaildomainkey", "emailrelay", "emailservicesaddress", "emailservicesfunction", "emailstatus", "emailtemplate", "embeddedservicedetail", "embeddedservicelabel", "entityparticle", "eventbussubscriber", "eventchangeevent", "eventlogfile", "eventrelationchangeevent", "expressionfilter", "expressionfiltercriteria", "externaldatasource", "externaldatauserauth", "fieldhistoryarchive", "fieldpermissions", "fieldservicemobilesettings", "filesearchactivity", "fiscalyearsettings", "flexqueueitem", "flowinterview", "flowinterviewshare", "flowrecordrelation", "flowstagerelation", "forecastingshare", "forecastshare", "fsl__criteria__c", "fsl__gantt_filter__c", "fsl__ganttpalette__c", "fsl__service_goal__c", "fsl__slr_cache__c", "fsl__territory_optimization_request__c", "goalhistory", "goalshare", "grantedbylicense", "idpeventlog", "iframewhitelisturl", "image", "imageshare", "installedmobileapp", "leadchangeevent", "leadhistory", "leadshare", "lightningexitbypagemetrics", "lightningexperiencetheme", "lightningtogglemetrics", "lightningusagebyapptypemetrics", "lightningusagebybrowsermetrics", "lightningusagebyflexipagemetrics", "lightningusagebypagemetrics", "linkedarticle", "listemailchangeevent", "listemailshare", "listview", "listviewchart", "listviewchartinstance", "listviewevent", "loginasevent", "loginevent", "logingeo", "loginhistory", "loginip", "logoutevent", "lookedupfromactivity", "macro", "macrohistory", "macroinstruction", "macroshare", "mailmergetemplate", "matchingrule", "matchingruleitem", "metricdatalinkhistory", "metrichistory", "metricshare", "mobilesettingsassignment", "mydomaindiscoverablelogin", "name", "namedcredential", "noteandattachment", "notificationmember", "oauthtoken", "objectpermissions", "onboardingmetrics", "openactivity", "opportunitychangeevent", "opportunitycontactrolechangeevent", "opportunityfieldhistory", "opportunityhistory", "opportunityshare", "orderchangeevent", "orderhistory", "orderitemchangeevent", "orderitemhistory", "ordershare", "orgdeleterequest", "orgdeleterequestshare", "orglifecyclenotification", "orgwideemailaddress", "outgoingemail", "outgoingemailrelation", "ownerchangeoptioninfo", "packagelicense", "period", "permissionsetlicense", "permissionsetlicenseassign", "permissionsettabsetting", "person", "picklistvalueinfo", "platformaction", "platformcachepartition", "platformcachepartitiontype", "platformstatusalertevent", "pricebook2history", "processinstancehistory", "product2changeevent", "product2history", "publisher", "pushtopic", "pushupgradeexcludedorg", "quicktexthistory", "quicktextshare", "quotetemplaterichtextdata", "recordaction", "recordactionhistory", "recordvisibility", "relationshipdomain", "relationshipinfo", "reportevent", "samlssoconfig", "scontrol", "searchactivity", "searchlayout", "searchpromotionrule", "securitycustombaseline", "servicereportlayout", "sessionpermsetactivation", "setupaudittrail", "setupentityaccess", "site", "sitedetail", "sitehistory", "siteiframewhitelisturl", "solutionhistory", "sosdeployment", "sossession", "sossessionactivity", "sossessionhistory", "sossessionshare", "staticresource", "streamingchannel", "streamingchannelshare", "subscriberpackage", "subscriberpackageversion", "taskchangeevent", "tenantusageentitlement", "testsuitemembership", "thirdpartyaccountlink", "todaygoal", "todaygoalshare", "transactionsecuritypolicy", "twofactorinfo", "twofactormethodsinfo", "twofactortempcode", "urievent", "userappinfo", "userappmenucustomization", "userappmenucustomizationshare", "userappmenuitem", "userchangeevent", "useremailpreferredperson", "useremailpreferredpersonshare", "userentityaccess", "userfieldaccess", "userlicense", "userlistview", "userlistviewcriterion", "userlogin", "userpackagelicense", "userpermissionaccess", "userpreference", "userprovaccount", "userprovaccountstaging", "userprovisioningconfig", "userprovisioninglog", "userprovisioningrequest", "userprovisioningrequestshare", "userprovmocktarget", "userrecordaccess", "usershare", "verificationhistory", "visibilitychangenotification", "visualforceaccessmetrics", "waveautoinstallrequest", "wavecompatibilitycheckitem", "weblink", "workcoachinghistory", "workcoachingshare", "workfeedbackhistory", "workfeedbackquestion", "workfeedbackquestionhistory", "workfeedbackquestionsethistory", "workfeedbackquestionsetshare", "workfeedbackquestionshare", "workfeedbackrequesthistory", "workfeedbackrequestshare", "workfeedbackshare", "workfeedbacktemplateshare", "workperformancecyclehistory", "workperformancecycleshare", ] ) ================================================ FILE: backend/onyx/connectors/salesforce/connector.py ================================================ import csv import gc import json import os import sys import tempfile import time from collections import defaultdict from collections.abc import Callable from pathlib import Path from typing import Any from typing import cast from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnectorWithPermSync from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import ConnectorCheckpoint from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import SlimDocument from onyx.connectors.models import TextSection from onyx.connectors.salesforce.doc_conversion import convert_sf_object_to_doc from onyx.connectors.salesforce.doc_conversion import convert_sf_query_result_to_doc from onyx.connectors.salesforce.doc_conversion import ID_PREFIX from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE from onyx.connectors.salesforce.utils import ID_FIELD from onyx.connectors.salesforce.utils import MODIFIED_FIELD from onyx.connectors.salesforce.utils import NAME_FIELD from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger logger = setup_logger() def _convert_to_metadata_value(value: Any) -> str | list[str]: """Convert a Salesforce field value to a valid metadata value. Document metadata expects str | list[str], but Salesforce returns various types (bool, float, int, etc.). This function ensures all values are properly converted to strings. """ if isinstance(value, list): return [str(item) for item in value] return str(value) _DEFAULT_PARENT_OBJECT_TYPES = [ACCOUNT_OBJECT_TYPE] _DEFAULT_ATTRIBUTES_TO_KEEP: dict[str, dict[str, str]] = { "Opportunity": { ACCOUNT_OBJECT_TYPE: "account", "FiscalQuarter": "fiscal_quarter", "FiscalYear": "fiscal_year", "IsClosed": "is_closed", NAME_FIELD: "name", "StageName": "stage_name", "Type": "type", "Amount": "amount", "CloseDate": "close_date", "Probability": "probability", "CreatedDate": "created_date", MODIFIED_FIELD: "last_modified_date", }, "Contact": { ACCOUNT_OBJECT_TYPE: "account", "CreatedDate": "created_date", MODIFIED_FIELD: "last_modified_date", }, } class SalesforceCheckpoint(ConnectorCheckpoint): initial_sync_complete: bool current_timestamp: SecondsSinceUnixEpoch class SalesforceConnectorContext: parent_types: set[str] = set() child_types: set[str] = set() parent_to_child_types: dict[str, set[str]] = {} # map from parent to child types child_to_parent_types: dict[str, set[str]] = {} # map from child to parent types parent_reference_fields_by_type: dict[str, dict[str, list[str]]] = {} type_to_queryable_fields: dict[str, set[str]] = {} prefix_to_type: dict[str, str] = {} # infer the object type of an id immediately parent_to_child_relationships: dict[str, set[str]] = ( {} ) # map from parent to child relationships parent_to_relationship_queryable_fields: dict[str, dict[str, set[str]]] = ( {} ) # map from relationship to queryable fields parent_child_names_to_relationships: dict[str, str] = {} def _extract_fields_and_associations_from_config( config: dict[str, Any], object_type: str ) -> tuple[list[str] | None, dict[str, list[str]]]: """ Extract fields and associations for a specific object type from custom config. Returns: tuple of (fields_list, associations_dict) - fields_list: List of fields to query, or None if not specified (use all) - associations_dict: Dict mapping association names to their config """ if object_type not in config: return None, {} obj_config = config[object_type] fields = obj_config.get("fields") associations = obj_config.get("associations", {}) return fields, associations def _validate_custom_query_config(config: dict[str, Any]) -> None: """ Validate the structure of the custom query configuration. """ for object_type, obj_config in config.items(): if not isinstance(obj_config, dict): raise ValueError( f"top level object {object_type} must be mapped to a dictionary" ) # Check if fields is a list when present if "fields" in obj_config: if not isinstance(obj_config["fields"], list): raise ValueError("if fields key exists, value must be a list") for v in obj_config["fields"]: if not isinstance(v, str): raise ValueError(f"if fields list value {v} is not a string") # Check if associations is a dict when present if "associations" in obj_config: if not isinstance(obj_config["associations"], dict): raise ValueError( "if associations key exists, value must be a dictionary" ) for assoc_name, assoc_fields in obj_config["associations"].items(): if not isinstance(assoc_fields, list): raise ValueError( f"associations list value {assoc_fields} for key {assoc_name} is not a list" ) for v in assoc_fields: if not isinstance(v, str): raise ValueError( f"if associations list value {v} is not a string" ) class SalesforceConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): """Approach outline Goal - get data for every record of every parent object type - The data should consist of the parent object record and all direct child relationship objects Initial sync - Does a full sync, then indexes each parent object + children as a document via the local sqlite db - get the first level children object types of parent object types - bulk export all object types to CSV -- NOTE: bulk exports of an object type contain parent id's, but not child id's - Load all CSV's to the DB - generate all parent object types as documents and yield them - Initial sync's must always be for the entire dataset. Otherwise, you can have cases where some records relate to other records that were updated recently. The more recently updated records will not be pulled down in the query. Delta sync's - delta sync's detect changes in parent objects, then perform a full sync of each parent object and its children If loading the entire db, this approach is much slower. For deltas, it works well. - query all changed records (includes children and parents) - extrapolate all changed parent objects - for each parent object, construct a query and yield the result back - Delta sync's can be done object by object by identifying the parent id of any changed record, and querying a single record at a time to get all the updated data. In this way, we avoid having to keep a locally synchronized copy of the entire salesforce db. TODO: verify record to doc conversion figure out why sometimes the field names are missing. """ MAX_BATCH_BYTES = 1024 * 1024 LOG_INTERVAL = 10.0 # how often to log stats in loop heavy parts of the connector def __init__( self, batch_size: int = INDEX_BATCH_SIZE, requested_objects: list[str] = [], custom_query_config: str | None = None, ) -> None: self.batch_size = batch_size self._sf_client: OnyxSalesforce | None = None # Validate and store custom query config if custom_query_config: config_json = json.loads(custom_query_config) self.custom_query_config: dict[str, Any] | None = config_json # If custom query config is provided, use the object types from it self.parent_object_list = list(config_json.keys()) else: self.custom_query_config = None # Use the traditional requested_objects approach self.parent_object_list = ( [obj.strip().capitalize() for obj in requested_objects] if requested_objects else _DEFAULT_PARENT_OBJECT_TYPES ) def load_credentials( self, credentials: dict[str, Any], ) -> dict[str, Any] | None: domain = "test" if credentials.get("is_sandbox") else None self._sf_client = OnyxSalesforce( username=credentials["sf_username"], password=credentials["sf_password"], security_token=credentials["sf_security_token"], domain=domain, ) return None @property def sf_client(self) -> OnyxSalesforce: if self._sf_client is None: raise ConnectorMissingCredentialError("Salesforce") return self._sf_client @staticmethod def reconstruct_object_types(directory: str) -> dict[str, list[str] | None]: """ Scans the given directory for all CSV files and reconstructs the available object types. Assumes filenames are formatted as "ObjectType.filename.csv" or "ObjectType.csv". Args: directory (str): The path to the directory containing CSV files. Returns: dict[str, list[str]]: A dictionary mapping object types to lists of file paths. """ object_types = defaultdict(list) for filename in os.listdir(directory): if filename.endswith(".csv"): parts = filename.split(".", 1) # Split on the first period object_type = parts[0] # Take the first part as the object type object_types[object_type].append(os.path.join(directory, filename)) return dict(object_types) @staticmethod def _download_object_csvs( all_types_to_filter: dict[str, bool], queryable_fields_by_type: dict[str, set[str]], directory: str, sf_client: OnyxSalesforce, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> None: # checkpoint - we've found all object types, now time to fetch the data logger.info("Fetching CSVs for all object types") # This takes like 30 minutes first time and <2 minutes for updates object_type_to_csv_path = fetch_all_csvs_in_parallel( sf_client=sf_client, all_types_to_filter=all_types_to_filter, queryable_fields_by_type=queryable_fields_by_type, start=start, end=end, target_dir=directory, ) # print useful information num_csvs = 0 num_bytes = 0 for object_type, csv_paths in object_type_to_csv_path.items(): if not csv_paths: continue for csv_path in csv_paths: if not csv_path: continue file_path = Path(csv_path) file_size = file_path.stat().st_size num_csvs += 1 num_bytes += file_size logger.info( f"CSV download: object_type={object_type} path={csv_path} bytes={file_size}" ) logger.info( f"CSV download total: total_csvs={num_csvs} total_bytes={num_bytes}" ) @staticmethod def _load_csvs_to_db( csv_directory: str, remove_ids: bool, sf_db: OnyxSalesforceSQLite ) -> dict[str, str]: """ Returns a dict of id to object type. Each id is a newly seen row in salesforce. """ updated_ids: dict[str, str] = {} object_type_to_csv_path = SalesforceConnector.reconstruct_object_types( csv_directory ) # NOTE(rkuo): this timing note is meaningless without a reference point in terms # of number of records, etc # This takes like 10 seconds # This is for testing the rest of the functionality if data has # already been fetched and put in sqlite # from import onyx.connectors.salesforce.sf_db.sqlite_functions find_ids_by_type # for object_type in self.parent_object_list: # updated_ids.update(list(find_ids_by_type(object_type))) # This takes 10-70 minutes first time (idk why the range is so big) total_types = len(object_type_to_csv_path) logger.info(f"Starting to process {total_types} object types") for i, (object_type, csv_paths) in enumerate( object_type_to_csv_path.items(), 1 ): logger.info(f"Processing object type {object_type} ({i}/{total_types})") # If path is None, it means it failed to fetch the csv if csv_paths is None: continue # Go through each csv path and use it to update the db for csv_path in csv_paths: num_records = 0 logger.debug( f"Processing CSV: object_type={object_type} " f"csv={csv_path} " f"len={Path(csv_path).stat().st_size} " f"records={num_records}" ) with open(csv_path, "r", newline="", encoding="utf-8") as f: reader = csv.DictReader(f) for row in reader: num_records += 1 new_ids = sf_db.update_from_csv( object_type=object_type, csv_download_path=csv_path, remove_ids=remove_ids, ) for new_id in new_ids: updated_ids[new_id] = object_type sf_db.flush() logger.debug( f"Added {len(new_ids)} new/updated records for {object_type}" ) logger.info( f"Processed CSV: object_type={object_type} " f"csv={csv_path} " f"len={Path(csv_path).stat().st_size} " f"records={num_records} " f"db_len={sf_db.file_size}" ) os.remove(csv_path) return updated_ids # @staticmethod # def _get_child_types( # parent_types: list[str], sf_client: OnyxSalesforce # ) -> set[str]: # all_types: set[str] = set(parent_types) # # Step 1 - get all object types # logger.info(f"Parent object types: num={len(parent_types)} list={parent_types}") # # This takes like 20 seconds # for parent_object_type in parent_types: # child_types = sf_client.get_children_of_sf_type(parent_object_type) # logger.debug( # f"Found {len(child_types)} child types for {parent_object_type}" # ) # all_types.update(child_types.keys()) # # Always want to make sure user is grabbed for permissioning purposes # all_types.add(USER_OBJECT_TYPE) # # Always want to make sure account is grabbed for reference purposes # all_types.add(ACCOUNT_OBJECT_TYPE) # logger.info(f"All object types: num={len(all_types)} list={all_types}") # # gc.collect() # return all_types # @staticmethod # def _get_all_types(parent_types: list[str], sf_client: Salesforce) -> set[str]: # all_types: set[str] = set(parent_types) # # Step 1 - get all object types # logger.info(f"Parent object types: num={len(parent_types)} list={parent_types}") # # This takes like 20 seconds # for parent_object_type in parent_types: # child_types = get_children_of_sf_type(sf_client, parent_object_type) # logger.debug( # f"Found {len(child_types)} child types for {parent_object_type}" # ) # all_types.update(child_types) # # Always want to make sure user is grabbed for permissioning purposes # all_types.add(USER_OBJECT_TYPE) # logger.info(f"All object types: num={len(all_types)} list={all_types}") # # gc.collect() # return all_types def _yield_doc_batches( self, sf_db: OnyxSalesforceSQLite, type_to_processed: dict[str, int], changed_ids_to_type: dict[str, str], parent_types: set[str], increment_parents_changed: Callable[[], None], ) -> GenerateDocumentsOutput: """ """ docs_to_yield: list[Document | HierarchyNode] = [] docs_to_yield_bytes = 0 last_log_time = 0.0 for ( parent_type, parent_id, examined_ids, ) in sf_db.get_changed_parent_ids_by_type( changed_ids=list(changed_ids_to_type.keys()), parent_types=parent_types, ): now = time.monotonic() processed = examined_ids - 1 if now - last_log_time > SalesforceConnector.LOG_INTERVAL: logger.info( f"Processing stats: {type_to_processed} " f"file_size={sf_db.file_size} " f"processed={processed} " f"remaining={len(changed_ids_to_type) - processed}" ) last_log_time = now type_to_processed[parent_type] = type_to_processed.get(parent_type, 0) + 1 parent_object = sf_db.get_record(parent_id, parent_type) if not parent_object: logger.warning( f"Failed to get parent object {parent_id} for {parent_type}" ) continue # use the db to create a document we can yield doc = convert_sf_object_to_doc( sf_db, sf_object=parent_object, sf_instance=self.sf_client.sf_instance, ) doc.metadata["object_type"] = parent_type # Add default attributes to the metadata for ( sf_attribute, canonical_attribute, ) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(parent_type, {}).items(): if sf_attribute in parent_object.data: doc.metadata[canonical_attribute] = _convert_to_metadata_value( parent_object.data[sf_attribute] ) doc_sizeof = sys.getsizeof(doc) docs_to_yield_bytes += doc_sizeof docs_to_yield.append(doc) increment_parents_changed() # memory usage is sensitive to the input length, so we're yielding immediately # if the batch exceeds a certain byte length if ( len(docs_to_yield) >= self.batch_size or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES ): yield docs_to_yield docs_to_yield = [] docs_to_yield_bytes = 0 # observed a memory leak / size issue with the account table if we don't gc.collect here. gc.collect() yield docs_to_yield def _full_sync( self, temp_dir: str, ) -> GenerateDocumentsOutput: type_to_processed: dict[str, int] = {} logger.info("_fetch_from_salesforce starting (full sync).") if not self._sf_client: raise RuntimeError("self._sf_client is None!") changed_ids_to_type: dict[str, str] = {} parents_changed = 0 examined_ids = 0 sf_db = OnyxSalesforceSQLite(os.path.join(temp_dir, "salesforce_db.sqlite")) sf_db.connect() try: sf_db.apply_schema() sf_db.log_stats() ctx = self._make_context( None, None, temp_dir, self.parent_object_list, self._sf_client ) gc.collect() # Step 2 - load CSV's to sqlite object_type_to_csv_paths = SalesforceConnector.reconstruct_object_types( temp_dir ) total_types = len(object_type_to_csv_paths) logger.info(f"Starting to process {total_types} object types") for i, (object_type, csv_paths) in enumerate( object_type_to_csv_paths.items(), 1 ): logger.info(f"Processing object type {object_type} ({i}/{total_types})") # If path is None, it means it failed to fetch the csv if csv_paths is None: continue # Go through each csv path and use it to update the db for csv_path in csv_paths: num_records = 0 with open(csv_path, "r", newline="", encoding="utf-8") as f: reader = csv.DictReader(f) for row in reader: num_records += 1 logger.debug( f"Processing CSV: object_type={object_type} " f"csv={csv_path} " f"len={Path(csv_path).stat().st_size} " f"records={num_records}" ) new_ids = sf_db.update_from_csv( object_type=object_type, csv_download_path=csv_path, ) for new_id in new_ids: changed_ids_to_type[new_id] = object_type sf_db.flush() logger.debug( f"Added {len(new_ids)} new/updated records for {object_type}" ) logger.info( f"Processed CSV: object_type={object_type} " f"csv={csv_path} " f"len={Path(csv_path).stat().st_size} " f"records={num_records} " f"db_len={sf_db.file_size}" ) os.remove(csv_path) gc.collect() gc.collect() logger.info(f"Found {len(changed_ids_to_type)} total updated records") logger.info( f"Starting to process parent objects of types: {ctx.parent_types}" ) # Step 3 - extract and index docs def increment_parents_changed() -> None: nonlocal parents_changed parents_changed += 1 yield from self._yield_doc_batches( sf_db, type_to_processed, changed_ids_to_type, ctx.parent_types, increment_parents_changed, ) except Exception: logger.exception("Unexpected exception") raise finally: logger.info( f"Final processing stats: " f"examined={examined_ids} " f"parents_changed={parents_changed} " f"remaining={len(changed_ids_to_type) - examined_ids}" ) logger.info(f"Top level object types processed: {type_to_processed}") sf_db.close() def _delta_sync( self, temp_dir: str, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> GenerateDocumentsOutput: type_to_processed: dict[str, int] = {} logger.info("_fetch_from_salesforce starting (delta sync).") if not self._sf_client: raise RuntimeError("self._sf_client is None!") changed_ids_to_type: dict[str, str] = {} parents_changed = 0 processed = 0 sf_db = OnyxSalesforceSQLite(os.path.join(temp_dir, "salesforce_db.sqlite")) sf_db.connect() try: sf_db.apply_schema() sf_db.log_stats() ctx = self._make_context( start, end, temp_dir, self.parent_object_list, self._sf_client ) gc.collect() # Step 2 - load CSV's to sqlite changed_ids_to_type = SalesforceConnector._load_csvs_to_db( temp_dir, False, sf_db ) gc.collect() logger.info(f"Found {len(changed_ids_to_type)} total updated records") logger.info( f"Starting to process parent objects of types: {ctx.parent_types}" ) # Step 3 - extract and index docs docs_to_yield: list[Document | HierarchyNode] = [] docs_to_yield_bytes = 0 last_log_time = 0.0 # this is a partial sync, so all changed parent id's must be retrieved from salesforce # NOTE: it may be an option to identify the object type of an id with its prefix # but unfortunately it's possible for an object type to not have a prefix. # so that would work in many important cases, but not all. for ( parent_id, actual_parent_type, num_examined, ) in sf_db.get_changed_parent_ids_by_type_2( changed_ids=changed_ids_to_type, parent_types=ctx.parent_types, parent_relationship_fields_by_type=ctx.parent_reference_fields_by_type, prefix_to_type=ctx.prefix_to_type, ): # this yields back each changed parent record, where changed means # the parent record itself or a child record was updated. now = time.monotonic() # query salesforce for the changed parent id record # NOTE(rkuo): we only know the record id and its possible types, # so we actually need to check each type until we succeed # to be entirely correct # this may be a source of inefficiency and thinking about # caching the most likely parent record type might be helpful # actual_parent_type: str | None = None # for possible_parent_type in possible_parent_types: # queryable_fields = ctx.queryable_fields_by_type[ # possible_parent_type # ] # query = _get_object_by_id_query( # parent_id, possible_parent_type, queryable_fields # ) # result = self._sf_client.query(query) # if result: # actual_parent_type = possible_parent_type # print(result) # break # get the parent record fields record = self._sf_client.query_object( actual_parent_type, parent_id, ctx.type_to_queryable_fields ) if not record: continue # queryable_fields = ctx.type_to_queryable_fields[ # actual_parent_type # ] # query = get_object_by_id_query( # parent_id, actual_parent_type, queryable_fields # ) # result = self._sf_client.query(query) # if not result: # continue # # print(result) # record: dict[str, Any] = {} # record_0 = result["records"][0] # for record_key, record_value in record_0.items(): # if record_key == "attributes": # continue # record[record_key] = record_value # for this parent type, increment the counter on the stats object type_to_processed[actual_parent_type] = ( type_to_processed.get(actual_parent_type, 0) + 1 ) # get the child records child_relationships = ctx.parent_to_child_relationships[ actual_parent_type ] relationship_to_queryable_fields = ( ctx.parent_to_relationship_queryable_fields[actual_parent_type] ) child_records = self.sf_client.get_child_objects_by_id( parent_id, actual_parent_type, list(child_relationships), relationship_to_queryable_fields, ) # NOTE(rkuo): does using the parent last modified make sense if the update # is being triggered because a child object changed? primary_owner_list: list[BasicExpertInfo] | None = None if "LastModifiedById" in record: try: last_modified_by_id = record["LastModifiedById"] user_record = self.sf_client.query_object( USER_OBJECT_TYPE, last_modified_by_id, ctx.type_to_queryable_fields, ) if user_record: primary_owner = BasicExpertInfo.from_dict(user_record) primary_owner_list = [primary_owner] except Exception: pass # for child_record_key, child_record in child_records.items(): # if not child_record: # continue # child_text_section = _extract_section( # child_record, # f"https://{self._sf_client.sf_instance}/{child_record_key}", # ) # sections.append(child_text_section) # for parent_relationship_field in parent_relationship_fields: # parent_relationship_id # json.loads(parent_object.data) # create and yield a document from the salesforce query doc = convert_sf_query_result_to_doc( parent_id, record, child_records, primary_owner_list, self._sf_client, ) # doc = Document( # id=ID_PREFIX + parent_id, # sections=cast(list[TextSection | ImageSection], sections), # source=DocumentSource.SALESFORCE, # semantic_identifier=parent_semantic_identifier, # doc_updated_at=time_str_to_utc(parent_last_modified_date), # primary_owners=primary_owner_list, # metadata={}, # ) # Add default attributes to the metadata for ( sf_attribute, canonical_attribute, ) in _DEFAULT_ATTRIBUTES_TO_KEEP.get(actual_parent_type, {}).items(): if sf_attribute in record: doc.metadata[canonical_attribute] = _convert_to_metadata_value( record[sf_attribute] ) doc_sizeof = sys.getsizeof(doc) docs_to_yield_bytes += doc_sizeof docs_to_yield.append(doc) parents_changed += 1 # memory usage is sensitive to the input length, so we're yielding immediately # if the batch exceeds a certain byte length if ( len(docs_to_yield) >= self.batch_size or docs_to_yield_bytes > SalesforceConnector.MAX_BATCH_BYTES ): yield docs_to_yield docs_to_yield = [] docs_to_yield_bytes = 0 # observed a memory leak / size issue with the account table if we don't gc.collect here. gc.collect() processed = num_examined if now - last_log_time > SalesforceConnector.LOG_INTERVAL: logger.info( f"Processing stats: {type_to_processed} " f"processed={processed} " f"remaining={len(changed_ids_to_type) - processed}" ) last_log_time = now yield docs_to_yield except Exception: logger.exception("Unexpected exception") raise finally: logger.info( f"Final processing stats: " f"processed={processed} " f"remaining={len(changed_ids_to_type) - processed} " f"parents_changed={parents_changed}" ) logger.info(f"Top level object types processed: {type_to_processed}") sf_db.close() def _make_context( self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None, temp_dir: str, parent_object_list: list[str], sf_client: OnyxSalesforce, ) -> SalesforceConnectorContext: """NOTE: I suspect we're doing way too many queries here. Likely fewer queries and just parsing all the info we need in less passes will work.""" parent_types = set(parent_object_list) child_types: set[str] = set() parent_to_child_types: dict[str, set[str]] = ( {} ) # map from parent to child types child_to_parent_types: dict[str, set[str]] = ( {} ) # map from child to parent types parent_reference_fields_by_type: dict[str, dict[str, list[str]]] = ( {} ) # for a given object, the fields reference parent objects type_to_queryable_fields: dict[str, set[str]] = {} prefix_to_type: dict[str, str] = {} parent_to_child_relationships: dict[str, set[str]] = ( {} ) # map from parent to child relationships # relationship keys are formatted as "parent__relationship" # we have to do this because relationship names are not unique! # values are a dict of relationship names to a list of queryable fields parent_to_relationship_queryable_fields: dict[str, dict[str, set[str]]] = {} parent_child_names_to_relationships: dict[str, str] = {} full_sync = start is None and end is None # Step 1 - make a list of all the types to download (parent + direct child + USER_OBJECT_TYPE) # prefixes = {} global_description = sf_client.describe() if not global_description: raise RuntimeError("sf_client.describe failed") for sobject in global_description["sobjects"]: if sobject["keyPrefix"]: prefix_to_type[sobject["keyPrefix"]] = sobject["name"] # prefixes[sobject['keyPrefix']] = { # 'object_name': sobject['name'], # 'label': sobject['label'], # 'is_custom': sobject['custom'] # } logger.info(f"Describe: num_prefixes={len(prefix_to_type)}") logger.info(f"Parent object types: num={len(parent_types)} list={parent_types}") for parent_type in parent_types: # parent_onyx_sf_type = OnyxSalesforceType(parent_type, sf_client) custom_fields: list[str] | None = [] associations_config: dict[str, list[str]] | None = None # Set queryable fields for parent type if self.custom_query_config: custom_fields, associations_config = ( _extract_fields_and_associations_from_config( self.custom_query_config, parent_type ) ) custom_fields = custom_fields or [] # Get custom fields for parent type field_set = set(custom_fields) # used during doc conversion # field_set.add(NAME_FIELD) # does not always exist field_set.add(ID_FIELD) field_set.add(MODIFIED_FIELD) # Use only the specified fields type_to_queryable_fields[parent_type] = field_set logger.info(f"Using custom fields for {parent_type}: {field_set}") else: # Use all queryable fields type_to_queryable_fields[parent_type] = ( sf_client.get_queryable_fields_by_type(parent_type) ) logger.info(f"Using all fields for {parent_type}") child_types_all = sf_client.get_children_of_sf_type(parent_type) logger.debug(f"Found {len(child_types_all)} child types for {parent_type}") logger.debug(f"child types: {child_types_all}") child_types_working = child_types_all.copy() if associations_config is not None: child_types_working = { k: v for k, v in child_types_all.items() if k in associations_config } any_not_found = False for k in associations_config: if k not in child_types_working: any_not_found = True logger.warning(f"Association {k} not found in {parent_type}") if any_not_found: queryable_fields = sf_client.get_queryable_fields_by_type( parent_type ) raise RuntimeError( f"Associations {associations_config} not found in {parent_type} " "make sure your parent-child associations are in the right order" # f"with child objects {child_types_all}" # f" and fields {queryable_fields}" ) parent_to_child_relationships[parent_type] = set() parent_to_child_types[parent_type] = set() parent_to_relationship_queryable_fields[parent_type] = {} for child_type, child_relationship in child_types_working.items(): child_type = cast(str, child_type) # onyx_sf_type = OnyxSalesforceType(child_type, sf_client) # map parent name to child name parent_to_child_types[parent_type].add(child_type) # reverse map child name to parent name if child_type not in child_to_parent_types: child_to_parent_types[child_type] = set() child_to_parent_types[child_type].add(parent_type) # map parent name to child relationship parent_to_child_relationships[parent_type].add(child_relationship) # map relationship to queryable fields of the target table if config_fields := ( associations_config and associations_config.get(child_type) ): field_set = set(config_fields) # these are expected and used during doc conversion # field_set.add(NAME_FIELD) # does not always exist field_set.add(ID_FIELD) field_set.add(MODIFIED_FIELD) queryable_fields = field_set else: queryable_fields = sf_client.get_queryable_fields_by_type( child_type ) if child_relationship in parent_to_relationship_queryable_fields: raise RuntimeError(f"{child_relationship=} already exists") parent_to_relationship_queryable_fields[parent_type][ child_relationship ] = queryable_fields type_to_queryable_fields[child_type] = queryable_fields parent_child_names_to_relationships[f"{parent_type}__{child_type}"] = ( child_relationship ) child_types.update(child_types_working.keys()) logger.info( f"Child object types: parent={parent_type} num={len(child_types_working)} list={child_types_working.keys()}" ) logger.info( f"Final child object types: num={len(child_types)} list={child_types}" ) all_types: set[str] = set(parent_types) all_types.update(child_types) # NOTE(rkuo): should this be an implicit parent type? all_types.add(USER_OBJECT_TYPE) # Always add User for permissioning purposes all_types.add(ACCOUNT_OBJECT_TYPE) # Always add Account for reference purposes logger.info(f"All object types: num={len(all_types)} list={all_types}") # Ensure User and Account have queryable fields if they weren't already processed essential_types = [USER_OBJECT_TYPE, ACCOUNT_OBJECT_TYPE] for essential_type in essential_types: if essential_type not in type_to_queryable_fields: type_to_queryable_fields[essential_type] = ( sf_client.get_queryable_fields_by_type(essential_type) ) # 1.1 - Detect all fields in child types which reference a parent type. # build dicts to detect relationships between parent and child for child_type in child_types.union(essential_types): # onyx_sf_type = OnyxSalesforceType(child_type, sf_client) parent_reference_fields = sf_client.get_parent_reference_fields( child_type, parent_types ) parent_reference_fields_by_type[child_type] = parent_reference_fields # Only add time filter if there is at least one object of the type # in the database. We aren't worried about partially completed object update runs # because this occurs after we check for existing csvs which covers this case # NOTE(rkuo): all_types_to_filter: dict[str, bool] = {} for sf_type in all_types: # onyx_sf_type = OnyxSalesforceType(sf_type, sf_client) # NOTE(rkuo): I'm not convinced it makes sense to restrict filtering at all # all_types_to_filter[sf_type] = sf_db.object_type_count(sf_type) > 0 all_types_to_filter[sf_type] = not full_sync # Step 1.2 - bulk download the CSV's for each object type SalesforceConnector._download_object_csvs( all_types_to_filter, type_to_queryable_fields, temp_dir, sf_client, start, end, ) return_context = SalesforceConnectorContext() return_context.parent_types = parent_types return_context.child_types = child_types return_context.parent_to_child_types = parent_to_child_types return_context.child_to_parent_types = child_to_parent_types return_context.parent_reference_fields_by_type = parent_reference_fields_by_type return_context.type_to_queryable_fields = type_to_queryable_fields return_context.prefix_to_type = prefix_to_type return_context.parent_to_child_relationships = parent_to_child_relationships return_context.parent_to_relationship_queryable_fields = ( parent_to_relationship_queryable_fields ) return_context.parent_child_names_to_relationships = ( parent_child_names_to_relationships ) return return_context def load_from_state(self) -> GenerateDocumentsOutput: # Always use a temp directory for SQLite - the database is rebuilt # from scratch each time via CSV downloads, so there's no caching benefit # from persisting it. Using temp dirs also avoids collisions between # multiple CC pairs and eliminates stale WAL/SHM file issues. # TODO(evan): make this thing checkpointed and persist/load db from filestore with tempfile.TemporaryDirectory() as temp_dir: yield from self._full_sync(temp_dir) def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: """Poll source will synchronize updated parent objects one by one.""" # Always use a temp directory - see comment in load_from_state() with tempfile.TemporaryDirectory() as temp_dir: yield from self._delta_sync(temp_dir, start, end) def retrieve_all_slim_docs_perm_sync( self, start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002 end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002 callback: IndexingHeartbeatInterface | None = None, # noqa: ARG002 ) -> GenerateSlimDocumentOutput: doc_metadata_list: list[SlimDocument | HierarchyNode] = [] for parent_object_type in self.parent_object_list: query = f"SELECT Id FROM {parent_object_type}" query_result = self.sf_client.safe_query_all(query) doc_metadata_list.extend( SlimDocument( id=f"{ID_PREFIX}{instance_dict.get('Id', '')}", external_access=None, ) for instance_dict in query_result["records"] ) yield doc_metadata_list def validate_connector_settings(self) -> None: """ Validate that the Salesforce credentials and connector settings are correct. Specifically checks that we can make an authenticated request to Salesforce. """ try: # Attempt to fetch a small batch of objects (arbitrary endpoint) to verify credentials self.sf_client.describe() except Exception as e: raise ConnectorMissingCredentialError( f"Failed to validate Salesforce credentials. Please check yourcredentials and try again. Error: {e}" ) if self.custom_query_config: try: _validate_custom_query_config(self.custom_query_config) except Exception as e: raise ConnectorMissingCredentialError( f"Failed to validate Salesforce custom query config. Please check yourconfig and try again. Error: {e}" ) logger.info("Salesforce credentials validated successfully.") # @override # def load_from_checkpoint( # self, # start: SecondsSinceUnixEpoch, # end: SecondsSinceUnixEpoch, # checkpoint: SalesforceCheckpoint, # ) -> CheckpointOutput[SalesforceCheckpoint]: # try: # return self._fetch_document_batches(checkpoint, start, end) # except Exception as e: # if _should_propagate_error(e) and start is not None: # logger.warning( # "Confluence says we provided an invalid 'updated' field. This may indicate" # "a real issue, but can also appear during edge cases like daylight" # f"savings time changes. Retrying with a 1 hour offset. Error: {e}" # ) # return self._fetch_document_batches(checkpoint, start - ONE_HOUR, end) # raise # @override # def build_dummy_checkpoint(self) -> SalesforceCheckpoint: # return SalesforceCheckpoint(last_updated=0, has_more=True, last_seen_doc_ids=[]) # @override # def validate_checkpoint_json(self, checkpoint_json: str) -> SalesforceCheckpoint: # return SalesforceCheckpoint.model_validate_json(checkpoint_json) if __name__ == "__main__": connector = SalesforceConnector(requested_objects=[ACCOUNT_OBJECT_TYPE]) connector.load_credentials( { "sf_username": os.environ["SF_USERNAME"], "sf_password": os.environ["SF_PASSWORD"], "sf_security_token": os.environ["SF_SECURITY_TOKEN"], } ) start_time = time.monotonic() doc_count = 0 section_count = 0 text_count = 0 for doc_batch in connector.load_from_state(): doc_count += len(doc_batch) print(f"doc_count: {doc_count}") for doc in doc_batch: if isinstance(doc, HierarchyNode): continue section_count += len(doc.sections) for section in doc.sections: if isinstance(section, TextSection) and section.text is not None: text_count += len(section.text) end_time = time.monotonic() print(f"Doc count: {doc_count}") print(f"Section count: {section_count}") print(f"Text count: {text_count}") print(f"Time taken: {end_time - start_time}") ================================================ FILE: backend/onyx/connectors/salesforce/doc_conversion.py ================================================ import re from typing import Any from typing import cast from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import Document from onyx.connectors.models import ImageSection from onyx.connectors.models import TextSection from onyx.connectors.salesforce.onyx_salesforce import OnyxSalesforce from onyx.connectors.salesforce.sqlite_functions import OnyxSalesforceSQLite from onyx.connectors.salesforce.utils import ID_FIELD from onyx.connectors.salesforce.utils import MODIFIED_FIELD from onyx.connectors.salesforce.utils import NAME_FIELD from onyx.connectors.salesforce.utils import SalesforceObject from onyx.utils.logger import setup_logger logger = setup_logger() ID_PREFIX = "SALESFORCE_" # All of these types of keys are handled by specific fields in the doc # conversion process (E.g. URLs) or are not useful for the user (E.g. UUIDs) _SF_JSON_FILTER = r"Id$|Date$|stamp$|url$" def _clean_salesforce_dict(data: dict | list) -> dict | list: """Clean and transform Salesforce API response data by recursively: 1. Extracting records from the response if present 2. Merging attributes into the main dictionary 3. Filtering out keys matching certain patterns (Id, Date, stamp, url) 4. Removing '__c' suffix from custom field names 5. Removing None values and empty containers Args: data: A dictionary or list from Salesforce API response Returns: Cleaned dictionary or list with transformed keys and filtered values """ if isinstance(data, dict): if "records" in data.keys(): data = data["records"] if isinstance(data, dict): if "attributes" in data.keys(): if isinstance(data["attributes"], dict): data.update(data.pop("attributes")) if isinstance(data, dict): filtered_dict = {} for key, value in data.items(): if not re.search(_SF_JSON_FILTER, key, re.IGNORECASE): # remove the custom object indicator for display if "__c" in key: key = key[:-3] if isinstance(value, (dict, list)): filtered_value = _clean_salesforce_dict(value) # Only add non-empty dictionaries or lists if filtered_value: filtered_dict[key] = filtered_value elif value is not None: filtered_dict[key] = value return filtered_dict if isinstance(data, list): filtered_list = [] for item in data: filtered_item: dict | list if isinstance(item, (dict, list)): filtered_item = _clean_salesforce_dict(item) # Only add non-empty dictionaries or lists if filtered_item: filtered_list.append(filtered_item) elif item is not None: filtered_list.append(item) return filtered_list return data def _json_to_natural_language(data: dict | list, indent: int = 0) -> str: """Convert a nested dictionary or list into a human-readable string format. Recursively traverses the data structure and formats it with: - Key-value pairs on separate lines - Nested structures indented for readability - Lists and dictionaries handled with appropriate formatting Args: data: The dictionary or list to convert indent: Number of spaces to indent (default: 0) Returns: A formatted string representation of the data structure """ result = [] indent_str = " " * indent if isinstance(data, dict): for key, value in data.items(): if isinstance(value, (dict, list)): result.append(f"{indent_str}{key}:") result.append(_json_to_natural_language(value, indent + 2)) else: result.append(f"{indent_str}{key}: {value}") elif isinstance(data, list): for item in data: result.append(_json_to_natural_language(item, indent + 2)) return "\n".join(result) def _extract_section(salesforce_object_data: dict[str, Any], link: str) -> TextSection: """Converts a dict to a TextSection""" # Extract text from a Salesforce API response dictionary by: # 1. Cleaning the dictionary # 2. Converting the cleaned dictionary to natural language processed_dict = _clean_salesforce_dict(salesforce_object_data) natural_language_for_dict = _json_to_natural_language(processed_dict) return TextSection( text=natural_language_for_dict, link=link, ) def _extract_primary_owner( sf_db: OnyxSalesforceSQLite, sf_object: SalesforceObject, ) -> BasicExpertInfo | None: object_dict = sf_object.data if not (last_modified_by_id := object_dict.get("LastModifiedById")): logger.warning(f"No LastModifiedById found for {sf_object.id}") return None if not (last_modified_by := sf_db.get_record(last_modified_by_id)): logger.warning(f"No LastModifiedBy found for {last_modified_by_id}") return None user_data = last_modified_by.data expert_info = BasicExpertInfo( first_name=user_data.get("FirstName"), last_name=user_data.get("LastName"), email=user_data.get("Email"), display_name=user_data.get(NAME_FIELD), ) # Check if all fields are None if ( expert_info.first_name is None and expert_info.last_name is None and expert_info.email is None and expert_info.display_name is None ): logger.warning(f"No identifying information found for user {user_data}") return None return expert_info def convert_sf_query_result_to_doc( record_id: str, record: dict[str, Any], child_records: dict[str, dict[str, Any]], primary_owner_list: list[BasicExpertInfo] | None, sf_client: OnyxSalesforce, ) -> Document: """Generates a yieldable Document from query results""" base_url = f"https://{sf_client.sf_instance}" extracted_doc_updated_at = time_str_to_utc(record[MODIFIED_FIELD]) extracted_semantic_identifier = record.get(NAME_FIELD) or record.get( ID_FIELD, "Unknown Object" ) sections = [_extract_section(record, f"{base_url}/{record_id}")] for child_record_key, child_record in child_records.items(): if not child_record: continue key_fields = child_record_key.split(":") child_record_id = key_fields[1] child_text_section = _extract_section( child_record, f"{base_url}/{child_record_id}", ) sections.append(child_text_section) doc = Document( id=f"{ID_PREFIX}{record_id}", sections=cast(list[TextSection | ImageSection], sections), source=DocumentSource.SALESFORCE, semantic_identifier=extracted_semantic_identifier, doc_updated_at=extracted_doc_updated_at, primary_owners=primary_owner_list, metadata={}, ) return doc def convert_sf_object_to_doc( sf_db: OnyxSalesforceSQLite, sf_object: SalesforceObject, sf_instance: str, ) -> Document: """Would be nice if this function was documented""" object_dict = sf_object.data salesforce_id = object_dict[ID_FIELD] onyx_salesforce_id = f"{ID_PREFIX}{salesforce_id}" base_url = f"https://{sf_instance}" extracted_doc_updated_at = time_str_to_utc(object_dict[MODIFIED_FIELD]) extracted_semantic_identifier = object_dict.get(NAME_FIELD) or object_dict.get( ID_FIELD, "Unknown Object" ) sections = [_extract_section(sf_object.data, f"{base_url}/{sf_object.id}")] for id in sf_db.get_child_ids(sf_object.id): if not (child_object := sf_db.get_record(id, isChild=True)): continue sections.append( _extract_section(child_object.data, f"{base_url}/{child_object.id}") ) # NOTE(rkuo): does using the parent last modified make sense if the update # is being triggered because a child object changed? primary_owner_list: list[BasicExpertInfo] | None = None primary_owner = sf_db.make_basic_expert_info_from_record(sf_object) if primary_owner: primary_owner_list = [primary_owner] doc = Document( id=onyx_salesforce_id, sections=cast(list[TextSection | ImageSection], sections), source=DocumentSource.SALESFORCE, semantic_identifier=extracted_semantic_identifier, doc_updated_at=extracted_doc_updated_at, primary_owners=primary_owner_list, metadata={}, ) return doc ================================================ FILE: backend/onyx/connectors/salesforce/onyx_salesforce.py ================================================ import time from typing import Any from simple_salesforce import Salesforce from simple_salesforce import SFType from simple_salesforce.exceptions import SalesforceRefusedRequest from onyx.connectors.cross_connector_utils.rate_limit_wrapper import ( rate_limit_builder, ) from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_OBJECTS from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_PREFIXES from onyx.connectors.salesforce.blacklist import SALESFORCE_BLACKLISTED_SUFFIXES from onyx.connectors.salesforce.salesforce_calls import get_object_by_id_query from onyx.connectors.salesforce.utils import ID_FIELD from onyx.utils.logger import setup_logger from onyx.utils.retry_wrapper import retry_builder logger = setup_logger() def is_salesforce_rate_limit_error(exception: Exception) -> bool: """Check if an exception is a Salesforce rate limit error.""" return isinstance( exception, SalesforceRefusedRequest ) and "REQUEST_LIMIT_EXCEEDED" in str(exception) class OnyxSalesforce(Salesforce): SOQL_MAX_SUBQUERIES = 20 def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.parent_types: set[str] = set() self.child_types: set[str] = set() self.parent_to_child_types: dict[str, set[str]] = ( {} ) # map from parent to child types self.child_to_parent_types: dict[str, set[str]] = ( {} ) # map from child to parent types self.parent_reference_fields_by_type: dict[str, dict[str, list[str]]] = {} self.queryable_fields_by_type: dict[str, list[str]] = {} self.prefix_to_type: dict[str, str] = ( {} ) # infer the object type of an id immediately def initialize(self) -> bool: """Eventually cache all first run client state with this method""" return True def is_blacklisted(self, object_type: str) -> bool: """Returns True if the object type is blacklisted.""" object_type_lower = object_type.lower() if object_type_lower in SALESFORCE_BLACKLISTED_OBJECTS: return True for prefix in SALESFORCE_BLACKLISTED_PREFIXES: if object_type_lower.startswith(prefix): return True for suffix in SALESFORCE_BLACKLISTED_SUFFIXES: if object_type_lower.endswith(suffix): return True return False @retry_builder( tries=6, delay=20, backoff=1.5, max_delay=60, exceptions=(SalesforceRefusedRequest,), ) @rate_limit_builder(max_calls=50, period=60) def safe_query(self, query: str, **kwargs: Any) -> dict[str, Any]: """Wrapper around the original query method with retry logic and rate limiting.""" try: return super().query(query, **kwargs) except SalesforceRefusedRequest as e: if is_salesforce_rate_limit_error(e): logger.warning( f"Salesforce rate limit exceeded for query: {query[:100]}..." ) # Add additional delay for rate limit errors time.sleep(5) raise @retry_builder( tries=5, delay=20, backoff=1.5, max_delay=60, exceptions=(SalesforceRefusedRequest,), ) @rate_limit_builder(max_calls=50, period=60) def safe_query_all(self, query: str, **kwargs: Any) -> dict[str, Any]: """Wrapper around the original query_all method with retry logic and rate limiting.""" try: return super().query_all(query, **kwargs) except SalesforceRefusedRequest as e: if is_salesforce_rate_limit_error(e): logger.warning( f"Salesforce rate limit exceeded for query_all: {query[:100]}..." ) # Add additional delay for rate limit errors time.sleep(5) raise @staticmethod def _make_child_objects_by_id_query( object_id: str, sf_type: str, child_relationships: list[str], relationships_to_fields: dict[str, set[str]], ) -> str: """Returns a SOQL query given the object id, type and child relationships. object_id: the id of the parent object sf_type: the object name/type of the parent object child_relationships: a list of the child object names/types to retrieve relationships_to_fields: a mapping of objects to their queryable fields When the query is executed, it comes back as result.records[0][child_relationship] """ # supposedly the real limit is 200? But we limit to 10 for practical reasons SUBQUERY_LIMIT = 10 query = "SELECT " for child_relationship in child_relationships: # TODO(rkuo): what happens if there is a very large list of child records? # is that possible problem? # NOTE: we actually have to list out the subqueries we want. # We can't use the following shortcuts: # FIELDS(ALL) can include binary fields, so don't use that # FIELDS(CUSTOM) can include aggregate queries, so don't use that fields = relationships_to_fields[child_relationship] fields_fragment = ",".join(fields) query += f"(SELECT {fields_fragment} FROM {child_relationship} LIMIT {SUBQUERY_LIMIT}), " query = query.rstrip(", ") query += f" FROM {sf_type} WHERE Id = '{object_id}'" return query def query_object( self, object_type: str, object_id: str, type_to_queryable_fields: dict[str, set[str]], ) -> dict[str, Any] | None: record: dict[str, Any] = {} queryable_fields = type_to_queryable_fields[object_type] query = get_object_by_id_query(object_id, object_type, queryable_fields) result = self.safe_query(query) if not result: return None record_0 = result["records"][0] for record_key, record_value in record_0.items(): if record_key == "attributes": continue record[record_key] = record_value return record def get_child_objects_by_id( self, object_id: str, sf_type: str, child_relationships: list[str], relationships_to_fields: dict[str, set[str]], ) -> dict[str, dict[str, Any]]: """There's a limit on the number of subqueries we can put in a single query.""" child_records: dict[str, dict[str, Any]] = {} child_relationships_batch: list[str] = [] remaining_child_relationships = list(child_relationships) while True: process_batch = False if ( len(remaining_child_relationships) == 0 and len(child_relationships_batch) == 0 ): break if len(child_relationships_batch) >= OnyxSalesforce.SOQL_MAX_SUBQUERIES: process_batch = True if len(remaining_child_relationships) == 0: process_batch = True if process_batch: if len(child_relationships_batch) == 0: break query = OnyxSalesforce._make_child_objects_by_id_query( object_id, sf_type, child_relationships_batch, relationships_to_fields, ) try: result = self.safe_query(query) except Exception: logger.exception(f"Query failed: {query=}") else: for child_record_key, child_result in result["records"][0].items(): if child_record_key == "attributes": continue if not child_result: continue for child_record in child_result["records"]: child_record_id = child_record[ID_FIELD] if not child_record_id: logger.warning("Child record has no id") continue child_records[f"{child_record_key}:{child_record_id}"] = ( child_record ) finally: child_relationships_batch.clear() continue if len(remaining_child_relationships) == 0: break child_relationship = remaining_child_relationships.pop(0) # this is binary content, skip it if child_relationship == "Attachments": continue child_relationships_batch.append(child_relationship) return child_records @retry_builder( tries=3, delay=1, backoff=2, exceptions=(SalesforceRefusedRequest,), ) def describe_type(self, name: str) -> Any: sf_object = SFType(name, self.session_id, self.sf_instance) try: result = sf_object.describe() return result except SalesforceRefusedRequest as e: if is_salesforce_rate_limit_error(e): logger.warning( f"Salesforce rate limit exceeded for describe_type: {name}" ) # Add additional delay for rate limit errors time.sleep(3) raise def get_queryable_fields_by_type(self, name: str) -> set[str]: object_description = self.describe_type(name) if object_description is None: return set() fields: list[dict[str, Any]] = object_description["fields"] valid_fields: set[str] = set() field_names_to_remove: set[str] = set() for field in fields: if compound_field_name := field.get("compoundFieldName"): # We do want to get name fields even if they are compound if not field.get("nameField"): field_names_to_remove.add(compound_field_name) field_name = field.get("name") field_type = field.get("type") if field_type in ["base64", "blob", "encryptedstring"]: continue if field_name: valid_fields.add(field_name) return valid_fields - field_names_to_remove def get_children_of_sf_type(self, sf_type: str) -> dict[str, str]: """Returns a dict of child object names to relationship names. Relationship names (not object names) are used in subqueries! """ names_to_relationships: dict[str, str] = {} object_description = self.describe_type(sf_type) index = 0 len_relationships = len(object_description["childRelationships"]) for child_relationship in object_description["childRelationships"]: child_name = child_relationship["childSObject"] index += 1 valid, reason = self._is_valid_child_object(child_relationship) if not valid: logger.debug( f"{index}/{len_relationships} - Invalid child object: " f"parent={sf_type} child={child_name} child_field_backreference={child_relationship['field']} {reason=}" ) continue logger.debug( f"{index}/{len_relationships} - Found valid child object: " f"parent={sf_type} child={child_name} child_field_backreference={child_relationship['field']}" ) name = child_name relationship = child_relationship["relationshipName"] names_to_relationships[name] = relationship return names_to_relationships def _is_valid_child_object( self, child_relationship: dict[str, Any] ) -> tuple[bool, str]: if not child_relationship["childSObject"]: return False, "childSObject is None" child_name = child_relationship["childSObject"] if self.is_blacklisted(child_name): return False, f"{child_name=} is blacklisted." if not child_relationship["relationshipName"]: return False, f"{child_name=} has no relationshipName." object_description = self.describe_type(child_relationship["childSObject"]) if not object_description["queryable"]: return False, f"{child_name=} is not queryable." if not child_relationship["field"]: return False, f"{child_name=} has no relationship field." if child_relationship["field"] == "RelatedToId": return False, f"{child_name=} field is RelatedToId and blacklisted." return True, "" def get_parent_reference_fields( self, sf_type: str, parent_types: set[str] ) -> dict[str, list[str]]: """ sf_type: the type in which to find parent reference fields parent_types: a list of parent reference field types we are actually interested in Other parent types will not be returned. Given an object type, returns a dict of field names to a list of referenced parent object types. (Yes, it is possible for a field to reference one of multiple object types, although this seems very unlikely.) Returns an empty dict if there are no parent reference fields. """ parent_reference_fields: dict[str, list[str]] = {} object_description = self.describe_type(sf_type) for field in object_description["fields"]: if field["type"] == "reference": for reference_to in field["referenceTo"]: if reference_to in parent_types: if field["name"] not in parent_reference_fields: parent_reference_fields[field["name"]] = [] parent_reference_fields[field["name"]].append( field["referenceTo"] ) return parent_reference_fields ================================================ FILE: backend/onyx/connectors/salesforce/salesforce_calls.py ================================================ import gc import os import time from concurrent.futures import ThreadPoolExecutor from datetime import datetime from pytz import UTC from simple_salesforce import Salesforce from simple_salesforce.bulk2 import SFBulk2Handler from simple_salesforce.bulk2 import SFBulk2Type from simple_salesforce.exceptions import SalesforceRefusedRequest from onyx.connectors.cross_connector_utils.rate_limit_wrapper import ( rate_limit_builder, ) from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.salesforce.utils import MODIFIED_FIELD from onyx.utils.logger import setup_logger from onyx.utils.retry_wrapper import retry_builder logger = setup_logger() def is_salesforce_rate_limit_error(exception: Exception) -> bool: """Check if an exception is a Salesforce rate limit error.""" return isinstance( exception, SalesforceRefusedRequest ) and "REQUEST_LIMIT_EXCEEDED" in str(exception) def _build_last_modified_time_filter_for_salesforce( start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None ) -> str: if start is None or end is None: return "" start_datetime = datetime.fromtimestamp(start, UTC) end_datetime = datetime.fromtimestamp(end, UTC) return f" WHERE LastModifiedDate > {start_datetime.isoformat()} AND LastModifiedDate < {end_datetime.isoformat()}" def _build_created_date_time_filter_for_salesforce( start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None ) -> str: if start is None or end is None: return "" start_datetime = datetime.fromtimestamp(start, UTC) end_datetime = datetime.fromtimestamp(end, UTC) return f" WHERE CreatedDate > {start_datetime.isoformat()} AND CreatedDate < {end_datetime.isoformat()}" def _make_time_filter_for_sf_type( queryable_fields: set[str], start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, ) -> str | None: if MODIFIED_FIELD in queryable_fields: return _build_last_modified_time_filter_for_salesforce(start, end) if "CreatedDate" in queryable_fields: return _build_created_date_time_filter_for_salesforce(start, end) return None def _make_time_filtered_query( queryable_fields: set[str], sf_type: str, time_filter: str ) -> str: query = f"SELECT {', '.join(queryable_fields)} FROM {sf_type}{time_filter}" return query def get_object_by_id_query( object_id: str, sf_type: str, queryable_fields: set[str] ) -> str: query = ( f"SELECT {', '.join(queryable_fields)} FROM {sf_type} WHERE Id = '{object_id}'" ) return query @retry_builder( tries=5, delay=2, backoff=2, max_delay=60, exceptions=(SalesforceRefusedRequest,), ) @rate_limit_builder(max_calls=50, period=60) def _object_type_has_api_data( sf_client: Salesforce, sf_type: str, time_filter: str ) -> bool: """ Use the rest api to check to make sure the query will result in a non-empty response. """ try: query = f"SELECT Count() FROM {sf_type}{time_filter} LIMIT 1" result = sf_client.query(query) if result["totalSize"] == 0: return False except SalesforceRefusedRequest as e: if is_salesforce_rate_limit_error(e): logger.warning( f"Salesforce rate limit exceeded for object type check: {sf_type}" ) # Add additional delay for rate limit errors time.sleep(3) raise except Exception as e: if "OPERATION_TOO_LARGE" not in str(e): logger.warning(f"Object type {sf_type} doesn't support query: {e}") return False return True def _bulk_retrieve_from_salesforce( sf_type: str, query: str, target_dir: str, sf_client: Salesforce, ) -> tuple[str, list[str] | None]: """Returns a tuple of 1. the salesforce object type (NOTE: seems redundant) 2. the list of CSV's written into the target directory """ bulk_2_handler: SFBulk2Handler | None = SFBulk2Handler( session_id=sf_client.session_id, bulk2_url=sf_client.bulk2_url, proxies=sf_client.proxies, session=sf_client.session, ) if not bulk_2_handler: return sf_type, None # NOTE(rkuo): there are signs this download is allocating large # amounts of memory instead of streaming the results to disk. # we're doing a gc.collect to try and mitigate this. # see https://github.com/simple-salesforce/simple-salesforce/issues/428 for a # possible solution bulk_2_type: SFBulk2Type | None = SFBulk2Type( object_name=sf_type, bulk2_url=bulk_2_handler.bulk2_url, headers=bulk_2_handler.headers, session=bulk_2_handler.session, ) if not bulk_2_type: return sf_type, None logger.info(f"Downloading {sf_type}") logger.debug(f"Query: {query}") try: # This downloads the file to a file in the target path with a random name results = bulk_2_type.download( query=query, path=target_dir, max_records=500000, ) # prepend each downloaded csv with the object type (delimiter = '.') all_download_paths: list[str] = [] for result in results: original_file_path = result["file"] directory, filename = os.path.split(original_file_path) new_filename = f"{sf_type}.{filename}" new_file_path = os.path.join(directory, new_filename) os.rename(original_file_path, new_file_path) all_download_paths.append(new_file_path) except Exception as e: logger.error( f"Failed to download salesforce csv for object type {sf_type}: {e}" ) logger.warning(f"Exceptioning query for object type {sf_type}: {query}") return sf_type, None finally: bulk_2_handler = None bulk_2_type = None gc.collect() logger.info(f"Downloaded {sf_type} to {all_download_paths}") return sf_type, all_download_paths def fetch_all_csvs_in_parallel( sf_client: Salesforce, all_types_to_filter: dict[str, bool], queryable_fields_by_type: dict[str, set[str]], start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None, target_dir: str, ) -> dict[str, list[str] | None]: """ Fetches all the csvs in parallel for the given object types Returns a dict of (sf_type, full_download_path) NOTE: We can probably lift object type has api data out of here """ type_to_query = {} # query the available fields for each object type and determine how to filter for sf_type, apply_filter in all_types_to_filter.items(): queryable_fields = queryable_fields_by_type[sf_type] time_filter = "" while True: if not apply_filter: break if start is not None and end is not None: time_filter_temp = _make_time_filter_for_sf_type( queryable_fields, start, end ) if time_filter_temp is None: logger.warning( f"Object type not filterable: type={sf_type} fields={queryable_fields}" ) time_filter = "" else: logger.info( f"Object type filterable: type={sf_type} filter={time_filter_temp}" ) time_filter = time_filter_temp break if not _object_type_has_api_data(sf_client, sf_type, time_filter): logger.warning(f"Object type skipped (no data available): type={sf_type}") continue query = _make_time_filtered_query(queryable_fields, sf_type, time_filter) type_to_query[sf_type] = query logger.info( f"Object types to query: initial={len(all_types_to_filter)} queryable={len(type_to_query)}" ) # Run the bulk retrieve in parallel # limit to 4 to help with memory usage with ThreadPoolExecutor(max_workers=4) as executor: results = executor.map( lambda object_type: _bulk_retrieve_from_salesforce( sf_type=object_type, query=type_to_query[object_type], target_dir=target_dir, sf_client=sf_client, ), type_to_query.keys(), ) return dict(results) ================================================ FILE: backend/onyx/connectors/salesforce/shelve_stuff/old_test_salesforce_shelves.py ================================================ import csv import os import shutil from onyx.connectors.salesforce.shelve_stuff.shelve_functions import find_ids_by_type from onyx.connectors.salesforce.shelve_stuff.shelve_functions import ( get_affected_parent_ids_by_type, ) from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_child_ids from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_record from onyx.connectors.salesforce.shelve_stuff.shelve_functions import ( update_sf_db_with_csv, ) from onyx.connectors.salesforce.utils import BASE_DATA_PATH from onyx.connectors.salesforce.utils import get_object_type_path _VALID_SALESFORCE_IDS = [ "001bm00000fd9Z3AAI", "001bm00000fdYTdAAM", "001bm00000fdYTeAAM", "001bm00000fdYTfAAM", "001bm00000fdYTgAAM", "001bm00000fdYThAAM", "001bm00000fdYTiAAM", "001bm00000fdYTjAAM", "001bm00000fdYTkAAM", "001bm00000fdYTlAAM", "001bm00000fdYTmAAM", "001bm00000fdYTnAAM", "001bm00000fdYToAAM", "500bm00000XoOxtAAF", "500bm00000XoOxuAAF", "500bm00000XoOxvAAF", "500bm00000XoOxwAAF", "500bm00000XoOxxAAF", "500bm00000XoOxyAAF", "500bm00000XoOxzAAF", "500bm00000XoOy0AAF", "500bm00000XoOy1AAF", "500bm00000XoOy2AAF", "500bm00000XoOy3AAF", "500bm00000XoOy4AAF", "500bm00000XoOy5AAF", "500bm00000XoOy6AAF", "500bm00000XoOy7AAF", "500bm00000XoOy8AAF", "500bm00000XoOy9AAF", "500bm00000XoOyAAAV", "500bm00000XoOyBAAV", "500bm00000XoOyCAAV", "500bm00000XoOyDAAV", "500bm00000XoOyEAAV", "500bm00000XoOyFAAV", "500bm00000XoOyGAAV", "500bm00000XoOyHAAV", "500bm00000XoOyIAAV", "003bm00000EjHCjAAN", "003bm00000EjHCkAAN", "003bm00000EjHClAAN", "003bm00000EjHCmAAN", "003bm00000EjHCnAAN", "003bm00000EjHCoAAN", "003bm00000EjHCpAAN", "003bm00000EjHCqAAN", "003bm00000EjHCrAAN", "003bm00000EjHCsAAN", "003bm00000EjHCtAAN", "003bm00000EjHCuAAN", "003bm00000EjHCvAAN", "003bm00000EjHCwAAN", "003bm00000EjHCxAAN", "003bm00000EjHCyAAN", "003bm00000EjHCzAAN", "003bm00000EjHD0AAN", "003bm00000EjHD1AAN", "003bm00000EjHD2AAN", "550bm00000EXc2tAAD", "006bm000006kyDpAAI", "006bm000006kyDqAAI", "006bm000006kyDrAAI", "006bm000006kyDsAAI", "006bm000006kyDtAAI", "006bm000006kyDuAAI", "006bm000006kyDvAAI", "006bm000006kyDwAAI", "006bm000006kyDxAAI", "006bm000006kyDyAAI", "006bm000006kyDzAAI", "006bm000006kyE0AAI", "006bm000006kyE1AAI", "006bm000006kyE2AAI", "006bm000006kyE3AAI", "006bm000006kyE4AAI", "006bm000006kyE5AAI", "006bm000006kyE6AAI", "006bm000006kyE7AAI", "006bm000006kyE8AAI", "006bm000006kyE9AAI", "006bm000006kyEAAAY", "006bm000006kyEBAAY", "006bm000006kyECAAY", "006bm000006kyEDAAY", "006bm000006kyEEAAY", "006bm000006kyEFAAY", "006bm000006kyEGAAY", "006bm000006kyEHAAY", "006bm000006kyEIAAY", "006bm000006kyEJAAY", "005bm000009zy0TAAQ", "005bm000009zy25AAA", "005bm000009zy26AAA", "005bm000009zy28AAA", "005bm000009zy29AAA", "005bm000009zy2AAAQ", "005bm000009zy2BAAQ", ] def clear_sf_db() -> None: """ Clears the SF DB by deleting all files in the data directory. """ shutil.rmtree(BASE_DATA_PATH) def create_csv_file( object_type: str, records: list[dict], filename: str = "test_data.csv" ) -> None: """ Creates a CSV file for the given object type and records. Args: object_type: The Salesforce object type (e.g. "Account", "Contact") records: List of dictionaries containing the record data filename: Name of the CSV file to create (default: test_data.csv) """ if not records: return # Get all unique fields from records fields: set[str] = set() for record in records: fields.update(record.keys()) fields = set(sorted(list(fields))) # Sort for consistent order # Create CSV file csv_path = os.path.join(get_object_type_path(object_type), filename) with open(csv_path, "w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=fields) writer.writeheader() for record in records: writer.writerow(record) # Update the database with the CSV update_sf_db_with_csv(object_type, csv_path) def create_csv_with_example_data() -> None: """ Creates CSV files with example data, organized by object type. """ example_data: dict[str, list[dict]] = { "Account": [ { "Id": _VALID_SALESFORCE_IDS[0], "Name": "Acme Inc.", "BillingCity": "New York", "Industry": "Technology", }, { "Id": _VALID_SALESFORCE_IDS[1], "Name": "Globex Corp", "BillingCity": "Los Angeles", "Industry": "Manufacturing", }, { "Id": _VALID_SALESFORCE_IDS[2], "Name": "Initech", "BillingCity": "Austin", "Industry": "Software", }, { "Id": _VALID_SALESFORCE_IDS[3], "Name": "TechCorp Solutions", "BillingCity": "San Francisco", "Industry": "Software", "AnnualRevenue": 5000000, }, { "Id": _VALID_SALESFORCE_IDS[4], "Name": "BioMed Research", "BillingCity": "Boston", "Industry": "Healthcare", "AnnualRevenue": 12000000, }, { "Id": _VALID_SALESFORCE_IDS[5], "Name": "Green Energy Co", "BillingCity": "Portland", "Industry": "Energy", "AnnualRevenue": 8000000, }, { "Id": _VALID_SALESFORCE_IDS[6], "Name": "DataFlow Analytics", "BillingCity": "Seattle", "Industry": "Technology", "AnnualRevenue": 3000000, }, { "Id": _VALID_SALESFORCE_IDS[7], "Name": "Cloud Nine Services", "BillingCity": "Denver", "Industry": "Cloud Computing", "AnnualRevenue": 7000000, }, ], "Contact": [ { "Id": _VALID_SALESFORCE_IDS[40], "FirstName": "John", "LastName": "Doe", "Email": "john.doe@acme.com", "Title": "CEO", }, { "Id": _VALID_SALESFORCE_IDS[41], "FirstName": "Jane", "LastName": "Smith", "Email": "jane.smith@acme.com", "Title": "CTO", }, { "Id": _VALID_SALESFORCE_IDS[42], "FirstName": "Bob", "LastName": "Johnson", "Email": "bob.j@globex.com", "Title": "Sales Director", }, { "Id": _VALID_SALESFORCE_IDS[43], "FirstName": "Sarah", "LastName": "Chen", "Email": "sarah.chen@techcorp.com", "Title": "Product Manager", "Phone": "415-555-0101", }, { "Id": _VALID_SALESFORCE_IDS[44], "FirstName": "Michael", "LastName": "Rodriguez", "Email": "m.rodriguez@biomed.com", "Title": "Research Director", "Phone": "617-555-0202", }, { "Id": _VALID_SALESFORCE_IDS[45], "FirstName": "Emily", "LastName": "Green", "Email": "emily.g@greenenergy.com", "Title": "Sustainability Lead", "Phone": "503-555-0303", }, { "Id": _VALID_SALESFORCE_IDS[46], "FirstName": "David", "LastName": "Kim", "Email": "david.kim@dataflow.com", "Title": "Data Scientist", "Phone": "206-555-0404", }, { "Id": _VALID_SALESFORCE_IDS[47], "FirstName": "Rachel", "LastName": "Taylor", "Email": "r.taylor@cloudnine.com", "Title": "Cloud Architect", "Phone": "303-555-0505", }, ], "Opportunity": [ { "Id": _VALID_SALESFORCE_IDS[62], "Name": "Acme Server Upgrade", "Amount": 50000, "Stage": "Prospecting", "CloseDate": "2024-06-30", }, { "Id": _VALID_SALESFORCE_IDS[63], "Name": "Globex Manufacturing Line", "Amount": 150000, "Stage": "Negotiation", "CloseDate": "2024-03-15", }, { "Id": _VALID_SALESFORCE_IDS[64], "Name": "Initech Software License", "Amount": 75000, "Stage": "Closed Won", "CloseDate": "2024-01-30", }, { "Id": _VALID_SALESFORCE_IDS[65], "Name": "TechCorp AI Implementation", "Amount": 250000, "Stage": "Needs Analysis", "CloseDate": "2024-08-15", "Probability": 60, }, { "Id": _VALID_SALESFORCE_IDS[66], "Name": "BioMed Lab Equipment", "Amount": 500000, "Stage": "Value Proposition", "CloseDate": "2024-09-30", "Probability": 75, }, { "Id": _VALID_SALESFORCE_IDS[67], "Name": "Green Energy Solar Project", "Amount": 750000, "Stage": "Proposal", "CloseDate": "2024-07-15", "Probability": 80, }, { "Id": _VALID_SALESFORCE_IDS[68], "Name": "DataFlow Analytics Platform", "Amount": 180000, "Stage": "Negotiation", "CloseDate": "2024-05-30", "Probability": 90, }, { "Id": _VALID_SALESFORCE_IDS[69], "Name": "Cloud Nine Infrastructure", "Amount": 300000, "Stage": "Qualification", "CloseDate": "2024-10-15", "Probability": 40, }, ], } # Create CSV files for each object type for object_type, records in example_data.items(): create_csv_file(object_type, records) def test_query() -> None: """ Tests querying functionality by verifying: 1. All expected Account IDs are found 2. Each Account's data matches what was inserted """ # Expected test data for verification expected_accounts: dict[str, dict[str, str | int]] = { _VALID_SALESFORCE_IDS[0]: { "Name": "Acme Inc.", "BillingCity": "New York", "Industry": "Technology", }, _VALID_SALESFORCE_IDS[1]: { "Name": "Globex Corp", "BillingCity": "Los Angeles", "Industry": "Manufacturing", }, _VALID_SALESFORCE_IDS[2]: { "Name": "Initech", "BillingCity": "Austin", "Industry": "Software", }, _VALID_SALESFORCE_IDS[3]: { "Name": "TechCorp Solutions", "BillingCity": "San Francisco", "Industry": "Software", "AnnualRevenue": 5000000, }, _VALID_SALESFORCE_IDS[4]: { "Name": "BioMed Research", "BillingCity": "Boston", "Industry": "Healthcare", "AnnualRevenue": 12000000, }, _VALID_SALESFORCE_IDS[5]: { "Name": "Green Energy Co", "BillingCity": "Portland", "Industry": "Energy", "AnnualRevenue": 8000000, }, _VALID_SALESFORCE_IDS[6]: { "Name": "DataFlow Analytics", "BillingCity": "Seattle", "Industry": "Technology", "AnnualRevenue": 3000000, }, _VALID_SALESFORCE_IDS[7]: { "Name": "Cloud Nine Services", "BillingCity": "Denver", "Industry": "Cloud Computing", "AnnualRevenue": 7000000, }, } # Get all Account IDs account_ids = find_ids_by_type("Account") # Verify we found all expected accounts assert len(account_ids) == len( expected_accounts ), f"Expected {len(expected_accounts)} accounts, found {len(account_ids)}" assert set(account_ids) == set( expected_accounts.keys() ), "Found account IDs don't match expected IDs" # Verify each account's data for acc_id in account_ids: combined = get_record(acc_id) assert combined is not None, f"Could not find account {acc_id}" expected = expected_accounts[acc_id] # Verify account data matches for key, value in expected.items(): value = str(value) assert ( combined.data[key] == value ), f"Account {acc_id} field {key} expected {value}, got {combined.data[key]}" print("All query tests passed successfully!") def test_upsert() -> None: """ Tests upsert functionality by: 1. Updating an existing account 2. Creating a new account 3. Verifying both operations were successful """ # Create CSV for updating an existing account and adding a new one update_data: list[dict[str, str | int]] = [ { "Id": _VALID_SALESFORCE_IDS[0], "Name": "Acme Inc. Updated", "BillingCity": "New York", "Industry": "Technology", "Description": "Updated company info", }, { "Id": _VALID_SALESFORCE_IDS[2], "Name": "New Company Inc.", "BillingCity": "Miami", "Industry": "Finance", "AnnualRevenue": 1000000, }, ] create_csv_file("Account", update_data, "update_data.csv") # Verify the update worked updated_record = get_record(_VALID_SALESFORCE_IDS[0]) assert updated_record is not None, "Updated record not found" assert updated_record.data["Name"] == "Acme Inc. Updated", "Name not updated" assert ( updated_record.data["Description"] == "Updated company info" ), "Description not added" # Verify the new record was created new_record = get_record(_VALID_SALESFORCE_IDS[2]) assert new_record is not None, "New record not found" assert new_record.data["Name"] == "New Company Inc.", "New record name incorrect" assert new_record.data["AnnualRevenue"] == "1000000", "New record revenue incorrect" print("All upsert tests passed successfully!") def test_relationships() -> None: """ Tests relationship shelf updates and queries by: 1. Creating test data with relationships 2. Verifying the relationships are correctly stored 3. Testing relationship queries """ # Create test data for each object type test_data: dict[str, list[dict[str, str | int]]] = { "Case": [ { "Id": _VALID_SALESFORCE_IDS[13], "AccountId": _VALID_SALESFORCE_IDS[0], "Subject": "Test Case 1", }, { "Id": _VALID_SALESFORCE_IDS[14], "AccountId": _VALID_SALESFORCE_IDS[0], "Subject": "Test Case 2", }, ], "Contact": [ { "Id": _VALID_SALESFORCE_IDS[48], "AccountId": _VALID_SALESFORCE_IDS[0], "FirstName": "Test", "LastName": "Contact", } ], "Opportunity": [ { "Id": _VALID_SALESFORCE_IDS[62], "AccountId": _VALID_SALESFORCE_IDS[0], "Name": "Test Opportunity", "Amount": 100000, } ], } # Create and update CSV files for each object type for object_type, records in test_data.items(): create_csv_file(object_type, records, "relationship_test.csv") # Test relationship queries # All these objects should be children of Acme Inc. child_ids = get_child_ids(_VALID_SALESFORCE_IDS[0]) assert len(child_ids) == 4, f"Expected 4 child objects, found {len(child_ids)}" assert _VALID_SALESFORCE_IDS[13] in child_ids, "Case 1 not found in relationship" assert _VALID_SALESFORCE_IDS[14] in child_ids, "Case 2 not found in relationship" assert _VALID_SALESFORCE_IDS[48] in child_ids, "Contact not found in relationship" assert ( _VALID_SALESFORCE_IDS[62] in child_ids ), "Opportunity not found in relationship" # Test querying relationships for a different account (should be empty) other_account_children = get_child_ids(_VALID_SALESFORCE_IDS[1]) assert ( len(other_account_children) == 0 ), "Expected no children for different account" print("All relationship tests passed successfully!") def test_account_with_children() -> None: """ Tests querying all accounts and retrieving their child objects. This test verifies that: 1. All accounts can be retrieved 2. Child objects are correctly linked 3. Child object data is complete and accurate """ # First get all account IDs account_ids = find_ids_by_type("Account") assert len(account_ids) > 0, "No accounts found" # For each account, get its children and verify the data for account_id in account_ids: account = get_record(account_id) assert account is not None, f"Could not find account {account_id}" # Get all child objects child_ids = get_child_ids(account_id) # For Acme Inc., verify specific relationships if account_id == _VALID_SALESFORCE_IDS[0]: # Acme Inc. assert ( len(child_ids) == 4 ), f"Expected 4 children for Acme Inc., found {len(child_ids)}" # Get all child records child_records = [] for child_id in child_ids: child_record = get_record(child_id) if child_record is not None: child_records.append(child_record) # Verify Cases cases = [r for r in child_records if r.type == "Case"] assert ( len(cases) == 2 ), f"Expected 2 cases for Acme Inc., found {len(cases)}" case_subjects = {case.data["Subject"] for case in cases} assert "Test Case 1" in case_subjects, "Test Case 1 not found" assert "Test Case 2" in case_subjects, "Test Case 2 not found" # Verify Contacts contacts = [r for r in child_records if r.type == "Contact"] assert ( len(contacts) == 1 ), f"Expected 1 contact for Acme Inc., found {len(contacts)}" contact = contacts[0] assert contact.data["FirstName"] == "Test", "Contact first name mismatch" assert contact.data["LastName"] == "Contact", "Contact last name mismatch" # Verify Opportunities opportunities = [r for r in child_records if r.type == "Opportunity"] assert ( len(opportunities) == 1 ), f"Expected 1 opportunity for Acme Inc., found {len(opportunities)}" opportunity = opportunities[0] assert ( opportunity.data["Name"] == "Test Opportunity" ), "Opportunity name mismatch" assert opportunity.data["Amount"] == "100000", "Opportunity amount mismatch" print("All account with children tests passed successfully!") def test_relationship_updates() -> None: """ Tests that relationships are properly updated when a child object's parent reference changes. This test verifies: 1. Initial relationship is created correctly 2. When parent reference is updated, old relationship is removed 3. New relationship is created correctly """ # Create initial test data - Contact linked to Acme Inc. initial_contact = [ { "Id": _VALID_SALESFORCE_IDS[40], "AccountId": _VALID_SALESFORCE_IDS[0], "FirstName": "Test", "LastName": "Contact", } ] create_csv_file("Contact", initial_contact, "initial_contact.csv") # Verify initial relationship acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0]) assert ( _VALID_SALESFORCE_IDS[40] in acme_children ), "Initial relationship not created" # Update contact to be linked to Globex Corp instead updated_contact = [ { "Id": _VALID_SALESFORCE_IDS[40], "AccountId": _VALID_SALESFORCE_IDS[1], "FirstName": "Test", "LastName": "Contact", } ] create_csv_file("Contact", updated_contact, "updated_contact.csv") # Verify old relationship is removed acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0]) assert ( _VALID_SALESFORCE_IDS[40] not in acme_children ), "Old relationship not removed" # Verify new relationship is created globex_children = get_child_ids(_VALID_SALESFORCE_IDS[1]) assert _VALID_SALESFORCE_IDS[40] in globex_children, "New relationship not created" print("All relationship update tests passed successfully!") def test_get_affected_parent_ids() -> None: """ Tests get_affected_parent_ids functionality by verifying: 1. IDs that are directly in the parent_types list are included 2. IDs that have children in the updated_ids list are included 3. IDs that are neither of the above are not included """ # Create test data with relationships test_data = { "Account": [ { "Id": _VALID_SALESFORCE_IDS[0], "Name": "Parent Account 1", }, { "Id": _VALID_SALESFORCE_IDS[1], "Name": "Parent Account 2", }, { "Id": _VALID_SALESFORCE_IDS[2], "Name": "Not Affected Account", }, ], "Contact": [ { "Id": _VALID_SALESFORCE_IDS[40], "AccountId": _VALID_SALESFORCE_IDS[0], "FirstName": "Child", "LastName": "Contact", } ], } # Create and update CSV files for test data for object_type, records in test_data.items(): create_csv_file(object_type, records) # Test Case 1: Account directly in updated_ids and parent_types updated_ids = {_VALID_SALESFORCE_IDS[1]} # Parent Account 2 parent_types = ["Account"] affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types) assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included" # Test Case 2: Account with child in updated_ids updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact parent_types = ["Account"] affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types) assert ( _VALID_SALESFORCE_IDS[0] in affected_ids ), "Parent of updated child not included" # Test Case 3: Both direct and indirect affects updated_ids = {_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]} # Both cases parent_types = ["Account"] affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types) assert len(affected_ids) == 2, "Expected exactly two affected parent IDs" assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included" assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included" assert ( _VALID_SALESFORCE_IDS[2] not in affected_ids ), "Unaffected ID incorrectly included" # Test Case 4: No matches updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact parent_types = ["Opportunity"] # Wrong type affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types) assert len(affected_ids) == 0, "Should return empty list when no matches" print("All get_affected_parent_ids tests passed successfully!") def main_build() -> None: clear_sf_db() create_csv_with_example_data() test_query() test_upsert() test_relationships() test_account_with_children() test_relationship_updates() test_get_affected_parent_ids() if __name__ == "__main__": main_build() ================================================ FILE: backend/onyx/connectors/salesforce/shelve_stuff/shelve_functions.py ================================================ import csv import shelve from onyx.connectors.salesforce.shelve_stuff.shelve_utils import ( get_child_to_parent_shelf_path, ) from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_id_type_shelf_path from onyx.connectors.salesforce.shelve_stuff.shelve_utils import get_object_shelf_path from onyx.connectors.salesforce.shelve_stuff.shelve_utils import ( get_parent_to_child_shelf_path, ) from onyx.connectors.salesforce.utils import SalesforceObject from onyx.connectors.salesforce.utils import validate_salesforce_id from onyx.utils.logger import setup_logger logger = setup_logger() def _update_relationship_shelves( child_id: str, parent_ids: set[str], ) -> None: """Update the relationship shelf when a record is updated.""" try: # Convert child_id to string once str_child_id = str(child_id) # First update child to parent mapping with shelve.open( get_child_to_parent_shelf_path(), flag="c", protocol=None, writeback=True, ) as child_to_parent_db: old_parent_ids = set(child_to_parent_db.get(str_child_id, [])) child_to_parent_db[str_child_id] = list(parent_ids) # Calculate differences outside the next context manager parent_ids_to_remove = old_parent_ids - parent_ids parent_ids_to_add = parent_ids - old_parent_ids # Only sync once at the end child_to_parent_db.sync() # Then update parent to child mapping in a single transaction if not parent_ids_to_remove and not parent_ids_to_add: return with shelve.open( get_parent_to_child_shelf_path(), flag="c", protocol=None, writeback=True, ) as parent_to_child_db: # Process all removals first for parent_id in parent_ids_to_remove: str_parent_id = str(parent_id) existing_children = set(parent_to_child_db.get(str_parent_id, [])) if str_child_id in existing_children: existing_children.remove(str_child_id) parent_to_child_db[str_parent_id] = list(existing_children) # Then process all additions for parent_id in parent_ids_to_add: str_parent_id = str(parent_id) existing_children = set(parent_to_child_db.get(str_parent_id, [])) existing_children.add(str_child_id) parent_to_child_db[str_parent_id] = list(existing_children) # Single sync at the end parent_to_child_db.sync() except Exception as e: logger.error(f"Error updating relationship shelves: {e}") logger.error(f"Child ID: {child_id}, Parent IDs: {parent_ids}") raise def get_child_ids(parent_id: str) -> set[str]: """Get all child IDs for a given parent ID. Args: parent_id: The ID of the parent object Returns: A set of child object IDs """ with shelve.open(get_parent_to_child_shelf_path()) as parent_to_child_db: return set(parent_to_child_db.get(parent_id, [])) def update_sf_db_with_csv( object_type: str, csv_download_path: str, ) -> list[str]: """Update the SF DB with a CSV file using shelve storage.""" updated_ids = [] shelf_path = get_object_shelf_path(object_type) # First read the CSV to get all the data with open(csv_download_path, "r", newline="", encoding="utf-8") as f: reader = csv.DictReader(f) for row in reader: id = row["Id"] parent_ids = set() field_to_remove: set[str] = set() # Update relationship shelves for any parent references for field, value in row.items(): if validate_salesforce_id(value) and field != "Id": parent_ids.add(value) field_to_remove.add(field) if not value: field_to_remove.add(field) _update_relationship_shelves(id, parent_ids) for field in field_to_remove: # We use this to extract the Primary Owner later if field != "LastModifiedById": del row[field] # Update the main object shelf with shelve.open(shelf_path) as object_type_db: object_type_db[id] = row # Update the ID-to-type mapping shelf with shelve.open(get_id_type_shelf_path()) as id_type_db: id_type_db[id] = object_type updated_ids.append(id) # os.remove(csv_download_path) return updated_ids def get_type_from_id(object_id: str) -> str | None: """Get the type of an object from its ID.""" # Look up the object type from the ID-to-type mapping with shelve.open(get_id_type_shelf_path()) as id_type_db: if object_id not in id_type_db: logger.warning(f"Object ID {object_id} not found in ID-to-type mapping") return None return id_type_db[object_id] def get_record( object_id: str, object_type: str | None = None ) -> SalesforceObject | None: """ Retrieve the record and return it as a SalesforceObject. The object type will be looked up from the ID-to-type mapping shelf. """ if object_type is None: if not (object_type := get_type_from_id(object_id)): return None shelf_path = get_object_shelf_path(object_type) with shelve.open(shelf_path) as db: if object_id not in db: logger.warning(f"Object ID {object_id} not found in {shelf_path}") return None data = db[object_id] return SalesforceObject( id=object_id, type=object_type, data=data, ) def find_ids_by_type(object_type: str) -> list[str]: """ Find all object IDs for rows of the specified type. """ shelf_path = get_object_shelf_path(object_type) try: with shelve.open(shelf_path) as db: return list(db.keys()) except FileNotFoundError: return [] def get_affected_parent_ids_by_type( updated_ids: set[str], parent_types: list[str] ) -> dict[str, set[str]]: """Get IDs of objects that are of the specified parent types and are either in the updated_ids or have children in the updated_ids. Args: updated_ids: List of IDs that were updated parent_types: List of object types to filter by Returns: A dictionary of IDs that match the criteria """ affected_ids_by_type: dict[str, set[str]] = {} # Check each updated ID for updated_id in updated_ids: # Add the ID itself if it's of a parent type updated_type = get_type_from_id(updated_id) if updated_type in parent_types: affected_ids_by_type.setdefault(updated_type, set()).add(updated_id) continue # Get parents of this ID and add them if they're of a parent type with shelve.open(get_child_to_parent_shelf_path()) as child_to_parent_db: parent_ids = child_to_parent_db.get(updated_id, []) for parent_id in parent_ids: parent_type = get_type_from_id(parent_id) if parent_type in parent_types: affected_ids_by_type.setdefault(parent_type, set()).add(parent_id) return affected_ids_by_type ================================================ FILE: backend/onyx/connectors/salesforce/shelve_stuff/shelve_utils.py ================================================ import os from onyx.connectors.salesforce.utils import BASE_DATA_PATH from onyx.connectors.salesforce.utils import get_object_type_path def get_object_shelf_path(object_type: str) -> str: """Get the path to the shelf file for a specific object type.""" base_path = get_object_type_path(object_type) os.makedirs(base_path, exist_ok=True) return os.path.join(base_path, "data.shelf") def get_id_type_shelf_path() -> str: """Get the path to the ID-to-type mapping shelf.""" os.makedirs(BASE_DATA_PATH, exist_ok=True) return os.path.join(BASE_DATA_PATH, "id_type_mapping.shelf.4g") def get_parent_to_child_shelf_path() -> str: """Get the path to the parent-to-child mapping shelf.""" os.makedirs(BASE_DATA_PATH, exist_ok=True) return os.path.join(BASE_DATA_PATH, "parent_to_child_mapping.shelf.4g") def get_child_to_parent_shelf_path() -> str: """Get the path to the child-to-parent mapping shelf.""" os.makedirs(BASE_DATA_PATH, exist_ok=True) return os.path.join(BASE_DATA_PATH, "child_to_parent_mapping.shelf.4g") ================================================ FILE: backend/onyx/connectors/salesforce/shelve_stuff/test_salesforce_shelves.py ================================================ import csv import os import shutil from onyx.connectors.salesforce.shelve_stuff.shelve_functions import find_ids_by_type from onyx.connectors.salesforce.shelve_stuff.shelve_functions import ( get_affected_parent_ids_by_type, ) from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_child_ids from onyx.connectors.salesforce.shelve_stuff.shelve_functions import get_record from onyx.connectors.salesforce.shelve_stuff.shelve_functions import ( update_sf_db_with_csv, ) from onyx.connectors.salesforce.utils import BASE_DATA_PATH from onyx.connectors.salesforce.utils import get_object_type_path _VALID_SALESFORCE_IDS = [ "001bm00000fd9Z3AAI", "001bm00000fdYTdAAM", "001bm00000fdYTeAAM", "001bm00000fdYTfAAM", "001bm00000fdYTgAAM", "001bm00000fdYThAAM", "001bm00000fdYTiAAM", "001bm00000fdYTjAAM", "001bm00000fdYTkAAM", "001bm00000fdYTlAAM", "001bm00000fdYTmAAM", "001bm00000fdYTnAAM", "001bm00000fdYToAAM", "500bm00000XoOxtAAF", "500bm00000XoOxuAAF", "500bm00000XoOxvAAF", "500bm00000XoOxwAAF", "500bm00000XoOxxAAF", "500bm00000XoOxyAAF", "500bm00000XoOxzAAF", "500bm00000XoOy0AAF", "500bm00000XoOy1AAF", "500bm00000XoOy2AAF", "500bm00000XoOy3AAF", "500bm00000XoOy4AAF", "500bm00000XoOy5AAF", "500bm00000XoOy6AAF", "500bm00000XoOy7AAF", "500bm00000XoOy8AAF", "500bm00000XoOy9AAF", "500bm00000XoOyAAAV", "500bm00000XoOyBAAV", "500bm00000XoOyCAAV", "500bm00000XoOyDAAV", "500bm00000XoOyEAAV", "500bm00000XoOyFAAV", "500bm00000XoOyGAAV", "500bm00000XoOyHAAV", "500bm00000XoOyIAAV", "003bm00000EjHCjAAN", "003bm00000EjHCkAAN", "003bm00000EjHClAAN", "003bm00000EjHCmAAN", "003bm00000EjHCnAAN", "003bm00000EjHCoAAN", "003bm00000EjHCpAAN", "003bm00000EjHCqAAN", "003bm00000EjHCrAAN", "003bm00000EjHCsAAN", "003bm00000EjHCtAAN", "003bm00000EjHCuAAN", "003bm00000EjHCvAAN", "003bm00000EjHCwAAN", "003bm00000EjHCxAAN", "003bm00000EjHCyAAN", "003bm00000EjHCzAAN", "003bm00000EjHD0AAN", "003bm00000EjHD1AAN", "003bm00000EjHD2AAN", "550bm00000EXc2tAAD", "006bm000006kyDpAAI", "006bm000006kyDqAAI", "006bm000006kyDrAAI", "006bm000006kyDsAAI", "006bm000006kyDtAAI", "006bm000006kyDuAAI", "006bm000006kyDvAAI", "006bm000006kyDwAAI", "006bm000006kyDxAAI", "006bm000006kyDyAAI", "006bm000006kyDzAAI", "006bm000006kyE0AAI", "006bm000006kyE1AAI", "006bm000006kyE2AAI", "006bm000006kyE3AAI", "006bm000006kyE4AAI", "006bm000006kyE5AAI", "006bm000006kyE6AAI", "006bm000006kyE7AAI", "006bm000006kyE8AAI", "006bm000006kyE9AAI", "006bm000006kyEAAAY", "006bm000006kyEBAAY", "006bm000006kyECAAY", "006bm000006kyEDAAY", "006bm000006kyEEAAY", "006bm000006kyEFAAY", "006bm000006kyEGAAY", "006bm000006kyEHAAY", "006bm000006kyEIAAY", "006bm000006kyEJAAY", "005bm000009zy0TAAQ", "005bm000009zy25AAA", "005bm000009zy26AAA", "005bm000009zy28AAA", "005bm000009zy29AAA", "005bm000009zy2AAAQ", "005bm000009zy2BAAQ", ] def clear_sf_db() -> None: """ Clears the SF DB by deleting all files in the data directory. """ shutil.rmtree(BASE_DATA_PATH) def create_csv_file( object_type: str, records: list[dict], filename: str = "test_data.csv" ) -> None: """ Creates a CSV file for the given object type and records. Args: object_type: The Salesforce object type (e.g. "Account", "Contact") records: List of dictionaries containing the record data filename: Name of the CSV file to create (default: test_data.csv) """ if not records: return # Get all unique fields from records fields: set[str] = set() for record in records: fields.update(record.keys()) fields = set(sorted(list(fields))) # Sort for consistent order # Create CSV file csv_path = os.path.join(get_object_type_path(object_type), filename) with open(csv_path, "w", newline="", encoding="utf-8") as f: writer = csv.DictWriter(f, fieldnames=fields) writer.writeheader() for record in records: writer.writerow(record) # Update the database with the CSV update_sf_db_with_csv(object_type, csv_path) def create_csv_with_example_data() -> None: """ Creates CSV files with example data, organized by object type. """ example_data: dict[str, list[dict]] = { "Account": [ { "Id": _VALID_SALESFORCE_IDS[0], "Name": "Acme Inc.", "BillingCity": "New York", "Industry": "Technology", }, { "Id": _VALID_SALESFORCE_IDS[1], "Name": "Globex Corp", "BillingCity": "Los Angeles", "Industry": "Manufacturing", }, { "Id": _VALID_SALESFORCE_IDS[2], "Name": "Initech", "BillingCity": "Austin", "Industry": "Software", }, { "Id": _VALID_SALESFORCE_IDS[3], "Name": "TechCorp Solutions", "BillingCity": "San Francisco", "Industry": "Software", "AnnualRevenue": 5000000, }, { "Id": _VALID_SALESFORCE_IDS[4], "Name": "BioMed Research", "BillingCity": "Boston", "Industry": "Healthcare", "AnnualRevenue": 12000000, }, { "Id": _VALID_SALESFORCE_IDS[5], "Name": "Green Energy Co", "BillingCity": "Portland", "Industry": "Energy", "AnnualRevenue": 8000000, }, { "Id": _VALID_SALESFORCE_IDS[6], "Name": "DataFlow Analytics", "BillingCity": "Seattle", "Industry": "Technology", "AnnualRevenue": 3000000, }, { "Id": _VALID_SALESFORCE_IDS[7], "Name": "Cloud Nine Services", "BillingCity": "Denver", "Industry": "Cloud Computing", "AnnualRevenue": 7000000, }, ], "Contact": [ { "Id": _VALID_SALESFORCE_IDS[40], "FirstName": "John", "LastName": "Doe", "Email": "john.doe@acme.com", "Title": "CEO", }, { "Id": _VALID_SALESFORCE_IDS[41], "FirstName": "Jane", "LastName": "Smith", "Email": "jane.smith@acme.com", "Title": "CTO", }, { "Id": _VALID_SALESFORCE_IDS[42], "FirstName": "Bob", "LastName": "Johnson", "Email": "bob.j@globex.com", "Title": "Sales Director", }, { "Id": _VALID_SALESFORCE_IDS[43], "FirstName": "Sarah", "LastName": "Chen", "Email": "sarah.chen@techcorp.com", "Title": "Product Manager", "Phone": "415-555-0101", }, { "Id": _VALID_SALESFORCE_IDS[44], "FirstName": "Michael", "LastName": "Rodriguez", "Email": "m.rodriguez@biomed.com", "Title": "Research Director", "Phone": "617-555-0202", }, { "Id": _VALID_SALESFORCE_IDS[45], "FirstName": "Emily", "LastName": "Green", "Email": "emily.g@greenenergy.com", "Title": "Sustainability Lead", "Phone": "503-555-0303", }, { "Id": _VALID_SALESFORCE_IDS[46], "FirstName": "David", "LastName": "Kim", "Email": "david.kim@dataflow.com", "Title": "Data Scientist", "Phone": "206-555-0404", }, { "Id": _VALID_SALESFORCE_IDS[47], "FirstName": "Rachel", "LastName": "Taylor", "Email": "r.taylor@cloudnine.com", "Title": "Cloud Architect", "Phone": "303-555-0505", }, ], "Opportunity": [ { "Id": _VALID_SALESFORCE_IDS[62], "Name": "Acme Server Upgrade", "Amount": 50000, "Stage": "Prospecting", "CloseDate": "2024-06-30", }, { "Id": _VALID_SALESFORCE_IDS[63], "Name": "Globex Manufacturing Line", "Amount": 150000, "Stage": "Negotiation", "CloseDate": "2024-03-15", }, { "Id": _VALID_SALESFORCE_IDS[64], "Name": "Initech Software License", "Amount": 75000, "Stage": "Closed Won", "CloseDate": "2024-01-30", }, { "Id": _VALID_SALESFORCE_IDS[65], "Name": "TechCorp AI Implementation", "Amount": 250000, "Stage": "Needs Analysis", "CloseDate": "2024-08-15", "Probability": 60, }, { "Id": _VALID_SALESFORCE_IDS[66], "Name": "BioMed Lab Equipment", "Amount": 500000, "Stage": "Value Proposition", "CloseDate": "2024-09-30", "Probability": 75, }, { "Id": _VALID_SALESFORCE_IDS[67], "Name": "Green Energy Solar Project", "Amount": 750000, "Stage": "Proposal", "CloseDate": "2024-07-15", "Probability": 80, }, { "Id": _VALID_SALESFORCE_IDS[68], "Name": "DataFlow Analytics Platform", "Amount": 180000, "Stage": "Negotiation", "CloseDate": "2024-05-30", "Probability": 90, }, { "Id": _VALID_SALESFORCE_IDS[69], "Name": "Cloud Nine Infrastructure", "Amount": 300000, "Stage": "Qualification", "CloseDate": "2024-10-15", "Probability": 40, }, ], } # Create CSV files for each object type for object_type, records in example_data.items(): create_csv_file(object_type, records) def test_query() -> None: """ Tests querying functionality by verifying: 1. All expected Account IDs are found 2. Each Account's data matches what was inserted """ # Expected test data for verification expected_accounts: dict[str, dict[str, str | int]] = { _VALID_SALESFORCE_IDS[0]: { "Name": "Acme Inc.", "BillingCity": "New York", "Industry": "Technology", }, _VALID_SALESFORCE_IDS[1]: { "Name": "Globex Corp", "BillingCity": "Los Angeles", "Industry": "Manufacturing", }, _VALID_SALESFORCE_IDS[2]: { "Name": "Initech", "BillingCity": "Austin", "Industry": "Software", }, _VALID_SALESFORCE_IDS[3]: { "Name": "TechCorp Solutions", "BillingCity": "San Francisco", "Industry": "Software", "AnnualRevenue": 5000000, }, _VALID_SALESFORCE_IDS[4]: { "Name": "BioMed Research", "BillingCity": "Boston", "Industry": "Healthcare", "AnnualRevenue": 12000000, }, _VALID_SALESFORCE_IDS[5]: { "Name": "Green Energy Co", "BillingCity": "Portland", "Industry": "Energy", "AnnualRevenue": 8000000, }, _VALID_SALESFORCE_IDS[6]: { "Name": "DataFlow Analytics", "BillingCity": "Seattle", "Industry": "Technology", "AnnualRevenue": 3000000, }, _VALID_SALESFORCE_IDS[7]: { "Name": "Cloud Nine Services", "BillingCity": "Denver", "Industry": "Cloud Computing", "AnnualRevenue": 7000000, }, } # Get all Account IDs account_ids = find_ids_by_type("Account") # Verify we found all expected accounts assert len(account_ids) == len( expected_accounts ), f"Expected {len(expected_accounts)} accounts, found {len(account_ids)}" assert set(account_ids) == set( expected_accounts.keys() ), "Found account IDs don't match expected IDs" # Verify each account's data for acc_id in account_ids: combined = get_record(acc_id) assert combined is not None, f"Could not find account {acc_id}" expected = expected_accounts[acc_id] # Verify account data matches for key, value in expected.items(): value = str(value) assert ( combined.data[key] == value ), f"Account {acc_id} field {key} expected {value}, got {combined.data[key]}" print("All query tests passed successfully!") def test_upsert() -> None: """ Tests upsert functionality by: 1. Updating an existing account 2. Creating a new account 3. Verifying both operations were successful """ # Create CSV for updating an existing account and adding a new one update_data: list[dict[str, str | int]] = [ { "Id": _VALID_SALESFORCE_IDS[0], "Name": "Acme Inc. Updated", "BillingCity": "New York", "Industry": "Technology", "Description": "Updated company info", }, { "Id": _VALID_SALESFORCE_IDS[2], "Name": "New Company Inc.", "BillingCity": "Miami", "Industry": "Finance", "AnnualRevenue": 1000000, }, ] create_csv_file("Account", update_data, "update_data.csv") # Verify the update worked updated_record = get_record(_VALID_SALESFORCE_IDS[0]) assert updated_record is not None, "Updated record not found" assert updated_record.data["Name"] == "Acme Inc. Updated", "Name not updated" assert ( updated_record.data["Description"] == "Updated company info" ), "Description not added" # Verify the new record was created new_record = get_record(_VALID_SALESFORCE_IDS[2]) assert new_record is not None, "New record not found" assert new_record.data["Name"] == "New Company Inc.", "New record name incorrect" assert new_record.data["AnnualRevenue"] == "1000000", "New record revenue incorrect" print("All upsert tests passed successfully!") def test_relationships() -> None: """ Tests relationship shelf updates and queries by: 1. Creating test data with relationships 2. Verifying the relationships are correctly stored 3. Testing relationship queries """ # Create test data for each object type test_data: dict[str, list[dict[str, str | int]]] = { "Case": [ { "Id": _VALID_SALESFORCE_IDS[13], "AccountId": _VALID_SALESFORCE_IDS[0], "Subject": "Test Case 1", }, { "Id": _VALID_SALESFORCE_IDS[14], "AccountId": _VALID_SALESFORCE_IDS[0], "Subject": "Test Case 2", }, ], "Contact": [ { "Id": _VALID_SALESFORCE_IDS[48], "AccountId": _VALID_SALESFORCE_IDS[0], "FirstName": "Test", "LastName": "Contact", } ], "Opportunity": [ { "Id": _VALID_SALESFORCE_IDS[62], "AccountId": _VALID_SALESFORCE_IDS[0], "Name": "Test Opportunity", "Amount": 100000, } ], } # Create and update CSV files for each object type for object_type, records in test_data.items(): create_csv_file(object_type, records, "relationship_test.csv") # Test relationship queries # All these objects should be children of Acme Inc. child_ids = get_child_ids(_VALID_SALESFORCE_IDS[0]) assert len(child_ids) == 4, f"Expected 4 child objects, found {len(child_ids)}" assert _VALID_SALESFORCE_IDS[13] in child_ids, "Case 1 not found in relationship" assert _VALID_SALESFORCE_IDS[14] in child_ids, "Case 2 not found in relationship" assert _VALID_SALESFORCE_IDS[48] in child_ids, "Contact not found in relationship" assert ( _VALID_SALESFORCE_IDS[62] in child_ids ), "Opportunity not found in relationship" # Test querying relationships for a different account (should be empty) other_account_children = get_child_ids(_VALID_SALESFORCE_IDS[1]) assert ( len(other_account_children) == 0 ), "Expected no children for different account" print("All relationship tests passed successfully!") def test_account_with_children() -> None: """ Tests querying all accounts and retrieving their child objects. This test verifies that: 1. All accounts can be retrieved 2. Child objects are correctly linked 3. Child object data is complete and accurate """ # First get all account IDs account_ids = find_ids_by_type("Account") assert len(account_ids) > 0, "No accounts found" # For each account, get its children and verify the data for account_id in account_ids: account = get_record(account_id) assert account is not None, f"Could not find account {account_id}" # Get all child objects child_ids = get_child_ids(account_id) # For Acme Inc., verify specific relationships if account_id == _VALID_SALESFORCE_IDS[0]: # Acme Inc. assert ( len(child_ids) == 4 ), f"Expected 4 children for Acme Inc., found {len(child_ids)}" # Get all child records child_records = [] for child_id in child_ids: child_record = get_record(child_id) if child_record is not None: child_records.append(child_record) # Verify Cases cases = [r for r in child_records if r.type == "Case"] assert ( len(cases) == 2 ), f"Expected 2 cases for Acme Inc., found {len(cases)}" case_subjects = {case.data["Subject"] for case in cases} assert "Test Case 1" in case_subjects, "Test Case 1 not found" assert "Test Case 2" in case_subjects, "Test Case 2 not found" # Verify Contacts contacts = [r for r in child_records if r.type == "Contact"] assert ( len(contacts) == 1 ), f"Expected 1 contact for Acme Inc., found {len(contacts)}" contact = contacts[0] assert contact.data["FirstName"] == "Test", "Contact first name mismatch" assert contact.data["LastName"] == "Contact", "Contact last name mismatch" # Verify Opportunities opportunities = [r for r in child_records if r.type == "Opportunity"] assert ( len(opportunities) == 1 ), f"Expected 1 opportunity for Acme Inc., found {len(opportunities)}" opportunity = opportunities[0] assert ( opportunity.data["Name"] == "Test Opportunity" ), "Opportunity name mismatch" assert opportunity.data["Amount"] == "100000", "Opportunity amount mismatch" print("All account with children tests passed successfully!") def test_relationship_updates() -> None: """ Tests that relationships are properly updated when a child object's parent reference changes. This test verifies: 1. Initial relationship is created correctly 2. When parent reference is updated, old relationship is removed 3. New relationship is created correctly """ # Create initial test data - Contact linked to Acme Inc. initial_contact = [ { "Id": _VALID_SALESFORCE_IDS[40], "AccountId": _VALID_SALESFORCE_IDS[0], "FirstName": "Test", "LastName": "Contact", } ] create_csv_file("Contact", initial_contact, "initial_contact.csv") # Verify initial relationship acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0]) assert ( _VALID_SALESFORCE_IDS[40] in acme_children ), "Initial relationship not created" # Update contact to be linked to Globex Corp instead updated_contact = [ { "Id": _VALID_SALESFORCE_IDS[40], "AccountId": _VALID_SALESFORCE_IDS[1], "FirstName": "Test", "LastName": "Contact", } ] create_csv_file("Contact", updated_contact, "updated_contact.csv") # Verify old relationship is removed acme_children = get_child_ids(_VALID_SALESFORCE_IDS[0]) assert ( _VALID_SALESFORCE_IDS[40] not in acme_children ), "Old relationship not removed" # Verify new relationship is created globex_children = get_child_ids(_VALID_SALESFORCE_IDS[1]) assert _VALID_SALESFORCE_IDS[40] in globex_children, "New relationship not created" print("All relationship update tests passed successfully!") def test_get_affected_parent_ids() -> None: """ Tests get_affected_parent_ids functionality by verifying: 1. IDs that are directly in the parent_types list are included 2. IDs that have children in the updated_ids list are included 3. IDs that are neither of the above are not included """ # Create test data with relationships test_data = { "Account": [ { "Id": _VALID_SALESFORCE_IDS[0], "Name": "Parent Account 1", }, { "Id": _VALID_SALESFORCE_IDS[1], "Name": "Parent Account 2", }, { "Id": _VALID_SALESFORCE_IDS[2], "Name": "Not Affected Account", }, ], "Contact": [ { "Id": _VALID_SALESFORCE_IDS[40], "AccountId": _VALID_SALESFORCE_IDS[0], "FirstName": "Child", "LastName": "Contact", } ], } # Create and update CSV files for test data for object_type, records in test_data.items(): create_csv_file(object_type, records) # Test Case 1: Account directly in updated_ids and parent_types updated_ids = {_VALID_SALESFORCE_IDS[1]} # Parent Account 2 parent_types = ["Account"] affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types) assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included" # Test Case 2: Account with child in updated_ids updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact parent_types = ["Account"] affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types) assert ( _VALID_SALESFORCE_IDS[0] in affected_ids ), "Parent of updated child not included" # Test Case 3: Both direct and indirect affects updated_ids = {_VALID_SALESFORCE_IDS[1], _VALID_SALESFORCE_IDS[40]} # Both cases parent_types = ["Account"] affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types) assert len(affected_ids) == 2, "Expected exactly two affected parent IDs" assert _VALID_SALESFORCE_IDS[0] in affected_ids, "Parent of child not included" assert _VALID_SALESFORCE_IDS[1] in affected_ids, "Direct parent ID not included" assert ( _VALID_SALESFORCE_IDS[2] not in affected_ids ), "Unaffected ID incorrectly included" # Test Case 4: No matches updated_ids = {_VALID_SALESFORCE_IDS[40]} # Child Contact parent_types = ["Opportunity"] # Wrong type affected_ids = get_affected_parent_ids_by_type(updated_ids, parent_types) assert len(affected_ids) == 0, "Should return empty list when no matches" print("All get_affected_parent_ids tests passed successfully!") def main_build() -> None: clear_sf_db() create_csv_with_example_data() test_query() test_upsert() test_relationships() test_account_with_children() test_relationship_updates() test_get_affected_parent_ids() if __name__ == "__main__": main_build() ================================================ FILE: backend/onyx/connectors/salesforce/sqlite_functions.py ================================================ import csv import json import os import sqlite3 import time from collections.abc import Iterator from pathlib import Path from typing import Any from typing import cast from onyx.connectors.models import BasicExpertInfo from onyx.connectors.salesforce.utils import ACCOUNT_OBJECT_TYPE from onyx.connectors.salesforce.utils import ID_FIELD from onyx.connectors.salesforce.utils import NAME_FIELD from onyx.connectors.salesforce.utils import remove_sqlite_db_files from onyx.connectors.salesforce.utils import SalesforceObject from onyx.connectors.salesforce.utils import USER_OBJECT_TYPE from onyx.connectors.salesforce.utils import validate_salesforce_id from onyx.utils.logger import setup_logger from shared_configs.utils import batch_list logger = setup_logger() SQLITE_DISK_IO_ERROR = "disk I/O error" class OnyxSalesforceSQLite: """Notes on context management using 'with self.conn': Does autocommit / rollback on exit. Does NOT close on exit! .close must be called explicitly. """ # NOTE(rkuo): this string could probably occur naturally. A more unique value # might be appropriate here. NULL_ID_STRING = "N/A" def __init__(self, filename: str, isolation_level: str | None = None): self.filename = filename self.isolation_level = isolation_level self._conn: sqlite3.Connection | None = None # this is only set on connection. This variable does not change # when a new db is initialized with this class. self._existing_db = True def __del__(self) -> None: self.close() @property def file_size(self) -> int: """Returns -1 if the file does not exist.""" if not self.filename: return -1 if not os.path.exists(self.filename): return -1 file_path = Path(self.filename) return file_path.stat().st_size def connect(self) -> None: if self._conn is not None: self._conn.close() self._conn = None self._existing_db = os.path.exists(self.filename) # make the path if it doesn't already exist os.makedirs(os.path.dirname(self.filename), exist_ok=True) conn = sqlite3.connect(self.filename, timeout=60.0) if self.isolation_level is not None: conn.isolation_level = self.isolation_level self._conn = conn def close(self) -> None: if self._conn is None: return self._conn.close() self._conn = None def cursor(self) -> sqlite3.Cursor: if self._conn is None: raise RuntimeError("Database connection is closed") return self._conn.cursor() def flush(self) -> None: """We're using SQLite in WAL mode sometimes. To flush to the DB we have to call this.""" if self._conn is None: raise RuntimeError("Database connection is closed") with self._conn: cursor = self._conn.cursor() cursor.execute("PRAGMA wal_checkpoint(FULL)") def apply_schema(self) -> None: """Initialize the SQLite database with required tables if they don't exist. Non-destructive operation. If a disk I/O error is encountered (often due to stale WAL/SHM files from a previous crash), this method will attempt to recover by removing the corrupted files and recreating the database. """ try: self._apply_schema_impl() except sqlite3.OperationalError as e: if SQLITE_DISK_IO_ERROR not in str(e): raise logger.warning(f"SQLite disk I/O error detected, attempting recovery: {e}") self._recover_from_corruption() self._apply_schema_impl() def _recover_from_corruption(self) -> None: """Recover from SQLite corruption by removing all database files and reconnecting.""" logger.info(f"Removing corrupted SQLite files: {self.filename}") # Close existing connection self.close() # Remove all SQLite files (main db, WAL, SHM) remove_sqlite_db_files(self.filename) # Reconnect - this will create a fresh database self.connect() logger.info("SQLite recovery complete, fresh database created") def _apply_schema_impl(self) -> None: """Internal implementation of apply_schema.""" if self._conn is None: raise RuntimeError("Database connection is closed") start = time.monotonic() with self._conn: cursor = self._conn.cursor() if self._existing_db: file_path = Path(self.filename) file_size = file_path.stat().st_size logger.info(f"init_db - found existing sqlite db: len={file_size}") else: # NOTE(rkuo): why is this only if the db doesn't exist? # Enable WAL mode for better concurrent access and write performance cursor.execute("PRAGMA journal_mode=WAL") cursor.execute("PRAGMA synchronous=NORMAL") cursor.execute("PRAGMA temp_store=MEMORY") cursor.execute("PRAGMA cache_size=-2000000") # Use 2GB memory for cache # Main table for storing Salesforce objects cursor.execute( """ CREATE TABLE IF NOT EXISTS salesforce_objects ( id TEXT PRIMARY KEY, object_type TEXT NOT NULL, data TEXT NOT NULL, -- JSON serialized data last_modified INTEGER DEFAULT (strftime('%s', 'now')) -- Add timestamp for better cache management ) WITHOUT ROWID -- Optimize for primary key lookups """ ) # NOTE(rkuo): this seems completely redundant with relationship_types # Table for parent-child relationships with covering index cursor.execute( """ CREATE TABLE IF NOT EXISTS relationships ( child_id TEXT NOT NULL, parent_id TEXT NOT NULL, PRIMARY KEY (child_id, parent_id) ) WITHOUT ROWID -- Optimize for primary key lookups """ ) # New table for caching parent-child relationships with object types cursor.execute( """ CREATE TABLE IF NOT EXISTS relationship_types ( child_id TEXT NOT NULL, parent_id TEXT NOT NULL, parent_type TEXT NOT NULL, PRIMARY KEY (child_id, parent_id, parent_type) ) WITHOUT ROWID """ ) # Create a table for User email to ID mapping if it doesn't exist cursor.execute( """ CREATE TABLE IF NOT EXISTS user_email_map ( email TEXT PRIMARY KEY, user_id TEXT, -- Nullable to allow for users without IDs FOREIGN KEY (user_id) REFERENCES salesforce_objects(id) ) WITHOUT ROWID """ ) # Create indexes if they don't exist (SQLite ignores IF NOT EXISTS for indexes) def create_index_if_not_exists( index_name: str, create_statement: str ) -> None: cursor.execute( f"SELECT name FROM sqlite_master WHERE type='index' AND name='{index_name}'" ) if not cursor.fetchone(): cursor.execute(create_statement) create_index_if_not_exists( "idx_object_type", """ CREATE INDEX idx_object_type ON salesforce_objects(object_type, id) WHERE object_type IS NOT NULL """, ) create_index_if_not_exists( "idx_parent_id", """ CREATE INDEX idx_parent_id ON relationships(parent_id, child_id) """, ) create_index_if_not_exists( "idx_child_parent", """ CREATE INDEX idx_child_parent ON relationships(child_id) WHERE child_id IS NOT NULL """, ) create_index_if_not_exists( "idx_relationship_types_lookup", """ CREATE INDEX idx_relationship_types_lookup ON relationship_types(parent_type, child_id, parent_id) """, ) elapsed = time.monotonic() - start logger.info(f"init_db - create tables and indices: elapsed={elapsed:.2f}") # Analyze tables to help query planner # NOTE(rkuo): skip ANALYZE - it takes too long and we likely don't have # complicated queries that need this # start = time.monotonic() # cursor.execute("ANALYZE relationships") # cursor.execute("ANALYZE salesforce_objects") # cursor.execute("ANALYZE relationship_types") # cursor.execute("ANALYZE user_email_map") # elapsed = time.monotonic() - start # logger.info(f"init_db - analyze: elapsed={elapsed:.2f}") # If database already existed but user_email_map needs to be populated start = time.monotonic() cursor.execute("SELECT COUNT(*) FROM user_email_map") elapsed = time.monotonic() - start logger.info(f"init_db - count user_email_map: elapsed={elapsed:.2f}") start = time.monotonic() if cursor.fetchone()[0] == 0: OnyxSalesforceSQLite._update_user_email_map(cursor) elapsed = time.monotonic() - start logger.info(f"init_db - update_user_email_map: elapsed={elapsed:.2f}") def get_user_id_by_email(self, email: str) -> str | None: """Get the Salesforce User ID for a given email address. Args: email: The email address to look up Returns: A tuple of (was_found, user_id): - was_found: True if the email exists in the table, False if not found - user_id: The Salesforce User ID if exists, None otherwise """ if self._conn is None: raise RuntimeError("Database connection is closed") with self._conn: cursor = self._conn.cursor() cursor.execute( "SELECT user_id FROM user_email_map WHERE email = ?", (email,) ) result = cursor.fetchone() if result is None: return None return result[0] def update_email_to_id_table(self, email: str, id: str | None) -> None: """Update the email to ID map table with a new email and ID.""" if self._conn is None: raise RuntimeError("Database connection is closed") id_to_use = id or self.NULL_ID_STRING with self._conn: cursor = self._conn.cursor() cursor.execute( "INSERT OR REPLACE INTO user_email_map (email, user_id) VALUES (?, ?)", (email, id_to_use), ) def log_stats(self) -> None: if self._conn is None: raise RuntimeError("Database connection is closed") with self._conn: cache_pages = self._conn.execute("PRAGMA cache_size").fetchone()[0] page_size = self._conn.execute("PRAGMA page_size").fetchone()[0] if cache_pages >= 0: cache_bytes = cache_pages * page_size else: cache_bytes = abs(cache_pages * 1024) logger.info( f"SQLite stats: sqlite_version={sqlite3.sqlite_version} " f"cache_pages={cache_pages} " f"page_size={page_size} " f"cache_bytes={cache_bytes}" ) # get_changed_parent_ids_by_type_2 replaces this def get_changed_parent_ids_by_type( self, changed_ids: list[str], parent_types: set[str], batch_size: int = 500, ) -> Iterator[tuple[str, str, int]]: """Get IDs of objects that are of the specified parent types and are either in the updated_ids or have children in the updated_ids. Yields tuples of (parent_type, affected_ids, num_examined). NOTE(rkuo): This function used to have some interesting behavior ... it created batches of id's and yielded back a list once for each parent type within that batch. There's no need to expose the details of the internal batching to the caller, so we're now yielding once per changed parent. """ if self._conn is None: raise RuntimeError("Database connection is closed") updated_parent_ids: set[str] = ( set() ) # dedupes parent id's that have already been yielded # SQLite typically has a limit of 999 variables num_examined = 0 updated_ids_batches = batch_list(changed_ids, batch_size) with self._conn: cursor = self._conn.cursor() for batch_ids in updated_ids_batches: num_examined += len(batch_ids) batch_ids = list(set(batch_ids) - updated_parent_ids) if not batch_ids: continue id_placeholders = ",".join(["?" for _ in batch_ids]) for parent_type in parent_types: affected_ids: set[str] = set() # Get directly updated objects of parent types - using index on object_type cursor.execute( f""" SELECT id FROM salesforce_objects WHERE id IN ({id_placeholders}) AND object_type = ? """, batch_ids + [parent_type], ) affected_ids.update(row[0] for row in cursor.fetchall()) # Get parent objects of updated objects - using optimized relationship_types table cursor.execute( f""" SELECT DISTINCT parent_id FROM relationship_types INDEXED BY idx_relationship_types_lookup WHERE parent_type = ? AND child_id IN ({id_placeholders}) """, [parent_type] + batch_ids, ) affected_ids.update(row[0] for row in cursor.fetchall()) # Remove any parent IDs that have already been processed newly_affected_ids = affected_ids - updated_parent_ids # Add the new affected IDs to the set of updated parent IDs if newly_affected_ids: # Yield each newly affected ID individually for parent_id in newly_affected_ids: yield parent_type, parent_id, num_examined updated_parent_ids.update(newly_affected_ids) def get_changed_parent_ids_by_type_2( self, changed_ids: dict[str, str], parent_types: set[str], parent_relationship_fields_by_type: dict[str, dict[str, list[str]]], prefix_to_type: dict[str, str], ) -> Iterator[tuple[str, str, int]]: """ This function yields back any changed parent id's based on a relationship lookup. Yields tuples of (changed_id, parent_type, num_examined) changed_id is the id of the changed parent record parent_type is the object table/type of the id (based on a prefix lookup) num_examined is an integer which signifies our progress through the changed_id's dict changed_ids is a list of all id's that changed, both parent and children. parent This is much simpler than get_changed_parent_ids_by_type. TODO(rkuo): for common entities, the first 3 chars identify the object type see https://help.salesforce.com/s/articleView?id=000385203&type=1 """ changed_parent_ids: set[str] = ( set() ) # dedupes parent id's that have already been yielded # SQLite typically has a limit of 999 variables num_examined = 0 for changed_id, changed_type in changed_ids.items(): num_examined += 1 # if we yielded this id already, continue if changed_id in changed_parent_ids: continue # if this id is a parent type, yield it directly if changed_type in parent_types: yield changed_id, changed_type, num_examined changed_parent_ids.add(changed_id) continue # if this id is a child type, then check the columns # that relate it to the parent id and yield those ids # NOTE: Although unlikely, id's yielded in this way may not be of the # type we're interested in, so the caller must be prepared # for the id to not be present # get the child id record sf_object = self.get_record(changed_id, changed_type) if not sf_object: continue # get the fields that contain parent id's parent_relationship_fields = parent_relationship_fields_by_type[ changed_type ] for field_name, _ in parent_relationship_fields.items(): if field_name not in sf_object.data: logger.warning(f"{field_name=} not in data for {changed_type=}!") continue parent_id = cast(str, sf_object.data[field_name]) parent_id_prefix = parent_id[:3] if parent_id_prefix not in prefix_to_type: logger.warning( f"Could not lookup type for prefix: {parent_id_prefix=}" ) continue parent_type = prefix_to_type[parent_id_prefix] if parent_type not in parent_types: continue yield parent_id, parent_type, num_examined changed_parent_ids.add(parent_id) break def object_type_count(self, object_type: str) -> int: """Check if there is at least one object of the specified type in the database. Args: object_type: The Salesforce object type to check Returns: bool: True if at least one object exists, False otherwise """ if self._conn is None: raise RuntimeError("Database connection is closed") with self._conn: cursor = self._conn.cursor() cursor.execute( "SELECT COUNT(*) FROM salesforce_objects WHERE object_type = ?", (object_type,), ) count = cursor.fetchone()[0] return count @staticmethod def normalize_record( original_record: dict[str, Any], remove_ids: bool = True, ) -> tuple[dict[str, Any], set[str]]: """Takes a dict of field names to values and removes fields we don't want. This means most parent id field's and any fields with null values. Return a json string and a list of parent_id's in the record. """ parent_ids: set[str] = set() fields_to_remove: set[str] = set() record = original_record.copy() for field, value in record.items(): # remove empty fields if not value: fields_to_remove.add(field) continue if field == "attributes": fields_to_remove.add(field) continue # remove salesforce id's (and add to parent id set) if ( field != ID_FIELD and isinstance(value, str) and validate_salesforce_id(value) ): parent_ids.add(value) if remove_ids: fields_to_remove.add(field) continue # this field is real data, leave it alone # Remove unwanted fields for field in fields_to_remove: if field != "LastModifiedById": del record[field] return record, parent_ids def update_from_csv( self, object_type: str, csv_download_path: str, remove_ids: bool = True ) -> list[str]: """Update the SF DB with a CSV file using SQLite storage.""" if self._conn is None: raise RuntimeError("Database connection is closed") # some customers need this to be larger than the default 128KB, go with 16MB csv.field_size_limit(16 * 1024 * 1024) updated_ids = [] with self._conn: cursor = self._conn.cursor() with open(csv_download_path, "r", newline="", encoding="utf-8") as f: reader = csv.DictReader(f) uncommitted_rows = 0 for row in reader: if ID_FIELD not in row: logger.warning( f"Row {row} does not have an {ID_FIELD} field in {csv_download_path}" ) continue row_id = row[ID_FIELD] normalized_record, parent_ids = ( OnyxSalesforceSQLite.normalize_record(row, remove_ids) ) normalized_record_json_str = json.dumps(normalized_record) # Update main object data # NOTE(rkuo): looks like we take a list and dump it as json into the db cursor.execute( """ INSERT OR REPLACE INTO salesforce_objects (id, object_type, data) VALUES (?, ?, ?) """, (row_id, object_type, normalized_record_json_str), ) # Update relationships using the same connection OnyxSalesforceSQLite._update_relationship_tables( cursor, row_id, parent_ids ) updated_ids.append(row_id) # periodically commit or else memory will balloon uncommitted_rows += 1 if uncommitted_rows >= 1024: self._conn.commit() uncommitted_rows = 0 # If we're updating User objects, update the email map if object_type == USER_OBJECT_TYPE: OnyxSalesforceSQLite._update_user_email_map(cursor) return updated_ids def get_child_ids(self, parent_id: str) -> set[str]: """Get all child IDs for a given parent ID.""" if self._conn is None: raise RuntimeError("Database connection is closed") with self._conn: cursor = self._conn.cursor() # Force index usage with INDEXED BY cursor.execute( "SELECT child_id FROM relationships INDEXED BY idx_parent_id WHERE parent_id = ?", (parent_id,), ) child_ids = {row[0] for row in cursor.fetchall()} return child_ids def get_type_from_id(self, object_id: str) -> str | None: """Get the type of an object from its ID.""" if self._conn is None: raise RuntimeError("Database connection is closed") with self._conn: cursor = self._conn.cursor() cursor.execute( "SELECT object_type FROM salesforce_objects WHERE id = ?", (object_id,) ) result = cursor.fetchone() if not result: logger.warning(f"Object ID {object_id} not found") return None return result[0] def get_record( self, object_id: str, object_type: str | None = None, isChild: bool = False ) -> SalesforceObject | None: """Retrieve the record and return it as a SalesforceObject.""" if self._conn is None: raise RuntimeError("Database connection is closed") if object_type is None: object_type = self.get_type_from_id(object_id) if not object_type: return None with self._conn: cursor = self._conn.cursor() # Get the object data and account data if object_type == ACCOUNT_OBJECT_TYPE or isChild: cursor.execute( "SELECT data FROM salesforce_objects WHERE id = ?", (object_id,) ) else: cursor.execute( "SELECT pso.data, r.parent_id as parent_id, sso.object_type FROM salesforce_objects pso \ LEFT JOIN relationships r on r.child_id = pso.id \ LEFT JOIN salesforce_objects sso on r.parent_id = sso.id \ WHERE pso.id = ? ", (object_id,), ) result = cursor.fetchall() if not result: logger.warning(f"Object ID {object_id} not found") return None data = json.loads(result[0][0]) if object_type != ACCOUNT_OBJECT_TYPE: # convert any account ids of the relationships back into data fields, with name for row in result: # the following skips Account objects. if len(row) < 3: continue if row[1] and row[2] and row[2] == ACCOUNT_OBJECT_TYPE: data["AccountId"] = row[1] cursor.execute( "SELECT data FROM salesforce_objects WHERE id = ?", (row[1],), ) account_data = json.loads(cursor.fetchone()[0]) data[ACCOUNT_OBJECT_TYPE] = account_data.get(NAME_FIELD, "") return SalesforceObject(id=object_id, type=object_type, data=data) def find_ids_by_type(self, object_type: str) -> list[str]: """Find all object IDs for rows of the specified type.""" if self._conn is None: raise RuntimeError("Database connection is closed") with self._conn: cursor = self._conn.cursor() cursor.execute( "SELECT id FROM salesforce_objects WHERE object_type = ?", (object_type,), ) return [row[0] for row in cursor.fetchall()] @staticmethod def _update_relationship_tables( cursor: sqlite3.Cursor, child_id: str, parent_ids: set[str] ) -> None: """Given a child id and a set of parent id's, updates the relationships of the child to the parents in the db and removes old relationships. Args: conn: The database connection to use (must be in a transaction) child_id: The ID of the child record parent_ids: Set of parent IDs to link to """ try: # Get existing parent IDs cursor.execute( "SELECT parent_id FROM relationships WHERE child_id = ?", (child_id,) ) old_parent_ids = {row[0] for row in cursor.fetchall()} # Calculate differences parent_ids_to_remove = old_parent_ids - parent_ids parent_ids_to_add = parent_ids - old_parent_ids # Remove old relationships if parent_ids_to_remove: cursor.executemany( "DELETE FROM relationships WHERE child_id = ? AND parent_id = ?", [(child_id, parent_id) for parent_id in parent_ids_to_remove], ) # Also remove from relationship_types cursor.executemany( "DELETE FROM relationship_types WHERE child_id = ? AND parent_id = ?", [(child_id, parent_id) for parent_id in parent_ids_to_remove], ) # Add new relationships if parent_ids_to_add: # First add to relationships table cursor.executemany( "INSERT INTO relationships (child_id, parent_id) VALUES (?, ?)", [(child_id, parent_id) for parent_id in parent_ids_to_add], ) # Then get the types of the parent objects and add to relationship_types for parent_id in parent_ids_to_add: cursor.execute( "SELECT object_type FROM salesforce_objects WHERE id = ?", (parent_id,), ) result = cursor.fetchone() if result: parent_type = result[0] cursor.execute( """ INSERT INTO relationship_types (child_id, parent_id, parent_type) VALUES (?, ?, ?) """, (child_id, parent_id, parent_type), ) except Exception: logger.exception( f"Error updating relationship tables: child_id={child_id} parent_ids={parent_ids}" ) raise @staticmethod def _update_user_email_map(cursor: sqlite3.Cursor) -> None: """Update the user_email_map table with current User objects. Called internally by update_sf_db_with_csv when User objects are updated. """ cursor.execute( """ INSERT OR REPLACE INTO user_email_map (email, user_id) SELECT json_extract(data, '$.Email'), id FROM salesforce_objects WHERE object_type = 'User' AND json_extract(data, '$.Email') IS NOT NULL """ ) def make_basic_expert_info_from_record( self, sf_object: SalesforceObject, ) -> BasicExpertInfo | None: """Parses record for LastModifiedById and returns BasicExpertInfo of the user if possible.""" object_dict: dict[str, Any] = sf_object.data if not (last_modified_by_id := object_dict.get("LastModifiedById")): logger.warning(f"No LastModifiedById found for {sf_object.id}") return None if not (last_modified_by := self.get_record(last_modified_by_id)): logger.warning(f"No LastModifiedBy found for {last_modified_by_id}") return None try: expert_info = BasicExpertInfo.from_dict(last_modified_by.data) except Exception: return None return expert_info ================================================ FILE: backend/onyx/connectors/salesforce/utils.py ================================================ import os from dataclasses import dataclass from typing import Any NAME_FIELD = "Name" MODIFIED_FIELD = "LastModifiedDate" ID_FIELD = "Id" ACCOUNT_OBJECT_TYPE = "Account" USER_OBJECT_TYPE = "User" @dataclass class SalesforceObject: id: str type: str data: dict[str, Any] def to_dict(self) -> dict[str, Any]: return { "ID": self.id, "Type": self.type, "Data": self.data, } @classmethod def from_dict(cls, data: dict[str, Any]) -> "SalesforceObject": return cls( id=data[ID_FIELD], type=data["Type"], data=data, ) # This defines the base path for all data files relative to this file # AKA BE CAREFUL WHEN MOVING THIS FILE BASE_DATA_PATH = os.path.join(os.path.dirname(__file__), "data") def get_sqlite_db_path(directory: str) -> str: """Get the path to the sqlite db file.""" return os.path.join(directory, "salesforce_db.sqlite") def remove_sqlite_db_files(db_path: str) -> None: """Remove SQLite database and all associated files (WAL, SHM). SQLite in WAL mode creates additional files: - .sqlite-wal: Write-ahead log - .sqlite-shm: Shared memory file If these files become stale (e.g., after a crash), they can cause 'disk I/O error' when trying to open the database. This function ensures all related files are removed. """ files_to_remove = [ db_path, f"{db_path}-wal", f"{db_path}-shm", ] for file_path in files_to_remove: if os.path.exists(file_path): os.remove(file_path) # NOTE: only used with shelves, deprecated at this point def get_object_type_path(object_type: str) -> str: """Get the directory path for a specific object type.""" type_dir = os.path.join(BASE_DATA_PATH, object_type) os.makedirs(type_dir, exist_ok=True) return type_dir _CHECKSUM_CHARS = "ABCDEFGHIJKLMNOPQRSTUVWXYZ012345" _LOOKUP = {format(i, "05b"): _CHECKSUM_CHARS[i] for i in range(32)} def validate_salesforce_id(salesforce_id: str) -> bool: """Validate the checksum portion of an 18-character Salesforce ID. Args: salesforce_id: An 18-character Salesforce ID Returns: bool: True if the checksum is valid, False otherwise """ if len(salesforce_id) != 18: return False chunks = [salesforce_id[0:5], salesforce_id[5:10], salesforce_id[10:15]] checksum = salesforce_id[15:18] calculated_checksum = "" for chunk in chunks: result_string = "".join( "1" if char.isupper() else "0" for char in reversed(chunk) ) calculated_checksum += _LOOKUP[result_string] return checksum == calculated_checksum ================================================ FILE: backend/onyx/connectors/sharepoint/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/sharepoint/connector.py ================================================ import base64 import copy import fnmatch import html import io import os import re import time from collections import deque from collections.abc import Generator from collections.abc import Iterable from datetime import datetime from datetime import timezone from enum import Enum from typing import Any from typing import cast from urllib.parse import quote from urllib.parse import unquote from urllib.parse import urlsplit import msal # type: ignore[import-untyped] import requests from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import serialization from cryptography.hazmat.primitives.serialization import pkcs12 from office365.graph_client import GraphClient # type: ignore[import-untyped] from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped] from office365.onedrive.sites.site import Site # type: ignore[import-untyped] from office365.onedrive.sites.sites_with_root import SitesWithRoot # type: ignore[import-untyped] from office365.runtime.auth.token_response import TokenResponse # type: ignore[import-untyped] from office365.runtime.client_request import ClientRequestException # type: ignore from office365.runtime.paths.resource_path import ResourcePath # type: ignore[import-untyped] from office365.runtime.queries.client_query import ClientQuery # type: ignore[import-untyped] from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped] from pydantic import BaseModel from pydantic import Field from requests.exceptions import HTTPError from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.app_configs import REQUEST_TIMEOUT_SECONDS from onyx.configs.app_configs import SHAREPOINT_CONNECTOR_SIZE_THRESHOLD from onyx.configs.constants import DocumentSource from onyx.configs.constants import FileOrigin from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import IndexingHeartbeatInterface from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnectorWithPermSync from onyx.connectors.microsoft_graph_env import resolve_microsoft_environment from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import ConnectorCheckpoint from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import DocumentFailure from onyx.connectors.models import EntityFailure from onyx.connectors.models import ExternalAccess from onyx.connectors.models import HierarchyNode from onyx.connectors.models import ImageSection from onyx.connectors.models import SlimDocument from onyx.connectors.models import TextSection from onyx.connectors.sharepoint.connector_utils import get_sharepoint_external_access from onyx.db.enums import HierarchyNodeType from onyx.file_processing.extract_file_text import extract_text_and_images from onyx.file_processing.extract_file_text import get_file_ext from onyx.file_processing.file_types import OnyxFileExtensions from onyx.file_processing.file_types import OnyxMimeTypes from onyx.file_processing.image_utils import store_image_and_create_section from onyx.utils.b64 import get_image_type_from_bytes from onyx.utils.logger import setup_logger logger = setup_logger() SLIM_BATCH_SIZE = 1000 _EPOCH = datetime.fromtimestamp(0, tz=timezone.utc) SHARED_DOCUMENTS_MAP = { "Documents": "Shared Documents", "Dokumente": "Freigegebene Dokumente", "Documentos": "Documentos compartidos", } SHARED_DOCUMENTS_MAP_REVERSE = {v: k for k, v in SHARED_DOCUMENTS_MAP.items()} ASPX_EXTENSION = ".aspx" def _is_site_excluded(site_url: str, excluded_site_patterns: list[str]) -> bool: """Check if a site URL matches any of the exclusion glob patterns.""" for pattern in excluded_site_patterns: if fnmatch.fnmatch(site_url, pattern) or fnmatch.fnmatch( site_url.rstrip("/"), pattern.rstrip("/") ): return True return False def _is_path_excluded(item_path: str, excluded_path_patterns: list[str]) -> bool: """Check if a drive item path matches any of the exclusion glob patterns. item_path is the relative path within a drive, e.g. "Engineering/API/report.docx". Matches are attempted against the full path and the filename alone so that patterns like "*.tmp" match files at any depth. """ filename = item_path.rsplit("/", 1)[-1] if "/" in item_path else item_path for pattern in excluded_path_patterns: if fnmatch.fnmatch(item_path, pattern) or fnmatch.fnmatch(filename, pattern): return True return False def _build_item_relative_path(parent_reference_path: str | None, item_name: str) -> str: """Build the relative path of a drive item from its parentReference.path and name. Example: parentReference.path="/drives/abc/root:/Eng/API", name="report.docx" => "Eng/API/report.docx" """ if parent_reference_path and "root:/" in parent_reference_path: folder = unquote(parent_reference_path.split("root:/", 1)[1]) if folder: return f"{folder}/{item_name}" return item_name DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com" DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com" DEFAULT_SHAREPOINT_DOMAIN_SUFFIX = "sharepoint.com" GRAPH_API_BASE = f"{DEFAULT_GRAPH_API_HOST}/v1.0" GRAPH_API_MAX_RETRIES = 5 GRAPH_API_RETRYABLE_STATUSES = frozenset({429, 500, 502, 503, 504}) class DriveItemData(BaseModel): """Lightweight representation of a Graph API drive item, parsed from JSON. Replaces the SDK DriveItem for fetching/listing so that we can paginate lazily through the Graph API without materialising every item in memory. """ id: str name: str web_url: str size: int | None = None mime_type: str | None = None download_url: str | None = None last_modified_datetime: datetime | None = None last_modified_by_display_name: str | None = None last_modified_by_email: str | None = None parent_reference_path: str | None = None drive_id: str | None = None @classmethod def from_graph_json(cls, item: dict[str, Any]) -> "DriveItemData": last_mod_raw = item.get("lastModifiedDateTime") last_mod: datetime | None = None if isinstance(last_mod_raw, str): last_mod = datetime.fromisoformat(last_mod_raw.replace("Z", "+00:00")) last_modified_by = item.get("lastModifiedBy", {}).get("user", {}) parent_ref = item.get("parentReference", {}) return cls( id=item["id"], name=item.get("name", ""), web_url=item.get("webUrl", ""), size=item.get("size"), mime_type=item.get("file", {}).get("mimeType"), download_url=item.get("@microsoft.graph.downloadUrl"), last_modified_datetime=last_mod, last_modified_by_display_name=last_modified_by.get("displayName"), last_modified_by_email=( last_modified_by.get("email") or last_modified_by.get("userPrincipalName") ), parent_reference_path=parent_ref.get("path"), drive_id=parent_ref.get("driveId"), ) def to_sdk_driveitem(self, graph_client: GraphClient) -> DriveItem: """Construct a lazy SDK DriveItem for permission lookups.""" if not self.drive_id: raise ValueError("drive_id is required to construct SDK DriveItem") path = ResourcePath( self.id, ResourcePath("items", ResourcePath(self.drive_id, ResourcePath("drives"))), ) item = DriveItem(graph_client, path) item.set_property("id", self.id) return item # The office365 library's ClientContext caches the access token from its # first request and never re-invokes the token callback. Microsoft access # tokens live ~60-75 minutes, so we recreate the cached ClientContext every # 30 minutes to let MSAL transparently handle token refresh. _REST_CTX_MAX_AGE_S = 30 * 60 class SiteDescriptor(BaseModel): """Data class for storing SharePoint site information. Args: url: The base site URL (e.g. https://danswerai.sharepoint.com/sites/sharepoint-tests or https://danswerai.sharepoint.com/teams/team-name) drive_name: The name of the drive to access (e.g. "Shared Documents", "Other Library") If None, all drives will be accessed. folder_path: The folder path within the drive to access (e.g. "test/nested with spaces") If None, all folders will be accessed. """ url: str drive_name: str | None folder_path: str | None class CertificateData(BaseModel): """Data class for storing certificate information loaded from PFX file.""" private_key: bytes thumbprint: str def _site_page_in_time_window( page: dict[str, Any], start: datetime | None, end: datetime | None, ) -> bool: """Return True if the page's lastModifiedDateTime falls within [start, end].""" if start is None and end is None: return True raw = page.get("lastModifiedDateTime") if not raw: return True if not isinstance(raw, str): raise ValueError(f"lastModifiedDateTime is not a string: {raw}") last_modified = datetime.fromisoformat(raw.replace("Z", "+00:00")) return (start is None or last_modified >= start) and ( end is None or last_modified <= end ) def sleep_and_retry( query_obj: ClientQuery, method_name: str, max_retries: int = 3 ) -> Any: """ Execute a SharePoint query with retry logic for rate limiting. """ for attempt in range(max_retries + 1): try: return query_obj.execute_query() except ClientRequestException as e: status = e.response.status_code if e.response is not None else None # 429 / 503 — rate limit or transient error. Back off and retry. if status in (429, 503) and attempt < max_retries: logger.warning( f"Rate limit exceeded on {method_name}, attempt {attempt + 1}/{max_retries + 1}, sleeping and retrying" ) retry_after = e.response.headers.get("Retry-After") if retry_after: sleep_time = int(retry_after) else: # Exponential backoff: 2^attempt * 5 seconds sleep_time = min(30, (2**attempt) * 5) logger.info(f"Sleeping for {sleep_time} seconds before retry") time.sleep(sleep_time) continue # Non-retryable error or retries exhausted — log details and raise. if e.response is not None: logger.error( f"SharePoint request failed for {method_name}: status={status}, " ) raise e class SharepointConnectorCheckpoint(ConnectorCheckpoint): cached_site_descriptors: deque[SiteDescriptor] | None = None current_site_descriptor: SiteDescriptor | None = None cached_drive_names: deque[str] | None = None current_drive_name: str | None = None # Drive's web_url from the API - used as raw_node_id for DRIVE hierarchy nodes current_drive_web_url: str | None = None # Resolved drive ID — avoids re-resolving on checkpoint resume current_drive_id: str | None = None # Next delta API page URL for per-page checkpointing within a drive. # When set, Phase 3b fetches one page at a time so progress is persisted # between pages. None means BFS path or no active delta traversal. current_drive_delta_next_link: str | None = None process_site_pages: bool = False # Track yielded hierarchy nodes by their raw_node_id (URLs) to avoid duplicates seen_hierarchy_node_raw_ids: set[str] = Field(default_factory=set) # Track yielded document IDs to avoid processing the same document twice. # The Microsoft Graph delta API can return the same item on multiple pages. seen_document_ids: set[str] = Field(default_factory=set) class SharepointAuthMethod(Enum): CLIENT_SECRET = "client_secret" CERTIFICATE = "certificate" class SizeCapExceeded(Exception): """Exception raised when the size cap is exceeded.""" def _log_and_raise_for_status(response: requests.Response) -> None: """Log the response text and raise for status.""" try: response.raise_for_status() except Exception: logger.error(f"HTTP request failed: {response.text}") raise GRAPH_INVALID_REQUEST_CODE = "invalidRequest" def _is_graph_invalid_request(response: requests.Response) -> bool: """Return True if the response body is the generic Graph API ``{"error": {"code": "invalidRequest", "message": "Invalid request"}}`` shape. This particular error has no actionable inner error code and is returned by the site-pages endpoint when a page has a corrupt canvas layout (e.g. duplicate web-part IDs — see SharePoint/sp-dev-docs#8822).""" try: body = response.json() except Exception: return False error = body.get("error", {}) return error.get("code") == GRAPH_INVALID_REQUEST_CODE def load_certificate_from_pfx(pfx_data: bytes, password: str) -> CertificateData | None: """Load certificate from .pfx file for MSAL authentication""" try: # Load the certificate and private key private_key, certificate, additional_certificates = ( pkcs12.load_key_and_certificates(pfx_data, password.encode("utf-8")) ) # Validate that certificate and private key are not None if certificate is None or private_key is None: raise ValueError("Certificate or private key is None") # Convert to PEM format that MSAL expects key_pem = private_key.private_bytes( encoding=serialization.Encoding.PEM, format=serialization.PrivateFormat.PKCS8, encryption_algorithm=serialization.NoEncryption(), ) return CertificateData( private_key=key_pem, thumbprint=certificate.fingerprint(hashes.SHA1()).hex(), ) except Exception as e: logger.error(f"Error loading certificate: {e}") return None def acquire_token_for_rest( msal_app: msal.ConfidentialClientApplication, sp_tenant_domain: str, sharepoint_domain_suffix: str, ) -> TokenResponse: token = msal_app.acquire_token_for_client( scopes=[f"https://{sp_tenant_domain}.{sharepoint_domain_suffix}/.default"] ) return TokenResponse.from_json(token) def _create_document_failure( driveitem: DriveItemData, error_message: str, exception: Exception | None = None, ) -> ConnectorFailure: """Helper method to create a ConnectorFailure for document processing errors.""" return ConnectorFailure( failed_document=DocumentFailure( document_id=driveitem.id or "unknown", document_link=driveitem.web_url, ), failure_message=f"SharePoint document '{driveitem.name or 'unknown'}': {error_message}", exception=exception, ) def _create_entity_failure( entity_id: str, error_message: str, time_range: tuple[datetime, datetime] | None = None, exception: Exception | None = None, ) -> ConnectorFailure: """Helper method to create a ConnectorFailure for entity-level errors.""" return ConnectorFailure( failed_entity=EntityFailure( entity_id=entity_id, missed_time_range=time_range, ), failure_message=f"SharePoint entity '{entity_id}': {error_message}", exception=exception, ) def _probe_remote_size(url: str, timeout: int) -> int | None: """Determine remote size using HEAD or a range GET probe. Returns None if unknown.""" try: head_resp = requests.head(url, timeout=timeout, allow_redirects=True) _log_and_raise_for_status(head_resp) cl = head_resp.headers.get("Content-Length") if cl and cl.isdigit(): return int(cl) except requests.RequestException: pass # Fallback: Range request for first byte to read total from Content-Range try: with requests.get( url, headers={"Range": "bytes=0-0"}, timeout=timeout, stream=True, ) as range_resp: _log_and_raise_for_status(range_resp) cr = range_resp.headers.get("Content-Range") # e.g., "bytes 0-0/12345" if cr and "/" in cr: total = cr.split("/")[-1] if total.isdigit(): return int(total) except requests.RequestException: pass # If both HEAD and a range GET failed to reveal a size, signal unknown size. # Callers should treat None as "size unavailable" and proceed with a safe # streaming path that enforces a hard cap to avoid excessive memory usage. return None def _download_with_cap(url: str, timeout: int, cap: int) -> bytes: """Stream download content with an upper bound on bytes read. Behavior: - Checks `Content-Length` first and aborts early if it exceeds `cap`. - Otherwise streams the body in chunks and stops once `cap` is surpassed. - Raises `SizeCapExceeded` when the cap would be exceeded. - Returns the full bytes if the content fits within `cap`. """ with requests.get(url, stream=True, timeout=timeout) as resp: _log_and_raise_for_status(resp) # If the server provides Content-Length, prefer an early decision. cl_header = resp.headers.get("Content-Length") if cl_header and cl_header.isdigit(): content_len = int(cl_header) if content_len > cap: logger.warning( f"Content-Length {content_len} exceeds cap {cap}; skipping download." ) raise SizeCapExceeded("pre_download") buf = io.BytesIO() # Stream in 64KB chunks; adjust if needed for slower networks. for chunk in resp.iter_content(64 * 1024): if not chunk: continue buf.write(chunk) if buf.tell() > cap: # Avoid keeping a large partial buffer; close and signal caller to skip. logger.warning( f"Streaming download exceeded cap {cap} bytes; aborting early." ) raise SizeCapExceeded("during_download") return buf.getvalue() def _download_via_graph_api( access_token: str, drive_id: str, item_id: str, bytes_allowed: int, graph_api_base: str, ) -> bytes: """Download a drive item via the Graph API /content endpoint with a byte cap. Raises SizeCapExceeded if the cap is exceeded. """ url = f"{graph_api_base}/drives/{drive_id}/items/{item_id}/content" headers = {"Authorization": f"Bearer {access_token}"} with requests.get( url, headers=headers, stream=True, timeout=REQUEST_TIMEOUT_SECONDS ) as resp: _log_and_raise_for_status(resp) buf = io.BytesIO() for chunk in resp.iter_content(64 * 1024): if not chunk: continue buf.write(chunk) if buf.tell() > bytes_allowed: raise SizeCapExceeded("during_graph_api_download") return buf.getvalue() def _convert_driveitem_to_document_with_permissions( driveitem: DriveItemData, drive_name: str, ctx: ClientContext | None, graph_client: GraphClient, graph_api_base: str, include_permissions: bool = False, parent_hierarchy_raw_node_id: str | None = None, access_token: str | None = None, treat_sharing_link_as_public: bool = False, ) -> Document | ConnectorFailure | None: if not driveitem.name or not driveitem.id: raise ValueError("DriveItem name/id is required") if include_permissions and ctx is None: raise ValueError("ClientContext is required for permissions") mime_type = driveitem.mime_type if not mime_type or mime_type in OnyxMimeTypes.EXCLUDED_IMAGE_TYPES: logger.debug( f"Skipping malformed or excluded mime type {mime_type} for {driveitem.name}" ) return None file_size = driveitem.size download_url = driveitem.download_url if file_size is None and download_url: file_size = _probe_remote_size(download_url, REQUEST_TIMEOUT_SECONDS) if file_size is not None and file_size > SHAREPOINT_CONNECTOR_SIZE_THRESHOLD: logger.warning( f"Skipping '{driveitem.name}' over size threshold ({file_size} > {SHAREPOINT_CONNECTOR_SIZE_THRESHOLD} bytes)." ) return None # Prefer downloadUrl streaming with size cap content_bytes: bytes | None = None if download_url: try: content_bytes = _download_with_cap( download_url, REQUEST_TIMEOUT_SECONDS, SHAREPOINT_CONNECTOR_SIZE_THRESHOLD, ) except SizeCapExceeded as e: logger.warning(f"Skipping '{driveitem.name}' exceeded size cap: {str(e)}") return None except requests.RequestException as e: status = e.response.status_code if e.response is not None else -1 logger.warning( f"Failed to download via downloadUrl for '{driveitem.name}' (status={status}); falling back to Graph API." ) # Fallback: download via Graph API /content endpoint if content_bytes is None and access_token and driveitem.drive_id: try: content_bytes = _download_via_graph_api( access_token, driveitem.drive_id, driveitem.id, SHAREPOINT_CONNECTOR_SIZE_THRESHOLD, graph_api_base=graph_api_base, ) except SizeCapExceeded: logger.warning( f"Skipping '{driveitem.name}' exceeded size cap during Graph API download." ) return None except Exception as e: logger.warning( f"Failed to download via Graph API for '{driveitem.name}': {e}" ) return _create_document_failure( driveitem, f"Failed to download via graph api: {e}", e ) sections: list[TextSection | ImageSection] = [] file_ext = get_file_ext(driveitem.name) if not content_bytes: logger.warning( f"Zero-length content for '{driveitem.name}'. Skipping text/image extraction." ) elif file_ext in OnyxFileExtensions.IMAGE_EXTENSIONS: image_section, _ = store_image_and_create_section( image_data=content_bytes, file_id=driveitem.id, display_name=driveitem.name, file_origin=FileOrigin.CONNECTOR, ) image_section.link = driveitem.web_url sections.append(image_section) else: def _store_embedded_image(img_data: bytes, img_name: str) -> None: try: img_mime = get_image_type_from_bytes(img_data) except ValueError: logger.debug( "Skipping embedded image with unknown format for %s", driveitem.name, ) return if img_mime in OnyxMimeTypes.EXCLUDED_IMAGE_TYPES: logger.debug( "Skipping embedded image of excluded type %s for %s", img_mime, driveitem.name, ) return image_section, _ = store_image_and_create_section( image_data=img_data, file_id=f"{driveitem.id}_img_{len(sections)}", display_name=img_name or f"{driveitem.name} - image {len(sections)}", file_origin=FileOrigin.CONNECTOR, ) image_section.link = driveitem.web_url sections.append(image_section) extraction_result = extract_text_and_images( file=io.BytesIO(content_bytes), file_name=driveitem.name, image_callback=_store_embedded_image, ) if extraction_result.text_content: sections.append( TextSection(link=driveitem.web_url, text=extraction_result.text_content) ) if include_permissions and ctx is not None: logger.info(f"Getting external access for {driveitem.name}") sdk_item = driveitem.to_sdk_driveitem(graph_client) external_access = get_sharepoint_external_access( ctx=ctx, graph_client=graph_client, drive_item=sdk_item, drive_name=drive_name, add_prefix=True, treat_sharing_link_as_public=treat_sharing_link_as_public, ) else: external_access = ExternalAccess.empty() doc = Document( id=driveitem.id, sections=sections, source=DocumentSource.SHAREPOINT, semantic_identifier=driveitem.name, external_access=external_access, doc_updated_at=( driveitem.last_modified_datetime.replace(tzinfo=timezone.utc) if driveitem.last_modified_datetime else None ), primary_owners=[ BasicExpertInfo( display_name=driveitem.last_modified_by_display_name or "", email=driveitem.last_modified_by_email or "", ) ], metadata={"drive": drive_name}, parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id, ) return doc def _convert_sitepage_to_document( site_page: dict[str, Any], site_name: str | None, ctx: ClientContext | None, graph_client: GraphClient, include_permissions: bool = False, parent_hierarchy_raw_node_id: str | None = None, treat_sharing_link_as_public: bool = False, ) -> Document: """Convert a SharePoint site page to a Document object.""" # Extract text content from the site page page_text = "" # Get title and description title = cast(str, site_page.get("title", "")) description = cast(str, site_page.get("description", "")) # Build the text content if title: page_text += f"# {title}\n\n" if description: page_text += f"{description}\n\n" # Extract content from canvas layout if available canvas_layout = site_page.get("canvasLayout", {}) if canvas_layout: horizontal_sections = canvas_layout.get("horizontalSections", []) for section in horizontal_sections: columns = section.get("columns", []) for column in columns: webparts = column.get("webparts", []) for webpart in webparts: # Extract text from different types of webparts webpart_type = webpart.get("@odata.type", "") # Extract text from text webparts if webpart_type == "#microsoft.graph.textWebPart": inner_html = webpart.get("innerHtml", "") if inner_html: # Basic HTML to text conversion # Remove HTML tags but preserve some structure text_content = re.sub(r"", "\n", inner_html) text_content = re.sub(r"
  • ", "• ", text_content) text_content = re.sub(r"
  • ", "\n", text_content) text_content = re.sub( r"]*>", "\n## ", text_content ) text_content = re.sub(r"", "\n", text_content) text_content = re.sub(r"]*>", "\n", text_content) text_content = re.sub(r"

    ", "\n", text_content) text_content = re.sub(r"<[^>]+>", "", text_content) # Decode HTML entities text_content = html.unescape(text_content) # Clean up extra whitespace text_content = re.sub( r"\n\s*\n", "\n\n", text_content ).strip() if text_content: page_text += f"{text_content}\n\n" # Extract text from standard webparts elif webpart_type == "#microsoft.graph.standardWebPart": data = webpart.get("data", {}) # Extract from serverProcessedContent server_content = data.get("serverProcessedContent", {}) searchable_texts = server_content.get( "searchablePlainTexts", [] ) for text_item in searchable_texts: if isinstance(text_item, dict): key = text_item.get("key", "") value = text_item.get("value", "") if value: # Add context based on key if key == "title": page_text += f"## {value}\n\n" else: page_text += f"{value}\n\n" # Extract description if available description = data.get("description", "") if description: page_text += f"{description}\n\n" # Extract title if available webpart_title = data.get("title", "") if webpart_title and webpart_title != description: page_text += f"## {webpart_title}\n\n" page_text = page_text.strip() # If no content extracted, use the title as fallback if not page_text and title: page_text = title # Parse creation and modification info created_datetime = site_page.get("createdDateTime") if created_datetime: if isinstance(created_datetime, str): created_datetime = datetime.fromisoformat( created_datetime.replace("Z", "+00:00") ) elif not created_datetime.tzinfo: created_datetime = created_datetime.replace(tzinfo=timezone.utc) last_modified_datetime = site_page.get("lastModifiedDateTime") if last_modified_datetime: if isinstance(last_modified_datetime, str): last_modified_datetime = datetime.fromisoformat( last_modified_datetime.replace("Z", "+00:00") ) elif not last_modified_datetime.tzinfo: last_modified_datetime = last_modified_datetime.replace(tzinfo=timezone.utc) # Extract owner information primary_owners = [] created_by = site_page.get("createdBy", {}).get("user", {}) if created_by.get("displayName"): primary_owners.append( BasicExpertInfo( display_name=created_by.get("displayName"), email=created_by.get("email", ""), ) ) web_url = site_page["webUrl"] semantic_identifier = cast(str, site_page.get("name", title)) if semantic_identifier.endswith(ASPX_EXTENSION): semantic_identifier = semantic_identifier[: -len(ASPX_EXTENSION)] if include_permissions: external_access = get_sharepoint_external_access( ctx=ctx, graph_client=graph_client, site_page=site_page, add_prefix=True, treat_sharing_link_as_public=treat_sharing_link_as_public, ) else: external_access = ExternalAccess.empty() doc = Document( id=site_page["id"], sections=[TextSection(link=web_url, text=page_text)], source=DocumentSource.SHAREPOINT, external_access=external_access, semantic_identifier=semantic_identifier, doc_updated_at=last_modified_datetime or created_datetime, primary_owners=primary_owners, metadata=( { "site": site_name, } if site_name else {} ), parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id, ) return doc def _convert_driveitem_to_slim_document( driveitem: DriveItemData, drive_name: str, ctx: ClientContext, graph_client: GraphClient, parent_hierarchy_raw_node_id: str | None = None, treat_sharing_link_as_public: bool = False, ) -> SlimDocument: if driveitem.id is None: raise ValueError("DriveItem ID is required") sdk_item = driveitem.to_sdk_driveitem(graph_client) external_access = get_sharepoint_external_access( ctx=ctx, graph_client=graph_client, drive_item=sdk_item, drive_name=drive_name, treat_sharing_link_as_public=treat_sharing_link_as_public, ) return SlimDocument( id=driveitem.id, external_access=external_access, parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id, ) def _convert_sitepage_to_slim_document( site_page: dict[str, Any], ctx: ClientContext | None, graph_client: GraphClient, parent_hierarchy_raw_node_id: str | None = None, treat_sharing_link_as_public: bool = False, ) -> SlimDocument: """Convert a SharePoint site page to a SlimDocument object.""" if site_page.get("id") is None: raise ValueError("Site page ID is required") external_access = get_sharepoint_external_access( ctx=ctx, graph_client=graph_client, site_page=site_page, treat_sharing_link_as_public=treat_sharing_link_as_public, ) id = site_page.get("id") if id is None: raise ValueError("Site page ID is required") return SlimDocument( id=id, external_access=external_access, parent_hierarchy_raw_node_id=parent_hierarchy_raw_node_id, ) class SharepointConnector( SlimConnectorWithPermSync, CheckpointedConnectorWithPermSync[SharepointConnectorCheckpoint], ): def __init__( self, batch_size: int = INDEX_BATCH_SIZE, sites: list[str] = [], excluded_sites: list[str] = [], excluded_paths: list[str] = [], include_site_pages: bool = True, include_site_documents: bool = True, treat_sharing_link_as_public: bool = False, authority_host: str = DEFAULT_AUTHORITY_HOST, graph_api_host: str = DEFAULT_GRAPH_API_HOST, sharepoint_domain_suffix: str = DEFAULT_SHAREPOINT_DOMAIN_SUFFIX, ) -> None: self.batch_size = batch_size self.sites = list(sites) self.excluded_sites = [s for p in excluded_sites if (s := p.strip())] self.excluded_paths = [s for p in excluded_paths if (s := p.strip())] self.treat_sharing_link_as_public = treat_sharing_link_as_public self.site_descriptors: list[SiteDescriptor] = self._extract_site_and_drive_info( sites ) self._graph_client: GraphClient | None = None self.msal_app: msal.ConfidentialClientApplication | None = None self.include_site_pages = include_site_pages self.include_site_documents = include_site_documents self.sp_tenant_domain: str | None = None self._credential_json: dict[str, Any] | None = None self._cached_rest_ctx: ClientContext | None = None self._cached_rest_ctx_url: str | None = None self._cached_rest_ctx_created_at: float = 0.0 resolved_env = resolve_microsoft_environment(graph_api_host, authority_host) self._azure_environment = resolved_env.environment self.authority_host = resolved_env.authority_host self.graph_api_host = resolved_env.graph_host self.graph_api_base = f"{self.graph_api_host}/v1.0" self.sharepoint_domain_suffix = resolved_env.sharepoint_domain_suffix if sharepoint_domain_suffix != resolved_env.sharepoint_domain_suffix: logger.warning( f"Configured sharepoint_domain_suffix '{sharepoint_domain_suffix}' " f"differs from the expected suffix '{resolved_env.sharepoint_domain_suffix}' " f"for the {resolved_env.environment} environment. " f"Using '{resolved_env.sharepoint_domain_suffix}'." ) def validate_connector_settings(self) -> None: # Validate that at least one content type is enabled if not self.include_site_documents and not self.include_site_pages: raise ConnectorValidationError( "At least one content type must be enabled. " "Please check either 'Include Site Documents' or 'Include Site Pages' (or both)." ) # Ensure sites are sharepoint urls for site_url in self.sites: if not site_url.startswith("https://") or not ( "/sites/" in site_url or "/teams/" in site_url ): raise ConnectorValidationError( "Site URLs must be full Sharepoint URLs (e.g. https://your-tenant.sharepoint.com/sites/your-site or https://your-tenant.sharepoint.com/teams/your-team)" ) def _extract_tenant_domain_from_sites(self) -> str | None: """Extract the tenant domain from configured site URLs. Site URLs look like https://{tenant}.sharepoint.com/sites/... so the tenant domain is the first label of the hostname. """ for site_url in self.sites: try: hostname = urlsplit(site_url.strip()).hostname except ValueError: continue if not hostname: continue tenant = hostname.split(".")[0] if tenant: return tenant logger.warning(f"No tenant domain found from {len(self.sites)} sites") return None def _resolve_tenant_domain_from_root_site(self) -> str: """Resolve tenant domain via GET /v1.0/sites/root which only requires Sites.Read.All (a permission the connector already needs).""" root_site = self.graph_client.sites.root.get().execute_query() hostname = root_site.site_collection.hostname if not hostname: raise ConnectorValidationError( "Could not determine tenant domain from root site" ) tenant_domain = hostname.split(".")[0] logger.info( "Resolved tenant domain '%s' from root site hostname '%s'", tenant_domain, hostname, ) return tenant_domain def _resolve_tenant_domain(self) -> str: """Determine the tenant domain, preferring site URLs over a Graph API call to avoid needing extra permissions.""" from_sites = self._extract_tenant_domain_from_sites() if from_sites: logger.info( "Resolved tenant domain '%s' from site URLs", from_sites, ) return from_sites logger.info("No site URLs available; resolving tenant domain from root site") return self._resolve_tenant_domain_from_root_site() @property def graph_client(self) -> GraphClient: if self._graph_client is None: raise ConnectorMissingCredentialError("Sharepoint") return self._graph_client def _create_rest_client_context(self, site_url: str) -> ClientContext: """Return a ClientContext for SharePoint REST API calls, with caching. The office365 library's ClientContext caches the access token from its first request and never re-invokes the token callback. We cache the context and recreate it when the site URL changes or after ``_REST_CTX_MAX_AGE_S``. On recreation we also call ``load_credentials`` to build a fresh MSAL app with an empty token cache, guaranteeing a brand-new token from Azure AD.""" elapsed = time.monotonic() - self._cached_rest_ctx_created_at if ( self._cached_rest_ctx is not None and self._cached_rest_ctx_url == site_url and elapsed <= _REST_CTX_MAX_AGE_S ): return self._cached_rest_ctx if self._credential_json: logger.info( "Rebuilding SharePoint REST client context (elapsed=%.0fs, site_changed=%s)", elapsed, self._cached_rest_ctx_url != site_url, ) self.load_credentials(self._credential_json) if not self.msal_app or not self.sp_tenant_domain: raise RuntimeError("MSAL app or tenant domain is not set") msal_app = self.msal_app sp_tenant_domain = self.sp_tenant_domain sp_domain_suffix = self.sharepoint_domain_suffix self._cached_rest_ctx = ClientContext(site_url).with_access_token( lambda: acquire_token_for_rest(msal_app, sp_tenant_domain, sp_domain_suffix) ) self._cached_rest_ctx_url = site_url self._cached_rest_ctx_created_at = time.monotonic() return self._cached_rest_ctx @staticmethod def _strip_share_link_tokens(path: str) -> list[str]: # Share links often include a token prefix like /:f:/r/ or /:x:/r/. segments = [segment for segment in path.split("/") if segment] if segments and segments[0].startswith(":"): segments = segments[1:] if segments and segments[0] in {"r", "s", "g"}: segments = segments[1:] return segments @staticmethod def _normalize_sharepoint_url(url: str) -> tuple[str | None, list[str]]: try: parsed = urlsplit(url) except ValueError: logger.warning(f"Sharepoint URL '{url}' could not be parsed") return None, [] if not parsed.scheme or not parsed.netloc: logger.warning( f"Sharepoint URL '{url}' is not a valid absolute URL (missing scheme or host)" ) return None, [] path_segments = SharepointConnector._strip_share_link_tokens(parsed.path) return f"{parsed.scheme}://{parsed.netloc}", path_segments @staticmethod def _extract_site_and_drive_info(site_urls: list[str]) -> list[SiteDescriptor]: site_data_list = [] for url in site_urls: base_url, parts = SharepointConnector._normalize_sharepoint_url(url.strip()) if base_url is None: continue lower_parts = [part.lower() for part in parts] site_type_index = None for site_token in ("sites", "teams"): if site_token in lower_parts: site_type_index = lower_parts.index(site_token) break if site_type_index is None or len(parts) <= site_type_index + 1: logger.warning( f"Site URL '{url}' is not a valid Sharepoint URL (must contain /sites/ or /teams/)" ) continue site_path = parts[: site_type_index + 2] remaining_parts = parts[site_type_index + 2 :] site_url = f"{base_url}/" + "/".join(site_path) # Extract drive name and folder path if remaining_parts: drive_name = unquote(remaining_parts[0]) folder_path = ( "/".join(unquote(part) for part in remaining_parts[1:]) if len(remaining_parts) > 1 else None ) else: drive_name = None folder_path = None site_data_list.append( SiteDescriptor( url=site_url, drive_name=drive_name, folder_path=folder_path, ) ) return site_data_list def _resolve_drive( self, site_descriptor: SiteDescriptor, drive_name: str, ) -> tuple[str, str | None] | None: """Find the drive ID and web_url for a given drive name on a site. Returns (drive_id, drive_web_url) or None if the drive was not found. Raises on auth/permission errors so callers can propagate them. """ site = self.graph_client.sites.get_by_url(site_descriptor.url) drives = site.drives.get().execute_query() logger.info(f"Found drives: {[d.name for d in drives]}") matched = [ d for d in drives if (d.name and d.name.lower() == drive_name.lower()) or ( d.name in SHARED_DOCUMENTS_MAP and SHARED_DOCUMENTS_MAP[d.name] == drive_name ) ] if not matched: logger.warning(f"Drive '{drive_name}' not found") return None drive = matched[0] drive_web_url: str | None = drive.web_url logger.info(f"Found drive: {drive.name} (web_url: {drive_web_url})") return cast(str, drive.id), drive_web_url def _get_drive_items_for_drive_id( self, site_descriptor: SiteDescriptor, drive_id: str, start: datetime | None = None, end: datetime | None = None, ) -> Generator[DriveItemData, None, None]: """Yield drive items lazily for a given drive name. Uses the delta API for whole-drive enumeration (flat, incremental via timestamp token) and falls back to BFS /children traversal when a folder_path is configured, since delta cannot scope to a subtree efficiently. Returns: A generator of DriveItemData. The generator paginates through the Graph API so items are never all held in memory at once. """ try: if site_descriptor.folder_path: yield from self._iter_drive_items_paged( drive_id=drive_id, folder_path=site_descriptor.folder_path, start=start, end=end, ) else: yield from self._iter_drive_items_delta( drive_id=drive_id, start=start, end=end, ) except Exception as e: err_str = str(e) if ( "403 Client Error" in err_str or "404 Client Error" in err_str or "invalid_client" in err_str ): raise e logger.warning(f"Failed to process site: {site_descriptor.url} - {err_str}") def _fetch_driveitems( self, site_descriptor: SiteDescriptor, start: datetime | None = None, end: datetime | None = None, ) -> Generator[tuple[DriveItemData, str, str | None], None, None]: """Yield drive items lazily for all drives in a site. Yields (DriveItemData, drive_name, drive_web_url) tuples one item at a time, paginating through the Graph API internally. """ try: site = self.graph_client.sites.get_by_url(site_descriptor.url) drives = site.drives.get().execute_query() logger.debug(f"Found drives: {[d.name for d in drives]}") if site_descriptor.drive_name: drives = [ drive for drive in drives if drive.name == site_descriptor.drive_name or ( drive.name in SHARED_DOCUMENTS_MAP and SHARED_DOCUMENTS_MAP[drive.name] == site_descriptor.drive_name ) ] if not drives: logger.warning(f"Drive '{site_descriptor.drive_name}' not found") return for drive in drives: try: drive_name = ( SHARED_DOCUMENTS_MAP[drive.name] if drive.name in SHARED_DOCUMENTS_MAP else cast(str, drive.name) ) drive_web_url: str | None = drive.web_url if site_descriptor.folder_path: item_iter = self._iter_drive_items_paged( drive_id=cast(str, drive.id), folder_path=site_descriptor.folder_path, start=start, end=end, ) else: item_iter = self._iter_drive_items_delta( drive_id=cast(str, drive.id), start=start, end=end, ) for item in item_iter: yield item, drive_name or "", drive_web_url except Exception as e: logger.warning(f"Failed to process drive '{drive.name}': {str(e)}") except Exception as e: err_str = str(e) if ( "403 Client Error" in err_str or "404 Client Error" in err_str or "invalid_client" in err_str ): raise e logger.warning(f"Failed to process site: {err_str}") def _handle_paginated_sites( self, sites: SitesWithRoot ) -> Generator[Site, None, None]: while sites: if sites.current_page: yield from sites.current_page if not sites.has_next: break sites = sites._get_next().execute_query() def _is_driveitem_excluded(self, driveitem: DriveItemData) -> bool: """Check if a drive item should be excluded based on excluded_paths patterns.""" if not self.excluded_paths: return False relative_path = _build_item_relative_path( driveitem.parent_reference_path, driveitem.name ) return _is_path_excluded(relative_path, self.excluded_paths) def _filter_excluded_sites( self, site_descriptors: list[SiteDescriptor] ) -> list[SiteDescriptor]: """Remove sites matching any excluded_sites glob pattern.""" if not self.excluded_sites: return site_descriptors result = [] for sd in site_descriptors: if _is_site_excluded(sd.url, self.excluded_sites): logger.info(f"Excluding site by denylist: {sd.url}") continue result.append(sd) return result def fetch_sites(self) -> list[SiteDescriptor]: sites = self.graph_client.sites.get_all_sites().execute_query() if not sites: raise RuntimeError("No sites found in the tenant") # OneDrive personal sites should not be indexed with SharepointConnector site_descriptors = [ SiteDescriptor( url=site.web_url or "", drive_name=None, folder_path=None, ) for site in self._handle_paginated_sites(sites) if "-my.sharepoint" not in site.web_url ] return self._filter_excluded_sites(site_descriptors) def _fetch_site_pages( self, site_descriptor: SiteDescriptor, start: datetime | None = None, end: datetime | None = None, ) -> Generator[dict[str, Any], None, None]: """Yield SharePoint site pages (.aspx files) one at a time. Pages are fetched via the Graph Pages API and yielded lazily as each API page arrives, so memory stays bounded regardless of total page count. Time-window filtering is applied per-item before yielding. """ site = self.graph_client.sites.get_by_url(site_descriptor.url) site.execute_query() site_id = site.id site_pages_base = ( f"{self.graph_api_base}/sites/{site_id}/pages/microsoft.graph.sitePage" ) page_url: str | None = site_pages_base params: dict[str, str] | None = {"$expand": "canvasLayout"} total_yielded = 0 yielded_ids: set[str] = set() while page_url: try: data = self._graph_api_get_json(page_url, params) except HTTPError as e: if e.response is not None and e.response.status_code == 404: logger.warning(f"Site page not found: {page_url}") break if ( e.response is not None and e.response.status_code == 400 and _is_graph_invalid_request(e.response) ): logger.warning( f"$expand=canvasLayout on the LIST endpoint returned 400 " f"for site {site_descriptor.url}. Falling back to " f"per-page expansion." ) yield from self._fetch_site_pages_individually( site_pages_base, start, end, skip_ids=yielded_ids ) return raise params = None # nextLink already embeds query params for page in data.get("value", []): if not _site_page_in_time_window(page, start, end): continue total_yielded += 1 page_id = page.get("id") if page_id: yielded_ids.add(page_id) yield page page_url = data.get("@odata.nextLink") logger.debug(f"Yielded {total_yielded} site pages for {site_descriptor.url}") def _fetch_site_pages_individually( self, site_pages_base: str, start: datetime | None = None, end: datetime | None = None, skip_ids: set[str] | None = None, ) -> Generator[dict[str, Any], None, None]: """Fallback for _fetch_site_pages: list pages without $expand, then expand canvasLayout on each page individually. The Graph API's LIST endpoint can return 400 when $expand=canvasLayout is used and *any* page in the site has a corrupt canvas layout (e.g. duplicate web part IDs — see SharePoint/sp-dev-docs#8822). Since the LIST expansion is all-or-nothing, a single bad page poisons the entire response. This method works around it by fetching metadata first, then expanding each page individually so only the broken page loses its canvas content. ``skip_ids`` contains page IDs already yielded by the caller before the fallback was triggered, preventing duplicates. """ page_url: str | None = site_pages_base total_yielded = 0 _skip_ids = skip_ids or set() while page_url: try: data = self._graph_api_get_json(page_url) except HTTPError as e: if e.response is not None and e.response.status_code == 404: break raise for page in data.get("value", []): if not _site_page_in_time_window(page, start, end): continue page_id = page.get("id") if page_id and page_id in _skip_ids: continue if not page_id: total_yielded += 1 yield page continue expanded = self._try_expand_single_page(site_pages_base, page_id, page) total_yielded += 1 yield expanded page_url = data.get("@odata.nextLink") logger.debug( f"Yielded {total_yielded} site pages (per-page expansion fallback)" ) def _try_expand_single_page( self, site_pages_base: str, page_id: str, fallback_page: dict[str, Any], ) -> dict[str, Any]: """Try to GET a single page with $expand=canvasLayout. On 400, return the metadata-only fallback so the page is still indexed (without canvas content).""" pages_collection = site_pages_base.removesuffix("/microsoft.graph.sitePage") single_url = f"{pages_collection}/{page_id}/microsoft.graph.sitePage" try: return self._graph_api_get_json(single_url, {"$expand": "canvasLayout"}) except HTTPError as e: if ( e.response is not None and e.response.status_code == 400 and _is_graph_invalid_request(e.response) ): page_name = fallback_page.get("name", page_id) logger.warning( f"$expand=canvasLayout failed for page '{page_name}' ({page_id}). Indexing metadata only." ) return fallback_page raise def _acquire_token(self) -> dict[str, Any]: """ Acquire token via MSAL """ if self.msal_app is None: raise RuntimeError("MSAL app is not initialized") token = self.msal_app.acquire_token_for_client( scopes=[f"{self.graph_api_host}/.default"] ) return token def _get_graph_access_token(self) -> str: token_data = self._acquire_token() access_token = token_data.get("access_token") if not access_token: raise RuntimeError("Failed to acquire Graph API access token") return access_token def _graph_api_get_json( self, url: str, params: dict[str, str] | None = None, ) -> dict[str, Any]: """Make an authenticated GET request to the Graph API with retry.""" access_token = self._get_graph_access_token() headers = {"Authorization": f"Bearer {access_token}"} for attempt in range(GRAPH_API_MAX_RETRIES + 1): try: response = requests.get( url, headers=headers, params=params, timeout=REQUEST_TIMEOUT_SECONDS, ) if response.status_code in GRAPH_API_RETRYABLE_STATUSES: if attempt < GRAPH_API_MAX_RETRIES: retry_after = int( response.headers.get("Retry-After", str(2**attempt)) ) wait = min(retry_after, 60) logger.warning( f"Graph API {response.status_code} on attempt {attempt + 1}, retrying in {wait}s: {url}" ) time.sleep(wait) # Re-acquire token in case it expired during a long traversal access_token = self._get_graph_access_token() headers = {"Authorization": f"Bearer {access_token}"} continue _log_and_raise_for_status(response) return response.json() except (requests.ConnectionError, requests.Timeout): if attempt < GRAPH_API_MAX_RETRIES: wait = min(2**attempt, 60) logger.warning( f"Graph API connection error on attempt {attempt + 1}, retrying in {wait}s: {url}" ) time.sleep(wait) continue raise raise RuntimeError( f"Graph API request failed after {GRAPH_API_MAX_RETRIES + 1} attempts: {url}" ) def _iter_drive_items_paged( self, drive_id: str, folder_path: str | None = None, start: datetime | None = None, end: datetime | None = None, page_size: int = 200, ) -> Generator[DriveItemData, None, None]: """Yield DriveItemData for every file in a drive via the Graph API. Performs BFS folder traversal manually, fetching one page of children at a time so that memory usage stays bounded regardless of drive size. """ base = f"{self.graph_api_base}/drives/{drive_id}" if folder_path: encoded_path = quote(folder_path, safe="/") start_url = f"{base}/root:/{encoded_path}:/children" else: start_url = f"{base}/root/children" folder_queue: deque[str] = deque([start_url]) while folder_queue: page_url: str | None = folder_queue.popleft() params: dict[str, str] | None = {"$top": str(page_size)} while page_url: data = self._graph_api_get_json(page_url, params) params = None # nextLink already embeds query params for item in data.get("value", []): if "folder" in item: child_url = f"{base}/items/{item['id']}/children" folder_queue.append(child_url) continue # Skip non-file items (e.g. OneNote notebooks without a "file" facet) # but still yield them — the downstream conversion handles filtering # by extension / mime type. # NOTE: We are now including items without a lastModifiedDateTime, # and respecting when only one of start or end is set. if start is not None or end is not None: raw_ts = item.get("lastModifiedDateTime") if raw_ts: mod_dt = datetime.fromisoformat( raw_ts.replace("Z", "+00:00") ) if start is not None and mod_dt < start: continue if end is not None and mod_dt > end: continue yield DriveItemData.from_graph_json(item) page_url = data.get("@odata.nextLink") def _iter_drive_items_delta( self, drive_id: str, start: datetime | None = None, end: datetime | None = None, page_size: int = 200, ) -> Generator[DriveItemData, None, None]: """Yield DriveItemData for every file in a drive via the Graph delta API. Uses the flat delta endpoint instead of recursive folder traversal. On subsequent runs (start > epoch), passes the start timestamp as a delta token so that only changed items are returned. Falls back to full enumeration if the API returns 410 Gone (expired token). """ use_timestamp_token = start is not None and start > _EPOCH initial_url = f"{self.graph_api_base}/drives/{drive_id}/root/delta" if use_timestamp_token: assert start is not None # mypy token = quote(start.isoformat(timespec="seconds")) initial_url += f"?token={token}" yield from self._iter_delta_pages( initial_url=initial_url, drive_id=drive_id, start=start, end=end, page_size=page_size, allow_full_resync=use_timestamp_token, ) def _iter_delta_pages( self, initial_url: str, drive_id: str, start: datetime | None, end: datetime | None, page_size: int, allow_full_resync: bool, ) -> Generator[DriveItemData, None, None]: """Paginate through delta API responses, yielding file DriveItemData. If the API responds with 410 Gone and allow_full_resync is True, restarts with a full delta enumeration. """ page_url: str | None = initial_url params: dict[str, str] | None = {"$top": str(page_size)} while page_url: try: data = self._graph_api_get_json(page_url, params) except requests.HTTPError as e: # 410 means the delta token expired, so we need to fall back to full enumeration if e.response is not None and e.response.status_code == 410: if not allow_full_resync: raise logger.warning( "Delta token expired (410 Gone) for drive '%s'. Falling back to full delta enumeration.", drive_id, ) yield from self._iter_delta_pages( initial_url=f"{self.graph_api_base}/drives/{drive_id}/root/delta", drive_id=drive_id, start=start, end=end, page_size=page_size, allow_full_resync=False, ) return raise params = None # nextLink/deltaLink already embed query params for item in data.get("value", []): if "folder" in item or "deleted" in item: continue if start is not None or end is not None: raw_ts = item.get("lastModifiedDateTime") if raw_ts: mod_dt = datetime.fromisoformat(raw_ts.replace("Z", "+00:00")) if start is not None and mod_dt < start: continue if end is not None and mod_dt > end: continue yield DriveItemData.from_graph_json(item) page_url = data.get("@odata.nextLink") if not page_url: break def _build_delta_start_url( self, drive_id: str, start: datetime | None = None, page_size: int = 200, ) -> str: """Build the initial delta API URL with query parameters embedded. Embeds ``$top`` (and optionally a timestamp ``token``) directly in the URL so that the returned string is fully self-contained and can be stored in a checkpoint without needing a separate params dict. """ base_url = f"{self.graph_api_base}/drives/{drive_id}/root/delta" params = [f"$top={page_size}"] if start is not None and start > _EPOCH: token = quote(start.isoformat(timespec="seconds")) params.append(f"token={token}") return f"{base_url}?{'&'.join(params)}" def _fetch_one_delta_page( self, page_url: str, drive_id: str, start: datetime | None = None, end: datetime | None = None, page_size: int = 200, ) -> tuple[list[DriveItemData], str | None]: """Fetch a single page of delta API results. Returns ``(items, next_page_url)``. *next_page_url* is ``None`` when the delta enumeration is complete (deltaLink with no nextLink). On 410 Gone (expired token) returns ``([], full_resync_url)`` so the caller can store the resync URL in the checkpoint and retry on the next cycle. """ try: data = self._graph_api_get_json(page_url) except requests.HTTPError as e: if e.response is not None and e.response.status_code == 410: logger.warning( "Delta token expired (410 Gone) for drive '%s'. Will restart with full delta enumeration.", drive_id, ) full_url = f"{self.graph_api_base}/drives/{drive_id}/root/delta?$top={page_size}" return [], full_url raise items: list[DriveItemData] = [] for item in data.get("value", []): if "folder" in item or "deleted" in item: continue if start is not None or end is not None: raw_ts = item.get("lastModifiedDateTime") if raw_ts: mod_dt = datetime.fromisoformat(raw_ts.replace("Z", "+00:00")) if start is not None and mod_dt < start: continue if end is not None and mod_dt > end: continue items.append(DriveItemData.from_graph_json(item)) next_url = data.get("@odata.nextLink") if next_url: return items, next_url return items, None @staticmethod def _clear_drive_checkpoint_state( checkpoint: "SharepointConnectorCheckpoint", ) -> None: """Reset all drive-level fields in the checkpoint.""" checkpoint.current_drive_name = None checkpoint.current_drive_id = None checkpoint.current_drive_web_url = None checkpoint.current_drive_delta_next_link = None checkpoint.seen_document_ids.clear() def _fetch_slim_documents_from_sharepoint( self, start: datetime | None = None, end: datetime | None = None, ) -> GenerateSlimDocumentOutput: site_descriptors = self._filter_excluded_sites( self.site_descriptors or self.fetch_sites() ) # Create a temporary checkpoint for hierarchy node tracking temp_checkpoint = SharepointConnectorCheckpoint(has_more=True) # goes over all urls, converts them into SlimDocument objects and then yields them in batches doc_batch: list[SlimDocument | HierarchyNode] = [] for site_descriptor in site_descriptors: site_url = site_descriptor.url # Yield site hierarchy node using helper doc_batch.extend( self._yield_site_hierarchy_node(site_descriptor, temp_checkpoint) ) # Process site documents if flag is True if self.include_site_documents: for driveitem, drive_name, drive_web_url in self._fetch_driveitems( site_descriptor=site_descriptor, start=start, end=end, ): if self._is_driveitem_excluded(driveitem): logger.debug(f"Excluding by path denylist: {driveitem.web_url}") continue if drive_web_url: doc_batch.extend( self._yield_drive_hierarchy_node( site_url, drive_web_url, drive_name, temp_checkpoint ) ) folder_path = self._extract_folder_path_from_parent_reference( driveitem.parent_reference_path ) if folder_path and drive_web_url: doc_batch.extend( self._yield_folder_hierarchy_nodes( site_url, drive_web_url, drive_name, folder_path, temp_checkpoint, ) ) parent_hierarchy_url: str | None = None if drive_web_url: parent_hierarchy_url = self._get_parent_hierarchy_url( site_url, drive_web_url, drive_name, driveitem ) try: logger.debug(f"Processing: {driveitem.web_url}") ctx = self._create_rest_client_context(site_descriptor.url) doc_batch.append( _convert_driveitem_to_slim_document( driveitem, drive_name, ctx, self.graph_client, parent_hierarchy_raw_node_id=parent_hierarchy_url, treat_sharing_link_as_public=self.treat_sharing_link_as_public, ) ) except Exception as e: logger.warning(f"Failed to process driveitem: {str(e)}") if len(doc_batch) >= SLIM_BATCH_SIZE: yield doc_batch doc_batch = [] # Process site pages if flag is True if self.include_site_pages: site_pages = self._fetch_site_pages( site_descriptor, start=start, end=end ) for site_page in site_pages: logger.debug( f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}" ) ctx = self._create_rest_client_context(site_descriptor.url) doc_batch.append( _convert_sitepage_to_slim_document( site_page, ctx, self.graph_client, parent_hierarchy_raw_node_id=site_descriptor.url, treat_sharing_link_as_public=self.treat_sharing_link_as_public, ) ) if len(doc_batch) >= SLIM_BATCH_SIZE: yield doc_batch doc_batch = [] yield doc_batch def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self._credential_json = credentials auth_method = credentials.get( "authentication_method", SharepointAuthMethod.CLIENT_SECRET.value ) sp_client_id = credentials.get("sp_client_id") sp_client_secret = credentials.get("sp_client_secret") sp_directory_id = credentials.get("sp_directory_id") sp_private_key = credentials.get("sp_private_key") sp_certificate_password = credentials.get("sp_certificate_password") if not sp_client_id: raise ConnectorValidationError("Client ID is required") if not sp_directory_id: raise ConnectorValidationError("Directory (tenant) ID is required") authority_url = f"{self.authority_host}/{sp_directory_id}" if auth_method == SharepointAuthMethod.CERTIFICATE.value: logger.info("Using certificate authentication") if not sp_private_key or not sp_certificate_password: raise ConnectorValidationError( "Private key and certificate password are required for certificate authentication" ) pfx_data = base64.b64decode(sp_private_key) certificate_data = load_certificate_from_pfx( pfx_data, sp_certificate_password ) if certificate_data is None: raise RuntimeError("Failed to load certificate") logger.info(f"Creating MSAL app with authority url {authority_url}") self.msal_app = msal.ConfidentialClientApplication( authority=authority_url, client_id=sp_client_id, client_credential=certificate_data.model_dump(), ) elif auth_method == SharepointAuthMethod.CLIENT_SECRET.value: logger.info("Using client secret authentication") self.msal_app = msal.ConfidentialClientApplication( authority=authority_url, client_id=sp_client_id, client_credential=sp_client_secret, ) else: raise ConnectorValidationError( "Invalid authentication method or missing required credentials" ) def _acquire_token_for_graph() -> dict[str, Any]: """ Acquire token via MSAL """ if self.msal_app is None: raise ConnectorValidationError("MSAL app is not initialized") token = self.msal_app.acquire_token_for_client( scopes=[f"{self.graph_api_host}/.default"] ) if token is None: raise ConnectorValidationError("Failed to acquire token for graph") return token self._graph_client = GraphClient( _acquire_token_for_graph, environment=self._azure_environment ) if auth_method == SharepointAuthMethod.CERTIFICATE.value: self.sp_tenant_domain = self._resolve_tenant_domain() return None def _get_drive_names_for_site(self, site_url: str) -> list[str]: """Return all library/drive names for a given SharePoint site.""" try: site = self.graph_client.sites.get_by_url(site_url) drives = site.drives.get_all(page_loaded=lambda _: None).execute_query() drive_names: list[str] = [] for drive in drives: if drive.name is None: continue drive_names.append(drive.name) return drive_names except Exception as e: logger.warning(f"Failed to fetch drives for site '{site_url}': {e}") return [] def _build_folder_url( self, site_url: str, drive_name: str, folder_path: str ) -> str: """Build a URL for a folder to use as raw_node_id. NOTE: This constructs an approximate folder URL from components rather than fetching the actual webUrl from the API. The constructed URL may differ slightly from SharePoint's canonical webUrl (e.g., URL encoding differences), but it functions correctly as a unique identifier for hierarchy tracking. We avoid fetching folder metadata to minimize API calls. """ return f"{site_url}/{drive_name}/{folder_path}" def _extract_folder_path_from_parent_reference( self, parent_reference_path: str | None ) -> str | None: """Extract folder path from DriveItem's parentReference.path. Example input: "/drives/b!abc123/root:/Engineering/API" Example output: "Engineering/API" Returns None if the item is at the root of the drive. """ if not parent_reference_path: return None # Path format: /drives/{drive_id}/root:/folder/path if "root:/" in parent_reference_path: folder_path = parent_reference_path.split("root:/")[1] return folder_path if folder_path else None # Item is at drive root return None def _yield_site_hierarchy_node( self, site_descriptor: SiteDescriptor, checkpoint: SharepointConnectorCheckpoint, ) -> Generator[HierarchyNode, None, None]: """Yield a hierarchy node for a site if not already yielded. Uses site.web_url as the raw_node_id (exact URL from API). """ site_url = site_descriptor.url if site_url in checkpoint.seen_hierarchy_node_raw_ids: return checkpoint.seen_hierarchy_node_raw_ids.add(site_url) # Extract display name from URL (last path segment) display_name = site_url.rstrip("/").split("/")[-1] yield HierarchyNode( raw_node_id=site_url, raw_parent_id=None, # Parent is SOURCE display_name=display_name, link=site_url, node_type=HierarchyNodeType.SITE, ) def _yield_drive_hierarchy_node( self, site_url: str, drive_web_url: str, drive_name: str, checkpoint: SharepointConnectorCheckpoint, ) -> Generator[HierarchyNode, None, None]: """Yield a hierarchy node for a drive if not already yielded. Uses drive.web_url as the raw_node_id (exact URL from API). """ if drive_web_url in checkpoint.seen_hierarchy_node_raw_ids: return checkpoint.seen_hierarchy_node_raw_ids.add(drive_web_url) yield HierarchyNode( raw_node_id=drive_web_url, raw_parent_id=site_url, # Site URL is parent display_name=drive_name, link=drive_web_url, node_type=HierarchyNodeType.DRIVE, ) def _yield_folder_hierarchy_nodes( self, site_url: str, drive_web_url: str, drive_name: str, folder_path: str, checkpoint: SharepointConnectorCheckpoint, ) -> Generator[HierarchyNode, None, None]: """Yield hierarchy nodes for all folders in a path. For path "Engineering/API/v2", yields nodes for: 1. "Engineering" (parent = drive) 2. "Engineering/API" (parent = "Engineering") 3. "Engineering/API/v2" (parent = "Engineering/API") Nodes are yielded in parent-to-child order. Uses constructed URLs as raw_node_id. See _build_folder_url for details on why we construct URLs rather than fetching them from the API. """ if not folder_path: return path_parts = folder_path.split("/") for i, part in enumerate(path_parts): current_path = "/".join(path_parts[: i + 1]) folder_url = self._build_folder_url(site_url, drive_name, current_path) if folder_url in checkpoint.seen_hierarchy_node_raw_ids: continue checkpoint.seen_hierarchy_node_raw_ids.add(folder_url) # Determine parent URL if i == 0: # First folder, parent is the drive parent_url = drive_web_url else: # Parent is the previous folder parent_path = "/".join(path_parts[:i]) parent_url = self._build_folder_url(site_url, drive_name, parent_path) yield HierarchyNode( raw_node_id=folder_url, raw_parent_id=parent_url, display_name=part, # Just the folder name link=folder_url, node_type=HierarchyNodeType.FOLDER, ) def _get_parent_hierarchy_url( self, site_url: str, drive_web_url: str, drive_name: str, driveitem: DriveItemData, ) -> str: """Determine the parent hierarchy node URL for a document. Returns: - Folder URL if document is in a folder - Drive URL if document is at drive root """ folder_path = self._extract_folder_path_from_parent_reference( driveitem.parent_reference_path ) if folder_path: return self._build_folder_url(site_url, drive_name, folder_path) # Document is at drive root return drive_web_url def _load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: SharepointConnectorCheckpoint, include_permissions: bool = False, ) -> CheckpointOutput[SharepointConnectorCheckpoint]: if self._graph_client is None: raise ConnectorMissingCredentialError("Sharepoint") checkpoint = copy.deepcopy(checkpoint) # Phase 1: Initialize cached_site_descriptors if needed if ( checkpoint.has_more and checkpoint.cached_site_descriptors is None and not checkpoint.process_site_pages ): logger.info("Initializing SharePoint sites for processing") site_descs = self._filter_excluded_sites( self.site_descriptors or self.fetch_sites() ) checkpoint.cached_site_descriptors = deque(site_descs) if not checkpoint.cached_site_descriptors: logger.warning( "No SharePoint sites found or accessible - nothing to process" ) checkpoint.has_more = False return checkpoint logger.info( f"Found {len(checkpoint.cached_site_descriptors)} sites to process" ) # Set first site and return to allow checkpoint persistence if checkpoint.cached_site_descriptors: checkpoint.current_site_descriptor = ( checkpoint.cached_site_descriptors.popleft() ) logger.info( f"Starting with site: {checkpoint.current_site_descriptor.url}" ) # Yield site hierarchy node for the first site yield from self._yield_site_hierarchy_node( checkpoint.current_site_descriptor, checkpoint ) return checkpoint # Phase 2: Initialize cached_drive_names for current site if needed if checkpoint.current_site_descriptor and checkpoint.cached_drive_names is None: # If site documents flag is False, set empty drive list to skip document processing if not self.include_site_documents: logger.debug("Documents disabled, skipping drive initialization") checkpoint.cached_drive_names = deque() return checkpoint logger.info( f"Initializing drives for site: {checkpoint.current_site_descriptor.url}" ) try: # If the user explicitly specified drive(s) for this site, honour that if checkpoint.current_site_descriptor.drive_name: logger.info( f"Using explicitly specified drive: {checkpoint.current_site_descriptor.drive_name}" ) checkpoint.cached_drive_names = deque( [checkpoint.current_site_descriptor.drive_name] ) else: drive_names = self._get_drive_names_for_site( checkpoint.current_site_descriptor.url ) checkpoint.cached_drive_names = deque(drive_names) if not checkpoint.cached_drive_names: logger.warning( f"No accessible drives found for site: {checkpoint.current_site_descriptor.url}" ) else: logger.info( f"Found {len(checkpoint.cached_drive_names)} drives: {list(checkpoint.cached_drive_names)}" ) except Exception as e: logger.error( f"Failed to initialize drives for site: {checkpoint.current_site_descriptor.url}: {e}" ) # Yield a ConnectorFailure for site-level access failures start_dt = datetime.fromtimestamp(start, tz=timezone.utc) end_dt = datetime.fromtimestamp(end, tz=timezone.utc) yield _create_entity_failure( checkpoint.current_site_descriptor.url, f"Failed to access site: {str(e)}", (start_dt, end_dt), e, ) # Move to next site if available if ( checkpoint.cached_site_descriptors and len(checkpoint.cached_site_descriptors) > 0 ): checkpoint.current_site_descriptor = ( checkpoint.cached_site_descriptors.popleft() ) checkpoint.cached_drive_names = None # Reset for new site return checkpoint else: # No more sites - we're done checkpoint.has_more = False return checkpoint # Return checkpoint to allow persistence after drive initialization return checkpoint # Phase 3a: Initialize the next drive for processing if ( checkpoint.current_site_descriptor and checkpoint.cached_drive_names and len(checkpoint.cached_drive_names) > 0 and checkpoint.current_drive_name is None ): checkpoint.current_drive_name = checkpoint.cached_drive_names.popleft() start_dt = datetime.fromtimestamp(start, tz=timezone.utc) end_dt = datetime.fromtimestamp(end, tz=timezone.utc) site_descriptor = checkpoint.current_site_descriptor logger.info( f"Processing drive '{checkpoint.current_drive_name}' in site: {site_descriptor.url}" ) logger.debug(f"Time range: {start_dt} to {end_dt}") current_drive_name = checkpoint.current_drive_name if current_drive_name is None: logger.warning("Current drive name is None, skipping") return checkpoint try: logger.info( f"Fetching drive items for drive name: {current_drive_name}" ) result = self._resolve_drive(site_descriptor, current_drive_name) if result is None: logger.warning(f"Drive '{current_drive_name}' not found, skipping") self._clear_drive_checkpoint_state(checkpoint) return checkpoint drive_id, drive_web_url = result checkpoint.current_drive_id = drive_id checkpoint.current_drive_web_url = drive_web_url except Exception as e: logger.error( f"Failed to retrieve items from drive '{current_drive_name}' in site: {site_descriptor.url}: {e}" ) yield _create_entity_failure( f"{site_descriptor.url}|{current_drive_name}", f"Failed to access drive '{current_drive_name}' in site '{site_descriptor.url}': {str(e)}", (start_dt, end_dt), e, ) self._clear_drive_checkpoint_state(checkpoint) return checkpoint display_drive_name = SHARED_DOCUMENTS_MAP.get( current_drive_name, current_drive_name ) if drive_web_url: yield from self._yield_drive_hierarchy_node( site_descriptor.url, drive_web_url, display_drive_name, checkpoint, ) # For non-folder-scoped drives, use delta API with per-page # checkpointing. Build the initial URL and fall through to 3b. if not site_descriptor.folder_path: checkpoint.current_drive_delta_next_link = self._build_delta_start_url( drive_id, start_dt ) # else: BFS path — delta_next_link stays None; # Phase 3b will use _iter_drive_items_paged. # Phase 3b: Process items from the current drive if ( checkpoint.current_site_descriptor and checkpoint.current_drive_name is not None and checkpoint.current_drive_id is not None ): site_descriptor = checkpoint.current_site_descriptor start_dt = datetime.fromtimestamp(start, tz=timezone.utc) end_dt = datetime.fromtimestamp(end, tz=timezone.utc) current_drive_name = SHARED_DOCUMENTS_MAP.get( checkpoint.current_drive_name, checkpoint.current_drive_name ) drive_web_url = checkpoint.current_drive_web_url # --- determine item source --- driveitems: Iterable[DriveItemData] has_more_delta_pages = False if checkpoint.current_drive_delta_next_link: # Delta path: fetch one page at a time for checkpointing try: page_items, next_url = self._fetch_one_delta_page( page_url=checkpoint.current_drive_delta_next_link, drive_id=checkpoint.current_drive_id, start=start_dt, end=end_dt, ) except Exception as e: logger.error( f"Failed to fetch delta page for drive '{current_drive_name}': {e}" ) yield _create_entity_failure( f"{site_descriptor.url}|{current_drive_name}", f"Failed to fetch delta page for drive '{current_drive_name}': {str(e)}", (start_dt, end_dt), e, ) self._clear_drive_checkpoint_state(checkpoint) return checkpoint driveitems = page_items has_more_delta_pages = next_url is not None if next_url: checkpoint.current_drive_delta_next_link = next_url else: # BFS path (folder-scoped): process all items at once driveitems = self._iter_drive_items_paged( drive_id=checkpoint.current_drive_id, folder_path=site_descriptor.folder_path, start=start_dt, end=end_dt, ) item_count = 0 for driveitem in driveitems: item_count += 1 if self._is_driveitem_excluded(driveitem): logger.debug(f"Excluding by path denylist: {driveitem.web_url}") continue if driveitem.id and driveitem.id in checkpoint.seen_document_ids: logger.debug( f"Skipping duplicate document {driveitem.id} ({driveitem.name})" ) continue driveitem_extension = get_file_ext(driveitem.name) if driveitem_extension not in OnyxFileExtensions.ALL_ALLOWED_EXTENSIONS: logger.warning( f"Skipping {driveitem.web_url} as it is not a supported file type" ) continue should_yield_if_empty = ( driveitem_extension in OnyxFileExtensions.IMAGE_EXTENSIONS or driveitem_extension == ".pdf" ) folder_path = self._extract_folder_path_from_parent_reference( driveitem.parent_reference_path ) if folder_path and drive_web_url: yield from self._yield_folder_hierarchy_nodes( site_descriptor.url, drive_web_url, current_drive_name, folder_path, checkpoint, ) parent_hierarchy_url: str | None = None if drive_web_url: parent_hierarchy_url = self._get_parent_hierarchy_url( site_descriptor.url, drive_web_url, current_drive_name, driveitem, ) try: ctx: ClientContext | None = None if include_permissions: ctx = self._create_rest_client_context(site_descriptor.url) access_token = self._get_graph_access_token() doc_or_failure = _convert_driveitem_to_document_with_permissions( driveitem, current_drive_name, ctx, self.graph_client, include_permissions=include_permissions, parent_hierarchy_raw_node_id=parent_hierarchy_url, graph_api_base=self.graph_api_base, access_token=access_token, treat_sharing_link_as_public=self.treat_sharing_link_as_public, ) if isinstance(doc_or_failure, Document): if doc_or_failure.sections: checkpoint.seen_document_ids.add(doc_or_failure.id) yield doc_or_failure elif should_yield_if_empty: doc_or_failure.sections = [ TextSection(link=driveitem.web_url, text="") ] checkpoint.seen_document_ids.add(doc_or_failure.id) yield doc_or_failure else: logger.warning( f"Skipping {driveitem.web_url} as it is empty and not a PDF or image" ) elif isinstance(doc_or_failure, ConnectorFailure): yield doc_or_failure except Exception as e: logger.warning( f"Failed to process driveitem {driveitem.web_url}: {e}" ) yield _create_document_failure( driveitem, f"Failed to process: {str(e)}", e ) logger.info(f"Processed {item_count} items in drive '{current_drive_name}'") if has_more_delta_pages: return checkpoint self._clear_drive_checkpoint_state(checkpoint) # Phase 4: Progression logic - determine next step # If we have more drives in current site, continue with current site if checkpoint.cached_drive_names and len(checkpoint.cached_drive_names) > 0: logger.debug( f"Continuing with {len(checkpoint.cached_drive_names)} remaining drives in current site" ) return checkpoint if ( self.include_site_pages and not checkpoint.process_site_pages and checkpoint.current_site_descriptor is not None ): logger.info( f"Processing site pages for site: {checkpoint.current_site_descriptor.url}" ) checkpoint.process_site_pages = True return checkpoint # Phase 5: Process site pages if ( checkpoint.process_site_pages and checkpoint.current_site_descriptor is not None ): # Fetch SharePoint site pages (.aspx files) site_descriptor = checkpoint.current_site_descriptor start_dt = datetime.fromtimestamp(start, tz=timezone.utc) end_dt = datetime.fromtimestamp(end, tz=timezone.utc) site_pages = self._fetch_site_pages( site_descriptor, start=start_dt, end=end_dt ) for site_page in site_pages: logger.debug( f"Processing site page: {site_page.get('webUrl', site_page.get('name', 'Unknown'))}" ) client_ctx: ClientContext | None = None if include_permissions: client_ctx = self._create_rest_client_context(site_descriptor.url) yield ( _convert_sitepage_to_document( site_page, site_descriptor.drive_name, client_ctx, self.graph_client, include_permissions=include_permissions, # Site pages have the site as their parent parent_hierarchy_raw_node_id=site_descriptor.url, treat_sharing_link_as_public=self.treat_sharing_link_as_public, ) ) logger.info( f"Finished processing site pages for site: {site_descriptor.url}" ) # If no more drives, move to next site if available if ( checkpoint.cached_site_descriptors and len(checkpoint.cached_site_descriptors) > 0 ): current_site = ( checkpoint.current_site_descriptor.url if checkpoint.current_site_descriptor else "unknown" ) checkpoint.current_site_descriptor = ( checkpoint.cached_site_descriptors.popleft() ) checkpoint.cached_drive_names = None # Reset for new site checkpoint.process_site_pages = False logger.info( f"Finished site '{current_site}', moving to next site: {checkpoint.current_site_descriptor.url}" ) logger.info( f"Remaining sites to process: {len(checkpoint.cached_site_descriptors) + 1}" ) # Yield site hierarchy node for the new site yield from self._yield_site_hierarchy_node( checkpoint.current_site_descriptor, checkpoint ) return checkpoint # No more sites or drives - we're done current_site = ( checkpoint.current_site_descriptor.url if checkpoint.current_site_descriptor else "unknown" ) logger.info( f"SharePoint processing complete. Finished last site: {current_site}" ) checkpoint.has_more = False return checkpoint def load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: SharepointConnectorCheckpoint, ) -> CheckpointOutput[SharepointConnectorCheckpoint]: return self._load_from_checkpoint( start, end, checkpoint, include_permissions=False ) def load_from_checkpoint_with_perm_sync( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: SharepointConnectorCheckpoint, ) -> CheckpointOutput[SharepointConnectorCheckpoint]: return self._load_from_checkpoint( start, end, checkpoint, include_permissions=True ) def build_dummy_checkpoint(self) -> SharepointConnectorCheckpoint: return SharepointConnectorCheckpoint(has_more=True) def validate_checkpoint_json( self, checkpoint_json: str ) -> SharepointConnectorCheckpoint: return SharepointConnectorCheckpoint.model_validate_json(checkpoint_json) def retrieve_all_slim_docs_perm_sync( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, # noqa: ARG002 ) -> GenerateSlimDocumentOutput: start_dt = ( datetime.fromtimestamp(start, tz=timezone.utc) if start is not None else None ) end_dt = ( datetime.fromtimestamp(end, tz=timezone.utc) if end is not None else None ) yield from self._fetch_slim_documents_from_sharepoint( start=start_dt, end=end_dt, ) if __name__ == "__main__": from onyx.connectors.connector_runner import ConnectorRunner connector = SharepointConnector(sites=os.environ["SHAREPOINT_SITES"].split(",")) connector.load_credentials( { "sp_client_id": os.environ["SHAREPOINT_CLIENT_ID"], "sp_client_secret": os.environ["SHAREPOINT_CLIENT_SECRET"], "sp_directory_id": os.environ["SHAREPOINT_CLIENT_DIRECTORY_ID"], } ) # Create a time range from epoch to now end_time = datetime.now(timezone.utc) start_time = datetime.fromtimestamp(0, tz=timezone.utc) time_range = (start_time, end_time) # Initialize the runner with a batch size of 10 runner: ConnectorRunner[SharepointConnectorCheckpoint] = ConnectorRunner( connector, batch_size=10, include_permissions=False, time_range=time_range ) # Get initial checkpoint checkpoint = connector.build_dummy_checkpoint() # Run the connector while checkpoint.has_more: for doc_batch, hierarchy_node_batch, failure, next_checkpoint in runner.run( checkpoint ): if doc_batch: print(f"Retrieved batch of {len(doc_batch)} documents") for test_doc in doc_batch: print(f"Document: {test_doc.semantic_identifier}") if failure: print(f"Failure: {failure.failure_message}") if next_checkpoint: checkpoint = next_checkpoint ================================================ FILE: backend/onyx/connectors/sharepoint/connector_utils.py ================================================ from typing import Any from office365.graph_client import GraphClient # type: ignore[import-untyped] from office365.onedrive.driveitems.driveItem import DriveItem # type: ignore[import-untyped] from office365.sharepoint.client_context import ClientContext # type: ignore[import-untyped] from onyx.connectors.models import ExternalAccess from onyx.utils.variable_functionality import ( fetch_versioned_implementation_with_fallback, ) def get_sharepoint_external_access( ctx: ClientContext, graph_client: GraphClient, drive_item: DriveItem | None = None, drive_name: str | None = None, site_page: dict[str, Any] | None = None, add_prefix: bool = False, treat_sharing_link_as_public: bool = False, ) -> ExternalAccess: if drive_item and drive_item.id is None: raise ValueError("DriveItem ID is required") # Get external access using the EE implementation def noop_fallback( *args: Any, **kwargs: Any # noqa: ARG001 ) -> ExternalAccess: # noqa: ARG001 return ExternalAccess.empty() get_external_access_func = fetch_versioned_implementation_with_fallback( "onyx.external_permissions.sharepoint.permission_utils", "get_external_access_from_sharepoint", fallback=noop_fallback, ) external_access = get_external_access_func( ctx, graph_client, drive_name, drive_item, site_page, add_prefix, treat_sharing_link_as_public, ) return external_access ================================================ FILE: backend/onyx/connectors/slab/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/slab/connector.py ================================================ import json from collections.abc import Callable from collections.abc import Generator from datetime import datetime from datetime import timezone from typing import Any from urllib.parse import urljoin import requests from dateutil import parser from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnectorWithPermSync from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import SlimDocument from onyx.connectors.models import TextSection from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger logger = setup_logger() # Fairly generous retry because it's not understood why occasionally GraphQL requests fail even with timeout > 1 min SLAB_GRAPHQL_MAX_TRIES = 10 SLAB_API_URL = "https://api.slab.com/v1/graphql" _SLIM_BATCH_SIZE = 1000 def run_graphql_request( graphql_query: dict, bot_token: str, max_tries: int = SLAB_GRAPHQL_MAX_TRIES ) -> str: headers = {"Authorization": bot_token, "Content-Type": "application/json"} for try_count in range(max_tries): try: response = requests.post( SLAB_API_URL, headers=headers, json=graphql_query, timeout=60 ) response.raise_for_status() if response.status_code != 200: raise ValueError(f"GraphQL query failed: {graphql_query}") return response.text except (requests.exceptions.Timeout, ValueError) as e: if try_count < max_tries - 1: logger.warning("A Slab GraphQL error occurred. Retrying...") continue if isinstance(e, requests.exceptions.Timeout): raise TimeoutError("Slab API timed out after 3 attempts") else: raise ValueError("Slab GraphQL query failed after 3 attempts") raise RuntimeError( "Unexpected execution from Slab Connector. This should not happen." ) # for static checker def get_all_post_ids(bot_token: str) -> list[str]: query = """ query GetAllPostIds { organization { posts { id } } } """ graphql_query = {"query": query} results = json.loads(run_graphql_request(graphql_query, bot_token)) posts = results["data"]["organization"]["posts"] return [post["id"] for post in posts] def get_post_by_id(post_id: str, bot_token: str) -> dict[str, str]: query = """ query GetPostById($postId: ID!) { post(id: $postId) { title content linkAccess updatedAt } } """ graphql_query = {"query": query, "variables": {"postId": post_id}} results = json.loads(run_graphql_request(graphql_query, bot_token)) return results["data"]["post"] def iterate_post_batches( batch_size: int, bot_token: str ) -> Generator[list[dict[str, str]], None, None]: """This may not be safe to use, not sure if page edits will change the order of results""" query = """ query IteratePostBatches($query: String!, $first: Int, $types: [SearchType], $after: String) { search(query: $query, first: $first, types: $types, after: $after) { edges { node { ... on PostSearchResult { post { id title content updatedAt } } } } pageInfo { endCursor hasNextPage } } } """ pagination_start = None exists_more_pages = True while exists_more_pages: graphql_query = { "query": query, "variables": { "query": "", "first": batch_size, "types": ["POST"], "after": pagination_start, }, } results = json.loads(run_graphql_request(graphql_query, bot_token)) pagination_start = results["data"]["search"]["pageInfo"]["endCursor"] hits = results["data"]["search"]["edges"] posts = [hit["node"] for hit in hits] if posts: yield posts exists_more_pages = results["data"]["search"]["pageInfo"]["hasNextPage"] def get_slab_url_from_title_id(base_url: str, title: str, page_id: str) -> str: """This is not a documented approach but seems to be the way it works currently May be subject to change without notification""" title = ( title.replace("[", "") .replace("]", "") .replace(":", "") .replace(" ", "-") .lower() ) url_id = title + "-" + page_id return urljoin(urljoin(base_url, "posts/"), url_id) class SlabConnector(LoadConnector, PollConnector, SlimConnectorWithPermSync): def __init__( self, base_url: str, batch_size: int = INDEX_BATCH_SIZE, ) -> None: self.base_url = base_url self.batch_size = batch_size self._slab_bot_token: str | None = None def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: self._slab_bot_token = credentials["slab_bot_token"] return None @property def slab_bot_token(self) -> str: if self._slab_bot_token is None: raise ConnectorMissingCredentialError("Slab") return self._slab_bot_token def _iterate_posts( self, time_filter: Callable[[datetime], bool] | None = None ) -> GenerateDocumentsOutput: doc_batch: list[Document | HierarchyNode] = [] if self.slab_bot_token is None: raise ConnectorMissingCredentialError("Slab") all_post_ids: list[str] = get_all_post_ids(self.slab_bot_token) for post_id in all_post_ids: post = get_post_by_id(post_id, self.slab_bot_token) last_modified = parser.parse(post["updatedAt"]) if time_filter is not None and not time_filter(last_modified): continue page_url = get_slab_url_from_title_id(self.base_url, post["title"], post_id) content_text = "" contents = json.loads(post["content"]) for content_segment in contents: insert = content_segment.get("insert") if insert and isinstance(insert, str): content_text += insert doc_batch.append( Document( id=post_id, # can't be url as this changes with the post title sections=[TextSection(link=page_url, text=content_text)], source=DocumentSource.SLAB, semantic_identifier=post["title"], metadata={}, ) ) if len(doc_batch) >= self.batch_size: yield doc_batch doc_batch = [] if doc_batch: yield doc_batch def load_from_state(self) -> GenerateDocumentsOutput: yield from self._iterate_posts() def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: start_time = datetime.fromtimestamp(start, tz=timezone.utc) end_time = datetime.fromtimestamp(end, tz=timezone.utc) yield from self._iterate_posts( time_filter=lambda t: start_time <= t <= end_time ) def retrieve_all_slim_docs_perm_sync( self, start: SecondsSinceUnixEpoch | None = None, # noqa: ARG002 end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002 callback: IndexingHeartbeatInterface | None = None, # noqa: ARG002 ) -> GenerateSlimDocumentOutput: slim_doc_batch: list[SlimDocument | HierarchyNode] = [] for post_id in get_all_post_ids(self.slab_bot_token): slim_doc_batch.append( SlimDocument( id=post_id, ) ) if len(slim_doc_batch) >= _SLIM_BATCH_SIZE: yield slim_doc_batch slim_doc_batch = [] if slim_doc_batch: yield slim_doc_batch def validate_connector_settings(self) -> None: """ Very basic validation, we could do more here """ if not self.base_url.startswith("https://") and not self.base_url.startswith( "http://" ): raise ConnectorValidationError( "Base URL must start with https:// or http://" ) try: get_all_post_ids(self.slab_bot_token) except ConnectorMissingCredentialError: raise except Exception as e: raise ConnectorValidationError(f"Failed to fetch posts from Slab: {e}") ================================================ FILE: backend/onyx/connectors/slack/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/slack/access.py ================================================ from collections.abc import Callable from typing import cast from slack_sdk import WebClient from onyx.access.models import ExternalAccess from onyx.connectors.models import BasicExpertInfo from onyx.connectors.slack.models import ChannelType from onyx.utils.variable_functionality import fetch_versioned_implementation from onyx.utils.variable_functionality import global_version def get_channel_access( client: WebClient, channel: ChannelType, user_cache: dict[str, BasicExpertInfo | None], ) -> ExternalAccess | None: """ Get channel access permissions for a Slack channel. This functionality requires Enterprise Edition. Args: client: Slack WebClient instance channel: Slack channel object containing channel info user_cache: Cache of user IDs to BasicExpertInfo objects. May be updated in place. Returns: ExternalAccess object for the channel. None if EE is not enabled. """ # Check if EE is enabled if not global_version.is_ee_version(): return None # Fetch the EE implementation ee_get_channel_access = cast( Callable[ [WebClient, ChannelType, dict[str, BasicExpertInfo | None]], ExternalAccess, ], fetch_versioned_implementation( "onyx.external_permissions.slack.channel_access", "get_channel_access" ), ) return ee_get_channel_access(client, channel, user_cache) ================================================ FILE: backend/onyx/connectors/slack/connector.py ================================================ import contextvars import copy import itertools import re from collections.abc import Callable from collections.abc import Generator from concurrent.futures import as_completed from concurrent.futures import Future from concurrent.futures import ThreadPoolExecutor from datetime import datetime from datetime import timezone from enum import Enum from http.client import IncompleteRead from http.client import RemoteDisconnected from typing import Any from typing import cast from urllib.error import URLError from urllib.parse import urlparse from pydantic import BaseModel from redis import Redis from slack_sdk import WebClient from slack_sdk.errors import SlackApiError from slack_sdk.http_retry import ConnectionErrorRetryHandler from slack_sdk.http_retry import RetryHandler from slack_sdk.http_retry.builtin_interval_calculators import ( FixedValueRetryIntervalCalculator, ) from typing_extensions import override from onyx.access.models import ExternalAccess from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.app_configs import SLACK_NUM_THREADS from onyx.configs.constants import DocumentSource from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.exceptions import UnexpectedValidationError from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import CredentialsConnector from onyx.connectors.interfaces import CredentialsProviderInterface from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import NormalizationResult from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnectorWithPermSync from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import ConnectorCheckpoint from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import DocumentFailure from onyx.connectors.models import EntityFailure from onyx.connectors.models import HierarchyNode from onyx.connectors.models import SlimDocument from onyx.connectors.models import TextSection from onyx.connectors.slack.access import get_channel_access from onyx.connectors.slack.models import ChannelType from onyx.connectors.slack.models import MessageType from onyx.connectors.slack.models import ThreadType from onyx.connectors.slack.onyx_retry_handler import OnyxRedisSlackRetryHandler from onyx.connectors.slack.onyx_slack_web_client import OnyxSlackWebClient from onyx.connectors.slack.utils import ( expert_info_from_slack_id, ) from onyx.connectors.slack.utils import get_message_link from onyx.connectors.slack.utils import make_paginated_slack_api_call from onyx.connectors.slack.utils import SlackTextCleaner from onyx.db.enums import HierarchyNodeType from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.redis.redis_pool import get_redis_client from onyx.utils.logger import setup_logger logger = setup_logger() _SLACK_LIMIT = 900 class SlackCheckpoint(ConnectorCheckpoint): channel_ids: list[str] | None # e.g. C8E6WHE2X # channel id mapped to the timestamp we want to retrieve messages up to # NOTE: this is usually the earliest timestamp of all the messages we have # since we walk backwards channel_completion_map: dict[str, str] current_channel: ChannelType | None current_channel_access: ExternalAccess | None seen_thread_ts: list[ str ] # apparently we identify threads/messages uniquely by timestamp? def _collect_paginated_channels( client: WebClient, exclude_archived: bool, channel_types: list[str], ) -> list[ChannelType]: channels: list[ChannelType] = [] for result in make_paginated_slack_api_call( client.conversations_list, exclude_archived=exclude_archived, # also get private channels the bot is added to types=channel_types, ): channels.extend(result["channels"]) return channels def get_channels( client: WebClient, exclude_archived: bool = True, get_public: bool = True, get_private: bool = True, ) -> list[ChannelType]: """Get all channels in the workspace.""" channels: list[ChannelType] = [] channel_types = [] if get_public: channel_types.append("public_channel") if get_private: channel_types.append("private_channel") # Try fetching both public and private channels first: try: channels = _collect_paginated_channels( client=client, exclude_archived=exclude_archived, channel_types=channel_types, ) except SlackApiError as e: msg = f"Unable to fetch private channels due to: {e}." if not get_public: logger.warning(msg + " Public channels are not enabled.") return [] logger.warning(msg + " Trying again with public channels only.") channel_types = ["public_channel"] channels = _collect_paginated_channels( client=client, exclude_archived=exclude_archived, channel_types=channel_types, ) return channels def get_channel_messages( client: WebClient, channel: ChannelType, oldest: str | None = None, latest: str | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> Generator[list[MessageType], None, None]: """Get all messages in a channel""" # join so that the bot can access messages if not channel["is_member"]: client.conversations_join( channel=channel["id"], is_private=channel["is_private"], ) logger.info(f"Successfully joined '{channel['name']}'") for result in make_paginated_slack_api_call( client.conversations_history, channel=channel["id"], oldest=oldest, latest=latest, ): if callback: if callback.should_stop(): raise RuntimeError("get_channel_messages: Stop signal detected") callback.progress("get_channel_messages", 0) yield cast(list[MessageType], result["messages"]) def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType: """Get all messages in a thread""" threads: list[MessageType] = [] for result in make_paginated_slack_api_call( client.conversations_replies, channel=channel_id, ts=thread_id ): threads.extend(result["messages"]) return threads def get_latest_message_time(thread: ThreadType) -> datetime: max_ts = max([float(msg.get("ts", 0)) for msg in thread]) return datetime.fromtimestamp(max_ts, tz=timezone.utc) def _build_doc_id(channel_id: str, thread_ts: str) -> str: return f"{channel_id}__{thread_ts}" def thread_to_doc( channel: ChannelType, thread: ThreadType, slack_cleaner: SlackTextCleaner, client: WebClient, user_cache: dict[str, BasicExpertInfo | None], channel_access: ExternalAccess | None, ) -> Document: channel_id = channel["id"] initial_sender_expert_info = expert_info_from_slack_id( user_id=thread[0].get("user"), client=client, user_cache=user_cache ) initial_sender_name = ( initial_sender_expert_info.get_semantic_name() if initial_sender_expert_info else "Unknown" ) valid_experts = None if ENABLE_EXPENSIVE_EXPERT_CALLS: all_sender_ids = [m.get("user") for m in thread] experts = [ expert_info_from_slack_id( user_id=sender_id, client=client, user_cache=user_cache ) for sender_id in all_sender_ids if sender_id ] valid_experts = [expert for expert in experts if expert] first_message = slack_cleaner.index_clean(cast(str, thread[0]["text"])) snippet = ( first_message[:50].rstrip() + "..." if len(first_message) > 50 else first_message ) doc_sem_id = f"{initial_sender_name} in #{channel['name']}: {snippet}".replace( "\n", " " ) channel_name = channel["name"] return Document( id=_build_doc_id(channel_id=channel_id, thread_ts=thread[0]["ts"]), sections=[ TextSection( link=get_message_link(event=m, client=client, channel_id=channel_id), text=slack_cleaner.index_clean(cast(str, m["text"])), ) for m in thread ], source=DocumentSource.SLACK, semantic_identifier=doc_sem_id, doc_updated_at=get_latest_message_time(thread), primary_owners=valid_experts, doc_metadata={ "hierarchy": { "source_path": [channel_name], "channel_name": channel_name, "channel_id": channel_id, } }, metadata={"Channel": channel_name}, external_access=channel_access, parent_hierarchy_raw_node_id=channel_id, ) # list of subtypes can be found here: https://api.slack.com/events/message _DISALLOWED_MSG_SUBTYPES = { "channel_join", "channel_leave", "channel_archive", "channel_unarchive", "pinned_item", "unpinned_item", "ekm_access_denied", "channel_posting_permissions", "group_join", "group_leave", "group_archive", "group_unarchive", "channel_leave", "channel_name", "channel_join", } class SlackMessageFilterReason(str, Enum): BOT = "bot" DISALLOWED = "disallowed" def default_msg_filter(message: MessageType) -> SlackMessageFilterReason | None: """Returns a filter reason if the message should be filtered out. Returns None if the message can be kept. """ # Don't keep messages from bots if message.get("bot_id") or message.get("app_id"): bot_profile_name = message.get("bot_profile", {}).get("name") if bot_profile_name == "DanswerBot Testing": return None return SlackMessageFilterReason.BOT # Uninformative if message.get("subtype", "") in _DISALLOWED_MSG_SUBTYPES: return SlackMessageFilterReason.DISALLOWED return None def _bot_inclusive_msg_filter( message: MessageType, ) -> SlackMessageFilterReason | None: """Like default_msg_filter but allows bot/app messages through. Only filters out disallowed subtypes (channel_join, channel_leave, etc.). """ if message.get("subtype", "") in _DISALLOWED_MSG_SUBTYPES: return SlackMessageFilterReason.DISALLOWED return None def filter_channels( all_channels: list[ChannelType], channels_to_connect: list[str] | None, regex_enabled: bool, ) -> list[ChannelType]: if not channels_to_connect: return all_channels if regex_enabled: return [ channel for channel in all_channels if any( re.fullmatch(channel_to_connect, channel["name"]) for channel_to_connect in channels_to_connect ) ] # validate that all channels in `channels_to_connect` are valid # fail loudly in the case of an invalid channel so that the user # knows that one of the channels they've specified is typo'd or private all_channel_names = {channel["name"] for channel in all_channels} for channel in channels_to_connect: if channel not in all_channel_names: raise ValueError( f"Channel '{channel}' not found in workspace. " f"Available channels (Showing {len(all_channel_names)} of " f"{min(len(all_channel_names), SlackConnector.MAX_CHANNELS_TO_LOG)}): " f"{list(itertools.islice(all_channel_names, SlackConnector.MAX_CHANNELS_TO_LOG))}" ) return [ channel for channel in all_channels if channel["name"] in channels_to_connect ] def _channel_to_hierarchy_node( channel: ChannelType, channel_access: ExternalAccess | None, workspace_url: str | None = None, ) -> HierarchyNode: """Convert a Slack channel to a HierarchyNode. Args: channel: The Slack channel object channel_access: External access permissions for the channel workspace_url: The workspace URL (e.g., https://myworkspace.slack.com) Returns: A HierarchyNode representing the channel """ # Link format: https://{workspace}.slack.com/archives/{channel_id} link = f"{workspace_url}/archives/{channel['id']}" if workspace_url else None return HierarchyNode( raw_node_id=channel["id"], raw_parent_id=None, # Direct child of SOURCE display_name=f"#{channel['name']}", link=link, node_type=HierarchyNodeType.CHANNEL, external_access=channel_access, ) def _get_channel_by_id(client: WebClient, channel_id: str) -> ChannelType: """Get a channel by its ID. Args: client: The Slack WebClient instance channel_id: The ID of the channel to fetch Returns: The channel information Raises: SlackApiError: If the channel cannot be fetched """ response = client.conversations_info( channel=channel_id, ) return cast(ChannelType, response["channel"]) def _get_messages( channel: ChannelType, client: WebClient, oldest: str | None = None, latest: str | None = None, limit: int = _SLACK_LIMIT, ) -> tuple[list[MessageType], bool]: """Slack goes from newest to oldest.""" # have to be in the channel in order to read messages if not channel["is_member"]: try: client.conversations_join( channel=channel["id"], is_private=channel["is_private"], ) except SlackApiError as e: if e.response["error"] == "is_archived": logger.warning(f"Channel {channel['name']} is archived. Skipping.") return [], False logger.exception(f"Error joining channel {channel['name']}") raise logger.info(f"Successfully joined '{channel['name']}'") response = client.conversations_history( channel=channel["id"], oldest=oldest, latest=latest, limit=limit, ) response.validate() messages = cast(list[MessageType], response.get("messages", [])) cursor = cast(dict[str, Any], response.get("response_metadata", {})).get( "next_cursor", "" ) has_more = bool(cursor) return messages, has_more def _message_to_doc( message: MessageType, client: WebClient, channel: ChannelType, slack_cleaner: SlackTextCleaner, user_cache: dict[str, BasicExpertInfo | None], seen_thread_ts: set[str], channel_access: ExternalAccess | None, msg_filter_func: Callable[ [MessageType], SlackMessageFilterReason | None ] = default_msg_filter, ) -> tuple[Document | None, SlackMessageFilterReason | None]: """Returns a doc or None. If None is returned, the second element of the tuple may be a filter reason """ filtered_thread: ThreadType | None = None filter_reason: SlackMessageFilterReason | None = None thread_ts = message.get("thread_ts") if thread_ts: # NOTE: if thread_ts is present, there's a thread we need to process # ... otherwise, we can skip it # skip threads we've already seen, since we've already processed all # messages in that thread if thread_ts in seen_thread_ts: return None, None thread = get_thread( client=client, channel_id=channel["id"], thread_id=thread_ts ) # we'll just set and use the last filter reason if # we bomb out later filtered_thread = [] for message in thread: filter_reason = msg_filter_func(message) if filter_reason: continue filtered_thread.append(message) else: filter_reason = msg_filter_func(message) if filter_reason: return None, filter_reason filtered_thread = [message] # we'll just set and use the last filter reason if we get an empty list if not filtered_thread: return None, filter_reason doc = thread_to_doc( channel=channel, thread=filtered_thread, slack_cleaner=slack_cleaner, client=client, user_cache=user_cache, channel_access=channel_access, ) return doc, None def _get_all_doc_ids( client: WebClient, channels: list[str] | None = None, channel_name_regex_enabled: bool = False, msg_filter_func: Callable[ [MessageType], SlackMessageFilterReason | None ] = default_msg_filter, callback: IndexingHeartbeatInterface | None = None, workspace_url: str | None = None, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, ) -> GenerateSlimDocumentOutput: """ Get all document ids in the workspace, channel by channel This is pretty identical to get_all_docs, but it returns a set of ids instead of documents This makes it an order of magnitude faster than get_all_docs """ all_channels = get_channels(client) filtered_channels = filter_channels( all_channels, channels, channel_name_regex_enabled ) user_cache: dict[str, BasicExpertInfo | None] = {} for channel in filtered_channels: channel_id = channel["id"] # NOTE: external_access is a frozen object, so it's okay to safe to use a single # instance for all documents in the channel external_access = get_channel_access( client=client, channel=channel, user_cache=user_cache, ) # Yield the channel as a HierarchyNode first (before any documents) yield [_channel_to_hierarchy_node(channel, external_access, workspace_url)] channel_message_batches = get_channel_messages( client=client, channel=channel, callback=callback, oldest=str(start) if start else None, # 0.0 -> None intentionally latest=str(end) if end is not None else None, ) for message_batch in channel_message_batches: slim_doc_batch: list[SlimDocument | HierarchyNode] = [] for message in message_batch: filter_reason = msg_filter_func(message) if filter_reason: continue # The document id is the channel id and the ts of the first message in the thread # Since we already have the first message of the thread, we dont have to # fetch the thread for id retrieval, saving time and API calls slim_doc_batch.append( SlimDocument( id=_build_doc_id( channel_id=channel_id, thread_ts=message["ts"] ), external_access=external_access, parent_hierarchy_raw_node_id=channel_id, ) ) yield slim_doc_batch class ProcessedSlackMessage(BaseModel): doc: Document | None # if the message is part of a thread, this is the thread_ts # otherwise, this is the message_ts. Either way, will be a unique identifier. # In the future, if the message becomes a thread, then the thread_ts # will be set to the message_ts. thread_or_message_ts: str # if doc is None, filter_reason may be populated filter_reason: SlackMessageFilterReason | None failure: ConnectorFailure | None def _process_message( message: MessageType, client: WebClient, channel: ChannelType, slack_cleaner: SlackTextCleaner, user_cache: dict[str, BasicExpertInfo | None], seen_thread_ts: set[str], channel_access: ExternalAccess | None, msg_filter_func: Callable[ [MessageType], SlackMessageFilterReason | None ] = default_msg_filter, ) -> ProcessedSlackMessage: thread_ts = message.get("thread_ts") thread_or_message_ts = thread_ts or message["ts"] try: # causes random failures for testing checkpointing / continue on failure # import random # if random.random() > 0.95: # raise RuntimeError("Random failure :P") doc, filter_reason = _message_to_doc( message=message, client=client, channel=channel, slack_cleaner=slack_cleaner, user_cache=user_cache, seen_thread_ts=seen_thread_ts, channel_access=channel_access, msg_filter_func=msg_filter_func, ) return ProcessedSlackMessage( doc=doc, thread_or_message_ts=thread_or_message_ts, filter_reason=filter_reason, failure=None, ) except Exception as e: logger.exception(f"Error processing message {message['ts']}") return ProcessedSlackMessage( doc=None, thread_or_message_ts=thread_or_message_ts, filter_reason=None, failure=ConnectorFailure( failed_document=DocumentFailure( document_id=_build_doc_id( channel_id=channel["id"], thread_ts=thread_or_message_ts ), document_link=get_message_link(message, client, channel["id"]), ), failure_message=str(e), exception=e, ), ) class SlackConnector( SlimConnectorWithPermSync, CredentialsConnector, CheckpointedConnectorWithPermSync[SlackCheckpoint], ): FAST_TIMEOUT = 1 MAX_RETRIES = 7 # arbitrarily selected MAX_CHANNELS_TO_LOG = 50 # *** values to use when filtering bot channels *** # the number of messages in the batch must be greater than or equal to this number # to consider filtering the channel BOT_CHANNEL_MIN_BATCH_SIZE = 256 # the percentage of messages in the batch above which the channel will be considered # a bot channel BOT_CHANNEL_PERCENTAGE_THRESHOLD = 0.95 def __init__( self, channels: list[str] | None = None, # if specified, will treat the specified channel strings as # regexes, and will only index channels that fully match the regexes channel_regex_enabled: bool = False, # if True, messages from bots/apps will be indexed instead of filtered out include_bot_messages: bool = False, batch_size: int = INDEX_BATCH_SIZE, num_threads: int = SLACK_NUM_THREADS, use_redis: bool = True, ) -> None: self.channels = channels self.channel_regex_enabled = channel_regex_enabled self.include_bot_messages = include_bot_messages self.msg_filter_func = ( _bot_inclusive_msg_filter if include_bot_messages else default_msg_filter ) self.batch_size = batch_size self.num_threads = num_threads self.client: WebClient | None = None self.fast_client: WebClient | None = None # just used for efficiency self.text_cleaner: SlackTextCleaner | None = None self.user_cache: dict[str, BasicExpertInfo | None] = {} self.credentials_provider: CredentialsProviderInterface | None = None self.credential_prefix: str | None = None self.use_redis: bool = use_redis # Workspace URL for building channel links (e.g., https://myworkspace.slack.com) self._workspace_url: str | None = None # self.delay_lock: str | None = None # the redis key for the shared lock # self.delay_key: str | None = None # the redis key for the shared delay @classmethod @override def normalize_url(cls, url: str) -> NormalizationResult: """Normalize a Slack URL to extract channel_id__thread_ts format.""" parsed = urlparse(url) if "slack.com" not in parsed.netloc.lower(): return NormalizationResult(normalized_url=None, use_default=False) # Slack document IDs are format: channel_id__thread_ts # Extract from URL pattern: .../archives/{channel_id}/p{timestamp} path_parts = parsed.path.split("/") if "archives" not in path_parts: return NormalizationResult(normalized_url=None, use_default=False) archives_idx = path_parts.index("archives") if archives_idx + 1 >= len(path_parts): return NormalizationResult(normalized_url=None, use_default=False) channel_id = path_parts[archives_idx + 1] if archives_idx + 2 >= len(path_parts): return NormalizationResult(normalized_url=None, use_default=False) thread_part = path_parts[archives_idx + 2] if not thread_part.startswith("p"): return NormalizationResult(normalized_url=None, use_default=False) # Convert p1234567890123456 to 1234567890.123456 format timestamp_str = thread_part[1:] # Remove 'p' prefix if len(timestamp_str) == 16: # Insert dot at position 10 to match canonical format thread_ts = f"{timestamp_str[:10]}.{timestamp_str[10:]}" else: thread_ts = timestamp_str normalized = f"{channel_id}__{thread_ts}" return NormalizationResult(normalized_url=normalized, use_default=False) @staticmethod def make_credential_prefix(key: str) -> str: return f"connector:slack:credential_{key}" @staticmethod def make_delay_lock(prefix: str) -> str: return f"{prefix}:delay_lock" @staticmethod def make_delay_key(prefix: str) -> str: return f"{prefix}:delay" @staticmethod def make_slack_web_client( prefix: str, token: str, max_retry_count: int, r: Redis ) -> WebClient: delay_lock = SlackConnector.make_delay_lock(prefix) delay_key = SlackConnector.make_delay_key(prefix) # NOTE: slack has a built in RateLimitErrorRetryHandler, but it isn't designed # for concurrent workers. We've extended it with OnyxRedisSlackRetryHandler. connection_error_retry_handler = ConnectionErrorRetryHandler( max_retry_count=max_retry_count, interval_calculator=FixedValueRetryIntervalCalculator(), error_types=[ URLError, ConnectionResetError, RemoteDisconnected, IncompleteRead, ], ) onyx_rate_limit_error_retry_handler = OnyxRedisSlackRetryHandler( max_retry_count=max_retry_count, delay_key=delay_key, r=r, ) custom_retry_handlers: list[RetryHandler] = [ connection_error_retry_handler, onyx_rate_limit_error_retry_handler, ] client = OnyxSlackWebClient( delay_lock=delay_lock, delay_key=delay_key, r=r, token=token, retry_handlers=custom_retry_handlers, ) return client @property def channels(self) -> list[str] | None: return self._channels @channels.setter def channels(self, channels: list[str] | None) -> None: self._channels = ( [channel.removeprefix("#") for channel in channels] if channels else None ) def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: raise NotImplementedError("Use set_credentials_provider with this connector.") def set_credentials_provider( self, credentials_provider: CredentialsProviderInterface ) -> None: credentials = credentials_provider.get_credentials() tenant_id = credentials_provider.get_tenant_id() if not tenant_id: raise ValueError("tenant_id cannot be None!") bot_token = credentials["slack_bot_token"] if self.use_redis: self.redis = get_redis_client(tenant_id=tenant_id) self.credential_prefix = SlackConnector.make_credential_prefix( credentials_provider.get_provider_key() ) self.client = SlackConnector.make_slack_web_client( self.credential_prefix, bot_token, self.MAX_RETRIES, self.redis ) else: connection_error_retry_handler = ConnectionErrorRetryHandler( max_retry_count=self.MAX_RETRIES, interval_calculator=FixedValueRetryIntervalCalculator(), error_types=[ URLError, ConnectionResetError, RemoteDisconnected, IncompleteRead, ], ) self.client = WebClient( token=bot_token, retry_handlers=[connection_error_retry_handler] ) # use for requests that must return quickly (e.g. realtime flows where user is waiting) self.fast_client = WebClient( token=bot_token, timeout=SlackConnector.FAST_TIMEOUT ) self.text_cleaner = SlackTextCleaner(client=self.client) self.credentials_provider = credentials_provider # Extract workspace URL from auth_test response for building channel links try: auth_response = self.client.auth_test() self._workspace_url = auth_response.get("url") except Exception as e: logger.warning(f"Failed to get workspace URL from auth_test: {e}") self._workspace_url = None def retrieve_all_slim_docs_perm_sync( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, callback: IndexingHeartbeatInterface | None = None, ) -> GenerateSlimDocumentOutput: if self.client is None: raise ConnectorMissingCredentialError("Slack") return _get_all_doc_ids( client=self.client, channels=self.channels, channel_name_regex_enabled=self.channel_regex_enabled, msg_filter_func=self.msg_filter_func, callback=callback, workspace_url=self._workspace_url, start=start, end=end, ) def _load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: SlackCheckpoint, include_permissions: bool = False, ) -> CheckpointOutput[SlackCheckpoint]: """Rough outline: Step 1: Get all channels, yield back Checkpoint. Step 2: Loop through each channel. For each channel: Step 2.1: Get messages within the time range. Step 2.2: Process messages in parallel, yield back docs. Step 2.3: Update checkpoint with new_oldest, seen_thread_ts, and current_channel. Slack returns messages from newest to oldest, so we need to keep track of the latest message we've seen in each channel. Step 2.4: If there are no more messages in the channel, switch the current channel to the next channel. """ num_channels_remaining = 0 if self.client is None or self.text_cleaner is None: raise ConnectorMissingCredentialError("Slack") checkpoint = cast(SlackCheckpoint, copy.deepcopy(checkpoint)) # if this is the very first time we've called this, need to # get all relevant channels and save them into the checkpoint if checkpoint.channel_ids is None: raw_channels = get_channels(self.client) filtered_channels = filter_channels( raw_channels, self.channels, self.channel_regex_enabled ) logger.info( f"Channels - initial checkpoint: all={len(raw_channels)} post_filtering={len(filtered_channels)}" ) checkpoint.channel_ids = [c["id"] for c in filtered_channels] if len(filtered_channels) == 0: checkpoint.has_more = False return checkpoint checkpoint.current_channel = filtered_channels[0] if include_permissions: # checkpoint.current_channel is guaranteed to be non-None here since we just assigned it assert checkpoint.current_channel is not None channel_access = get_channel_access( client=self.client, channel=checkpoint.current_channel, user_cache=self.user_cache, ) checkpoint.current_channel_access = channel_access checkpoint.has_more = True return checkpoint final_channel_ids = checkpoint.channel_ids for channel_id in final_channel_ids: if channel_id not in checkpoint.channel_completion_map: num_channels_remaining += 1 logger.info( f"Channels - current status: " f"processed={len(final_channel_ids) - num_channels_remaining} " f"remaining={num_channels_remaining} " f"total={len(final_channel_ids)}" ) channel = checkpoint.current_channel if channel is None: raise ValueError("current_channel key not set in checkpoint") channel_id = channel["id"] if channel_id not in final_channel_ids: raise ValueError(f"Channel {channel_id} not found in checkpoint") channel_created = channel["created"] seen_thread_ts = set(checkpoint.seen_thread_ts) try: num_bot_filtered_messages = 0 num_other_filtered_messages = 0 oldest = str(start) if start else None latest = str(end) channel_message_ts = checkpoint.channel_completion_map.get(channel_id) if channel_message_ts: # Set oldest to the checkpoint timestamp to resume from where we left off oldest = channel_message_ts else: # First time processing this channel - yield its hierarchy node yield _channel_to_hierarchy_node( channel, checkpoint.current_channel_access, self._workspace_url, ) logger.debug( f"Getting messages for channel {channel} within range {oldest} - {latest}" ) message_batch, has_more_in_channel = _get_messages( channel, self.client, oldest, latest ) logger.info( f"Retrieved messages: {len(message_batch)=} {channel=} {oldest=} {latest=}" ) # message_batch[0] is the newest message (Slack returns newest to oldest) new_oldest = message_batch[0]["ts"] if message_batch else latest num_threads_start = len(seen_thread_ts) # Process messages in parallel using ThreadPoolExecutor with ThreadPoolExecutor(max_workers=self.num_threads) as executor: # NOTE(rkuo): this seems to be assuming the slack sdk is thread safe. # That's a very bold assumption! Haven't seen a direct issue with this # yet, but likely not correct to rely on. futures: list[Future[ProcessedSlackMessage]] = [] for message in message_batch: # Capture the current context so that the thread gets the current tenant ID current_context = contextvars.copy_context() futures.append( executor.submit( current_context.run, _process_message, message=message, client=self.client, channel=channel, slack_cleaner=self.text_cleaner, user_cache=self.user_cache, seen_thread_ts=seen_thread_ts, channel_access=checkpoint.current_channel_access, msg_filter_func=self.msg_filter_func, ) ) for future in as_completed(futures): processed_slack_message = future.result() doc = processed_slack_message.doc thread_or_message_ts = processed_slack_message.thread_or_message_ts failure = processed_slack_message.failure if doc: # handle race conditions here since this is single # threaded. Multi-threaded _process_message reads from this # but since this is single threaded, we won't run into simul # writes. At worst, we can duplicate a thread, which will be # deduped later on. if thread_or_message_ts not in seen_thread_ts: yield doc seen_thread_ts.add(thread_or_message_ts) elif processed_slack_message.filter_reason: if ( processed_slack_message.filter_reason == SlackMessageFilterReason.BOT ): num_bot_filtered_messages += 1 else: num_other_filtered_messages += 1 elif failure: yield failure num_threads_processed = len(seen_thread_ts) - num_threads_start # calculate a percentage progress for the current channel by determining # how much of the time range we've processed so far new_oldest_seconds_epoch = SecondsSinceUnixEpoch(new_oldest) range_start = start if start else max(0, channel_created) if new_oldest_seconds_epoch < range_start: range_complete = 0.0 else: range_complete = new_oldest_seconds_epoch - range_start range_total = end - range_start if range_total <= 0: range_total = 1 range_percent_complete = range_complete / range_total * 100.0 num_filtered = num_bot_filtered_messages + num_other_filtered_messages log_func = logger.warning if num_bot_filtered_messages > 0 else logger.info log_func( f"Message processing stats: " f"batch_len={len(message_batch)} " f"batch_yielded={num_threads_processed} " f"filtered={num_filtered} " f"(bot={num_bot_filtered_messages} other={num_other_filtered_messages}) " f"total_threads_seen={len(seen_thread_ts)}" ) logger.info( f"Current channel processing stats: {range_start=} range_end={end} percent_complete={range_percent_complete=:.2f}" ) checkpoint.seen_thread_ts = list(seen_thread_ts) checkpoint.channel_completion_map[channel["id"]] = new_oldest # bypass channels where the first set of messages seen are all # filtered (bots + disallowed subtypes like channel_join) # check at least MIN_BOT_MESSAGE_THRESHOLD messages are in the batch # we shouldn't skip based on a small sampling of messages if ( channel_message_ts is None and len(message_batch) > SlackConnector.BOT_CHANNEL_MIN_BATCH_SIZE ): if ( num_filtered > SlackConnector.BOT_CHANNEL_PERCENTAGE_THRESHOLD * len(message_batch) ): logger.warning( "Bypassing this channel since it appears to be mostly bot messages" ) has_more_in_channel = False if not has_more_in_channel: num_channels_remaining -= 1 new_channel_id = next( ( channel_id for channel_id in final_channel_ids if channel_id not in checkpoint.channel_completion_map ), None, ) if new_channel_id: new_channel = _get_channel_by_id(self.client, new_channel_id) checkpoint.current_channel = new_channel if include_permissions: channel_access = get_channel_access( client=self.client, channel=new_channel, user_cache=self.user_cache, ) checkpoint.current_channel_access = channel_access else: checkpoint.current_channel = None checkpoint.has_more = checkpoint.current_channel is not None channels_processed = len(final_channel_ids) - num_channels_remaining channels_percent_complete = ( channels_processed / len(final_channel_ids) * 100.0 ) logger.info( f"All channels processing stats: " f"processed={len(final_channel_ids) - num_channels_remaining} " f"remaining={num_channels_remaining} " f"total={len(final_channel_ids)} " f"percent_complete={channels_percent_complete:.2f}" ) except Exception as e: logger.exception(f"Error processing channel {channel['name']}") yield ConnectorFailure( failed_entity=EntityFailure( entity_id=channel["id"], missed_time_range=( datetime.fromtimestamp(start, tz=timezone.utc), datetime.fromtimestamp(end, tz=timezone.utc), ), ), failure_message=str(e), exception=e, ) return checkpoint def load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: SlackCheckpoint, ) -> CheckpointOutput[SlackCheckpoint]: return self._load_from_checkpoint( start, end, checkpoint, include_permissions=False ) def load_from_checkpoint_with_perm_sync( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: SlackCheckpoint, ) -> CheckpointOutput[SlackCheckpoint]: return self._load_from_checkpoint( start, end, checkpoint, include_permissions=True ) def validate_connector_settings(self) -> None: """ 1. Verify the bot token is valid for the workspace (via auth_test). 2. Ensure the bot has enough scope to list channels. 3. Check that every channel specified in self.channels exists (only when regex is not enabled). """ if self.fast_client is None: raise ConnectorMissingCredentialError("Slack credentials not loaded.") try: # 1) Validate connection to workspace auth_response = self.fast_client.auth_test() if not auth_response.get("ok", False): error_msg = auth_response.get( "error", "Unknown error from Slack auth_test" ) raise ConnectorValidationError(f"Failed Slack auth_test: {error_msg}") # 2) Minimal test to confirm listing channels works test_resp = self.fast_client.conversations_list( limit=1, types=["public_channel"] ) if not test_resp.get("ok", False): error_msg = test_resp.get("error", "Unknown error from Slack") if error_msg == "invalid_auth": raise ConnectorValidationError( f"Invalid Slack bot token ({error_msg})." ) elif error_msg == "not_authed": raise CredentialExpiredError( f"Invalid or expired Slack bot token ({error_msg})." ) raise UnexpectedValidationError( f"Slack API returned a failure: {error_msg}" ) # 3) If channels are specified and regex is not enabled, verify each is accessible # NOTE: removed this for now since it may be too slow for large workspaces which may # have some automations which create a lot of channels (100k+) # if self.channels and not self.channel_regex_enabled: # accessible_channels = get_channels( # client=self.fast_client, # exclude_archived=True, # get_public=True, # get_private=True, # ) # # For quick lookups by name or ID, build a map: # accessible_channel_names = {ch["name"] for ch in accessible_channels} # accessible_channel_ids = {ch["id"] for ch in accessible_channels} # for user_channel in self.channels: # if ( # user_channel not in accessible_channel_names # and user_channel not in accessible_channel_ids # ): # raise ConnectorValidationError( # f"Channel '{user_channel}' not found or inaccessible in this workspace." # ) except SlackApiError as e: slack_error = e.response.get("error", "") if slack_error == "ratelimited": # Handle rate limiting specifically retry_after = int(e.response.headers.get("Retry-After", 1)) logger.warning( f"Slack API rate limited during validation. Retry suggested after {retry_after} seconds. " "Proceeding with validation, but be aware that connector operations might be throttled." ) # Continue validation without failing - the connector is likely valid but just rate limited return elif slack_error == "missing_scope": raise InsufficientPermissionsError( "Slack bot token lacks the necessary scope to list/access channels. " "Please ensure your Slack app has 'channels:read' (and/or 'groups:read' for private channels)." ) elif slack_error == "invalid_auth": raise CredentialExpiredError( f"Invalid Slack bot token ({slack_error})." ) elif slack_error == "not_authed": raise CredentialExpiredError( f"Invalid or expired Slack bot token ({slack_error})." ) raise UnexpectedValidationError( f"Unexpected Slack error '{slack_error}' during settings validation." ) except ConnectorValidationError as e: raise e except Exception as e: raise UnexpectedValidationError( f"Unexpected error during Slack settings validation: {e}" ) @override def build_dummy_checkpoint(self) -> SlackCheckpoint: return SlackCheckpoint( channel_ids=None, channel_completion_map={}, current_channel=None, current_channel_access=None, seen_thread_ts=[], has_more=True, ) @override def validate_checkpoint_json(self, checkpoint_json: str) -> SlackCheckpoint: return SlackCheckpoint.model_validate_json(checkpoint_json) if __name__ == "__main__": import os import time from onyx.connectors.credentials_provider import OnyxStaticCredentialsProvider from shared_configs.contextvars import get_current_tenant_id slack_channel = os.environ.get("SLACK_CHANNEL") connector = SlackConnector( channels=[slack_channel] if slack_channel else None, ) provider = OnyxStaticCredentialsProvider( tenant_id=get_current_tenant_id(), connector_name="slack", credential_json={ "slack_bot_token": os.environ["SLACK_BOT_TOKEN"], }, ) connector.set_credentials_provider(provider) current = time.time() one_day_ago = current - 24 * 60 * 60 # 1 day checkpoint = connector.build_dummy_checkpoint() gen = connector.load_from_checkpoint( one_day_ago, current, cast(SlackCheckpoint, checkpoint), ) try: for document_or_failure in gen: if isinstance(document_or_failure, Document): print(document_or_failure) elif isinstance(document_or_failure, ConnectorFailure): print(document_or_failure) except StopIteration as e: checkpoint = e.value print("Next checkpoint:", checkpoint) print("Next checkpoint:", checkpoint) ================================================ FILE: backend/onyx/connectors/slack/models.py ================================================ from typing import NotRequired from typing_extensions import TypedDict class ChannelTopicPurposeType(TypedDict): """ Represents the topic or purpose of a Slack channel. """ value: str creator: str last_set: int class ChannelType(TypedDict): """ Represents a Slack channel. """ id: str name: str is_channel: bool is_group: bool is_im: bool created: int creator: str is_archived: bool is_general: bool unlinked: int name_normalized: str is_shared: bool is_ext_shared: bool is_org_shared: bool pending_shared: list[str] is_pending_ext_shared: bool is_member: bool is_private: bool is_mpim: bool updated: int topic: ChannelTopicPurposeType purpose: ChannelTopicPurposeType previous_names: list[str] num_members: int class AttachmentType(TypedDict): """ Represents a Slack message attachment. """ service_name: NotRequired[str] text: NotRequired[str] fallback: NotRequired[str] thumb_url: NotRequired[str] thumb_width: NotRequired[int] thumb_height: NotRequired[int] id: NotRequired[int] class BotProfileType(TypedDict): """ Represents a Slack bot profile. """ id: NotRequired[str] deleted: NotRequired[bool] name: NotRequired[str] updated: NotRequired[int] app_id: NotRequired[str] team_id: NotRequired[str] class MessageType(TypedDict): """ Represents a Slack message. """ type: str user: str text: str ts: str attachments: NotRequired[list[AttachmentType]] # Bot-related fields bot_id: NotRequired[str] app_id: NotRequired[str] bot_profile: NotRequired[BotProfileType] # Message threading thread_ts: NotRequired[str] # Message subtype (for filtering certain message types) subtype: NotRequired[str] # list of messages in a thread ThreadType = list[MessageType] ================================================ FILE: backend/onyx/connectors/slack/onyx_retry_handler.py ================================================ import random from typing import cast from typing import Optional from redis import Redis from slack_sdk.http_retry.handler import RetryHandler from slack_sdk.http_retry.request import HttpRequest from slack_sdk.http_retry.response import HttpResponse from slack_sdk.http_retry.state import RetryState from onyx.utils.logger import setup_logger logger = setup_logger() class OnyxRedisSlackRetryHandler(RetryHandler): """ This class uses Redis to share a rate limit among multiple threads. As currently implemented, this code is already surrounded by a lock in Redis via an override of _perform_urllib_http_request in OnyxSlackWebClient. This just sets the desired retry delay with TTL in redis. In conjunction with a custom subclass of the client, the value is read and obeyed prior to an API call and also serialized. Another way to do this is just to do exponential backoff. Might be easier? Adapted from slack's RateLimitErrorRetryHandler. """ """RetryHandler that does retries for rate limited errors.""" def __init__( self, max_retry_count: int, delay_key: str, r: Redis, ): """ delay_lock: the redis key to use with RedisLock (to synchronize access to delay_key) delay_key: the redis key containing a shared TTL """ super().__init__(max_retry_count=max_retry_count) self._redis: Redis = r self._delay_key = delay_key def _can_retry( self, *, state: RetryState, # noqa: ARG002 request: HttpRequest, # noqa: ARG002 response: Optional[HttpResponse] = None, error: Optional[Exception] = None, # noqa: ARG002 ) -> bool: return response is not None and response.status_code == 429 def prepare_for_next_attempt( self, *, state: RetryState, request: HttpRequest, # noqa: ARG002 response: Optional[HttpResponse] = None, error: Optional[Exception] = None, ) -> None: """As initially designed by the SDK authors, this function is responsible for the wait to retry ... aka we actually sleep in this function. This doesn't work well with multiple clients because every thread is unaware of the current retry value until it actually calls the endpoint. We're combining this with an actual subclass of the slack web client so that the delay is used BEFORE calling an API endpoint. The subclassed client has already taken the lock in redis when this method is called. """ ttl_ms: int | None = None retry_after_value: str | None = None retry_after_header_name: Optional[str] = None duration_s: float = 1.0 # seconds if response is None: # NOTE(rkuo): this logic comes from RateLimitErrorRetryHandler. # This reads oddly, as if the caller itself could raise the exception. # We don't have the luxury of changing this. if error: raise error return state.next_attempt_requested = True # this signals the caller to retry # calculate wait duration based on retry-after + some jitter for k in response.headers.keys(): if k.lower() == "retry-after": retry_after_header_name = k break try: if retry_after_header_name is None: # This situation usually does not arise. Just in case. raise ValueError( "OnyxRedisSlackRetryHandler.prepare_for_next_attempt: retry-after header name is None" ) retry_after_header_value = response.headers.get(retry_after_header_name) if not retry_after_header_value: raise ValueError( "OnyxRedisSlackRetryHandler.prepare_for_next_attempt: retry-after header value is None" ) # Handle case where header value might be a list retry_after_value = ( retry_after_header_value[0] if isinstance(retry_after_header_value, list) else retry_after_header_value ) retry_after_value_int = int( retry_after_value ) # will raise ValueError if somehow we can't convert to int jitter = retry_after_value_int * 0.25 * random.random() duration_s = retry_after_value_int + jitter except ValueError: duration_s += random.random() # Read and extend the ttl ttl_ms = cast(int, self._redis.pttl(self._delay_key)) if ttl_ms < 0: # negative values are error status codes ... see docs ttl_ms = 0 ttl_ms_new = ttl_ms + int(duration_s * 1000.0) self._redis.set(self._delay_key, "1", px=ttl_ms_new) logger.warning( f"OnyxRedisSlackRetryHandler.prepare_for_next_attempt setting delay: " f"current_attempt={state.current_attempt} " f"retry-after={retry_after_value} " f"{ttl_ms_new=}" ) state.increment_current_attempt() ================================================ FILE: backend/onyx/connectors/slack/onyx_slack_web_client.py ================================================ import threading import time from typing import Any from typing import cast from typing import Dict from urllib.request import Request from redis import Redis from redis.lock import Lock as RedisLock from slack_sdk import WebClient from onyx.connectors.slack.utils import ONYX_SLACK_LOCK_BLOCKING_TIMEOUT from onyx.connectors.slack.utils import ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT from onyx.connectors.slack.utils import ONYX_SLACK_LOCK_TTL from onyx.utils.logger import setup_logger logger = setup_logger() class OnyxSlackWebClient(WebClient): """Use in combination with the Onyx Retry Handler. This client wrapper enforces a proper retry delay through redis BEFORE the api call so that multiple clients can synchronize and rate limit properly. The retry handler writes the correct delay value to redis so that it is can be used by this wrapper. """ def __init__( self, delay_lock: str, delay_key: str, r: Redis, *args: Any, **kwargs: Any ) -> None: super().__init__(*args, **kwargs) self._delay_key = delay_key self._delay_lock = delay_lock self._redis: Redis = r self.num_requests: int = 0 self._lock = threading.Lock() def _perform_urllib_http_request( self, *, url: str, args: Dict[str, Dict[str, Any]] ) -> Dict[str, Any]: """By locking around the base class method, we ensure that both the delay from Redis and parsing/writing of retry values to Redis are handled properly in one place""" # lock and extend the ttl lock: RedisLock = self._redis.lock( self._delay_lock, timeout=ONYX_SLACK_LOCK_TTL, ) # try to acquire the lock start = time.monotonic() while True: acquired = lock.acquire(blocking_timeout=ONYX_SLACK_LOCK_BLOCKING_TIMEOUT) if acquired: break # if we couldn't acquire the lock but it exists, there's at least some activity # so keep trying... if self._redis.exists(self._delay_lock): continue if time.monotonic() - start > ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT: raise RuntimeError( f"OnyxSlackWebClient._perform_urllib_http_request - " f"timed out waiting for lock: {ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT=}" ) try: result = super()._perform_urllib_http_request(url=url, args=args) finally: if lock.owned(): lock.release() else: logger.warning( "OnyxSlackWebClient._perform_urllib_http_request lock not owned on release" ) time.monotonic() - start # logger.info( # f"OnyxSlackWebClient._perform_urllib_http_request: Releasing lock: {elapsed=}" # ) return result def _perform_urllib_http_request_internal( self, url: str, req: Request, ) -> Dict[str, Any]: """Overrides the internal method which is mostly the direct call to urllib/urlopen ... so this is a good place to perform our delay.""" # read and execute the delay delay_ms = cast(int, self._redis.pttl(self._delay_key)) if delay_ms < 0: # negative values are error status codes ... see docs delay_ms = 0 if delay_ms > 0: logger.warning( f"OnyxSlackWebClient._perform_urllib_http_request_internal delay: {delay_ms=} {self.num_requests=}" ) time.sleep(delay_ms / 1000.0) result = super()._perform_urllib_http_request_internal(url, req) with self._lock: self.num_requests += 1 # the delay key should have naturally expired by this point return result ================================================ FILE: backend/onyx/connectors/slack/utils.py ================================================ import re from collections.abc import Callable from collections.abc import Generator from functools import lru_cache from functools import wraps from typing import Any from typing import cast from slack_sdk import WebClient from slack_sdk.errors import SlackApiError from slack_sdk.web import SlackResponse from onyx.connectors.models import BasicExpertInfo from onyx.connectors.slack.models import MessageType from onyx.utils.logger import setup_logger from onyx.utils.retry_wrapper import retry_builder logger = setup_logger() # retry after 0.1, 1.2, 3.4, 7.8, 16.6, 34.2 seconds basic_retry_wrapper = retry_builder(tries=7) # number of messages we request per page when fetching paginated slack messages _SLACK_LIMIT = 900 # used to serialize access to the retry TTL ONYX_SLACK_LOCK_TTL = 1800 # how long the lock is allowed to idle before it expires ONYX_SLACK_LOCK_BLOCKING_TIMEOUT = 60 # how long to wait for the lock per wait attempt ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT = 3600 # how long to wait for the lock in total @lru_cache() def get_base_url(token: str) -> str: """Retrieve and cache the base URL of the Slack workspace based on the client token.""" client = WebClient(token=token) return client.auth_test()["url"] def get_message_link(event: MessageType, client: WebClient, channel_id: str) -> str: message_ts = event["ts"] message_ts_without_dot = message_ts.replace(".", "") thread_ts = event.get("thread_ts") base_url = get_base_url(client.token) link = f"{base_url.rstrip('/')}/archives/{channel_id}/p{message_ts_without_dot}" + ( f"?thread_ts={thread_ts}" if thread_ts else "" ) return link def make_slack_api_call( call: Callable[..., SlackResponse], **kwargs: Any ) -> SlackResponse: return call(**kwargs) def make_paginated_slack_api_call( call: Callable[..., SlackResponse], **kwargs: Any ) -> Generator[dict[str, Any], None, None]: return _make_slack_api_call_paginated(call)(**kwargs) def _make_slack_api_call_paginated( call: Callable[..., SlackResponse], ) -> Callable[..., Generator[dict[str, Any], None, None]]: """Wraps calls to slack API so that they automatically handle pagination""" @wraps(call) def paginated_call(**kwargs: Any) -> Generator[dict[str, Any], None, None]: cursor: str | None = None has_more = True while has_more: response = call(cursor=cursor, limit=_SLACK_LIMIT, **kwargs) yield cast(dict[str, Any], response.validate()) cursor = cast(dict[str, Any], response.get("response_metadata", {})).get( "next_cursor", "" ) has_more = bool(cursor) return paginated_call # NOTE(rkuo): we may not need this any more if the integrated retry handlers work as # expected. Do we want to keep this around? # def make_slack_api_rate_limited( # call: Callable[..., SlackResponse], max_retries: int = 7 # ) -> Callable[..., SlackResponse]: # """Wraps calls to slack API so that they automatically handle rate limiting""" # @wraps(call) # def rate_limited_call(**kwargs: Any) -> SlackResponse: # last_exception = None # for _ in range(max_retries): # try: # # Make the API call # response = call(**kwargs) # # Check for errors in the response, will raise `SlackApiError` # # if anything went wrong # response.validate() # return response # except SlackApiError as e: # last_exception = e # try: # error = e.response["error"] # except KeyError: # error = "unknown error" # if error == "ratelimited": # # Handle rate limiting: get the 'Retry-After' header value and sleep for that duration # retry_after = int(e.response.headers.get("Retry-After", 1)) # logger.info( # f"Slack call rate limited, retrying after {retry_after} seconds. Exception: {e}" # ) # time.sleep(retry_after) # elif error in ["already_reacted", "no_reaction", "internal_error"]: # # Log internal_error and return the response instead of failing # logger.warning( # f"Slack call encountered '{error}', skipping and continuing..." # ) # return e.response # else: # # Raise the error for non-transient errors # raise # # If the code reaches this point, all retries have been exhausted # msg = f"Max retries ({max_retries}) exceeded" # if last_exception: # raise Exception(msg) from last_exception # else: # raise Exception(msg) # return rate_limited_call # temporarily disabling due to using a different retry approach # might be permanent if everything works out # def make_slack_api_call_w_retries( # call: Callable[..., SlackResponse], **kwargs: Any # ) -> SlackResponse: # return basic_retry_wrapper(call)(**kwargs) # def make_paginated_slack_api_call_w_retries( # call: Callable[..., SlackResponse], **kwargs: Any # ) -> Generator[dict[str, Any], None, None]: # return _make_slack_api_call_paginated(basic_retry_wrapper(call))(**kwargs) def expert_info_from_slack_id( user_id: str | None, client: WebClient, user_cache: dict[str, BasicExpertInfo | None], ) -> BasicExpertInfo | None: if not user_id: return None if user_id in user_cache: return user_cache[user_id] response = client.users_info(user=user_id) if not response["ok"]: user_cache[user_id] = None return None user: dict = cast(dict[Any, dict], response.data).get("user", {}) profile = user.get("profile", {}) expert = BasicExpertInfo( display_name=user.get("real_name") or profile.get("display_name"), first_name=profile.get("first_name"), last_name=profile.get("last_name"), email=profile.get("email"), ) user_cache[user_id] = expert return expert class SlackTextCleaner: """Utility class to replace user IDs with usernames in a message. Handles caching, so the same request is not made multiple times for the same user ID""" def __init__(self, client: WebClient) -> None: self._client = client self._id_to_name_map: dict[str, str] = {} def _get_slack_name(self, user_id: str) -> str: if user_id not in self._id_to_name_map: try: response = self._client.users_info(user=user_id) # prefer display name if set, since that is what is shown in Slack self._id_to_name_map[user_id] = ( response["user"]["profile"]["display_name"] or response["user"]["profile"]["real_name"] ) except SlackApiError as e: logger.exception( f"Error fetching data for user {user_id}: {e.response['error']}" ) raise return self._id_to_name_map[user_id] def _replace_user_ids_with_names(self, message: str) -> str: # Find user IDs in the message user_ids = re.findall("<@(.*?)>", message) # Iterate over each user ID found for user_id in user_ids: try: if user_id in self._id_to_name_map: user_name = self._id_to_name_map[user_id] else: user_name = self._get_slack_name(user_id) # Replace the user ID with the username in the message message = message.replace(f"<@{user_id}>", f"@{user_name}") except Exception: logger.exception( f"Unable to replace user ID with username for user_id '{user_id}'" ) return message def index_clean(self, message: str) -> str: """During indexing, replace pattern sets that may cause confusion to the model Some special patterns are left in as they can provide information ie. links that contain format text|link, both the text and the link may be informative """ message = self._replace_user_ids_with_names(message) message = self.replace_tags_basic(message) message = self.replace_channels_basic(message) message = self.replace_special_mentions(message) message = self.replace_special_catchall(message) return message @staticmethod def replace_tags_basic(message: str) -> str: """Simply replaces all tags with `@` in order to prevent us from tagging users in Slack when we don't want to""" # Find user IDs in the message user_ids = re.findall("<@(.*?)>", message) for user_id in user_ids: message = message.replace(f"<@{user_id}>", f"@{user_id}") return message @staticmethod def replace_channels_basic(message: str) -> str: """Simply replaces all channel mentions with `#` in order to make a message work as part of a link""" # Find user IDs in the message channel_matches = re.findall(r"<#(.*?)\|(.*?)>", message) for channel_id, channel_name in channel_matches: message = message.replace( f"<#{channel_id}|{channel_name}>", f"#{channel_name}" ) return message @staticmethod def replace_special_mentions(message: str) -> str: """Simply replaces @channel, @here, and @everyone so we don't tag a bunch of people in Slack when we don't want to""" # Find user IDs in the message message = message.replace("", "@channel") message = message.replace("", "@here") message = message.replace("", "@everyone") return message @staticmethod def replace_special_catchall(message: str) -> str: """Replaces pattern of with another-thing This is added for but may match other cases as well """ pattern = r"]+)>" return re.sub(pattern, r"\2", message) @staticmethod def add_zero_width_whitespace_after_tag(message: str) -> str: """Add a 0 width whitespace after every @""" return message.replace("@", "@\u200b") ================================================ FILE: backend/onyx/connectors/teams/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/teams/connector.py ================================================ import copy import os from collections.abc import Iterator from datetime import datetime from datetime import timezone from typing import Any from typing import cast import msal # type: ignore from office365.graph_client import GraphClient # type: ignore from office365.runtime.client_request_exception import ClientRequestException # type: ignore from office365.runtime.http.request_options import RequestOptions # type: ignore[import-untyped] from office365.teams.channels.channel import Channel # type: ignore from office365.teams.team import Team # type: ignore from onyx.configs.constants import DocumentSource from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.exceptions import UnexpectedValidationError from onyx.connectors.interfaces import CheckpointedConnectorWithPermSync from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnectorWithPermSync from onyx.connectors.microsoft_graph_env import resolve_microsoft_environment from onyx.connectors.models import ConnectorCheckpoint from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import EntityFailure from onyx.connectors.models import HierarchyNode from onyx.connectors.models import SlimDocument from onyx.connectors.models import TextSection from onyx.connectors.teams.models import Message from onyx.connectors.teams.utils import fetch_expert_infos from onyx.connectors.teams.utils import fetch_external_access from onyx.connectors.teams.utils import fetch_messages from onyx.connectors.teams.utils import fetch_replies from onyx.file_processing.html_utils import parse_html_page_basic from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.logger import setup_logger from onyx.utils.threadpool_concurrency import run_with_timeout logger = setup_logger() _SLIM_DOC_BATCH_SIZE = 5000 class TeamsCheckpoint(ConnectorCheckpoint): todo_team_ids: list[str] | None = None DEFAULT_AUTHORITY_HOST = "https://login.microsoftonline.com" DEFAULT_GRAPH_API_HOST = "https://graph.microsoft.com" class TeamsConnector( CheckpointedConnectorWithPermSync[TeamsCheckpoint], SlimConnectorWithPermSync, ): MAX_WORKERS = 10 def __init__( self, # TODO: (chris) move from "Display Names" to IDs, since display names # are not necessarily guaranteed to be unique teams: list[str] = [], max_workers: int = MAX_WORKERS, authority_host: str = DEFAULT_AUTHORITY_HOST, graph_api_host: str = DEFAULT_GRAPH_API_HOST, ) -> None: self.graph_client: GraphClient | None = None self.msal_app: msal.ConfidentialClientApplication | None = None self.max_workers = max_workers self.requested_team_list: list[str] = teams resolved_env = resolve_microsoft_environment(graph_api_host, authority_host) self._azure_environment = resolved_env.environment self.authority_host = resolved_env.authority_host self.graph_api_host = resolved_env.graph_host # impls for BaseConnector def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: teams_client_id = credentials["teams_client_id"] teams_client_secret = credentials["teams_client_secret"] teams_directory_id = credentials["teams_directory_id"] authority_url = f"{self.authority_host}/{teams_directory_id}" self.msal_app = msal.ConfidentialClientApplication( authority=authority_url, client_id=teams_client_id, client_credential=teams_client_secret, ) def _acquire_token_func() -> dict[str, Any]: """ Acquire token via MSAL """ if self.msal_app is None: raise RuntimeError("MSAL app is not initialized") token = self.msal_app.acquire_token_for_client( scopes=[f"{self.graph_api_host}/.default"] ) if not isinstance(token, dict): raise RuntimeError("`token` instance must be of type dict") return token self.graph_client = GraphClient( _acquire_token_func, environment=self._azure_environment ) return None def validate_connector_settings(self) -> None: if self.graph_client is None: raise ConnectorMissingCredentialError("Teams credentials not loaded.") # Check if any requested teams have special characters that need client-side filtering has_special_chars = _has_odata_incompatible_chars(self.requested_team_list) if has_special_chars: logger.info( "Some requested team names contain special characters (&, (, )) that require " "client-side filtering during data retrieval." ) # Minimal validation: just check if we can access the teams endpoint timeout = 10 # Short timeout for basic validation try: # For validation, do a lightweight check instead of full team search logger.info( f"Requested team count: {len(self.requested_team_list) if self.requested_team_list else 0}, " f"Has special chars: {has_special_chars}" ) validation_query = self.graph_client.teams.get().top(1) run_with_timeout( timeout=timeout, func=lambda: validation_query.execute_query(), ) logger.info( "Teams validation successful - Access to teams endpoint confirmed" ) except TimeoutError as e: raise ConnectorValidationError( f"Timeout while validating Teams access (waited {timeout}s). " f"This may indicate network issues or authentication problems. " f"Error: {e}" ) except ClientRequestException as e: if not e.response: raise RuntimeError(f"No response provided in error; {e=}") status_code = e.response.status_code if status_code == 401: raise CredentialExpiredError( "Invalid or expired Microsoft Teams credentials (401 Unauthorized)." ) elif status_code == 403: raise InsufficientPermissionsError( "Your app lacks sufficient permissions to read Teams (403 Forbidden)." ) raise UnexpectedValidationError(f"Unexpected error retrieving teams: {e}") except Exception as e: error_str = str(e).lower() if ( "unauthorized" in error_str or "401" in error_str or "invalid_grant" in error_str ): raise CredentialExpiredError( "Invalid or expired Microsoft Teams credentials." ) elif "forbidden" in error_str or "403" in error_str: raise InsufficientPermissionsError( "App lacks required permissions to read from Microsoft Teams." ) raise ConnectorValidationError( f"Unexpected error during Teams validation: {e}" ) # impls for CheckpointedConnector def build_dummy_checkpoint(self) -> TeamsCheckpoint: return TeamsCheckpoint( has_more=True, ) def validate_checkpoint_json(self, checkpoint_json: str) -> TeamsCheckpoint: return TeamsCheckpoint.model_validate_json(checkpoint_json) def load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, # noqa: ARG002 checkpoint: TeamsCheckpoint, ) -> CheckpointOutput[TeamsCheckpoint]: if self.graph_client is None: raise ConnectorMissingCredentialError("Teams") checkpoint = cast(TeamsCheckpoint, copy.deepcopy(checkpoint)) todos = checkpoint.todo_team_ids if todos is None: teams = _collect_all_teams( graph_client=self.graph_client, requested=self.requested_team_list, ) todo_team_ids = [team.id for team in teams if team.id] return TeamsCheckpoint( todo_team_ids=todo_team_ids, has_more=bool(todo_team_ids), ) # `todos.pop()` should always return an element. This is because if # `todos` was the empty list, then we would have set `has_more=False` # during the previous invocation of `TeamsConnector.load_from_checkpoint`, # meaning that this function wouldn't have been called in the first place. todo_team_id = todos.pop() team = _get_team_by_id( graph_client=self.graph_client, team_id=todo_team_id, ) channels = _collect_all_channels_from_team( team=team, ) # An iterator of channels, in which each channel is an iterator of docs. channels_docs = [ _collect_documents_for_channel( graph_client=self.graph_client, team=team, channel=channel, start=start, ) for channel in channels ] # Was previously `for doc in parallel_yield(gens=docs, max_workers=self.max_workers): ...`. # However, that lead to some weird exceptions (potentially due to non-thread-safe behaviour in the Teams library). # Reverting back to the non-threaded case for now. for channel_docs in channels_docs: for channel_doc in channel_docs: if channel_doc: yield channel_doc logger.info( f"Processed team with id {todo_team_id}; {len(todos)} team(s) left to process" ) return TeamsCheckpoint( todo_team_ids=todos, has_more=bool(todos), ) def load_from_checkpoint_with_perm_sync( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: TeamsCheckpoint, ) -> CheckpointOutput[TeamsCheckpoint]: # Teams already fetches external_access (permissions) for each document # in _convert_thread_to_document, so we can just delegate to load_from_checkpoint return self.load_from_checkpoint(start, end, checkpoint) # impls for SlimConnectorWithPermSync def retrieve_all_slim_docs_perm_sync( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002 callback: IndexingHeartbeatInterface | None = None, ) -> GenerateSlimDocumentOutput: start = start or 0 teams = _collect_all_teams( graph_client=self.graph_client, requested=self.requested_team_list, ) for team in teams: if not team.id: logger.warning( f"Expected a team with an id, instead got no id: {team=}" ) continue channels = _collect_all_channels_from_team( team=team, ) for channel in channels: if not channel.id: logger.warning( f"Expected a channel with an id, instead got no id: {channel=}" ) continue external_access = fetch_external_access( graph_client=self.graph_client, channel=channel ) messages = fetch_messages( graph_client=self.graph_client, team_id=team.id, channel_id=channel.id, start=start, ) slim_doc_buffer: list[SlimDocument | HierarchyNode] = [] for message in messages: slim_doc_buffer.append( SlimDocument( id=message.id, external_access=external_access, ) ) if len(slim_doc_buffer) >= _SLIM_DOC_BATCH_SIZE: if callback: if callback.should_stop(): raise RuntimeError( "retrieve_all_slim_docs_perm_sync: Stop signal detected" ) callback.progress("retrieve_all_slim_docs_perm_sync", 1) yield slim_doc_buffer slim_doc_buffer = [] # Flush any remaining slim documents collected for this channel if slim_doc_buffer: yield slim_doc_buffer slim_doc_buffer = [] def _escape_odata_string(name: str) -> str: """Escape special characters for OData string literals. Uses proper OData v4 string literal escaping: - Single quotes: ' becomes '' - Other characters are handled by using contains() instead of eq for problematic cases """ # Escape single quotes for OData syntax (replace ' with '') escaped = name.replace("'", "''") return escaped def _has_odata_incompatible_chars(team_names: list[str] | None) -> bool: """Check if any team name contains characters that break Microsoft Graph OData filters. The Microsoft Graph Teams API has limited OData support. Characters like &, (, and ) cause parsing errors and require client-side filtering instead. """ if not team_names: return False return any(char in name for name in team_names for char in ["&", "(", ")"]) def _can_use_odata_filter( team_names: list[str] | None, ) -> tuple[bool, list[str], list[str]]: """Determine which teams can use OData filtering vs client-side filtering. Microsoft Graph /teams endpoint OData limitations: - Only supports basic 'eq' operators in filters - No 'contains', 'startswith', or other advanced operators - Special characters (&, (, )) break OData parsing Returns: tuple: (can_use_odata, safe_names, problematic_names) """ if not team_names: return False, [], [] safe_names = [] problematic_names = [] for name in team_names: if any(char in name for char in ["&", "(", ")"]): problematic_names.append(name) else: safe_names.append(name) return bool(safe_names), safe_names, problematic_names def _build_simple_odata_filter(safe_names: list[str]) -> str | None: """Build simple OData filter using only 'eq' operators for safe names.""" if not safe_names: return None filter_parts = [] for name in safe_names: escaped_name = _escape_odata_string(name) filter_parts.append(f"displayName eq '{escaped_name}'") return " or ".join(filter_parts) def _construct_semantic_identifier(channel: Channel, top_message: Message) -> str: top_message_user_name: str if top_message.from_ and top_message.from_.user: user_display_name = top_message.from_.user.display_name top_message_user_name = ( user_display_name if user_display_name else "Unknown User" ) else: logger.warning(f"Message {top_message=} has no `from.user` field") top_message_user_name = "Unknown User" top_message_content = top_message.body.content or "" top_message_subject = top_message.subject or "Unknown Subject" channel_name = channel.properties.get("displayName", "Unknown") try: snippet = parse_html_page_basic(top_message_content.rstrip()) snippet = snippet[:50] + "..." if len(snippet) > 50 else snippet except Exception: logger.exception( f"Error parsing snippet for message {top_message.id} with url {top_message.web_url}" ) snippet = "" semantic_identifier = ( f"{top_message_user_name} in {channel_name} about {top_message_subject}" ) if snippet: semantic_identifier += f": {snippet}" return semantic_identifier def _convert_thread_to_document( graph_client: GraphClient, channel: Channel, thread: list[Message], ) -> Document | None: if len(thread) == 0: return None most_recent_message_datetime: datetime | None = None top_message = thread[0] thread_text = "" sorted_thread = sorted(thread, key=lambda m: m.created_date_time, reverse=True) if sorted_thread: most_recent_message_datetime = sorted_thread[0].created_date_time for message in thread: # Add text and a newline if message.body.content: thread_text += parse_html_page_basic(message.body.content) # If it has a subject, that means its the top level post message, so grab its id, url, and subject if message.subject: top_message = message if not thread_text: return None semantic_string = _construct_semantic_identifier(channel, top_message) expert_infos = fetch_expert_infos(graph_client=graph_client, channel=channel) external_access = fetch_external_access( graph_client=graph_client, channel=channel, expert_infos=expert_infos ) return Document( id=top_message.id, sections=[TextSection(link=top_message.web_url, text=thread_text)], source=DocumentSource.TEAMS, semantic_identifier=semantic_string, title="", # teams threads don't really have a "title" doc_updated_at=most_recent_message_datetime, primary_owners=expert_infos, metadata={}, external_access=external_access, ) def _update_request_url(request: RequestOptions, next_url: str) -> None: request.url = next_url def _add_prefer_header(request: RequestOptions) -> None: """Add Prefer header to work around Microsoft Graph API ampersand bug. See: https://developer.microsoft.com/en-us/graph/known-issues/?search=18185 """ if not hasattr(request, "headers") or request.headers is None: request.headers = {} # Add header to handle properly encoded ampersands in filters request.headers["Prefer"] = "legacySearch=false" def _collect_all_teams( graph_client: GraphClient, requested: list[str] | None = None, ) -> list[Team]: """Collect teams from Microsoft Graph using appropriate filtering strategy. For teams with special characters (&, (, )), uses client-side filtering with paginated search. For teams without special characters, uses efficient OData server-side filtering. Args: graph_client: Authenticated Microsoft Graph client requested: List of team names to find, or None for all teams Returns: List of Team objects matching the requested names """ teams: list[Team] = [] next_url: str | None = None # Determine filtering strategy based on Microsoft Graph limitations if not requested: # No specific teams requested - return empty list (avoid fetching all teams) logger.info("No specific teams requested - returning empty list") return [] _, safe_names, problematic_names = _can_use_odata_filter(requested) if problematic_names and not safe_names: # ALL requested teams have special characters - cannot use OData filtering logger.info( f"All requested team names contain special characters (&, (, )) which require " f"client-side filtering. Using basic /teams endpoint with pagination. " f"Teams: {problematic_names}" ) # Use unfiltered query with pagination limit to avoid fetching too many teams use_client_side_filtering = True odata_filter = None elif problematic_names and safe_names: # Mixed scenario - need to fetch more teams to find the problematic ones logger.info( f"Mixed team types: will use client-side filtering for all. " f"Safe names: {safe_names}, Special char names: {problematic_names}" ) use_client_side_filtering = True odata_filter = None elif safe_names: # All names are safe - use OData filtering logger.info(f"Using OData filtering for all requested teams: {safe_names}") use_client_side_filtering = False odata_filter = _build_simple_odata_filter(safe_names) else: # No valid names return [] # Track pagination to avoid fetching too many teams for client-side filtering max_pages = 200 page_count = 0 while True: try: if use_client_side_filtering: # Use basic /teams endpoint with top parameter to limit results per page query = graph_client.teams.get().top(50) # Limit to 50 teams per page else: # Use OData filter with only 'eq' operators query = graph_client.teams.get().filter(odata_filter) # Add header to work around Microsoft Graph API issues query.before_execute(lambda req: _add_prefer_header(request=req)) if next_url: url = next_url query.before_execute( lambda req: _update_request_url(request=req, next_url=url) ) team_collection = query.execute_query() except (ClientRequestException, ValueError) as e: # If OData filter fails, fall back to client-side filtering if not use_client_side_filtering and odata_filter: logger.warning( f"OData filter failed: {e}. Falling back to client-side filtering." ) use_client_side_filtering = True odata_filter = None teams = [] next_url = None page_count = 0 continue # If client-side approach also fails, re-raise logger.error(f"Teams query failed: {e}") raise filtered_teams = ( team for team in team_collection if _filter_team(team=team, requested=requested) ) teams.extend(filtered_teams) # For client-side filtering, check if we found all requested teams or hit page limit if use_client_side_filtering: page_count += 1 found_team_names = { team.display_name for team in teams if team.display_name } requested_set = set(requested) # Log progress every 10 pages to avoid excessive logging if page_count % 10 == 0: logger.info( f"Searched {page_count} pages, found {len(found_team_names)} matching teams so far" ) # Stop if we found all requested teams or hit the page limit if requested_set.issubset(found_team_names): logger.info(f"Found all requested teams after {page_count} pages") break elif page_count >= max_pages: logger.warning( f"Reached maximum page limit ({max_pages}) while searching for teams. " f"Found: {found_team_names & requested_set}, " f"Missing: {requested_set - found_team_names}" ) break if not team_collection.has_next: break if not isinstance(team_collection._next_request_url, str): raise ValueError( f"The next request url field should be a string, instead got {type(team_collection._next_request_url)}" ) next_url = team_collection._next_request_url return teams def _normalize_team_name(name: str) -> str: """Normalize team name for flexible matching.""" if not name: return "" # Convert to lowercase and strip whitespace for case-insensitive matching return name.lower().strip() def _matches_requested_team( team_display_name: str, requested: list[str] | None ) -> bool: """Check if team display name matches any of the requested team names. Uses flexible matching to handle slight variations in team names. """ if not requested or not team_display_name: return ( not requested ) # If no teams requested, match all; if no name, don't match normalized_team_name = _normalize_team_name(team_display_name) for requested_name in requested: normalized_requested = _normalize_team_name(requested_name) # Exact match after normalization if normalized_team_name == normalized_requested: return True # Flexible matching - check if team name contains all significant words # This helps with slight variations in formatting team_words = set(normalized_team_name.split()) requested_words = set(normalized_requested.split()) # If the requested name has special characters, split on those too for char in ["&", "(", ")"]: if char in normalized_requested: # Split on special characters and add words parts = normalized_requested.replace(char, " ").split() requested_words.update(parts) # Remove very short words that aren't meaningful meaningful_requested_words = { word for word in requested_words if len(word) >= 3 } # Check if team name contains most of the meaningful words if ( meaningful_requested_words and len(meaningful_requested_words & team_words) >= len(meaningful_requested_words) * 0.7 ): return True return False def _filter_team( team: Team, requested: list[str] | None = None, ) -> bool: """ Returns the true if: - Team is not expired / deleted - Team has a display-name and ID - Team display-name matches any of the requested teams (with flexible matching) Otherwise, returns false. """ if not team.id or not team.display_name: return False if not _matches_requested_team(team.display_name, requested): return False props = team.properties expiration = props.get("expirationDateTime") deleted = props.get("deletedDateTime") # We just check for the existence of those two fields, not their actual dates. # This is because if these fields do exist, they have to have occurred in the past, thus making them already # expired / deleted. return not expiration and not deleted def _get_team_by_id( graph_client: GraphClient, team_id: str, ) -> Team: team_collection = ( graph_client.teams.get().filter(f"id eq '{team_id}'").top(1).execute_query() ) if not team_collection: raise ValueError(f"No team with {team_id=} was found") elif team_collection.has_next: # shouldn't happen, but catching it regardless raise RuntimeError(f"Multiple teams with {team_id=} were found") return team_collection[0] def _collect_all_channels_from_team( team: Team, ) -> list[Channel]: if not team.id: raise RuntimeError(f"The {team=} has an empty `id` field") channels: list[Channel] = [] next_url = None while True: query = team.channels.get_all( # explicitly needed because of incorrect type definitions provided by the `office365` library page_loaded=lambda _: None ) if next_url: url = next_url query = query.before_execute( lambda req: _update_request_url(request=req, next_url=url) ) channel_collection = query.execute_query() channels.extend(channel for channel in channel_collection if channel.id) if not channel_collection.has_next: break return channels def _collect_documents_for_channel( graph_client: GraphClient, team: Team, channel: Channel, start: SecondsSinceUnixEpoch, ) -> Iterator[Document | None | ConnectorFailure]: """ This function yields an iterator of `Document`s, where each `Document` corresponds to a "thread". A "thread" is the conjunction of the "root" message and all of its replies. """ for message in fetch_messages( graph_client=graph_client, team_id=team.id, channel_id=channel.id, start=start, ): try: replies = list( fetch_replies( graph_client=graph_client, team_id=team.id, channel_id=channel.id, root_message_id=message.id, ) ) thread = [message] thread.extend(replies[::-1]) # Note: # We convert an entire *thread* (including the root message and its replies) into one, singular `Document`. # I.e., we don't convert each individual message and each individual reply into their own individual `Document`s. if doc := _convert_thread_to_document( graph_client=graph_client, channel=channel, thread=thread, ): yield doc except Exception as e: yield ConnectorFailure( failed_entity=EntityFailure( entity_id=message.id, ), failure_message=f"Retrieval of message and its replies failed; {channel.id=} {message.id}", exception=e, ) if __name__ == "__main__": from tests.daily.connectors.utils import load_all_from_connector app_id = os.environ["TEAMS_APPLICATION_ID"] dir_id = os.environ["TEAMS_DIRECTORY_ID"] secret = os.environ["TEAMS_SECRET"] teams_env_var = os.environ.get("TEAMS", None) teams = teams_env_var.split(",") if teams_env_var else [] teams_connector = TeamsConnector(teams=teams) teams_connector.load_credentials( { "teams_client_id": app_id, "teams_directory_id": dir_id, "teams_client_secret": secret, } ) teams_connector.validate_connector_settings() for slim_doc in teams_connector.retrieve_all_slim_docs_perm_sync(): ... for doc in load_all_from_connector( connector=teams_connector, start=0.0, end=datetime.now(tz=timezone.utc).timestamp(), ).documents: print(doc) ================================================ FILE: backend/onyx/connectors/teams/models.py ================================================ from datetime import datetime from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field from pydantic.alias_generators import to_camel class Body(BaseModel): content_type: str content: str | None model_config = ConfigDict( alias_generator=to_camel, populate_by_name=True, ) class User(BaseModel): id: str display_name: str model_config = ConfigDict( alias_generator=to_camel, populate_by_name=True, ) class From(BaseModel): user: User | None model_config = ConfigDict( alias_generator=to_camel, populate_by_name=True, ) class Message(BaseModel): id: str replyToId: str | None subject: str | None from_: From | None = Field(alias="from") body: Body created_date_time: datetime last_modified_date_time: datetime | None last_edited_date_time: datetime | None deleted_date_time: datetime | None web_url: str model_config = ConfigDict( alias_generator=to_camel, populate_by_name=True, ) ================================================ FILE: backend/onyx/connectors/teams/utils.py ================================================ import time from collections.abc import Generator from datetime import datetime from datetime import timezone from http import HTTPStatus from office365.graph_client import GraphClient # type: ignore[import-untyped] from office365.teams.channels.channel import Channel # type: ignore[import-untyped] from office365.teams.channels.channel import ConversationMember from onyx.access.models import ExternalAccess from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import BasicExpertInfo from onyx.connectors.teams.models import Message from onyx.utils.logger import setup_logger logger = setup_logger() _PUBLIC_MEMBERSHIP_TYPE = "standard" # public teams channel def _sanitize_message_user_display_name(value: dict) -> dict: try: from_obj = value.get("from") if isinstance(from_obj, dict): user_obj = from_obj.get("user") if isinstance(user_obj, dict) and user_obj.get("displayName") is None: value = dict(value) from_obj = dict(from_obj) user_obj = dict(user_obj) user_obj["displayName"] = "Unknown User" from_obj["user"] = user_obj value["from"] = from_obj except (AttributeError, TypeError, KeyError): pass return value def _retry( graph_client: GraphClient, request_url: str, ) -> dict: MAX_RETRIES = 10 retry_number = 0 while retry_number < MAX_RETRIES: response = graph_client.execute_request_direct(request_url) if response.ok: json = response.json() if not isinstance(json, dict): raise RuntimeError(f"Expected a JSON object, instead got {json=}") return json if response.status_code == int(HTTPStatus.TOO_MANY_REQUESTS): retry_number += 1 cooldown = int(response.headers.get("Retry-After", 10)) time.sleep(cooldown) continue response.raise_for_status() raise RuntimeError( f"Max number of retries for hitting {request_url=} exceeded; unable to fetch data" ) def _get_next_url( graph_client: GraphClient, json_response: dict, ) -> str | None: next_url = json_response.get("@odata.nextLink") if not next_url: return None if not isinstance(next_url, str): raise RuntimeError( f"Expected a string for the `@odata.nextUrl`, instead got {next_url=}" ) return next_url.removeprefix(graph_client.service_root_url()).removeprefix("/") def _get_or_fetch_email( graph_client: GraphClient, member: ConversationMember, ) -> str | None: if email := member.properties.get("email"): return email user_id = member.properties.get("userId") if not user_id: logger.warn(f"No user-id found for this member; {member=}") return None json_data = _retry(graph_client=graph_client, request_url=f"users/{user_id}") email = json_data.get("userPrincipalName") if not isinstance(email, str): logger.warn(f"Expected email to be of type str, instead got {email=}") return None return email def _is_channel_public(channel: Channel) -> bool: return ( channel.membership_type and channel.membership_type == _PUBLIC_MEMBERSHIP_TYPE ) def fetch_messages( graph_client: GraphClient, team_id: str, channel_id: str, start: SecondsSinceUnixEpoch, ) -> Generator[Message]: startfmt = datetime.fromtimestamp(start, tz=timezone.utc).strftime( "%Y-%m-%dT%H:%M:%SZ" ) initial_request_url = f"teams/{team_id}/channels/{channel_id}/messages/delta?$filter=lastModifiedDateTime gt {startfmt}" request_url: str | None = initial_request_url while request_url: json_response = _retry(graph_client=graph_client, request_url=request_url) for value in json_response.get("value", []): yield Message(**_sanitize_message_user_display_name(value)) request_url = _get_next_url( graph_client=graph_client, json_response=json_response ) def fetch_replies( graph_client: GraphClient, team_id: str, channel_id: str, root_message_id: str, ) -> Generator[Message]: initial_request_url = ( f"teams/{team_id}/channels/{channel_id}/messages/{root_message_id}/replies" ) request_url: str | None = initial_request_url while request_url: json_response = _retry(graph_client=graph_client, request_url=request_url) for value in json_response.get("value", []): yield Message(**_sanitize_message_user_display_name(value)) request_url = _get_next_url( graph_client=graph_client, json_response=json_response ) def fetch_expert_infos( graph_client: GraphClient, channel: Channel ) -> list[BasicExpertInfo]: members = channel.members.get_all( # explicitly needed because of incorrect type definitions provided by the `office365` library page_loaded=lambda _: None ).execute_query_retry() expert_infos = [] for member in members: if not member.display_name: logger.warn(f"Failed to grab the display-name of {member=}; skipping") continue email = _get_or_fetch_email(graph_client=graph_client, member=member) if not email: logger.warn(f"Failed to grab the email of {member=}; skipping") continue expert_infos.append( BasicExpertInfo( display_name=member.display_name, email=email, ) ) return expert_infos def fetch_external_access( graph_client: GraphClient, channel: Channel, expert_infos: list[BasicExpertInfo] | None = None, ) -> ExternalAccess: is_public = _is_channel_public(channel=channel) if is_public: return ExternalAccess.public() expert_infos = ( expert_infos if expert_infos is not None else fetch_expert_infos(graph_client=graph_client, channel=channel) ) emails = {expert_info.email for expert_info in expert_infos if expert_info.email} return ExternalAccess( external_user_emails=emails, external_user_group_ids=set(), is_public=is_public, ) ================================================ FILE: backend/onyx/connectors/testrail/__init__.py ================================================ # Package marker for TestRail connector ================================================ FILE: backend/onyx/connectors/testrail/connector.py ================================================ from __future__ import annotations from collections.abc import Iterator from datetime import datetime from datetime import timezone from typing import Any from typing import ClassVar from typing import Optional import requests from bs4 import BeautifulSoup from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.exceptions import UnexpectedValidationError from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import TextSection from onyx.file_processing.html_utils import format_document_soup from onyx.utils.logger import setup_logger from onyx.utils.text_processing import remove_markdown_image_references logger = setup_logger() class TestRailConnector(LoadConnector, PollConnector): """Connector for TestRail. Minimal implementation that indexes Test Cases per project. """ document_source_type: ClassVar[DocumentSource] = DocumentSource.TESTRAIL # Fields that need ID-to-label value mapping FIELDS_NEEDING_VALUE_MAPPING: ClassVar[set[str]] = { "priority_id", "custom_automation_type", "custom_scenario_db_automation", "custom_case_golden_canvas_automation", "custom_customers", "custom_case_environments", "custom_case_overall_automation", "custom_case_team_ownership", "custom_case_unit_or_integration_automation", "custom_effort", } def __init__( self, batch_size: int = INDEX_BATCH_SIZE, project_ids: str | list[int] | None = None, cases_page_size: int | None = None, max_pages: int | None = None, skip_doc_absolute_chars: int | None = None, ) -> None: self.base_url: str | None = None self.username: str | None = None self.api_key: str | None = None self.batch_size = batch_size parsed_project_ids: list[int] | None # Parse project_ids from string if needed # None = all projects (no filtering), [] = no projects, [1,2,3] = specific projects if isinstance(project_ids, str): if project_ids.strip(): parsed_project_ids = [ int(x.strip()) for x in project_ids.split(",") if x.strip() ] else: # Empty string from UI means "all projects" parsed_project_ids = None elif project_ids is None: parsed_project_ids = None else: parsed_project_ids = [int(pid) for pid in project_ids] self.project_ids: list[int] | None = parsed_project_ids # Handle empty strings from UI and convert to int with defaults self.cases_page_size = ( int(cases_page_size) if cases_page_size and str(cases_page_size).strip() else 250 ) self.max_pages = ( int(max_pages) if max_pages and str(max_pages).strip() else 10000 ) self.skip_doc_absolute_chars = ( int(skip_doc_absolute_chars) if skip_doc_absolute_chars and str(skip_doc_absolute_chars).strip() else 200000 ) # Cache for field labels and value mappings - will be populated on first use self._field_labels: dict[str, str] | None = None self._value_maps: dict[str, dict[str, str]] | None = None # --- Rich text sanitization helpers --- # Note: TestRail stores some fields as HTML (e.g. shared test steps). # This function handles both HTML and plain text. @staticmethod def _sanitize_rich_text(value: Any) -> str: if value is None: return "" text = str(value) # Parse HTML and remove image tags soup = BeautifulSoup(text, "html.parser") # Remove all img tags and their containers for img_tag in soup.find_all("img"): img_tag.decompose() for span in soup.find_all("span", class_="markdown-img-container"): span.decompose() # Use format_document_soup for better HTML-to-text conversion # This preserves document structure (paragraphs, lists, line breaks, etc.) text = format_document_soup(soup) # Also remove markdown-style image references (in case any remain) text = remove_markdown_image_references(text) return text.strip() def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: # Expected keys from UI credential JSON self.base_url = str(credentials["testrail_base_url"]).rstrip("/") self.username = str(credentials["testrail_username"]) # email or username self.api_key = str(credentials["testrail_api_key"]) # API key (password) return None def validate_connector_settings(self) -> None: """Lightweight validation to surface common misconfigurations early.""" projects = self._list_projects() if not projects: logger.warning("TestRail: no projects visible to this credential.") # ---- API helpers ---- def _api_get(self, endpoint: str, params: Optional[dict[str, Any]] = None) -> Any: if not self.base_url or not self.username or not self.api_key: raise ConnectorMissingCredentialError("testrail") # TestRail API base is typically /index.php?/api/v2/ url = f"{self.base_url}/index.php?/api/v2/{endpoint}" try: response = requests.get( url, auth=(self.username, self.api_key), params=params, ) response.raise_for_status() except requests.exceptions.HTTPError as e: status = e.response.status_code if getattr(e, "response", None) else None if status == 401: raise CredentialExpiredError( "Invalid or expired TestRail credentials (HTTP 401)." ) from e if status == 403: raise InsufficientPermissionsError( "Insufficient permissions to access TestRail resources (HTTP 403)." ) from e raise UnexpectedValidationError( f"Unexpected TestRail HTTP error (status={status})." ) from e except requests.exceptions.RequestException as e: raise UnexpectedValidationError(f"TestRail request failed: {e}") from e try: return response.json() except ValueError as e: raise UnexpectedValidationError( "Invalid JSON returned by TestRail API" ) from e def _list_projects(self) -> list[dict[str, Any]]: projects = self._api_get("get_projects") if isinstance(projects, dict): projects_list = projects.get("projects") return projects_list if isinstance(projects_list, list) else [] return [] def _list_suites(self, project_id: int) -> list[dict[str, Any]]: """Return suites for a project. If the project is in single-suite mode, some TestRail instances may return an empty list; callers should gracefully fallback to calling get_cases without suite_id. """ suites = self._api_get(f"get_suites/{project_id}") if isinstance(suites, dict): suites_list = suites.get("suites") return suites_list if isinstance(suites_list, list) else [] return [] def _get_case_fields(self) -> list[dict[str, Any]]: """Get case field definitions from TestRail API.""" try: fields = self._api_get("get_case_fields") return fields if isinstance(fields, list) else [] except Exception as e: logger.warning(f"Failed to fetch case fields from TestRail: {e}") return [] def _parse_items_string(self, items_str: str) -> dict[str, str]: """Parse items string from field config into ID -> label mapping. Format: "1, Option A\\n2, Option B\\n3, Option C" Returns: {"1": "Option A", "2": "Option B", "3": "Option C"} """ id_to_label: dict[str, str] = {} if not items_str: return id_to_label for line in items_str.split("\n"): line = line.strip() if not line: continue parts = line.split(",", 1) if len(parts) == 2: item_id = parts[0].strip() item_label = parts[1].strip() id_to_label[item_id] = item_label return id_to_label def _build_field_maps(self) -> tuple[dict[str, str], dict[str, dict[str, str]]]: """Build both field labels and value mappings in one pass. Returns: (field_labels, value_maps) where: - field_labels: system_name -> label - value_maps: system_name -> {id -> label} """ field_labels = {} value_maps = {} try: fields = self._get_case_fields() for field in fields: system_name = field.get("system_name") # Build field label map label = field.get("label") if system_name and label: field_labels[system_name] = label # Build value map if needed if system_name in self.FIELDS_NEEDING_VALUE_MAPPING: configs = field.get("configs", []) if configs and len(configs) > 0: options = configs[0].get("options", {}) items_str = options.get("items") if items_str: value_maps[system_name] = self._parse_items_string( items_str ) except Exception as e: logger.warning(f"Failed to build field maps from TestRail: {e}") return field_labels, value_maps def _get_field_labels(self) -> dict[str, str]: """Get field labels, fetching from API if not cached.""" if self._field_labels is None: self._field_labels, self._value_maps = self._build_field_maps() return self._field_labels def _get_value_maps(self) -> dict[str, dict[str, str]]: """Get value maps, fetching from API if not cached.""" if self._value_maps is None: self._field_labels, self._value_maps = self._build_field_maps() return self._value_maps def _map_field_value(self, field_name: str, field_value: Any) -> str: """Map a field value using the value map if available. Examples: - priority_id: 2 -> "Medium" - custom_case_team_ownership: [10] -> "Sim Platform" - custom_case_environments: [1, 2] -> "Local, Cloud" """ if field_value is None or field_value == "": return "" # Get value map for this field value_maps = self._get_value_maps() value_map = value_maps.get(field_name, {}) # Handle list values if isinstance(field_value, list): if not field_value: return "" mapped = [value_map.get(str(v), str(v)) for v in field_value] return ", ".join(mapped) # Handle single values val_str = str(field_value) return value_map.get(val_str, val_str) def _get_cases( self, project_id: int, suite_id: Optional[int], limit: int, offset: int ) -> list[dict[str, Any]]: """Get cases for a project from the API.""" params: dict[str, Any] = {"limit": limit, "offset": offset} if suite_id is not None: params["suite_id"] = suite_id cases_response = self._api_get(f"get_cases/{project_id}", params=params) cases_list: list[dict[str, Any]] = [] if isinstance(cases_response, dict): cases_items = cases_response.get("cases") if isinstance(cases_items, list): cases_list = cases_items return cases_list def _iter_cases( self, project_id: int, suite_id: Optional[int] = None, start: Optional[SecondsSinceUnixEpoch] = None, end: Optional[SecondsSinceUnixEpoch] = None, ) -> Iterator[dict[str, Any]]: # Pagination: TestRail supports 'limit' and 'offset' for many list endpoints limit = self.cases_page_size # Use a bounded page loop to avoid infinite loops on API anomalies for page_index in range(self.max_pages): offset = page_index * limit cases = self._get_cases(project_id, suite_id, limit, offset) if not cases: break # Filter by updated window if provided for case in cases: # 'updated_on' is unix timestamp (seconds) updated_on = case.get("updated_on") or case.get("created_on") if start is not None and updated_on is not None and updated_on < start: continue if end is not None and updated_on is not None and updated_on > end: continue yield case if len(cases) < limit: break def _build_case_link(self, project_id: int, case_id: int) -> str: # noqa: ARG002 # Standard UI link to a case return f"{self.base_url}/index.php?/cases/view/{case_id}" def _doc_from_case( self, project: dict[str, Any], case: dict[str, Any], suite: dict[str, Any] | None = None, # noqa: ARG002 ) -> Document | None: project_id = project.get("id") if not isinstance(project_id, int): logger.warning( "Skipping TestRail case because project id is missing or invalid: %s", project_id, ) return None case_id = case.get("id") if not isinstance(case_id, int): logger.warning( "Skipping TestRail case because case id is missing or invalid: %s", case_id, ) return None title = case.get("title", f"Case {case_id}") case_key = f"C{case_id}" # Convert epoch seconds to aware datetime if available updated = case.get("updated_on") or case.get("created_on") updated_dt = ( datetime.fromtimestamp(updated, tz=timezone.utc) if isinstance(updated, (int, float)) else None ) text_lines: list[str] = [] if case.get("title"): text_lines.append(f"Title: {case['title']}") if case_key: text_lines.append(f"Case ID: {case_key}") if case_id is not None: text_lines.append(f"ID: {case_id}") doc_link = case.get("custom_documentation_link") if doc_link: text_lines.append(f"Documentation: {doc_link}") # Add fields that need value mapping field_labels = self._get_field_labels() for field_name in self.FIELDS_NEEDING_VALUE_MAPPING: field_value = case.get(field_name) if field_value is not None and field_value != "" and field_value != []: mapped_value = self._map_field_value(field_name, field_value) if mapped_value: # Get label from TestRail field definition label = field_labels.get( field_name, field_name.replace("_", " ").title() ) text_lines.append(f"{label}: {mapped_value}") pre = self._sanitize_rich_text(case.get("custom_preconds")) if pre: text_lines.append(f"Preconditions: {pre}") # Steps: use separated steps format if available steps_added = False steps_separated = case.get("custom_steps_separated") if isinstance(steps_separated, list) and steps_separated: rendered_steps: list[str] = [] for idx, step_item in enumerate(steps_separated, start=1): step_content = self._sanitize_rich_text(step_item.get("content")) step_expected = self._sanitize_rich_text(step_item.get("expected")) parts: list[str] = [] if step_content: parts.append(f"Step {idx}: {step_content}") else: parts.append(f"Step {idx}:") if step_expected: parts.append(f"Expected: {step_expected}") rendered_steps.append("\n".join(parts)) if rendered_steps: text_lines.append("Steps:\n" + "\n".join(rendered_steps)) steps_added = True # Fallback to custom_steps and custom_expected if no separated steps if not steps_added: custom_steps = self._sanitize_rich_text(case.get("custom_steps")) custom_expected = self._sanitize_rich_text(case.get("custom_expected")) if custom_steps: text_lines.append(f"Steps: {custom_steps}") if custom_expected: text_lines.append(f"Expected: {custom_expected}") link = self._build_case_link(project_id, case_id) # Build full text and apply size policies full_text = "\n".join(text_lines) if len(full_text) > self.skip_doc_absolute_chars: logger.warning( f"Skipping TestRail case {case_id} due to excessive size: {len(full_text)} chars" ) return None # Metadata for document identification metadata: dict[str, Any] = {} if case_key: metadata["case_key"] = case_key # Include the human-friendly case key in identifiers for easier search display_title = f"{case_key}: {title}" if case_key else title return Document( id=f"TESTRAIL_CASE_{case_id}", source=DocumentSource.TESTRAIL, semantic_identifier=display_title, title=display_title, sections=[TextSection(link=link, text=full_text)], metadata=metadata, doc_updated_at=updated_dt, ) def _generate_documents( self, start: Optional[SecondsSinceUnixEpoch], end: Optional[SecondsSinceUnixEpoch], ) -> GenerateDocumentsOutput: if not self.base_url or not self.username or not self.api_key: raise ConnectorMissingCredentialError("testrail") doc_batch: list[Document | HierarchyNode] = [] projects = self._list_projects() project_filter: list[int] | None = self.project_ids for project in projects: project_id_raw = project.get("id") if not isinstance(project_id_raw, int): logger.warning( "Skipping TestRail project with invalid id: %s", project_id_raw ) continue project_id = project_id_raw # None = index all, [] = index none, [1,2,3] = index only those if project_filter is not None and project_id not in project_filter: continue suites = self._list_suites(project_id) if suites: for s in suites: suite_id = s.get("id") for case in self._iter_cases(project_id, suite_id, start, end): doc = self._doc_from_case(project, case, s) if doc is None: continue doc_batch.append(doc) if len(doc_batch) >= self.batch_size: yield doc_batch doc_batch = [] else: # single-suite mode fallback for case in self._iter_cases(project_id, None, start, end): doc = self._doc_from_case(project, case, None) if doc is None: continue doc_batch.append(doc) if len(doc_batch) >= self.batch_size: yield doc_batch doc_batch = [] if doc_batch: yield doc_batch # ---- Onyx interfaces ---- def load_from_state(self) -> GenerateDocumentsOutput: return self._generate_documents(start=None, end=None) def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: return self._generate_documents(start=start, end=end) if __name__ == "__main__": from onyx.configs.app_configs import ( TESTRAIL_API_KEY, TESTRAIL_BASE_URL, TESTRAIL_USERNAME, ) connector = TestRailConnector() connector.load_credentials( { "testrail_base_url": TESTRAIL_BASE_URL, "testrail_username": TESTRAIL_USERNAME, "testrail_api_key": TESTRAIL_API_KEY, } ) connector.validate_connector_settings() # Probe a tiny batch from load total = 0 for batch in connector.load_from_state(): print(f"Fetched batch: {len(batch)} docs") total += len(batch) if total >= 10: break print(f"Total fetched in test: {total}") ================================================ FILE: backend/onyx/connectors/web/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/web/connector.py ================================================ import ipaddress import random import socket import time from datetime import datetime from datetime import timezone from enum import Enum from typing import Any from typing import cast from typing import Tuple from urllib.parse import urljoin from urllib.parse import urlparse import requests from bs4 import BeautifulSoup from oauthlib.oauth2 import BackendApplicationClient from playwright.sync_api import BrowserContext from playwright.sync_api import Playwright from playwright.sync_api import sync_playwright from playwright.sync_api import TimeoutError from requests_oauthlib import OAuth2Session # type:ignore from urllib3.exceptions import MaxRetryError from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.app_configs import WEB_CONNECTOR_OAUTH_CLIENT_ID from onyx.configs.app_configs import WEB_CONNECTOR_OAUTH_CLIENT_SECRET from onyx.configs.app_configs import WEB_CONNECTOR_OAUTH_TOKEN_URL from onyx.configs.app_configs import WEB_CONNECTOR_VALIDATE_URLS from onyx.configs.constants import DocumentSource from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.exceptions import UnexpectedValidationError from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import TextSection from onyx.file_processing.html_utils import web_html_cleanup from onyx.utils.logger import setup_logger from onyx.utils.sitemap import list_pages_for_site from onyx.utils.web_content import extract_pdf_text from onyx.utils.web_content import is_pdf_resource from shared_configs.configs import MULTI_TENANT logger = setup_logger() class ScrapeSessionContext: """Session level context for scraping""" def __init__(self, base_url: str, to_visit: list[str]): self.base_url = base_url self.to_visit = to_visit self.visited_links: set[str] = set() self.content_hashes: set[int] = set() self.doc_batch: list[Document | HierarchyNode] = [] self.at_least_one_doc: bool = False self.last_error: str | None = None self.needs_retry: bool = False self.playwright: Playwright | None = None self.playwright_context: BrowserContext | None = None def initialize(self) -> None: self.stop() self.playwright, self.playwright_context = start_playwright() def stop(self) -> None: if self.playwright_context: self.playwright_context.close() self.playwright_context = None if self.playwright: self.playwright.stop() self.playwright = None class ScrapeResult: doc: Document | None = None retry: bool = False WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20 # Threshold for determining when to replace vs append iframe content IFRAME_TEXT_LENGTH_THRESHOLD = 700 # Message indicating JavaScript is disabled, which often appears when scraping fails JAVASCRIPT_DISABLED_MESSAGE = "You have JavaScript disabled in your browser" # Grace period after page navigation to allow bot-detection challenges # and SPA content rendering to complete PAGE_RENDER_TIMEOUT_MS = 5000 # Define common headers that mimic a real browser DEFAULT_USER_AGENT = "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/123.0.0.0 Safari/537.36" DEFAULT_HEADERS = { "User-Agent": DEFAULT_USER_AGENT, "Accept": ( "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp," "image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7" ), "Accept-Language": "en-US,en;q=0.9", # Brotli decoding has been flaky in brotlicffi/httpx for certain chunked responses; # stick to gzip/deflate to keep connectivity checks stable. "Accept-Encoding": "gzip, deflate", "Connection": "keep-alive", "Upgrade-Insecure-Requests": "1", "Sec-Fetch-Dest": "document", "Sec-Fetch-Mode": "navigate", "Sec-Fetch-Site": "none", "Sec-Fetch-User": "?1", "Sec-CH-UA": '"Google Chrome";v="123", "Not:A-Brand";v="8"', "Sec-CH-UA-Mobile": "?0", "Sec-CH-UA-Platform": '"macOS"', } class WEB_CONNECTOR_VALID_SETTINGS(str, Enum): # Given a base site, index everything under that path RECURSIVE = "recursive" # Given a URL, index only the given page SINGLE = "single" # Given a sitemap.xml URL, parse all the pages in it SITEMAP = "sitemap" # Given a file upload where every line is a URL, parse all the URLs provided UPLOAD = "upload" def protected_url_check(url: str) -> None: """Couple considerations: - DNS mapping changes over time so we don't want to cache the results - Fetching this is assumed to be relatively fast compared to other bottlenecks like reading the page or embedding the contents - To be extra safe, all IPs associated with the URL must be global - This is to prevent misuse and not explicit attacks """ if not WEB_CONNECTOR_VALIDATE_URLS: return parse = urlparse(url) if parse.scheme != "http" and parse.scheme != "https": raise ValueError("URL must be of scheme https?://") if not parse.hostname: raise ValueError("URL must include a hostname") try: # This may give a large list of IP addresses for domains with extensive DNS configurations # such as large distributed systems of CDNs info = socket.getaddrinfo(parse.hostname, None) except socket.gaierror as e: raise ConnectionError(f"DNS resolution failed for {parse.hostname}: {e}") for address in info: ip = address[4][0] if not ipaddress.ip_address(ip).is_global: raise ValueError( f"Non-global IP address detected: {ip}, skipping page {url}. " f"The Web Connector is not allowed to read loopback, link-local, or private ranges" ) def check_internet_connection(url: str) -> None: try: # Use a more realistic browser-like request session = requests.Session() session.headers.update(DEFAULT_HEADERS) response = session.get(url, timeout=5, allow_redirects=True) response.raise_for_status() except requests.exceptions.HTTPError as e: # Extract status code from the response, defaulting to -1 if response is None status_code = e.response.status_code if e.response is not None else -1 # For 403 errors, we do have internet connection, but the request is blocked by the server # this is usually due to bot detection. Future calls (via Playwright) will usually get # around this. if status_code == 403: logger.warning( f"Received 403 Forbidden for {url}, will retry with browser automation" ) return error_msg = { 400: "Bad Request", 401: "Unauthorized", 403: "Forbidden", 404: "Not Found", 500: "Internal Server Error", 502: "Bad Gateway", 503: "Service Unavailable", 504: "Gateway Timeout", }.get(status_code, "HTTP Error") raise Exception(f"{error_msg} ({status_code}) for {url} - {e}") except requests.exceptions.SSLError as e: cause = ( e.args[0].reason if isinstance(e.args, tuple) and isinstance(e.args[0], MaxRetryError) else e.args ) raise Exception(f"SSL error {str(cause)}") except (requests.RequestException, ValueError) as e: raise Exception(f"Unable to reach {url} - check your internet connection: {e}") def is_valid_url(url: str) -> bool: try: result = urlparse(url) return all([result.scheme, result.netloc]) except ValueError: return False def _same_site(base_url: str, candidate_url: str) -> bool: base, candidate = urlparse(base_url), urlparse(candidate_url) base_netloc = base.netloc.lower().removeprefix("www.") candidate_netloc = candidate.netloc.lower().removeprefix("www.") if base_netloc != candidate_netloc: return False base_path = (base.path or "/").rstrip("/") if base_path in ("", "/"): return True candidate_path = candidate.path or "/" if candidate_path == base_path: return True boundary = f"{base_path}/" return candidate_path.startswith(boundary) def get_internal_links( base_url: str, url: str, soup: BeautifulSoup, should_ignore_pound: bool = True ) -> set[str]: internal_links = set() for link in cast(list[dict[str, Any]], soup.find_all("a")): href = cast(str | None, link.get("href")) if not href: continue # Account for malformed backslashes in URLs href = href.replace("\\", "/") # "#!" indicates the page is using a hashbang URL, which is a client-side routing technique if should_ignore_pound and "#" in href and "#!" not in href: href = href.split("#")[0] if not is_valid_url(href): # Relative path handling href = urljoin(url, href) if _same_site(base_url, href): internal_links.add(href) return internal_links def start_playwright() -> Tuple[Playwright, BrowserContext]: playwright = sync_playwright().start() # Launch browser with more realistic settings browser = playwright.chromium.launch( headless=True, args=[ "--disable-blink-features=AutomationControlled", "--disable-features=IsolateOrigins,site-per-process", "--disable-site-isolation-trials", ], ) # Create a context with realistic browser properties context = browser.new_context( user_agent=DEFAULT_USER_AGENT, viewport={"width": 1440, "height": 900}, device_scale_factor=2.0, locale="en-US", timezone_id="America/Los_Angeles", has_touch=False, java_script_enabled=True, color_scheme="light", # Add more realistic browser properties bypass_csp=True, ignore_https_errors=True, ) # Set additional headers to mimic a real browser context.set_extra_http_headers( { "Accept": DEFAULT_HEADERS["Accept"], "Accept-Language": DEFAULT_HEADERS["Accept-Language"], "Sec-Fetch-Dest": DEFAULT_HEADERS["Sec-Fetch-Dest"], "Sec-Fetch-Mode": DEFAULT_HEADERS["Sec-Fetch-Mode"], "Sec-Fetch-Site": DEFAULT_HEADERS["Sec-Fetch-Site"], "Sec-Fetch-User": DEFAULT_HEADERS["Sec-Fetch-User"], "Sec-CH-UA": DEFAULT_HEADERS["Sec-CH-UA"], "Sec-CH-UA-Mobile": DEFAULT_HEADERS["Sec-CH-UA-Mobile"], "Sec-CH-UA-Platform": DEFAULT_HEADERS["Sec-CH-UA-Platform"], "Cache-Control": "max-age=0", "DNT": "1", } ) # Add a script to modify navigator properties to avoid detection context.add_init_script( """ Object.defineProperty(navigator, 'webdriver', { get: () => undefined }); Object.defineProperty(navigator, 'plugins', { get: () => [1, 2, 3, 4, 5] }); Object.defineProperty(navigator, 'languages', { get: () => ['en-US', 'en'] }); """ ) if ( WEB_CONNECTOR_OAUTH_CLIENT_ID and WEB_CONNECTOR_OAUTH_CLIENT_SECRET and WEB_CONNECTOR_OAUTH_TOKEN_URL ): client = BackendApplicationClient(client_id=WEB_CONNECTOR_OAUTH_CLIENT_ID) oauth = OAuth2Session(client=client) token = oauth.fetch_token( token_url=WEB_CONNECTOR_OAUTH_TOKEN_URL, client_id=WEB_CONNECTOR_OAUTH_CLIENT_ID, client_secret=WEB_CONNECTOR_OAUTH_CLIENT_SECRET, ) context.set_extra_http_headers( {"Authorization": "Bearer {}".format(token["access_token"])} ) return playwright, context def extract_urls_from_sitemap(sitemap_url: str) -> list[str]: # requests should handle brotli compression automatically # as long as the brotli package is available in the venv. Leaving this line here to avoid # a regression as someone says "Ah, looks like this brotli package isn't used anywhere, let's remove it" # import brotli try: response = requests.get(sitemap_url, headers=DEFAULT_HEADERS) response.raise_for_status() soup = BeautifulSoup(response.content, "html.parser") urls = [ _ensure_absolute_url(sitemap_url, loc_tag.text) for loc_tag in soup.find_all("loc") ] if len(urls) == 0 and len(soup.find_all("urlset")) == 0: # the given url doesn't look like a sitemap, let's try to find one urls = list_pages_for_site(sitemap_url) if len(urls) == 0: raise ValueError( f"No URLs found in sitemap {sitemap_url}. Try using the 'single' or 'recursive' scraping options instead." ) return urls except requests.RequestException as e: raise RuntimeError(f"Failed to fetch sitemap from {sitemap_url}: {e}") except ValueError as e: raise RuntimeError(f"Error processing sitemap {sitemap_url}: {e}") except Exception as e: raise RuntimeError( f"Unexpected error while processing sitemap {sitemap_url}: {e}" ) def _ensure_absolute_url(source_url: str, maybe_relative_url: str) -> str: if not urlparse(maybe_relative_url).netloc: return urljoin(source_url, maybe_relative_url) return maybe_relative_url def _ensure_valid_url(url: str) -> str: if "://" not in url: return "https://" + url return url def _read_urls_file(location: str) -> list[str]: with open(location, "r") as f: urls = [_ensure_valid_url(line.strip()) for line in f if line.strip()] return urls def _get_datetime_from_last_modified_header(last_modified: str) -> datetime | None: try: return datetime.strptime(last_modified, "%a, %d %b %Y %H:%M:%S %Z").replace( tzinfo=timezone.utc ) except (ValueError, TypeError): return None def _handle_cookies(context: BrowserContext, url: str) -> None: """Handle cookies for the given URL to help with bot detection""" try: # Parse the URL to get the domain parsed_url = urlparse(url) domain = parsed_url.netloc # Add some common cookies that might help with bot detection cookies: list[dict[str, str]] = [ { "name": "cookieconsent", "value": "accepted", "domain": domain, "path": "/", }, { "name": "consent", "value": "true", "domain": domain, "path": "/", }, { "name": "session", "value": "random_session_id", "domain": domain, "path": "/", }, ] # Add cookies to the context for cookie in cookies: try: context.add_cookies([cookie]) # type: ignore except Exception as e: logger.debug(f"Failed to add cookie {cookie['name']} for {domain}: {e}") except Exception: logger.exception( f"Unexpected error while handling cookies for Web Connector with URL {url}" ) class WebConnector(LoadConnector): MAX_RETRIES = 3 def __init__( self, base_url: str, # Can't change this without disrupting existing users web_connector_type: str = WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value, mintlify_cleanup: bool = True, # Mostly ok to apply to other websites as well batch_size: int = INDEX_BATCH_SIZE, scroll_before_scraping: bool = False, **kwargs: Any, # noqa: ARG002 ) -> None: self.mintlify_cleanup = mintlify_cleanup self.batch_size = batch_size self.recursive = False self.scroll_before_scraping = scroll_before_scraping self.web_connector_type = web_connector_type if web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value: self.recursive = True self.to_visit_list = [_ensure_valid_url(base_url)] return elif web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.SINGLE.value: self.to_visit_list = [_ensure_valid_url(base_url)] elif web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.SITEMAP: self.to_visit_list = extract_urls_from_sitemap(_ensure_valid_url(base_url)) elif web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.UPLOAD: # Explicitly check if running in multi-tenant mode to prevent potential security risks if MULTI_TENANT: raise ValueError( "Upload input for web connector is not supported in cloud environments" ) logger.warning( "This is not a UI supported Web Connector flow, are you sure you want to do this?" ) self.to_visit_list = _read_urls_file(base_url) else: raise ValueError( "Invalid Web Connector Config, must choose a valid type between: " ) def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: if credentials: logger.warning("Unexpected credentials provided for Web Connector") return None def _do_scrape( self, index: int, initial_url: str, session_ctx: ScrapeSessionContext, ) -> ScrapeResult: """Returns a ScrapeResult object with a doc and retry flag.""" if session_ctx.playwright is None: raise RuntimeError("scrape_context.playwright is None") if session_ctx.playwright_context is None: raise RuntimeError("scrape_context.playwright_context is None") result = ScrapeResult() # Handle cookies for the URL _handle_cookies(session_ctx.playwright_context, initial_url) # First do a HEAD request to check content type without downloading the entire content head_response = requests.head( initial_url, headers=DEFAULT_HEADERS, allow_redirects=True ) content_type = head_response.headers.get("content-type") is_pdf = is_pdf_resource(initial_url, content_type) if is_pdf: # PDF files are not checked for links response = requests.get(initial_url, headers=DEFAULT_HEADERS) page_text, metadata = extract_pdf_text(response.content) last_modified = response.headers.get("Last-Modified") result.doc = Document( id=initial_url, sections=[TextSection(link=initial_url, text=page_text)], source=DocumentSource.WEB, semantic_identifier=initial_url.rstrip("/").split("/")[-1] or initial_url, metadata=metadata, doc_updated_at=( _get_datetime_from_last_modified_header(last_modified) if last_modified else None ), ) return result page = session_ctx.playwright_context.new_page() try: # Use "commit" instead of "domcontentloaded" to avoid hanging on bot-detection pages # that may never fire domcontentloaded. "commit" waits only for navigation to be # committed (response received), then we add a short wait for initial rendering. page_response = page.goto( initial_url, timeout=30000, # 30 seconds wait_until="commit", # Wait for navigation to commit ) # Give the page a moment to start rendering after navigation commits. # Allows CloudFlare and other bot-detection challenges to complete. page.wait_for_timeout(PAGE_RENDER_TIMEOUT_MS) # Wait for network activity to settle so SPAs that fetch content # asynchronously after the initial JS bundle have time to render. try: # A bit of extra time to account for long-polling, websockets, etc. page.wait_for_load_state("networkidle", timeout=PAGE_RENDER_TIMEOUT_MS) except TimeoutError: pass last_modified = ( page_response.header_value("Last-Modified") if page_response else None ) final_url = page.url if final_url != initial_url: protected_url_check(final_url) initial_url = final_url if initial_url in session_ctx.visited_links: logger.info( f"{index}: {initial_url} redirected to {final_url} - already indexed" ) page.close() return result logger.info(f"{index}: {initial_url} redirected to {final_url}") session_ctx.visited_links.add(initial_url) # If we got here, the request was successful if self.scroll_before_scraping: scroll_attempts = 0 previous_height = page.evaluate("document.body.scrollHeight") while scroll_attempts < WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS: page.evaluate("window.scrollTo(0, document.body.scrollHeight)") # Wait for content to load, but catch timeout if page never reaches networkidle # (e.g., CloudFlare protection keeps making requests) try: page.wait_for_load_state( "networkidle", timeout=PAGE_RENDER_TIMEOUT_MS ) except TimeoutError: # If networkidle times out, just give it a moment for content to render time.sleep(1) time.sleep(0.5) # let javascript run new_height = page.evaluate("document.body.scrollHeight") if new_height == previous_height: break # Stop scrolling when no more content is loaded previous_height = new_height scroll_attempts += 1 content = page.content() soup = BeautifulSoup(content, "html.parser") if self.recursive: internal_links = get_internal_links( session_ctx.base_url, initial_url, soup ) for link in internal_links: if link not in session_ctx.visited_links: session_ctx.to_visit.append(link) if page_response and str(page_response.status)[0] in ("4", "5"): session_ctx.last_error = f"Skipped indexing {initial_url} due to HTTP {page_response.status} response" logger.info(session_ctx.last_error) result.retry = True return result # after this point, we don't need the caller to retry parsed_html = web_html_cleanup(soup, self.mintlify_cleanup) """For websites containing iframes that need to be scraped, the code below can extract text from within these iframes. """ logger.debug( f"{index}: Length of cleaned text {len(parsed_html.cleaned_text)}" ) if JAVASCRIPT_DISABLED_MESSAGE in parsed_html.cleaned_text: iframe_count = page.frame_locator("iframe").locator("html").count() if iframe_count > 0: iframe_texts = ( page.frame_locator("iframe").locator("html").all_inner_texts() ) document_text = "\n".join(iframe_texts) """ 700 is the threshold value for the length of the text extracted from the iframe based on the issue faced """ if len(parsed_html.cleaned_text) < IFRAME_TEXT_LENGTH_THRESHOLD: parsed_html.cleaned_text = document_text else: parsed_html.cleaned_text += "\n" + document_text # Sometimes pages with #! will serve duplicate content # There are also just other ways this can happen hashed_text = hash((parsed_html.title, parsed_html.cleaned_text)) if hashed_text in session_ctx.content_hashes: logger.info( f"{index}: Skipping duplicate title + content for {initial_url}" ) return result session_ctx.content_hashes.add(hashed_text) result.doc = Document( id=initial_url, sections=[TextSection(link=initial_url, text=parsed_html.cleaned_text)], source=DocumentSource.WEB, semantic_identifier=parsed_html.title or initial_url, metadata={}, doc_updated_at=( _get_datetime_from_last_modified_header(last_modified) if last_modified else None ), ) finally: page.close() return result def load_from_state(self) -> GenerateDocumentsOutput: """Traverses through all pages found on the website and converts them into documents""" if not self.to_visit_list: raise ValueError("No URLs to visit") base_url = self.to_visit_list[0] # For the recursive case check_internet_connection(base_url) # make sure we can connect to the base url session_ctx = ScrapeSessionContext(base_url, self.to_visit_list) session_ctx.initialize() while session_ctx.to_visit: initial_url = session_ctx.to_visit.pop() if initial_url in session_ctx.visited_links: continue session_ctx.visited_links.add(initial_url) try: protected_url_check(initial_url) except Exception as e: session_ctx.last_error = f"Invalid URL {initial_url} due to {e}" logger.warning(session_ctx.last_error) continue index = len(session_ctx.visited_links) logger.info(f"{index}: Visiting {initial_url}") # Add retry mechanism with exponential backoff retry_count = 0 while retry_count < self.MAX_RETRIES: if retry_count > 0: # Add a random delay between retries (exponential backoff) delay = min(2**retry_count + random.uniform(0, 1), 10) logger.info( f"Retry {retry_count}/{self.MAX_RETRIES} for {initial_url} after {delay:.2f}s delay" ) time.sleep(delay) try: result = self._do_scrape(index, initial_url, session_ctx) if result.retry: continue if result.doc: session_ctx.doc_batch.append(result.doc) except Exception as e: session_ctx.last_error = f"Failed to fetch '{initial_url}': {e}" logger.exception(session_ctx.last_error) session_ctx.initialize() continue finally: retry_count += 1 break # success / don't retry if len(session_ctx.doc_batch) >= self.batch_size: session_ctx.initialize() session_ctx.at_least_one_doc = True yield session_ctx.doc_batch session_ctx.doc_batch = [] if session_ctx.doc_batch: session_ctx.stop() session_ctx.at_least_one_doc = True yield session_ctx.doc_batch if not session_ctx.at_least_one_doc: if session_ctx.last_error: raise RuntimeError(session_ctx.last_error) raise RuntimeError("No valid pages found.") session_ctx.stop() def validate_connector_settings(self) -> None: # Make sure we have at least one valid URL to check if not self.to_visit_list: raise ConnectorValidationError( "No URL configured. Please provide at least one valid URL." ) if ( self.web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.SITEMAP.value or self.web_connector_type == WEB_CONNECTOR_VALID_SETTINGS.RECURSIVE.value ): return None # We'll just test the first URL for connectivity and correctness test_url = self.to_visit_list[0] # Check that the URL is allowed and well-formed try: protected_url_check(test_url) except ValueError as e: raise ConnectorValidationError( f"Protected URL check failed for '{test_url}': {e}" ) except ConnectionError as e: # Typically DNS or other network issues raise ConnectorValidationError(str(e)) # Make a quick request to see if we get a valid response try: check_internet_connection(test_url) except Exception as e: err_str = str(e) if "401" in err_str: raise CredentialExpiredError( f"Unauthorized access to '{test_url}': {e}" ) elif "403" in err_str: raise InsufficientPermissionsError( f"Forbidden access to '{test_url}': {e}" ) elif "404" in err_str: raise ConnectorValidationError(f"Page not found for '{test_url}': {e}") elif "Max retries exceeded" in err_str and "NameResolutionError" in err_str: raise ConnectorValidationError( f"Unable to resolve hostname for '{test_url}'. Please check the URL and your internet connection." ) else: # Could be a 5xx or another error, treat as unexpected raise UnexpectedValidationError( f"Unexpected error validating '{test_url}': {e}" ) if __name__ == "__main__": connector = WebConnector("https://docs.onyx.app/") document_batches = connector.load_from_state() print(next(document_batches)) ================================================ FILE: backend/onyx/connectors/wikipedia/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/wikipedia/connector.py ================================================ from typing import ClassVar from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.mediawiki import wiki class WikipediaConnector(wiki.MediaWikiConnector): """Connector for Wikipedia.""" document_source_type: ClassVar[DocumentSource] = DocumentSource.WIKIPEDIA def __init__( self, categories: list[str], pages: list[str], recurse_depth: int, language_code: str = "en", batch_size: int = INDEX_BATCH_SIZE, ) -> None: super().__init__( hostname="wikipedia.org", categories=categories, pages=pages, recurse_depth=recurse_depth, language_code=language_code, batch_size=batch_size, ) ================================================ FILE: backend/onyx/connectors/xenforo/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/xenforo/connector.py ================================================ """ This is the XenforoConnector class. It is used to connect to a Xenforo forum and load or update documents from the forum. To use this class, you need to provide the URL of the Xenforo forum board you want to connect to when creating an instance of the class. The URL should be a string that starts with 'http://' or 'https://', followed by the domain name of the forum, followed by the board name. For example: base_url = 'https://www.example.com/forum/boards/some-topic/' The `load_from_state` method is used to load documents from the forum. It takes an optional `state` parameter, which can be used to specify a state from which to start loading documents. """ import re from datetime import datetime from datetime import timedelta from datetime import timezone from typing import Any from urllib.parse import urlparse import pytz import requests from bs4 import BeautifulSoup from bs4 import Tag from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.miscellaneous_utils import datetime_to_utc from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import TextSection from onyx.utils.logger import setup_logger logger = setup_logger() def get_title(soup: BeautifulSoup) -> str: el = soup.find("h1", "p-title-value") if not el: return "" title = el.text for char in (";", ":", "!", "*", "/", "\\", "?", '"', "<", ">", "|"): title = title.replace(char, "_") return title def get_pages(soup: BeautifulSoup, url: str) -> list[str]: page_tags = soup.select("li.pageNav-page") page_numbers = [] for button in page_tags: if re.match(r"^\d+$", button.text): page_numbers.append(button.text) max_pages = int(max(page_numbers, key=int)) if page_numbers else 1 all_pages = [] for x in range(1, int(max_pages) + 1): all_pages.append(f"{url}page-{x}") return all_pages def parse_post_date(post_element: BeautifulSoup) -> datetime: el = post_element.find("time") if not isinstance(el, Tag) or "datetime" not in el.attrs: return datetime.utcfromtimestamp(0).replace(tzinfo=timezone.utc) date_value = el["datetime"] # Ensure date_value is a string (if it's a list, take the first element) if isinstance(date_value, list): date_value = date_value[0] post_date = datetime.strptime(date_value, "%Y-%m-%dT%H:%M:%S%z") return datetime_to_utc(post_date) def scrape_page_posts( soup: BeautifulSoup, page_index: int, url: str, initial_run: bool, start_time: datetime, ) -> list: title = get_title(soup) documents = [] for post in soup.find_all("div", class_="message-inner"): post_date = parse_post_date(post) if initial_run or post_date > start_time: el = post.find("div", class_="bbWrapper") if not el: continue post_text = el.get_text(strip=True) + "\n" author_tag = post.find("a", class_="username") if author_tag is None: author_tag = post.find("span", class_="username") author = author_tag.get_text(strip=True) if author_tag else "Deleted author" formatted_time = post_date.strftime("%Y-%m-%d %H:%M:%S") # TODO: if a caller calls this for each page of a thread, it may see the # same post multiple times if there is a sticky post # that appears on each page of a thread. # it's important to generate unique doc id's, so page index is part of the # id. We may want to de-dupe this stuff inside the indexing service. document = Document( id=f"{DocumentSource.XENFORO.value}_{title}_{page_index}_{formatted_time}", sections=[TextSection(link=url, text=post_text)], title=title, source=DocumentSource.XENFORO, semantic_identifier=title, primary_owners=[BasicExpertInfo(display_name=author)], metadata={ "type": "post", "author": author, "time": formatted_time, }, doc_updated_at=post_date, ) documents.append(document) return documents class XenforoConnector(LoadConnector): # Class variable to track if the connector has been run before has_been_run_before = False def __init__(self, base_url: str) -> None: self.base_url = base_url self.initial_run = not XenforoConnector.has_been_run_before self.start = datetime.utcnow().replace(tzinfo=pytz.utc) - timedelta(days=1) self.cookies: dict[str, str] = {} # mimic user browser to avoid being blocked by the website (see: https://www.useragents.me/) self.headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) " "AppleWebKit/537.36 (KHTML, like Gecko) " "Chrome/121.0.0.0 Safari/537.36" } def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: if credentials: logger.warning("Unexpected credentials provided for Xenforo Connector") return None def load_from_state(self) -> GenerateDocumentsOutput: # Standardize URL to always end in /. if self.base_url[-1] != "/": self.base_url += "/" # Remove all extra parameters from the end such as page, post. matches = ("threads/", "boards/", "forums/") for each in matches: if each in self.base_url: try: self.base_url = self.base_url[ 0 : self.base_url.index( "/", self.base_url.index(each) + len(each) ) + 1 ] except ValueError: pass doc_batch: list[Document | HierarchyNode] = [] all_threads = [] # If the URL contains "boards/" or "forums/", find all threads. if "boards/" in self.base_url or "forums/" in self.base_url: pages = get_pages(self.requestsite(self.base_url), self.base_url) # Get all pages on thread_list_page for pre_count, thread_list_page in enumerate(pages, start=1): logger.info( f"Getting pages from thread_list_page.. Current: {pre_count}/{len(pages)}\r" ) all_threads += self.get_threads(thread_list_page) # If the URL contains "threads/", add the thread to the list. elif "threads/" in self.base_url: all_threads.append(self.base_url) # Process all threads for thread_count, thread_url in enumerate(all_threads, start=1): soup = self.requestsite(thread_url) if soup is None: logger.error(f"Failed to load page: {self.base_url}") continue pages = get_pages(soup, thread_url) # Getting all pages for all threads for page_index, page in enumerate(pages, start=1): logger.info( f"Progress: Page {page_index}/{len(pages)} - Thread {thread_count}/{len(all_threads)}\r" ) soup_page = self.requestsite(page) doc_batch.extend( scrape_page_posts( soup_page, page_index, thread_url, self.initial_run, self.start ) ) if doc_batch: yield doc_batch # Mark the initial run finished after all threads and pages have been processed XenforoConnector.has_been_run_before = True def get_threads(self, url: str) -> list[str]: soup = self.requestsite(url) thread_tags = soup.find_all(class_="structItem-title") base_url = "{uri.scheme}://{uri.netloc}".format(uri=urlparse(url)) threads = [] for x in thread_tags: y = x.find_all(href=True) for element in y: link = element["href"] if "threads/" in link: stripped = link[0 : link.rfind("/") + 1] if base_url + stripped not in threads: threads.append(base_url + stripped) return threads def requestsite(self, url: str) -> BeautifulSoup: try: response = requests.get( url, cookies=self.cookies, headers=self.headers, timeout=10 ) if response.status_code != 200: logger.error( f"<{url}> Request Error: {response.status_code} - {response.reason}" ) return BeautifulSoup(response.text, "html.parser") except TimeoutError: logger.error("Timed out Error.") except Exception as e: logger.error(f"Error on {url}") logger.exception(e) return BeautifulSoup("", "html.parser") if __name__ == "__main__": connector = XenforoConnector( # base_url="https://cassiopaea.org/forum/threads/how-to-change-your-emotional-state.41381/" base_url="https://xenforo.com/community/threads/whats-new-with-enhanced-search-resource-manager-and-media-gallery-in-xenforo-2-3.220935/" ) document_batches = connector.load_from_state() print(next(document_batches)) ================================================ FILE: backend/onyx/connectors/zendesk/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/zendesk/connector.py ================================================ import copy import time from collections.abc import Callable from collections.abc import Iterator from typing import Any from typing import cast import requests from pydantic import BaseModel from requests.exceptions import HTTPError from typing_extensions import override from onyx.configs.app_configs import ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS from onyx.configs.constants import DocumentSource from onyx.connectors.cross_connector_utils.miscellaneous_utils import ( time_str_to_utc, ) from onyx.connectors.cross_connector_utils.rate_limit_wrapper import ( rate_limit_builder, ) from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.exceptions import CredentialExpiredError from onyx.connectors.exceptions import InsufficientPermissionsError from onyx.connectors.interfaces import CheckpointedConnector from onyx.connectors.interfaces import CheckpointOutput from onyx.connectors.interfaces import ConnectorFailure from onyx.connectors.interfaces import GenerateSlimDocumentOutput from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.interfaces import SlimConnectorWithPermSync from onyx.connectors.models import BasicExpertInfo from onyx.connectors.models import ConnectorCheckpoint from onyx.connectors.models import Document from onyx.connectors.models import DocumentFailure from onyx.connectors.models import HierarchyNode from onyx.connectors.models import SlimDocument from onyx.connectors.models import TextSection from onyx.file_processing.html_utils import parse_html_page_basic from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.utils.retry_wrapper import retry_builder MAX_PAGE_SIZE = 30 # Zendesk API maximum MAX_AUTHOR_MAP_SIZE = 50_000 # Reset author map cache if it gets too large _SLIM_BATCH_SIZE = 1000 class ZendeskCredentialsNotSetUpError(PermissionError): def __init__(self) -> None: super().__init__( "Zendesk Credentials are not set up, was load_credentials called?" ) class ZendeskClient: def __init__( self, subdomain: str, email: str, token: str, calls_per_minute: int | None = None, ): self.base_url = f"https://{subdomain}.zendesk.com/api/v2" self.auth = (f"{email}/token", token) self.make_request = request_with_rate_limit(self, calls_per_minute) def request_with_rate_limit( client: ZendeskClient, max_calls_per_minute: int | None = None ) -> Callable[[str, dict[str, Any]], dict[str, Any]]: @retry_builder() @( rate_limit_builder(max_calls=max_calls_per_minute, period=60) if max_calls_per_minute else lambda x: x ) def make_request(endpoint: str, params: dict[str, Any]) -> dict[str, Any]: response = requests.get( f"{client.base_url}/{endpoint}", auth=client.auth, params=params ) if response.status_code == 429: retry_after = response.headers.get("Retry-After") if retry_after is not None: # Sleep for the duration indicated by the Retry-After header time.sleep(int(retry_after)) elif ( response.status_code == 403 and response.json().get("error") == "SupportProductInactive" ): return response.json() response.raise_for_status() return response.json() return make_request class ZendeskPageResponse(BaseModel): data: list[dict[str, Any]] meta: dict[str, Any] has_more: bool def _get_content_tag_mapping(client: ZendeskClient) -> dict[str, str]: content_tags: dict[str, str] = {} params = {"page[size]": MAX_PAGE_SIZE} try: while True: data = client.make_request("guide/content_tags", params) for tag in data.get("records", []): content_tags[tag["id"]] = tag["name"] # Check if there are more pages if data.get("meta", {}).get("has_more", False): params["page[after]"] = data["meta"]["after_cursor"] else: break return content_tags except Exception as e: raise Exception(f"Error fetching content tags: {str(e)}") def _get_articles( client: ZendeskClient, start_time: int | None = None, page_size: int = MAX_PAGE_SIZE ) -> Iterator[dict[str, Any]]: params = {"page[size]": page_size, "sort_by": "updated_at", "sort_order": "asc"} if start_time is not None: params["start_time"] = start_time while True: data = client.make_request("help_center/articles", params) for article in data["articles"]: yield article if not data.get("meta", {}).get("has_more"): break params["page[after]"] = data["meta"]["after_cursor"] def _get_article_page( client: ZendeskClient, start_time: int | None = None, after_cursor: str | None = None, page_size: int = MAX_PAGE_SIZE, ) -> ZendeskPageResponse: params = {"page[size]": page_size, "sort_by": "updated_at", "sort_order": "asc"} if start_time is not None: params["start_time"] = start_time if after_cursor is not None: params["page[after]"] = after_cursor data = client.make_request("help_center/articles", params) return ZendeskPageResponse( data=data["articles"], meta=data["meta"], has_more=bool(data["meta"].get("has_more", False)), ) def _get_tickets( client: ZendeskClient, start_time: int | None = None ) -> Iterator[dict[str, Any]]: params = {"start_time": start_time or 0} while True: data = client.make_request("incremental/tickets.json", params) for ticket in data["tickets"]: yield ticket if not data.get("end_of_stream", False): params["start_time"] = data["end_time"] else: break # TODO: maybe these don't need to be their own functions? def _get_tickets_page( client: ZendeskClient, start_time: int | None = None ) -> ZendeskPageResponse: params = {"start_time": start_time or 0} # NOTE: for some reason zendesk doesn't seem to be respecting the start_time param # in my local testing with very few tickets. We'll look into it if this becomes an # issue in larger deployments data = client.make_request("incremental/tickets.json", params) if data.get("error") == "SupportProductInactive": raise ValueError( "Zendesk Support Product is not active for this account, No tickets to index" ) return ZendeskPageResponse( data=data["tickets"], meta={"end_time": data["end_time"]}, has_more=not bool(data.get("end_of_stream", False)), ) def _fetch_author( client: ZendeskClient, author_id: str | int ) -> BasicExpertInfo | None: # Skip fetching if author_id is invalid # cast to str to avoid issues with zendesk changing their types if not author_id or str(author_id) == "-1": return None try: author_data = client.make_request(f"users/{author_id}", {}) user = author_data.get("user") return ( BasicExpertInfo(display_name=user.get("name"), email=user.get("email")) if user and user.get("name") and user.get("email") else None ) except requests.exceptions.HTTPError: # Handle any API errors gracefully return None def _article_to_document( article: dict[str, Any], content_tags: dict[str, str], author_map: dict[str, BasicExpertInfo], client: ZendeskClient, ) -> tuple[dict[str, BasicExpertInfo] | None, Document]: author_id = article.get("author_id") if not author_id: author = None else: author = ( author_map.get(author_id) if author_id in author_map else _fetch_author(client, author_id) ) new_author_mapping = {author_id: author} if author_id and author else None updated_at = article.get("updated_at") update_time = time_str_to_utc(updated_at) if updated_at else None # Build metadata metadata: dict[str, str | list[str]] = { "labels": [str(label) for label in article.get("label_names", []) if label], "content_tags": [ content_tags[tag_id] for tag_id in article.get("content_tag_ids", []) if tag_id in content_tags ], } # Remove empty values metadata = {k: v for k, v in metadata.items() if v} return new_author_mapping, Document( id=f"article:{article['id']}", sections=[ TextSection( link=cast(str, article.get("html_url")), text=parse_html_page_basic(article["body"]), ) ], source=DocumentSource.ZENDESK, semantic_identifier=article["title"], doc_updated_at=update_time, primary_owners=[author] if author else None, metadata=metadata, ) def _get_comment_text( comment: dict[str, Any], author_map: dict[str, BasicExpertInfo], client: ZendeskClient, ) -> tuple[dict[str, BasicExpertInfo] | None, str]: author_id = comment.get("author_id") if not author_id: author = None else: author = ( author_map.get(author_id) if author_id in author_map else _fetch_author(client, author_id) ) new_author_mapping = {author_id: author} if author_id and author else None comment_text = f"Comment{' by ' + author.display_name if author and author.display_name else ''}" comment_text += f"{' at ' + comment['created_at'] if comment.get('created_at') else ''}:\n{comment['body']}" return new_author_mapping, comment_text def _ticket_to_document( ticket: dict[str, Any], author_map: dict[str, BasicExpertInfo], client: ZendeskClient, default_subdomain: str, ) -> tuple[dict[str, BasicExpertInfo] | None, Document]: submitter_id = ticket.get("submitter") if not submitter_id: submitter = None else: submitter = ( author_map.get(submitter_id) if submitter_id in author_map else _fetch_author(client, submitter_id) ) new_author_mapping = ( {submitter_id: submitter} if submitter_id and submitter else None ) updated_at = ticket.get("updated_at") update_time = time_str_to_utc(updated_at) if updated_at else None metadata: dict[str, str | list[str]] = {} if status := ticket.get("status"): metadata["status"] = status if priority := ticket.get("priority"): metadata["priority"] = priority if tags := ticket.get("tags"): metadata["tags"] = tags if ticket_type := ticket.get("type"): metadata["ticket_type"] = ticket_type # Fetch comments for the ticket comments_data = client.make_request(f"tickets/{ticket.get('id')}/comments", {}) comments = comments_data.get("comments", []) comment_texts = [] for comment in comments: new_author_mapping, comment_text = _get_comment_text( comment, author_map, client ) if new_author_mapping: author_map.update(new_author_mapping) comment_texts.append(comment_text) comments_text = "\n\n".join(comment_texts) subject = ticket.get("subject") full_text = f"Ticket Subject:\n{subject}\n\nComments:\n{comments_text}" ticket_url = ticket.get("url") subdomain = ( ticket_url.split("//")[1].split(".zendesk.com")[0] if ticket_url else default_subdomain ) ticket_display_url = ( f"https://{subdomain}.zendesk.com/agent/tickets/{ticket.get('id')}" ) return new_author_mapping, Document( id=f"zendesk_ticket_{ticket['id']}", sections=[TextSection(link=ticket_display_url, text=full_text)], source=DocumentSource.ZENDESK, semantic_identifier=f"Ticket #{ticket['id']}: {subject or 'No Subject'}", doc_updated_at=update_time, primary_owners=[submitter] if submitter else None, metadata=metadata, ) class ZendeskConnectorCheckpoint(ConnectorCheckpoint): # We use cursor-based paginated retrieval for articles after_cursor_articles: str | None # We use timestamp-based paginated retrieval for tickets next_start_time_tickets: int | None cached_author_map: dict[str, BasicExpertInfo] | None cached_content_tags: dict[str, str] | None class ZendeskConnector( SlimConnectorWithPermSync, CheckpointedConnector[ZendeskConnectorCheckpoint] ): def __init__( self, content_type: str = "articles", calls_per_minute: int | None = None, ) -> None: self.content_type = content_type self.subdomain = "" # Fetch all tags ahead of time self.content_tags: dict[str, str] = {} self.calls_per_minute = calls_per_minute def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: # Subdomain is actually the whole URL subdomain = ( credentials["zendesk_subdomain"] .replace("https://", "") .split(".zendesk.com")[0] ) self.subdomain = subdomain self.client = ZendeskClient( subdomain, credentials["zendesk_email"], credentials["zendesk_token"], calls_per_minute=self.calls_per_minute, ) return None @override def load_from_checkpoint( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch, checkpoint: ZendeskConnectorCheckpoint, ) -> CheckpointOutput[ZendeskConnectorCheckpoint]: if self.client is None: raise ZendeskCredentialsNotSetUpError() if checkpoint.cached_content_tags is None: checkpoint.cached_content_tags = _get_content_tag_mapping(self.client) return checkpoint # save the content tags to the checkpoint self.content_tags = checkpoint.cached_content_tags if self.content_type == "articles": checkpoint = yield from self._retrieve_articles(start, end, checkpoint) return checkpoint elif self.content_type == "tickets": checkpoint = yield from self._retrieve_tickets(start, end, checkpoint) return checkpoint else: raise ValueError(f"Unsupported content_type: {self.content_type}") def _retrieve_articles( self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None, checkpoint: ZendeskConnectorCheckpoint, ) -> CheckpointOutput[ZendeskConnectorCheckpoint]: checkpoint = copy.deepcopy(checkpoint) # This one is built on the fly as there may be more many more authors than tags author_map: dict[str, BasicExpertInfo] = checkpoint.cached_author_map or {} after_cursor = checkpoint.after_cursor_articles doc_batch: list[Document] = [] response = _get_article_page( self.client, start_time=int(start) if start else None, after_cursor=after_cursor, ) articles = response.data has_more = response.has_more after_cursor = response.meta.get("after_cursor") for article in articles: if ( article.get("body") is None or article.get("draft") or any( label in ZENDESK_CONNECTOR_SKIP_ARTICLE_LABELS for label in article.get("label_names", []) ) ): continue try: new_author_map, document = _article_to_document( article, self.content_tags, author_map, self.client ) except Exception as e: yield ConnectorFailure( failed_document=DocumentFailure( document_id=f"{article.get('id')}", document_link=article.get("html_url", ""), ), failure_message=str(e), exception=e, ) continue if new_author_map: author_map.update(new_author_map) doc_batch.append(document) if not has_more: yield from doc_batch checkpoint.has_more = False return checkpoint # Sometimes no documents are retrieved, but the cursor # is still updated so the connector makes progress. yield from doc_batch checkpoint.after_cursor_articles = after_cursor last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None checkpoint.has_more = bool( end is None or last_doc_updated_at is None or last_doc_updated_at.timestamp() <= end ) checkpoint.cached_author_map = ( author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None ) return checkpoint def _retrieve_tickets( self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None, checkpoint: ZendeskConnectorCheckpoint, ) -> CheckpointOutput[ZendeskConnectorCheckpoint]: checkpoint = copy.deepcopy(checkpoint) if self.client is None: raise ZendeskCredentialsNotSetUpError() author_map: dict[str, BasicExpertInfo] = checkpoint.cached_author_map or {} doc_batch: list[Document] = [] next_start_time = int(checkpoint.next_start_time_tickets or start or 0) ticket_response = _get_tickets_page(self.client, start_time=next_start_time) tickets = ticket_response.data has_more = ticket_response.has_more next_start_time = ticket_response.meta["end_time"] for ticket in tickets: if ticket.get("status") == "deleted": continue try: new_author_map, document = _ticket_to_document( ticket=ticket, author_map=author_map, client=self.client, default_subdomain=self.subdomain, ) except Exception as e: yield ConnectorFailure( failed_document=DocumentFailure( document_id=f"{ticket.get('id')}", document_link=ticket.get("url", ""), ), failure_message=str(e), exception=e, ) continue if new_author_map: author_map.update(new_author_map) doc_batch.append(document) if not has_more: yield from doc_batch checkpoint.has_more = False return checkpoint yield from doc_batch checkpoint.next_start_time_tickets = next_start_time last_doc_updated_at = doc_batch[-1].doc_updated_at if doc_batch else None checkpoint.has_more = bool( end is None or last_doc_updated_at is None or last_doc_updated_at.timestamp() <= end ) checkpoint.cached_author_map = ( author_map if len(author_map) <= MAX_AUTHOR_MAP_SIZE else None ) return checkpoint def retrieve_all_slim_docs_perm_sync( self, start: SecondsSinceUnixEpoch | None = None, end: SecondsSinceUnixEpoch | None = None, # noqa: ARG002 callback: IndexingHeartbeatInterface | None = None, # noqa: ARG002 ) -> GenerateSlimDocumentOutput: slim_doc_batch: list[SlimDocument | HierarchyNode] = [] if self.content_type == "articles": articles = _get_articles( self.client, start_time=int(start) if start else None ) for article in articles: slim_doc_batch.append( SlimDocument( id=f"article:{article['id']}", ) ) if len(slim_doc_batch) >= _SLIM_BATCH_SIZE: yield slim_doc_batch slim_doc_batch = [] elif self.content_type == "tickets": tickets = _get_tickets( self.client, start_time=int(start) if start else None ) for ticket in tickets: slim_doc_batch.append( SlimDocument( id=f"zendesk_ticket_{ticket['id']}", ) ) if len(slim_doc_batch) >= _SLIM_BATCH_SIZE: yield slim_doc_batch slim_doc_batch = [] else: raise ValueError(f"Unsupported content_type: {self.content_type}") if slim_doc_batch: yield slim_doc_batch @override def validate_connector_settings(self) -> None: if self.client is None: raise ZendeskCredentialsNotSetUpError() try: _get_article_page(self.client, start_time=0) except HTTPError as e: # Check for HTTP status codes if e.response.status_code == 401: raise CredentialExpiredError( "Your Zendesk credentials appear to be invalid or expired (HTTP 401)." ) from e elif e.response.status_code == 403: raise InsufficientPermissionsError( "Your Zendesk token does not have sufficient permissions (HTTP 403)." ) from e elif e.response.status_code == 404: raise ConnectorValidationError( "Zendesk resource not found (HTTP 404)." ) from e else: raise ConnectorValidationError( f"Unexpected Zendesk error (status={e.response.status_code}): {e}" ) from e @override def validate_checkpoint_json( self, checkpoint_json: str ) -> ZendeskConnectorCheckpoint: return ZendeskConnectorCheckpoint.model_validate_json(checkpoint_json) @override def build_dummy_checkpoint(self) -> ZendeskConnectorCheckpoint: return ZendeskConnectorCheckpoint( after_cursor_articles=None, next_start_time_tickets=None, cached_author_map=None, cached_content_tags=None, has_more=True, ) if __name__ == "__main__": import os connector = ZendeskConnector() connector.load_credentials( { "zendesk_subdomain": os.environ["ZENDESK_SUBDOMAIN"], "zendesk_email": os.environ["ZENDESK_EMAIL"], "zendesk_token": os.environ["ZENDESK_TOKEN"], } ) current = time.time() one_day_ago = current - 24 * 60 * 60 # 1 day document_batches = connector.load_from_checkpoint( one_day_ago, current, connector.build_dummy_checkpoint(), ) print(next(document_batches)) ================================================ FILE: backend/onyx/connectors/zulip/__init__.py ================================================ ================================================ FILE: backend/onyx/connectors/zulip/connector.py ================================================ import os import tempfile import urllib.parse from collections.abc import Generator from datetime import datetime from datetime import timezone from typing import Any from typing import Dict from typing import List from typing import Tuple from typing import Union from zulip import Client from onyx.configs.app_configs import INDEX_BATCH_SIZE from onyx.configs.constants import DocumentSource from onyx.connectors.interfaces import GenerateDocumentsOutput from onyx.connectors.interfaces import LoadConnector from onyx.connectors.interfaces import PollConnector from onyx.connectors.interfaces import SecondsSinceUnixEpoch from onyx.connectors.models import ConnectorMissingCredentialError from onyx.connectors.models import Document from onyx.connectors.models import HierarchyNode from onyx.connectors.models import TextSection from onyx.connectors.zulip.schemas import GetMessagesResponse from onyx.connectors.zulip.schemas import Message from onyx.connectors.zulip.utils import build_search_narrow from onyx.connectors.zulip.utils import call_api from onyx.connectors.zulip.utils import encode_zulip_narrow_operand from onyx.utils.logger import setup_logger # Potential improvements # 1. Group documents messages into topics, make 1 document per topic per week # 2. Add end date support once https://github.com/zulip/zulip/issues/25436 is solved logger = setup_logger() class ZulipConnector(LoadConnector, PollConnector): def __init__( self, realm_name: str, realm_url: str, batch_size: int = INDEX_BATCH_SIZE ) -> None: self.batch_size = batch_size self.realm_name = realm_name # Clean and normalize the URL realm_url = realm_url.strip().lower() # Remove any trailing slashes realm_url = realm_url.rstrip("/") # Ensure the URL has a scheme if not realm_url.startswith(("http://", "https://")): realm_url = f"https://{realm_url}" try: parsed = urllib.parse.urlparse(realm_url) # Extract the base domain without any paths or ports netloc = parsed.netloc.split(":")[0] # Remove port if present if not netloc: raise ValueError( f"Invalid realm URL format: {realm_url}. URL must include a valid domain name." ) # Always use HTTPS for security self.base_url = f"https://{netloc}" self.client: Client | None = None except Exception as e: raise ValueError( f"Failed to parse Zulip realm URL: {realm_url}. " f"Please provide a URL in the format: domain.com or https://domain.com. " f"Error: {str(e)}" ) def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None: contents = credentials["zuliprc_content"] # The input field converts newlines to spaces in the provided # zuliprc file. This reverts them back to newlines. contents_spaces_to_newlines = contents.replace(" ", "\n") # create a temporary zuliprc file tempdir = tempfile.tempdir if tempdir is None: raise Exception("Could not determine tempfile directory") config_file = os.path.join(tempdir, f"zuliprc-{self.realm_name}") with open(config_file, "w") as f: f.write(contents_spaces_to_newlines) self.client = Client(config_file=config_file) return None def _message_to_narrow_link(self, m: Message) -> str: try: stream_name = m.display_recipient # assume str stream_operand = encode_zulip_narrow_operand(f"{m.stream_id}-{stream_name}") topic_operand = encode_zulip_narrow_operand(m.subject) narrow_link = f"{self.base_url}#narrow/stream/{stream_operand}/topic/{topic_operand}/near/{m.id}" return narrow_link except Exception as e: logger.error(f"Error generating Zulip message link: {e}") # Fallback to a basic link that at least includes the base URL return f"{self.base_url}#narrow/id/{m.id}" def _get_message_batch(self, anchor: str) -> Tuple[bool, List[Message]]: if self.client is None: raise ConnectorMissingCredentialError("Zulip") logger.info(f"Fetching messages starting with anchor={anchor}") request = build_search_narrow( limit=INDEX_BATCH_SIZE, anchor=anchor, apply_md=False ) response = GetMessagesResponse(**call_api(self.client.get_messages, request)) end = False if len(response.messages) == 0 or response.found_oldest: end = True # reverse, so that the last message is the new anchor # and the order is from newest to oldest return end, response.messages[::-1] def _message_to_doc(self, message: Message) -> Document: text = f"{message.sender_full_name}: {message.content}" try: # Convert timestamps to UTC datetime objects post_time = datetime.fromtimestamp(message.timestamp, tz=timezone.utc) edit_time = ( datetime.fromtimestamp(message.last_edit_timestamp, tz=timezone.utc) if message.last_edit_timestamp is not None else None ) # Use the most recent edit time if available, otherwise use post time doc_time = edit_time if edit_time is not None else post_time except (ValueError, TypeError) as e: logger.warning(f"Failed to parse timestamp for message {message.id}: {e}") post_time = None edit_time = None doc_time = None metadata: Dict[str, Union[str, List[str]]] = { "stream_name": str(message.display_recipient), "topic": str(message.subject), "sender_name": str(message.sender_full_name), "sender_email": str(message.sender_email), "message_timestamp": str(message.timestamp), "message_id": str(message.id), "stream_id": str(message.stream_id), "has_reactions": str(len(message.reactions) > 0), "content_type": str(message.content_type or "text"), } # Always include edit timestamp in metadata when available if edit_time is not None: metadata["edit_timestamp"] = str(message.last_edit_timestamp) return Document( id=f"{message.stream_id}__{message.id}", sections=[ TextSection( link=self._message_to_narrow_link(message), text=text, ) ], source=DocumentSource.ZULIP, semantic_identifier=f"{message.display_recipient} > {message.subject}", metadata=metadata, doc_updated_at=doc_time, # Use most recent edit time or post time ) def _get_docs( self, anchor: str, start: SecondsSinceUnixEpoch | None = None ) -> Generator[Document, None, None]: message: Message | None = None while True: end, message_batch = self._get_message_batch(anchor) for message in message_batch: if start is not None and float(message.timestamp) < start: return yield self._message_to_doc(message) if end or message is None: return # Last message is oldest, use as next anchor anchor = str(message.id) def _poll_source( self, start: SecondsSinceUnixEpoch | None, end: SecondsSinceUnixEpoch | None, # noqa: ARG002 ) -> GenerateDocumentsOutput: # Since Zulip doesn't support searching by timestamp, # we have to always start from the newest message # and go backwards. anchor = "newest" docs: list[Document | HierarchyNode] = [] for doc in self._get_docs(anchor=anchor, start=start): docs.append(doc) if len(docs) == self.batch_size: yield docs docs = [] if docs: yield docs def load_from_state(self) -> GenerateDocumentsOutput: return self._poll_source(start=None, end=None) def poll_source( self, start: SecondsSinceUnixEpoch, end: SecondsSinceUnixEpoch ) -> GenerateDocumentsOutput: return self._poll_source(start, end) ================================================ FILE: backend/onyx/connectors/zulip/schemas.py ================================================ from typing import Any from typing import List from typing import Optional from typing import Union from pydantic import BaseModel from pydantic import Field class Message(BaseModel): id: int sender_id: int content: str recipient_id: int timestamp: int client: str is_me_message: bool sender_full_name: str sender_email: str sender_realm_str: str subject: str topic_links: Optional[List[Any]] = None last_edit_timestamp: Optional[int] = None edit_history: Any = None reactions: List[Any] submessages: List[Any] flags: List[str] = Field(default_factory=list) display_recipient: Optional[str] = None type: Optional[str] = None stream_id: int avatar_url: Optional[str] content_type: Optional[str] rendered_content: Optional[str] = None class GetMessagesResponse(BaseModel): result: str msg: str found_anchor: Optional[bool] = None found_oldest: Optional[bool] = None found_newest: Optional[bool] = None history_limited: Optional[bool] = None anchor: Optional[Union[str, int]] = None messages: List[Message] = Field(default_factory=list) ================================================ FILE: backend/onyx/connectors/zulip/utils.py ================================================ import time from collections.abc import Callable from typing import Any from typing import Dict from typing import Optional from urllib.parse import quote from onyx.utils.logger import setup_logger logger = setup_logger() class ZulipAPIError(Exception): def __init__(self, code: Any = None, msg: str | None = None) -> None: self.code = code self.msg = msg def __str__(self) -> str: return ( f"Error occurred during Zulip API call: {self.msg}" + "" if self.code is None else f" ({self.code})" ) class ZulipHTTPError(ZulipAPIError): def __init__(self, msg: str | None = None, status_code: Any = None) -> None: super().__init__(code=None, msg=msg) self.status_code = status_code def __str__(self) -> str: return f"HTTP error {self.status_code} occurred during Zulip API call" def __call_with_retry(fun: Callable, *args: Any, **kwargs: Any) -> Dict[str, Any]: result = fun(*args, **kwargs) if result.get("result") == "error": if result.get("code") == "RATE_LIMIT_HIT": retry_after = float(result["retry-after"]) + 1 logger.warn(f"Rate limit hit, retrying after {retry_after} seconds") time.sleep(retry_after) return __call_with_retry(fun, *args) return result def __raise_if_error(response: dict[str, Any]) -> None: if response.get("result") == "error": raise ZulipAPIError( code=response.get("code"), msg=response.get("msg"), ) elif response.get("result") == "http-error": raise ZulipHTTPError( msg=response.get("msg"), status_code=response.get("status_code") ) def call_api(fun: Callable, *args: Any, **kwargs: Any) -> Dict[str, Any]: response = __call_with_retry(fun, *args, **kwargs) __raise_if_error(response) return response def build_search_narrow( *, stream: Optional[str] = None, topic: Optional[str] = None, limit: int = 100, content: Optional[str] = None, apply_md: bool = False, anchor: str = "newest", ) -> Dict[str, Any]: narrow_filters = [] if stream: narrow_filters.append({"operator": "stream", "operand": stream}) if topic: narrow_filters.append({"operator": "topic", "operand": topic}) if content: narrow_filters.append({"operator": "has", "operand": content}) if not stream and not topic and not content: narrow_filters.append({"operator": "streams", "operand": "public"}) narrow = { "anchor": anchor, "num_before": limit, "num_after": 0, "narrow": narrow_filters, } narrow["apply_markdown"] = apply_md return narrow def encode_zulip_narrow_operand(value: str) -> str: # like https://github.com/zulip/zulip/blob/1577662a6/static/js/hash_util.js#L18-L25 # safe characters necessary to make Python match Javascript's escaping behaviour, # see: https://stackoverflow.com/a/74439601 return quote(value, safe="!~*'()").replace(".", "%2E").replace("%", ".") ================================================ FILE: backend/onyx/context/search/__init__.py ================================================ ================================================ FILE: backend/onyx/context/search/enums.py ================================================ """NOTE: this needs to be separate from models.py because of circular imports. Both search/models.py and db/models.py import enums from this file AND search/models.py imports from db/models.py.""" from enum import Enum class RecencyBiasSetting(str, Enum): FAVOR_RECENT = "favor_recent" # 2x decay rate BASE_DECAY = "base_decay" NO_DECAY = "no_decay" # Determine based on query if to use base_decay or favor_recent AUTO = "auto" class QueryType(str, Enum): """ The type of first-pass query to use for hybrid search. The values of this enum are injected into the ranking profile name which should match the name in the schema. """ KEYWORD = "keyword" SEMANTIC = "semantic" class SearchType(str, Enum): KEYWORD = "keyword" SEMANTIC = "semantic" INTERNET = "internet" ================================================ FILE: backend/onyx/context/search/federated/models.py ================================================ from datetime import datetime from typing import TypedDict from pydantic import BaseModel from onyx.onyxbot.slack.models import ChannelType class ChannelMetadata(TypedDict): """Type definition for cached channel metadata.""" name: str type: ChannelType is_private: bool is_member: bool class SlackMessage(BaseModel): document_id: str channel_id: str message_id: str thread_id: str | None link: str metadata: dict[str, str | list[str]] timestamp: datetime recency_bias: float semantic_identifier: str text: str highlighted_texts: set[str] slack_score: float ================================================ FILE: backend/onyx/context/search/federated/slack_search.py ================================================ import json import re import time from datetime import datetime from typing import Any from typing import cast from pydantic import BaseModel from pydantic import ConfigDict from pydantic import ValidationError from slack_sdk import WebClient from slack_sdk.errors import SlackApiError from sqlalchemy.orm import Session from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG from onyx.configs.app_configs import MAX_SLACK_THREAD_CONTEXT_MESSAGES from onyx.configs.app_configs import SLACK_THREAD_CONTEXT_BATCH_SIZE from onyx.configs.chat_configs import DOC_TIME_DECAY from onyx.connectors.models import IndexingDocument from onyx.connectors.models import TextSection from onyx.context.search.federated.models import ChannelMetadata from onyx.context.search.federated.models import SlackMessage from onyx.context.search.federated.slack_search_utils import ALL_CHANNEL_TYPES from onyx.context.search.federated.slack_search_utils import build_channel_query_filter from onyx.context.search.federated.slack_search_utils import build_slack_queries from onyx.context.search.federated.slack_search_utils import get_channel_type from onyx.context.search.federated.slack_search_utils import ( get_channel_type_for_missing_scope, ) from onyx.context.search.federated.slack_search_utils import is_recency_query from onyx.context.search.federated.slack_search_utils import should_include_message from onyx.context.search.models import ChunkIndexRequest from onyx.context.search.models import InferenceChunk from onyx.db.document import DocumentSource from onyx.db.models import SearchSettings from onyx.db.search_settings import get_current_search_settings from onyx.document_index.document_index_utils import ( get_multipass_config, ) from onyx.federated_connectors.slack.models import SlackEntities from onyx.indexing.chunker import Chunker from onyx.indexing.embedder import DefaultIndexingEmbedder from onyx.indexing.models import DocAwareChunk from onyx.llm.factory import get_default_llm from onyx.onyxbot.slack.models import ChannelType from onyx.onyxbot.slack.models import SlackContext from onyx.redis.redis_pool import get_redis_client from onyx.server.federated.models import FederatedConnectorDetail from onyx.utils.logger import setup_logger from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel from onyx.utils.timing import log_function_time from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE logger = setup_logger() HIGHLIGHT_START_CHAR = "\ue000" HIGHLIGHT_END_CHAR = "\ue001" CHANNEL_METADATA_CACHE_TTL = 60 * 60 * 24 # 24 hours USER_PROFILE_CACHE_TTL = 60 * 60 * 24 # 24 hours SLACK_THREAD_CONTEXT_WINDOW = 3 # Number of messages before matched message to include CHANNEL_METADATA_MAX_RETRIES = 3 # Maximum retry attempts for channel metadata fetching CHANNEL_METADATA_RETRY_DELAY = 1 # Initial retry delay in seconds (exponential backoff) def fetch_and_cache_channel_metadata( access_token: str, team_id: str, include_private: bool = True ) -> dict[str, ChannelMetadata]: """ Fetch ALL channel metadata in one API call and cache it. Returns a dict mapping channel_id -> metadata including name, type, etc. This replaces multiple conversations.info calls with a single conversations.list. Note: We ALWAYS fetch all channel types (including private) and cache them together. This ensures a single cache entry per team, avoiding duplicate API calls. """ # Use tenant-specific Redis client redis_client = get_redis_client() # (tenant_id prefix is added automatically by TenantRedis) cache_key = f"slack_federated_search:{team_id}:channels:metadata" try: cached = redis_client.get(cache_key) if cached: logger.debug(f"Channel metadata cache HIT for team {team_id}") cached_str: str = ( cached.decode("utf-8") if isinstance(cached, bytes) else str(cached) ) cached_data = cast(dict[str, ChannelMetadata], json.loads(cached_str)) logger.debug(f"Loaded {len(cached_data)} channels from cache") if not include_private: filtered: dict[str, ChannelMetadata] = { k: v for k, v in cached_data.items() if v.get("type") != ChannelType.PRIVATE_CHANNEL.value } logger.debug(f"Filtered to {len(filtered)} channels (exclude private)") return filtered return cached_data except Exception as e: logger.warning(f"Error reading from channel metadata cache: {e}") # Cache miss - fetch from Slack API with retry logic logger.debug(f"Channel metadata cache MISS for team {team_id} - fetching from API") slack_client = WebClient(token=access_token) channel_metadata: dict[str, ChannelMetadata] = {} # Retry logic with exponential backoff last_exception = None available_channel_types = ALL_CHANNEL_TYPES.copy() for attempt in range(CHANNEL_METADATA_MAX_RETRIES): try: # Use available channel types (may be reduced if scopes are missing) channel_types = ",".join(available_channel_types) # Fetch all channels in one call cursor = None channel_count = 0 while True: response = slack_client.conversations_list( types=channel_types, exclude_archived=True, limit=1000, cursor=cursor, ) response.validate() # Cast response.data to dict for type checking response_data: dict[str, Any] = response.data # type: ignore for ch in response_data.get("channels", []): channel_id = ch.get("id") if not channel_id: continue # Determine channel type channel_type_enum = get_channel_type(channel_info=ch) channel_type = ChannelType(channel_type_enum.value) channel_metadata[channel_id] = { "name": ch.get("name", ""), "type": channel_type, "is_private": ch.get("is_private", False), "is_member": ch.get("is_member", False), } channel_count += 1 cursor = response_data.get("response_metadata", {}).get("next_cursor") if not cursor: break logger.info(f"Fetched {channel_count} channels for team {team_id}") # Cache the results try: redis_client.set( cache_key, json.dumps(channel_metadata), ex=CHANNEL_METADATA_CACHE_TTL, ) logger.info( f"Cached {channel_count} channels for team {team_id} (TTL: {CHANNEL_METADATA_CACHE_TTL}s, key: {cache_key})" ) except Exception as e: logger.warning(f"Error caching channel metadata: {e}") return channel_metadata except SlackApiError as e: last_exception = e # Extract all needed fields from response upfront if e.response: error_response = e.response.get("error", "") needed_scope = e.response.get("needed", "") else: error_response = "" needed_scope = "" # Check if this is a missing_scope error if error_response == "missing_scope": # Get the channel type that requires this scope missing_channel_type = get_channel_type_for_missing_scope(needed_scope) if ( missing_channel_type and missing_channel_type in available_channel_types ): # Remove the problematic channel type and retry available_channel_types.remove(missing_channel_type) logger.warning( f"Missing scope '{needed_scope}' for channel type '{missing_channel_type}'. " f"Continuing with reduced channel types: {available_channel_types}" ) # Don't count this as a retry attempt, just try again with fewer types if available_channel_types: # Only continue if we have types left continue # Otherwise fall through to retry logic else: logger.error( f"Missing scope '{needed_scope}' but could not map to channel type or already removed. " f"Response: {e.response}" ) # For other errors, use retry logic if attempt < CHANNEL_METADATA_MAX_RETRIES - 1: retry_delay = CHANNEL_METADATA_RETRY_DELAY * (2**attempt) logger.warning( f"Failed to fetch channel metadata (attempt {attempt + 1}/{CHANNEL_METADATA_MAX_RETRIES}): {e}. " f"Retrying in {retry_delay}s..." ) time.sleep(retry_delay) else: logger.error( f"Failed to fetch channel metadata after {CHANNEL_METADATA_MAX_RETRIES} attempts: {e}" ) # If we have some channel metadata despite errors, return it with a warning if channel_metadata: logger.warning( f"Returning partial channel metadata ({len(channel_metadata)} channels) despite errors. Last error: {last_exception}" ) return channel_metadata # If we exhausted all retries and have no data, raise the last exception if last_exception: raise SlackApiError( f"Channel metadata fetching failed after {CHANNEL_METADATA_MAX_RETRIES} attempts", last_exception.response, ) return {} def get_available_channels( access_token: str, team_id: str, include_private: bool = False ) -> list[str]: """Fetch list of available channel names using cached metadata.""" metadata = fetch_and_cache_channel_metadata(access_token, team_id, include_private) return [meta["name"] for meta in metadata.values() if meta["name"]] def get_cached_user_profile( access_token: str, team_id: str, user_id: str ) -> str | None: """ Get a user's display name from cache or fetch from Slack API. Uses Redis caching to avoid repeated API calls and rate limiting. Returns the user's real_name or email, or None if not found. """ redis_client = get_redis_client() cache_key = f"slack_federated_search:{team_id}:user:{user_id}" # Check cache first try: cached = redis_client.get(cache_key) if cached is not None: cached_str = ( cached.decode("utf-8") if isinstance(cached, bytes) else str(cached) ) # Empty string means user was not found previously return cached_str if cached_str else None except Exception as e: logger.debug(f"Error reading user profile cache: {e}") # Cache miss - fetch from Slack API slack_client = WebClient(token=access_token) try: response = slack_client.users_profile_get(user=user_id) response.validate() profile: dict[str, Any] = response.get("profile", {}) name: str | None = profile.get("real_name") or profile.get("email") # Cache the result (empty string for not found) try: redis_client.set( cache_key, name or "", ex=USER_PROFILE_CACHE_TTL, ) except Exception as e: logger.debug(f"Error caching user profile: {e}") return name except SlackApiError as e: error_str = str(e) if "user_not_found" in error_str: logger.debug( f"User {user_id} not found in Slack workspace (likely deleted/deactivated)" ) elif "ratelimited" in error_str: # Don't cache rate limit errors - we'll retry later logger.debug(f"Rate limited fetching user {user_id}, will retry later") return None else: logger.warning(f"Could not fetch profile for user {user_id}: {e}") # Cache negative result to avoid repeated lookups for missing users try: redis_client.set(cache_key, "", ex=USER_PROFILE_CACHE_TTL) except Exception: pass return None def batch_get_user_profiles( access_token: str, team_id: str, user_ids: set[str] ) -> dict[str, str]: """ Batch fetch user profiles with caching. Returns a dict mapping user_id -> display_name for users that were found. """ result: dict[str, str] = {} for user_id in user_ids: name = get_cached_user_profile(access_token, team_id, user_id) if name: result[user_id] = name return result def _extract_channel_data_from_entities( entities: dict[str, Any] | None, channel_metadata_dict: dict[str, ChannelMetadata] | None, ) -> list[str] | None: """Extract available channels list from metadata based on entity configuration. Args: entities: Entity filter configuration dict channel_metadata_dict: Pre-fetched channel metadata dictionary Returns: List of available channel names, or None if not needed """ if not entities or not channel_metadata_dict: return None try: parsed_entities = SlackEntities(**entities) # Only extract if we have exclusions or channel filters if parsed_entities.exclude_channels or parsed_entities.channels: # Extract channel names from metadata dict return [ meta["name"] for meta in channel_metadata_dict.values() if meta["name"] and ( parsed_entities.include_private_channels or meta.get("type") != ChannelType.PRIVATE_CHANNEL.value ) ] except ValidationError: logger.debug("Failed to parse entities for channel data extraction") return None def _should_skip_channel( channel_id: str, allowed_private_channel: str | None, bot_token: str | None, access_token: str, include_dm: bool, channel_metadata_dict: dict[str, ChannelMetadata] | None = None, ) -> bool: """Bot context filtering: skip private channels unless explicitly allowed. Uses pre-fetched channel metadata when available to avoid API calls. """ if bot_token and not include_dm: try: # First try to use pre-fetched metadata from cache if channel_metadata_dict and channel_id in channel_metadata_dict: channel_meta = channel_metadata_dict[channel_id] channel_type_str = channel_meta.get("type", "") is_private_or_dm = channel_type_str in [ ChannelType.PRIVATE_CHANNEL.value, ChannelType.IM.value, ChannelType.MPIM.value, ] if is_private_or_dm and channel_id != allowed_private_channel: return True return False # Fallback: API call only if not in cache (should be rare) token_to_use = bot_token or access_token channel_client = WebClient(token=token_to_use) channel_info = channel_client.conversations_info(channel=channel_id) if isinstance(channel_info.data, dict): channel_data = channel_info.data.get("channel", {}) channel_type = get_channel_type(channel_info=channel_data) is_private_or_dm = channel_type in [ ChannelType.PRIVATE_CHANNEL, ChannelType.IM, ChannelType.MPIM, ] if is_private_or_dm and channel_id != allowed_private_channel: return True except Exception as e: logger.warning( f"Could not determine channel type for {channel_id}, filtering out: {e}" ) return True return False class SlackQueryResult(BaseModel): """Result from a single Slack query including stats.""" model_config = ConfigDict(arbitrary_types_allowed=True) messages: list[SlackMessage] filtered_channels: list[str] # Channels filtered out during this query def query_slack( query_string: str, access_token: str, limit: int | None = None, allowed_private_channel: str | None = None, bot_token: str | None = None, include_dm: bool = False, entities: dict[str, Any] | None = None, available_channels: list[str] | None = None, channel_metadata_dict: dict[str, ChannelMetadata] | None = None, ) -> SlackQueryResult: # Check if query has channel override (user specified channels in query) has_channel_override = query_string.startswith("__CHANNEL_OVERRIDE__") if has_channel_override: # Remove the marker and use the query as-is (already has channel filters) final_query = query_string.replace("__CHANNEL_OVERRIDE__", "").strip() else: # Normal flow: build channel filters from entity config channel_filter = "" if entities: channel_filter = build_channel_query_filter(entities, available_channels) final_query = query_string if channel_filter: # Add channel filter to query final_query = f"{query_string} {channel_filter}" logger.info(f"Final query to slack: {final_query}") # Detect if query asks for most recent results sort_by_time = is_recency_query(query_string) slack_client = WebClient(token=access_token) try: search_params: dict[str, Any] = { "query": final_query, "count": limit, "highlight": True, } # Sort by timestamp for recency-focused queries, otherwise by relevance if sort_by_time: search_params["sort"] = "timestamp" search_params["sort_dir"] = "desc" response = slack_client.search_messages(**search_params) response.validate() messages: dict[str, Any] = response.get("messages", {}) matches: list[dict[str, Any]] = messages.get("matches", []) logger.info(f"Slack search found {len(matches)} messages") except SlackApiError as slack_error: logger.error(f"Slack API error in search_messages: {slack_error}") logger.error( f"Slack API error details: status={slack_error.response.status_code}, error={slack_error.response.get('error')}" ) if "not_allowed_token_type" in str(slack_error): # Log token type prefix token_prefix = access_token[:4] if len(access_token) >= 4 else "unknown" logger.error(f"TOKEN TYPE ERROR: access_token type: {token_prefix}...") return SlackQueryResult(messages=[], filtered_channels=[]) # convert matches to slack messages slack_messages: list[SlackMessage] = [] filtered_channels: list[str] = [] for match in matches: text: str | None = match.get("text") permalink: str | None = match.get("permalink") message_id: str | None = match.get("ts") channel_id: str | None = match.get("channel", {}).get("id") channel_name: str | None = match.get("channel", {}).get("name") username: str | None = match.get("username") if not username: # Fallback: try to get from user field if username is missing user_info = match.get("user", "") if isinstance(user_info, str) and user_info: username = user_info # Use user ID as fallback else: username = "unknown_user" score: float = match.get("score", 0.0) if ( # can't use any() because of type checking :( not text or not permalink or not message_id or not channel_id or not channel_name or not username ): continue # Apply channel filtering if needed if _should_skip_channel( channel_id, allowed_private_channel, bot_token, access_token, include_dm, channel_metadata_dict, ): filtered_channels.append(f"{channel_name}({channel_id})") continue # generate thread id and document id thread_id = ( permalink.split("?thread_ts=", 1)[1] if "?thread_ts=" in permalink else None ) document_id = f"{channel_id}_{message_id}" decay_factor = DOC_TIME_DECAY doc_time = datetime.fromtimestamp(float(message_id)) doc_age_years = (datetime.now() - doc_time).total_seconds() / ( 365 * 24 * 60 * 60 ) recency_bias = max(1 / (1 + decay_factor * doc_age_years), 0.75) metadata: dict[str, str | list[str]] = { "channel": channel_name, "time": doc_time.isoformat(), } # extract out the highlighted texts highlighted_texts = set( re.findall( rf"{re.escape(HIGHLIGHT_START_CHAR)}(.*?){re.escape(HIGHLIGHT_END_CHAR)}", text, ) ) cleaned_text = text.replace(HIGHLIGHT_START_CHAR, "").replace( HIGHLIGHT_END_CHAR, "" ) # get the semantic identifier snippet = ( cleaned_text[:50].rstrip() + "..." if len(cleaned_text) > 50 else text ).replace("\n", " ") doc_sem_id = f"{username} in #{channel_name}: {snippet}" slack_messages.append( SlackMessage( document_id=document_id, channel_id=channel_id, message_id=message_id, thread_id=thread_id, link=permalink, metadata=metadata, timestamp=doc_time, recency_bias=recency_bias, semantic_identifier=doc_sem_id, text=f"{username}: {cleaned_text}", highlighted_texts=highlighted_texts, slack_score=score, ) ) return SlackQueryResult( messages=slack_messages, filtered_channels=filtered_channels ) def merge_slack_messages( query_results: list[SlackQueryResult], ) -> tuple[list[SlackMessage], dict[str, SlackMessage], set[str]]: """Merge messages from multiple query results, deduplicating by document_id. Returns: Tuple of (merged_messages, docid_to_message, all_filtered_channels) """ merged_messages: list[SlackMessage] = [] docid_to_message: dict[str, SlackMessage] = {} all_filtered_channels: set[str] = set() for result in query_results: # Collect filtered channels from all queries all_filtered_channels.update(result.filtered_channels) for message in result.messages: if message.document_id in docid_to_message: # update the score and highlighted texts, rest should be identical docid_to_message[message.document_id].slack_score = max( docid_to_message[message.document_id].slack_score, message.slack_score, ) docid_to_message[message.document_id].highlighted_texts.update( message.highlighted_texts ) continue # add the message to the list docid_to_message[message.document_id] = message merged_messages.append(message) # re-sort by score merged_messages.sort(key=lambda x: x.slack_score, reverse=True) return merged_messages, docid_to_message, all_filtered_channels class SlackRateLimitError(Exception): """Raised when Slack API returns a rate limit error (429).""" class ThreadContextResult: """Result wrapper for thread context fetch that captures error type.""" __slots__ = ("text", "is_rate_limited", "is_error") def __init__( self, text: str, is_rate_limited: bool = False, is_error: bool = False ): self.text = text self.is_rate_limited = is_rate_limited self.is_error = is_error @classmethod def success(cls, text: str) -> "ThreadContextResult": return cls(text) @classmethod def rate_limited(cls, original_text: str) -> "ThreadContextResult": return cls(original_text, is_rate_limited=True) @classmethod def error(cls, original_text: str) -> "ThreadContextResult": return cls(original_text, is_error=True) def _fetch_thread_context( message: SlackMessage, access_token: str, team_id: str | None = None ) -> ThreadContextResult: """ Fetch thread context for a message, returning a result object. Returns ThreadContextResult with: - success: enriched thread text - rate_limited: original text + flag indicating we should stop - error: original text for other failures (graceful degradation) """ channel_id = message.channel_id thread_id = message.thread_id message_id = message.message_id # If not a thread, return original text as success if thread_id is None: return ThreadContextResult.success(message.text) slack_client = WebClient(token=access_token, timeout=30) try: response = slack_client.conversations_replies( channel=channel_id, ts=thread_id, ) response.validate() messages: list[dict[str, Any]] = response.get("messages", []) except SlackApiError as e: # Check for rate limit error specifically if e.response and e.response.status_code == 429: logger.warning( f"Slack rate limit hit while fetching thread context for {channel_id}/{thread_id}" ) return ThreadContextResult.rate_limited(message.text) # For other Slack errors, log and return original text logger.error(f"Slack API error in thread context fetch: {e}") return ThreadContextResult.error(message.text) except Exception as e: # Network errors, timeouts, etc - treat as recoverable error logger.error(f"Unexpected error in thread context fetch: {e}") return ThreadContextResult.error(message.text) # If empty response or single message (not a thread), return original text if len(messages) <= 1: return ThreadContextResult.success(message.text) # Build thread text from thread starter + context window around matched message thread_text = _build_thread_text( messages, message_id, thread_id, access_token, team_id, slack_client ) return ThreadContextResult.success(thread_text) def _build_thread_text( messages: list[dict[str, Any]], message_id: str, thread_id: str, access_token: str, team_id: str | None, slack_client: WebClient, ) -> str: """Build the thread text from messages.""" msg_text = messages[0].get("text", "") msg_sender = messages[0].get("user", "") thread_text = f"<@{msg_sender}>: {msg_text}" thread_text += "\n\nReplies:" if thread_id == message_id: message_id_idx = 0 else: message_id_idx = next( (i for i, msg in enumerate(messages) if msg.get("ts") == message_id), 0 ) if not message_id_idx: return thread_text start_idx = max(1, message_id_idx - SLACK_THREAD_CONTEXT_WINDOW) if start_idx > 1: thread_text += "\n..." for i in range(start_idx, message_id_idx): msg_text = messages[i].get("text", "") msg_sender = messages[i].get("user", "") thread_text += f"\n\n<@{msg_sender}>: {msg_text}" msg_text = messages[message_id_idx].get("text", "") msg_sender = messages[message_id_idx].get("user", "") thread_text += f"\n\n<@{msg_sender}>: {msg_text}" # Add following replies len_replies = 0 for msg in messages[message_id_idx + 1 :]: msg_text = msg.get("text", "") msg_sender = msg.get("user", "") reply = f"\n\n<@{msg_sender}>: {msg_text}" thread_text += reply len_replies += len(reply) if len_replies >= DOC_EMBEDDING_CONTEXT_SIZE * 4: thread_text += "\n..." break # Replace user IDs with names using cached lookups userids: set[str] = set(re.findall(r"<@([A-Z0-9]+)>", thread_text)) if team_id: user_profiles = batch_get_user_profiles(access_token, team_id, userids) for userid, name in user_profiles.items(): thread_text = thread_text.replace(f"<@{userid}>", name) else: for userid in userids: try: response = slack_client.users_profile_get(user=userid) response.validate() profile: dict[str, Any] = response.get("profile", {}) user_name: str | None = profile.get("real_name") or profile.get("email") except SlackApiError as e: if "user_not_found" in str(e): logger.debug( f"User {userid} not found (likely deleted/deactivated)" ) else: logger.warning(f"Could not fetch profile for user {userid}: {e}") continue if not user_name: continue thread_text = thread_text.replace(f"<@{userid}>", user_name) return thread_text def fetch_thread_contexts_with_rate_limit_handling( slack_messages: list[SlackMessage], access_token: str, team_id: str | None, batch_size: int = SLACK_THREAD_CONTEXT_BATCH_SIZE, max_messages: int | None = MAX_SLACK_THREAD_CONTEXT_MESSAGES, ) -> list[str]: """ Fetch thread contexts in controlled batches, stopping on rate limit. Distinguishes between error types: - Rate limit (429): Stop processing further batches - Other errors: Continue processing (graceful degradation) Args: slack_messages: Messages to fetch thread context for (should be sorted by relevance) access_token: Slack OAuth token team_id: Slack team ID for user profile caching batch_size: Number of concurrent API calls per batch max_messages: Maximum messages to fetch thread context for (None = no limit) Returns: List of thread texts, one per input message. Messages beyond max_messages or after rate limit get their original text. """ if not slack_messages: return [] # Limit how many messages we fetch thread context for (if max_messages is set) if max_messages and max_messages < len(slack_messages): messages_for_context = slack_messages[:max_messages] messages_without_context = slack_messages[max_messages:] else: messages_for_context = slack_messages messages_without_context = [] logger.info( f"Fetching thread context for {len(messages_for_context)} of {len(slack_messages)} messages " f"(batch_size={batch_size}, max={max_messages or 'unlimited'})" ) results: list[str] = [] rate_limited = False total_batches = (len(messages_for_context) + batch_size - 1) // batch_size rate_limit_batch = 0 # Process in batches for i in range(0, len(messages_for_context), batch_size): current_batch = i // batch_size + 1 if rate_limited: # Skip remaining batches, use original message text remaining = messages_for_context[i:] skipped_batches = total_batches - rate_limit_batch logger.warning( f"Slack rate limit: skipping {len(remaining)} remaining messages " f"({skipped_batches} of {total_batches} batches). " f"Successfully enriched {len(results)} messages before rate limit." ) results.extend([msg.text for msg in remaining]) break batch = messages_for_context[i : i + batch_size] # _fetch_thread_context returns ThreadContextResult (never raises) # allow_failures=True is a safety net for any unexpected exceptions batch_results: list[ThreadContextResult | None] = ( run_functions_tuples_in_parallel( [ ( _fetch_thread_context, (msg, access_token, team_id), ) for msg in batch ], allow_failures=True, max_workers=batch_size, ) ) # Process results - ThreadContextResult tells us exactly what happened for j, result in enumerate(batch_results): if result is None: # Unexpected exception (shouldn't happen) - use original text, stop logger.error(f"Unexpected None result for message {j} in batch") results.append(batch[j].text) rate_limited = True rate_limit_batch = current_batch elif result.is_rate_limited: # Rate limit hit - use original text, stop further batches results.append(result.text) rate_limited = True rate_limit_batch = current_batch else: # Success or recoverable error - use the text (enriched or original) results.append(result.text) if rate_limited: logger.warning( f"Slack rate limit (429) hit at batch {current_batch}/{total_batches} " f"while fetching thread context. Stopping further API calls." ) # Add original text for messages we didn't fetch context for results.extend([msg.text for msg in messages_without_context]) return results def convert_slack_score(slack_score: float) -> float: """ Convert slack score to a score between 0 and 1. Will affect UI ordering and LLM ordering, but not the pruning. I.e., should have very little effect on the search/answer quality. """ return max(0.0, min(1.0, slack_score / 90_000)) @log_function_time(print_only=True) def slack_retrieval( query: ChunkIndexRequest, access_token: str, db_session: Session | None = None, connector: FederatedConnectorDetail | None = None, # noqa: ARG001 entities: dict[str, Any] | None = None, limit: int | None = None, slack_event_context: SlackContext | None = None, bot_token: str | None = None, # Add bot token parameter team_id: str | None = None, # Pre-fetched data — when provided, avoids DB query (no session needed) search_settings: SearchSettings | None = None, ) -> list[InferenceChunk]: """ Main entry point for Slack federated search with entity filtering. Applies entity filtering including: - Channel selection and exclusion - Date range extraction and enforcement - DM/private channel filtering - Multi-layer caching Args: query: Search query object access_token: User OAuth access token db_session: Database session (optional if search_settings provided) connector: Federated connector detail (unused, kept for backwards compat) entities: Connector-level config (entity filtering configuration) limit: Maximum number of results slack_event_context: Context when called from Slack bot bot_token: Bot token for enhanced permissions team_id: Slack team/workspace ID Returns: List of InferenceChunk objects """ # Use connector-level config entities = entities or {} if not entities: logger.debug("No entity configuration found, using defaults") else: logger.debug(f"Using entity configuration: {entities}") # Extract limit from entity config if not explicitly provided query_limit = limit if entities: try: parsed_entities = SlackEntities(**entities) if limit is None: query_limit = parsed_entities.max_messages_per_query logger.debug(f"Using max_messages_per_query from config: {query_limit}") except Exception as e: logger.warning(f"Error parsing entities for limit: {e}") if limit is None: query_limit = 100 # Fallback default elif limit is None: query_limit = 100 # Default when no entities and no limit provided # Pre-fetch channel metadata from Redis cache and extract available channels # This avoids repeated Redis lookups during parallel search execution available_channels = None channel_metadata_dict = None if team_id: # Always fetch all channel types (include_private=True) to ensure single cache entry channel_metadata_dict = fetch_and_cache_channel_metadata( access_token, team_id, include_private=True ) # Extract available channels list if needed for pattern matching available_channels = _extract_channel_data_from_entities( entities, channel_metadata_dict ) # Query slack with entity filtering llm = get_default_llm() query_strings = build_slack_queries(query, llm, entities, available_channels) # Determine filtering based on entities OR context (bot) include_dm = False allowed_private_channel = None # Bot context overrides (if entities not specified) if slack_event_context and not entities: channel_type = slack_event_context.channel_type if channel_type == ChannelType.IM: # DM with user include_dm = True if channel_type == ChannelType.PRIVATE_CHANNEL: allowed_private_channel = slack_event_context.channel_id logger.debug( f"Private channel context: will only allow messages from {allowed_private_channel} + public channels" ) # Build search tasks search_tasks = [ ( query_slack, ( query_string, access_token, query_limit, allowed_private_channel, bot_token, include_dm, entities, available_channels, channel_metadata_dict, ), ) for query_string in query_strings ] # If include_dm is True AND we're not already searching all channels, # add additional searches without channel filters. # This allows searching DMs/group DMs while still searching the specified channels. # Skip this if search_all_channels is already True (would be duplicate queries). if ( entities and entities.get("include_dm") and not entities.get("search_all_channels") ): # Create a minimal entities dict that won't add channel filters # This ensures we search ALL conversations (DMs, group DMs, private channels) # BUT we still want to exclude channels specified in exclude_channels dm_entities = { "include_dm": True, "include_private_channels": entities.get("include_private_channels", False), "default_search_days": entities.get("default_search_days", 30), "search_all_channels": True, "channels": None, "exclude_channels": entities.get( "exclude_channels" ), # ALWAYS apply exclude_channels } for query_string in query_strings: search_tasks.append( ( query_slack, ( query_string, access_token, query_limit, allowed_private_channel, bot_token, include_dm, dm_entities, available_channels, channel_metadata_dict, ), ) ) # Execute searches in parallel results = run_functions_tuples_in_parallel(search_tasks) # Calculate stats for consolidated logging total_raw_messages = sum(len(r.messages) for r in results) # Merge and post-filter results slack_messages, docid_to_message, query_filtered_channels = merge_slack_messages( results ) messages_after_dedup = len(slack_messages) # Post-filter by channel type (DM, private channel, etc.) # NOTE: We must post-filter because Slack's search.messages API only supports # filtering by channel NAME (via in:#channel syntax), not by channel TYPE. # There's no way to specify "only public channels" or "exclude DMs" in the query. # Start with channels filtered during query execution, then add post-filter channels filtered_out_channels: set[str] = set(query_filtered_channels) if entities and team_id: # Use pre-fetched channel metadata to avoid cache misses # Pass it directly instead of relying on Redis cache filtered_messages = [] for msg in slack_messages: # Pass pre-fetched metadata to avoid cache lookups channel_type = get_channel_type( channel_id=msg.channel_id, channel_metadata=channel_metadata_dict, ) if should_include_message(channel_type, entities): filtered_messages.append(msg) else: # Track unique channel name for summary channel_name = msg.metadata.get("channel", msg.channel_id) filtered_out_channels.add(f"{channel_name}({msg.channel_id})") slack_messages = filtered_messages slack_messages = slack_messages[: limit or len(slack_messages)] # Log consolidated summary with request ID for correlation request_id = ( slack_event_context.message_ts[:10] if slack_event_context and slack_event_context.message_ts else "no-ctx" ) logger.info( f"[req:{request_id}] Slack federated search: {len(search_tasks)} queries, " f"{total_raw_messages} raw msgs -> {messages_after_dedup} after dedup -> " f"{len(slack_messages)} final" + ( f", filtered channels: {sorted(filtered_out_channels)}" if filtered_out_channels else "" ) ) if not slack_messages: return [] # Fetch thread context with rate limit handling and message limiting # Messages are already sorted by relevance (slack_score), so top N get full context thread_texts = fetch_thread_contexts_with_rate_limit_handling( slack_messages=slack_messages, access_token=access_token, team_id=team_id, ) for slack_message, thread_text in zip(slack_messages, thread_texts): slack_message.text = thread_text # get the highlighted texts from shortest to longest highlighted_texts: set[str] = set() for slack_message in slack_messages: highlighted_texts.update(slack_message.highlighted_texts) sorted_highlighted_texts = sorted(highlighted_texts, key=len) # For queries without highlights (e.g., empty recency queries), we should keep all chunks has_highlights = len(sorted_highlighted_texts) > 0 # convert slack messages to index documents index_docs: list[IndexingDocument] = [] for slack_message in slack_messages: section: TextSection = TextSection( text=slack_message.text, link=slack_message.link ) index_docs.append( IndexingDocument( id=slack_message.document_id, sections=[section], processed_sections=[section], source=DocumentSource.SLACK, title=slack_message.semantic_identifier, semantic_identifier=slack_message.semantic_identifier, metadata=slack_message.metadata, doc_updated_at=slack_message.timestamp, ) ) # chunk index docs into doc aware chunks # a single index doc can get split into multiple chunks if search_settings is None: if db_session is None: raise ValueError("Either db_session or search_settings must be provided") search_settings = get_current_search_settings(db_session) embedder = DefaultIndexingEmbedder.from_db_search_settings( search_settings=search_settings ) multipass_config = get_multipass_config(search_settings) enable_contextual_rag = ( search_settings.enable_contextual_rag or ENABLE_CONTEXTUAL_RAG ) chunker = Chunker( tokenizer=embedder.embedding_model.tokenizer, enable_multipass=multipass_config.multipass_indexing, enable_large_chunks=multipass_config.enable_large_chunks, enable_contextual_rag=enable_contextual_rag, ) chunks = chunker.chunk(index_docs) # prune chunks without any highlighted texts # BUT: for recency queries without keywords, keep all chunks relevant_chunks: list[DocAwareChunk] = [] chunkid_to_match_highlight: dict[str, str] = {} if not has_highlights: # No highlighted terms - keep all chunks (recency query) for chunk in chunks: chunk_id = f"{chunk.source_document.id}__{chunk.chunk_id}" relevant_chunks.append(chunk) chunkid_to_match_highlight[chunk_id] = chunk.content # No highlighting if limit and len(relevant_chunks) >= limit: break else: # Prune chunks that don't contain highlighted terms for chunk in chunks: match_highlight = chunk.content for highlight in sorted_highlighted_texts: # faster than re sub match_highlight = match_highlight.replace( highlight, f"{highlight}" ) # if nothing got replaced, the chunk is irrelevant if len(match_highlight) == len(chunk.content): continue chunk_id = f"{chunk.source_document.id}__{chunk.chunk_id}" relevant_chunks.append(chunk) chunkid_to_match_highlight[chunk_id] = match_highlight if limit and len(relevant_chunks) >= limit: break # convert to inference chunks top_chunks: list[InferenceChunk] = [] for chunk in relevant_chunks: document_id = chunk.source_document.id chunk_id = f"{document_id}__{chunk.chunk_id}" top_chunks.append( InferenceChunk( chunk_id=chunk.chunk_id, blurb=chunk.blurb, content=chunk.content, source_links=chunk.source_links, image_file_id=chunk.image_file_id, section_continuation=chunk.section_continuation, semantic_identifier=docid_to_message[document_id].semantic_identifier, document_id=document_id, source_type=DocumentSource.SLACK, title=chunk.title_prefix, boost=0, score=convert_slack_score(docid_to_message[document_id].slack_score), hidden=False, is_relevant=None, relevance_explanation="", metadata=docid_to_message[document_id].metadata, match_highlights=[chunkid_to_match_highlight[chunk_id]], doc_summary="", chunk_context="", updated_at=docid_to_message[document_id].timestamp, is_federated=True, ) ) return top_chunks ================================================ FILE: backend/onyx/context/search/federated/slack_search_utils.py ================================================ import fnmatch import json import re from datetime import datetime from datetime import timedelta from datetime import timezone from typing import Any from pydantic import ValidationError from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS from onyx.context.search.federated.models import ChannelMetadata from onyx.context.search.models import ChunkIndexRequest from onyx.federated_connectors.slack.models import SlackEntities from onyx.llm.interfaces import LLM from onyx.llm.models import UserMessage from onyx.llm.utils import llm_response_to_string from onyx.natural_language_processing.english_stopwords import ENGLISH_STOPWORDS_SET from onyx.onyxbot.slack.models import ChannelType from onyx.prompts.federated_search import SLACK_DATE_EXTRACTION_PROMPT from onyx.prompts.federated_search import SLACK_QUERY_EXPANSION_PROMPT from onyx.tracing.llm_utils import llm_generation_span from onyx.tracing.llm_utils import record_llm_response from onyx.utils.logger import setup_logger logger = setup_logger() # Constants for date extraction heuristics DEFAULT_RECENCY_DAYS = 7 DEFAULT_LATELY_DAYS = 14 DAYS_PER_WEEK = 7 DAYS_PER_MONTH = 30 MAX_CONTENT_WORDS = 3 # Punctuation to strip from words during analysis WORD_PUNCTUATION = ".,!?;:\"'#" RECENCY_KEYWORDS = ["recent", "latest", "newest", "last"] # All Slack channel types for fetching metadata ALL_CHANNEL_TYPES = [ ChannelType.PUBLIC_CHANNEL.value, ChannelType.IM.value, ChannelType.MPIM.value, ChannelType.PRIVATE_CHANNEL.value, ] # Map Slack API scopes to their corresponding channel types # This is used for graceful degradation when scopes are missing SCOPE_TO_CHANNEL_TYPE_MAP = { "mpim:read": ChannelType.MPIM.value, "mpim:history": ChannelType.MPIM.value, "im:read": ChannelType.IM.value, "im:history": ChannelType.IM.value, "groups:read": ChannelType.PRIVATE_CHANNEL.value, "groups:history": ChannelType.PRIVATE_CHANNEL.value, "channels:read": ChannelType.PUBLIC_CHANNEL.value, "channels:history": ChannelType.PUBLIC_CHANNEL.value, } def get_channel_type_for_missing_scope(scope: str) -> str | None: """Get the channel type that requires a specific Slack scope. Args: scope: The Slack API scope (e.g., 'mpim:read', 'im:history') Returns: The channel type string if scope is recognized, None otherwise Examples: >>> get_channel_type_for_missing_scope('mpim:read') 'mpim' >>> get_channel_type_for_missing_scope('im:read') 'im' >>> get_channel_type_for_missing_scope('unknown:scope') None """ return SCOPE_TO_CHANNEL_TYPE_MAP.get(scope) def _parse_llm_code_block_response(response: str) -> str: """Remove code block markers from LLM response if present. Handles responses wrapped in triple backticks (```) by removing the opening and closing markers. Args: response: Raw LLM response string Returns: Cleaned response with code block markers removed """ response_clean = response.strip() if response_clean.startswith("```"): lines = response_clean.split("\n") lines = lines[1:] if lines and lines[-1].strip() == "```": lines = lines[:-1] response_clean = "\n".join(lines) return response_clean def is_recency_query(query: str) -> bool: """Check if a query is primarily about recency (not content + recency). Returns True only for pure recency queries like "recent messages" or "latest updates", but False for queries with content + recency like "golf scores last saturday". """ # Check if query contains recency keywords has_recency_keyword = any( re.search(rf"\b{re.escape(keyword)}\b", query, flags=re.IGNORECASE) for keyword in RECENCY_KEYWORDS ) if not has_recency_keyword: return False # Get combined stop words (English + Slack-specific) all_stop_words = _get_combined_stop_words() # Extract content words (excluding stop words) query_lower = query.lower() words = query_lower.split() # Count content words (not stop words, length > 2) content_word_count = 0 for word in words: clean_word = word.strip(WORD_PUNCTUATION) if clean_word and len(clean_word) > 2 and clean_word not in all_stop_words: content_word_count += 1 # If query has significant content words (>= 2), it's not a pure recency query # Examples: # - "recent messages" -> content_word_count = 0 -> pure recency # - "golf scores last saturday" -> content_word_count = 3 (golf, scores, saturday) -> not pure recency return content_word_count < 2 def extract_date_range_from_query( query: str, llm: LLM, default_search_days: int, ) -> int: query_lower = query.lower() if re.search(r"\btoday(?:\'?s)?\b", query_lower): return 0 if re.search(r"\byesterday\b", query_lower): return min(1, default_search_days) # Handle "last [day of week]" - e.g., "last monday", "last saturday" days_of_week = [ "monday", "tuesday", "wednesday", "thursday", "friday", "saturday", "sunday", ] for day in days_of_week: if re.search(rf"\b(?:last|this)\s+{day}\b", query_lower): # Assume last occurrence of that day was within the past week return min(DAYS_PER_WEEK, default_search_days) match = re.search(r"\b(?:last|past)\s+(\d+)\s+days?\b", query_lower) if match: days = int(match.group(1)) return min(days, default_search_days) if re.search(r"\b(?:last|past|this)\s+week\b", query_lower): return min(DAYS_PER_WEEK, default_search_days) match = re.search(r"\b(?:last|past)\s+(\d+)\s+weeks?\b", query_lower) if match: weeks = int(match.group(1)) return min(weeks * DAYS_PER_WEEK, default_search_days) if re.search(r"\b(?:last|past|this)\s+month\b", query_lower): return min(DAYS_PER_MONTH, default_search_days) match = re.search(r"\b(?:last|past)\s+(\d+)\s+months?\b", query_lower) if match: months = int(match.group(1)) return min(months * DAYS_PER_MONTH, default_search_days) if re.search(r"\brecent(?:ly)?\b", query_lower): return min(DEFAULT_RECENCY_DAYS, default_search_days) if re.search(r"\blately\b", query_lower): return min(DEFAULT_LATELY_DAYS, default_search_days) try: prompt = SLACK_DATE_EXTRACTION_PROMPT.format(query=query) prompt_msg = UserMessage(content=prompt) # Call LLM with Braintrust tracing with llm_generation_span( llm=llm, flow="slack_date_extraction", input_messages=[prompt_msg] ) as span_generation: llm_response = llm.invoke(prompt_msg) record_llm_response(span_generation, llm_response) response = llm_response_to_string(llm_response) response_clean = _parse_llm_code_block_response(response) try: data = json.loads(response_clean) if not isinstance(data, dict): logger.debug( f"LLM date extraction returned non-dict response for query: " f"'{query}', using default: {default_search_days} days" ) return default_search_days days_back = data.get("days_back") if days_back is None: logger.debug( f"LLM date extraction returned null for query: '{query}', using default: {default_search_days} days" ) return default_search_days if not isinstance(days_back, (int, float)): logger.debug( f"LLM date extraction returned non-numeric days_back for " f"query: '{query}', using default: {default_search_days} days" ) return default_search_days except json.JSONDecodeError: logger.debug( f"Failed to parse LLM date extraction response for query: '{query}' " f"(response: '{response_clean}'), " f"using default: {default_search_days} days" ) return default_search_days return min(int(days_back), default_search_days) except Exception as e: logger.warning(f"Error extracting date range with LLM for query '{query}': {e}") return default_search_days def matches_exclude_pattern(channel_name: str, patterns: list[str]) -> bool: if not patterns: return False channel_norm = channel_name.lower().strip().lstrip("#") for pattern in patterns: pattern_norm = pattern.lower().strip().lstrip("#") if fnmatch.fnmatch(channel_norm, pattern_norm): return True return False def build_channel_query_filter( parsed_entities: SlackEntities | dict[str, Any], available_channels: list[str] | None = None, ) -> str: # Parse entities if dict try: if isinstance(parsed_entities, dict): entities = SlackEntities(**parsed_entities) else: entities = parsed_entities except ValidationError: return "" search_all_channels = entities.search_all_channels if search_all_channels: if not entities.exclude_channels: return "" # Can't apply exclusions without available_channels if not available_channels: return "" excluded_channels = [ ch for ch in available_channels if matches_exclude_pattern(ch, entities.exclude_channels) ] normalized_excluded = [ch.lstrip("#") for ch in excluded_channels] exclusion_filters = [f"-in:#{channel}" for channel in normalized_excluded] return " ".join(exclusion_filters) if not entities.channels: return "" included_channels: list[str] = [] for pattern in entities.channels: pattern_norm = pattern.lstrip("#") if "*" in pattern_norm or "?" in pattern_norm: # Glob patterns require available_channels if available_channels: matching = [ ch for ch in available_channels if fnmatch.fnmatch(ch.lstrip("#").lower(), pattern_norm.lower()) ] included_channels.extend(matching) else: # Exact match: use directly or verify against available_channels if not available_channels or pattern_norm in [ ch.lstrip("#") for ch in available_channels ]: included_channels.append(pattern_norm) # Apply exclusions to included channels if entities.exclude_channels: included_channels = [ ch for ch in included_channels if not matches_exclude_pattern(ch, entities.exclude_channels) ] if not included_channels: return "" normalized_channels = [ch.lstrip("#") for ch in included_channels] filters = [f"in:#{channel}" for channel in normalized_channels] return " ".join(filters) def get_channel_type( channel_info: dict[str, Any] | None = None, channel_id: str | None = None, channel_metadata: dict[str, ChannelMetadata] | None = None, ) -> ChannelType: """ Determine channel type from channel info dict or by looking up channel_id. Args: channel_info: Channel info dict from Slack API (direct mode) channel_id: Channel ID to look up (lookup mode) channel_metadata: Pre-fetched metadata dict (for lookup mode) Returns: ChannelType enum """ if channel_info is not None: if channel_info.get("is_im"): return ChannelType.IM if channel_info.get("is_mpim"): return ChannelType.MPIM if channel_info.get("is_private"): return ChannelType.PRIVATE_CHANNEL return ChannelType.PUBLIC_CHANNEL # Lookup mode: get type from pre-fetched metadata if channel_id and channel_metadata: ch_meta = channel_metadata.get(channel_id) if ch_meta: type_str = ch_meta.get("type") if type_str == ChannelType.IM.value: return ChannelType.IM elif type_str == ChannelType.MPIM.value: return ChannelType.MPIM elif type_str == ChannelType.PRIVATE_CHANNEL.value: return ChannelType.PRIVATE_CHANNEL return ChannelType.PUBLIC_CHANNEL return ChannelType.PUBLIC_CHANNEL def should_include_message(channel_type: ChannelType, entities: dict[str, Any]) -> bool: include_dm = entities.get("include_dm", False) include_group_dm = entities.get("include_group_dm", False) include_private = entities.get("include_private_channels", False) if channel_type == ChannelType.IM: return include_dm if channel_type == ChannelType.MPIM: return include_group_dm if channel_type == ChannelType.PRIVATE_CHANNEL: return include_private return True def extract_channel_references_from_query(query_text: str) -> set[str]: """Extract channel names referenced in the query text. Only matches explicit channel references with prepositions or # symbols: - "in the office channel" - "from the office channel" - "in #office" - "from #office" Does NOT match generic phrases like "slack discussions" or "team channel". Args: query_text: The user's query text Returns: Set of channel names (without # prefix) """ channel_references = set() query_lower = query_text.lower() # Only match channels with explicit prepositions (in/from) or # prefix # This prevents false positives like "slack discussions" being interpreted as channel "slack" channel_patterns = [ r"\bin\s+(?:the\s+)?([a-z0-9_-]+)\s+(?:slack\s+)?channels?\b", # "in the office channel" r"\bfrom\s+(?:the\s+)?([a-z0-9_-]+)\s+(?:slack\s+)?channels?\b", # "from the office channel" r"\bin[:\s]*#([a-z0-9_-]+)\b", # "in #office" or "in:#office" r"\bfrom[:\s]*#([a-z0-9_-]+)\b", # "from #office" or "from:#office" ] for pattern in channel_patterns: matches = re.finditer(pattern, query_lower) for match in matches: channel_references.add(match.group(1)) return channel_references def validate_channel_references( channel_references: set[str], entities: dict[str, Any], available_channels: list[str] | None, ) -> None: """Validate that referenced channels exist and are allowed by entity config. Args: channel_references: Set of channel names extracted from query entities: Entity configuration dict available_channels: List of available channel names in workspace Raises: ValueError: If channel doesn't exist, is excluded, or not in inclusion list """ if not channel_references or not entities: return try: parsed_entities = SlackEntities(**entities) for channel_name in channel_references: # Check if channel exists if available_channels is not None: # Normalize for comparison (available_channels may or may not have #) normalized_available = [ ch.lstrip("#").lower() for ch in available_channels ] if channel_name.lower() not in normalized_available: raise ValueError( f"Channel '{channel_name}' does not exist in your Slack workspace. " f"Please check the channel name and try again." ) # Check if channel is in exclusion list if parsed_entities.exclude_channels: if matches_exclude_pattern( channel_name, parsed_entities.exclude_channels ): raise ValueError( f"Channel '{channel_name}' is excluded from search by your configuration. " f"Please update your connector settings to search this channel." ) # Check if channel is in inclusion list (when search_all_channels is False) if not parsed_entities.search_all_channels: if parsed_entities.channels: # Normalize channel lists for comparison normalized_channels = [ ch.lstrip("#").lower() for ch in parsed_entities.channels ] if channel_name.lower() not in normalized_channels: raise ValueError( f"Channel '{channel_name}' is not in your configured channel list. " f"Please update your connector settings to include this channel." ) except ValidationError: # If entities are malformed, skip validation pass def build_channel_override_query(channel_references: set[str], time_filter: str) -> str: """Build a Slack query with ONLY channel filters and time filter (no keywords). Args: channel_references: Set of channel names to search time_filter: Time filter string (e.g., " after:2025-11-07") Returns: Query string with __CHANNEL_OVERRIDE__ marker """ normalized_channels = [ch.lstrip("#") for ch in channel_references] channel_filter = " ".join([f"in:#{channel}" for channel in normalized_channels]) return f"__CHANNEL_OVERRIDE__ {channel_filter}{time_filter}" # Slack-specific stop words (in addition to standard English stop words) # These include Slack-specific terms and temporal/recency keywords SLACK_SPECIFIC_STOP_WORDS = frozenset( RECENCY_KEYWORDS + [ "dm", "dms", "message", "messages", "channel", "channels", "slack", "post", "posted", "posting", "sent", ] ) def _get_combined_stop_words() -> frozenset[str]: """Get combined English + Slack-specific stop words. Returns a frozenset of stop words for filtering content words. Note: Currently only supports English stop words. Non-English queries may have suboptimal content word extraction. Future enhancement could detect query language and load appropriate stop words. """ return ENGLISH_STOPWORDS_SET | SLACK_SPECIFIC_STOP_WORDS def extract_content_words_from_recency_query( query_text: str, channel_references: set[str] ) -> list[str]: """Extract meaningful content words from a recency query. Filters out English stop words, Slack-specific terms, channel references, and proper nouns. Args: query_text: The user's query text channel_references: Channel names to exclude from content words Returns: List of content words (up to MAX_CONTENT_WORDS) """ # Get combined stop words (English + Slack-specific) all_stop_words = _get_combined_stop_words() words = query_text.split() content_words = [] for word in words: clean_word = word.lower().strip(WORD_PUNCTUATION) # Skip if it's a channel reference or a stop word if clean_word in channel_references: continue if clean_word and clean_word not in all_stop_words and len(clean_word) > 2: clean_word_orig = word.strip(WORD_PUNCTUATION) if clean_word_orig.lower() not in all_stop_words: content_words.append(clean_word_orig) # Filter out proper nouns (capitalized words) content_words_filtered = [word for word in content_words if not word[0].isupper()] return content_words_filtered[:MAX_CONTENT_WORDS] def _is_valid_keyword_query(line: str) -> bool: """Check if a line looks like a valid keyword query vs explanatory text. Returns False for lines that appear to be LLM explanations rather than keywords. """ # Reject lines that start with parentheses (explanatory notes) if line.startswith("("): return False # Reject lines that are too long (likely sentences, not keywords) # Keywords should be short - reject if > 50 chars or > 6 words if len(line) > 50 or len(line.split()) > 6: return False return True def expand_query_with_llm(query_text: str, llm: LLM) -> list[str]: """Use LLM to expand query into multiple search variations. Args: query_text: The user's original query llm: LLM instance to use for expansion Returns: List of rephrased query strings (up to MAX_SLACK_QUERY_EXPANSIONS) """ prompt = UserMessage( content=SLACK_QUERY_EXPANSION_PROMPT.format( query=query_text, max_queries=MAX_SLACK_QUERY_EXPANSIONS ) ) try: # Call LLM with Braintrust tracing with llm_generation_span( llm=llm, flow="slack_query_expansion", input_messages=[prompt] ) as span_generation: llm_response = llm.invoke(prompt) record_llm_response(span_generation, llm_response) response = llm_response_to_string(llm_response) response_clean = _parse_llm_code_block_response(response) # Split into lines and filter out empty lines raw_queries = [ line.strip() for line in response_clean.split("\n") if line.strip() ] # Filter out lines that look like explanatory text rather than keywords rephrased_queries = [q for q in raw_queries if _is_valid_keyword_query(q)] # Log if we filtered out garbage if len(raw_queries) != len(rephrased_queries): filtered_out = set(raw_queries) - set(rephrased_queries) logger.warning(f"Filtered out non-keyword LLM responses: {filtered_out}") # If no queries generated, use empty query if not rephrased_queries: logger.debug("No content keywords extracted from query expansion") return [""] logger.debug( f"Expanded query into {len(rephrased_queries)} queries: {rephrased_queries}" ) return rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS] except Exception as e: logger.error(f"Error expanding query: {e}") return [query_text] def build_slack_queries( query: ChunkIndexRequest, llm: LLM, entities: dict[str, Any] | None = None, available_channels: list[str] | None = None, ) -> list[str]: """Build Slack query strings with date filtering and query expansion.""" default_search_days = 30 if entities: try: parsed_entities = SlackEntities(**entities) default_search_days = parsed_entities.default_search_days except ValidationError as e: logger.warning(f"Invalid entities in build_slack_queries: {e}") days_back = extract_date_range_from_query( query=query.query, llm=llm, default_search_days=default_search_days, ) # get time filter time_filter = "" if days_back is not None and days_back >= 0: if days_back == 0: time_filter = " on:today" else: cutoff_date = datetime.now(timezone.utc) - timedelta(days=days_back) time_filter = f" after:{cutoff_date.strftime('%Y-%m-%d')}" # ALWAYS extract channel references from the query (not just for recency queries) channel_references = extract_channel_references_from_query(query.query) # Validate channel references against available channels and entity config # This will raise ValueError if channels are invalid if channel_references and entities: try: validate_channel_references( channel_references, entities, available_channels ) logger.info( f"Detected and validated channel references: {channel_references}" ) # If valid channels detected, use ONLY those channels with NO keywords # Return query with ONLY time filter + channel filter (no keywords) return [build_channel_override_query(channel_references, time_filter)] except ValueError as e: # If validation fails, log the error and continue with normal flow logger.warning(f"Channel reference validation failed: {e}") channel_references = set() # use llm to generate slack queries (use original query to use same keywords as the user) if is_recency_query(query.query): # For recency queries, extract content words (excluding channel names and stop words) content_words = extract_content_words_from_recency_query( query.query, channel_references ) rephrased_queries = [" ".join(content_words)] if content_words else [""] else: # For other queries, use LLM to expand into multiple variations rephrased_queries = expand_query_with_llm(query.query, llm) # Build final query strings with time filters return [ rephrased_query.strip() + time_filter for rephrased_query in rephrased_queries[:MAX_SLACK_QUERY_EXPANSIONS] ] ================================================ FILE: backend/onyx/context/search/models.py ================================================ from collections.abc import Sequence from datetime import datetime from enum import Enum from typing import Any from pydantic import BaseModel from pydantic import Field from onyx.configs.constants import DocumentSource from onyx.db.models import SearchSettings from onyx.indexing.models import BaseChunk from onyx.indexing.models import IndexingSetting from onyx.tools.tool_implementations.web_search.models import WEB_SEARCH_PREFIX class QueryExpansions(BaseModel): keywords_expansions: list[str] | None = None semantic_expansions: list[str] | None = None class QueryExpansionType(Enum): KEYWORD = "keyword" SEMANTIC = "semantic" class SearchSettingsCreationRequest(IndexingSetting): @classmethod def from_db_model( cls, search_settings: SearchSettings ) -> "SearchSettingsCreationRequest": indexing_setting = IndexingSetting.from_db_model(search_settings) return cls(**indexing_setting.model_dump()) class SavedSearchSettings(IndexingSetting): # Previously this contained also Inference time settings. Keeping this wrapper class around # as there may again be inference time settings that may get added. @classmethod def from_db_model(cls, search_settings: SearchSettings) -> "SavedSearchSettings": return cls( # Indexing Setting model_name=search_settings.model_name, model_dim=search_settings.model_dim, normalize=search_settings.normalize, query_prefix=search_settings.query_prefix, passage_prefix=search_settings.passage_prefix, provider_type=search_settings.provider_type, index_name=search_settings.index_name, multipass_indexing=search_settings.multipass_indexing, embedding_precision=search_settings.embedding_precision, reduced_dimension=search_settings.reduced_dimension, switchover_type=search_settings.switchover_type, enable_contextual_rag=search_settings.enable_contextual_rag, contextual_rag_llm_name=search_settings.contextual_rag_llm_name, contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider, ) class Tag(BaseModel): tag_key: str tag_value: str class BaseFilters(BaseModel): source_type: list[DocumentSource] | None = None document_set: list[str] | None = None time_cutoff: datetime | None = None tags: list[Tag] | None = None class UserFileFilters(BaseModel): # Scopes search to user files tagged with a given project/persona in Vespa. # These are NOT simply the IDs of the current project or persona — they are # only set when the persona's/project's user files overflowed the LLM # context window and must be searched via vector DB instead of being loaded # directly into the prompt. project_id_filter: int | None = None persona_id_filter: int | None = None class AssistantKnowledgeFilters(BaseModel): """Filters for knowledge attached to an assistant (persona). These filters scope search to documents/folders explicitly attached to the assistant. When present, only documents matching these criteria are searched (in addition to ACL filtering). """ # Document IDs explicitly attached to the assistant attached_document_ids: list[str] | None = None # Hierarchy node IDs (folders/spaces) attached to the assistant. # Matches chunks where ancestor_hierarchy_node_ids contains any of these. hierarchy_node_ids: list[int] | None = None class IndexFilters(BaseFilters, UserFileFilters, AssistantKnowledgeFilters): # NOTE: These strings must be formatted in the same way as the output of # DocumentAccess::to_acl. access_control_list: list[str] | None tenant_id: str | None = None class BasicChunkRequest(BaseModel): query: str # In case the caller wants to override the weighting between semantic and keyword search. hybrid_alpha: float | None = None # In case some queries favor recency more than other queries. recency_bias_multiplier: float = 1.0 limit: int | None = None class ChunkSearchRequest(BasicChunkRequest): # Final filters are calculated from these user_selected_filters: BaseFilters | None = None # Use with caution! bypass_acl: bool = False # From the Chat Session we know what project (if any) this search should include # From the user uploads and persona uploaded files, we know which of those to include class ChunkIndexRequest(BasicChunkRequest): # Calculated final filters filters: IndexFilters query_keywords: list[str] | None = None class ContextExpansionType(str, Enum): NOT_RELEVANT = "not_relevant" MAIN_SECTION_ONLY = "main_section_only" INCLUDE_ADJACENT_SECTIONS = "include_adjacent_sections" FULL_DOCUMENT = "full_document" class InferenceChunk(BaseChunk): document_id: str source_type: DocumentSource semantic_identifier: str title: str | None # Separate from Semantic Identifier though often same boost: int score: float | None hidden: bool is_relevant: bool | None = None relevance_explanation: str | None = None # TODO(andrei): Ideally we could improve this to where each value is just a # list of strings. metadata: dict[str, str | list[str]] # Matched sections in the chunk. Uses Vespa syntax e.g. TEXT # to specify that a set of words should be highlighted. For example: # ["the answer is 42", "he couldn't find an answer"] match_highlights: list[str] doc_summary: str chunk_context: str # when the doc was last updated updated_at: datetime | None primary_owners: list[str] | None = None secondary_owners: list[str] | None = None large_chunk_reference_ids: list[int] = Field(default_factory=list) is_federated: bool = False @property def unique_id(self) -> str: return f"{self.document_id}__{self.chunk_id}" def __repr__(self) -> str: blurb_words = self.blurb.split() short_blurb = "" for word in blurb_words: if not short_blurb: short_blurb = word continue if len(short_blurb) > 25: break short_blurb += " " + word return f"Inference Chunk: {self.document_id} - {short_blurb}..." def __eq__(self, other: Any) -> bool: if not isinstance(other, InferenceChunk): return False return (self.document_id, self.chunk_id) == (other.document_id, other.chunk_id) def __hash__(self) -> int: return hash((self.document_id, self.chunk_id)) def __lt__(self, other: Any) -> bool: if not isinstance(other, InferenceChunk): return NotImplemented if self.score is None: if other.score is None: return self.chunk_id > other.chunk_id return True if other.score is None: return False if self.score == other.score: return self.chunk_id > other.chunk_id return self.score < other.score def __gt__(self, other: Any) -> bool: if not isinstance(other, InferenceChunk): return NotImplemented if self.score is None: return False if other.score is None: return True if self.score == other.score: return self.chunk_id < other.chunk_id return self.score > other.score class InferenceChunkUncleaned(InferenceChunk): metadata_suffix: str | None def to_inference_chunk(self) -> InferenceChunk: # Create a dict of all fields except 'metadata_suffix' # Assumes the cleaning has already been applied and just needs to translate to the right type inference_chunk_data = { k: v for k, v in self.model_dump().items() if k not in ["metadata_suffix"] # May be other fields to throw out in the future } return InferenceChunk(**inference_chunk_data) class InferenceSection(BaseModel): """Section list of chunks with a combined content. A section could be a single chunk, several chunks from the same document or the entire document.""" center_chunk: InferenceChunk chunks: list[InferenceChunk] combined_content: str class SearchDoc(BaseModel): document_id: str chunk_ind: int semantic_identifier: str link: str | None = None blurb: str source_type: DocumentSource boost: int # Whether the document is hidden when doing a standard search # since a standard search will never find a hidden doc, this can only ever # be `True` when doing an admin search hidden: bool metadata: dict[str, str | list[str]] score: float | None = None is_relevant: bool | None = None relevance_explanation: str | None = None # Matched sections in the doc. Uses Vespa syntax e.g. TEXT # to specify that a set of words should be highlighted. For example: # ["the answer is 42", "the answer is 42""] match_highlights: list[str] # when the doc was last updated updated_at: datetime | None = None primary_owners: list[str] | None = None secondary_owners: list[str] | None = None is_internet: bool = False @classmethod def from_chunks_or_sections( cls, items: "Sequence[InferenceChunk | InferenceSection] | None", ) -> list["SearchDoc"]: """Convert a sequence of InferenceChunk or InferenceSection objects to SearchDoc objects.""" if not items: return [] search_docs = [ cls( document_id=( chunk := ( item.center_chunk if isinstance(item, InferenceSection) else item ) ).document_id, chunk_ind=chunk.chunk_id, semantic_identifier=chunk.semantic_identifier or "Unknown", link=chunk.source_links[0] if chunk.source_links else None, blurb=chunk.blurb, source_type=chunk.source_type, boost=chunk.boost, hidden=chunk.hidden, metadata=chunk.metadata, score=chunk.score, match_highlights=chunk.match_highlights, updated_at=chunk.updated_at, primary_owners=chunk.primary_owners, secondary_owners=chunk.secondary_owners, is_internet=False, ) for item in items ] return search_docs # TODO - there is likely a way to clean this all up and not have the switch between these @classmethod def from_saved_search_doc(cls, saved_search_doc: "SavedSearchDoc") -> "SearchDoc": """Convert a SavedSearchDoc to SearchDoc by dropping the db_doc_id field.""" saved_search_doc_data = saved_search_doc.model_dump() # Remove db_doc_id as it's not part of SearchDoc saved_search_doc_data.pop("db_doc_id", None) return cls(**saved_search_doc_data) @classmethod def from_saved_search_docs( cls, saved_search_docs: list["SavedSearchDoc"] ) -> list["SearchDoc"]: return [ cls.from_saved_search_doc(saved_search_doc) for saved_search_doc in saved_search_docs ] def model_dump(self, *args: list, **kwargs: dict[str, Any]) -> dict[str, Any]: # type: ignore initial_dict = super().model_dump(*args, **kwargs) # type: ignore initial_dict["updated_at"] = ( self.updated_at.isoformat() if self.updated_at else None ) return initial_dict class SearchDocsResponse(BaseModel): search_docs: list[SearchDoc] # Maps the citation number to the document id # Since these are no longer just links on the frontend but instead document cards, mapping it to the # document id is the most staightforward way. citation_mapping: dict[int, str] # For cases where the frontend only needs to display a subset of the search docs # The whole list is typically still needed for later steps but this set should be saved separately displayed_docs: list[SearchDoc] | None = None class SavedSearchDoc(SearchDoc): db_doc_id: int score: float | None = 0.0 @classmethod def from_search_doc( cls, search_doc: SearchDoc, db_doc_id: int = 0 ) -> "SavedSearchDoc": """IMPORTANT: careful using this and not providing a db_doc_id If db_doc_id is not provided, it won't be able to actually fetch the saved doc and info later on. So only skip providing this if the SavedSearchDoc will not be used in the future""" search_doc_data = search_doc.model_dump() search_doc_data["score"] = search_doc_data.get("score") or 0.0 return cls(**search_doc_data, db_doc_id=db_doc_id) @classmethod def from_dict(cls, data: dict[str, Any]) -> "SavedSearchDoc": """Create SavedSearchDoc from serialized dictionary data (e.g., from database JSON)""" return cls(**data) @classmethod def from_url(cls, url: str) -> "SavedSearchDoc": """Create a SavedSearchDoc from a URL for internet search documents. Uses the INTERNET_SEARCH_DOC_ prefix for document_id to match the format used by inference sections created from internet content. """ return cls( # db_doc_id can be a filler value since these docs are not saved to the database. db_doc_id=0, document_id=WEB_SEARCH_PREFIX + url, chunk_ind=0, semantic_identifier=url, link=url, blurb="", source_type=DocumentSource.WEB, boost=1, hidden=False, metadata={}, score=0.0, is_relevant=None, relevance_explanation=None, match_highlights=[], updated_at=None, primary_owners=None, secondary_owners=None, is_internet=True, ) def __lt__(self, other: Any) -> bool: if not isinstance(other, SavedSearchDoc): return NotImplemented self_score = self.score if self.score is not None else 0.0 other_score = other.score if other.score is not None else 0.0 return self_score < other_score class SavedSearchDocWithContent(SavedSearchDoc): """Used for endpoints that need to return the actual contents of the retrieved section in addition to the match_highlights.""" content: str class PersonaSearchInfo(BaseModel): """Snapshot of persona data needed by the search pipeline. Extracted from the ORM Persona before the DB session is released so that SearchTool and search_pipeline never lazy-load relationships post-commit. """ document_set_names: list[str] search_start_date: datetime | None attached_document_ids: list[str] hierarchy_node_ids: list[int] ================================================ FILE: backend/onyx/context/search/pipeline.py ================================================ from collections import defaultdict from datetime import datetime from sqlalchemy.orm import Session from onyx.context.search.models import BaseFilters from onyx.context.search.models import ChunkIndexRequest from onyx.context.search.models import ChunkSearchRequest from onyx.context.search.models import IndexFilters from onyx.context.search.models import InferenceChunk from onyx.context.search.models import InferenceSection from onyx.context.search.models import PersonaSearchInfo from onyx.context.search.preprocessing.access_filters import ( build_access_filters_for_user, ) from onyx.context.search.retrieval.search_runner import search_chunks from onyx.context.search.utils import inference_section_from_chunks from onyx.db.models import User from onyx.document_index.interfaces import DocumentIndex from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo from onyx.llm.interfaces import LLM from onyx.natural_language_processing.english_stopwords import strip_stopwords from onyx.natural_language_processing.search_nlp_models import EmbeddingModel from onyx.secondary_llm_flows.source_filter import extract_source_filter from onyx.secondary_llm_flows.time_filter import extract_time_filter from onyx.utils.logger import setup_logger from onyx.utils.threadpool_concurrency import FunctionCall from onyx.utils.threadpool_concurrency import run_functions_in_parallel from onyx.utils.timing import log_function_time from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop from shared_configs.configs import MULTI_TENANT from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() @log_function_time(print_only=True) def _build_index_filters( user_provided_filters: BaseFilters | None, user: User, # Used for ACLs, anonymous users only see public docs project_id_filter: int | None, persona_id_filter: int | None, persona_document_sets: list[str] | None, persona_time_cutoff: datetime | None, db_session: Session | None = None, auto_detect_filters: bool = False, query: str | None = None, llm: LLM | None = None, bypass_acl: bool = False, # Assistant knowledge filters attached_document_ids: list[str] | None = None, hierarchy_node_ids: list[int] | None = None, # Pre-fetched ACL filters (skips DB query when provided) acl_filters: list[str] | None = None, ) -> IndexFilters: if auto_detect_filters and (llm is None or query is None): raise RuntimeError("LLM and query are required for auto detect filters") base_filters = user_provided_filters or BaseFilters() document_set_filter = ( base_filters.document_set if base_filters.document_set is not None else persona_document_sets ) time_filter = base_filters.time_cutoff or persona_time_cutoff source_filter = base_filters.source_type detected_time_filter = None detected_source_filter = None if auto_detect_filters: time_filter_fnc = FunctionCall(extract_time_filter, (query, llm), {}) if not source_filter: source_filter_fnc = FunctionCall( extract_source_filter, (query, llm, db_session), {} ) else: source_filter_fnc = None functions_to_run = [fn for fn in [time_filter_fnc, source_filter_fnc] if fn] parallel_results = run_functions_in_parallel(functions_to_run) # Detected favor recent is not used for now detected_time_filter, _detected_favor_recent = parallel_results[ time_filter_fnc.result_id ] if source_filter_fnc: detected_source_filter = parallel_results[source_filter_fnc.result_id] # If the detected time filter is more recent, use that one if time_filter and detected_time_filter and detected_time_filter > time_filter: time_filter = detected_time_filter # If the user has explicitly set a source filter, use that one if not source_filter and detected_source_filter: source_filter = detected_source_filter if bypass_acl: user_acl_filters = None elif acl_filters is not None: user_acl_filters = acl_filters else: if db_session is None: raise ValueError("Either db_session or acl_filters must be provided") user_acl_filters = build_access_filters_for_user(user, db_session) final_filters = IndexFilters( project_id_filter=project_id_filter, persona_id_filter=persona_id_filter, source_type=source_filter, document_set=document_set_filter, time_cutoff=time_filter, tags=base_filters.tags, access_control_list=user_acl_filters, tenant_id=get_current_tenant_id() if MULTI_TENANT else None, # Assistant knowledge filters attached_document_ids=attached_document_ids, hierarchy_node_ids=hierarchy_node_ids, ) return final_filters def merge_individual_chunks( chunks: list[InferenceChunk], ) -> list[InferenceSection]: """Merge adjacent chunks from the same document into sections. Chunks are considered adjacent if their chunk_ids differ by 1 and they are from the same document. The section maintains the position of the first chunk in the original list. """ if not chunks: return [] # Create a mapping from (document_id, chunk_id) to original index # This helps us find the chunk that appears first in the original list chunk_to_original_index: dict[tuple[str, int], int] = {} for idx, chunk in enumerate(chunks): chunk_to_original_index[(chunk.document_id, chunk.chunk_id)] = idx # Group chunks by document_id doc_chunks: dict[str, list[InferenceChunk]] = defaultdict(list) for chunk in chunks: doc_chunks[chunk.document_id].append(chunk) # For each document, sort chunks by chunk_id to identify adjacent chunks for doc_id in doc_chunks: doc_chunks[doc_id].sort(key=lambda c: c.chunk_id) # Create a mapping from (document_id, chunk_id) to the section it belongs to # This helps us maintain the original order chunk_to_section: dict[tuple[str, int], InferenceSection] = {} # Process each document's chunks for doc_id, doc_chunk_list in doc_chunks.items(): if not doc_chunk_list: continue # Group adjacent chunks into sections current_section_chunks = [doc_chunk_list[0]] for i in range(1, len(doc_chunk_list)): prev_chunk = doc_chunk_list[i - 1] curr_chunk = doc_chunk_list[i] # Check if chunks are adjacent (chunk_id difference is 1) if curr_chunk.chunk_id == prev_chunk.chunk_id + 1: # Add to current section current_section_chunks.append(curr_chunk) else: # Create section from previous chunks # Find the chunk that appears first in the original list center_chunk = min( current_section_chunks, key=lambda c: chunk_to_original_index.get( (c.document_id, c.chunk_id), float("inf") ), ) section = inference_section_from_chunks( center_chunk=center_chunk, chunks=current_section_chunks.copy(), ) if section: for chunk in current_section_chunks: chunk_to_section[(chunk.document_id, chunk.chunk_id)] = section # Start new section current_section_chunks = [curr_chunk] # Create section for the last group if current_section_chunks: # Find the chunk that appears first in the original list center_chunk = min( current_section_chunks, key=lambda c: chunk_to_original_index.get( (c.document_id, c.chunk_id), float("inf") ), ) section = inference_section_from_chunks( center_chunk=center_chunk, chunks=current_section_chunks.copy(), ) if section: for chunk in current_section_chunks: chunk_to_section[(chunk.document_id, chunk.chunk_id)] = section # Build result list maintaining original order # Use (document_id, chunk_id) of center_chunk as unique identifier for sections seen_section_ids: set[tuple[str, int]] = set() result: list[InferenceSection] = [] for chunk in chunks: section = chunk_to_section.get((chunk.document_id, chunk.chunk_id)) if section: section_id = ( section.center_chunk.document_id, section.center_chunk.chunk_id, ) if section_id not in seen_section_ids: seen_section_ids.add(section_id) result.append(section) else: # Chunk wasn't part of any merged section, create a single-chunk section single_section = inference_section_from_chunks( center_chunk=chunk, chunks=[chunk], ) if single_section: single_section_id = ( single_section.center_chunk.document_id, single_section.center_chunk.chunk_id, ) if single_section_id not in seen_section_ids: seen_section_ids.add(single_section_id) result.append(single_section) return result @log_function_time(print_only=True, debug_only=True) def search_pipeline( # Query and settings chunk_search_request: ChunkSearchRequest, # Document index to search over # Note that federated sources will also be used (not related to this arg) document_index: DocumentIndex, # Used for ACLs and federated search, anonymous users only see public docs user: User, # Pre-extracted persona search configuration (None when no persona) persona_search_info: PersonaSearchInfo | None, db_session: Session | None = None, auto_detect_filters: bool = False, llm: LLM | None = None, # Vespa metadata filters for overflowing user files. NOT the raw IDs # of the current project/persona — only set when user files couldn't fit # in the LLM context and need to be searched via vector DB. project_id_filter: int | None = None, persona_id_filter: int | None = None, # Pre-fetched data — when provided, avoids DB queries (no session needed) acl_filters: list[str] | None = None, embedding_model: EmbeddingModel | None = None, prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None, ) -> list[InferenceChunk]: persona_document_sets: list[str] | None = ( persona_search_info.document_set_names if persona_search_info else None ) persona_time_cutoff: datetime | None = ( persona_search_info.search_start_date if persona_search_info else None ) attached_document_ids: list[str] | None = ( persona_search_info.attached_document_ids or None if persona_search_info else None ) hierarchy_node_ids: list[int] | None = ( persona_search_info.hierarchy_node_ids or None if persona_search_info else None ) filters = _build_index_filters( user_provided_filters=chunk_search_request.user_selected_filters, user=user, project_id_filter=project_id_filter, persona_id_filter=persona_id_filter, persona_document_sets=persona_document_sets, persona_time_cutoff=persona_time_cutoff, db_session=db_session, auto_detect_filters=auto_detect_filters, query=chunk_search_request.query, llm=llm, bypass_acl=chunk_search_request.bypass_acl, attached_document_ids=attached_document_ids, hierarchy_node_ids=hierarchy_node_ids, acl_filters=acl_filters, ) query_keywords = strip_stopwords(chunk_search_request.query) query_request = ChunkIndexRequest( query=chunk_search_request.query, hybrid_alpha=chunk_search_request.hybrid_alpha, recency_bias_multiplier=chunk_search_request.recency_bias_multiplier, query_keywords=query_keywords, filters=filters, limit=chunk_search_request.limit, ) retrieved_chunks = search_chunks( query_request=query_request, user_id=user.id if user else None, document_index=document_index, db_session=db_session, embedding_model=embedding_model, prefetched_federated_retrieval_infos=prefetched_federated_retrieval_infos, ) # For some specific connectors like Salesforce, a user that has access to an object doesn't mean # that they have access to all of the fields of the object. censored_chunks: list[InferenceChunk] = fetch_ee_implementation_or_noop( "onyx.external_permissions.post_query_censoring", "_post_query_chunk_censoring", retrieved_chunks, )( chunks=retrieved_chunks, user=user, ) return censored_chunks ================================================ FILE: backend/onyx/context/search/preprocessing/access_filters.py ================================================ from sqlalchemy.orm import Session from onyx.access.access import get_acl_for_user from onyx.context.search.models import IndexFilters from onyx.db.models import User def build_access_filters_for_user(user: User, session: Session) -> list[str]: user_acl = get_acl_for_user(user, session) return list(user_acl) def build_user_only_filters(user: User, db_session: Session) -> IndexFilters: user_acl_filters = build_access_filters_for_user(user, db_session) return IndexFilters( source_type=None, document_set=None, time_cutoff=None, tags=None, access_control_list=user_acl_filters, ) ================================================ FILE: backend/onyx/context/search/retrieval/search_runner.py ================================================ from collections.abc import Callable from uuid import UUID from sqlalchemy.orm import Session from onyx.configs.chat_configs import HYBRID_ALPHA from onyx.configs.chat_configs import NUM_RETURNED_HITS from onyx.context.search.models import ChunkIndexRequest from onyx.context.search.models import IndexFilters from onyx.context.search.models import InferenceChunk from onyx.context.search.models import InferenceSection from onyx.context.search.models import QueryExpansionType from onyx.context.search.utils import get_query_embedding from onyx.context.search.utils import inference_section_from_chunks from onyx.document_index.interfaces import DocumentIndex from onyx.document_index.interfaces import VespaChunkRequest from onyx.document_index.interfaces_new import DocumentIndex as NewDocumentIndex from onyx.document_index.opensearch.opensearch_document_index import ( OpenSearchOldDocumentIndex, ) from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo from onyx.federated_connectors.federated_retrieval import ( get_federated_retrieval_functions, ) from onyx.natural_language_processing.search_nlp_models import EmbeddingModel from onyx.utils.logger import setup_logger from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel logger = setup_logger() def combine_retrieval_results( chunk_sets: list[list[InferenceChunk]], ) -> list[InferenceChunk]: all_chunks = [chunk for chunk_set in chunk_sets for chunk in chunk_set] unique_chunks: dict[tuple[str, int], InferenceChunk] = {} for chunk in all_chunks: key = (chunk.document_id, chunk.chunk_id) if key not in unique_chunks: unique_chunks[key] = chunk continue stored_chunk_score = unique_chunks[key].score or 0 this_chunk_score = chunk.score or 0 if stored_chunk_score < this_chunk_score: unique_chunks[key] = chunk sorted_chunks = sorted( unique_chunks.values(), key=lambda x: x.score or 0, reverse=True ) return sorted_chunks def _embed_and_hybrid_search( query_request: ChunkIndexRequest, document_index: DocumentIndex, db_session: Session | None = None, embedding_model: EmbeddingModel | None = None, ) -> list[InferenceChunk]: query_embedding = get_query_embedding( query_request.query, db_session=db_session, embedding_model=embedding_model, ) hybrid_alpha = query_request.hybrid_alpha or HYBRID_ALPHA top_chunks = document_index.hybrid_retrieval( query=query_request.query, query_embedding=query_embedding, final_keywords=query_request.query_keywords, filters=query_request.filters, hybrid_alpha=hybrid_alpha, time_decay_multiplier=query_request.recency_bias_multiplier, num_to_retrieve=query_request.limit or NUM_RETURNED_HITS, ranking_profile_type=( QueryExpansionType.KEYWORD if hybrid_alpha <= 0.3 else QueryExpansionType.SEMANTIC ), ) return top_chunks def _keyword_search( query_request: ChunkIndexRequest, document_index: NewDocumentIndex, ) -> list[InferenceChunk]: return document_index.keyword_retrieval( query=query_request.query, filters=query_request.filters, num_to_retrieve=query_request.limit or NUM_RETURNED_HITS, ) def search_chunks( query_request: ChunkIndexRequest, user_id: UUID | None, document_index: DocumentIndex, db_session: Session | None = None, embedding_model: EmbeddingModel | None = None, prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None, ) -> list[InferenceChunk]: run_queries: list[tuple[Callable, tuple]] = [] source_filters = ( set(query_request.filters.source_type) if query_request.filters.source_type else None ) # Federated retrieval — use pre-fetched if available, otherwise query DB if prefetched_federated_retrieval_infos is not None: federated_retrieval_infos = prefetched_federated_retrieval_infos else: if db_session is None: raise ValueError( "Either db_session or prefetched_federated_retrieval_infos must be provided" ) federated_retrieval_infos = get_federated_retrieval_functions( db_session=db_session, user_id=user_id, source_types=list(source_filters) if source_filters else None, document_set_names=query_request.filters.document_set, ) federated_sources = set( federated_retrieval_info.source.to_non_federated_source() for federated_retrieval_info in federated_retrieval_infos ) for federated_retrieval_info in federated_retrieval_infos: run_queries.append( (federated_retrieval_info.retrieval_function, (query_request,)) ) # Don't run normal hybrid search if there are no indexed sources to # search over normal_search_enabled = (source_filters is None) or ( len(set(source_filters) - federated_sources) > 0 ) if normal_search_enabled: if ( query_request.hybrid_alpha is not None and query_request.hybrid_alpha == 0.0 and isinstance(document_index, OpenSearchOldDocumentIndex) ): # If hybrid alpha is explicitly set to keyword only, do pure keyword # search without generating an embedding. This is currently only # supported with OpenSearchDocumentIndex. opensearch_new_document_index: NewDocumentIndex = document_index._real_index run_queries.append( ( lambda: _keyword_search( query_request, opensearch_new_document_index ), (), ) ) else: run_queries.append( ( _embed_and_hybrid_search, (query_request, document_index, db_session, embedding_model), ) ) parallel_search_results = run_functions_tuples_in_parallel(run_queries) top_chunks = combine_retrieval_results(parallel_search_results) if not top_chunks: logger.debug( f"Search returned no results for query: {query_request.query} with filters: {query_request.filters}." ) return top_chunks # TODO: This is unused code. def inference_sections_from_ids( doc_identifiers: list[tuple[str, int]], document_index: DocumentIndex, ) -> list[InferenceSection]: # Currently only fetches whole docs doc_ids_set = set(doc_id for doc_id, _ in doc_identifiers) chunk_requests: list[VespaChunkRequest] = [ VespaChunkRequest(document_id=doc_id) for doc_id in doc_ids_set ] # No need for ACL here because the doc ids were validated beforehand filters = IndexFilters(access_control_list=None) retrieved_chunks = document_index.id_based_retrieval( chunk_requests=chunk_requests, filters=filters, ) if not retrieved_chunks: return [] # Group chunks by document ID chunks_by_doc_id: dict[str, list[InferenceChunk]] = {} for chunk in retrieved_chunks: chunks_by_doc_id.setdefault(chunk.document_id, []).append(chunk) inference_sections = [ section for chunks in chunks_by_doc_id.values() if chunks and ( section := inference_section_from_chunks( # The scores will always be 0 because the fetching by id gives back # no search scores. This is not needed though if the user is explicitly # selecting a document. center_chunk=chunks[0], chunks=chunks, ) ) ] return inference_sections ================================================ FILE: backend/onyx/context/search/utils.py ================================================ from typing import TypeVar from sqlalchemy.orm import Session from onyx.context.search.models import InferenceChunk from onyx.context.search.models import InferenceSection from onyx.context.search.models import SavedSearchDoc from onyx.context.search.models import SavedSearchDocWithContent from onyx.context.search.models import SearchDoc from onyx.db.search_settings import get_current_search_settings from onyx.natural_language_processing.search_nlp_models import EmbeddingModel from onyx.utils.logger import setup_logger from onyx.utils.timing import log_function_time from shared_configs.configs import MODEL_SERVER_HOST from shared_configs.configs import MODEL_SERVER_PORT from shared_configs.enums import EmbedTextType from shared_configs.model_server_models import Embedding logger = setup_logger() T = TypeVar( "T", InferenceSection, InferenceChunk, SearchDoc, SavedSearchDoc, SavedSearchDocWithContent, ) TSection = TypeVar( "TSection", InferenceSection, SearchDoc, SavedSearchDoc, SavedSearchDocWithContent, ) def inference_section_from_chunks( center_chunk: InferenceChunk, chunks: list[InferenceChunk], ) -> InferenceSection | None: if not chunks: return None combined_content = "\n".join([chunk.content for chunk in chunks]) return InferenceSection( center_chunk=center_chunk, chunks=chunks, combined_content=combined_content, ) # If it should be a real section, don't use this one def inference_section_from_single_chunk( chunk: InferenceChunk, ) -> InferenceSection: return InferenceSection( center_chunk=chunk, chunks=[chunk], combined_content=chunk.content, ) def get_query_embeddings( queries: list[str], db_session: Session | None = None, embedding_model: EmbeddingModel | None = None, ) -> list[Embedding]: if embedding_model is None: if db_session is None: raise ValueError("Either db_session or embedding_model must be provided") search_settings = get_current_search_settings(db_session) embedding_model = EmbeddingModel.from_db_model( search_settings=search_settings, server_host=MODEL_SERVER_HOST, server_port=MODEL_SERVER_PORT, ) query_embedding = embedding_model.encode(queries, text_type=EmbedTextType.QUERY) return query_embedding @log_function_time(print_only=True, debug_only=True) def get_query_embedding( query: str, db_session: Session | None = None, embedding_model: EmbeddingModel | None = None, ) -> Embedding: return get_query_embeddings( [query], db_session=db_session, embedding_model=embedding_model )[0] def convert_inference_sections_to_search_docs( inference_sections: list[InferenceSection], is_internet: bool = False, ) -> list[SearchDoc]: search_docs = SearchDoc.from_chunks_or_sections(inference_sections) for search_doc in search_docs: search_doc.is_internet = is_internet return search_docs ================================================ FILE: backend/onyx/db/README.md ================================================ An explanation of how the history of messages, tool calls, and docs are stored in the database: Messages are grouped by a chat session, a tree structured is used to allow edits and for the user to switch between branches. Each ChatMessage is either a user message or an assistant message. It should always alternate between the two, System messages, custom agent prompt injections, and reminder messages are injected dynamically after the chat session is loaded into memory. The user and assistant messages are stored in pairs, though it is ok if the user message is stored and the assistant message fails. The user chat message is relatively simple and includes the user prompt and any attached documents. The assistant message includes the response, tool calls, feedback, citations, etc. Things provided as input are part of the user message, things that happen during the inference and LLM loop are part of the assistant message. Reasoning is part of the message or tool call that occured after the reasoning. Really the reasoning should be part of the previous message / tool call because if it branches afterwards as a result of the reasoning, this is somewhat unintuitive. But to not include reasoning as part of the user message, it is instead included with the following message or tool call. With parallel tool calls, the reasoning will be included with each of the tool calls. Tool calls are stored in the ToolCall table and can represent all of the following: - Parallel tool calls, these will have the same turn number and parent tool call id - Sequential tool calls, these will have a different turn number and parent tool call id - Tool calls attached to the ChatMessage are top level tool calls directly triggered by the LLM - Tool calls that are instead attached to other ToolCalls are tool calls that happen as part of an agent that has been called. The top level tool call is the agent call and the tool calls that have the agent call as a parent are the tool calls that happen as part of the agent. The different branches are generated by sending a new search query to an existing parent. ``` [Empty Root Message] (This allows the first message to be branched/edited as well) / | \ [First Message] [First Message Edit 1] [First Message Edit 2] | | [Second Message] [Second Message of Edit 1 Branch] ``` ================================================ FILE: backend/onyx/db/__init__.py ================================================ ================================================ FILE: backend/onyx/db/_deprecated/pg_file_store.py ================================================ """Kept around since it's used in the migration to move to S3/MinIO""" import tempfile from io import BytesIO from typing import IO from psycopg2.extensions import connection from sqlalchemy import text # NEW: for SQL large-object helpers from sqlalchemy.orm import Session from onyx.file_store.constants import MAX_IN_MEMORY_SIZE from onyx.file_store.constants import STANDARD_CHUNK_SIZE from onyx.utils.logger import setup_logger logger = setup_logger() def get_pg_conn_from_session(db_session: Session) -> connection: return db_session.connection().connection.connection # type: ignore def create_populate_lobj( content: IO, db_session: Session, ) -> int: """Create a PostgreSQL large object from *content* and return its OID. Preferred approach is to use the psycopg2 ``lobject`` API, but if that is unavailable (e.g. when the underlying connection is an asyncpg adapter) we fall back to PostgreSQL helper functions such as ``lo_from_bytea``. NOTE: this function intentionally *does not* commit the surrounding transaction – that is handled by the caller so all work stays atomic. """ pg_conn = None try: pg_conn = get_pg_conn_from_session(db_session) # ``AsyncAdapt_asyncpg_connection`` (asyncpg) has no ``lobject`` if not hasattr(pg_conn, "lobject"): raise AttributeError # will be handled by fallback below large_object = pg_conn.lobject() # write in multiple chunks to avoid loading the whole file into memory while True: chunk = content.read(STANDARD_CHUNK_SIZE) if not chunk: break large_object.write(chunk) large_object.close() return large_object.oid except AttributeError: # Fall back to SQL helper functions – read the full content into memory # (acceptable for the limited number and size of files handled during # migrations). ``lo_from_bytea`` returns the new OID. byte_data = content.read() result = db_session.execute( text("SELECT lo_from_bytea(0, :data) AS oid"), {"data": byte_data}, ) # ``scalar_one`` is 2.0-style; ``scalar`` works on both 1.4/2.0. lobj_oid = result.scalar() if lobj_oid is None: raise RuntimeError("Failed to create large object") return int(lobj_oid) def read_lobj( lobj_oid: int, db_session: Session, mode: str | None = None, use_tempfile: bool = False, ) -> IO: """Read a PostgreSQL large object identified by *lobj_oid*. Attempts to use the native ``lobject`` API first; if unavailable falls back to ``lo_get`` which returns the large object's contents as *bytea*. """ pg_conn = None try: pg_conn = get_pg_conn_from_session(db_session) if not hasattr(pg_conn, "lobject"): raise AttributeError # Ensure binary mode by default if mode is None: mode = "rb" large_object = ( pg_conn.lobject(lobj_oid, mode=mode) if mode else pg_conn.lobject(lobj_oid) ) if use_tempfile: temp_file = tempfile.SpooledTemporaryFile(max_size=MAX_IN_MEMORY_SIZE) while True: chunk = large_object.read(STANDARD_CHUNK_SIZE) if not chunk: break temp_file.write(chunk) temp_file.seek(0) return temp_file else: return BytesIO(large_object.read()) except AttributeError: # Fallback path using ``lo_get`` result = db_session.execute( text("SELECT lo_get(:oid) AS data"), {"oid": lobj_oid}, ) byte_data = result.scalar() if byte_data is None: raise RuntimeError("Failed to read large object") if use_tempfile: temp_file = tempfile.SpooledTemporaryFile(max_size=MAX_IN_MEMORY_SIZE) temp_file.write(byte_data) temp_file.seek(0) return temp_file return BytesIO(byte_data) def delete_lobj_by_id( lobj_oid: int, db_session: Session, ) -> None: """Remove a large object by OID, regardless of driver implementation.""" try: pg_conn = get_pg_conn_from_session(db_session) if hasattr(pg_conn, "lobject"): pg_conn.lobject(lobj_oid).unlink() return raise AttributeError except AttributeError: # Fallback for drivers without ``lobject`` support db_session.execute(text("SELECT lo_unlink(:oid)"), {"oid": lobj_oid}) # No explicit result expected ================================================ FILE: backend/onyx/db/api_key.py ================================================ import uuid from fastapi_users.password import PasswordHelper from sqlalchemy import delete from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session from onyx.auth.api_key import ApiKeyDescriptor from onyx.auth.api_key import build_displayable_api_key from onyx.auth.api_key import generate_api_key from onyx.auth.api_key import hash_api_key from onyx.auth.schemas import UserRole from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN from onyx.configs.constants import DANSWER_API_KEY_PREFIX from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER from onyx.db.enums import AccountType from onyx.db.models import ApiKey from onyx.db.models import User from onyx.db.models import User__UserGroup from onyx.db.models import UserGroup from onyx.db.permissions import recompute_user_permissions__no_commit from onyx.db.users import assign_user_to_default_groups__no_commit from onyx.server.api_key.models import APIKeyArgs from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() def get_api_key_email_pattern() -> str: return DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN def is_api_key_email_address(email: str) -> bool: return email.endswith(get_api_key_email_pattern()) def fetch_api_keys(db_session: Session) -> list[ApiKeyDescriptor]: api_keys = ( db_session.scalars(select(ApiKey).options(joinedload(ApiKey.user))) .unique() .all() ) return [ ApiKeyDescriptor( api_key_id=api_key.id, api_key_role=api_key.user.role, api_key_display=api_key.api_key_display, api_key_name=api_key.name, user_id=api_key.user_id, ) for api_key in api_keys ] async def fetch_user_for_api_key( hashed_api_key: str, async_db_session: AsyncSession ) -> User | None: """NOTE: this is async, since it's used during auth (which is necessarily async due to FastAPI Users)""" return await async_db_session.scalar( select(User) .join(ApiKey, ApiKey.user_id == User.id) .where(ApiKey.hashed_api_key == hashed_api_key) ) def get_api_key_fake_email( name: str, unique_id: str, ) -> str: return f"{DANSWER_API_KEY_PREFIX}{name}@{unique_id}{DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN}" def insert_api_key( db_session: Session, api_key_args: APIKeyArgs, user_id: uuid.UUID | None ) -> ApiKeyDescriptor: std_password_helper = PasswordHelper() # Get tenant_id from context var (will be default schema for single tenant) tenant_id = get_current_tenant_id() api_key = generate_api_key(tenant_id) api_key_user_id = uuid.uuid4() display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER api_key_user_row = User( id=api_key_user_id, email=get_api_key_fake_email(display_name, str(api_key_user_id)), # a random password for the "user" hashed_password=std_password_helper.hash(std_password_helper.generate()), is_active=True, is_superuser=False, is_verified=True, role=api_key_args.role, account_type=AccountType.SERVICE_ACCOUNT, ) db_session.add(api_key_user_row) api_key_row = ApiKey( name=api_key_args.name, hashed_api_key=hash_api_key(api_key), api_key_display=build_displayable_api_key(api_key), user_id=api_key_user_id, owner_id=user_id, ) db_session.add(api_key_row) # Assign the API key virtual user to the appropriate default group # before commit so everything is atomic. # LIMITED role service accounts should have no group membership. if api_key_args.role != UserRole.LIMITED: assign_user_to_default_groups__no_commit( db_session, api_key_user_row, is_admin=(api_key_args.role == UserRole.ADMIN), ) db_session.commit() return ApiKeyDescriptor( api_key_id=api_key_row.id, api_key_role=api_key_user_row.role, api_key_display=api_key_row.api_key_display, api_key=api_key, api_key_name=api_key_args.name, user_id=api_key_user_id, ) def update_api_key( db_session: Session, api_key_id: int, api_key_args: APIKeyArgs ) -> ApiKeyDescriptor: existing_api_key = db_session.scalar(select(ApiKey).where(ApiKey.id == api_key_id)) if existing_api_key is None: raise ValueError(f"API key with id {api_key_id} does not exist") existing_api_key.name = api_key_args.name api_key_user = db_session.scalar( select(User).where(User.id == existing_api_key.user_id) # type: ignore ) if api_key_user is None: raise RuntimeError("API Key does not have associated user.") email_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER api_key_user.email = get_api_key_fake_email(email_name, str(api_key_user.id)) old_role = api_key_user.role api_key_user.role = api_key_args.role # Reconcile default-group membership when the role changes. if old_role != api_key_args.role: # Remove from all default groups first. delete_stmt = delete(User__UserGroup).where( User__UserGroup.user_id == api_key_user.id, User__UserGroup.user_group_id.in_( select(UserGroup.id).where(UserGroup.is_default.is_(True)) ), ) db_session.execute(delete_stmt) # Re-assign to the correct default group (skip for LIMITED). if api_key_args.role != UserRole.LIMITED: assign_user_to_default_groups__no_commit( db_session, api_key_user, is_admin=(api_key_args.role == UserRole.ADMIN), ) else: # No group assigned for LIMITED, but we still need to recompute # since we just removed the old default-group membership above. recompute_user_permissions__no_commit(api_key_user.id, db_session) db_session.commit() return ApiKeyDescriptor( api_key_id=existing_api_key.id, api_key_display=existing_api_key.api_key_display, api_key_name=api_key_args.name, api_key_role=api_key_user.role, user_id=existing_api_key.user_id, ) def regenerate_api_key(db_session: Session, api_key_id: int) -> ApiKeyDescriptor: """NOTE: currently, any admin can regenerate any API key.""" existing_api_key = db_session.scalar(select(ApiKey).where(ApiKey.id == api_key_id)) if existing_api_key is None: raise ValueError(f"API key with id {api_key_id} does not exist") api_key_user = db_session.scalar( select(User).where(User.id == existing_api_key.user_id) # type: ignore ) if api_key_user is None: raise RuntimeError("API Key does not have associated user.") # Get tenant_id from context var (will be default schema for single tenant) tenant_id = get_current_tenant_id() new_api_key = generate_api_key(tenant_id) existing_api_key.hashed_api_key = hash_api_key(new_api_key) existing_api_key.api_key_display = build_displayable_api_key(new_api_key) db_session.commit() return ApiKeyDescriptor( api_key_id=existing_api_key.id, api_key_display=existing_api_key.api_key_display, api_key=new_api_key, api_key_name=existing_api_key.name, api_key_role=api_key_user.role, user_id=existing_api_key.user_id, ) def remove_api_key(db_session: Session, api_key_id: int) -> None: existing_api_key = db_session.scalar(select(ApiKey).where(ApiKey.id == api_key_id)) if existing_api_key is None: raise ValueError(f"API key with id {api_key_id} does not exist") user_associated_with_key = db_session.scalar( select(User).where(User.id == existing_api_key.user_id) # type: ignore ) if user_associated_with_key is None: raise ValueError( f"User associated with API key with id {api_key_id} does not exist. This should not happen." ) db_session.delete(existing_api_key) db_session.delete(user_associated_with_key) db_session.commit() ================================================ FILE: backend/onyx/db/auth.py ================================================ from collections.abc import AsyncGenerator from collections.abc import Callable from typing import Any from typing import Dict from typing import TypeVar from fastapi import Depends from fastapi_users.models import ID from fastapi_users.models import UP from fastapi_users_db_sqlalchemy import SQLAlchemyUserDatabase from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyAccessTokenDatabase from sqlalchemy import func from sqlalchemy import Select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.future import select from sqlalchemy.orm import Session from onyx.auth.schemas import UserRole from onyx.configs.constants import ANONYMOUS_USER_EMAIL from onyx.configs.constants import NO_AUTH_PLACEHOLDER_USER_EMAIL from onyx.db.api_key import get_api_key_email_pattern from onyx.db.engine.async_sql_engine import get_async_session from onyx.db.engine.async_sql_engine import get_async_session_context_manager from onyx.db.models import AccessToken from onyx.db.models import OAuthAccount from onyx.db.models import User from onyx.utils.variable_functionality import ( fetch_versioned_implementation_with_fallback, ) T = TypeVar("T", bound=tuple[Any, ...]) def get_default_admin_user_emails() -> list[str]: """Returns a list of emails who should default to Admin role. Only used in the EE version. For MIT, just return empty list.""" get_default_admin_user_emails_fn: Callable[[], list[str]] = ( fetch_versioned_implementation_with_fallback( "onyx.auth.users", "get_default_admin_user_emails_", lambda: list[str]() ) ) return get_default_admin_user_emails_fn() def _add_live_user_count_where_clause( select_stmt: Select[T], only_admin_users: bool, ) -> Select[T]: """ Builds a SQL column expression that can be used to filter out users who should not be included in the live user count. Excludes: - API key users (by email pattern) - System users (anonymous user, no-auth placeholder) - External permission users (unless only_admin_users is True) """ select_stmt = select_stmt.where(~User.email.endswith(get_api_key_email_pattern())) # type: ignore # Exclude system users (anonymous user, no-auth placeholder) select_stmt = select_stmt.where(User.email != ANONYMOUS_USER_EMAIL) # type: ignore select_stmt = select_stmt.where(User.email != NO_AUTH_PLACEHOLDER_USER_EMAIL) # type: ignore if only_admin_users: return select_stmt.where(User.role == UserRole.ADMIN) return select_stmt.where( User.role != UserRole.EXT_PERM_USER, ) def get_live_users_count(db_session: Session) -> int: """ Returns the number of users in the system. This does NOT include invited users, "users" pulled in from external connectors, or API keys. """ count_stmt = func.count(User.id) select_stmt = select(count_stmt) select_stmt_w_filters = _add_live_user_count_where_clause(select_stmt, False) user_count = db_session.scalar(select_stmt_w_filters) if user_count is None: raise RuntimeError("Was not able to fetch the user count.") return user_count async def get_user_count(only_admin_users: bool = False) -> int: async with get_async_session_context_manager() as session: count_stmt = func.count(User.id) stmt = select(count_stmt) stmt_w_filters = _add_live_user_count_where_clause(stmt, only_admin_users) user_count = await session.scalar(stmt_w_filters) if user_count is None: raise RuntimeError("Was not able to fetch the user count.") return user_count # Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]): async def create( self, create_dict: Dict[str, Any], ) -> UP: user_count = await get_user_count() if user_count == 0 or create_dict["email"] in get_default_admin_user_emails(): create_dict["role"] = UserRole.ADMIN else: create_dict["role"] = UserRole.BASIC return await super().create(create_dict) async def get_user_db( session: AsyncSession = Depends(get_async_session), ) -> AsyncGenerator[SQLAlchemyUserAdminDB, None]: yield SQLAlchemyUserAdminDB(session, User, OAuthAccount) async def get_access_token_db( session: AsyncSession = Depends(get_async_session), ) -> AsyncGenerator[SQLAlchemyAccessTokenDatabase, None]: yield SQLAlchemyAccessTokenDatabase(session, AccessToken) ================================================ FILE: backend/onyx/db/background_error.py ================================================ from sqlalchemy.orm import Session from onyx.db.models import BackgroundError def create_background_error( db_session: Session, message: str, cc_pair_id: int | None ) -> None: db_session.add(BackgroundError(message=message, cc_pair_id=cc_pair_id)) db_session.commit() ================================================ FILE: backend/onyx/db/chat.py ================================================ from collections.abc import Sequence from datetime import datetime from datetime import timedelta from datetime import timezone from typing import Tuple from uuid import UUID from fastapi import HTTPException from sqlalchemy import delete from sqlalchemy import desc from sqlalchemy import func from sqlalchemy import nullsfirst from sqlalchemy import or_ from sqlalchemy import Row from sqlalchemy import select from sqlalchemy import update from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.orm import joinedload from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from onyx.configs.chat_configs import HARD_DELETE_CHATS from onyx.configs.constants import MessageType from onyx.context.search.models import InferenceSection from onyx.context.search.models import SavedSearchDoc from onyx.context.search.models import SearchDoc as ServerSearchDoc from onyx.db.models import ChatMessage from onyx.db.models import ChatMessage__SearchDoc from onyx.db.models import ChatSession from onyx.db.models import ChatSessionSharedStatus from onyx.db.models import Persona from onyx.db.models import SearchDoc as DBSearchDoc from onyx.db.models import ToolCall from onyx.db.models import User from onyx.db.persona import get_best_persona_id_for_user from onyx.file_store.file_store import get_default_file_store from onyx.file_store.models import FileDescriptor from onyx.llm.override_models import LLMOverride from onyx.llm.override_models import PromptOverride from onyx.server.query_and_chat.models import ChatMessageDetail from onyx.utils.logger import setup_logger from onyx.utils.postgres_sanitization import sanitize_string logger = setup_logger() # Note: search/streaming packet helpers moved to streaming_utils.py def get_chat_session_by_id( chat_session_id: UUID, user_id: UUID | None, db_session: Session, include_deleted: bool = False, is_shared: bool = False, eager_load_persona: bool = False, ) -> ChatSession: stmt = select(ChatSession).where(ChatSession.id == chat_session_id) if eager_load_persona: stmt = stmt.options( joinedload(ChatSession.persona).options( selectinload(Persona.tools), selectinload(Persona.user_files), selectinload(Persona.document_sets), selectinload(Persona.attached_documents), selectinload(Persona.hierarchy_nodes), ), joinedload(ChatSession.project), ) if is_shared: stmt = stmt.where(ChatSession.shared_status == ChatSessionSharedStatus.PUBLIC) else: # if user_id is None, assume this is an admin who should be able # to view all chat sessions if user_id is not None: stmt = stmt.where( or_(ChatSession.user_id == user_id, ChatSession.user_id.is_(None)) ) result = db_session.execute(stmt) chat_session = result.scalar_one_or_none() if not chat_session: raise ValueError("Invalid Chat Session ID provided") if not include_deleted and chat_session.deleted: raise ValueError("Chat session has been deleted") return chat_session def get_chat_sessions_by_slack_thread_id( slack_thread_id: str, user_id: UUID | None, db_session: Session, ) -> Sequence[ChatSession]: stmt = select(ChatSession).where(ChatSession.slack_thread_id == slack_thread_id) if user_id is not None: stmt = stmt.where( or_(ChatSession.user_id == user_id, ChatSession.user_id.is_(None)) ) return db_session.scalars(stmt).all() # Retrieves chat sessions by user # Chat sessions do not include onyxbot flows def get_chat_sessions_by_user( user_id: UUID | None, deleted: bool | None, db_session: Session, include_onyxbot_flows: bool = False, limit: int = 50, before: datetime | None = None, project_id: int | None = None, only_non_project_chats: bool = False, include_failed_chats: bool = False, ) -> list[ChatSession]: stmt = select(ChatSession).where(ChatSession.user_id == user_id) if not include_onyxbot_flows: stmt = stmt.where(ChatSession.onyxbot_flow.is_(False)) stmt = stmt.order_by(desc(ChatSession.time_updated)) if deleted is not None: stmt = stmt.where(ChatSession.deleted == deleted) if before is not None: stmt = stmt.where(ChatSession.time_updated < before) if project_id is not None: stmt = stmt.where(ChatSession.project_id == project_id) elif only_non_project_chats: stmt = stmt.where(ChatSession.project_id.is_(None)) # When filtering out failed chats, we apply the limit in Python after # filtering rather than in SQL, since the post-filter may remove rows. if limit and include_failed_chats: stmt = stmt.limit(limit) result = db_session.execute(stmt) chat_sessions = list(result.scalars().all()) if not include_failed_chats and chat_sessions: # Filter out "failed" sessions (those with only SYSTEM messages) # using a separate efficient query instead of a correlated EXISTS # subquery, which causes full sequential scans of chat_message. leeway = datetime.now(timezone.utc) - timedelta(minutes=5) session_ids = [cs.id for cs in chat_sessions if cs.time_created < leeway] if session_ids: valid_session_ids_stmt = ( select(ChatMessage.chat_session_id) .where(ChatMessage.chat_session_id.in_(session_ids)) .where(ChatMessage.message_type != MessageType.SYSTEM) .distinct() ) valid_session_ids = set( db_session.execute(valid_session_ids_stmt).scalars().all() ) chat_sessions = [ cs for cs in chat_sessions if cs.time_created >= leeway or cs.id in valid_session_ids ] if limit: chat_sessions = chat_sessions[:limit] return chat_sessions def delete_orphaned_search_docs(db_session: Session) -> None: orphaned_docs = ( db_session.query(DBSearchDoc) .outerjoin(ChatMessage__SearchDoc) .filter(ChatMessage__SearchDoc.chat_message_id.is_(None)) .all() ) for doc in orphaned_docs: db_session.delete(doc) db_session.commit() def delete_messages_and_files_from_chat_session( chat_session_id: UUID, db_session: Session ) -> None: # Select messages older than cutoff_time with files messages_with_files = db_session.execute( select(ChatMessage.id, ChatMessage.files).where( ChatMessage.chat_session_id == chat_session_id, ) ).fetchall() for _, files in messages_with_files: file_store = get_default_file_store() for file_info in files or []: file_store.delete_file(file_id=file_info.get("id")) # Delete ChatMessage records - CASCADE constraints will automatically handle: # - ChatMessage__StandardAnswer relationship records db_session.execute( delete(ChatMessage).where(ChatMessage.chat_session_id == chat_session_id) ) db_session.commit() delete_orphaned_search_docs(db_session) def create_chat_session( db_session: Session, description: str | None, user_id: UUID | None, persona_id: int | None, # Can be none if temporary persona is used llm_override: LLMOverride | None = None, prompt_override: PromptOverride | None = None, onyxbot_flow: bool = False, slack_thread_id: str | None = None, project_id: int | None = None, ) -> ChatSession: chat_session = ChatSession( user_id=user_id, persona_id=persona_id, description=description, llm_override=llm_override, prompt_override=prompt_override, onyxbot_flow=onyxbot_flow, slack_thread_id=slack_thread_id, project_id=project_id, ) db_session.add(chat_session) db_session.commit() return chat_session def duplicate_chat_session_for_user_from_slack( db_session: Session, user: User, chat_session_id: UUID, ) -> ChatSession: """ This takes a chat session id for a session in Slack and: - Creates a new chat session in the DB - Tries to copy the persona from the original chat session (if it is available to the user clicking the button) - Sets the user to the given user (if provided) """ chat_session = get_chat_session_by_id( chat_session_id=chat_session_id, user_id=None, # Ignore user permissions for this db_session=db_session, ) if not chat_session: raise HTTPException(status_code=400, detail="Invalid Chat Session ID provided") # This enforces permissions and sets a default new_persona_id = get_best_persona_id_for_user( db_session=db_session, user=user, persona_id=chat_session.persona_id, ) return create_chat_session( db_session=db_session, user_id=user.id, persona_id=new_persona_id, # Set this to empty string so the frontend will force a rename description="", llm_override=chat_session.llm_override, prompt_override=chat_session.prompt_override, # Chat is in UI now so this is false onyxbot_flow=False, # Maybe we want this in the future to track if it was created from Slack slack_thread_id=None, ) def update_chat_session( db_session: Session, user_id: UUID | None, chat_session_id: UUID, description: str | None = None, sharing_status: ChatSessionSharedStatus | None = None, ) -> ChatSession: chat_session = get_chat_session_by_id( chat_session_id=chat_session_id, user_id=user_id, db_session=db_session ) if chat_session.deleted: raise ValueError("Trying to rename a deleted chat session") if description is not None: chat_session.description = description if sharing_status is not None: chat_session.shared_status = sharing_status db_session.commit() return chat_session def delete_all_chat_sessions_for_user( user: User, db_session: Session, hard_delete: bool = HARD_DELETE_CHATS ) -> None: user_id = user.id chat_sessions = ( db_session.query(ChatSession) .filter(ChatSession.user_id == user_id, ChatSession.onyxbot_flow.is_(False)) .all() ) if hard_delete: for chat_session in chat_sessions: delete_messages_and_files_from_chat_session(chat_session.id, db_session) db_session.execute( delete(ChatSession).where( ChatSession.user_id == user_id, ChatSession.onyxbot_flow.is_(False) ) ) else: db_session.execute( update(ChatSession) .where(ChatSession.user_id == user_id, ChatSession.onyxbot_flow.is_(False)) .values(deleted=True) ) db_session.commit() def delete_chat_session( user_id: UUID | None, chat_session_id: UUID, db_session: Session, include_deleted: bool = False, hard_delete: bool = HARD_DELETE_CHATS, ) -> None: chat_session = get_chat_session_by_id( chat_session_id=chat_session_id, user_id=user_id, db_session=db_session, include_deleted=include_deleted, ) if chat_session.deleted and not include_deleted: raise ValueError("Cannot delete an already deleted chat session") if hard_delete: delete_messages_and_files_from_chat_session(chat_session_id, db_session) db_session.execute(delete(ChatSession).where(ChatSession.id == chat_session_id)) else: chat_session = get_chat_session_by_id( chat_session_id=chat_session_id, user_id=user_id, db_session=db_session ) chat_session.deleted = True db_session.commit() def get_chat_sessions_older_than( days_old: int, db_session: Session ) -> list[tuple[UUID | None, UUID]]: """ Retrieves chat sessions older than a specified number of days. Args: days_old: The number of days to consider as "old". db_session: The database session. Returns: A list of tuples, where each tuple contains the user_id (can be None) and the chat_session_id of an old chat session. """ cutoff_time = datetime.utcnow() - timedelta(days=days_old) old_sessions: Sequence[Row[Tuple[UUID | None, UUID]]] = db_session.execute( select(ChatSession.user_id, ChatSession.id).where( ChatSession.time_created < cutoff_time ) ).fetchall() # convert old_sessions to a conventional list of tuples returned_sessions: list[tuple[UUID | None, UUID]] = [ (user_id, session_id) for user_id, session_id in old_sessions ] return returned_sessions def get_chat_message( chat_message_id: int, user_id: UUID | None, db_session: Session, ) -> ChatMessage: stmt = select(ChatMessage).where(ChatMessage.id == chat_message_id) result = db_session.execute(stmt) chat_message = result.scalar_one_or_none() if not chat_message: raise ValueError("Invalid Chat Message specified") chat_user = chat_message.chat_session.user expected_user_id = chat_user.id if chat_user is not None else None if expected_user_id != user_id: logger.error( f"User {user_id} tried to fetch a chat message that does not belong to them" ) raise ValueError("Chat message does not belong to user") return chat_message def get_chat_session_by_message_id( db_session: Session, message_id: int, ) -> ChatSession: """ Should only be used for Slack Get the chat session associated with a specific message ID Note: this ignores permission checks. """ stmt = select(ChatMessage).where(ChatMessage.id == message_id) result = db_session.execute(stmt) chat_message = result.scalar_one_or_none() if chat_message is None: raise ValueError( f"Unable to find chat session associated with message ID: {message_id}" ) return chat_message.chat_session def get_chat_messages_by_sessions( chat_session_ids: list[UUID], user_id: UUID | None, db_session: Session, skip_permission_check: bool = False, ) -> Sequence[ChatMessage]: if not skip_permission_check: for chat_session_id in chat_session_ids: get_chat_session_by_id( chat_session_id=chat_session_id, user_id=user_id, db_session=db_session ) stmt = ( select(ChatMessage) .where(ChatMessage.chat_session_id.in_(chat_session_ids)) .order_by(nullsfirst(ChatMessage.parent_message_id)) ) return db_session.execute(stmt).scalars().all() def add_chats_to_session_from_slack_thread( db_session: Session, slack_chat_session_id: UUID, new_chat_session_id: UUID, ) -> None: new_root_message = get_or_create_root_message( chat_session_id=new_chat_session_id, db_session=db_session, ) for chat_message in get_chat_messages_by_sessions( chat_session_ids=[slack_chat_session_id], user_id=None, # Ignore user permissions for this db_session=db_session, skip_permission_check=True, ): if chat_message.message_type == MessageType.SYSTEM: continue # Duplicate the message new_root_message = create_new_chat_message( db_session=db_session, chat_session_id=new_chat_session_id, parent_message=new_root_message, message=chat_message.message, files=chat_message.files, error=chat_message.error, token_count=chat_message.token_count, message_type=chat_message.message_type, reasoning_tokens=chat_message.reasoning_tokens, ) def add_search_docs_to_chat_message( chat_message_id: int, search_doc_ids: list[int], db_session: Session ) -> None: """ Link SearchDocs to a ChatMessage by creating entries in the chat_message__search_doc junction table. Args: chat_message_id: The ID of the chat message search_doc_ids: List of search document IDs to link db_session: The database session """ for search_doc_id in search_doc_ids: chat_message_search_doc = ChatMessage__SearchDoc( chat_message_id=chat_message_id, search_doc_id=search_doc_id ) db_session.add(chat_message_search_doc) def add_search_docs_to_tool_call( tool_call_id: int, search_doc_ids: list[int], db_session: Session ) -> None: """ Link SearchDocs to a ToolCall by creating entries in the tool_call__search_doc junction table. Args: tool_call_id: The ID of the tool call search_doc_ids: List of search document IDs to link db_session: The database session """ from onyx.db.models import ToolCall__SearchDoc for search_doc_id in search_doc_ids: tool_call_search_doc = ToolCall__SearchDoc( tool_call_id=tool_call_id, search_doc_id=search_doc_id ) db_session.add(tool_call_search_doc) def get_chat_messages_by_session( chat_session_id: UUID, user_id: UUID | None, db_session: Session, skip_permission_check: bool = False, prefetch_top_two_level_tool_calls: bool = True, ) -> list[ChatMessage]: if not skip_permission_check: # bug if we ever call this expecting the permission check to not be skipped get_chat_session_by_id( chat_session_id=chat_session_id, user_id=user_id, db_session=db_session ) stmt = ( select(ChatMessage) .where(ChatMessage.chat_session_id == chat_session_id) .order_by(nullsfirst(ChatMessage.parent_message_id)) ) # This should handle both the top level tool calls and deep research # If there are future nested agents, this can be extended. if prefetch_top_two_level_tool_calls: # Load tool_calls and their direct children (one level deep) stmt = stmt.options( selectinload(ChatMessage.tool_calls).selectinload( ToolCall.tool_call_children ) ) result = db_session.scalars(stmt).unique().all() else: result = db_session.scalars(stmt).all() return list(result) def get_or_create_root_message( chat_session_id: UUID, db_session: Session, ) -> ChatMessage: try: root_message: ChatMessage | None = ( db_session.query(ChatMessage) .filter( ChatMessage.chat_session_id == chat_session_id, ChatMessage.parent_message_id.is_(None), ) .one_or_none() ) except MultipleResultsFound: raise Exception( "Multiple root messages found for chat session. Data inconsistency detected." ) if root_message is not None: return root_message else: new_root_message = ChatMessage( chat_session_id=chat_session_id, parent_message_id=None, latest_child_message_id=None, message="", token_count=0, message_type=MessageType.SYSTEM, ) db_session.add(new_root_message) db_session.commit() return new_root_message def reserve_message_id( db_session: Session, chat_session_id: UUID, parent_message: int, message_type: MessageType = MessageType.ASSISTANT, ) -> ChatMessage: # Create an temporary holding chat message to the updated and saved at the end empty_message = ChatMessage( chat_session_id=chat_session_id, parent_message_id=parent_message, latest_child_message_id=None, message="Response was terminated prior to completion, try regenerating.", token_count=15, message_type=message_type, ) # Add the empty message to the session db_session.add(empty_message) db_session.flush() # Get the parent message and set its child pointer to the current message parent_chat_message = ( db_session.query(ChatMessage).filter(ChatMessage.id == parent_message).first() ) if parent_chat_message: parent_chat_message.latest_child_message_id = empty_message.id # Committing because it's ok to recover this state. More clear to the user than it is now. # Ideally there's a special UI for a case like this with a regenerate button but not needed for now. db_session.commit() return empty_message def reserve_multi_model_message_ids( db_session: Session, chat_session_id: UUID, parent_message_id: int, model_display_names: list[str], ) -> list[ChatMessage]: """Reserve N assistant message placeholders for multi-model parallel streaming. All messages share the same parent (the user message). The parent's latest_child_message_id points to the LAST reserved message so that the default history-chain walker picks it up. """ reserved: list[ChatMessage] = [] for display_name in model_display_names: msg = ChatMessage( chat_session_id=chat_session_id, parent_message_id=parent_message_id, latest_child_message_id=None, message="Response was terminated prior to completion, try regenerating.", token_count=15, # placeholder; updated on completion by llm_loop_completion_handle message_type=MessageType.ASSISTANT, model_display_name=display_name, ) db_session.add(msg) reserved.append(msg) # Flush to assign IDs without committing yet db_session.flush() # Point parent's latest_child to the last reserved message parent = ( db_session.query(ChatMessage) .filter(ChatMessage.id == parent_message_id) .first() ) if parent: parent.latest_child_message_id = reserved[-1].id db_session.commit() return reserved def set_preferred_response( db_session: Session, user_message_id: int, preferred_assistant_message_id: int, ) -> None: """Mark one assistant response as the user's preferred choice in a multi-model turn. Also advances ``latest_child_message_id`` so the preferred response becomes the active branch for any subsequent messages in the conversation. Args: db_session: Active database session. user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose preferred response is being set. preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type ``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``. Raises: ValueError: If either message is not found, if ``user_message_id`` does not refer to a USER message, or if the assistant message is not a direct child of the user message. """ user_msg = db_session.get(ChatMessage, user_message_id) if user_msg is None: raise ValueError(f"User message {user_message_id} not found") if user_msg.message_type != MessageType.USER: raise ValueError(f"Message {user_message_id} is not a user message") assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id) if assistant_msg is None: raise ValueError( f"Assistant message {preferred_assistant_message_id} not found" ) if assistant_msg.parent_message_id != user_message_id: raise ValueError( f"Assistant message {preferred_assistant_message_id} is not a child of user message {user_message_id}" ) user_msg.preferred_response_id = preferred_assistant_message_id user_msg.latest_child_message_id = preferred_assistant_message_id db_session.commit() def create_new_chat_message( chat_session_id: UUID, parent_message: ChatMessage, message: str, token_count: int, message_type: MessageType, db_session: Session, files: list[FileDescriptor] | None = None, error: str | None = None, commit: bool = True, reserved_message_id: int | None = None, reasoning_tokens: str | None = None, ) -> ChatMessage: if reserved_message_id is not None: # Edit existing message existing_message = db_session.query(ChatMessage).get(reserved_message_id) if existing_message is None: raise ValueError(f"No message found with id {reserved_message_id}") existing_message.chat_session_id = chat_session_id existing_message.parent_message_id = parent_message.id existing_message.message = message existing_message.token_count = token_count existing_message.message_type = message_type existing_message.files = files existing_message.error = error existing_message.reasoning_tokens = reasoning_tokens new_chat_message = existing_message else: # Create new message new_chat_message = ChatMessage( chat_session_id=chat_session_id, parent_message_id=parent_message.id, latest_child_message_id=None, message=message, token_count=token_count, message_type=message_type, files=files, error=error, reasoning_tokens=reasoning_tokens, ) db_session.add(new_chat_message) # Flush the session to get an ID for the new chat message db_session.flush() parent_message.latest_child_message_id = new_chat_message.id if commit: db_session.commit() return new_chat_message def set_as_latest_chat_message( chat_message: ChatMessage, user_id: UUID | None, db_session: Session, ) -> None: parent_message_id = chat_message.parent_message_id if parent_message_id is None: raise RuntimeError( f"Trying to set a latest message without parent, message id: {chat_message.id}" ) parent_message = get_chat_message( chat_message_id=parent_message_id, user_id=user_id, db_session=db_session ) parent_message.latest_child_message_id = chat_message.id db_session.commit() def create_db_search_doc( server_search_doc: ServerSearchDoc, db_session: Session, commit: bool = True, ) -> DBSearchDoc: db_search_doc = DBSearchDoc( document_id=sanitize_string(server_search_doc.document_id), chunk_ind=server_search_doc.chunk_ind, semantic_id=sanitize_string(server_search_doc.semantic_identifier), link=( sanitize_string(server_search_doc.link) if server_search_doc.link is not None else None ), blurb=sanitize_string(server_search_doc.blurb), source_type=server_search_doc.source_type, boost=server_search_doc.boost, hidden=server_search_doc.hidden, doc_metadata=server_search_doc.metadata, is_relevant=server_search_doc.is_relevant, relevance_explanation=( sanitize_string(server_search_doc.relevance_explanation) if server_search_doc.relevance_explanation is not None else None ), score=server_search_doc.score or 0.0, match_highlights=[ sanitize_string(h) for h in server_search_doc.match_highlights ], updated_at=server_search_doc.updated_at, primary_owners=( [sanitize_string(o) for o in server_search_doc.primary_owners] if server_search_doc.primary_owners is not None else None ), secondary_owners=( [sanitize_string(o) for o in server_search_doc.secondary_owners] if server_search_doc.secondary_owners is not None else None ), is_internet=server_search_doc.is_internet, ) db_session.add(db_search_doc) if commit: db_session.commit() else: db_session.flush() return db_search_doc def get_db_search_doc_by_id(doc_id: int, db_session: Session) -> DBSearchDoc | None: """There are no safety checks here like user permission etc., use with caution""" search_doc = db_session.query(DBSearchDoc).filter(DBSearchDoc.id == doc_id).first() return search_doc def get_db_search_doc_by_document_id( document_id: str, db_session: Session ) -> DBSearchDoc | None: """Get SearchDoc by document_id field. There are no safety checks here like user permission etc., use with caution""" search_doc = ( db_session.query(DBSearchDoc) .filter(DBSearchDoc.document_id == document_id) .first() ) return search_doc def translate_db_search_doc_to_saved_search_doc( db_search_doc: DBSearchDoc, remove_doc_content: bool = False, ) -> SavedSearchDoc: return SavedSearchDoc( db_doc_id=db_search_doc.id, score=db_search_doc.score, document_id=db_search_doc.document_id, chunk_ind=db_search_doc.chunk_ind, semantic_identifier=db_search_doc.semantic_id, link=db_search_doc.link, blurb=db_search_doc.blurb if not remove_doc_content else "", source_type=db_search_doc.source_type, boost=db_search_doc.boost, hidden=db_search_doc.hidden, metadata=db_search_doc.doc_metadata if not remove_doc_content else {}, match_highlights=( db_search_doc.match_highlights if not remove_doc_content else [] ), relevance_explanation=db_search_doc.relevance_explanation, is_relevant=db_search_doc.is_relevant, updated_at=db_search_doc.updated_at if not remove_doc_content else None, primary_owners=db_search_doc.primary_owners if not remove_doc_content else [], secondary_owners=( db_search_doc.secondary_owners if not remove_doc_content else [] ), is_internet=db_search_doc.is_internet, ) def translate_db_message_to_chat_message_detail( chat_message: ChatMessage, remove_doc_content: bool = False, ) -> ChatMessageDetail: # Get current feedback if any current_feedback = None if chat_message.chat_message_feedbacks: latest_feedback = chat_message.chat_message_feedbacks[-1] if latest_feedback.is_positive is not None: current_feedback = "like" if latest_feedback.is_positive else "dislike" # Convert citations from {citation_num: db_doc_id} to {citation_num: document_id} converted_citations = None if chat_message.citations and chat_message.search_docs: # Build lookup map: db_doc_id -> document_id db_doc_id_to_document_id = { doc.id: doc.document_id for doc in chat_message.search_docs } converted_citations = {} for citation_num, db_doc_id in chat_message.citations.items(): document_id = db_doc_id_to_document_id.get(db_doc_id) if document_id: converted_citations[citation_num] = document_id top_documents = [ translate_db_search_doc_to_saved_search_doc( db_doc, remove_doc_content=remove_doc_content ) for db_doc in chat_message.search_docs ] top_documents = sorted( top_documents, key=lambda doc: doc.score or 0.0, reverse=True ) chat_msg_detail = ChatMessageDetail( chat_session_id=chat_message.chat_session_id, message_id=chat_message.id, parent_message=chat_message.parent_message_id, latest_child_message=chat_message.latest_child_message_id, message=chat_message.message, reasoning_tokens=chat_message.reasoning_tokens, message_type=chat_message.message_type, context_docs=top_documents, citations=converted_citations, time_sent=chat_message.time_sent, files=chat_message.files or [], error=chat_message.error, current_feedback=current_feedback, processing_duration_seconds=chat_message.processing_duration_seconds, preferred_response_id=chat_message.preferred_response_id, model_display_name=chat_message.model_display_name, ) return chat_msg_detail def update_chat_session_updated_at_timestamp( chat_session_id: UUID, db_session: Session ) -> None: """ Explicitly update the timestamp on a chat session without modifying other fields. This is useful when adding messages to a chat session to reflect recent activity. """ # Direct SQL update to avoid loading the entire object if it's not already loaded db_session.execute( update(ChatSession) .where(ChatSession.id == chat_session_id) .values(time_updated=func.now()) ) # No commit - the caller is responsible for committing the transaction def create_search_doc_from_inference_section( inference_section: InferenceSection, is_internet: bool, db_session: Session, score: float = 0.0, is_relevant: bool | None = None, relevance_explanation: str | None = None, commit: bool = False, ) -> DBSearchDoc: """Create a SearchDoc in the database from an InferenceSection.""" db_search_doc = DBSearchDoc( document_id=inference_section.center_chunk.document_id, chunk_ind=inference_section.center_chunk.chunk_id, semantic_id=inference_section.center_chunk.semantic_identifier, link=( inference_section.center_chunk.source_links.get(0) if inference_section.center_chunk.source_links else None ), blurb=inference_section.center_chunk.blurb, source_type=inference_section.center_chunk.source_type, boost=inference_section.center_chunk.boost, hidden=inference_section.center_chunk.hidden, doc_metadata=inference_section.center_chunk.metadata, score=score, is_relevant=is_relevant, relevance_explanation=relevance_explanation, match_highlights=inference_section.center_chunk.match_highlights, updated_at=inference_section.center_chunk.updated_at, primary_owners=inference_section.center_chunk.primary_owners or [], secondary_owners=inference_section.center_chunk.secondary_owners or [], is_internet=is_internet, ) db_session.add(db_search_doc) if commit: db_session.commit() else: db_session.flush() return db_search_doc def create_search_doc_from_saved_search_doc( saved_search_doc: SavedSearchDoc, ) -> DBSearchDoc: """Convert SavedSearchDoc (server model) into DB SearchDoc with correct field mapping.""" return DBSearchDoc( document_id=saved_search_doc.document_id, chunk_ind=saved_search_doc.chunk_ind, # Map Pydantic semantic_identifier -> DB semantic_id; ensure non-null semantic_id=saved_search_doc.semantic_identifier or "Unknown", link=saved_search_doc.link, blurb=saved_search_doc.blurb, source_type=saved_search_doc.source_type, boost=saved_search_doc.boost, hidden=saved_search_doc.hidden, # Map metadata -> doc_metadata (DB column name) doc_metadata=saved_search_doc.metadata, # SavedSearchDoc.score exists and defaults to 0.0 score=saved_search_doc.score or 0.0, match_highlights=saved_search_doc.match_highlights, updated_at=saved_search_doc.updated_at, primary_owners=saved_search_doc.primary_owners, secondary_owners=saved_search_doc.secondary_owners, is_internet=saved_search_doc.is_internet, is_relevant=saved_search_doc.is_relevant, relevance_explanation=saved_search_doc.relevance_explanation, ) def update_db_session_with_messages( db_session: Session, chat_message_id: int, chat_session_id: UUID, message: str | None = None, message_type: str | None = None, token_count: int | None = None, error: str | None = None, update_parent_message: bool = True, files: list[FileDescriptor] | None = None, reasoning_tokens: str | None = None, commit: bool = False, ) -> ChatMessage: chat_message = ( db_session.query(ChatMessage) .filter( ChatMessage.id == chat_message_id, ChatMessage.chat_session_id == chat_session_id, ) .first() ) if not chat_message: raise ValueError("Chat message with id not found") # should never happen if message: chat_message.message = message if message_type: chat_message.message_type = MessageType(message_type) if token_count: chat_message.token_count = token_count if error: chat_message.error = error if files is not None: chat_message.files = files if reasoning_tokens is not None: chat_message.reasoning_tokens = reasoning_tokens if update_parent_message: parent_chat_message = ( db_session.query(ChatMessage) .filter(ChatMessage.id == chat_message.parent_message_id) .first() ) if parent_chat_message: parent_chat_message.latest_child_message_id = chat_message.id if commit: db_session.commit() else: db_session.flush() return chat_message ================================================ FILE: backend/onyx/db/chat_search.py ================================================ from typing import List from typing import Optional from typing import Tuple from uuid import UUID from sqlalchemy import column from sqlalchemy import desc from sqlalchemy import func from sqlalchemy import select from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session from sqlalchemy.sql.expression import ColumnClause from onyx.db.models import ChatMessage from onyx.db.models import ChatSession def search_chat_sessions( user_id: UUID | None, db_session: Session, query: Optional[str] = None, page: int = 1, page_size: int = 10, include_deleted: bool = False, include_onyxbot_flows: bool = False, ) -> Tuple[List[ChatSession], bool]: """ Fast full-text search on ChatSession + ChatMessage using tsvectors. If no query is provided, returns the most recent chat sessions. Otherwise, searches both chat messages and session descriptions. Returns a tuple of (sessions, has_more) where has_more indicates if there are additional results beyond the requested page. """ offset_val = (page - 1) * page_size # If no query, just return the most recent sessions if not query or not query.strip(): stmt = ( select(ChatSession) .order_by(desc(ChatSession.time_created)) .offset(offset_val) .limit(page_size + 1) ) if user_id is not None: stmt = stmt.where(ChatSession.user_id == user_id) if not include_onyxbot_flows: stmt = stmt.where(ChatSession.onyxbot_flow.is_(False)) if not include_deleted: stmt = stmt.where(ChatSession.deleted.is_(False)) result = db_session.execute(stmt.options(joinedload(ChatSession.persona))) sessions = result.scalars().all() has_more = len(sessions) > page_size if has_more: sessions = sessions[:page_size] return list(sessions), has_more # Otherwise, proceed with full-text search query = query.strip() base_conditions = [] if user_id is not None: base_conditions.append(ChatSession.user_id == user_id) if not include_onyxbot_flows: base_conditions.append(ChatSession.onyxbot_flow.is_(False)) if not include_deleted: base_conditions.append(ChatSession.deleted.is_(False)) message_tsv: ColumnClause = column("message_tsv") description_tsv: ColumnClause = column("description_tsv") ts_query = func.plainto_tsquery("english", query) description_session_ids = ( select(ChatSession.id) .where(*base_conditions) .where(description_tsv.op("@@")(ts_query)) ) message_session_ids = ( select(ChatMessage.chat_session_id) .join(ChatSession, ChatMessage.chat_session_id == ChatSession.id) .where(*base_conditions) .where(message_tsv.op("@@")(ts_query)) ) combined_ids = description_session_ids.union(message_session_ids).alias( "combined_ids" ) final_stmt = ( select(ChatSession) .join(combined_ids, ChatSession.id == combined_ids.c.id) .order_by(desc(ChatSession.time_created)) .distinct() .offset(offset_val) .limit(page_size + 1) .options(joinedload(ChatSession.persona)) ) session_objs = db_session.execute(final_stmt).scalars().all() has_more = len(session_objs) > page_size if has_more: session_objs = session_objs[:page_size] return list(session_objs), has_more ================================================ FILE: backend/onyx/db/chunk.py ================================================ from datetime import datetime from datetime import timezone from sqlalchemy import delete from sqlalchemy.orm import Session from onyx.db.models import ChunkStats from onyx.indexing.models import UpdatableChunkData def update_chunk_boost_components__no_commit( chunk_data: list[UpdatableChunkData], db_session: Session, ) -> None: """Updates the chunk_boost_components for chunks in the database. Args: chunk_data: List of dicts containing chunk_id, document_id, and boost_score db_session: SQLAlchemy database session """ if not chunk_data: return for data in chunk_data: chunk_in_doc_id = int(data.chunk_id) if chunk_in_doc_id < 0: raise ValueError(f"Chunk ID is empty for chunk {data}") chunk_document_id = f"{data.document_id}__{chunk_in_doc_id}" chunk_stats = ( db_session.query(ChunkStats) .filter( ChunkStats.id == chunk_document_id, ) .first() ) score = data.boost_score if chunk_stats: chunk_stats.information_content_boost = score chunk_stats.last_modified = datetime.now(timezone.utc) db_session.add(chunk_stats) else: # do not save new chunks with a neutral boost score if score == 1.0: continue # Create new record chunk_stats = ChunkStats( document_id=data.document_id, chunk_in_doc_id=chunk_in_doc_id, information_content_boost=score, ) db_session.add(chunk_stats) def delete_chunk_stats_by_connector_credential_pair__no_commit( db_session: Session, document_ids: list[str] ) -> None: """This deletes just chunk stats in postgres.""" stmt = delete(ChunkStats).where(ChunkStats.document_id.in_(document_ids)) db_session.execute(stmt) ================================================ FILE: backend/onyx/db/code_interpreter.py ================================================ from sqlalchemy import select from sqlalchemy.orm import Session from onyx.db.models import CodeInterpreterServer def fetch_code_interpreter_server( db_session: Session, ) -> CodeInterpreterServer: server = db_session.scalars(select(CodeInterpreterServer)).one() return server def update_code_interpreter_server_enabled( db_session: Session, enabled: bool, ) -> CodeInterpreterServer: server = db_session.scalars(select(CodeInterpreterServer)).one() server.server_enabled = enabled db_session.commit() return server ================================================ FILE: backend/onyx/db/connector.py ================================================ from datetime import datetime from datetime import timezone from typing import cast from sqlalchemy import and_ from sqlalchemy import exists from sqlalchemy import func from sqlalchemy import select from sqlalchemy.orm import aliased from sqlalchemy.orm import Session from onyx.configs.app_configs import DEFAULT_PRUNING_FREQ from onyx.configs.constants import DocumentSource from onyx.connectors.models import InputType from onyx.db.enums import IndexingMode from onyx.db.models import Connector from onyx.db.models import ConnectorCredentialPair from onyx.db.models import FederatedConnector from onyx.db.models import IndexAttempt from onyx.kg.models import KGConnectorData from onyx.server.documents.models import ConnectorBase from onyx.server.documents.models import ObjectCreationIdResponse from onyx.server.models import StatusResponse from onyx.utils.logger import setup_logger logger = setup_logger() def check_federated_connectors_exist(db_session: Session) -> bool: stmt = select(exists(FederatedConnector)) result = db_session.execute(stmt) return result.scalar() or False def check_connectors_exist(db_session: Session) -> bool: # Connector 0 is created on server startup as a default for ingestion # it will always exist and we don't need to count it for this stmt = select(exists(Connector).where(Connector.id > 0)) result = db_session.execute(stmt) return result.scalar() or False def check_user_files_exist(db_session: Session) -> bool: """Check if any user files exist in the system. This is used to determine if the search tool should be available when there are no regular connectors but there are user files (User Knowledge mode). """ from onyx.db.models import UserFile from onyx.db.enums import UserFileStatus stmt = select(exists(UserFile).where(UserFile.status == UserFileStatus.COMPLETED)) result = db_session.execute(stmt) return result.scalar() or False def fetch_connectors( db_session: Session, sources: list[DocumentSource] | None = None, input_types: list[InputType] | None = None, ) -> list[Connector]: stmt = select(Connector) if sources is not None: stmt = stmt.where(Connector.source.in_(sources)) if input_types is not None: stmt = stmt.where(Connector.input_type.in_(input_types)) results = db_session.scalars(stmt) return list(results.all()) def connector_by_name_source_exists( connector_name: str, source: DocumentSource, db_session: Session ) -> bool: stmt = select(Connector).where( Connector.name == connector_name, Connector.source == source ) result = db_session.execute(stmt) connector = result.scalar_one_or_none() return connector is not None def fetch_connector_by_id(connector_id: int, db_session: Session) -> Connector | None: stmt = select(Connector).where(Connector.id == connector_id) result = db_session.execute(stmt) connector = result.scalar_one_or_none() return connector def fetch_ingestion_connector_by_name( connector_name: str, db_session: Session ) -> Connector | None: stmt = ( select(Connector) .where(Connector.name == connector_name) .where(Connector.source == DocumentSource.INGESTION_API) ) result = db_session.execute(stmt) connector = result.scalar_one_or_none() return connector def create_connector( db_session: Session, connector_data: ConnectorBase, ) -> ObjectCreationIdResponse: if connector_by_name_source_exists( connector_data.name, connector_data.source, db_session ): raise ValueError( "Connector by this name already exists, duplicate naming not allowed." ) connector = Connector( name=connector_data.name, source=connector_data.source, input_type=connector_data.input_type, connector_specific_config=connector_data.connector_specific_config, refresh_freq=connector_data.refresh_freq, indexing_start=connector_data.indexing_start, prune_freq=connector_data.prune_freq, ) db_session.add(connector) db_session.commit() return ObjectCreationIdResponse(id=connector.id) def update_connector( connector_id: int, connector_data: ConnectorBase, db_session: Session, ) -> Connector | None: connector = fetch_connector_by_id(connector_id, db_session) if connector is None: return None if connector_data.name != connector.name and connector_by_name_source_exists( connector_data.name, connector_data.source, db_session ): raise ValueError( "Connector by this name already exists, duplicate naming not allowed." ) connector.name = connector_data.name connector.source = connector_data.source connector.input_type = connector_data.input_type connector.connector_specific_config = connector_data.connector_specific_config connector.refresh_freq = connector_data.refresh_freq connector.prune_freq = ( connector_data.prune_freq if connector_data.prune_freq is not None else DEFAULT_PRUNING_FREQ ) db_session.commit() return connector def delete_connector( db_session: Session, connector_id: int, ) -> StatusResponse[int]: """Only used in special cases (e.g. a connector is in a bad state and we need to delete it). Be VERY careful using this, as it could lead to a bad state if not used correctly. """ connector = fetch_connector_by_id(connector_id, db_session) if connector is None: return StatusResponse( success=True, message="Connector was already deleted", data=connector_id ) db_session.delete(connector) return StatusResponse( success=True, message="Connector deleted successfully", data=connector_id ) def get_connector_credential_ids( connector_id: int, db_session: Session, ) -> list[int]: connector = fetch_connector_by_id(connector_id, db_session) if connector is None: raise ValueError(f"Connector by id {connector_id} does not exist") return [association.credential.id for association in connector.credentials] def fetch_latest_index_attempt_by_connector( db_session: Session, source: DocumentSource | None = None, ) -> list[IndexAttempt]: latest_index_attempts: list[IndexAttempt] = [] if source: connectors = fetch_connectors(db_session, sources=[source]) else: connectors = fetch_connectors(db_session) if not connectors: return [] for connector in connectors: latest_index_attempt = ( db_session.query(IndexAttempt) .join(ConnectorCredentialPair) .filter(ConnectorCredentialPair.connector_id == connector.id) .order_by(IndexAttempt.time_updated.desc()) .first() ) if latest_index_attempt is not None: latest_index_attempts.append(latest_index_attempt) return latest_index_attempts def fetch_latest_index_attempts_by_status( db_session: Session, ) -> list[IndexAttempt]: subquery = ( db_session.query( IndexAttempt.connector_credential_pair_id, IndexAttempt.status, func.max(IndexAttempt.time_updated).label("time_updated"), ) .group_by(IndexAttempt.connector_credential_pair_id) .group_by(IndexAttempt.status) .subquery() ) alias = aliased(IndexAttempt, subquery) query = db_session.query(IndexAttempt).join( alias, and_( IndexAttempt.connector_credential_pair_id == alias.connector_credential_pair_id, IndexAttempt.status == alias.status, IndexAttempt.time_updated == alias.time_updated, ), ) return cast(list[IndexAttempt], query.all()) def fetch_unique_document_sources(db_session: Session) -> list[DocumentSource]: distinct_sources = db_session.query(Connector.source).distinct().all() sources = [ source[0] for source in distinct_sources if source[0] != DocumentSource.INGESTION_API ] return sources def create_initial_default_connector(db_session: Session) -> None: default_connector_id = 0 default_connector = fetch_connector_by_id(default_connector_id, db_session) if default_connector is not None: if ( default_connector.source != DocumentSource.INGESTION_API or default_connector.input_type != InputType.LOAD_STATE or default_connector.refresh_freq is not None or default_connector.name != "Ingestion API" or default_connector.connector_specific_config != {} or default_connector.prune_freq is not None ): logger.warning( "Default connector does not have expected values. Updating to proper state." ) # Ensure default connector has correct values default_connector.source = DocumentSource.INGESTION_API default_connector.input_type = InputType.LOAD_STATE default_connector.refresh_freq = None default_connector.name = "Ingestion API" default_connector.connector_specific_config = {} default_connector.prune_freq = None db_session.commit() return # Create a new default connector if it doesn't exist connector = Connector( id=default_connector_id, name="Ingestion API", source=DocumentSource.INGESTION_API, input_type=InputType.LOAD_STATE, connector_specific_config={}, refresh_freq=None, prune_freq=None, ) db_session.add(connector) db_session.commit() def mark_ccpair_as_pruned(cc_pair_id: int, db_session: Session) -> None: stmt = select(ConnectorCredentialPair).where( ConnectorCredentialPair.id == cc_pair_id ) cc_pair = db_session.scalar(stmt) if cc_pair is None: raise ValueError(f"No cc_pair with ID: {cc_pair_id}") cc_pair.last_pruned = datetime.now(timezone.utc) db_session.commit() def mark_cc_pair_as_hierarchy_fetched(db_session: Session, cc_pair_id: int) -> None: stmt = select(ConnectorCredentialPair).where( ConnectorCredentialPair.id == cc_pair_id ) cc_pair = db_session.scalar(stmt) if cc_pair is None: raise ValueError(f"No cc_pair with ID: {cc_pair_id}") cc_pair.last_time_hierarchy_fetch = datetime.now(timezone.utc) db_session.commit() def mark_cc_pair_as_permissions_synced( db_session: Session, cc_pair_id: int, start_time: datetime | None ) -> None: stmt = select(ConnectorCredentialPair).where( ConnectorCredentialPair.id == cc_pair_id ) cc_pair = db_session.scalar(stmt) if cc_pair is None: raise ValueError(f"No cc_pair with ID: {cc_pair_id}") cc_pair.last_time_perm_sync = start_time db_session.commit() def mark_cc_pair_as_external_group_synced(db_session: Session, cc_pair_id: int) -> None: stmt = select(ConnectorCredentialPair).where( ConnectorCredentialPair.id == cc_pair_id ) cc_pair = db_session.scalar(stmt) if cc_pair is None: raise ValueError(f"No cc_pair with ID: {cc_pair_id}") # The sync time can be marked after it ran because all group syncs # are run in full, not polling for changes. # If this changes, we need to update this function. cc_pair.last_time_external_group_sync = datetime.now(timezone.utc) db_session.commit() def mark_ccpair_with_indexing_trigger( cc_pair_id: int, indexing_mode: IndexingMode | None, db_session: Session ) -> None: """indexing_mode sets a field which will be picked up by a background task to trigger indexing. Set to None to disable the trigger.""" try: cc_pair = db_session.execute( select(ConnectorCredentialPair) .where(ConnectorCredentialPair.id == cc_pair_id) .with_for_update() ).scalar_one() if cc_pair is None: raise ValueError(f"No cc_pair with ID: {cc_pair_id}") cc_pair.indexing_trigger = indexing_mode db_session.commit() except Exception: db_session.rollback() raise def get_kg_enabled_connectors(db_session: Session) -> list[KGConnectorData]: """ Retrieves a list of connector IDs that have not been KG processed for a given tenant. Args: db_session (Session): The database session to use Returns: list[KGConnectorData]: List of connector IDs with KG extraction enabled but have unprocessed documents """ try: stmt = select(Connector.id, Connector.source, Connector.kg_coverage_days).where( Connector.kg_processing_enabled ) result = db_session.execute(stmt) connector_results = [ KGConnectorData(id=row[0], source=row[1].lower(), kg_coverage_days=row[2]) for row in result.fetchall() ] return connector_results except Exception as e: logger.error(f"Error fetching unprocessed connector IDs: {str(e)}") raise e ================================================ FILE: backend/onyx/db/connector_credential_pair.py ================================================ from datetime import datetime from enum import Enum from typing import TypeVarTuple from fastapi import HTTPException from sqlalchemy import delete from sqlalchemy import desc from sqlalchemy import exists from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import update from sqlalchemy.orm import aliased from sqlalchemy.orm import joinedload from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from onyx.configs.constants import DocumentSource from onyx.db.connector import fetch_connector_by_id from onyx.db.credentials import fetch_credential_by_id from onyx.db.credentials import fetch_credential_by_id_for_user from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import ProcessingMode from onyx.db.models import Connector from onyx.db.models import ConnectorCredentialPair from onyx.db.models import Credential from onyx.db.models import IndexAttempt from onyx.db.models import IndexingStatus from onyx.db.models import SearchSettings from onyx.db.models import User from onyx.db.models import User__UserGroup from onyx.db.models import UserGroup__ConnectorCredentialPair from onyx.db.models import UserRole from onyx.server.models import StatusResponse from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop logger = setup_logger() R = TypeVarTuple("R") class ConnectorType(str, Enum): STANDARD = "standard" USER_FILE = "user_file" def _add_user_filters( stmt: Select[tuple[*R]], user: User, get_editable: bool = True ) -> Select[tuple[*R]]: if user.role == UserRole.ADMIN: return stmt # If anonymous user, only show public cc_pairs if user.is_anonymous: where_clause = ConnectorCredentialPair.access_type == AccessType.PUBLIC return stmt.where(where_clause) stmt = stmt.distinct() UG__CCpair = aliased(UserGroup__ConnectorCredentialPair) User__UG = aliased(User__UserGroup) """ Here we select cc_pairs by relation: User -> User__UserGroup -> UserGroup__ConnectorCredentialPair -> ConnectorCredentialPair """ stmt = stmt.outerjoin(UG__CCpair).outerjoin( User__UG, User__UG.user_group_id == UG__CCpair.user_group_id, ) """ Filter cc_pairs by: - if the user is in the user_group that owns the cc_pair - if the user is not a global_curator, they must also have a curator relationship to the user_group - if editing is being done, we also filter out cc_pairs that are owned by groups that the user isn't a curator for - if we are not editing, we show all cc_pairs in the groups the user is a curator for (as well as public cc_pairs) """ where_clause = User__UG.user_id == user.id if user.role == UserRole.CURATOR and get_editable: where_clause &= User__UG.is_curator == True # noqa: E712 if get_editable: user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id) if user.role == UserRole.CURATOR: user_groups = user_groups.where( User__UserGroup.is_curator == True # noqa: E712 ) where_clause &= ( ~exists() .where(UG__CCpair.cc_pair_id == ConnectorCredentialPair.id) .where(~UG__CCpair.user_group_id.in_(user_groups)) .correlate(ConnectorCredentialPair) ) where_clause |= ConnectorCredentialPair.creator_id == user.id else: where_clause |= ConnectorCredentialPair.access_type == AccessType.PUBLIC where_clause |= ConnectorCredentialPair.access_type == AccessType.SYNC return stmt.where(where_clause) def get_connector_credential_pairs_for_user( db_session: Session, user: User, get_editable: bool = True, ids: list[int] | None = None, eager_load_connector: bool = False, eager_load_credential: bool = False, eager_load_user: bool = False, order_by_desc: bool = False, source: DocumentSource | None = None, processing_mode: ProcessingMode | None = ProcessingMode.REGULAR, defer_connector_config: bool = False, ) -> list[ConnectorCredentialPair]: """Get connector credential pairs for a user. Args: processing_mode: Filter by processing mode. Defaults to REGULAR to hide FILE_SYSTEM connectors from standard admin UI. Pass None to get all. defer_connector_config: If True, skips loading Connector.connector_specific_config to avoid fetching large JSONB blobs when they aren't needed. """ if eager_load_user: assert ( eager_load_credential ), "eager_load_credential must be True if eager_load_user is True" stmt = select(ConnectorCredentialPair).distinct() if eager_load_connector: connector_load = selectinload(ConnectorCredentialPair.connector) if defer_connector_config: connector_load = connector_load.defer(Connector.connector_specific_config) stmt = stmt.options(connector_load) if eager_load_credential: load_opts = selectinload(ConnectorCredentialPair.credential) if eager_load_user: load_opts = load_opts.joinedload(Credential.user) stmt = stmt.options(load_opts) stmt = _add_user_filters(stmt, user, get_editable) if source: stmt = stmt.join(ConnectorCredentialPair.connector).where( Connector.source == source.value ) if ids: stmt = stmt.where(ConnectorCredentialPair.id.in_(ids)) if processing_mode is not None: stmt = stmt.where(ConnectorCredentialPair.processing_mode == processing_mode) if order_by_desc: stmt = stmt.order_by(desc(ConnectorCredentialPair.id)) return list(db_session.scalars(stmt).unique().all()) # For use with our thread-level parallelism utils. Note that any relationships # you wish to use MUST be eagerly loaded, as the session will not be available # after this function to allow lazy loading. def get_connector_credential_pairs_for_user_parallel( user: User, get_editable: bool = True, ids: list[int] | None = None, eager_load_connector: bool = False, eager_load_credential: bool = False, eager_load_user: bool = False, order_by_desc: bool = False, source: DocumentSource | None = None, processing_mode: ProcessingMode | None = ProcessingMode.REGULAR, defer_connector_config: bool = False, ) -> list[ConnectorCredentialPair]: with get_session_with_current_tenant() as db_session: return get_connector_credential_pairs_for_user( db_session=db_session, user=user, get_editable=get_editable, ids=ids, eager_load_connector=eager_load_connector, eager_load_credential=eager_load_credential, eager_load_user=eager_load_user, order_by_desc=order_by_desc, source=source, processing_mode=processing_mode, defer_connector_config=defer_connector_config, ) def get_connector_credential_pairs( db_session: Session, ids: list[int] | None = None ) -> list[ConnectorCredentialPair]: stmt = select(ConnectorCredentialPair).distinct() if ids: stmt = stmt.where(ConnectorCredentialPair.id.in_(ids)) return list(db_session.scalars(stmt).all()) def add_deletion_failure_message( db_session: Session, cc_pair_id: int, failure_message: str, ) -> None: cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, ) if not cc_pair: return cc_pair.deletion_failure_message = failure_message db_session.commit() def get_cc_pair_groups_for_ids( db_session: Session, cc_pair_ids: list[int], ) -> list[UserGroup__ConnectorCredentialPair]: stmt = select(UserGroup__ConnectorCredentialPair).distinct() stmt = stmt.outerjoin( ConnectorCredentialPair, UserGroup__ConnectorCredentialPair.cc_pair_id == ConnectorCredentialPair.id, ) stmt = stmt.where(UserGroup__ConnectorCredentialPair.cc_pair_id.in_(cc_pair_ids)) return list(db_session.scalars(stmt).all()) # For use with our thread-level parallelism utils. Note that any relationships # you wish to use MUST be eagerly loaded, as the session will not be available # after this function to allow lazy loading. def get_cc_pair_groups_for_ids_parallel( cc_pair_ids: list[int], ) -> list[UserGroup__ConnectorCredentialPair]: with get_session_with_current_tenant() as db_session: return get_cc_pair_groups_for_ids(db_session, cc_pair_ids) def get_connector_credential_pair_for_user( db_session: Session, connector_id: int, credential_id: int, user: User, get_editable: bool = True, ) -> ConnectorCredentialPair | None: stmt = select(ConnectorCredentialPair) stmt = _add_user_filters(stmt, user, get_editable) stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id) stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id) result = db_session.execute(stmt) return result.scalar_one_or_none() def get_connector_credential_pair( db_session: Session, connector_id: int, credential_id: int, ) -> ConnectorCredentialPair | None: stmt = select(ConnectorCredentialPair) stmt = stmt.where(ConnectorCredentialPair.connector_id == connector_id) stmt = stmt.where(ConnectorCredentialPair.credential_id == credential_id) result = db_session.execute(stmt) return result.scalar_one_or_none() def get_connector_credential_pair_from_id_for_user( cc_pair_id: int, db_session: Session, user: User, get_editable: bool = True, ) -> ConnectorCredentialPair | None: stmt = select(ConnectorCredentialPair).distinct() stmt = _add_user_filters(stmt, user, get_editable) stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id) result = db_session.execute(stmt) return result.scalar_one_or_none() def verify_user_has_access_to_cc_pair( cc_pair_id: int, db_session: Session, user: User, get_editable: bool = True, ) -> bool: stmt = select(ConnectorCredentialPair.id) stmt = _add_user_filters(stmt, user, get_editable) stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id) result = db_session.execute(stmt) return result.scalars().first() is not None def get_connector_credential_pair_from_id( db_session: Session, cc_pair_id: int, eager_load_connector: bool = False, eager_load_credential: bool = False, ) -> ConnectorCredentialPair | None: stmt = select(ConnectorCredentialPair).distinct() stmt = stmt.where(ConnectorCredentialPair.id == cc_pair_id) if eager_load_credential: stmt = stmt.options(joinedload(ConnectorCredentialPair.credential)) if eager_load_connector: stmt = stmt.options(joinedload(ConnectorCredentialPair.connector)) result = db_session.execute(stmt) return result.scalar_one_or_none() def get_connector_credential_pairs_for_source( db_session: Session, source: DocumentSource, ) -> list[ConnectorCredentialPair]: stmt = ( select(ConnectorCredentialPair) .join(ConnectorCredentialPair.connector) .where(Connector.source == source) ) return list(db_session.scalars(stmt).unique().all()) def get_last_successful_attempt_poll_range_end( cc_pair_id: int, earliest_index: float, search_settings: SearchSettings, db_session: Session, ) -> float: """Used to get the latest `poll_range_end` for a given connector and credential. This can be used to determine the next "start" time for a new index attempt. Note that the attempts time_started is not necessarily correct - that gets set separately and is similar but not exactly the same as the `poll_range_end`. """ latest_successful_index_attempt = ( db_session.query(IndexAttempt) .join( ConnectorCredentialPair, IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id, ) .filter( ConnectorCredentialPair.id == cc_pair_id, IndexAttempt.search_settings_id == search_settings.id, IndexAttempt.status == IndexingStatus.SUCCESS, ) .order_by(IndexAttempt.poll_range_end.desc()) .first() ) if ( not latest_successful_index_attempt or not latest_successful_index_attempt.poll_range_end ): return earliest_index return latest_successful_index_attempt.poll_range_end.timestamp() """Updates""" def _update_connector_credential_pair( db_session: Session, cc_pair: ConnectorCredentialPair, status: ConnectorCredentialPairStatus | None = None, net_docs: int | None = None, run_dt: datetime | None = None, ) -> None: # simply don't update last_successful_index_time if run_dt is not specified # at worst, this would result in re-indexing documents that were already indexed if run_dt is not None: cc_pair.last_successful_index_time = run_dt if net_docs is not None: cc_pair.total_docs_indexed += net_docs if status is not None: cc_pair.status = status db_session.commit() def update_connector_credential_pair_from_id( db_session: Session, cc_pair_id: int, status: ConnectorCredentialPairStatus | None = None, net_docs: int | None = None, run_dt: datetime | None = None, ) -> None: cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, ) if not cc_pair: logger.warning( f"Attempted to update pair for Connector Credential Pair '{cc_pair_id}' but it does not exist" ) return _update_connector_credential_pair( db_session=db_session, cc_pair=cc_pair, status=status, net_docs=net_docs, run_dt=run_dt, ) def update_connector_credential_pair( db_session: Session, connector_id: int, credential_id: int, status: ConnectorCredentialPairStatus | None = None, net_docs: int | None = None, run_dt: datetime | None = None, ) -> None: cc_pair = get_connector_credential_pair( db_session=db_session, connector_id=connector_id, credential_id=credential_id, ) if not cc_pair: logger.warning( f"Attempted to update pair for connector id {connector_id} and credential id {credential_id}" ) return _update_connector_credential_pair( db_session=db_session, cc_pair=cc_pair, status=status, net_docs=net_docs, run_dt=run_dt, ) def set_cc_pair_repeated_error_state( db_session: Session, cc_pair_id: int, in_repeated_error_state: bool, ) -> None: stmt = ( update(ConnectorCredentialPair) .where(ConnectorCredentialPair.id == cc_pair_id) .values(in_repeated_error_state=in_repeated_error_state) ) db_session.execute(stmt) db_session.commit() def delete_connector_credential_pair__no_commit( db_session: Session, connector_id: int, credential_id: int, ) -> None: stmt = delete(ConnectorCredentialPair).where( ConnectorCredentialPair.connector_id == connector_id, ConnectorCredentialPair.credential_id == credential_id, ) db_session.execute(stmt) def associate_default_cc_pair(db_session: Session) -> None: existing_association = ( db_session.query(ConnectorCredentialPair) .filter( ConnectorCredentialPair.connector_id == 0, ConnectorCredentialPair.credential_id == 0, ) .one_or_none() ) if existing_association is not None: return # DefaultCCPair has id 1 since it is the first CC pair created # It is DEFAULT_CC_PAIR_ID, but can't set it explicitly because it messed with the # auto-incrementing id association = ConnectorCredentialPair( connector_id=0, credential_id=0, access_type=AccessType.PUBLIC, name="DefaultCCPair", status=ConnectorCredentialPairStatus.ACTIVE, ) db_session.add(association) db_session.commit() def _relate_groups_to_cc_pair__no_commit( db_session: Session, cc_pair_id: int, user_group_ids: list[int] | None = None, ) -> None: if not user_group_ids: return for group_id in user_group_ids: db_session.add( UserGroup__ConnectorCredentialPair( user_group_id=group_id, cc_pair_id=cc_pair_id ) ) def add_credential_to_connector( db_session: Session, user: User, connector_id: int, credential_id: int, cc_pair_name: str, access_type: AccessType, groups: list[int] | None, auto_sync_options: dict | None = None, initial_status: ConnectorCredentialPairStatus = ConnectorCredentialPairStatus.SCHEDULED, last_successful_index_time: datetime | None = None, seeding_flow: bool = False, processing_mode: ProcessingMode = ProcessingMode.REGULAR, ) -> StatusResponse: connector = fetch_connector_by_id(connector_id, db_session) # If we are in the seeding flow, we shouldn't need to check if the credential belongs to the user if seeding_flow: credential = fetch_credential_by_id( credential_id=credential_id, db_session=db_session, ) else: credential = fetch_credential_by_id_for_user( credential_id, user, db_session, get_editable=False, ) if connector is None: raise HTTPException(status_code=404, detail="Connector does not exist") if access_type == AccessType.SYNC: if not fetch_ee_implementation_or_noop( "onyx.external_permissions.sync_params", "check_if_valid_sync_source", noop_return_value=True, )(connector.source): raise HTTPException( status_code=400, detail=f"Connector of type {connector.source} does not support SYNC access type", ) if credential is None: error_msg = ( f"Credential {credential_id} does not exist or does not belong to user" ) logger.error(error_msg) raise HTTPException( status_code=401, detail=error_msg, ) existing_association = ( db_session.query(ConnectorCredentialPair) .filter( ConnectorCredentialPair.connector_id == connector_id, ConnectorCredentialPair.credential_id == credential_id, ) .one_or_none() ) if existing_association is not None: return StatusResponse( success=False, message=f"Connector {connector_id} already has Credential {credential_id}", data=connector_id, ) association = ConnectorCredentialPair( creator_id=user.id, connector_id=connector_id, credential_id=credential_id, name=cc_pair_name, status=initial_status, access_type=access_type, auto_sync_options=auto_sync_options, last_successful_index_time=last_successful_index_time, processing_mode=processing_mode, ) db_session.add(association) db_session.flush() # make sure the association has an id db_session.refresh(association) _relate_groups_to_cc_pair__no_commit( db_session=db_session, cc_pair_id=association.id, user_group_ids=groups, ) db_session.commit() return StatusResponse( success=True, message=f"Creating new association between Connector {connector_id} and Credential {credential_id}", data=association.id, ) def remove_credential_from_connector( connector_id: int, credential_id: int, user: User, db_session: Session, ) -> StatusResponse[int]: connector = fetch_connector_by_id(connector_id, db_session) credential = fetch_credential_by_id_for_user( credential_id, user, db_session, get_editable=False, ) if connector is None: raise HTTPException(status_code=404, detail="Connector does not exist") if credential is None: raise HTTPException( status_code=404, detail="Credential does not exist or does not belong to user", ) association = get_connector_credential_pair_for_user( db_session=db_session, connector_id=connector_id, credential_id=credential_id, user=user, get_editable=True, ) if association is not None: fetch_ee_implementation_or_noop( "onyx.db.external_perm", "delete_user__ext_group_for_cc_pair__no_commit", )( db_session=db_session, cc_pair_id=association.id, ) db_session.delete(association) db_session.commit() return StatusResponse( success=True, message=f"Credential {credential_id} removed from Connector", data=connector_id, ) return StatusResponse( success=False, message=f"Connector already does not have Credential {credential_id}", data=connector_id, ) def fetch_indexable_standard_connector_credential_pair_ids( db_session: Session, active_cc_pairs_only: bool = True, limit: int | None = None, ) -> list[int]: stmt = select(ConnectorCredentialPair.id) # For regular indexing checks if active_cc_pairs_only: stmt = stmt.where( ConnectorCredentialPair.status.in_( ConnectorCredentialPairStatus.active_statuses() ) ) else: # For embedding swap checks, include PAUSED and exclude DELETING or INVALID stmt = stmt.where( ConnectorCredentialPair.status.in_( ConnectorCredentialPairStatus.indexable_statuses() ) ) if limit: stmt = stmt.limit(limit) return list(db_session.scalars(stmt)) def fetch_connector_credential_pair_for_connector( db_session: Session, connector_id: int, ) -> ConnectorCredentialPair | None: stmt = select(ConnectorCredentialPair).where( ConnectorCredentialPair.connector_id == connector_id, ) return db_session.scalar(stmt) def resync_cc_pair( cc_pair: ConnectorCredentialPair, search_settings_id: int, db_session: Session, ) -> None: """ Updates state stored in the connector_credential_pair table based on the latest index attempt for the given search settings. Args: cc_pair: ConnectorCredentialPair to resync search_settings_id: SearchSettings to use for resync db_session: Database session """ def find_latest_index_attempt( connector_id: int, credential_id: int, only_include_success: bool, db_session: Session, ) -> IndexAttempt | None: query = ( db_session.query(IndexAttempt) .join( ConnectorCredentialPair, IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id, ) .filter( ConnectorCredentialPair.connector_id == connector_id, ConnectorCredentialPair.credential_id == credential_id, IndexAttempt.search_settings_id == search_settings_id, ) ) if only_include_success: query = query.filter(IndexAttempt.status == IndexingStatus.SUCCESS) latest_index_attempt = query.order_by(desc(IndexAttempt.time_started)).first() return latest_index_attempt last_success = find_latest_index_attempt( connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id, only_include_success=True, db_session=db_session, ) cc_pair.last_successful_index_time = ( last_success.time_started if last_success else None ) db_session.commit() # ── Metrics query helpers ────────────────────────────────────────────── def get_connector_health_for_metrics( db_session: Session, ) -> list: # Returns list of Row tuples """Return connector health data for Prometheus metrics. Each row is (cc_pair_id, status, in_repeated_error_state, last_successful_index_time, name, source). """ return ( db_session.query( ConnectorCredentialPair.id, ConnectorCredentialPair.status, ConnectorCredentialPair.in_repeated_error_state, ConnectorCredentialPair.last_successful_index_time, ConnectorCredentialPair.name, Connector.source, ) .join( Connector, ConnectorCredentialPair.connector_id == Connector.id, ) .all() ) ================================================ FILE: backend/onyx/db/constants.py ================================================ SLACK_BOT_PERSONA_PREFIX = "__slack_bot_persona__" DEFAULT_PERSONA_SLACK_CHANNEL_NAME = "DEFAULT_SLACK_CHANNEL" CONNECTOR_VALIDATION_ERROR_MESSAGE_PREFIX = "ConnectorValidationError:" # Sentinel value to distinguish between "not provided" and "explicitly set to None" class UnsetType: def __repr__(self) -> str: return "" UNSET = UnsetType() ================================================ FILE: backend/onyx/db/credentials.py ================================================ from typing import Any from sqlalchemy import exists from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import update from sqlalchemy.orm import Session from sqlalchemy.sql.expression import and_ from sqlalchemy.sql.expression import or_ from onyx.auth.schemas import UserRole from onyx.configs.constants import DocumentSource from onyx.connectors.google_utils.shared_constants import ( DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY, ) from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.models import ConnectorCredentialPair from onyx.db.models import Credential from onyx.db.models import Credential__UserGroup from onyx.db.models import DocumentByConnectorCredentialPair from onyx.db.models import User from onyx.db.models import User__UserGroup from onyx.server.documents.models import CredentialBase from onyx.utils.logger import setup_logger logger = setup_logger() # The credentials for these sources are not real so # permissions are not enforced for them CREDENTIAL_PERMISSIONS_TO_IGNORE = { DocumentSource.FILE, DocumentSource.WEB, DocumentSource.NOT_APPLICABLE, DocumentSource.GOOGLE_SITES, DocumentSource.WIKIPEDIA, DocumentSource.MEDIAWIKI, } PUBLIC_CREDENTIAL_ID = 0 def _add_user_filters( stmt: Select, user: User, get_editable: bool = True, ) -> Select: """Attaches filters to the statement to ensure that the user can only access the appropriate credentials""" if user.is_anonymous: raise ValueError("Anonymous users are not allowed to access credentials") if user.role == UserRole.ADMIN: # Admins can access all credentials that are public or owned by them # or are not associated with any user return stmt.where( or_( Credential.user_id == user.id, Credential.user_id.is_(None), Credential.admin_public == True, # noqa: E712 Credential.source.in_(CREDENTIAL_PERMISSIONS_TO_IGNORE), ) ) if user.role == UserRole.BASIC: # Basic users can only access credentials that are owned by them return stmt.where(Credential.user_id == user.id) stmt = stmt.distinct() """ THIS PART IS FOR CURATORS AND GLOBAL CURATORS Here we select cc_pairs by relation: User -> User__UserGroup -> Credential__UserGroup -> Credential """ stmt = stmt.outerjoin(Credential__UserGroup).outerjoin( User__UserGroup, User__UserGroup.user_group_id == Credential__UserGroup.user_group_id, ) """ Filter Credentials by: - if the user is in the user_group that owns the Credential - if the user is a curator, they must also have a curator relationship to the user_group - if editing is being done, we also filter out Credentials that are owned by groups that the user isn't a curator for - if we are not editing, we show all Credentials in the groups the user is a curator for (as well as public Credentials) - if we are not editing, we return all Credentials directly connected to the user """ where_clause = User__UserGroup.user_id == user.id if user.role == UserRole.CURATOR: where_clause &= User__UserGroup.is_curator == True # noqa: E712 if get_editable: user_groups = select(User__UserGroup.user_group_id).where( User__UserGroup.user_id == user.id ) if user.role == UserRole.CURATOR: user_groups = user_groups.where( User__UserGroup.is_curator == True # noqa: E712 ) where_clause &= ( ~exists() .where(Credential__UserGroup.credential_id == Credential.id) .where(~Credential__UserGroup.user_group_id.in_(user_groups)) .correlate(Credential) ) else: where_clause |= Credential.curator_public == True # noqa: E712 where_clause |= Credential.user_id == user.id # noqa: E712 where_clause |= Credential.source.in_(CREDENTIAL_PERMISSIONS_TO_IGNORE) return stmt.where(where_clause) def _relate_credential_to_user_groups__no_commit( db_session: Session, credential_id: int, user_group_ids: list[int], ) -> None: credential_user_groups = [] for group_id in user_group_ids: credential_user_groups.append( Credential__UserGroup( credential_id=credential_id, user_group_id=group_id, ) ) db_session.add_all(credential_user_groups) def fetch_credentials_for_user( db_session: Session, user: User, get_editable: bool = True, ) -> list[Credential]: stmt = select(Credential) stmt = _add_user_filters(stmt, user, get_editable=get_editable) results = db_session.scalars(stmt) return list(results.all()) def fetch_credential_by_id_for_user( credential_id: int, user: User, db_session: Session, get_editable: bool = True, ) -> Credential | None: stmt = select(Credential).distinct() stmt = stmt.where(Credential.id == credential_id) stmt = _add_user_filters( stmt=stmt, user=user, get_editable=get_editable, ) result = db_session.execute(stmt) credential = result.scalar_one_or_none() return credential def fetch_credential_by_id( credential_id: int, db_session: Session, ) -> Credential | None: stmt = select(Credential).distinct() stmt = stmt.where(Credential.id == credential_id) result = db_session.execute(stmt) credential = result.scalar_one_or_none() return credential def fetch_credentials_by_source_for_user( db_session: Session, user: User, document_source: DocumentSource | None = None, get_editable: bool = True, ) -> list[Credential]: base_query = select(Credential).where(Credential.source == document_source) base_query = _add_user_filters(base_query, user, get_editable=get_editable) credentials = db_session.execute(base_query).scalars().all() return list(credentials) def fetch_credentials_by_source( db_session: Session, document_source: DocumentSource | None = None, ) -> list[Credential]: base_query = select(Credential).where(Credential.source == document_source) credentials = db_session.execute(base_query).scalars().all() return list(credentials) def swap_credentials_connector( new_credential_id: int, connector_id: int, user: User, db_session: Session ) -> ConnectorCredentialPair: # Check if the user has permission to use the new credential new_credential = fetch_credential_by_id_for_user( new_credential_id, user, db_session ) if not new_credential: raise ValueError( f"No Credential found with id {new_credential_id} or user doesn't have permission to use it" ) # Existing pair existing_pair = db_session.execute( select(ConnectorCredentialPair).where( ConnectorCredentialPair.connector_id == connector_id ) ).scalar_one_or_none() if not existing_pair: raise ValueError( f"No ConnectorCredentialPair found for connector_id {connector_id}" ) # Check if the new credential is compatible with the connector if new_credential.source != existing_pair.connector.source: raise ValueError( f"New credential source {new_credential.source} does not match connector source {existing_pair.connector.source}" ) db_session.execute( update(DocumentByConnectorCredentialPair) .where( and_( DocumentByConnectorCredentialPair.connector_id == connector_id, DocumentByConnectorCredentialPair.credential_id == existing_pair.credential_id, ) ) .values(credential_id=new_credential_id) ) # Update the existing pair with the new credential existing_pair.credential_id = new_credential_id existing_pair.credential = new_credential # Update ccpair status if it's in INVALID state if existing_pair.status == ConnectorCredentialPairStatus.INVALID: existing_pair.status = ConnectorCredentialPairStatus.ACTIVE # Commit the changes db_session.commit() # Refresh the object to ensure all relationships are up-to-date db_session.refresh(existing_pair) return existing_pair def create_credential( credential_data: CredentialBase, user: User, db_session: Session, ) -> Credential: credential = Credential( credential_json=credential_data.credential_json, user_id=user.id, admin_public=credential_data.admin_public, source=credential_data.source, name=credential_data.name, curator_public=credential_data.curator_public, ) db_session.add(credential) db_session.flush() # This ensures the credential gets an ID _relate_credential_to_user_groups__no_commit( db_session=db_session, credential_id=credential.id, user_group_ids=credential_data.groups, ) db_session.commit() # Expire to ensure credential_json is reloaded as SensitiveValue from DB db_session.expire(credential) return credential def _cleanup_credential__user_group_relationships__no_commit( db_session: Session, credential_id: int ) -> None: """NOTE: does not commit the transaction.""" db_session.query(Credential__UserGroup).filter( Credential__UserGroup.credential_id == credential_id ).delete(synchronize_session=False) def alter_credential( credential_id: int, name: str, credential_json: dict[str, Any], user: User, db_session: Session, ) -> Credential | None: # TODO: add user group relationship update credential = fetch_credential_by_id_for_user(credential_id, user, db_session) if credential is None: return None credential.name = name # Get existing credential_json and merge with new values existing_json = ( credential.credential_json.get_value(apply_mask=False) if credential.credential_json else {} ) credential.credential_json = { # type: ignore[assignment] **existing_json, **credential_json, } credential.user_id = user.id db_session.commit() # Expire to ensure credential_json is reloaded as SensitiveValue from DB db_session.expire(credential) return credential def update_credential( credential_id: int, credential_data: CredentialBase, user: User, db_session: Session, ) -> Credential | None: credential = fetch_credential_by_id_for_user(credential_id, user, db_session) if credential is None: return None credential.credential_json = credential_data.credential_json # type: ignore[assignment] credential.user_id = user.id if user is not None else None db_session.commit() # Expire to ensure credential_json is reloaded as SensitiveValue from DB db_session.expire(credential) return credential def update_credential_json( credential_id: int, credential_json: dict[str, Any], user: User, db_session: Session, ) -> Credential | None: credential = fetch_credential_by_id_for_user(credential_id, user, db_session) if credential is None: return None credential.credential_json = credential_json # type: ignore[assignment] db_session.commit() # Expire to ensure credential_json is reloaded as SensitiveValue from DB db_session.expire(credential) return credential def backend_update_credential_json( credential: Credential, credential_json: dict[str, Any], db_session: Session, ) -> None: """This should not be used in any flows involving the frontend or users""" credential.credential_json = credential_json # type: ignore[assignment] db_session.commit() def _delete_credential_internal( credential: Credential, credential_id: int, db_session: Session, force: bool = False, ) -> None: """Internal utility function to handle the actual deletion of a credential""" associated_connectors = ( db_session.query(ConnectorCredentialPair) .filter(ConnectorCredentialPair.credential_id == credential_id) .all() ) associated_doc_cc_pairs = ( db_session.query(DocumentByConnectorCredentialPair) .filter(DocumentByConnectorCredentialPair.credential_id == credential_id) .all() ) if associated_connectors or associated_doc_cc_pairs: if force: logger.warning( f"Force deleting credential {credential_id} and its associated records" ) # Delete DocumentByConnectorCredentialPair records first for doc_cc_pair in associated_doc_cc_pairs: db_session.delete(doc_cc_pair) # Then delete ConnectorCredentialPair records for connector in associated_connectors: db_session.delete(connector) # Commit these deletions before deleting the credential db_session.flush() else: raise ValueError( f"Cannot delete credential as it is still associated with " f"{len(associated_connectors)} connector(s) and {len(associated_doc_cc_pairs)} document(s). " ) if force: logger.warning(f"Force deleting credential {credential_id}") else: logger.notice(f"Deleting credential {credential_id}") _cleanup_credential__user_group_relationships__no_commit(db_session, credential_id) db_session.delete(credential) db_session.commit() def delete_credential_for_user( credential_id: int, user: User, db_session: Session, force: bool = False, ) -> None: """Delete a credential that belongs to a specific user""" credential = fetch_credential_by_id_for_user(credential_id, user, db_session) if credential is None: raise ValueError( f"Credential by provided id {credential_id} does not exist or does not belong to user" ) _delete_credential_internal(credential, credential_id, db_session, force) def delete_credential( credential_id: int, db_session: Session, force: bool = False, ) -> None: """Delete a credential regardless of ownership (admin function)""" credential = fetch_credential_by_id(credential_id, db_session) if credential is None: raise ValueError(f"Credential by provided id {credential_id} does not exist") _delete_credential_internal(credential, credential_id, db_session, force) def create_initial_public_credential(db_session: Session) -> None: error_msg = ( "DB is not in a valid initial state." "There must exist an empty public credential for data connectors that do not require additional Auth." ) first_credential = fetch_credential_by_id( credential_id=PUBLIC_CREDENTIAL_ID, db_session=db_session, ) if first_credential is not None: credential_json_value = ( first_credential.credential_json.get_value(apply_mask=False) if first_credential.credential_json else {} ) if credential_json_value != {} or first_credential.user is not None: raise ValueError(error_msg) return credential = Credential( id=PUBLIC_CREDENTIAL_ID, credential_json={}, user_id=None, ) db_session.add(credential) db_session.commit() def cleanup_gmail_credentials(db_session: Session) -> None: gmail_credentials = fetch_credentials_by_source( db_session=db_session, document_source=DocumentSource.GMAIL ) for credential in gmail_credentials: db_session.delete(credential) db_session.commit() def cleanup_google_drive_credentials(db_session: Session) -> None: google_drive_credentials = fetch_credentials_by_source( db_session=db_session, document_source=DocumentSource.GOOGLE_DRIVE ) for credential in google_drive_credentials: db_session.delete(credential) db_session.commit() def delete_service_account_credentials( user: User, db_session: Session, source: DocumentSource ) -> None: credentials = fetch_credentials_for_user(db_session=db_session, user=user) for credential in credentials: credential_json = ( credential.credential_json.get_value(apply_mask=False) if credential.credential_json else {} ) if ( credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY) and credential.source == source ): db_session.delete(credential) db_session.commit() ================================================ FILE: backend/onyx/db/dal.py ================================================ """Base Data Access Layer (DAL) for database operations. The DAL pattern groups related database operations into cohesive classes with explicit session management. It supports two usage modes: 1. **External session** (FastAPI endpoints) — the caller provides a session whose lifecycle is managed by FastAPI's dependency injection. 2. **Self-managed session** (Celery tasks, scripts) — the DAL creates its own session via the tenant-aware session factory. Subclasses add domain-specific query methods while inheriting session management. See ``ee.onyx.db.scim.ScimDAL`` for a concrete example. Example (FastAPI):: def get_scim_dal(db_session: Session = Depends(get_session)) -> ScimDAL: return ScimDAL(db_session) @router.get("/users") def list_users(dal: ScimDAL = Depends(get_scim_dal)) -> ...: return dal.list_user_mappings(...) Example (Celery):: with ScimDAL.from_tenant("tenant_abc") as dal: dal.create_user_mapping(...) dal.commit() """ from __future__ import annotations from collections.abc import Generator from contextlib import contextmanager from sqlalchemy.orm import Session from onyx.db.engine.sql_engine import get_session_with_tenant class DAL: """Base Data Access Layer. Holds a SQLAlchemy session and provides transaction control helpers. Subclasses add domain-specific query methods. """ def __init__(self, db_session: Session) -> None: self._session = db_session @property def session(self) -> Session: """Direct access to the underlying session for advanced use cases.""" return self._session def commit(self) -> None: self._session.commit() def flush(self) -> None: self._session.flush() def rollback(self) -> None: self._session.rollback() @classmethod @contextmanager def from_tenant(cls, tenant_id: str) -> Generator["DAL", None, None]: """Create a DAL with a self-managed session for the given tenant. The session is automatically closed when the context manager exits. The caller must explicitly call ``commit()`` to persist changes. """ with get_session_with_tenant(tenant_id=tenant_id) as session: yield cls(session) ================================================ FILE: backend/onyx/db/deletion_attempt.py ================================================ from sqlalchemy.orm import Session from onyx.db.index_attempt import get_last_attempt from onyx.db.models import ConnectorCredentialPair from onyx.db.models import IndexingStatus from onyx.db.search_settings import get_current_search_settings def check_deletion_attempt_is_allowed( connector_credential_pair: ConnectorCredentialPair, db_session: Session, allow_scheduled: bool = False, ) -> str | None: """ To be deletable: (1) connector should be paused (2) there should be no in-progress/planned index attempts Returns an error message if the deletion attempt is not allowed, otherwise None. """ base_error_msg = ( f"Connector with ID '{connector_credential_pair.connector_id}' and credential ID " f"'{connector_credential_pair.credential_id}' is not deletable." ) if connector_credential_pair.status.is_active(): return base_error_msg + " Connector must be paused." connector_id = connector_credential_pair.connector_id credential_id = connector_credential_pair.credential_id search_settings = get_current_search_settings(db_session) last_indexing = get_last_attempt( connector_id=connector_id, credential_id=credential_id, search_settings_id=search_settings.id, db_session=db_session, ) if not last_indexing: return None if last_indexing.status == IndexingStatus.IN_PROGRESS or ( last_indexing.status == IndexingStatus.NOT_STARTED and not allow_scheduled ): return ( base_error_msg + " There is an ongoing / planned indexing attempt. " + "The indexing attempt must be completed or cancelled before deletion." ) return None ================================================ FILE: backend/onyx/db/discord_bot.py ================================================ """CRUD operations for Discord bot models.""" from datetime import datetime from datetime import timezone from sqlalchemy import delete from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session from onyx.auth.api_key import build_displayable_api_key from onyx.auth.api_key import generate_api_key from onyx.auth.api_key import hash_api_key from onyx.auth.schemas import UserRole from onyx.configs.constants import DISCORD_SERVICE_API_KEY_NAME from onyx.db.api_key import insert_api_key from onyx.db.models import ApiKey from onyx.db.models import DiscordBotConfig from onyx.db.models import DiscordChannelConfig from onyx.db.models import DiscordGuildConfig from onyx.db.models import User from onyx.db.utils import DiscordChannelView from onyx.server.api_key.models import APIKeyArgs from onyx.utils.logger import setup_logger logger = setup_logger() # === DiscordBotConfig === def get_discord_bot_config(db_session: Session) -> DiscordBotConfig | None: """Get the Discord bot config for this tenant (at most one).""" return db_session.scalar(select(DiscordBotConfig).limit(1)) def create_discord_bot_config( db_session: Session, bot_token: str, ) -> DiscordBotConfig: """Create the Discord bot config. Raises ValueError if already exists. The check constraint on id='SINGLETON' ensures only one config per tenant. """ existing = get_discord_bot_config(db_session) if existing: raise ValueError("Discord bot config already exists") config = DiscordBotConfig(bot_token=bot_token) db_session.add(config) try: db_session.flush() except IntegrityError: # Race condition: another request created the config concurrently db_session.rollback() raise ValueError("Discord bot config already exists") return config def delete_discord_bot_config(db_session: Session) -> bool: """Delete the Discord bot config. Returns True if deleted.""" result = db_session.execute(delete(DiscordBotConfig)) db_session.flush() return result.rowcount > 0 # type: ignore[attr-defined] # === Discord Service API Key === def get_discord_service_api_key(db_session: Session) -> ApiKey | None: """Get the Discord service API key if it exists.""" return db_session.scalar( select(ApiKey).where(ApiKey.name == DISCORD_SERVICE_API_KEY_NAME) ) def get_or_create_discord_service_api_key( db_session: Session, tenant_id: str, ) -> str: """Get existing Discord service API key or create one. The API key is used by the Discord bot to authenticate with the Onyx API pods when sending chat requests. Args: db_session: Database session for the tenant. tenant_id: The tenant ID (used for logging/context). Returns: The raw API key string (not hashed). Raises: RuntimeError: If API key creation fails. """ # Check for existing key existing = get_discord_service_api_key(db_session) if existing: # Database only stores the hash, so we must regenerate to get the raw key. # This is safe since the Discord bot is the only consumer of this key. logger.debug( f"Found existing Discord service API key for tenant {tenant_id} that isn't in cache, regenerating to update cache" ) new_api_key = generate_api_key(tenant_id) existing.hashed_api_key = hash_api_key(new_api_key) existing.api_key_display = build_displayable_api_key(new_api_key) db_session.flush() return new_api_key # Create new API key logger.info(f"Creating Discord service API key for tenant {tenant_id}") api_key_args = APIKeyArgs( name=DISCORD_SERVICE_API_KEY_NAME, role=UserRole.LIMITED, # Limited role is sufficient for chat requests ) api_key_descriptor = insert_api_key( db_session=db_session, api_key_args=api_key_args, user_id=None, # Service account, no owner ) if not api_key_descriptor.api_key: raise RuntimeError( f"Failed to create Discord service API key for tenant {tenant_id}" ) return api_key_descriptor.api_key def delete_discord_service_api_key(db_session: Session) -> bool: """Delete the Discord service API key for a tenant. Called when: - Bot config is deleted (self-hosted) - All guild configs are deleted (Cloud) Args: db_session: Database session for the tenant. Returns: True if the key was deleted, False if it didn't exist. """ existing_key = get_discord_service_api_key(db_session) if not existing_key: return False # Also delete the associated user api_key_user = db_session.scalar( select(User).where(User.id == existing_key.user_id) # type: ignore[arg-type] ) db_session.delete(existing_key) if api_key_user: db_session.delete(api_key_user) db_session.flush() logger.info("Deleted Discord service API key") return True # === DiscordGuildConfig === def get_guild_configs( db_session: Session, include_channels: bool = False, ) -> list[DiscordGuildConfig]: """Get all guild configs for this tenant.""" stmt = select(DiscordGuildConfig) if include_channels: stmt = stmt.options(joinedload(DiscordGuildConfig.channels)) return list(db_session.scalars(stmt).unique().all()) def get_guild_config_by_internal_id( db_session: Session, internal_id: int, ) -> DiscordGuildConfig | None: """Get a specific guild config by its ID.""" return db_session.scalar( select(DiscordGuildConfig).where(DiscordGuildConfig.id == internal_id) ) def get_guild_config_by_discord_id( db_session: Session, guild_id: int, ) -> DiscordGuildConfig | None: """Get a guild config by Discord guild ID.""" return db_session.scalar( select(DiscordGuildConfig).where(DiscordGuildConfig.guild_id == guild_id) ) def get_guild_config_by_registration_key( db_session: Session, registration_key: str, ) -> DiscordGuildConfig | None: """Get a guild config by its registration key.""" return db_session.scalar( select(DiscordGuildConfig).where( DiscordGuildConfig.registration_key == registration_key ) ) def create_guild_config( db_session: Session, registration_key: str, ) -> DiscordGuildConfig: """Create a new guild config with a registration key (guild_id=NULL).""" config = DiscordGuildConfig(registration_key=registration_key) db_session.add(config) db_session.flush() return config def register_guild( db_session: Session, config: DiscordGuildConfig, guild_id: int, guild_name: str, ) -> DiscordGuildConfig: """Complete registration by setting guild_id and guild_name.""" config.guild_id = guild_id config.guild_name = guild_name config.registered_at = datetime.now(timezone.utc) db_session.flush() return config def update_guild_config( db_session: Session, config: DiscordGuildConfig, enabled: bool, default_persona_id: int | None = None, ) -> DiscordGuildConfig: """Update guild config fields.""" config.enabled = enabled config.default_persona_id = default_persona_id db_session.flush() return config def delete_guild_config( db_session: Session, internal_id: int, ) -> bool: """Delete guild config (cascades to channel configs). Returns True if deleted.""" result = db_session.execute( delete(DiscordGuildConfig).where(DiscordGuildConfig.id == internal_id) ) db_session.flush() return result.rowcount > 0 # type: ignore[attr-defined] # === DiscordChannelConfig === def get_channel_configs( db_session: Session, guild_config_id: int, ) -> list[DiscordChannelConfig]: """Get all channel configs for a guild.""" return list( db_session.scalars( select(DiscordChannelConfig).where( DiscordChannelConfig.guild_config_id == guild_config_id ) ).all() ) def get_channel_config_by_discord_ids( db_session: Session, guild_id: int, channel_id: int, ) -> DiscordChannelConfig | None: """Get a specific channel config by guild_id and channel_id.""" return db_session.scalar( select(DiscordChannelConfig) .join(DiscordGuildConfig) .where( DiscordGuildConfig.guild_id == guild_id, DiscordChannelConfig.channel_id == channel_id, ) ) def get_channel_config_by_internal_ids( db_session: Session, guild_config_id: int, channel_config_id: int, ) -> DiscordChannelConfig | None: """Get a specific channel config by guild_config_id and channel_config_id""" return db_session.scalar( select(DiscordChannelConfig).where( DiscordChannelConfig.guild_config_id == guild_config_id, DiscordChannelConfig.id == channel_config_id, ) ) def update_discord_channel_config( db_session: Session, config: DiscordChannelConfig, channel_name: str, thread_only_mode: bool, require_bot_invocation: bool, enabled: bool, persona_override_id: int | None = None, ) -> DiscordChannelConfig: """Update channel config fields.""" config.channel_name = channel_name config.require_bot_invocation = require_bot_invocation config.persona_override_id = persona_override_id config.enabled = enabled config.thread_only_mode = thread_only_mode db_session.flush() return config def delete_discord_channel_config( db_session: Session, guild_config_id: int, channel_config_id: int, ) -> bool: """Delete a channel config. Returns True if deleted.""" result = db_session.execute( delete(DiscordChannelConfig).where( DiscordChannelConfig.guild_config_id == guild_config_id, DiscordChannelConfig.id == channel_config_id, ) ) db_session.flush() return result.rowcount > 0 # type: ignore[attr-defined] def create_channel_config( db_session: Session, guild_config_id: int, channel_view: DiscordChannelView, ) -> DiscordChannelConfig: """Create a new channel config with default settings (disabled by default, admin enables via UI).""" config = DiscordChannelConfig( guild_config_id=guild_config_id, channel_id=channel_view.channel_id, channel_name=channel_view.channel_name, channel_type=channel_view.channel_type, is_private=channel_view.is_private, ) db_session.add(config) db_session.flush() return config def bulk_create_channel_configs( db_session: Session, guild_config_id: int, channels: list[DiscordChannelView], ) -> list[DiscordChannelConfig]: """Create multiple channel configs at once. Skips existing channels.""" # Get existing channel IDs for this guild existing_channel_ids = set( db_session.scalars( select(DiscordChannelConfig.channel_id).where( DiscordChannelConfig.guild_config_id == guild_config_id ) ).all() ) # Create configs for new channels only new_configs = [] for channel_view in channels: if channel_view.channel_id not in existing_channel_ids: config = DiscordChannelConfig( guild_config_id=guild_config_id, channel_id=channel_view.channel_id, channel_name=channel_view.channel_name, channel_type=channel_view.channel_type, is_private=channel_view.is_private, ) db_session.add(config) new_configs.append(config) db_session.flush() return new_configs def sync_channel_configs( db_session: Session, guild_config_id: int, current_channels: list[DiscordChannelView], ) -> tuple[int, int, int]: """Sync channel configs with current Discord channels. - Creates configs for new channels (disabled by default) - Removes configs for deleted channels - Updates names and types for existing channels if changed Returns: (added_count, removed_count, updated_count) """ current_channel_map = { channel_view.channel_id: channel_view for channel_view in current_channels } current_channel_ids = set(current_channel_map.keys()) # Get existing configs existing_configs = get_channel_configs(db_session, guild_config_id) existing_channel_ids = {c.channel_id for c in existing_configs} # Find channels to add, remove, and potentially update to_add = current_channel_ids - existing_channel_ids to_remove = existing_channel_ids - current_channel_ids # Add new channels added_count = 0 for channel_id in to_add: channel_view = current_channel_map[channel_id] create_channel_config(db_session, guild_config_id, channel_view) added_count += 1 # Remove deleted channels removed_count = 0 for config in existing_configs: if config.channel_id in to_remove: db_session.delete(config) removed_count += 1 # Update names, types, and privacy for existing channels if changed updated_count = 0 for config in existing_configs: if config.channel_id in current_channel_ids: channel_view = current_channel_map[config.channel_id] changed = False if config.channel_name != channel_view.channel_name: config.channel_name = channel_view.channel_name changed = True if config.channel_type != channel_view.channel_type: config.channel_type = channel_view.channel_type changed = True if config.is_private != channel_view.is_private: config.is_private = channel_view.is_private changed = True if changed: updated_count += 1 db_session.flush() return added_count, removed_count, updated_count ================================================ FILE: backend/onyx/db/document.py ================================================ import contextlib import time from collections.abc import Generator from collections.abc import Iterable from collections.abc import Sequence from datetime import datetime from datetime import timedelta from datetime import timezone from typing import Any from uuid import UUID from sqlalchemy import and_ from sqlalchemy import delete from sqlalchemy import exists from sqlalchemy import func from sqlalchemy import or_ from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import tuple_ from sqlalchemy import update from sqlalchemy.dialects.postgresql import insert from sqlalchemy.engine.util import TransactionalContext from sqlalchemy.exc import OperationalError from sqlalchemy.orm import Session from sqlalchemy.sql.expression import null from onyx.configs.constants import DEFAULT_BOOST from onyx.configs.constants import DocumentSource from onyx.configs.kg_configs import KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES from onyx.db.chunk import delete_chunk_stats_by_connector_credential_pair__no_commit from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.document_access import apply_document_access_filter from onyx.db.entities import delete_from_kg_entities__no_commit from onyx.db.entities import delete_from_kg_entities_extraction_staging__no_commit from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.feedback import delete_document_feedback_for_documents__no_commit from onyx.db.models import Connector from onyx.db.models import ConnectorCredentialPair from onyx.db.models import Credential from onyx.db.models import Document as DbDocument from onyx.db.models import DocumentByConnectorCredentialPair from onyx.db.models import KGEntity from onyx.db.models import KGRelationship from onyx.db.models import User from onyx.db.relationships import delete_from_kg_relationships__no_commit from onyx.db.relationships import ( delete_from_kg_relationships_extraction_staging__no_commit, ) from onyx.db.tag import delete_document_tags_for_documents__no_commit from onyx.db.utils import DocumentRow from onyx.db.utils import model_to_dict from onyx.db.utils import SortOrder from onyx.document_index.interfaces import DocumentMetadata from onyx.kg.models import KGStage from onyx.server.documents.models import ConnectorCredentialPairIdentifier from onyx.utils.logger import setup_logger logger = setup_logger() ONE_HOUR_IN_SECONDS = 60 * 60 def check_docs_exist(db_session: Session) -> bool: stmt = select(exists(DbDocument)) result = db_session.execute(stmt) return result.scalar() or False def count_documents_by_needs_sync(session: Session) -> int: """Get the count of all documents where: 1. last_modified is newer than last_synced 2. last_synced is null (meaning we've never synced) AND the document has a relationship with a connector/credential pair TODO: The documents without a relationship with a connector/credential pair should be cleaned up somehow eventually. This function executes the query and returns the count of documents matching the criteria.""" return ( session.query(DbDocument.id) .filter( or_( DbDocument.last_modified > DbDocument.last_synced, DbDocument.last_synced.is_(None), ) ) .count() ) def construct_document_id_select_by_needs_sync() -> Select: """Get all document IDs that need syncing across all connector credential pairs. Returns a Select statement for documents where: 1. last_modified is newer than last_synced 2. last_synced is null (meaning we've never synced) AND the document has a relationship with a connector/credential pair """ return select(DbDocument.id).where( or_( DbDocument.last_modified > DbDocument.last_synced, DbDocument.last_synced.is_(None), ) ) def construct_document_id_select_for_connector_credential_pair( connector_id: int, credential_id: int | None = None ) -> Select: initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where( and_( DocumentByConnectorCredentialPair.connector_id == connector_id, DocumentByConnectorCredentialPair.credential_id == credential_id, ) ) stmt = ( select(DbDocument.id).where(DbDocument.id.in_(initial_doc_ids_stmt)).distinct() ) return stmt def construct_document_select_for_connector_credential_pair( connector_id: int, credential_id: int | None = None ) -> Select: initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where( and_( DocumentByConnectorCredentialPair.connector_id == connector_id, DocumentByConnectorCredentialPair.credential_id == credential_id, ) ) stmt = select(DbDocument).where(DbDocument.id.in_(initial_doc_ids_stmt)).distinct() return stmt def get_documents_for_cc_pair( db_session: Session, cc_pair_id: int, ) -> list[DbDocument]: cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, ) if not cc_pair: raise ValueError(f"No CC pair found with ID: {cc_pair_id}") stmt = construct_document_select_for_connector_credential_pair( connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id ) return list(db_session.scalars(stmt).all()) def get_document_ids_for_connector_credential_pair( db_session: Session, connector_id: int, credential_id: int ) -> list[str]: doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where( and_( DocumentByConnectorCredentialPair.connector_id == connector_id, DocumentByConnectorCredentialPair.credential_id == credential_id, ) ) return list(db_session.execute(doc_ids_stmt).scalars().all()) def get_documents_for_connector_credential_pair_limited_columns( db_session: Session, connector_id: int, credential_id: int, sort_order: SortOrder | None = None, ) -> Sequence[DocumentRow]: doc_ids_subquery = select(DocumentByConnectorCredentialPair.id).where( and_( DocumentByConnectorCredentialPair.connector_id == connector_id, DocumentByConnectorCredentialPair.credential_id == credential_id, ) ) doc_ids_subquery = doc_ids_subquery.join( DbDocument, DocumentByConnectorCredentialPair.id == DbDocument.id ) stmt = select( DbDocument.id, DbDocument.doc_metadata, DbDocument.external_user_group_ids ) stmt = stmt.where(DbDocument.id.in_(doc_ids_subquery)) if sort_order == SortOrder.ASC: stmt = stmt.order_by(DbDocument.last_modified.asc()) elif sort_order == SortOrder.DESC: stmt = stmt.order_by(DbDocument.last_modified.desc()) rows = db_session.execute(stmt).mappings().all() doc_rows: list[DocumentRow] = [] for row in rows: doc_row = DocumentRow( id=row.id, doc_metadata=row.doc_metadata, external_user_group_ids=row.external_user_group_ids or [], ) doc_rows.append(doc_row) return doc_rows def get_documents_for_connector_credential_pair( db_session: Session, connector_id: int, credential_id: int, limit: int | None = None ) -> Sequence[DbDocument]: initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where( and_( DocumentByConnectorCredentialPair.connector_id == connector_id, DocumentByConnectorCredentialPair.credential_id == credential_id, ) ) stmt = select(DbDocument).where(DbDocument.id.in_(initial_doc_ids_stmt)).distinct() if limit: stmt = stmt.limit(limit) return db_session.scalars(stmt).all() def get_documents_by_ids( db_session: Session, document_ids: list[str], ) -> list[DbDocument]: stmt = select(DbDocument).where(DbDocument.id.in_(document_ids)) documents = db_session.execute(stmt).scalars().all() return list(documents) def get_documents_by_source( db_session: Session, source: DocumentSource, creator_id: UUID | None = None, ) -> list[DbDocument]: """Get all documents associated with a specific source type. This queries through the connector relationship to find all documents that were indexed by connectors of the given source type. Args: db_session: Database session source: The document source type to filter by creator_id: If provided, only return documents from connectors created by this user. Filters via ConnectorCredentialPair. """ stmt = ( select(DbDocument) .join( DocumentByConnectorCredentialPair, DbDocument.id == DocumentByConnectorCredentialPair.id, ) .join( ConnectorCredentialPair, and_( DocumentByConnectorCredentialPair.connector_id == ConnectorCredentialPair.connector_id, DocumentByConnectorCredentialPair.credential_id == ConnectorCredentialPair.credential_id, ), ) .join( Connector, ConnectorCredentialPair.connector_id == Connector.id, ) .where(Connector.source == source) ) if creator_id is not None: stmt = stmt.where(ConnectorCredentialPair.creator_id == creator_id) stmt = stmt.distinct() documents = db_session.execute(stmt).scalars().all() return list(documents) def _apply_last_updated_cursor_filter( stmt: Select, cursor_last_modified: datetime | None, cursor_last_synced: datetime | None, cursor_document_id: str | None, is_ascending: bool, ) -> Select: """Apply cursor filter for last_updated sorting. ASC uses nulls_first (NULLs at start), DESC uses nulls_last (NULLs at end). This affects which extra clauses are needed when the cursor has NULL last_synced vs non-NULL last_synced. """ if not cursor_last_modified or not cursor_document_id: return stmt # Pick comparison operators based on sort direction if is_ascending: modified_cmp = DbDocument.last_modified > cursor_last_modified synced_cmp = DbDocument.last_synced > cursor_last_synced id_cmp = DbDocument.id > cursor_document_id else: modified_cmp = DbDocument.last_modified < cursor_last_modified synced_cmp = DbDocument.last_synced < cursor_last_synced id_cmp = DbDocument.id < cursor_document_id if cursor_last_synced is None: # Cursor has NULL last_synced # ASC (nulls_first): NULL is at start, so non-NULL values come after # DESC (nulls_last): NULL is at end, so nothing with non-NULL comes after base_clauses = [ modified_cmp, and_( DbDocument.last_modified == cursor_last_modified, DbDocument.last_synced.is_(None), id_cmp, ), ] if is_ascending: # Any non-NULL last_synced comes after NULL when nulls_first base_clauses.append( and_( DbDocument.last_modified == cursor_last_modified, DbDocument.last_synced.is_not(None), ) ) return stmt.where(or_(*base_clauses)) # Cursor has non-NULL last_synced # ASC (nulls_first): NULLs came before, so no NULL clause needed # DESC (nulls_last): NULLs come after non-NULL values synced_clauses = [ synced_cmp, and_(DbDocument.last_synced == cursor_last_synced, id_cmp), ] if not is_ascending: # NULLs come after all non-NULL values when nulls_last synced_clauses.append(DbDocument.last_synced.is_(None)) return stmt.where( or_( modified_cmp, and_( DbDocument.last_modified == cursor_last_modified, or_(*synced_clauses), ), ) ) def _apply_name_cursor_filter_asc( stmt: Select, cursor_name: str | None, cursor_document_id: str | None, ) -> Select: """Apply cursor filter for name ASC sorting.""" if not cursor_name or not cursor_document_id: return stmt return stmt.where( or_( DbDocument.semantic_id > cursor_name, and_( DbDocument.semantic_id == cursor_name, DbDocument.id > cursor_document_id, ), ) ) def _apply_name_cursor_filter_desc( stmt: Select, cursor_name: str | None, cursor_document_id: str | None, ) -> Select: """Apply cursor filter for name DESC sorting.""" if not cursor_name or not cursor_document_id: return stmt return stmt.where( or_( DbDocument.semantic_id < cursor_name, and_( DbDocument.semantic_id == cursor_name, DbDocument.id < cursor_document_id, ), ) ) def get_accessible_documents_for_hierarchy_node_paginated( db_session: Session, parent_hierarchy_node_id: int, user_email: str | None, external_group_ids: list[str], limit: int, # Sort options sort_by_name: bool = False, sort_ascending: bool = False, # Cursor fields for last_updated sorting cursor_last_modified: datetime | None = None, cursor_last_synced: datetime | None = None, # Cursor field for name sorting cursor_name: str | None = None, # Document ID for tie-breaking (used by both sort types) cursor_document_id: str | None = None, ) -> list[DbDocument]: stmt = select(DbDocument).where( DbDocument.parent_hierarchy_node_id == parent_hierarchy_node_id ) stmt = apply_document_access_filter(stmt, user_email, external_group_ids) # Apply cursor filter based on sort type and direction if sort_by_name: if sort_ascending: stmt = _apply_name_cursor_filter_asc(stmt, cursor_name, cursor_document_id) stmt = stmt.order_by(DbDocument.semantic_id.asc(), DbDocument.id.asc()) else: stmt = _apply_name_cursor_filter_desc(stmt, cursor_name, cursor_document_id) stmt = stmt.order_by(DbDocument.semantic_id.desc(), DbDocument.id.desc()) else: # Sort by last_updated if sort_ascending: stmt = _apply_last_updated_cursor_filter( stmt, cursor_last_modified, cursor_last_synced, cursor_document_id, is_ascending=True, ) stmt = stmt.order_by( DbDocument.last_modified.asc(), DbDocument.last_synced.asc().nulls_first(), DbDocument.id.asc(), ) else: stmt = _apply_last_updated_cursor_filter( stmt, cursor_last_modified, cursor_last_synced, cursor_document_id, is_ascending=False, ) stmt = stmt.order_by( DbDocument.last_modified.desc(), DbDocument.last_synced.desc().nulls_last(), DbDocument.id.desc(), ) # Use distinct to avoid duplicates when a document belongs to multiple cc_pairs stmt = stmt.distinct() stmt = stmt.limit(limit) return list(db_session.execute(stmt).scalars().all()) def filter_existing_document_ids( db_session: Session, document_ids: list[str], ) -> set[str]: """Filter a list of document IDs to only those that exist in the database. Args: db_session: Database session document_ids: List of document IDs to check for existence Returns: Set of document IDs from the input list that exist in the database """ if not document_ids: return set() stmt = select(DbDocument.id).where(DbDocument.id.in_(document_ids)) return set(db_session.execute(stmt).scalars().all()) def fetch_document_ids_by_links( db_session: Session, links: list[str], ) -> dict[str, str]: """Fetch document IDs for documents whose link matches any of the provided values.""" if not links: return {} stmt = select(DbDocument.link, DbDocument.id).where(DbDocument.link.in_(links)) rows = db_session.execute(stmt).all() return {link: doc_id for link, doc_id in rows if link} def get_document_connector_count( db_session: Session, document_id: str, ) -> int: results = get_document_connector_counts(db_session, [document_id]) if not results or len(results) == 0: return 0 return results[0][1] def get_document_connector_counts( db_session: Session, document_ids: list[str], ) -> Sequence[tuple[str, int]]: stmt = ( select( DocumentByConnectorCredentialPair.id, func.count(), ) .where(DocumentByConnectorCredentialPair.id.in_(document_ids)) .group_by(DocumentByConnectorCredentialPair.id) ) return db_session.execute(stmt).all() # type: ignore def get_document_counts_for_cc_pairs( db_session: Session, cc_pairs: list[ConnectorCredentialPairIdentifier] ) -> Sequence[tuple[int, int, int]]: """Returns a sequence of tuples of (connector_id, credential_id, document count)""" if not cc_pairs: return [] # Prepare a list of (connector_id, credential_id) tuples cc_ids = [(x.connector_id, x.credential_id) for x in cc_pairs] # Batch to avoid generating extremely large IN clauses that can blow Postgres stack depth batch_size = 1000 aggregated_counts: dict[tuple[int, int], int] = {} for start_idx in range(0, len(cc_ids), batch_size): batch = cc_ids[start_idx : start_idx + batch_size] stmt = ( select( DocumentByConnectorCredentialPair.connector_id, DocumentByConnectorCredentialPair.credential_id, func.count(), ) .where( and_( tuple_( DocumentByConnectorCredentialPair.connector_id, DocumentByConnectorCredentialPair.credential_id, ).in_(batch), DocumentByConnectorCredentialPair.has_been_indexed.is_(True), ) ) .group_by( DocumentByConnectorCredentialPair.connector_id, DocumentByConnectorCredentialPair.credential_id, ) ) for connector_id, credential_id, cnt in db_session.execute(stmt).all(): aggregated_counts[(connector_id, credential_id)] = cnt # Convert aggregated results back to the expected sequence of tuples return [ (connector_id, credential_id, cnt) for (connector_id, credential_id), cnt in aggregated_counts.items() ] def get_document_counts_for_all_cc_pairs( db_session: Session, ) -> Sequence[tuple[int, int, int]]: """Return (connector_id, credential_id, count) for ALL CC pairs with indexed docs. Executes a single grouped query so Postgres can fully leverage indexes, avoiding large batched IN-lists. """ stmt = ( select( DocumentByConnectorCredentialPair.connector_id, DocumentByConnectorCredentialPair.credential_id, func.count(), ) .where(DocumentByConnectorCredentialPair.has_been_indexed.is_(True)) .group_by( DocumentByConnectorCredentialPair.connector_id, DocumentByConnectorCredentialPair.credential_id, ) ) return db_session.execute(stmt).all() # type: ignore def get_access_info_for_document( db_session: Session, document_id: str, ) -> tuple[str, list[str | None], bool] | None: """Gets access info for a single document by calling the get_access_info_for_documents function and passing a list with a single document ID. Args: db_session (Session): The database session to use. document_id (str): The document ID to fetch access info for. Returns: Optional[Tuple[str, List[str | None], bool]]: A tuple containing the document ID, a list of user emails, and a boolean indicating if the document is globally public, or None if no results are found. """ results = get_access_info_for_documents(db_session, [document_id]) if not results: return None return results[0] def get_access_info_for_documents( db_session: Session, document_ids: list[str], ) -> Sequence[tuple[str, list[str | None], bool]]: """Gets back all relevant access info for the given documents. This includes the user_ids for cc pairs that the document is associated with + whether any of the associated cc pairs are intending to make the document globally public. Returns the list where each element contains: - Document ID (which is also the ID of the DocumentByConnectorCredentialPair) - List of emails of Onyx users with direct access to the doc (includes a "None" element if the connector was set up by an admin when auth was off - bool for whether the document is public (the document later can also be marked public by automatic permission sync step) """ stmt = select( DocumentByConnectorCredentialPair.id, func.array_agg(func.coalesce(User.email, null())).label("user_emails"), func.bool_or(ConnectorCredentialPair.access_type == AccessType.PUBLIC).label( "public_doc" ), ).where(DocumentByConnectorCredentialPair.id.in_(document_ids)) stmt = ( stmt.join( Credential, DocumentByConnectorCredentialPair.credential_id == Credential.id, ) .join( ConnectorCredentialPair, and_( DocumentByConnectorCredentialPair.connector_id == ConnectorCredentialPair.connector_id, DocumentByConnectorCredentialPair.credential_id == ConnectorCredentialPair.credential_id, ), ) .outerjoin( User, and_( Credential.user_id == User.id, ConnectorCredentialPair.access_type != AccessType.SYNC, ), ) # don't include CC pairs that are being deleted # NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them .where(ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING) .group_by(DocumentByConnectorCredentialPair.id) ) return db_session.execute(stmt).all() # type: ignore def upsert_documents( db_session: Session, document_metadata_batch: list[DocumentMetadata], initial_boost: int = DEFAULT_BOOST, ) -> None: """NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause. Also note, this function should not be used for updating documents, only creating and ensuring that it exists. It IGNORES the doc_updated_at field""" seen_documents: dict[str, DocumentMetadata] = {} for document_metadata in document_metadata_batch: doc_id = document_metadata.document_id if doc_id not in seen_documents: seen_documents[doc_id] = document_metadata if not seen_documents: logger.info("No documents to upsert. Skipping.") return includes_permissions = any(doc.external_access for doc in seen_documents.values()) insert_stmt = insert(DbDocument).values( [ model_to_dict( DbDocument( id=doc.document_id, from_ingestion_api=doc.from_ingestion_api, boost=initial_boost, hidden=False, semantic_id=doc.semantic_identifier, link=doc.first_link, doc_updated_at=None, # this is intentional last_modified=datetime.now(timezone.utc), primary_owners=doc.primary_owners, secondary_owners=doc.secondary_owners, kg_stage=KGStage.NOT_STARTED, parent_hierarchy_node_id=doc.parent_hierarchy_node_id, **( { "external_user_emails": list( doc.external_access.external_user_emails ), "external_user_group_ids": list( doc.external_access.external_user_group_ids ), "is_public": doc.external_access.is_public, } if doc.external_access else {} ), doc_metadata=doc.doc_metadata, ) ) for doc in seen_documents.values() ] ) update_set = { "from_ingestion_api": insert_stmt.excluded.from_ingestion_api, "boost": insert_stmt.excluded.boost, "hidden": insert_stmt.excluded.hidden, "semantic_id": insert_stmt.excluded.semantic_id, "link": insert_stmt.excluded.link, "primary_owners": insert_stmt.excluded.primary_owners, "secondary_owners": insert_stmt.excluded.secondary_owners, "doc_metadata": insert_stmt.excluded.doc_metadata, "parent_hierarchy_node_id": insert_stmt.excluded.parent_hierarchy_node_id, } if includes_permissions: # Use COALESCE to preserve existing permissions when new values are NULL. # This prevents subsequent indexing runs (which don't fetch permissions) # from overwriting permissions set by permission sync jobs. update_set.update( { "external_user_emails": func.coalesce( insert_stmt.excluded.external_user_emails, DbDocument.external_user_emails, ), "external_user_group_ids": func.coalesce( insert_stmt.excluded.external_user_group_ids, DbDocument.external_user_group_ids, ), "is_public": func.coalesce( insert_stmt.excluded.is_public, DbDocument.is_public, ), } ) on_conflict_stmt = insert_stmt.on_conflict_do_update( index_elements=["id"], set_=update_set, # Conflict target ) db_session.execute(on_conflict_stmt) db_session.commit() def upsert_document_by_connector_credential_pair( db_session: Session, connector_id: int, credential_id: int, document_ids: list[str] ) -> None: """NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause.""" if not document_ids: logger.info("`document_ids` is empty. Skipping.") return insert_stmt = insert(DocumentByConnectorCredentialPair).values( [ model_to_dict( DocumentByConnectorCredentialPair( id=doc_id, connector_id=connector_id, credential_id=credential_id, has_been_indexed=False, ) ) for doc_id in document_ids ] ) # this must be `on_conflict_do_nothing` rather than `on_conflict_do_update` # since we don't want to update the `has_been_indexed` field for documents # that already exist on_conflict_stmt = insert_stmt.on_conflict_do_nothing() db_session.execute(on_conflict_stmt) db_session.commit() def mark_document_as_indexed_for_cc_pair__no_commit( db_session: Session, connector_id: int, credential_id: int, document_ids: Iterable[str], ) -> None: """Should be called only after a successful index operation for a batch.""" db_session.execute( update(DocumentByConnectorCredentialPair) .where( and_( DocumentByConnectorCredentialPair.connector_id == connector_id, DocumentByConnectorCredentialPair.credential_id == credential_id, DocumentByConnectorCredentialPair.id.in_(document_ids), ) ) .values(has_been_indexed=True) ) def update_docs_updated_at__no_commit( ids_to_new_updated_at: dict[str, datetime], db_session: Session, ) -> None: doc_ids = list(ids_to_new_updated_at.keys()) documents_to_update = ( db_session.query(DbDocument).filter(DbDocument.id.in_(doc_ids)).all() ) for document in documents_to_update: document.doc_updated_at = ids_to_new_updated_at[document.id] def update_docs_last_modified__no_commit( document_ids: list[str], db_session: Session, ) -> None: documents_to_update = ( db_session.query(DbDocument).filter(DbDocument.id.in_(document_ids)).all() ) now = datetime.now(timezone.utc) for doc in documents_to_update: doc.last_modified = now def update_docs_chunk_count__no_commit( document_ids: list[str], doc_id_to_chunk_count: dict[str, int], db_session: Session, ) -> None: documents_to_update = ( db_session.query(DbDocument).filter(DbDocument.id.in_(document_ids)).all() ) for doc in documents_to_update: doc.chunk_count = doc_id_to_chunk_count[doc.id] def mark_document_as_modified( document_id: str, db_session: Session, ) -> None: stmt = select(DbDocument).where(DbDocument.id == document_id) doc = db_session.scalar(stmt) if doc is None: raise ValueError(f"No document with ID: {document_id}") # update last_synced doc.last_modified = datetime.now(timezone.utc) db_session.commit() def mark_document_as_synced(document_id: str, db_session: Session) -> None: stmt = select(DbDocument).where(DbDocument.id == document_id) doc = db_session.scalar(stmt) if doc is None: raise ValueError(f"No document with ID: {document_id}") # update last_synced doc.last_synced = datetime.now(timezone.utc) db_session.commit() def delete_document_by_connector_credential_pair__no_commit( db_session: Session, document_id: str, connector_credential_pair_identifier: ( ConnectorCredentialPairIdentifier | None ) = None, ) -> None: """Deletes a single document by cc pair relationship entry. Foreign key rows are left in place. The implicit assumption is that the document itself still has other cc_pair references and needs to continue existing. """ delete_documents_by_connector_credential_pair__no_commit( db_session=db_session, document_ids=[document_id], connector_credential_pair_identifier=connector_credential_pair_identifier, ) def delete_documents_by_connector_credential_pair__no_commit( db_session: Session, document_ids: list[str], connector_credential_pair_identifier: ( ConnectorCredentialPairIdentifier | None ) = None, ) -> None: """This deletes just the document by cc pair entries for a particular cc pair. Foreign key rows are left in place. The implicit assumption is that the document itself still has other cc_pair references and needs to continue existing. """ stmt = delete(DocumentByConnectorCredentialPair).where( DocumentByConnectorCredentialPair.id.in_(document_ids) ) if connector_credential_pair_identifier: stmt = stmt.where( and_( DocumentByConnectorCredentialPair.connector_id == connector_credential_pair_identifier.connector_id, DocumentByConnectorCredentialPair.credential_id == connector_credential_pair_identifier.credential_id, ) ) db_session.execute(stmt) def delete_all_documents_by_connector_credential_pair__no_commit( db_session: Session, connector_id: int, credential_id: int, ) -> None: """Deletes all document by connector credential pair entries for a specific connector and credential. This is primarily used during connector deletion to ensure all references are removed before deleting the connector itself. This is crucial because connector_id is part of the primary key in DocumentByConnectorCredentialPair, and attempting to delete the Connector would otherwise try to set the foreign key to NULL, which fails for primary keys. NOTE: Does not commit the transaction, this must be done by the caller. """ stmt = delete(DocumentByConnectorCredentialPair).where( and_( DocumentByConnectorCredentialPair.connector_id == connector_id, DocumentByConnectorCredentialPair.credential_id == credential_id, ) ) db_session.execute(stmt) def delete_documents__no_commit(db_session: Session, document_ids: list[str]) -> None: db_session.execute(delete(DbDocument).where(DbDocument.id.in_(document_ids))) def delete_documents_complete__no_commit( db_session: Session, document_ids: list[str] ) -> None: """This completely deletes the documents from the db, including all foreign key relationships""" # Start with the kg references delete_from_kg_relationships__no_commit( db_session=db_session, document_ids=document_ids, ) delete_from_kg_entities__no_commit( db_session=db_session, document_ids=document_ids, ) delete_from_kg_relationships_extraction_staging__no_commit( db_session=db_session, document_ids=document_ids, ) delete_from_kg_entities_extraction_staging__no_commit( db_session=db_session, document_ids=document_ids, ) # Continue with deleting the chunk stats for the documents delete_chunk_stats_by_connector_credential_pair__no_commit( db_session=db_session, document_ids=document_ids, ) delete_documents_by_connector_credential_pair__no_commit(db_session, document_ids) delete_document_feedback_for_documents__no_commit( document_ids=document_ids, db_session=db_session ) delete_document_tags_for_documents__no_commit( document_ids=document_ids, db_session=db_session ) delete_documents__no_commit(db_session, document_ids) def delete_all_documents_for_connector_credential_pair( db_session: Session, connector_id: int, credential_id: int, timeout: int = ONE_HOUR_IN_SECONDS, ) -> None: """Delete all documents for a given connector credential pair. This will delete all documents and their associated data (chunks, feedback, tags, etc.) NOTE: a bit inefficient, but it's not a big deal since this is done rarely - only during an index swap. If we wanted to make this more efficient, we could use a single delete statement + cascade. """ batch_size = 1000 start_time = time.monotonic() while True: # Get document IDs in batches stmt = ( select(DocumentByConnectorCredentialPair.id) .where( DocumentByConnectorCredentialPair.connector_id == connector_id, DocumentByConnectorCredentialPair.credential_id == credential_id, ) .limit(batch_size) ) document_ids = db_session.scalars(stmt).all() if not document_ids: break delete_documents_complete__no_commit( db_session=db_session, document_ids=list(document_ids) ) db_session.commit() if time.monotonic() - start_time > timeout: raise RuntimeError("Timeout reached while deleting documents") def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool: """Acquire locks for the specified documents. Ideally this shouldn't be called with large list of document_ids (an exception could be made if the length of holding the lock is very short). Will simply raise an exception if any of the documents are already locked. This prevents deadlocks (assuming that the caller passes in all required document IDs in a single call). """ stmt = ( select(DbDocument.id) .where(DbDocument.id.in_(document_ids)) .with_for_update(nowait=True) ) # will raise exception if any of the documents are already locked documents = db_session.scalars(stmt).all() # make sure we found every document if len(documents) != len(set(document_ids)): logger.warning("Didn't find row for all specified document IDs. Aborting.") return False return True _NUM_LOCK_ATTEMPTS = 10 _LOCK_RETRY_DELAY = 10 @contextlib.contextmanager def prepare_to_modify_documents( db_session: Session, document_ids: list[str], retry_delay: int = _LOCK_RETRY_DELAY ) -> Generator[TransactionalContext, None, None]: """Try and acquire locks for the documents to prevent other jobs from modifying them at the same time (e.g. avoid race conditions). This should be called ahead of any modification to Vespa. Locks should be released by the caller as soon as updates are complete by finishing the transaction. NOTE: only one commit is allowed within the context manager returned by this function. Multiple commits will result in a sqlalchemy.exc.InvalidRequestError. NOTE: this function will commit any existing transaction. """ db_session.commit() # ensure that we're not in a transaction lock_acquired = False for i in range(_NUM_LOCK_ATTEMPTS): try: with db_session.begin() as transaction: lock_acquired = acquire_document_locks( db_session=db_session, document_ids=document_ids ) if lock_acquired: yield transaction break except OperationalError as e: logger.warning( f"Failed to acquire locks for documents on attempt {i}, retrying. Error: {e}" ) time.sleep(retry_delay) if not lock_acquired: raise RuntimeError( f"Failed to acquire locks after {_NUM_LOCK_ATTEMPTS} attempts for documents: {document_ids}" ) def get_ingestion_documents( db_session: Session, ) -> list[DbDocument]: # TODO add the option to filter by DocumentSource stmt = select(DbDocument).where(DbDocument.from_ingestion_api.is_(True)) documents = db_session.execute(stmt).scalars().all() return list(documents) def get_documents_by_cc_pair( cc_pair_id: int, db_session: Session, ) -> list[DbDocument]: return ( db_session.query(DbDocument) .join( DocumentByConnectorCredentialPair, DbDocument.id == DocumentByConnectorCredentialPair.id, ) .join( ConnectorCredentialPair, and_( DocumentByConnectorCredentialPair.connector_id == ConnectorCredentialPair.connector_id, DocumentByConnectorCredentialPair.credential_id == ConnectorCredentialPair.credential_id, ), ) .filter(ConnectorCredentialPair.id == cc_pair_id) .all() ) def get_document( document_id: str, db_session: Session, ) -> DbDocument | None: stmt = select(DbDocument).where(DbDocument.id == document_id) doc: DbDocument | None = db_session.execute(stmt).scalar_one_or_none() return doc def get_cc_pairs_for_document( db_session: Session, document_id: str, ) -> list[ConnectorCredentialPair]: stmt = ( select(ConnectorCredentialPair) .join( DocumentByConnectorCredentialPair, and_( DocumentByConnectorCredentialPair.connector_id == ConnectorCredentialPair.connector_id, DocumentByConnectorCredentialPair.credential_id == ConnectorCredentialPair.credential_id, ), ) .where(DocumentByConnectorCredentialPair.id == document_id) ) return list(db_session.execute(stmt).scalars().all()) def get_document_sources( db_session: Session, document_ids: list[str], ) -> dict[str, DocumentSource]: """Gets the sources for a list of document IDs. Returns a dictionary mapping document ID to its source. If a document has multiple sources (multiple CC pairs), returns the first one found. """ stmt = ( select( DocumentByConnectorCredentialPair.id, Connector.source, ) .join( ConnectorCredentialPair, and_( DocumentByConnectorCredentialPair.connector_id == ConnectorCredentialPair.connector_id, DocumentByConnectorCredentialPair.credential_id == ConnectorCredentialPair.credential_id, ), ) .join( Connector, ConnectorCredentialPair.connector_id == Connector.id, ) .where(DocumentByConnectorCredentialPair.id.in_(document_ids)) .distinct() ) results = db_session.execute(stmt).all() return {doc_id: source for doc_id, source in results} def fetch_chunk_counts_for_documents( document_ids: list[str], db_session: Session, ) -> list[tuple[str, int]]: """ Return a list of (document_id, chunk_count) tuples. If a document_id is not found in the database, it will be returned with a chunk_count of 0. """ stmt = select(DbDocument.id, DbDocument.chunk_count).where( DbDocument.id.in_(document_ids) ) results = db_session.execute(stmt).all() # Create a dictionary of document_id to chunk_count chunk_counts = {str(row.id): row.chunk_count or 0 for row in results} # Return a list of tuples, preserving `None` for documents not found or with # an unknown chunk count. Callers should handle the `None` case and fall # back to an existence check against the vector DB if necessary. return [(doc_id, chunk_counts.get(doc_id, 0)) for doc_id in document_ids] def fetch_chunk_count_for_document( document_id: str, db_session: Session, ) -> int | None: stmt = select(DbDocument.chunk_count).where(DbDocument.id == document_id) return db_session.execute(stmt).scalar_one_or_none() def get_unprocessed_kg_document_batch_for_connector( db_session: Session, connector_id: int, kg_coverage_start: datetime, kg_max_coverage_days: int, batch_size: int = 100, ) -> list[DbDocument]: """ Retrieves a batch of documents that have not been processed for knowledge graph extraction. Args: db_session (Session): The database session to use connector_id (int): The ID of the connector to get documents for batch_size (int): The maximum number of documents to retrieve Returns: list[DbDocument]: List of documents that need KG processing """ stmt = ( select(DbDocument) .join( DocumentByConnectorCredentialPair, DbDocument.id == DocumentByConnectorCredentialPair.id, ) .where( and_( DocumentByConnectorCredentialPair.connector_id == connector_id, DbDocument.doc_updated_at >= max( kg_coverage_start, datetime.now() - timedelta(days=kg_max_coverage_days), ), or_( DbDocument.kg_stage.is_(None), DbDocument.kg_stage == KGStage.NOT_STARTED, DbDocument.doc_updated_at > DbDocument.kg_processing_time, ), ) ) .distinct() .limit(batch_size) ) documents = db_session.scalars(stmt).all() db_session.flush() return list(documents) def get_kg_extracted_document_ids(db_session: Session) -> list[str]: """ Retrieves all document IDs where kg_stage is EXTRACTED. Args: db_session (Session): The database session to use Returns: list[str]: List of document IDs that have been KG processed """ stmt = select(DbDocument.id).where(DbDocument.kg_stage == KGStage.EXTRACTED) return list(db_session.scalars(stmt).all()) def update_document_kg_info( db_session: Session, document_id: str, kg_stage: KGStage ) -> None: """Updates the knowledge graph related information for a document. Args: db_session (Session): The database session to use document_id (str): The ID of the document to update kg_stage (KGStage): The stage of the knowledge graph processing for the document Raises: ValueError: If the document with the given ID is not found """ stmt = ( update(DbDocument) .where(DbDocument.id == document_id) .values( kg_stage=kg_stage, kg_processing_time=datetime.now(timezone.utc), ) ) db_session.execute(stmt) def update_document_kg_stage( db_session: Session, document_id: str, kg_stage: KGStage, ) -> None: stmt = ( update(DbDocument).where(DbDocument.id == document_id).values(kg_stage=kg_stage) ) db_session.execute(stmt) db_session.flush() def get_all_kg_extracted_documents_info( db_session: Session, ) -> list[str]: """Retrieves the knowledge graph data for all documents that have been processed. Args: db_session (Session): The database session to use Returns: List[Tuple[str, dict]]: A list of tuples containing: - str: The document ID - dict: The KG data containing 'entities', 'relationships', and 'terms' Only returns documents where kg_stage is EXTRACTED """ stmt = ( select(DbDocument.id) .where(DbDocument.kg_stage == KGStage.EXTRACTED) .order_by(DbDocument.id) ) results = db_session.execute(stmt).all() return [str(doc_id) for doc_id in results] def get_base_llm_doc_information( db_session: Session, document_ids: list[str] ) -> list[str]: stmt = select(DbDocument).where(DbDocument.id.in_(document_ids)) results = db_session.execute(stmt).all() documents = [] for doc_nr, doc in enumerate(results): bare_doc = doc[0] documents.append( f"""* [{bare_doc.semantic_id}]({bare_doc.link}) ({bare_doc.doc_updated_at})""" ) return documents[:KG_SIMPLE_ANSWER_MAX_DISPLAYED_SOURCES] def get_document_updated_at( document_id: str, db_session: Session, ) -> datetime | None: """Retrieves the doc_updated_at timestamp for a given document ID. Args: document_id (str): The ID of the document to query db_session (Session): The database session to use Returns: Optional[datetime]: The doc_updated_at timestamp if found, None if document doesn't exist """ stmt = select(DbDocument.doc_updated_at).where(DbDocument.id == document_id) return db_session.execute(stmt).scalar_one_or_none() def reset_all_document_kg_stages(db_session: Session) -> int: """Reset the KG stage of all documents that are not in NOT_STARTED state to NOT_STARTED. Args: db_session (Session): The database session to use Returns: int: Number of documents that were reset """ stmt = ( update(DbDocument) .where(DbDocument.kg_stage != KGStage.NOT_STARTED) .values(kg_stage=KGStage.NOT_STARTED) ) result = db_session.execute(stmt) # The hasattr check is needed for type checking, even though rowcount # is guaranteed to exist at runtime for UPDATE operations return result.rowcount if hasattr(result, "rowcount") else 0 def update_document_kg_stages( db_session: Session, source_stage: KGStage, target_stage: KGStage ) -> int: """Reset the KG stage only of documents back to NOT_STARTED. Part of reset flow for documents that have been extracted but not clustered. Args: db_session (Session): The database session to use Returns: int: Number of documents that were reset """ stmt = ( update(DbDocument) .where(DbDocument.kg_stage == source_stage) .values(kg_stage=target_stage) ) result = db_session.execute(stmt) # The hasattr check is needed for type checking, even though rowcount # is guaranteed to exist at runtime for UPDATE operations return result.rowcount if hasattr(result, "rowcount") else 0 def get_skipped_kg_documents(db_session: Session) -> list[str]: """ Retrieves all document IDs where kg_stage is SKIPPED. Args: db_session (Session): The database session to use Returns: list[str]: List of document IDs that have been skipped in KG processing """ stmt = select(DbDocument.id).where(DbDocument.kg_stage == KGStage.SKIPPED) return list(db_session.scalars(stmt).all()) # def get_kg_doc_info_for_entity_name( # db_session: Session, document_id: str, entity_type: str # ) -> KGEntityDocInfo: # """ # Get the semantic ID and the link for an entity name. # """ # result = ( # db_session.query(Document.semantic_id, Document.link) # .filter(Document.id == document_id) # .first() # ) # if result is None: # return KGEntityDocInfo( # doc_id=None, # doc_semantic_id=None, # doc_link=None, # semantic_entity_name=f"{entity_type}:{document_id}", # semantic_linked_entity_name=f"{entity_type}:{document_id}", # ) # return KGEntityDocInfo( # doc_id=document_id, # doc_semantic_id=result[0], # doc_link=result[1], # semantic_entity_name=f"{entity_type.upper()}:{result[0]}", # semantic_linked_entity_name=f"[{entity_type.upper()}:{result[0]}]({result[1]})", # ) def check_for_documents_needing_kg_processing( db_session: Session, kg_coverage_start: datetime, kg_max_coverage_days: int ) -> bool: """Check if there are any documents that need KG processing. A document needs KG processing if: 1. It is associated with a connector that has kg_processing_enabled = true 2. AND either: - Its kg_stage is NOT_STARTED or NULL - OR its last_updated timestamp is greater than its kg_processing_time Args: db_session (Session): The database session to use Returns: bool: True if there are any documents needing KG processing, False otherwise """ stmt = ( select(1) .select_from(DbDocument) .join( DocumentByConnectorCredentialPair, DbDocument.id == DocumentByConnectorCredentialPair.id, ) .join( Connector, DocumentByConnectorCredentialPair.connector_id == Connector.id, ) .where( and_( Connector.kg_processing_enabled.is_(True), DbDocument.doc_updated_at >= max( kg_coverage_start, datetime.now() - timedelta(days=kg_max_coverage_days), ), or_( DbDocument.kg_stage.is_(None), DbDocument.kg_stage == KGStage.NOT_STARTED, DbDocument.doc_updated_at > DbDocument.kg_processing_time, ), ) ) .exists() ) return db_session.execute(select(stmt)).scalar() or False def check_for_documents_needing_kg_clustering(db_session: Session) -> bool: """Check if there are any documents that need KG clustering. A document needs KG clustering if: 1. It is associated with a connector that has kg_processing_enabled = true 2. AND either: - Its kg_stage is EXTRACTED - OR its last_updated timestamp is greater than its kg_processing_time Args: db_session (Session): The database session to use Returns: bool: True if there are any documents needing KG clustering, False otherwise """ stmt = ( select(1) .select_from(DbDocument) .join( DocumentByConnectorCredentialPair, DbDocument.id == DocumentByConnectorCredentialPair.id, ) .join( ConnectorCredentialPair, and_( DocumentByConnectorCredentialPair.connector_id == ConnectorCredentialPair.connector_id, DocumentByConnectorCredentialPair.credential_id == ConnectorCredentialPair.credential_id, ), ) .join( Connector, ConnectorCredentialPair.connector_id == Connector.id, ) .where( and_( Connector.kg_processing_enabled.is_(True), ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING, or_( DbDocument.kg_stage == KGStage.EXTRACTED, DbDocument.last_modified > DbDocument.kg_processing_time, ), ) ) .exists() ) return db_session.execute(select(stmt)).scalar() or False def get_document_kg_entities_and_relationships( db_session: Session, document_id: str ) -> tuple[list[KGEntity], list[KGRelationship]]: """ Get the KG entities and relationships that references the document. """ entities = ( db_session.query(KGEntity).filter(KGEntity.document_id == document_id).all() ) if not entities: return [], [] entity_id_names = [entity.id_name for entity in entities] relationships = ( db_session.query(KGRelationship) .filter( or_( KGRelationship.source_node.in_(entity_id_names), KGRelationship.target_node.in_(entity_id_names), KGRelationship.source_document == document_id, ) ) .all() ) return entities, relationships def get_num_chunks_for_document(db_session: Session, document_id: str) -> int: stmt = select(DbDocument.chunk_count).where(DbDocument.id == document_id) return db_session.execute(stmt).scalar_one_or_none() or 0 def update_document_metadata__no_commit( db_session: Session, document_id: str, doc_metadata: dict[str, Any], ) -> None: """Update the doc_metadata field for a document. Note: Does not commit. Caller is responsible for committing. Args: db_session: Database session document_id: The ID of the document to update doc_metadata: The new metadata dictionary to set """ stmt = ( update(DbDocument) .where(DbDocument.id == document_id) .values(doc_metadata=doc_metadata) ) db_session.execute(stmt) def delete_document_by_id__no_commit( db_session: Session, document_id: str, ) -> None: """Delete a single document and its connector credential pair relationships. Note: Does not commit. Caller is responsible for committing. This uses delete_documents_complete__no_commit which handles all foreign key relationships (KG entities, relationships, chunk stats, cc pair associations, feedback, tags). """ delete_documents_complete__no_commit(db_session, [document_id]) ================================================ FILE: backend/onyx/db/document_access.py ================================================ """ Document access filtering utilities. This module provides reusable access filtering logic for documents based on: - Connector access type (PUBLIC vs SYNC) - Document-level public flag - User email matching external_user_emails - User group overlap with external_user_group_ids This is a standalone module to avoid circular imports between document.py and persona.py. """ from sqlalchemy import and_ from sqlalchemy import any_ from sqlalchemy import cast from sqlalchemy import or_ from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import String from sqlalchemy.dialects import postgresql from sqlalchemy.orm import Session from sqlalchemy.sql.elements import ColumnElement from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.models import ConnectorCredentialPair from onyx.db.models import Document from onyx.db.models import DocumentByConnectorCredentialPair def apply_document_access_filter( stmt: Select, user_email: str | None, external_group_ids: list[str], ) -> Select: """ Apply document access filtering to a query. This joins with DocumentByConnectorCredentialPair and ConnectorCredentialPair to: 1. Check if the document is from a PUBLIC connector (access_type = PUBLIC) 2. Check document-level permissions (is_public, external_user_emails, external_user_group_ids) 3. Exclude documents from cc_pairs that are being deleted Args: stmt: The SELECT statement to modify (must be selecting from Document) user_email: The user's email for permission checking external_group_ids: List of external group IDs the user belongs to Returns: Modified SELECT statement with access filtering applied """ # Join to get cc_pair info for each document stmt = stmt.join( DocumentByConnectorCredentialPair, Document.id == DocumentByConnectorCredentialPair.id, ).join( ConnectorCredentialPair, and_( DocumentByConnectorCredentialPair.connector_id == ConnectorCredentialPair.connector_id, DocumentByConnectorCredentialPair.credential_id == ConnectorCredentialPair.credential_id, ), ) # Exclude documents from cc_pairs that are being deleted stmt = stmt.where( ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING ) # Build access filters access_filters: list[ColumnElement[bool]] = [ # Document is from a PUBLIC connector ConnectorCredentialPair.access_type == AccessType.PUBLIC, # Document is marked as public (e.g., "Anyone with link" in source) Document.is_public.is_(True), ] if user_email: access_filters.append(any_(Document.external_user_emails) == user_email) if external_group_ids: access_filters.append( Document.external_user_group_ids.overlap( cast(postgresql.array(external_group_ids), postgresql.ARRAY(String)) ) ) stmt = stmt.where(or_(*access_filters)) return stmt def get_accessible_documents_by_ids( db_session: Session, document_ids: list[str], user_email: str | None, external_group_ids: list[str], ) -> list[Document]: """ Fetch documents by IDs, filtering to only those the user has access to. Uses the same access filtering logic as other document queries: - Documents from PUBLIC connectors - Documents marked as public (e.g., "Anyone with link") - Documents where user email matches external_user_emails - Documents where user's groups overlap with external_user_group_ids Args: db_session: Database session document_ids: List of document IDs to fetch user_email: User's email for permission checking external_group_ids: List of external group IDs the user belongs to Returns: List of Document objects from the input that the user has access to """ if not document_ids: return [] stmt = select(Document).where(Document.id.in_(document_ids)) stmt = apply_document_access_filter(stmt, user_email, external_group_ids) # Use distinct to avoid duplicates when a document belongs to multiple cc_pairs stmt = stmt.distinct() return list(db_session.execute(stmt).scalars().all()) ================================================ FILE: backend/onyx/db/document_set.py ================================================ from collections.abc import Sequence from typing import cast from uuid import UUID from sqlalchemy import and_ from sqlalchemy import delete from sqlalchemy import exists from sqlalchemy import func from sqlalchemy import or_ from sqlalchemy import Select from sqlalchemy import select from sqlalchemy.orm import aliased from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from onyx.configs.app_configs import DISABLE_VECTOR_DB from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids from onyx.db.connector_credential_pair import get_connector_credential_pairs from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.federated import create_federated_connector_document_set_mapping from onyx.db.models import ConnectorCredentialPair from onyx.db.models import Document from onyx.db.models import DocumentByConnectorCredentialPair from onyx.db.models import DocumentSet as DocumentSetDBModel from onyx.db.models import DocumentSet__ConnectorCredentialPair from onyx.db.models import DocumentSet__UserGroup from onyx.db.models import FederatedConnector__DocumentSet from onyx.db.models import User from onyx.db.models import User__UserGroup from onyx.db.models import UserRole from onyx.server.features.document_set.models import DocumentSetCreationRequest from onyx.server.features.document_set.models import DocumentSetUpdateRequest from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_versioned_implementation logger = setup_logger() def _add_user_filters(stmt: Select, user: User, get_editable: bool = True) -> Select: if user.role == UserRole.ADMIN: return stmt stmt = stmt.distinct() DocumentSet__UG = aliased(DocumentSet__UserGroup) User__UG = aliased(User__UserGroup) """ Here we select cc_pairs by relation: User -> User__UserGroup -> DocumentSet__UserGroup -> DocumentSet """ stmt = stmt.outerjoin(DocumentSet__UG).outerjoin( User__UserGroup, User__UserGroup.user_group_id == DocumentSet__UG.user_group_id, ) """ Filter DocumentSets by: - if the user is in the user_group that owns the DocumentSet - if the user is not a global_curator, they must also have a curator relationship to the user_group - if editing is being done, we also filter out DocumentSets that are owned by groups that the user isn't a curator for - if we are not editing, we show all DocumentSets in the groups the user is a curator for (as well as public DocumentSets) """ # Anonymous users only see public DocumentSets if user.is_anonymous: where_clause = DocumentSetDBModel.is_public == True # noqa: E712 return stmt.where(where_clause) where_clause = User__UserGroup.user_id == user.id if user.role == UserRole.CURATOR and get_editable: where_clause &= User__UserGroup.is_curator == True # noqa: E712 if get_editable: user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id) if user.role == UserRole.CURATOR: user_groups = user_groups.where(User__UG.is_curator == True) # noqa: E712 where_clause &= ( ~exists() .where(DocumentSet__UG.document_set_id == DocumentSetDBModel.id) .where(~DocumentSet__UG.user_group_id.in_(user_groups)) .correlate(DocumentSetDBModel) ) where_clause |= DocumentSetDBModel.user_id == user.id else: where_clause |= DocumentSetDBModel.is_public == True # noqa: E712 return stmt.where(where_clause) def _delete_document_set_cc_pairs__no_commit( db_session: Session, document_set_id: int, is_current: bool | None = None ) -> None: """NOTE: does not commit transaction, this must be done by the caller""" stmt = delete(DocumentSet__ConnectorCredentialPair).where( DocumentSet__ConnectorCredentialPair.document_set_id == document_set_id ) if is_current is not None: stmt = stmt.where(DocumentSet__ConnectorCredentialPair.is_current == is_current) db_session.execute(stmt) def _mark_document_set_cc_pairs_as_outdated__no_commit( db_session: Session, document_set_id: int ) -> None: """NOTE: does not commit transaction, this must be done by the caller""" stmt = select(DocumentSet__ConnectorCredentialPair).where( DocumentSet__ConnectorCredentialPair.document_set_id == document_set_id ) for row in db_session.scalars(stmt): row.is_current = False def delete_document_set_privacy__no_commit( document_set_id: int, db_session: Session ) -> None: """No private document sets in Onyx MIT""" def get_document_set_by_id_for_user( db_session: Session, document_set_id: int, user: User, get_editable: bool = True, ) -> DocumentSetDBModel | None: stmt = ( select(DocumentSetDBModel) .distinct() .options(selectinload(DocumentSetDBModel.federated_connectors)) ) stmt = stmt.where(DocumentSetDBModel.id == document_set_id) stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable) return db_session.scalar(stmt) def get_document_set_by_id( db_session: Session, document_set_id: int, ) -> DocumentSetDBModel | None: stmt = select(DocumentSetDBModel).distinct() stmt = stmt.where(DocumentSetDBModel.id == document_set_id) return db_session.scalar(stmt) def get_document_set_by_name( db_session: Session, document_set_name: str ) -> DocumentSetDBModel | None: return db_session.scalar( select(DocumentSetDBModel).where(DocumentSetDBModel.name == document_set_name) ) def get_document_sets_by_name( db_session: Session, document_set_names: list[str] ) -> Sequence[DocumentSetDBModel]: return db_session.scalars( select(DocumentSetDBModel).where( DocumentSetDBModel.name.in_(document_set_names) ) ).all() def get_document_sets_by_ids( db_session: Session, document_set_ids: list[int] ) -> Sequence[DocumentSetDBModel]: if not document_set_ids: return [] return db_session.scalars( select(DocumentSetDBModel).where(DocumentSetDBModel.id.in_(document_set_ids)) ).all() def make_doc_set_private( document_set_id: int, # noqa: ARG001 user_ids: list[UUID] | None, group_ids: list[int] | None, db_session: Session, # noqa: ARG001 ) -> None: # May cause error if someone switches down to MIT from EE if user_ids or group_ids: raise NotImplementedError("Onyx MIT does not support private Document Sets") def _check_if_cc_pairs_are_owned_by_groups( db_session: Session, cc_pair_ids: list[int], group_ids: list[int], ) -> None: """ This function checks if the CC pairs are owned by the specified groups or public. If not, it raises a ValueError. """ group_cc_pair_relationships = get_cc_pair_groups_for_ids( db_session=db_session, cc_pair_ids=cc_pair_ids, ) group_cc_pair_relationships_set = { (relationship.cc_pair_id, relationship.user_group_id) for relationship in group_cc_pair_relationships } missing_cc_pair_ids = [] for cc_pair_id in cc_pair_ids: for group_id in group_ids: if (cc_pair_id, group_id) not in group_cc_pair_relationships_set: missing_cc_pair_ids.append(cc_pair_id) break if missing_cc_pair_ids: cc_pairs = get_connector_credential_pairs( db_session=db_session, ids=missing_cc_pair_ids, ) for cc_pair in cc_pairs: if cc_pair.access_type == AccessType.PRIVATE: raise ValueError( f"Connector Credential Pair with ID: '{cc_pair.id}' is not owned by the specified groups" ) def insert_document_set( document_set_creation_request: DocumentSetCreationRequest, user_id: UUID | None, db_session: Session, ) -> tuple[DocumentSetDBModel, list[DocumentSet__ConnectorCredentialPair]]: # Check if we have either CC pairs or federated connectors (or both) if ( not document_set_creation_request.cc_pair_ids and not document_set_creation_request.federated_connectors ): raise ValueError("Cannot create a document set with no connectors") if not document_set_creation_request.is_public: _check_if_cc_pairs_are_owned_by_groups( db_session=db_session, cc_pair_ids=document_set_creation_request.cc_pair_ids, group_ids=document_set_creation_request.groups or [], ) new_document_set_row: DocumentSetDBModel ds_cc_pairs: list[DocumentSet__ConnectorCredentialPair] try: new_document_set_row = DocumentSetDBModel( name=document_set_creation_request.name, description=document_set_creation_request.description, user_id=user_id, is_public=document_set_creation_request.is_public, is_up_to_date=DISABLE_VECTOR_DB, time_last_modified_by_user=func.now(), ) db_session.add(new_document_set_row) db_session.flush() # ensure the new document set gets assigned an ID # Create CC pair mappings ds_cc_pairs = [ DocumentSet__ConnectorCredentialPair( document_set_id=new_document_set_row.id, connector_credential_pair_id=cc_pair_id, is_current=True, ) for cc_pair_id in document_set_creation_request.cc_pair_ids ] db_session.add_all(ds_cc_pairs) # Create federated connector mappings from onyx.db.federated import create_federated_connector_document_set_mapping for fc_config in document_set_creation_request.federated_connectors: create_federated_connector_document_set_mapping( db_session=db_session, federated_connector_id=fc_config.federated_connector_id, document_set_id=new_document_set_row.id, entities=fc_config.entities, ) versioned_private_doc_set_fn = fetch_versioned_implementation( "onyx.db.document_set", "make_doc_set_private" ) # Private Document Sets versioned_private_doc_set_fn( document_set_id=new_document_set_row.id, user_ids=document_set_creation_request.users, group_ids=document_set_creation_request.groups, db_session=db_session, ) db_session.commit() except Exception as e: db_session.rollback() logger.error(f"Error creating document set: {e}") raise return new_document_set_row, ds_cc_pairs def update_document_set( db_session: Session, document_set_update_request: DocumentSetUpdateRequest, user: User, ) -> tuple[DocumentSetDBModel, list[DocumentSet__ConnectorCredentialPair]]: """If successful, this sets document_set_row.is_up_to_date = False. That will be processed via Celery in check_for_vespa_sync_task and trigger a long running background sync to Vespa. """ # Check if we have either CC pairs or federated connectors (or both) if ( not document_set_update_request.cc_pair_ids and not document_set_update_request.federated_connectors ): raise ValueError("Cannot update a document set with no connectors") if not document_set_update_request.is_public: _check_if_cc_pairs_are_owned_by_groups( db_session=db_session, cc_pair_ids=document_set_update_request.cc_pair_ids, group_ids=document_set_update_request.groups, ) try: # update the description document_set_row = get_document_set_by_id_for_user( db_session=db_session, document_set_id=document_set_update_request.id, user=user, get_editable=True, ) if document_set_row is None: raise ValueError( f"No document set with ID '{document_set_update_request.id}'" ) if not document_set_row.is_up_to_date: raise ValueError( "Cannot update document set while it is syncing. Please wait for it to finish syncing, and then try again." ) document_set_row.description = document_set_update_request.description if not DISABLE_VECTOR_DB: document_set_row.is_up_to_date = False document_set_row.is_public = document_set_update_request.is_public document_set_row.time_last_modified_by_user = func.now() versioned_private_doc_set_fn = fetch_versioned_implementation( "onyx.db.document_set", "make_doc_set_private" ) # Private Document Sets versioned_private_doc_set_fn( document_set_id=document_set_row.id, user_ids=document_set_update_request.users, group_ids=document_set_update_request.groups, db_session=db_session, ) # update the attached CC pairs # first, mark all existing CC pairs as not current _mark_document_set_cc_pairs_as_outdated__no_commit( db_session=db_session, document_set_id=document_set_row.id ) # add in rows for the new CC pairs ds_cc_pairs = [ DocumentSet__ConnectorCredentialPair( document_set_id=document_set_update_request.id, connector_credential_pair_id=cc_pair_id, is_current=True, ) for cc_pair_id in document_set_update_request.cc_pair_ids ] db_session.add_all(ds_cc_pairs) # Update federated connector mappings # Delete existing federated connector mappings for this document set delete_stmt = delete(FederatedConnector__DocumentSet).where( FederatedConnector__DocumentSet.document_set_id == document_set_row.id ) db_session.execute(delete_stmt) # Create new federated connector mappings for fc_config in document_set_update_request.federated_connectors: create_federated_connector_document_set_mapping( db_session=db_session, federated_connector_id=fc_config.federated_connector_id, document_set_id=document_set_row.id, entities=fc_config.entities, ) db_session.commit() except: db_session.rollback() raise return document_set_row, ds_cc_pairs def mark_document_set_as_synced(document_set_id: int, db_session: Session) -> None: stmt = select(DocumentSetDBModel).where(DocumentSetDBModel.id == document_set_id) document_set = db_session.scalar(stmt) if document_set is None: raise ValueError(f"No document set with ID: {document_set_id}") # mark as up to date document_set.is_up_to_date = True # delete outdated relationship table rows _delete_document_set_cc_pairs__no_commit( db_session=db_session, document_set_id=document_set_id, is_current=False ) db_session.commit() def delete_document_set( document_set_row: DocumentSetDBModel, db_session: Session ) -> None: # delete all relationships to CC pairs _delete_document_set_cc_pairs__no_commit( db_session=db_session, document_set_id=document_set_row.id ) db_session.delete(document_set_row) db_session.commit() def mark_document_set_as_to_be_deleted( db_session: Session, document_set_id: int, user: User, ) -> None: """Cleans up all document_set -> cc_pair relationships and marks the document set as needing an update. The actual document set row will be deleted by the background job which syncs these changes to Vespa.""" try: document_set_row = get_document_set_by_id_for_user( db_session=db_session, document_set_id=document_set_id, user=user, get_editable=True, ) if document_set_row is None: error_msg = f"Document set with ID: '{document_set_id}' does not exist " if user is not None: error_msg += f"or is not editable by user with email: '{user.email}'" raise ValueError(error_msg) if not document_set_row.is_up_to_date: raise ValueError( "Cannot delete document set while it is syncing. Please wait for it to finish syncing, and then try again." ) # delete all relationships to CC pairs _delete_document_set_cc_pairs__no_commit( db_session=db_session, document_set_id=document_set_id ) # delete all federated connector mappings so the cleanup task can fully # remove the document set once the Vespa sync completes delete_stmt = delete(FederatedConnector__DocumentSet).where( FederatedConnector__DocumentSet.document_set_id == document_set_id ) db_session.execute(delete_stmt) # delete all private document set information versioned_delete_private_fn = fetch_versioned_implementation( "onyx.db.document_set", "delete_document_set_privacy__no_commit" ) versioned_delete_private_fn( document_set_id=document_set_id, db_session=db_session ) # mark the row as needing a sync, it will be deleted there since there # are no more relationships to cc pairs document_set_row.is_up_to_date = False db_session.commit() except: db_session.rollback() raise def delete_document_set_cc_pair_relationship__no_commit( connector_id: int, credential_id: int, db_session: Session ) -> int: """Deletes all rows from DocumentSet__ConnectorCredentialPair where the connector_credential_pair_id matches the given cc_pair_id.""" delete_stmt = delete(DocumentSet__ConnectorCredentialPair).where( and_( ConnectorCredentialPair.connector_id == connector_id, ConnectorCredentialPair.credential_id == credential_id, DocumentSet__ConnectorCredentialPair.connector_credential_pair_id == ConnectorCredentialPair.id, ) ) result = db_session.execute(delete_stmt) return result.rowcount # type: ignore def fetch_document_sets( user_id: UUID | None, # noqa: ARG001 db_session: Session, include_outdated: bool = False, ) -> list[tuple[DocumentSetDBModel, list[ConnectorCredentialPair]]]: """Return is a list where each element contains a tuple of: 1. The document set itself 2. All CC pairs associated with the document set""" stmt = ( select(DocumentSetDBModel, ConnectorCredentialPair) .join( DocumentSet__ConnectorCredentialPair, DocumentSetDBModel.id == DocumentSet__ConnectorCredentialPair.document_set_id, isouter=True, # outer join is needed to also fetch document sets with no cc pairs ) .join( ConnectorCredentialPair, ConnectorCredentialPair.id == DocumentSet__ConnectorCredentialPair.connector_credential_pair_id, isouter=True, # outer join is needed to also fetch document sets with no cc pairs ) ) if not include_outdated: stmt = stmt.where( or_( DocumentSet__ConnectorCredentialPair.is_current == True, # noqa: E712 # `None` handles case where no CC Pairs exist for a Document Set DocumentSet__ConnectorCredentialPair.is_current.is_(None), ) ) results = cast( list[tuple[DocumentSetDBModel, ConnectorCredentialPair | None]], db_session.execute(stmt).all(), ) aggregated_results: dict[ int, tuple[DocumentSetDBModel, list[ConnectorCredentialPair]] ] = {} for document_set, cc_pair in results: if document_set.id not in aggregated_results: aggregated_results[document_set.id] = ( document_set, [cc_pair] if cc_pair else [], ) else: if cc_pair: aggregated_results[document_set.id][1].append(cc_pair) return [ (document_set, cc_pairs) for document_set, cc_pairs in aggregated_results.values() ] def fetch_all_document_sets_for_user( db_session: Session, user: User, get_editable: bool = True, ) -> Sequence[DocumentSetDBModel]: stmt = ( select(DocumentSetDBModel) .distinct() .options( selectinload(DocumentSetDBModel.connector_credential_pairs).selectinload( ConnectorCredentialPair.connector ), selectinload(DocumentSetDBModel.users), selectinload(DocumentSetDBModel.groups), selectinload(DocumentSetDBModel.federated_connectors).selectinload( FederatedConnector__DocumentSet.federated_connector ), ) ) stmt = _add_user_filters(stmt, user, get_editable=get_editable) return db_session.scalars(stmt).unique().all() def fetch_documents_for_document_set_paginated( document_set_id: int, db_session: Session, current_only: bool = True, last_document_id: str | None = None, limit: int = 100, ) -> tuple[Sequence[Document], str | None]: stmt = ( select(Document) .join( DocumentByConnectorCredentialPair, DocumentByConnectorCredentialPair.id == Document.id, ) .join( ConnectorCredentialPair, and_( ConnectorCredentialPair.connector_id == DocumentByConnectorCredentialPair.connector_id, ConnectorCredentialPair.credential_id == DocumentByConnectorCredentialPair.credential_id, ), ) .join( DocumentSet__ConnectorCredentialPair, DocumentSet__ConnectorCredentialPair.connector_credential_pair_id == ConnectorCredentialPair.id, ) .join( DocumentSetDBModel, DocumentSetDBModel.id == DocumentSet__ConnectorCredentialPair.document_set_id, ) .where(DocumentSetDBModel.id == document_set_id) .order_by(Document.id) .limit(limit) ) if last_document_id is not None: stmt = stmt.where(Document.id > last_document_id) if current_only: stmt = stmt.where( DocumentSet__ConnectorCredentialPair.is_current == True # noqa: E712 ) stmt = stmt.distinct() documents = db_session.scalars(stmt).all() return documents, documents[-1].id if documents else None def construct_document_id_select_by_docset( document_set_id: int, current_only: bool = True, ) -> Select: """This returns a statement that should be executed using .yield_per() to minimize overhead. The primary consumers of this function are background processing task generators.""" stmt = ( select(Document.id) .join( DocumentByConnectorCredentialPair, DocumentByConnectorCredentialPair.id == Document.id, ) .join( ConnectorCredentialPair, and_( ConnectorCredentialPair.connector_id == DocumentByConnectorCredentialPair.connector_id, ConnectorCredentialPair.credential_id == DocumentByConnectorCredentialPair.credential_id, ), ) .join( DocumentSet__ConnectorCredentialPair, DocumentSet__ConnectorCredentialPair.connector_credential_pair_id == ConnectorCredentialPair.id, ) .join( DocumentSetDBModel, DocumentSetDBModel.id == DocumentSet__ConnectorCredentialPair.document_set_id, ) .where(DocumentSetDBModel.id == document_set_id) .order_by(Document.id) ) if current_only: stmt = stmt.where( DocumentSet__ConnectorCredentialPair.is_current == True # noqa: E712 ) stmt = stmt.distinct() return stmt def fetch_document_sets_for_document( document_id: str, db_session: Session, ) -> list[str]: """ Fetches the document set names for a single document ID. :param document_id: The ID of the document to fetch sets for. :param db_session: The SQLAlchemy session to use for the query. :return: A list of document set names, or None if no result is found. """ result = fetch_document_sets_for_documents([document_id], db_session) if not result: return [] return result[0][1] def fetch_document_sets_for_documents( document_ids: list[str], db_session: Session, ) -> Sequence[tuple[str, list[str]]]: """Gives back a list of (document_id, list[document_set_names]) tuples""" """Building subqueries""" # NOTE: have to build these subqueries first in order to guarantee that we get one # returned row for each specified document_id. Basically, we want to do the filters first, # then the outer joins. # don't include CC pairs that are being deleted # NOTE: CC pairs can never go from DELETING to any other state -> it's safe to ignore them # as we can assume their document sets are no longer relevant valid_cc_pairs_subquery = aliased( ConnectorCredentialPair, select(ConnectorCredentialPair) .where( ConnectorCredentialPair.status != ConnectorCredentialPairStatus.DELETING ) # noqa: E712 .subquery(), ) valid_document_set__cc_pairs_subquery = aliased( DocumentSet__ConnectorCredentialPair, select(DocumentSet__ConnectorCredentialPair) .where(DocumentSet__ConnectorCredentialPair.is_current == True) # noqa: E712 .subquery(), ) """End building subqueries""" stmt = ( select( Document.id, func.coalesce( func.array_remove(func.array_agg(DocumentSetDBModel.name), None), [] ).label("document_set_names"), ) # Here we select document sets by relation: # Document -> DocumentByConnectorCredentialPair -> ConnectorCredentialPair -> # DocumentSet__ConnectorCredentialPair -> DocumentSet .outerjoin( DocumentByConnectorCredentialPair, Document.id == DocumentByConnectorCredentialPair.id, ) .outerjoin( valid_cc_pairs_subquery, and_( DocumentByConnectorCredentialPair.connector_id == valid_cc_pairs_subquery.connector_id, DocumentByConnectorCredentialPair.credential_id == valid_cc_pairs_subquery.credential_id, ), ) .outerjoin( valid_document_set__cc_pairs_subquery, valid_cc_pairs_subquery.id == valid_document_set__cc_pairs_subquery.connector_credential_pair_id, ) .outerjoin( DocumentSetDBModel, DocumentSetDBModel.id == valid_document_set__cc_pairs_subquery.document_set_id, ) .where(Document.id.in_(document_ids)) .group_by(Document.id) ) return db_session.execute(stmt).all() # type: ignore def get_or_create_document_set_by_name( db_session: Session, document_set_name: str, document_set_description: str = "Default Persona created Document-Set, please update description", ) -> DocumentSetDBModel: """This is used by the default personas which need to attach to document sets on server startup""" doc_set = get_document_set_by_name(db_session, document_set_name) if doc_set is not None: return doc_set new_doc_set = DocumentSetDBModel( name=document_set_name, description=document_set_description, user_id=None, is_up_to_date=True, ) db_session.add(new_doc_set) db_session.commit() return new_doc_set def check_document_sets_are_public( db_session: Session, document_set_ids: list[int], ) -> bool: """Checks if any of the CC-Pairs are Non Public (meaning that some documents in this document set is not Public""" connector_credential_pair_ids = ( db_session.query( DocumentSet__ConnectorCredentialPair.connector_credential_pair_id ) .filter( DocumentSet__ConnectorCredentialPair.document_set_id.in_(document_set_ids) ) .subquery() ) not_public_exists = ( db_session.query(ConnectorCredentialPair.id) .filter( ConnectorCredentialPair.id.in_( connector_credential_pair_ids # type:ignore ), ConnectorCredentialPair.access_type != AccessType.PUBLIC, ) .limit(1) .first() is not None ) return not not_public_exists ================================================ FILE: backend/onyx/db/engine/__init__.py ================================================ ================================================ FILE: backend/onyx/db/engine/async_sql_engine.py ================================================ from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from typing import Any from typing import AsyncContextManager import asyncpg # type: ignore from fastapi import HTTPException from sqlalchemy import event from sqlalchemy import pool from sqlalchemy.ext.asyncio import AsyncEngine from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import create_async_engine from onyx.configs.app_configs import AWS_REGION_NAME from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE from onyx.configs.app_configs import POSTGRES_DB from onyx.configs.app_configs import POSTGRES_HOST from onyx.configs.app_configs import POSTGRES_POOL_PRE_PING from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE from onyx.configs.app_configs import POSTGRES_PORT from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL from onyx.configs.app_configs import POSTGRES_USER from onyx.db.engine.iam_auth import create_ssl_context_if_iam from onyx.db.engine.iam_auth import get_iam_auth_token from onyx.db.engine.sql_engine import ASYNC_DB_API from onyx.db.engine.sql_engine import build_connection_string from onyx.db.engine.sql_engine import is_valid_schema_name from onyx.db.engine.sql_engine import SqlEngine from onyx.db.engine.sql_engine import USE_IAM_AUTH from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE from shared_configs.contextvars import get_current_tenant_id # Global so we don't create more than one engine per process _ASYNC_ENGINE: AsyncEngine | None = None async def get_async_connection() -> Any: """ Custom connection function for async engine when using IAM auth. """ host = POSTGRES_HOST port = POSTGRES_PORT user = POSTGRES_USER db = POSTGRES_DB token = get_iam_auth_token(host, port, user, AWS_REGION_NAME) # asyncpg requires 'ssl="require"' if SSL needed return await asyncpg.connect( user=user, password=token, host=host, port=int(port), database=db, ssl="require" ) def get_sqlalchemy_async_engine() -> AsyncEngine: global _ASYNC_ENGINE if _ASYNC_ENGINE is None: app_name = SqlEngine.get_app_name() + "_async" connection_string = build_connection_string( db_api=ASYNC_DB_API, use_iam_auth=USE_IAM_AUTH, ) connect_args: dict[str, Any] = {} if app_name: connect_args["server_settings"] = {"application_name": app_name} connect_args["ssl"] = create_ssl_context_if_iam() engine_kwargs = { "connect_args": connect_args, "pool_pre_ping": POSTGRES_POOL_PRE_PING, "pool_recycle": POSTGRES_POOL_RECYCLE, } if POSTGRES_USE_NULL_POOL: engine_kwargs["poolclass"] = pool.NullPool else: engine_kwargs["pool_size"] = POSTGRES_API_SERVER_POOL_SIZE engine_kwargs["max_overflow"] = POSTGRES_API_SERVER_POOL_OVERFLOW _ASYNC_ENGINE = create_async_engine( connection_string, **engine_kwargs, ) if USE_IAM_AUTH: @event.listens_for(_ASYNC_ENGINE.sync_engine, "do_connect") def provide_iam_token_async( dialect: Any, # noqa: ARG001 conn_rec: Any, # noqa: ARG001 cargs: Any, # noqa: ARG001 cparams: Any, ) -> None: # For async engine using asyncpg, we still need to set the IAM token here. host = POSTGRES_HOST port = POSTGRES_PORT user = POSTGRES_USER token = get_iam_auth_token(host, port, user, AWS_REGION_NAME) cparams["password"] = token cparams["ssl"] = create_ssl_context_if_iam() return _ASYNC_ENGINE async def get_async_session( tenant_id: str | None = None, ) -> AsyncGenerator[AsyncSession, None]: """For use w/ Depends for *async* FastAPI endpoints. For standard `async with ... as ...` use, use get_async_session_context_manager. """ if tenant_id is None: tenant_id = get_current_tenant_id() if not is_valid_schema_name(tenant_id): raise HTTPException(status_code=400, detail="Invalid tenant ID") engine = get_sqlalchemy_async_engine() # no need to use the schema translation map for self-hosted + default schema if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE: async with AsyncSession(bind=engine, expire_on_commit=False) as session: yield session return # Create connection with schema translation to handle querying the right schema schema_translate_map = {None: tenant_id} async with engine.connect() as connection: connection = await connection.execution_options( schema_translate_map=schema_translate_map ) async with AsyncSession( bind=connection, expire_on_commit=False ) as async_session: yield async_session def get_async_session_context_manager( tenant_id: str | None = None, ) -> AsyncContextManager[AsyncSession]: return asynccontextmanager(get_async_session)(tenant_id) ================================================ FILE: backend/onyx/db/engine/connection_warmup.py ================================================ from sqlalchemy import text from onyx.db.engine.async_sql_engine import get_sqlalchemy_async_engine from onyx.db.engine.sql_engine import get_sqlalchemy_engine async def warm_up_connections( sync_connections_to_warm_up: int = 20, async_connections_to_warm_up: int = 20 ) -> None: sync_postgres_engine = get_sqlalchemy_engine() connections = [ sync_postgres_engine.connect() for _ in range(sync_connections_to_warm_up) ] for conn in connections: conn.execute(text("SELECT 1")) for conn in connections: conn.close() async_postgres_engine = get_sqlalchemy_async_engine() async_connections = [ await async_postgres_engine.connect() for _ in range(async_connections_to_warm_up) ] for async_conn in async_connections: await async_conn.execute(text("SELECT 1")) for async_conn in async_connections: await async_conn.close() ================================================ FILE: backend/onyx/db/engine/iam_auth.py ================================================ import functools import os import ssl from typing import Any import boto3 from onyx.configs.app_configs import POSTGRES_HOST from onyx.configs.app_configs import POSTGRES_PORT from onyx.configs.app_configs import POSTGRES_USER from onyx.configs.app_configs import USE_IAM_AUTH from onyx.configs.constants import SSL_CERT_FILE def get_iam_auth_token( host: str, port: str, user: str, region: str = "us-east-2" ) -> str: """ Generate an IAM authentication token using boto3. """ client = boto3.client("rds", region_name=region) token = client.generate_db_auth_token( DBHostname=host, Port=int(port), DBUsername=user ) return token def configure_psycopg2_iam_auth( cparams: dict[str, Any], host: str, port: str, user: str, region: str ) -> None: """ Configure cparams for psycopg2 with IAM token and SSL. """ token = get_iam_auth_token(host, port, user, region) cparams["password"] = token cparams["sslmode"] = "require" cparams["sslrootcert"] = SSL_CERT_FILE def provide_iam_token( dialect: Any, # noqa: ARG001 conn_rec: Any, # noqa: ARG001 cargs: Any, # noqa: ARG001 cparams: Any, ) -> None: if USE_IAM_AUTH: host = POSTGRES_HOST port = POSTGRES_PORT user = POSTGRES_USER region = os.getenv("AWS_REGION_NAME", "us-east-2") # Configure for psycopg2 with IAM token configure_psycopg2_iam_auth(cparams, host, port, user, region) @functools.cache def create_ssl_context_if_iam() -> ssl.SSLContext | None: """Create an SSL context if IAM authentication is enabled, else return None.""" if USE_IAM_AUTH: return ssl.create_default_context(cafile=SSL_CERT_FILE) return None ================================================ FILE: backend/onyx/db/engine/sql_engine.py ================================================ import os import re import threading import time from collections.abc import Generator from contextlib import contextmanager from typing import Any from fastapi import HTTPException from sqlalchemy import event from sqlalchemy import pool from sqlalchemy.engine import create_engine from sqlalchemy.engine import Engine from sqlalchemy.orm import Session from onyx.configs.app_configs import DB_READONLY_PASSWORD from onyx.configs.app_configs import DB_READONLY_USER from onyx.configs.app_configs import LOG_POSTGRES_CONN_COUNTS from onyx.configs.app_configs import LOG_POSTGRES_LATENCY from onyx.configs.app_configs import POSTGRES_DB from onyx.configs.app_configs import POSTGRES_HOST from onyx.configs.app_configs import POSTGRES_PASSWORD from onyx.configs.app_configs import POSTGRES_POOL_PRE_PING from onyx.configs.app_configs import POSTGRES_POOL_RECYCLE from onyx.configs.app_configs import POSTGRES_PORT from onyx.configs.app_configs import POSTGRES_USE_NULL_POOL from onyx.configs.app_configs import POSTGRES_USER from onyx.configs.constants import POSTGRES_UNKNOWN_APP_NAME from onyx.db.engine.iam_auth import provide_iam_token from onyx.server.utils import BasicAuthenticationError from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.contextvars import get_current_tenant_id # Moved is_valid_schema_name here to avoid circular import logger = setup_logger() # Schema name validation (moved here to avoid circular import) SCHEMA_NAME_REGEX = re.compile(r"^[a-zA-Z0-9_-]+$") def is_valid_schema_name(name: str) -> bool: return SCHEMA_NAME_REGEX.match(name) is not None SYNC_DB_API = "psycopg2" ASYNC_DB_API = "asyncpg" # why isn't this in configs? USE_IAM_AUTH = os.getenv("USE_IAM_AUTH", "False").lower() == "true" def build_connection_string( *, db_api: str = ASYNC_DB_API, user: str = POSTGRES_USER, password: str = POSTGRES_PASSWORD, host: str = POSTGRES_HOST, port: str = POSTGRES_PORT, db: str = POSTGRES_DB, app_name: str | None = None, use_iam_auth: bool = USE_IAM_AUTH, region: str = "us-west-2", # noqa: ARG001 ) -> str: if use_iam_auth: base_conn_str = f"postgresql+{db_api}://{user}@{host}:{port}/{db}" else: base_conn_str = f"postgresql+{db_api}://{user}:{password}@{host}:{port}/{db}" # For asyncpg, do not include application_name in the connection string if app_name and db_api != "asyncpg": if "?" in base_conn_str: return f"{base_conn_str}&application_name={app_name}" else: return f"{base_conn_str}?application_name={app_name}" return base_conn_str if LOG_POSTGRES_LATENCY: @event.listens_for(Engine, "before_cursor_execute") def before_cursor_execute( # type: ignore conn, cursor, # noqa: ARG001 statement, # noqa: ARG001 parameters, # noqa: ARG001 context, # noqa: ARG001 executemany, # noqa: ARG001 ): conn.info["query_start_time"] = time.time() @event.listens_for(Engine, "after_cursor_execute") def after_cursor_execute( # type: ignore conn, cursor, # noqa: ARG001 statement, parameters, # noqa: ARG001 context, # noqa: ARG001 executemany, # noqa: ARG001 ): total_time = time.time() - conn.info["query_start_time"] if total_time > 0.1: logger.debug( f"Query Complete: {statement}\n\nTotal Time: {total_time:.4f} seconds" ) if LOG_POSTGRES_CONN_COUNTS: checkout_count = 0 checkin_count = 0 @event.listens_for(Engine, "checkout") def log_checkout(dbapi_connection, connection_record, connection_proxy): # type: ignore # noqa: ARG001 global checkout_count checkout_count += 1 active_connections = connection_proxy._pool.checkedout() idle_connections = connection_proxy._pool.checkedin() pool_size = connection_proxy._pool.size() logger.debug( "Connection Checkout\n" f"Active Connections: {active_connections};\n" f"Idle: {idle_connections};\n" f"Pool Size: {pool_size};\n" f"Total connection checkouts: {checkout_count}" ) @event.listens_for(Engine, "checkin") def log_checkin(dbapi_connection, connection_record): # type: ignore # noqa: ARG001 global checkin_count checkin_count += 1 logger.debug(f"Total connection checkins: {checkin_count}") class SqlEngine: _engine: Engine | None = None _readonly_engine: Engine | None = None _lock: threading.Lock = threading.Lock() _readonly_lock: threading.Lock = threading.Lock() _app_name: str = POSTGRES_UNKNOWN_APP_NAME @classmethod def init_engine( cls, pool_size: int, # is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy max_overflow: int, app_name: str | None = None, # noqa: ARG003 db_api: str = SYNC_DB_API, use_iam: bool = USE_IAM_AUTH, connection_string: str | None = None, **extra_engine_kwargs: Any, ) -> None: """NOTE: enforce that pool_size and pool_max_overflow are passed in. These are important args, and if incorrectly specified, we have run into hitting the pool limit / using too many connections and overwhelming the database. Specifying connection_string directly will cause some of the other parameters to be ignored. """ with cls._lock: if cls._engine: return if not connection_string: connection_string = build_connection_string( db_api=db_api, app_name=cls._app_name + "_sync", use_iam_auth=use_iam, ) # Start with base kwargs that are valid for all pool types final_engine_kwargs: dict[str, Any] = {} if POSTGRES_USE_NULL_POOL: # if null pool is specified, then we need to make sure that # we remove any passed in kwargs related to pool size that would # cause the initialization to fail final_engine_kwargs.update(extra_engine_kwargs) final_engine_kwargs["poolclass"] = pool.NullPool if "pool_size" in final_engine_kwargs: del final_engine_kwargs["pool_size"] if "max_overflow" in final_engine_kwargs: del final_engine_kwargs["max_overflow"] else: final_engine_kwargs["pool_size"] = pool_size final_engine_kwargs["max_overflow"] = max_overflow final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE # any passed in kwargs override the defaults final_engine_kwargs.update(extra_engine_kwargs) logger.info(f"Creating engine with kwargs: {final_engine_kwargs}") # echo=True here for inspecting all emitted db queries engine = create_engine(connection_string, **final_engine_kwargs) if use_iam: event.listen(engine, "do_connect", provide_iam_token) cls._engine = engine @classmethod def init_readonly_engine( cls, pool_size: int, # is really `pool_max_overflow`, but calling it `max_overflow` to stay consistent with SQLAlchemy max_overflow: int, **extra_engine_kwargs: Any, ) -> None: """NOTE: enforce that pool_size and pool_max_overflow are passed in. These are important args, and if incorrectly specified, we have run into hitting the pool limit / using too many connections and overwhelming the database.""" with cls._readonly_lock: if cls._readonly_engine: return if not DB_READONLY_USER or not DB_READONLY_PASSWORD: raise ValueError( "Custom database user credentials not configured in environment variables" ) # Build connection string with custom user connection_string = build_connection_string( user=DB_READONLY_USER, password=DB_READONLY_PASSWORD, use_iam_auth=False, # Custom users typically don't use IAM auth db_api=SYNC_DB_API, # Explicitly use sync DB API ) # Start with base kwargs that are valid for all pool types final_engine_kwargs: dict[str, Any] = {} if POSTGRES_USE_NULL_POOL: # if null pool is specified, then we need to make sure that # we remove any passed in kwargs related to pool size that would # cause the initialization to fail final_engine_kwargs.update(extra_engine_kwargs) final_engine_kwargs["poolclass"] = pool.NullPool if "pool_size" in final_engine_kwargs: del final_engine_kwargs["pool_size"] if "max_overflow" in final_engine_kwargs: del final_engine_kwargs["max_overflow"] else: final_engine_kwargs["pool_size"] = pool_size final_engine_kwargs["max_overflow"] = max_overflow final_engine_kwargs["pool_pre_ping"] = POSTGRES_POOL_PRE_PING final_engine_kwargs["pool_recycle"] = POSTGRES_POOL_RECYCLE # any passed in kwargs override the defaults final_engine_kwargs.update(extra_engine_kwargs) logger.info(f"Creating engine with kwargs: {final_engine_kwargs}") # echo=True here for inspecting all emitted db queries engine = create_engine(connection_string, **final_engine_kwargs) if USE_IAM_AUTH: event.listen(engine, "do_connect", provide_iam_token) cls._readonly_engine = engine @classmethod def get_engine(cls) -> Engine: if not cls._engine: raise RuntimeError("Engine not initialized. Must call init_engine first.") return cls._engine @classmethod def get_readonly_engine(cls) -> Engine: if not cls._readonly_engine: raise RuntimeError( "Readonly engine not initialized. Must call init_readonly_engine first." ) return cls._readonly_engine @classmethod def set_app_name(cls, app_name: str) -> None: cls._app_name = app_name @classmethod def get_app_name(cls) -> str: if not cls._app_name: return "" return cls._app_name @classmethod def reset_engine(cls) -> None: with cls._lock: if cls._engine: cls._engine.dispose() cls._engine = None @classmethod @contextmanager def scoped_engine(cls, **init_kwargs: Any) -> Generator[None, None, None]: """Context manager that initializes the engine and guarantees cleanup.""" cls.init_engine(**init_kwargs) try: yield finally: cls.reset_engine() def get_sqlalchemy_engine() -> Engine: return SqlEngine.get_engine() def get_readonly_sqlalchemy_engine() -> Engine: return SqlEngine.get_readonly_engine() @contextmanager def get_session_with_current_tenant() -> Generator[Session, None, None]: """Standard way to get a DB session.""" tenant_id = get_current_tenant_id() with get_session_with_tenant(tenant_id=tenant_id) as session: yield session @contextmanager def get_session_with_current_tenant_if_none( session: Session | None, ) -> Generator[Session, None, None]: if session is None: tenant_id = get_current_tenant_id() with get_session_with_tenant(tenant_id=tenant_id) as session: yield session else: yield session # Used in multi tenant mode when need to refer to the shared `public` schema @contextmanager def get_session_with_shared_schema() -> Generator[Session, None, None]: token = CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA) with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as session: yield session CURRENT_TENANT_ID_CONTEXTVAR.reset(token) @contextmanager def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]: """ Generate a database session for a specific tenant. """ engine = get_sqlalchemy_engine() if not is_valid_schema_name(tenant_id): raise HTTPException(status_code=400, detail="Invalid tenant ID") # no need to use the schema translation map for self-hosted + default schema if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE: with Session(bind=engine, expire_on_commit=False) as session: yield session return # Create connection with schema translation to handle querying the right schema schema_translate_map = {None: tenant_id} with engine.connect().execution_options( schema_translate_map=schema_translate_map ) as connection: with Session(bind=connection, expire_on_commit=False) as session: yield session def get_session() -> Generator[Session, None, None]: """For use w/ Depends for FastAPI endpoints. Has some additional validation, and likely should be merged with get_session_with_current_tenant in the future.""" tenant_id = get_current_tenant_id() if tenant_id == POSTGRES_DEFAULT_SCHEMA and MULTI_TENANT: raise BasicAuthenticationError(detail="User must authenticate") if not is_valid_schema_name(tenant_id): raise HTTPException(status_code=400, detail="Invalid tenant ID") with get_session_with_current_tenant() as db_session: yield db_session @contextmanager def get_db_readonly_user_session_with_current_tenant() -> ( Generator[Session, None, None] ): """ Generate a database session using a custom database user for the current tenant. The custom user credentials are obtained from environment variables. """ tenant_id = get_current_tenant_id() readonly_engine = get_readonly_sqlalchemy_engine() if not is_valid_schema_name(tenant_id): raise HTTPException(status_code=400, detail="Invalid tenant ID") # no need to use the schema translation map for self-hosted + default schema if not MULTI_TENANT and tenant_id == POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE: with Session(readonly_engine, expire_on_commit=False) as session: yield session return schema_translate_map = {None: tenant_id} with readonly_engine.connect().execution_options( schema_translate_map=schema_translate_map ) as connection: with Session(bind=connection, expire_on_commit=False) as session: yield session ================================================ FILE: backend/onyx/db/engine/tenant_utils.py ================================================ from sqlalchemy import text from onyx.db.engine.sql_engine import get_session_with_shared_schema from onyx.db.engine.sql_engine import SqlEngine from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.configs import TENANT_ID_PREFIX def get_schemas_needing_migration( tenant_schemas: list[str], head_rev: str ) -> list[str]: """Return only schemas whose current alembic version is not at head. Uses a server-side PL/pgSQL loop to collect each schema's alembic version into a temp table one at a time. This avoids building a massive UNION ALL query (which locks the DB and times out at 17k+ schemas) and instead acquires locks sequentially, one schema per iteration. """ if not tenant_schemas: return [] engine = SqlEngine.get_engine() with engine.connect() as conn: # Populate a temp input table with exactly the schemas we care about. # The DO block reads from this table so it only iterates the requested # schemas instead of every tenant_% schema in the database. conn.execute(text("DROP TABLE IF EXISTS _alembic_version_snapshot")) conn.execute(text("DROP TABLE IF EXISTS _tenant_schemas_input")) conn.execute(text("CREATE TEMP TABLE _tenant_schemas_input (schema_name text)")) conn.execute( text( "INSERT INTO _tenant_schemas_input (schema_name) SELECT unnest(CAST(:schemas AS text[]))" ), {"schemas": tenant_schemas}, ) conn.execute( text( "CREATE TEMP TABLE _alembic_version_snapshot (schema_name text, version_num text)" ) ) conn.execute( text( """ DO $$ DECLARE s text; schemas text[]; BEGIN SELECT array_agg(schema_name) INTO schemas FROM _tenant_schemas_input; IF schemas IS NULL THEN RAISE NOTICE 'No tenant schemas found.'; RETURN; END IF; FOREACH s IN ARRAY schemas LOOP BEGIN EXECUTE format( 'INSERT INTO _alembic_version_snapshot SELECT %L, version_num FROM %I.alembic_version', s, s ); EXCEPTION -- undefined_table: schema exists but has no alembic_version -- table yet (new tenant, not yet migrated). -- invalid_schema_name: tenant is registered but its -- PostgreSQL schema does not exist yet (e.g. provisioning -- incomplete). Both cases mean no version is available and -- the schema will be included in the migration list. WHEN undefined_table THEN NULL; WHEN invalid_schema_name THEN NULL; END; END LOOP; END; $$ """ ) ) rows = conn.execute( text("SELECT schema_name, version_num FROM _alembic_version_snapshot") ) version_by_schema = {row[0]: row[1] for row in rows} conn.execute(text("DROP TABLE IF EXISTS _alembic_version_snapshot")) conn.execute(text("DROP TABLE IF EXISTS _tenant_schemas_input")) # Schemas missing from the snapshot have no alembic_version table yet and # also need migration. version_by_schema.get(s) returns None for those, # and None != head_rev, so they are included automatically. return [s for s in tenant_schemas if version_by_schema.get(s) != head_rev] def get_all_tenant_ids() -> list[str]: """Returning [None] means the only tenant is the 'public' or self hosted tenant.""" tenant_ids: list[str] if not MULTI_TENANT: return [POSTGRES_DEFAULT_SCHEMA] with get_session_with_shared_schema() as session: result = session.execute( text( f""" SELECT schema_name FROM information_schema.schemata WHERE schema_name NOT IN ('pg_catalog', 'information_schema', '{POSTGRES_DEFAULT_SCHEMA}')""" ) ) tenant_ids = [row[0] for row in result] valid_tenants = [ tenant for tenant in tenant_ids if tenant is None or tenant.startswith(TENANT_ID_PREFIX) ] return valid_tenants ================================================ FILE: backend/onyx/db/engine/time_utils.py ================================================ from datetime import datetime from sqlalchemy import text from sqlalchemy.orm import Session def get_db_current_time(db_session: Session) -> datetime: result = db_session.execute(text("SELECT NOW()")).scalar() if result is None: raise ValueError("Database did not return a time") return result ================================================ FILE: backend/onyx/db/entities.py ================================================ import uuid from datetime import datetime from datetime import timezone from typing import List from sqlalchemy import func from sqlalchemy import literal from sqlalchemy import select from sqlalchemy import update from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.orm import Session import onyx.db.document as dbdocument from onyx.db.entity_type import UNGROUNDED_SOURCE_NAME from onyx.db.models import Document from onyx.db.models import KGEntity from onyx.db.models import KGEntityExtractionStaging from onyx.db.models import KGEntityType from onyx.kg.models import KGGroundingType from onyx.kg.models import KGStage from onyx.kg.utils.formatting_utils import make_entity_id def upsert_staging_entity( db_session: Session, name: str, entity_type: str, document_id: str | None = None, occurrences: int = 1, attributes: dict[str, str] | None = None, event_time: datetime | None = None, ) -> KGEntityExtractionStaging: """Add or update a new staging entity to the database. Args: db_session: SQLAlchemy session name: Name of the entity entity_type: Type of the entity (must match an existing KGEntityType) document_id: ID of the document the entity belongs to occurrences: Number of times this entity has been found attributes: Attributes of the entity event_time: Time the entity was added to the database Returns: KGEntityExtractionStaging: The created entity """ entity_type = entity_type.upper() name = name.title() id_name = make_entity_id(entity_type, name) attributes = attributes or {} entity_key = attributes.get("key") entity_parent = attributes.get("parent") keep_attributes = { attr_key: attr_val for attr_key, attr_val in attributes.items() if attr_key not in ("key", "parent") } # Create new entity stmt = ( pg_insert(KGEntityExtractionStaging) .values( id_name=id_name, name=name, entity_type_id_name=entity_type, entity_key=entity_key, parent_key=entity_parent, document_id=document_id, occurrences=occurrences, attributes=keep_attributes, event_time=event_time, ) .on_conflict_do_update( index_elements=["id_name"], set_=dict( occurrences=KGEntityExtractionStaging.occurrences + occurrences, ), ) .returning(KGEntityExtractionStaging) ) result = db_session.execute(stmt).scalar() if result is None: raise RuntimeError( f"Failed to create or increment staging entity with id_name: {id_name}" ) # Update the document's kg_stage if document_id is provided if document_id is not None: db_session.query(Document).filter(Document.id == document_id).update( { "kg_stage": KGStage.EXTRACTED, "kg_processing_time": datetime.now(timezone.utc), } ) db_session.flush() return result def transfer_entity( db_session: Session, entity: KGEntityExtractionStaging, ) -> KGEntity: """Transfer an entity from the extraction staging table to the normalized table. Args: db_session: SQLAlchemy session entity: Entity to transfer Returns: KGEntity: The transferred entity """ # Create the transferred entity stmt = ( pg_insert(KGEntity) .values( id_name=make_entity_id(entity.entity_type_id_name, uuid.uuid4().hex[:20]), name=entity.name.casefold(), entity_key=entity.entity_key, parent_key=entity.parent_key, alternative_names=entity.alternative_names or [], entity_type_id_name=entity.entity_type_id_name, document_id=entity.document_id, occurrences=entity.occurrences, attributes=entity.attributes, event_time=entity.event_time, ) .on_conflict_do_update( index_elements=["name", "entity_type_id_name", "document_id"], set_=dict( occurrences=KGEntity.occurrences + entity.occurrences, attributes=KGEntity.attributes.op("||")( literal(entity.attributes, JSONB) ), entity_key=func.coalesce(KGEntity.entity_key, entity.entity_key), parent_key=func.coalesce(KGEntity.parent_key, entity.parent_key), event_time=entity.event_time, time_updated=datetime.now(), ), ) .returning(KGEntity) ) new_entity = db_session.execute(stmt).scalar() if new_entity is None: raise RuntimeError(f"Failed to transfer entity with id_name: {entity.id_name}") # Update the document's kg_stage if document_id is provided if entity.document_id is not None: dbdocument.update_document_kg_info( db_session, document_id=entity.document_id, kg_stage=KGStage.NORMALIZED, ) # Update transferred db_session.query(KGEntityExtractionStaging).filter( KGEntityExtractionStaging.id_name == entity.id_name ).update({"transferred_id_name": new_entity.id_name}) db_session.flush() return new_entity def merge_entities( db_session: Session, parent: KGEntity, child: KGEntityExtractionStaging ) -> KGEntity: """Merge an entity from the extraction staging table into an existing entity in the normalized table. Args: db_session: SQLAlchemy session parent: Parent entity to merge into child: Child staging entity to merge Returns: KGEntity: The merged entity """ # check we're not merging two entities with different document_ids if ( parent.document_id is not None and child.document_id is not None and parent.document_id != child.document_id ): raise ValueError( "Overwriting the document_id of an entity with a document_id already is not allowed" ) # update the parent entity (only document_id, alternative_names, occurrences) setting_doc = parent.document_id is None and child.document_id is not None document_id = child.document_id if setting_doc else parent.document_id alternative_names = set(parent.alternative_names or []) alternative_names.update(child.alternative_names or []) alternative_names.add(child.name.lower()) alternative_names.discard(parent.name) stmt = ( update(KGEntity) .where(KGEntity.id_name == parent.id_name) .values( document_id=document_id, alternative_names=list(alternative_names), occurrences=parent.occurrences + child.occurrences, attributes=parent.attributes | child.attributes, entity_key=parent.entity_key or child.entity_key, parent_key=parent.parent_key or child.parent_key, ) .returning(KGEntity) ) result = db_session.execute(stmt).scalar() if result is None: raise RuntimeError(f"Failed to merge entities with id_name: {parent.id_name}") # Update the document's kg_stage if document_id is set if setting_doc and child.document_id is not None: dbdocument.update_document_kg_info( db_session, document_id=child.document_id, kg_stage=KGStage.NORMALIZED, ) # Update transferred db_session.query(KGEntityExtractionStaging).filter( KGEntityExtractionStaging.id_name == child.id_name ).update({"transferred_id_name": parent.id_name}) db_session.flush() return result def get_kg_entity_by_document(db: Session, document_id: str) -> KGEntity | None: """ Check if a document_id exists in the kg_entities table and return its id_name if found. Args: db: SQLAlchemy database session document_id: The document ID to search for Returns: The id_name of the matching KGEntity if found, None otherwise """ query = select(KGEntity).where(KGEntity.document_id == document_id) result = db.execute(query).scalar() return result def get_grounded_entities_by_types( db_session: Session, entity_types: List[str], grounding: KGGroundingType ) -> List[KGEntity]: """Get all entities matching an entity_type. Args: db_session: SQLAlchemy session entity_types: List of entity types to filter by Returns: List of KGEntity objects belonging to the specified entity types """ return ( db_session.query(KGEntity) .join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name) .filter(KGEntity.entity_type_id_name.in_(entity_types)) .filter(KGEntityType.grounding == grounding) .all() ) def get_document_id_for_entity(db_session: Session, entity_id_name: str) -> str | None: """Get the document ID associated with an entity. Args: db_session: SQLAlchemy database session entity_id_name: The entity id_name to look up Returns: The document ID if found, None otherwise """ entity = ( db_session.query(KGEntity).filter(KGEntity.id_name == entity_id_name).first() ) return entity.document_id if entity else None def delete_from_kg_entities_extraction_staging__no_commit( db_session: Session, document_ids: list[str] ) -> None: """Delete entities from the extraction staging table.""" db_session.query(KGEntityExtractionStaging).filter( KGEntityExtractionStaging.document_id.in_(document_ids) ).delete(synchronize_session=False) def delete_from_kg_entities__no_commit( db_session: Session, document_ids: list[str] ) -> None: """Delete entities from the normalized table.""" db_session.query(KGEntity).filter(KGEntity.document_id.in_(document_ids)).delete( synchronize_session=False ) def get_entity_name(db_session: Session, entity_id_name: str) -> str | None: """Get the name of an entity.""" entity = ( db_session.query(KGEntity).filter(KGEntity.id_name == entity_id_name).first() ) return entity.name if entity else None def get_entity_stats_by_grounded_source_name( db_session: Session, ) -> dict[str, tuple[datetime, int]]: """ Returns a dict mapping each grounded_source_name to a tuple in which: - the first element is the latest update time across all entities with the same entity-type - the second element is the count of `KGEntity`s """ results = ( db_session.query( KGEntityType.grounded_source_name, func.count(KGEntity.id_name).label("entities_count"), func.max(KGEntity.time_updated).label("last_updated"), ) .join(KGEntityType, KGEntity.entity_type_id_name == KGEntityType.id_name) .group_by(KGEntityType.grounded_source_name) .all() ) # `row.grounded_source_name` is NULLABLE in the database schema. # Thus, for all "ungrounded" entity-types, we use a default name. return { (row.grounded_source_name or UNGROUNDED_SOURCE_NAME): ( row.last_updated, row.entities_count, ) for row in results } ================================================ FILE: backend/onyx/db/entity_type.py ================================================ from collections import defaultdict from sqlalchemy import update from sqlalchemy.orm import Session from onyx.db.connector import fetch_unique_document_sources from onyx.db.document import DocumentSource from onyx.db.models import Connector from onyx.db.models import KGEntityType from onyx.kg.models import KGAttributeEntityOption from onyx.server.kg.models import EntityType UNGROUNDED_SOURCE_NAME = "Ungrounded" def get_entity_types_with_grounded_source_name( db_session: Session, ) -> list[KGEntityType]: """Get all entity types that have non-null grounded_source_name. Args: db_session: SQLAlchemy session Returns: List of KGEntityType objects that have grounded_source_name defined """ return ( db_session.query(KGEntityType) .filter(KGEntityType.grounded_source_name.isnot(None)) .all() ) def get_entity_types( db_session: Session, active: bool | None = True, ) -> list[KGEntityType]: # Query the database for all distinct entity types if active is None: return db_session.query(KGEntityType).order_by(KGEntityType.id_name).all() else: return ( db_session.query(KGEntityType) .filter(KGEntityType.active == active) .order_by(KGEntityType.id_name) .all() ) def get_configured_entity_types(db_session: Session) -> dict[str, list[KGEntityType]]: # get entity types from configured sources configured_connector_sources = { source.value.lower() for source in fetch_unique_document_sources(db_session=db_session) } entity_types = ( db_session.query(KGEntityType) .filter(KGEntityType.grounded_source_name.in_(configured_connector_sources)) .all() ) entity_type_set = {et.id_name for et in entity_types} # get implied entity types from those entity types for et in entity_types: for prop in et.parsed_attributes.metadata_attribute_conversion.values(): if prop.implication_property is None: continue implied_et = prop.implication_property.implied_entity_type if implied_et == KGAttributeEntityOption.FROM_EMAIL: if "ACCOUNT" not in entity_type_set: entity_type_set.add("ACCOUNT") if "EMPLOYEE" not in entity_type_set: entity_type_set.add("EMPLOYEE") elif isinstance(implied_et, str): if implied_et not in entity_type_set: entity_type_set.add(implied_et) ets = ( db_session.query(KGEntityType) .filter(KGEntityType.id_name.in_(entity_type_set)) .all() ) et_map = defaultdict(list) for et in ets: key = et.grounded_source_name or UNGROUNDED_SOURCE_NAME et_map[key].append(et) return et_map def update_entity_types_and_related_connectors__commit( db_session: Session, updates: list[EntityType] ) -> None: for upd in updates: db_session.execute( update(KGEntityType) .where(KGEntityType.id_name == upd.name) .values( description=upd.description, active=upd.active, ) ) db_session.flush() # Update connector sources configured_entity_types = get_configured_entity_types(db_session=db_session) active_entity_type_sources = { et.grounded_source_name for ets in configured_entity_types.values() for et in ets if et.active } # Update connectors that should be enabled db_session.execute( update(Connector) .where( Connector.source.in_( [ source for source in DocumentSource if source.value.lower() in active_entity_type_sources ] ) ) .where(~Connector.kg_processing_enabled) .values(kg_processing_enabled=True) ) # Update connectors that should be disabled db_session.execute( update(Connector) .where( Connector.source.in_( [ source for source in DocumentSource if source.value.lower() not in active_entity_type_sources ] ) ) .where(Connector.kg_processing_enabled) .values(kg_processing_enabled=False) ) db_session.commit() ================================================ FILE: backend/onyx/db/enums.py ================================================ from __future__ import annotations from enum import Enum as PyEnum from typing import ClassVar class AccountType(str, PyEnum): """ What kind of account this is — determines whether the user enters the group-based permission system. STANDARD + SERVICE_ACCOUNT → participate in group system BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior """ STANDARD = "STANDARD" BOT = "BOT" EXT_PERM_USER = "EXT_PERM_USER" SERVICE_ACCOUNT = "SERVICE_ACCOUNT" ANONYMOUS = "ANONYMOUS" def is_web_login(self) -> bool: """Whether this account type supports interactive web login.""" return self not in ( AccountType.BOT, AccountType.EXT_PERM_USER, ) class GrantSource(str, PyEnum): """How a permission grant was created.""" USER = "USER" SCIM = "SCIM" SYSTEM = "SYSTEM" class IndexingStatus(str, PyEnum): NOT_STARTED = "not_started" IN_PROGRESS = "in_progress" SUCCESS = "success" CANCELED = "canceled" FAILED = "failed" COMPLETED_WITH_ERRORS = "completed_with_errors" def is_terminal(self) -> bool: terminal_states = { IndexingStatus.SUCCESS, IndexingStatus.COMPLETED_WITH_ERRORS, IndexingStatus.CANCELED, IndexingStatus.FAILED, } return self in terminal_states def is_successful(self) -> bool: return ( self == IndexingStatus.SUCCESS or self == IndexingStatus.COMPLETED_WITH_ERRORS ) class PermissionSyncStatus(str, PyEnum): """Status enum for permission sync attempts""" NOT_STARTED = "not_started" IN_PROGRESS = "in_progress" SUCCESS = "success" CANCELED = "canceled" FAILED = "failed" COMPLETED_WITH_ERRORS = "completed_with_errors" def is_terminal(self) -> bool: terminal_states = { PermissionSyncStatus.SUCCESS, PermissionSyncStatus.COMPLETED_WITH_ERRORS, PermissionSyncStatus.CANCELED, PermissionSyncStatus.FAILED, } return self in terminal_states def is_successful(self) -> bool: return ( self == PermissionSyncStatus.SUCCESS or self == PermissionSyncStatus.COMPLETED_WITH_ERRORS ) class IndexingMode(str, PyEnum): UPDATE = "update" REINDEX = "reindex" class ProcessingMode(str, PyEnum): """Determines how documents are processed after fetching.""" REGULAR = "REGULAR" # Full pipeline: chunk → embed → Vespa FILE_SYSTEM = "FILE_SYSTEM" # Write to file system only (JSON documents) RAW_BINARY = "RAW_BINARY" # Write raw binary to S3 (no text extraction) class SyncType(str, PyEnum): DOCUMENT_SET = "document_set" USER_GROUP = "user_group" CONNECTOR_DELETION = "connector_deletion" PRUNING = "pruning" # not really a sync, but close enough EXTERNAL_PERMISSIONS = "external_permissions" EXTERNAL_GROUP = "external_group" def __str__(self) -> str: return self.value class SyncStatus(str, PyEnum): IN_PROGRESS = "in_progress" SUCCESS = "success" FAILED = "failed" CANCELED = "canceled" def is_terminal(self) -> bool: terminal_states = { SyncStatus.SUCCESS, SyncStatus.FAILED, } return self in terminal_states class MCPAuthenticationType(str, PyEnum): NONE = "NONE" API_TOKEN = "API_TOKEN" OAUTH = "OAUTH" PT_OAUTH = "PT_OAUTH" # Pass-Through OAuth class MCPTransport(str, PyEnum): """MCP transport types""" STDIO = "STDIO" # TODO: currently unsupported, need to add a user guide for setup SSE = "SSE" # Server-Sent Events (deprecated but still used) STREAMABLE_HTTP = "STREAMABLE_HTTP" # Modern HTTP streaming class MCPAuthenticationPerformer(str, PyEnum): ADMIN = "ADMIN" PER_USER = "PER_USER" class MCPServerStatus(str, PyEnum): CREATED = "CREATED" # Server created, needs auth configuration AWAITING_AUTH = "AWAITING_AUTH" # Auth configured, pending user authentication FETCHING_TOOLS = "FETCHING_TOOLS" # Auth complete, fetching tools CONNECTED = "CONNECTED" # Fully configured and connected DISCONNECTED = "DISCONNECTED" # Server disconnected, but not deleted # Consistent with Celery task statuses class TaskStatus(str, PyEnum): PENDING = "PENDING" STARTED = "STARTED" SUCCESS = "SUCCESS" FAILURE = "FAILURE" class IndexModelStatus(str, PyEnum): PAST = "PAST" PRESENT = "PRESENT" FUTURE = "FUTURE" def is_current(self) -> bool: return self == IndexModelStatus.PRESENT def is_future(self) -> bool: return self == IndexModelStatus.FUTURE class ChatSessionSharedStatus(str, PyEnum): PUBLIC = "public" PRIVATE = "private" class ConnectorCredentialPairStatus(str, PyEnum): SCHEDULED = "SCHEDULED" INITIAL_INDEXING = "INITIAL_INDEXING" ACTIVE = "ACTIVE" PAUSED = "PAUSED" DELETING = "DELETING" INVALID = "INVALID" @classmethod def active_statuses(cls) -> list["ConnectorCredentialPairStatus"]: return [ ConnectorCredentialPairStatus.ACTIVE, ConnectorCredentialPairStatus.SCHEDULED, ConnectorCredentialPairStatus.INITIAL_INDEXING, ] @classmethod def indexable_statuses(self) -> list["ConnectorCredentialPairStatus"]: # Superset of active statuses for indexing model swaps return self.active_statuses() + [ ConnectorCredentialPairStatus.PAUSED, ] def is_active(self) -> bool: return self in self.active_statuses() class AccessType(str, PyEnum): PUBLIC = "public" PRIVATE = "private" SYNC = "sync" class EmbeddingPrecision(str, PyEnum): # matches vespa tensor type # only support float / bfloat16 for now, since there's not a # good reason to specify anything else BFLOAT16 = "bfloat16" FLOAT = "float" class UserFileStatus(str, PyEnum): PROCESSING = "PROCESSING" INDEXING = "INDEXING" COMPLETED = "COMPLETED" SKIPPED = "SKIPPED" FAILED = "FAILED" CANCELED = "CANCELED" DELETING = "DELETING" class ThemePreference(str, PyEnum): LIGHT = "light" DARK = "dark" SYSTEM = "system" class DefaultAppMode(str, PyEnum): AUTO = "AUTO" CHAT = "CHAT" SEARCH = "SEARCH" class SwitchoverType(str, PyEnum): REINDEX = "reindex" ACTIVE_ONLY = "active_only" INSTANT = "instant" class OpenSearchDocumentMigrationStatus(str, PyEnum): """Status for Vespa to OpenSearch migration per document.""" PENDING = "pending" COMPLETED = "completed" FAILED = "failed" PERMANENTLY_FAILED = "permanently_failed" class OpenSearchTenantMigrationStatus(str, PyEnum): """Status for tenant-level OpenSearch migration.""" PENDING = "pending" COMPLETED = "completed" # Onyx Build Mode Enums class BuildSessionStatus(str, PyEnum): ACTIVE = "active" IDLE = "idle" class SharingScope(str, PyEnum): PRIVATE = "private" PUBLIC_ORG = "public_org" PUBLIC_GLOBAL = "public_global" class SandboxStatus(str, PyEnum): PROVISIONING = "provisioning" RUNNING = "running" SLEEPING = "sleeping" # Pod terminated, snapshots saved to S3 TERMINATED = "terminated" FAILED = "failed" def is_active(self) -> bool: """Check if sandbox is in an active state (running).""" return self == SandboxStatus.RUNNING def is_terminal(self) -> bool: """Check if sandbox is in a terminal state.""" return self in (SandboxStatus.TERMINATED, SandboxStatus.FAILED) def is_sleeping(self) -> bool: """Check if sandbox is sleeping (pod terminated but can be restored).""" return self == SandboxStatus.SLEEPING class ArtifactType(str, PyEnum): WEB_APP = "web_app" PPTX = "pptx" DOCX = "docx" IMAGE = "image" MARKDOWN = "markdown" EXCEL = "excel" class HierarchyNodeType(str, PyEnum): """Types of hierarchy nodes across different sources""" # Generic FOLDER = "folder" # Root-level type SOURCE = "source" # Root node for a source (e.g., "Google Drive") # Google Drive SHARED_DRIVE = "shared_drive" MY_DRIVE = "my_drive" # Confluence SPACE = "space" PAGE = "page" # Confluence pages can be both hierarchy nodes AND documents # Jira PROJECT = "project" # Notion DATABASE = "database" WORKSPACE = "workspace" # Sharepoint SITE = "site" DRIVE = "drive" # Document library within a site # Slack CHANNEL = "channel" class LLMModelFlowType(str, PyEnum): CHAT = "chat" VISION = "vision" CONTEXTUAL_RAG = "contextual_rag" class HookPoint(str, PyEnum): DOCUMENT_INGESTION = "document_ingestion" QUERY_PROCESSING = "query_processing" class HookFailStrategy(str, PyEnum): HARD = "hard" # exception propagates, pipeline aborts SOFT = "soft" # log error, return original input, pipeline continues class Permission(str, PyEnum): """ Permission tokens for group-based authorization. 19 tokens total. full_admin_panel_access is an override — if present, any permission check passes. """ # Basic (auto-granted to every new group) BASIC_ACCESS = "basic" # Read tokens — implied only, never granted directly READ_CONNECTORS = "read:connectors" READ_DOCUMENT_SETS = "read:document_sets" READ_AGENTS = "read:agents" READ_USERS = "read:users" # Add / Manage pairs ADD_AGENTS = "add:agents" MANAGE_AGENTS = "manage:agents" MANAGE_DOCUMENT_SETS = "manage:document_sets" ADD_CONNECTORS = "add:connectors" MANAGE_CONNECTORS = "manage:connectors" MANAGE_LLMS = "manage:llms" # Toggle tokens READ_AGENT_ANALYTICS = "read:agent_analytics" MANAGE_ACTIONS = "manage:actions" READ_QUERY_HISTORY = "read:query_history" MANAGE_USER_GROUPS = "manage:user_groups" CREATE_USER_API_KEYS = "create:user_api_keys" CREATE_SERVICE_ACCOUNT_API_KEYS = "create:service_account_api_keys" CREATE_SLACK_DISCORD_BOTS = "create:slack_discord_bots" # Override — any permission check passes FULL_ADMIN_PANEL_ACCESS = "admin" # Permissions that are implied by other grants and must never be stored # directly in the permission_grant table. IMPLIED: ClassVar[frozenset[Permission]] Permission.IMPLIED = frozenset( { Permission.READ_CONNECTORS, Permission.READ_DOCUMENT_SETS, Permission.READ_AGENTS, Permission.READ_USERS, } ) ================================================ FILE: backend/onyx/db/federated.py ================================================ from datetime import datetime from typing import Any from uuid import UUID from sqlalchemy import select from sqlalchemy.orm import joinedload from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from onyx.configs.constants import FederatedConnectorSource from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import DocumentSet from onyx.db.models import FederatedConnector from onyx.db.models import FederatedConnector__DocumentSet from onyx.db.models import FederatedConnectorOAuthToken from onyx.federated_connectors.factory import get_federated_connector from onyx.utils.logger import setup_logger logger = setup_logger() def fetch_federated_connector_by_id( federated_connector_id: int, db_session: Session ) -> FederatedConnector | None: """Fetch a federated connector by its ID.""" stmt = select(FederatedConnector).where( FederatedConnector.id == federated_connector_id ) result = db_session.execute(stmt) return result.scalar_one_or_none() def fetch_all_federated_connectors(db_session: Session) -> list[FederatedConnector]: """Fetch all federated connectors with their OAuth tokens and document sets.""" stmt = select(FederatedConnector).options( selectinload(FederatedConnector.oauth_tokens), selectinload(FederatedConnector.document_sets), ) result = db_session.execute(stmt) return list(result.scalars().all()) def fetch_all_federated_connectors_parallel() -> list[FederatedConnector]: with get_session_with_current_tenant() as db_session: return fetch_all_federated_connectors(db_session) def validate_federated_connector_credentials( source: FederatedConnectorSource, credentials: dict[str, Any], ) -> bool: """Validate credentials for a federated connector using the connector's validation logic.""" try: # the initialization will fail if the credentials are invalid get_federated_connector(source, credentials) return True except Exception as e: logger.error(f"Error validating credentials for source {source}: {e}") return False def create_federated_connector( db_session: Session, source: FederatedConnectorSource, credentials: dict[str, Any], config: dict[str, Any] | None = None, ) -> FederatedConnector: """Create a new federated connector with credential and config validation.""" # Validate credentials before creating if not validate_federated_connector_credentials(source, credentials): raise ValueError( f"Invalid credentials for federated connector source: {source}" ) # Validate config using connector-specific validation if config: try: # Get connector instance to access validate_config method connector = get_federated_connector(source, credentials) if not connector.validate_config(config): raise ValueError( f"Invalid config for federated connector source: {source}" ) except Exception as e: raise ValueError(f"Config validation failed for {source}: {str(e)}") federated_connector = FederatedConnector( source=source, credentials=credentials, config=config or {}, ) db_session.add(federated_connector) db_session.commit() return federated_connector def update_federated_connector_oauth_token( db_session: Session, federated_connector_id: int, user_id: UUID, token: str, expires_at: datetime | None = None, ) -> FederatedConnectorOAuthToken: """Update or create OAuth token for a federated connector and user.""" # First, try to find existing token for this user and connector stmt = select(FederatedConnectorOAuthToken).where( FederatedConnectorOAuthToken.federated_connector_id == federated_connector_id, FederatedConnectorOAuthToken.user_id == user_id, ) existing_token = db_session.execute(stmt).scalar_one_or_none() if existing_token: # Update existing token existing_token.token = token # type: ignore[assignment] existing_token.expires_at = expires_at db_session.commit() return existing_token else: # Create new token oauth_token = FederatedConnectorOAuthToken( federated_connector_id=federated_connector_id, user_id=user_id, token=token, expires_at=expires_at, ) db_session.add(oauth_token) db_session.commit() return oauth_token def get_federated_connector_oauth_token( db_session: Session, federated_connector_id: int, user_id: UUID, ) -> FederatedConnectorOAuthToken | None: """Get OAuth token for a federated connector and user.""" stmt = select(FederatedConnectorOAuthToken).where( FederatedConnectorOAuthToken.federated_connector_id == federated_connector_id, FederatedConnectorOAuthToken.user_id == user_id, ) result = db_session.execute(stmt) return result.scalar_one_or_none() def list_federated_connector_oauth_tokens( db_session: Session, user_id: UUID, ) -> list[FederatedConnectorOAuthToken]: """List all OAuth tokens for all federated connectors.""" stmt = ( select(FederatedConnectorOAuthToken) .where( FederatedConnectorOAuthToken.user_id == user_id, ) .options( joinedload(FederatedConnectorOAuthToken.federated_connector), ) ) result = db_session.scalars(stmt) return list(result) def create_federated_connector_document_set_mapping( db_session: Session, federated_connector_id: int, document_set_id: int, entities: dict[str, Any], ) -> FederatedConnector__DocumentSet: """Create a mapping between federated connector and document set with entities.""" mapping = FederatedConnector__DocumentSet( federated_connector_id=federated_connector_id, document_set_id=document_set_id, entities=entities, ) db_session.add(mapping) db_session.commit() return mapping def update_federated_connector_document_set_entities( db_session: Session, federated_connector_id: int, document_set_id: int, entities: dict[str, Any], ) -> FederatedConnector__DocumentSet | None: """Update entities for a federated connector document set mapping.""" stmt = select(FederatedConnector__DocumentSet).where( FederatedConnector__DocumentSet.federated_connector_id == federated_connector_id, FederatedConnector__DocumentSet.document_set_id == document_set_id, ) mapping = db_session.execute(stmt).scalar_one_or_none() if mapping: mapping.entities = entities db_session.commit() return mapping return None def get_federated_connector_document_set_mappings( db_session: Session, federated_connector_id: int, ) -> list[FederatedConnector__DocumentSet]: """Get all document set mappings for a federated connector.""" stmt = select(FederatedConnector__DocumentSet).where( FederatedConnector__DocumentSet.federated_connector_id == federated_connector_id ) result = db_session.execute(stmt) return list(result.scalars().all()) def delete_federated_connector_document_set_mapping( db_session: Session, federated_connector_id: int, document_set_id: int, ) -> bool: """Delete a federated connector document set mapping.""" stmt = select(FederatedConnector__DocumentSet).where( FederatedConnector__DocumentSet.federated_connector_id == federated_connector_id, FederatedConnector__DocumentSet.document_set_id == document_set_id, ) mapping = db_session.execute(stmt).scalar_one_or_none() if mapping: db_session.delete(mapping) db_session.commit() return True return False def get_federated_connector_document_set_mappings_by_document_set_names( db_session: Session, document_set_names: list[str], ) -> list[FederatedConnector__DocumentSet]: """Get all document set mappings for a federated connector by document set names.""" stmt = ( select(FederatedConnector__DocumentSet) .join( DocumentSet, FederatedConnector__DocumentSet.document_set_id == DocumentSet.id, ) .options(joinedload(FederatedConnector__DocumentSet.federated_connector)) .where(DocumentSet.name.in_(document_set_names)) ) result = db_session.scalars(stmt) # Use unique() because joinedload can cause duplicate rows return list(result.unique()) def update_federated_connector( db_session: Session, federated_connector_id: int, credentials: dict[str, Any] | None = None, config: dict[str, Any] | None = None, ) -> FederatedConnector | None: """Update a federated connector with credential and config validation.""" federated_connector = fetch_federated_connector_by_id( federated_connector_id, db_session ) if not federated_connector: return None # Use provided credentials if updating them, otherwise use existing credentials # This is needed to instantiate the connector for config validation when only config is being updated creds_to_use = ( credentials if credentials is not None else ( federated_connector.credentials.get_value(apply_mask=False) if federated_connector.credentials else {} ) ) if credentials is not None: # Validate credentials before updating if not validate_federated_connector_credentials( federated_connector.source, credentials ): raise ValueError( f"Invalid credentials for federated connector source: {federated_connector.source}" ) federated_connector.credentials = credentials # type: ignore[assignment] if config is not None: # Validate config using connector-specific validation try: # Get connector instance to access validate_config method connector = get_federated_connector( federated_connector.source, creds_to_use ) if not connector.validate_config(config): raise ValueError( f"Invalid config for federated connector source: {federated_connector.source}" ) except Exception as e: raise ValueError( f"Config validation failed for {federated_connector.source}: {str(e)}" ) federated_connector.config = config db_session.commit() return federated_connector def delete_federated_connector( db_session: Session, federated_connector_id: int, ) -> bool: """Delete a federated connector and all its related data.""" federated_connector = fetch_federated_connector_by_id( federated_connector_id, db_session ) if not federated_connector: return False # Delete related OAuth tokens (cascade should handle this) # Delete related document set mappings (cascade should handle this) db_session.delete(federated_connector) db_session.commit() return True ================================================ FILE: backend/onyx/db/feedback.py ================================================ from datetime import datetime from datetime import timezone from uuid import UUID from fastapi import HTTPException from sqlalchemy import and_ from sqlalchemy import asc from sqlalchemy import delete from sqlalchemy import desc from sqlalchemy import exists from sqlalchemy import Select from sqlalchemy import select from sqlalchemy.orm import aliased from sqlalchemy.orm import Session from onyx.configs.constants import MessageType from onyx.configs.constants import SearchFeedbackType from onyx.db.chat import get_chat_message from onyx.db.enums import AccessType from onyx.db.models import ChatMessageFeedback from onyx.db.models import ConnectorCredentialPair from onyx.db.models import Document as DbDocument from onyx.db.models import DocumentByConnectorCredentialPair from onyx.db.models import DocumentRetrievalFeedback from onyx.db.models import User from onyx.db.models import User__UserGroup from onyx.db.models import UserGroup__ConnectorCredentialPair from onyx.db.models import UserRole from onyx.utils.logger import setup_logger logger = setup_logger() def _fetch_db_doc_by_id(doc_id: str, db_session: Session) -> DbDocument: stmt = select(DbDocument).where(DbDocument.id == doc_id) result = db_session.execute(stmt) doc = result.scalar_one_or_none() if not doc: raise ValueError("Invalid Document ID Provided") return doc def _add_user_filters(stmt: Select, user: User, get_editable: bool = True) -> Select: if user.role == UserRole.ADMIN: return stmt stmt = stmt.distinct() DocByCC = aliased(DocumentByConnectorCredentialPair) CCPair = aliased(ConnectorCredentialPair) UG__CCpair = aliased(UserGroup__ConnectorCredentialPair) User__UG = aliased(User__UserGroup) """ Here we select documents by relation: User -> User__UserGroup -> UserGroup__ConnectorCredentialPair -> ConnectorCredentialPair -> DocumentByConnectorCredentialPair -> Document """ stmt = ( stmt.outerjoin(DocByCC, DocByCC.id == DbDocument.id) .outerjoin( CCPair, and_( CCPair.connector_id == DocByCC.connector_id, CCPair.credential_id == DocByCC.credential_id, ), ) .outerjoin(UG__CCpair, UG__CCpair.cc_pair_id == CCPair.id) .outerjoin(User__UG, User__UG.user_group_id == UG__CCpair.user_group_id) ) """ Filter Documents by: - if the user is in the user_group that owns the object - if the user is not a global_curator, they must also have a curator relationship to the user_group - if editing is being done, we also filter out objects that are owned by groups that the user isn't a curator for - if we are not editing, we show all objects in the groups the user is a curator for (as well as public objects as well) """ # Anonymous users only see public documents if user.is_anonymous: where_clause = CCPair.access_type == AccessType.PUBLIC return stmt.where(where_clause) where_clause = User__UG.user_id == user.id if user.role == UserRole.CURATOR and get_editable: where_clause &= User__UG.is_curator == True # noqa: E712 if get_editable: user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id) where_clause &= ( ~exists() .where(UG__CCpair.cc_pair_id == CCPair.id) .where(~UG__CCpair.user_group_id.in_(user_groups)) .correlate(CCPair) ) else: where_clause |= CCPair.access_type == AccessType.PUBLIC return stmt.where(where_clause) def fetch_docs_ranked_by_boost_for_user( db_session: Session, user: User, ascending: bool = False, limit: int = 100, ) -> list[DbDocument]: order_func = asc if ascending else desc stmt = select(DbDocument) stmt = _add_user_filters(stmt=stmt, user=user, get_editable=False) stmt = stmt.order_by( order_func(DbDocument.boost), order_func(DbDocument.semantic_id) ) stmt = stmt.limit(limit) result = db_session.execute(stmt) doc_list = result.scalars().all() return list(doc_list) def update_document_boost_for_user( db_session: Session, document_id: str, boost: int, user: User, ) -> None: stmt = select(DbDocument).where(DbDocument.id == document_id) stmt = _add_user_filters(stmt, user, get_editable=True) result: DbDocument | None = db_session.execute(stmt).scalar_one_or_none() if result is None: raise HTTPException( status_code=400, detail="Document is not editable by this user" ) result.boost = boost # updating last_modified triggers sync # TODO: Should this submit to the queue directly so that the UI can update? result.last_modified = datetime.now(timezone.utc) db_session.commit() def update_document_hidden_for_user( db_session: Session, document_id: str, hidden: bool, user: User, ) -> None: stmt = select(DbDocument).where(DbDocument.id == document_id) stmt = _add_user_filters(stmt, user, get_editable=True) result = db_session.execute(stmt).scalar_one_or_none() if result is None: raise HTTPException( status_code=400, detail="Document is not editable by this user" ) result.hidden = hidden # updating last_modified triggers sync # TODO: Should this submit to the queue directly so that the UI can update? result.last_modified = datetime.now(timezone.utc) db_session.commit() def create_doc_retrieval_feedback( message_id: int, document_id: str, document_rank: int, db_session: Session, clicked: bool = False, feedback: SearchFeedbackType | None = None, ) -> None: """Creates a new Document feedback row and updates the boost value in Postgres and Vespa""" db_doc = _fetch_db_doc_by_id(document_id, db_session) retrieval_feedback = DocumentRetrievalFeedback( chat_message_id=message_id, document_id=document_id, document_rank=document_rank, clicked=clicked, feedback=feedback, ) if feedback is not None: if feedback == SearchFeedbackType.ENDORSE: db_doc.boost += 1 elif feedback == SearchFeedbackType.REJECT: db_doc.boost -= 1 elif feedback == SearchFeedbackType.HIDE: db_doc.hidden = True elif feedback == SearchFeedbackType.UNHIDE: db_doc.hidden = False else: raise ValueError("Unhandled document feedback type") if feedback in [ SearchFeedbackType.ENDORSE, SearchFeedbackType.REJECT, SearchFeedbackType.HIDE, ]: # updating last_modified triggers sync # TODO: Should this submit to the queue directly so that the UI can update? db_doc.last_modified = datetime.now(timezone.utc) db_session.add(retrieval_feedback) db_session.commit() def delete_document_feedback_for_documents__no_commit( document_ids: list[str], db_session: Session ) -> None: """NOTE: does not commit transaction so that this can be used as part of a larger transaction block.""" stmt = delete(DocumentRetrievalFeedback).where( DocumentRetrievalFeedback.document_id.in_(document_ids) ) db_session.execute(stmt) def create_chat_message_feedback( is_positive: bool | None, feedback_text: str | None, chat_message_id: int, user_id: UUID | None, db_session: Session, # Slack user requested help from human required_followup: bool | None = None, predefined_feedback: str | None = None, # Added predefined_feedback parameter ) -> None: if ( is_positive is None and feedback_text is None and required_followup is None and predefined_feedback is None ): raise ValueError("No feedback provided") chat_message = get_chat_message( chat_message_id=chat_message_id, user_id=user_id, db_session=db_session ) if chat_message.message_type != MessageType.ASSISTANT: raise ValueError("Can only provide feedback on LLM Outputs") message_feedback = ChatMessageFeedback( chat_message_id=chat_message_id, is_positive=is_positive, feedback_text=feedback_text, required_followup=required_followup, predefined_feedback=predefined_feedback, ) db_session.add(message_feedback) db_session.commit() def remove_chat_message_feedback( chat_message_id: int, user_id: UUID | None, db_session: Session, ) -> None: """Remove all feedback for a chat message.""" chat_message = get_chat_message( chat_message_id=chat_message_id, user_id=user_id, db_session=db_session ) if chat_message.message_type != MessageType.ASSISTANT: raise ValueError("Can only remove feedback from LLM Outputs") # Delete all feedback for this message db_session.query(ChatMessageFeedback).filter( ChatMessageFeedback.chat_message_id == chat_message_id ).delete() db_session.commit() ================================================ FILE: backend/onyx/db/file_content.py ================================================ from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session from onyx.db.models import FileContent def get_file_content_by_file_id( file_id: str, db_session: Session, ) -> FileContent: record = db_session.query(FileContent).filter_by(file_id=file_id).first() if not record: raise RuntimeError( f"File content for file_id {file_id} does not exist or was deleted" ) return record def get_file_content_by_file_id_optional( file_id: str, db_session: Session, ) -> FileContent | None: return db_session.query(FileContent).filter_by(file_id=file_id).first() def upsert_file_content( file_id: str, lobj_oid: int, file_size: int, db_session: Session, ) -> FileContent: """Atomic upsert using INSERT ... ON CONFLICT DO UPDATE to avoid race conditions when concurrent calls target the same file_id.""" stmt = insert(FileContent).values( file_id=file_id, lobj_oid=lobj_oid, file_size=file_size, ) stmt = stmt.on_conflict_do_update( index_elements=[FileContent.file_id], set_={ "lobj_oid": stmt.excluded.lobj_oid, "file_size": stmt.excluded.file_size, }, ) db_session.execute(stmt) # Return the merged ORM instance so callers can inspect the result return db_session.get(FileContent, file_id) # type: ignore[return-value] def transfer_file_content_file_id( old_file_id: str, new_file_id: str, db_session: Session, ) -> None: """Move a file_content row from old_file_id to new_file_id in-place. This avoids creating a duplicate row that shares the same Large Object OID, keeping OID ownership unique at all times. The caller must ensure that new_file_id already exists in file_record (FK target).""" rows = ( db_session.query(FileContent) .filter_by(file_id=old_file_id) .update({"file_id": new_file_id}) ) if not rows: raise RuntimeError( f"File content for file_id {old_file_id} does not exist or was deleted" ) def delete_file_content_by_file_id( file_id: str, db_session: Session, ) -> None: db_session.query(FileContent).filter_by(file_id=file_id).delete() ================================================ FILE: backend/onyx/db/file_record.py ================================================ from sqlalchemy import and_ from sqlalchemy import select from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session from onyx.background.task_utils import QUERY_REPORT_NAME_PREFIX from onyx.configs.constants import FileOrigin from onyx.configs.constants import FileType from onyx.db.models import FileRecord def get_query_history_export_files( db_session: Session, ) -> list[FileRecord]: return list( db_session.scalars( select(FileRecord).where( and_( FileRecord.file_id.like(f"{QUERY_REPORT_NAME_PREFIX}-%"), FileRecord.file_type == FileType.CSV, FileRecord.file_origin == FileOrigin.QUERY_HISTORY_CSV, ) ) ) ) def get_filerecord_by_file_id_optional( file_id: str, db_session: Session, ) -> FileRecord | None: return db_session.query(FileRecord).filter_by(file_id=file_id).first() def get_filerecord_by_file_id( file_id: str, db_session: Session, ) -> FileRecord: filestore = db_session.query(FileRecord).filter_by(file_id=file_id).first() if not filestore: raise RuntimeError(f"File by id {file_id} does not exist or was deleted") return filestore def get_filerecord_by_prefix( prefix: str, db_session: Session, ) -> list[FileRecord]: if not prefix: return db_session.query(FileRecord).all() return ( db_session.query(FileRecord).filter(FileRecord.file_id.like(f"{prefix}%")).all() ) def delete_filerecord_by_file_id( file_id: str, db_session: Session, ) -> None: db_session.query(FileRecord).filter_by(file_id=file_id).delete() def upsert_filerecord( file_id: str, display_name: str, file_origin: FileOrigin, file_type: str, bucket_name: str, object_key: str, db_session: Session, file_metadata: dict | None = None, ) -> FileRecord: """Atomic upsert using INSERT ... ON CONFLICT DO UPDATE to avoid race conditions when concurrent calls target the same file_id.""" stmt = insert(FileRecord).values( file_id=file_id, display_name=display_name, file_origin=file_origin, file_type=file_type, file_metadata=file_metadata, bucket_name=bucket_name, object_key=object_key, ) stmt = stmt.on_conflict_do_update( index_elements=[FileRecord.file_id], set_={ "display_name": stmt.excluded.display_name, "file_origin": stmt.excluded.file_origin, "file_type": stmt.excluded.file_type, "file_metadata": stmt.excluded.file_metadata, "bucket_name": stmt.excluded.bucket_name, "object_key": stmt.excluded.object_key, }, ) db_session.execute(stmt) return db_session.get(FileRecord, file_id) # type: ignore[return-value] ================================================ FILE: backend/onyx/db/hierarchy.py ================================================ """CRUD operations for HierarchyNode.""" from collections import defaultdict from sqlalchemy import delete from sqlalchemy import select from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.engine import CursorResult from sqlalchemy.orm import Session from onyx.configs.constants import DocumentSource from onyx.connectors.models import HierarchyNode as PydanticHierarchyNode from onyx.db.enums import HierarchyNodeType from onyx.db.models import Document from onyx.db.models import HierarchyNode from onyx.db.models import HierarchyNodeByConnectorCredentialPair from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_versioned_implementation logger = setup_logger() # Sources where hierarchy nodes can also be documents. # For these sources, pages/items can be both a hierarchy node (with children) # AND a document with indexed content. For example: # - Notion: Pages with child pages are hierarchy nodes, but also documents # - Confluence: Pages can have child pages and also contain content # Other sources like Google Drive have folders as hierarchy nodes, but folders # are not documents themselves. SOURCES_WITH_HIERARCHY_NODE_DOCUMENTS: set[DocumentSource] = { DocumentSource.NOTION, DocumentSource.CONFLUENCE, } def _get_source_display_name(source: DocumentSource) -> str: """Get a human-readable display name for a source type.""" return source.value.replace("_", " ").title() def get_hierarchy_node_by_raw_id( db_session: Session, raw_node_id: str, source: DocumentSource, ) -> HierarchyNode | None: """Get a hierarchy node by its raw ID and source.""" stmt = select(HierarchyNode).where( HierarchyNode.raw_node_id == raw_node_id, HierarchyNode.source == source, ) return db_session.execute(stmt).scalar_one_or_none() def get_source_hierarchy_node( db_session: Session, source: DocumentSource, ) -> HierarchyNode | None: """Get the SOURCE-type root node for a given source.""" stmt = select(HierarchyNode).where( HierarchyNode.source == source, HierarchyNode.node_type == HierarchyNodeType.SOURCE, ) return db_session.execute(stmt).scalar_one_or_none() def ensure_source_node_exists( db_session: Session, source: DocumentSource, commit: bool = True, ) -> HierarchyNode: """ Ensure that a SOURCE-type root node exists for the given source. This function is idempotent - it will return the existing SOURCE node if one exists, or create a new one if not. The SOURCE node is the root of the hierarchy tree for a given source type (e.g., "Google Drive", "Confluence"). All other hierarchy nodes for that source should ultimately have this node as an ancestor. For the SOURCE node: - raw_node_id is set to the source name (e.g., "google_drive") - parent_id is None (it's the root) - display_name is a human-readable version (e.g., "Google Drive") Args: db_session: SQLAlchemy session source: The document source type commit: Whether to commit the transaction Returns: The existing or newly created SOURCE-type HierarchyNode """ # Try to get existing SOURCE node first existing_node = get_source_hierarchy_node(db_session, source) if existing_node: return existing_node # Create the SOURCE node display_name = _get_source_display_name(source) source_node = HierarchyNode( raw_node_id=source.value, # Use source name as raw_node_id display_name=display_name, link=None, source=source, node_type=HierarchyNodeType.SOURCE, document_id=None, parent_id=None, # SOURCE nodes have no parent ) db_session.add(source_node) # Flush to get the ID and detect any race conditions try: db_session.flush() except Exception: # Race condition - another worker created it. Roll back and fetch. db_session.rollback() existing_node = get_source_hierarchy_node(db_session, source) if existing_node: return existing_node # If still not found, re-raise the original exception raise if commit: db_session.commit() logger.info( f"Created SOURCE hierarchy node for {source.value}: id={source_node.id}, display_name={display_name}" ) return source_node def resolve_parent_hierarchy_node_id( db_session: Session, raw_parent_id: str | None, source: DocumentSource, ) -> int | None: """ Resolve a raw_parent_id to a database HierarchyNode ID. If raw_parent_id is None, returns the SOURCE node ID for backward compatibility. If the parent node doesn't exist, returns the SOURCE node ID as fallback. """ if raw_parent_id is None: # No parent specified - use the SOURCE node source_node = get_source_hierarchy_node(db_session, source) return source_node.id if source_node else None parent_node = get_hierarchy_node_by_raw_id(db_session, raw_parent_id, source) if parent_node: return parent_node.id # Parent not found - fall back to SOURCE node logger.warning( f"Parent hierarchy node not found: raw_id={raw_parent_id}, source={source}. Falling back to SOURCE node." ) source_node = get_source_hierarchy_node(db_session, source) return source_node.id if source_node else None def upsert_parents( db_session: Session, node: PydanticHierarchyNode, source: DocumentSource, node_by_id: dict[str, PydanticHierarchyNode], done_ids: set[str], is_connector_public: bool = False, ) -> None: """ Upsert the parents of a hierarchy node. """ if ( node.node_type == HierarchyNodeType.SOURCE or (node.raw_parent_id not in node_by_id) or (node.raw_parent_id in done_ids) ): return parent_node = node_by_id[node.raw_parent_id] upsert_parents( db_session, parent_node, source, node_by_id, done_ids, is_connector_public=is_connector_public, ) upsert_hierarchy_node( db_session, parent_node, source, commit=False, is_connector_public=is_connector_public, ) done_ids.add(parent_node.raw_node_id) def upsert_hierarchy_node( db_session: Session, node: PydanticHierarchyNode, source: DocumentSource, commit: bool = True, is_connector_public: bool = False, ) -> HierarchyNode: """ Upsert a hierarchy node from a Pydantic model. If a node with the same raw_node_id and source exists, updates it. Otherwise, creates a new node. Args: db_session: SQLAlchemy session node: The Pydantic hierarchy node to upsert source: Document source type commit: Whether to commit the transaction is_connector_public: If True, the connector is public (organization-wide access) and all hierarchy nodes should be marked as public regardless of their external_access settings. This ensures nodes from public connectors are accessible to all users. """ # Resolve parent_id from raw_parent_id parent_id = ( None if node.node_type == HierarchyNodeType.SOURCE else resolve_parent_hierarchy_node_id(db_session, node.raw_parent_id, source) ) # For public connectors, all nodes are public # Otherwise, extract permission fields from external_access if present if is_connector_public: is_public = True external_user_emails: list[str] | None = None external_user_group_ids: list[str] | None = None elif node.external_access: is_public = node.external_access.is_public external_user_emails = ( list(node.external_access.external_user_emails) if node.external_access.external_user_emails else None ) external_user_group_ids = ( list(node.external_access.external_user_group_ids) if node.external_access.external_user_group_ids else None ) else: is_public = False external_user_emails = None external_user_group_ids = None # Check if node already exists existing_node = get_hierarchy_node_by_raw_id(db_session, node.raw_node_id, source) if existing_node: # Update existing node existing_node.display_name = node.display_name existing_node.link = node.link existing_node.node_type = node.node_type existing_node.parent_id = parent_id # Update permission fields existing_node.is_public = is_public existing_node.external_user_emails = external_user_emails existing_node.external_user_group_ids = external_user_group_ids hierarchy_node = existing_node else: # Create new node hierarchy_node = HierarchyNode( raw_node_id=node.raw_node_id, display_name=node.display_name, link=node.link, source=source, node_type=node.node_type, parent_id=parent_id, is_public=is_public, external_user_emails=external_user_emails, external_user_group_ids=external_user_group_ids, ) db_session.add(hierarchy_node) if commit: db_session.commit() else: db_session.flush() return hierarchy_node def upsert_hierarchy_nodes_batch( db_session: Session, nodes: list[PydanticHierarchyNode], source: DocumentSource, commit: bool = True, is_connector_public: bool = False, ) -> list[HierarchyNode]: """ Batch upsert hierarchy nodes. Note: This function requires that for each node passed in, all its ancestors exist in either the database or elsewhere in the nodes list. This function handles parent dependencies for you as long as that condition is met (so you don't need to worry about parent nodes appearing before their children in the list). Args: db_session: SQLAlchemy session nodes: List of Pydantic hierarchy nodes to upsert source: Document source type commit: Whether to commit the transaction is_connector_public: If True, the connector is public (organization-wide access) and all hierarchy nodes should be marked as public regardless of their external_access settings. """ node_by_id = {} for node in nodes: if node.node_type != HierarchyNodeType.SOURCE: node_by_id[node.raw_node_id] = node done_ids = set[str]() results = [] for node in nodes: if node.raw_node_id in done_ids: continue upsert_parents( db_session, node, source, node_by_id, done_ids, is_connector_public=is_connector_public, ) hierarchy_node = upsert_hierarchy_node( db_session, node, source, commit=False, is_connector_public=is_connector_public, ) done_ids.add(node.raw_node_id) results.append(hierarchy_node) if commit: db_session.commit() return results def link_hierarchy_nodes_to_documents( db_session: Session, document_ids: list[str], source: DocumentSource, commit: bool = True, ) -> int: """ Link hierarchy nodes to their corresponding documents. For connectors like Notion and Confluence where pages can be both hierarchy nodes AND documents, we need to set the document_id field on hierarchy nodes after the documents are created. This is because hierarchy nodes are processed before documents, and the FK constraint on document_id requires the document to exist first. Args: db_session: SQLAlchemy session document_ids: List of document IDs that were just created/updated source: The document source (e.g., NOTION, CONFLUENCE) commit: Whether to commit the transaction Returns: Number of hierarchy nodes that were linked to documents """ # Skip for sources where hierarchy nodes cannot also be documents if source not in SOURCES_WITH_HIERARCHY_NODE_DOCUMENTS: return 0 if not document_ids: return 0 # Find hierarchy nodes where raw_node_id matches a document_id # These are pages that are both hierarchy nodes and documents stmt = select(HierarchyNode).where( HierarchyNode.source == source, HierarchyNode.raw_node_id.in_(document_ids), HierarchyNode.document_id.is_(None), # Only update if not already linked ) nodes_to_update = list(db_session.execute(stmt).scalars().all()) # Update document_id for each matching node for node in nodes_to_update: node.document_id = node.raw_node_id if commit: db_session.commit() if nodes_to_update: logger.debug( f"Linked {len(nodes_to_update)} hierarchy nodes to documents for source {source.value}" ) return len(nodes_to_update) def get_hierarchy_node_children( db_session: Session, parent_id: int, limit: int = 100, offset: int = 0, ) -> list[HierarchyNode]: """Get children of a hierarchy node, paginated.""" stmt = ( select(HierarchyNode) .where(HierarchyNode.parent_id == parent_id) .order_by(HierarchyNode.display_name) .limit(limit) .offset(offset) ) return list(db_session.execute(stmt).scalars().all()) def get_hierarchy_node_by_id( db_session: Session, node_id: int, ) -> HierarchyNode | None: """Get a hierarchy node by its database ID.""" return db_session.get(HierarchyNode, node_id) def get_root_hierarchy_nodes_for_source( db_session: Session, source: DocumentSource, ) -> list[HierarchyNode]: """Get all root-level hierarchy nodes for a source (children of SOURCE node).""" source_node = get_source_hierarchy_node(db_session, source) if not source_node: return [] return get_hierarchy_node_children(db_session, source_node.id) def get_all_hierarchy_nodes_for_source( db_session: Session, source: DocumentSource, ) -> list[HierarchyNode]: """ Get ALL hierarchy nodes for a given source. This is used to populate the Redis cache. Returns all nodes including the SOURCE-type root node. Args: db_session: SQLAlchemy session source: The document source to get nodes for Returns: List of all HierarchyNode objects for the source """ stmt = select(HierarchyNode).where(HierarchyNode.source == source) return list(db_session.execute(stmt).scalars().all()) def _get_accessible_hierarchy_nodes_for_source( db_session: Session, source: DocumentSource, user_email: str, # noqa: ARG001 external_group_ids: list[str], # noqa: ARG001 ) -> list[HierarchyNode]: """ MIT version: Returns all hierarchy nodes for the source without permission filtering. In the MIT version, permission checks are not performed on hierarchy nodes. The EE version overrides this to apply permission filtering based on user email and external group IDs. Args: db_session: SQLAlchemy session source: Document source type user_email: User's email (unused in MIT version) external_group_ids: User's external group IDs (unused in MIT version) Returns: List of all HierarchyNode objects for the source """ stmt = select(HierarchyNode).where(HierarchyNode.source == source) stmt = stmt.order_by(HierarchyNode.display_name) return list(db_session.execute(stmt).scalars().all()) def get_accessible_hierarchy_nodes_for_source( db_session: Session, source: DocumentSource, user_email: str, external_group_ids: list[str], ) -> list[HierarchyNode]: """ Get hierarchy nodes for a source that are accessible to the user. Uses fetch_versioned_implementation to get the appropriate version: - MIT version: Returns all nodes (no permission filtering) - EE version: Filters based on user email and external group IDs """ versioned_fn = fetch_versioned_implementation( "onyx.db.hierarchy", "_get_accessible_hierarchy_nodes_for_source" ) return versioned_fn(db_session, source, user_email, external_group_ids) def get_document_parent_hierarchy_node_ids( db_session: Session, document_ids: list[str], ) -> dict[str, int | None]: """ Get the parent_hierarchy_node_id for multiple documents in a single query. Args: db_session: SQLAlchemy session document_ids: List of document IDs to look up Returns: Dict mapping document_id -> parent_hierarchy_node_id (or None if not set) """ if not document_ids: return {} stmt = select(Document.id, Document.parent_hierarchy_node_id).where( Document.id.in_(document_ids) ) results = db_session.execute(stmt).all() return {doc_id: parent_id for doc_id, parent_id in results} def update_document_parent_hierarchy_nodes( db_session: Session, doc_parent_map: dict[str, int | None], commit: bool = True, ) -> int: """Bulk-update Document.parent_hierarchy_node_id for multiple documents. Only updates rows whose current value differs from the desired value to avoid unnecessary writes. Args: db_session: SQLAlchemy session doc_parent_map: Mapping of document_id → desired parent_hierarchy_node_id commit: Whether to commit the transaction Returns: Number of documents actually updated """ if not doc_parent_map: return 0 doc_ids = list(doc_parent_map.keys()) existing = get_document_parent_hierarchy_node_ids(db_session, doc_ids) by_parent: dict[int | None, list[str]] = defaultdict(list) for doc_id, desired_parent_id in doc_parent_map.items(): current = existing.get(doc_id) if current == desired_parent_id or doc_id not in existing: continue by_parent[desired_parent_id].append(doc_id) updated = 0 for desired_parent_id, ids in by_parent.items(): db_session.query(Document).filter(Document.id.in_(ids)).update( {Document.parent_hierarchy_node_id: desired_parent_id}, synchronize_session=False, ) updated += len(ids) if commit: db_session.commit() elif updated: db_session.flush() return updated def update_hierarchy_node_permissions( db_session: Session, raw_node_id: str, source: DocumentSource, is_public: bool, external_user_emails: list[str] | None, external_user_group_ids: list[str] | None, commit: bool = True, ) -> bool: """ Update permissions for an existing hierarchy node. This is used during permission sync to update folder permissions without needing the full Pydantic HierarchyNode model. Args: db_session: SQLAlchemy session raw_node_id: Raw node ID from the source system source: Document source type is_public: Whether the node is public external_user_emails: List of user emails with access external_user_group_ids: List of group IDs with access commit: Whether to commit the transaction Returns: True if the node was found and updated, False if not found """ existing_node = get_hierarchy_node_by_raw_id(db_session, raw_node_id, source) if not existing_node: logger.warning( f"Hierarchy node not found for permission update: raw_node_id={raw_node_id}, source={source}" ) return False existing_node.is_public = is_public existing_node.external_user_emails = external_user_emails existing_node.external_user_group_ids = external_user_group_ids if commit: db_session.commit() else: db_session.flush() return True def upsert_hierarchy_node_cc_pair_entries( db_session: Session, hierarchy_node_ids: list[int], connector_id: int, credential_id: int, commit: bool = True, ) -> None: """Insert rows into HierarchyNodeByConnectorCredentialPair, ignoring conflicts. This records that the given cc_pair "owns" these hierarchy nodes. Used by indexing, pruning, and hierarchy-fetching paths. """ if not hierarchy_node_ids: return _M = HierarchyNodeByConnectorCredentialPair stmt = pg_insert(_M).values( [ { _M.hierarchy_node_id: node_id, _M.connector_id: connector_id, _M.credential_id: credential_id, } for node_id in hierarchy_node_ids ] ) stmt = stmt.on_conflict_do_nothing() db_session.execute(stmt) if commit: db_session.commit() else: db_session.flush() def remove_stale_hierarchy_node_cc_pair_entries( db_session: Session, connector_id: int, credential_id: int, live_hierarchy_node_ids: set[int], commit: bool = True, ) -> int: """Delete join-table rows for this cc_pair that are NOT in the live set. If ``live_hierarchy_node_ids`` is empty ALL rows for the cc_pair are deleted (i.e. the connector no longer has any hierarchy nodes). Callers that want a no-op when there are no live nodes must guard before calling. Returns the number of deleted rows. """ stmt = delete(HierarchyNodeByConnectorCredentialPair).where( HierarchyNodeByConnectorCredentialPair.connector_id == connector_id, HierarchyNodeByConnectorCredentialPair.credential_id == credential_id, ) if live_hierarchy_node_ids: stmt = stmt.where( HierarchyNodeByConnectorCredentialPair.hierarchy_node_id.notin_( live_hierarchy_node_ids ) ) result: CursorResult = db_session.execute(stmt) # type: ignore[assignment] deleted = result.rowcount if commit: db_session.commit() elif deleted: db_session.flush() return deleted def delete_orphaned_hierarchy_nodes( db_session: Session, source: DocumentSource, commit: bool = True, ) -> list[str]: """Delete hierarchy nodes for a source that have zero cc_pair associations. SOURCE-type nodes are excluded (they are synthetic roots). Returns the list of raw_node_ids that were deleted (for cache eviction). """ # Find orphaned nodes: no rows in the join table orphan_stmt = ( select(HierarchyNode.id, HierarchyNode.raw_node_id) .outerjoin( HierarchyNodeByConnectorCredentialPair, HierarchyNode.id == HierarchyNodeByConnectorCredentialPair.hierarchy_node_id, ) .where( HierarchyNode.source == source, HierarchyNode.node_type != HierarchyNodeType.SOURCE, HierarchyNodeByConnectorCredentialPair.hierarchy_node_id.is_(None), ) ) orphans = db_session.execute(orphan_stmt).all() if not orphans: return [] orphan_ids = [row[0] for row in orphans] deleted_raw_ids = [row[1] for row in orphans] db_session.execute(delete(HierarchyNode).where(HierarchyNode.id.in_(orphan_ids))) if commit: db_session.commit() else: db_session.flush() return deleted_raw_ids def reparent_orphaned_hierarchy_nodes( db_session: Session, source: DocumentSource, commit: bool = True, ) -> list[HierarchyNode]: """Re-parent hierarchy nodes whose parent_id is NULL to the SOURCE node. After pruning deletes stale nodes, their former children get parent_id=NULL via the SET NULL cascade. This function points them back to the SOURCE root. Returns the reparented HierarchyNode objects (with updated parent_id) so callers can refresh downstream caches. """ source_node = get_source_hierarchy_node(db_session, source) if not source_node: return [] stmt = select(HierarchyNode).where( HierarchyNode.source == source, HierarchyNode.parent_id.is_(None), HierarchyNode.node_type != HierarchyNodeType.SOURCE, ) orphans = list(db_session.execute(stmt).scalars().all()) if not orphans: return [] for node in orphans: node.parent_id = source_node.id if commit: db_session.commit() else: db_session.flush() return orphans ================================================ FILE: backend/onyx/db/hook.py ================================================ import datetime from uuid import UUID from sqlalchemy import delete from sqlalchemy import select from sqlalchemy.engine import CursorResult from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from onyx.db.constants import UNSET from onyx.db.constants import UnsetType from onyx.db.enums import HookFailStrategy from onyx.db.enums import HookPoint from onyx.db.models import Hook from onyx.db.models import HookExecutionLog from onyx.error_handling.error_codes import OnyxErrorCode from onyx.error_handling.exceptions import OnyxError # ── Hook CRUD ──────────────────────────────────────────────────────────── def get_hook_by_id( *, db_session: Session, hook_id: int, include_deleted: bool = False, include_creator: bool = False, ) -> Hook | None: stmt = select(Hook).where(Hook.id == hook_id) if not include_deleted: stmt = stmt.where(Hook.deleted.is_(False)) if include_creator: stmt = stmt.options(selectinload(Hook.creator)) return db_session.scalar(stmt) def get_non_deleted_hook_by_hook_point( *, db_session: Session, hook_point: HookPoint, include_creator: bool = False, ) -> Hook | None: stmt = ( select(Hook).where(Hook.hook_point == hook_point).where(Hook.deleted.is_(False)) ) if include_creator: stmt = stmt.options(selectinload(Hook.creator)) return db_session.scalar(stmt) def get_hooks( *, db_session: Session, include_deleted: bool = False, include_creator: bool = False, ) -> list[Hook]: stmt = select(Hook) if not include_deleted: stmt = stmt.where(Hook.deleted.is_(False)) if include_creator: stmt = stmt.options(selectinload(Hook.creator)) stmt = stmt.order_by(Hook.hook_point, Hook.created_at.desc()) return list(db_session.scalars(stmt).all()) def create_hook__no_commit( *, db_session: Session, name: str, hook_point: HookPoint, endpoint_url: str | None = None, api_key: str | None = None, fail_strategy: HookFailStrategy, timeout_seconds: float, is_active: bool = False, is_reachable: bool | None = None, creator_id: UUID | None = None, ) -> Hook: """Create a new hook for the given hook point. At most one non-deleted hook per hook point is allowed. Raises OnyxError(CONFLICT) if a hook already exists, including under concurrent duplicate creates where the partial unique index fires an IntegrityError. """ existing = get_non_deleted_hook_by_hook_point( db_session=db_session, hook_point=hook_point ) if existing: raise OnyxError( OnyxErrorCode.CONFLICT, f"A hook for '{hook_point.value}' already exists (id={existing.id}).", ) hook = Hook( name=name, hook_point=hook_point, endpoint_url=endpoint_url, api_key=api_key, fail_strategy=fail_strategy, timeout_seconds=timeout_seconds, is_active=is_active, is_reachable=is_reachable, creator_id=creator_id, ) # Use a savepoint so that a failed insert only rolls back this operation, # not the entire outer transaction. savepoint = db_session.begin_nested() try: db_session.add(hook) savepoint.commit() except IntegrityError as exc: savepoint.rollback() if "ix_hook_one_non_deleted_per_point" in str(exc.orig): raise OnyxError( OnyxErrorCode.CONFLICT, f"A hook for '{hook_point.value}' already exists.", ) raise # re-raise unrelated integrity errors (FK violations, etc.) return hook def update_hook__no_commit( *, db_session: Session, hook_id: int, name: str | None = None, endpoint_url: str | None | UnsetType = UNSET, api_key: str | None | UnsetType = UNSET, fail_strategy: HookFailStrategy | None = None, timeout_seconds: float | None = None, is_active: bool | None = None, is_reachable: bool | None = None, include_creator: bool = False, ) -> Hook: """Update hook fields. Sentinel conventions: - endpoint_url, api_key: pass UNSET to leave unchanged; pass None to clear. - name, fail_strategy, timeout_seconds, is_active, is_reachable: pass None to leave unchanged. """ hook = get_hook_by_id( db_session=db_session, hook_id=hook_id, include_creator=include_creator ) if hook is None: raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook with id {hook_id} not found.") if name is not None: hook.name = name if not isinstance(endpoint_url, UnsetType): hook.endpoint_url = endpoint_url if not isinstance(api_key, UnsetType): hook.api_key = api_key # type: ignore[assignment] # EncryptedString coerces str → SensitiveValue at the ORM level if fail_strategy is not None: hook.fail_strategy = fail_strategy if timeout_seconds is not None: hook.timeout_seconds = timeout_seconds if is_active is not None: hook.is_active = is_active if is_reachable is not None: hook.is_reachable = is_reachable db_session.flush() return hook def delete_hook__no_commit( *, db_session: Session, hook_id: int, ) -> None: hook = get_hook_by_id(db_session=db_session, hook_id=hook_id) if hook is None: raise OnyxError(OnyxErrorCode.NOT_FOUND, f"Hook with id {hook_id} not found.") hook.deleted = True hook.is_active = False db_session.flush() # ── HookExecutionLog CRUD ──────────────────────────────────────────────── def create_hook_execution_log__no_commit( *, db_session: Session, hook_id: int, is_success: bool, error_message: str | None = None, status_code: int | None = None, duration_ms: int | None = None, ) -> HookExecutionLog: log = HookExecutionLog( hook_id=hook_id, is_success=is_success, error_message=error_message, status_code=status_code, duration_ms=duration_ms, ) db_session.add(log) db_session.flush() return log def get_hook_execution_logs( *, db_session: Session, hook_id: int, limit: int, ) -> list[HookExecutionLog]: stmt = ( select(HookExecutionLog) .where(HookExecutionLog.hook_id == hook_id) .order_by(HookExecutionLog.created_at.desc()) .limit(limit) ) return list(db_session.scalars(stmt).all()) def cleanup_old_execution_logs__no_commit( *, db_session: Session, max_age_days: int, ) -> int: """Delete execution logs older than max_age_days. Returns the number of rows deleted.""" cutoff = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta( days=max_age_days ) result: CursorResult = db_session.execute( # type: ignore[assignment] delete(HookExecutionLog) .where(HookExecutionLog.created_at < cutoff) .execution_options(synchronize_session=False) ) return result.rowcount ================================================ FILE: backend/onyx/db/image_generation.py ================================================ from sqlalchemy import select from sqlalchemy import update from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from onyx.db.models import ImageGenerationConfig from onyx.db.models import LLMProvider from onyx.db.models import ModelConfiguration from onyx.llm.utils import get_max_input_tokens from onyx.utils.logger import setup_logger logger = setup_logger() # Default image generation config constants DEFAULT_IMAGE_PROVIDER_ID = "openai_gpt_image_1" DEFAULT_IMAGE_MODEL_NAME = "gpt-image-1" DEFAULT_IMAGE_PROVIDER = "openai" def create_image_generation_config__no_commit( db_session: Session, image_provider_id: str, model_configuration_id: int, is_default: bool = False, ) -> ImageGenerationConfig: """Create a new image generation config.""" # If setting as default, clear ALL existing defaults in a single atomic update # This is more atomic than select-then-update pattern if is_default: db_session.execute( update(ImageGenerationConfig) .where(ImageGenerationConfig.is_default.is_(True)) .values(is_default=False) ) new_config = ImageGenerationConfig( image_provider_id=image_provider_id, model_configuration_id=model_configuration_id, is_default=is_default, ) db_session.add(new_config) db_session.flush() return new_config def get_all_image_generation_configs( db_session: Session, ) -> list[ImageGenerationConfig]: """Get all image generation configs. Returns: List of all ImageGenerationConfig objects """ stmt = select(ImageGenerationConfig) return list(db_session.scalars(stmt).all()) def get_image_generation_config( db_session: Session, image_provider_id: str, ) -> ImageGenerationConfig | None: """Get a single image generation config by image_provider_id with relationships loaded. Args: db_session: Database session image_provider_id: The image provider ID (primary key) Returns: The ImageGenerationConfig or None if not found """ stmt = ( select(ImageGenerationConfig) .where(ImageGenerationConfig.image_provider_id == image_provider_id) .options( selectinload(ImageGenerationConfig.model_configuration).selectinload( ModelConfiguration.llm_provider ) ) ) return db_session.scalar(stmt) def get_default_image_generation_config( db_session: Session, ) -> ImageGenerationConfig | None: """Get the default image generation config. Returns: The default ImageGenerationConfig or None if not set """ stmt = ( select(ImageGenerationConfig) .where(ImageGenerationConfig.is_default.is_(True)) .options( selectinload(ImageGenerationConfig.model_configuration).selectinload( ModelConfiguration.llm_provider ) ) ) return db_session.scalar(stmt) def set_default_image_generation_config( db_session: Session, image_provider_id: str, ) -> None: """Set a config as the default (clears previous default). Args: db_session: Database session image_provider_id: The image provider ID to set as default Raises: ValueError: If config not found """ # Get the config to set as default new_default = db_session.get(ImageGenerationConfig, image_provider_id) if not new_default: raise ValueError( f"ImageGenerationConfig with image_provider_id {image_provider_id} not found" ) # Clear ALL existing defaults in a single atomic update # This is more atomic than select-then-update pattern db_session.execute( update(ImageGenerationConfig) .where( ImageGenerationConfig.is_default.is_(True), ImageGenerationConfig.image_provider_id != image_provider_id, ) .values(is_default=False) ) # Set new default new_default.is_default = True db_session.commit() def unset_default_image_generation_config( db_session: Session, image_provider_id: str, ) -> None: """Unset a config as the default.""" config = db_session.get(ImageGenerationConfig, image_provider_id) if not config: raise ValueError( f"ImageGenerationConfig with image_provider_id {image_provider_id} not found" ) config.is_default = False db_session.commit() def delete_image_generation_config__no_commit( db_session: Session, image_provider_id: str, ) -> None: """Delete an image generation config by image_provider_id.""" config = db_session.get(ImageGenerationConfig, image_provider_id) if not config: raise ValueError( f"ImageGenerationConfig with image_provider_id {image_provider_id} not found" ) db_session.delete(config) db_session.flush() def create_default_image_gen_config_from_api_key( db_session: Session, api_key: str, provider: str = DEFAULT_IMAGE_PROVIDER, image_provider_id: str = DEFAULT_IMAGE_PROVIDER_ID, model_name: str = DEFAULT_IMAGE_MODEL_NAME, ) -> ImageGenerationConfig | None: """Create default image gen config using an API key directly. This function is used during tenant provisioning to automatically create a default image generation config when an OpenAI provider is configured. Args: db_session: Database session api_key: API key for the LLM provider provider: Provider name (default: openai) image_provider_id: Static unique key for the config (default: openai_gpt_image_1) model_name: Model name for image generation (default: gpt-image-1) Returns: The created ImageGenerationConfig, or None if: - image_generation_config table already has records """ # Check if any image generation configs already exist (optimization to avoid work) existing_configs = get_all_image_generation_configs(db_session) if existing_configs: logger.info("Image generation config already exists, skipping default creation") return None try: # Create new LLM provider for image generation new_provider = LLMProvider( name=f"Image Gen - {image_provider_id}", provider=provider, api_key=api_key, api_base=None, api_version=None, deployment_name=None, is_public=True, ) db_session.add(new_provider) db_session.flush() # Create model configuration max_input_tokens = get_max_input_tokens( model_name=model_name, model_provider=provider, ) model_config = ModelConfiguration( llm_provider_id=new_provider.id, name=model_name, is_visible=True, max_input_tokens=max_input_tokens, ) db_session.add(model_config) db_session.flush() # Create image generation config config = create_image_generation_config__no_commit( db_session=db_session, image_provider_id=image_provider_id, model_configuration_id=model_config.id, is_default=True, ) db_session.commit() logger.info(f"Created default image generation config: {image_provider_id}") return config except Exception: db_session.rollback() logger.exception( f"Failed to create default image generation config {image_provider_id}" ) return None ================================================ FILE: backend/onyx/db/index_attempt.py ================================================ from collections.abc import Sequence from datetime import datetime from datetime import timedelta from datetime import timezone from typing import NamedTuple from typing import TYPE_CHECKING from typing import TypeVarTuple from sqlalchemy import and_ from sqlalchemy import delete from sqlalchemy import desc from sqlalchemy import func from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import update from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session from onyx.connectors.models import ConnectorFailure from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import IndexingStatus from onyx.db.enums import IndexModelStatus from onyx.db.models import ConnectorCredentialPair from onyx.db.models import IndexAttempt from onyx.db.models import IndexAttemptError from onyx.db.models import SearchSettings from onyx.server.documents.models import ConnectorCredentialPairIdentifier from onyx.utils.logger import setup_logger from onyx.utils.telemetry import optional_telemetry from onyx.utils.telemetry import RecordType if TYPE_CHECKING: from onyx.configs.constants import DocumentSource # from sqlalchemy.sql.selectable import Select # Comment out unused imports that cause mypy errors # from onyx.auth.models import UserRole # from onyx.configs.constants import MAX_LAST_VALID_CHECKPOINT_AGE_SECONDS # from onyx.db.connector_credential_pair import ConnectorCredentialPairIdentifier # from onyx.db.engine import async_query_for_dms logger = setup_logger() def get_last_attempt_for_cc_pair( cc_pair_id: int, search_settings_id: int, db_session: Session, ) -> IndexAttempt | None: return ( db_session.query(IndexAttempt) .filter( IndexAttempt.connector_credential_pair_id == cc_pair_id, IndexAttempt.search_settings_id == search_settings_id, ) .order_by(IndexAttempt.time_updated.desc()) .first() ) def get_recent_completed_attempts_for_cc_pair( cc_pair_id: int, search_settings_id: int, limit: int, db_session: Session, ) -> list[IndexAttempt]: """Most recent to least recent.""" return ( db_session.query(IndexAttempt) .filter( IndexAttempt.connector_credential_pair_id == cc_pair_id, IndexAttempt.search_settings_id == search_settings_id, IndexAttempt.status.notin_( [IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS] ), ) .order_by(IndexAttempt.time_updated.desc()) .limit(limit) .all() ) def get_recent_attempts_for_cc_pair( cc_pair_id: int, search_settings_id: int, limit: int, db_session: Session, ) -> list[IndexAttempt]: """Most recent to least recent.""" return ( db_session.query(IndexAttempt) .filter( IndexAttempt.connector_credential_pair_id == cc_pair_id, IndexAttempt.search_settings_id == search_settings_id, ) .order_by(IndexAttempt.time_updated.desc()) .limit(limit) .all() ) def get_index_attempt( db_session: Session, index_attempt_id: int, eager_load_cc_pair: bool = False, eager_load_search_settings: bool = False, ) -> IndexAttempt | None: stmt = select(IndexAttempt).where(IndexAttempt.id == index_attempt_id) if eager_load_cc_pair: stmt = stmt.options( joinedload(IndexAttempt.connector_credential_pair).joinedload( ConnectorCredentialPair.connector ) ) stmt = stmt.options( joinedload(IndexAttempt.connector_credential_pair).joinedload( ConnectorCredentialPair.credential ) ) if eager_load_search_settings: stmt = stmt.options(joinedload(IndexAttempt.search_settings)) return db_session.scalars(stmt).first() def count_error_rows_for_index_attempt( index_attempt_id: int, db_session: Session, ) -> int: return ( db_session.query(IndexAttemptError) .filter(IndexAttemptError.index_attempt_id == index_attempt_id) .count() ) def create_index_attempt( connector_credential_pair_id: int, search_settings_id: int, db_session: Session, from_beginning: bool = False, celery_task_id: str | None = None, ) -> int: new_attempt = IndexAttempt( connector_credential_pair_id=connector_credential_pair_id, search_settings_id=search_settings_id, from_beginning=from_beginning, status=IndexingStatus.NOT_STARTED, celery_task_id=celery_task_id, ) db_session.add(new_attempt) db_session.commit() return new_attempt.id def delete_index_attempt(db_session: Session, index_attempt_id: int) -> None: index_attempt = get_index_attempt(db_session, index_attempt_id) if index_attempt: db_session.delete(index_attempt) db_session.commit() def mock_successful_index_attempt( connector_credential_pair_id: int, search_settings_id: int, docs_indexed: int, db_session: Session, ) -> int: """Should not be used in any user triggered flows""" db_time = func.now() new_attempt = IndexAttempt( connector_credential_pair_id=connector_credential_pair_id, search_settings_id=search_settings_id, from_beginning=True, status=IndexingStatus.SUCCESS, total_docs_indexed=docs_indexed, new_docs_indexed=docs_indexed, # Need this to be some convincing random looking value and it can't be 0 # or the indexing rate would calculate out to infinity time_started=db_time - timedelta(seconds=1.92), time_updated=db_time, ) db_session.add(new_attempt) db_session.commit() return new_attempt.id def get_in_progress_index_attempts( connector_id: int | None, db_session: Session, ) -> list[IndexAttempt]: stmt = select(IndexAttempt) if connector_id is not None: stmt = stmt.where( IndexAttempt.connector_credential_pair.has(connector_id=connector_id) ) stmt = stmt.where(IndexAttempt.status == IndexingStatus.IN_PROGRESS) incomplete_attempts = db_session.scalars(stmt) return list(incomplete_attempts.all()) def get_all_index_attempts_by_status( status: IndexingStatus, db_session: Session ) -> list[IndexAttempt]: """Returns index attempts with the given status. Only recommend calling this with non-terminal states as the full list of terminal statuses may be quite large. Results are ordered by time_created (oldest to newest).""" stmt = select(IndexAttempt) stmt = stmt.where(IndexAttempt.status == status) stmt = stmt.order_by(IndexAttempt.time_created) new_attempts = db_session.scalars(stmt) return list(new_attempts.all()) def transition_attempt_to_in_progress( index_attempt_id: int, db_session: Session, ) -> IndexAttempt: """Locks the row when we try to update""" try: attempt = db_session.execute( select(IndexAttempt) .where(IndexAttempt.id == index_attempt_id) .with_for_update() ).scalar_one() if attempt is None: raise RuntimeError( f"Unable to find IndexAttempt for ID '{index_attempt_id}'" ) if attempt.status != IndexingStatus.NOT_STARTED: raise RuntimeError( f"Indexing attempt with ID '{index_attempt_id}' is not in NOT_STARTED status. " f"Current status is '{attempt.status}'." ) attempt.status = IndexingStatus.IN_PROGRESS attempt.time_started = attempt.time_started or func.now() # type: ignore db_session.commit() return attempt except Exception: db_session.rollback() logger.exception("transition_attempt_to_in_progress exceptioned.") raise def mark_attempt_in_progress( index_attempt: IndexAttempt, db_session: Session, ) -> None: try: attempt = db_session.execute( select(IndexAttempt) .where(IndexAttempt.id == index_attempt.id) .with_for_update() ).scalar_one() attempt.status = IndexingStatus.IN_PROGRESS attempt.time_started = index_attempt.time_started or func.now() # type: ignore db_session.commit() # Add telemetry for index attempt status change optional_telemetry( record_type=RecordType.INDEX_ATTEMPT_STATUS, data={ "index_attempt_id": index_attempt.id, "status": IndexingStatus.IN_PROGRESS.value, "cc_pair_id": index_attempt.connector_credential_pair_id, }, ) except Exception: db_session.rollback() raise def mark_attempt_succeeded( index_attempt_id: int, db_session: Session, ) -> IndexAttempt: try: attempt = db_session.execute( select(IndexAttempt) .where(IndexAttempt.id == index_attempt_id) .with_for_update() ).scalar_one() attempt.status = IndexingStatus.SUCCESS attempt.celery_task_id = None db_session.commit() # Add telemetry for index attempt status change optional_telemetry( record_type=RecordType.INDEX_ATTEMPT_STATUS, data={ "index_attempt_id": index_attempt_id, "status": IndexingStatus.SUCCESS.value, "cc_pair_id": attempt.connector_credential_pair_id, }, ) return attempt except Exception: db_session.rollback() raise def mark_attempt_partially_succeeded( index_attempt_id: int, db_session: Session, ) -> IndexAttempt: try: attempt = db_session.execute( select(IndexAttempt) .where(IndexAttempt.id == index_attempt_id) .with_for_update() ).scalar_one() attempt.status = IndexingStatus.COMPLETED_WITH_ERRORS attempt.celery_task_id = None db_session.commit() # Add telemetry for index attempt status change optional_telemetry( record_type=RecordType.INDEX_ATTEMPT_STATUS, data={ "index_attempt_id": index_attempt_id, "status": IndexingStatus.COMPLETED_WITH_ERRORS.value, "cc_pair_id": attempt.connector_credential_pair_id, }, ) return attempt except Exception: db_session.rollback() raise def mark_attempt_canceled( index_attempt_id: int, db_session: Session, reason: str = "Unknown", ) -> None: try: attempt = db_session.execute( select(IndexAttempt) .where(IndexAttempt.id == index_attempt_id) .with_for_update() ).scalar_one() if not attempt.time_started: attempt.time_started = datetime.now(timezone.utc) attempt.status = IndexingStatus.CANCELED attempt.error_msg = reason db_session.commit() # Add telemetry for index attempt status change optional_telemetry( record_type=RecordType.INDEX_ATTEMPT_STATUS, data={ "index_attempt_id": index_attempt_id, "status": IndexingStatus.CANCELED.value, "cc_pair_id": attempt.connector_credential_pair_id, }, ) except Exception: db_session.rollback() raise def mark_attempt_failed( index_attempt_id: int, db_session: Session, failure_reason: str = "Unknown", full_exception_trace: str | None = None, ) -> None: try: attempt = db_session.execute( select(IndexAttempt) .where(IndexAttempt.id == index_attempt_id) .with_for_update() ).scalar_one() if not attempt.time_started: attempt.time_started = datetime.now(timezone.utc) attempt.status = IndexingStatus.FAILED attempt.error_msg = failure_reason attempt.full_exception_trace = full_exception_trace attempt.celery_task_id = None db_session.commit() # Add telemetry for index attempt status change optional_telemetry( record_type=RecordType.INDEX_ATTEMPT_STATUS, data={ "index_attempt_id": index_attempt_id, "status": IndexingStatus.FAILED.value, "cc_pair_id": attempt.connector_credential_pair_id, }, ) except Exception: db_session.rollback() raise def update_docs_indexed( db_session: Session, index_attempt_id: int, total_docs_indexed: int, new_docs_indexed: int, docs_removed_from_index: int, ) -> None: """Updates the docs_indexed and new_docs_indexed fields of an index attempt. Adds the given values to the current values in the db""" try: attempt = db_session.execute( select(IndexAttempt) .where(IndexAttempt.id == index_attempt_id) .with_for_update() # Locks the row when we try to update ).scalar_one() attempt.total_docs_indexed = ( attempt.total_docs_indexed or 0 ) + total_docs_indexed attempt.new_docs_indexed = (attempt.new_docs_indexed or 0) + new_docs_indexed attempt.docs_removed_from_index = ( attempt.docs_removed_from_index or 0 ) + docs_removed_from_index db_session.commit() except Exception: db_session.rollback() logger.exception("update_docs_indexed exceptioned.") raise def get_last_attempt( connector_id: int, credential_id: int, search_settings_id: int | None, db_session: Session, ) -> IndexAttempt | None: stmt = ( select(IndexAttempt) .join(ConnectorCredentialPair) .where( ConnectorCredentialPair.connector_id == connector_id, ConnectorCredentialPair.credential_id == credential_id, IndexAttempt.search_settings_id == search_settings_id, ) ) # Note, the below is using time_created instead of time_updated stmt = stmt.order_by(desc(IndexAttempt.time_created)) return db_session.execute(stmt).scalars().first() def get_latest_index_attempts_by_status( secondary_index: bool, db_session: Session, status: IndexingStatus, ) -> Sequence[IndexAttempt]: """ Retrieves the most recent index attempt with the specified status for each connector_credential_pair. Filters attempts based on the secondary_index flag to get either future or present index attempts. Returns a sequence of IndexAttempt objects, one for each unique connector_credential_pair. """ latest_failed_attempts = ( select( IndexAttempt.connector_credential_pair_id, func.max(IndexAttempt.id).label("max_failed_id"), ) .join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id) .where( SearchSettings.status == ( IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT ), IndexAttempt.status == status, ) .group_by(IndexAttempt.connector_credential_pair_id) .subquery() ) stmt = select(IndexAttempt).join( latest_failed_attempts, ( IndexAttempt.connector_credential_pair_id == latest_failed_attempts.c.connector_credential_pair_id ) & (IndexAttempt.id == latest_failed_attempts.c.max_failed_id), ) return db_session.execute(stmt).scalars().all() T = TypeVarTuple("T") def _add_only_finished_clause(stmt: Select[tuple[*T]]) -> Select[tuple[*T]]: return stmt.where( IndexAttempt.status.not_in( [IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS] ), ) def get_latest_index_attempts( secondary_index: bool, db_session: Session, eager_load_cc_pair: bool = False, only_finished: bool = False, ) -> Sequence[IndexAttempt]: ids_stmt = select( IndexAttempt.connector_credential_pair_id, func.max(IndexAttempt.id).label("max_id"), ).join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id) status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT ids_stmt = ids_stmt.where(SearchSettings.status == status) if only_finished: ids_stmt = _add_only_finished_clause(ids_stmt) ids_stmt = ids_stmt.group_by(IndexAttempt.connector_credential_pair_id) ids_subquery = ids_stmt.subquery() stmt = ( select(IndexAttempt) .join( ids_subquery, IndexAttempt.connector_credential_pair_id == ids_subquery.c.connector_credential_pair_id, ) .where(IndexAttempt.id == ids_subquery.c.max_id) ) if only_finished: stmt = _add_only_finished_clause(stmt) if eager_load_cc_pair: stmt = stmt.options( joinedload(IndexAttempt.connector_credential_pair), joinedload(IndexAttempt.error_rows), ) return db_session.execute(stmt).scalars().unique().all() # For use with our thread-level parallelism utils. Note that any relationships # you wish to use MUST be eagerly loaded, as the session will not be available # after this function to allow lazy loading. def get_latest_index_attempts_parallel( secondary_index: bool, eager_load_cc_pair: bool = False, only_finished: bool = False, ) -> Sequence[IndexAttempt]: with get_session_with_current_tenant() as db_session: return get_latest_index_attempts( secondary_index, db_session, eager_load_cc_pair, only_finished, ) def get_latest_index_attempt_for_cc_pair_id( db_session: Session, connector_credential_pair_id: int, secondary_index: bool, only_finished: bool = True, ) -> IndexAttempt | None: stmt = select(IndexAttempt) stmt = stmt.where( IndexAttempt.connector_credential_pair_id == connector_credential_pair_id, ) if only_finished: stmt = _add_only_finished_clause(stmt) status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT stmt = stmt.join(SearchSettings).where(SearchSettings.status == status) stmt = stmt.order_by(desc(IndexAttempt.time_created)) stmt = stmt.limit(1) return db_session.execute(stmt).scalar_one_or_none() def get_latest_successful_index_attempt_for_cc_pair_id( db_session: Session, connector_credential_pair_id: int, secondary_index: bool = False, ) -> IndexAttempt | None: """Returns the most recent successful index attempt for the given cc pair, filtered to the current (or future) search settings. Uses MAX(id) semantics to match get_latest_index_attempts_by_status.""" status = IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT stmt = ( select(IndexAttempt) .where( IndexAttempt.connector_credential_pair_id == connector_credential_pair_id, IndexAttempt.status.in_( [IndexingStatus.SUCCESS, IndexingStatus.COMPLETED_WITH_ERRORS] ), ) .join(SearchSettings) .where(SearchSettings.status == status) .order_by(desc(IndexAttempt.id)) .limit(1) ) return db_session.execute(stmt).scalar_one_or_none() def get_latest_successful_index_attempts_parallel( secondary_index: bool = False, ) -> Sequence[IndexAttempt]: """Batch version: returns the latest successful index attempt per cc pair. Covers both SUCCESS and COMPLETED_WITH_ERRORS (matching is_successful()).""" model_status = ( IndexModelStatus.FUTURE if secondary_index else IndexModelStatus.PRESENT ) with get_session_with_current_tenant() as db_session: latest_ids = ( select( IndexAttempt.connector_credential_pair_id, func.max(IndexAttempt.id).label("max_id"), ) .join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id) .where( SearchSettings.status == model_status, IndexAttempt.status.in_( [IndexingStatus.SUCCESS, IndexingStatus.COMPLETED_WITH_ERRORS] ), ) .group_by(IndexAttempt.connector_credential_pair_id) .subquery() ) stmt = select(IndexAttempt).join( latest_ids, ( IndexAttempt.connector_credential_pair_id == latest_ids.c.connector_credential_pair_id ) & (IndexAttempt.id == latest_ids.c.max_id), ) return db_session.execute(stmt).scalars().all() def count_index_attempts_for_cc_pair( db_session: Session, cc_pair_id: int, only_current: bool = True, disinclude_finished: bool = False, ) -> int: stmt = select(IndexAttempt).where( IndexAttempt.connector_credential_pair_id == cc_pair_id ) if disinclude_finished: stmt = stmt.where( IndexAttempt.status.in_( [IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS] ) ) if only_current: stmt = stmt.join(SearchSettings).where( SearchSettings.status == IndexModelStatus.PRESENT ) # Count total items for pagination count_stmt = stmt.with_only_columns(func.count()).order_by(None) total_count = db_session.execute(count_stmt).scalar_one() return total_count def get_paginated_index_attempts_for_cc_pair_id( db_session: Session, cc_pair_id: int, page: int, page_size: int, only_current: bool = True, disinclude_finished: bool = False, ) -> list[IndexAttempt]: stmt = select(IndexAttempt).where( IndexAttempt.connector_credential_pair_id == cc_pair_id ) if disinclude_finished: stmt = stmt.where( IndexAttempt.status.in_( [IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS] ) ) if only_current: stmt = stmt.join(SearchSettings).where( SearchSettings.status == IndexModelStatus.PRESENT ) stmt = stmt.order_by(IndexAttempt.time_started.desc()) # Apply pagination stmt = stmt.offset(page * page_size).limit(page_size) return list(db_session.execute(stmt).scalars().unique().all()) def get_index_attempts_for_cc_pair( db_session: Session, cc_pair_identifier: ConnectorCredentialPairIdentifier, only_current: bool = True, disinclude_finished: bool = False, ) -> Sequence[IndexAttempt]: stmt = ( select(IndexAttempt) .join(ConnectorCredentialPair) .where( and_( ConnectorCredentialPair.connector_id == cc_pair_identifier.connector_id, ConnectorCredentialPair.credential_id == cc_pair_identifier.credential_id, ) ) ) if disinclude_finished: stmt = stmt.where( IndexAttempt.status.in_( [IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS] ) ) if only_current: stmt = stmt.join(SearchSettings).where( SearchSettings.status == IndexModelStatus.PRESENT ) stmt = stmt.order_by(IndexAttempt.time_created.desc()) return db_session.execute(stmt).scalars().all() def delete_index_attempts( cc_pair_id: int, db_session: Session, ) -> None: # First, delete related entries in IndexAttemptErrors stmt_errors = delete(IndexAttemptError).where( IndexAttemptError.index_attempt_id.in_( select(IndexAttempt.id).where( IndexAttempt.connector_credential_pair_id == cc_pair_id ) ) ) db_session.execute(stmt_errors) stmt = delete(IndexAttempt).where( IndexAttempt.connector_credential_pair_id == cc_pair_id, ) db_session.execute(stmt) def expire_index_attempts( search_settings_id: int, db_session: Session, ) -> None: not_started_query = ( update(IndexAttempt) .where(IndexAttempt.search_settings_id == search_settings_id) .where(IndexAttempt.status == IndexingStatus.NOT_STARTED) .values( status=IndexingStatus.CANCELED, error_msg="Canceled, likely due to model swap", ) ) db_session.execute(not_started_query) update_query = ( update(IndexAttempt) .where(IndexAttempt.search_settings_id == search_settings_id) .where(IndexAttempt.status != IndexingStatus.SUCCESS) .values( status=IndexingStatus.FAILED, error_msg="Canceled due to embedding model swap", ) ) db_session.execute(update_query) db_session.commit() def cancel_indexing_attempts_for_ccpair( cc_pair_id: int, db_session: Session, include_secondary_index: bool = False, ) -> None: stmt = ( update(IndexAttempt) .where(IndexAttempt.connector_credential_pair_id == cc_pair_id) .where(IndexAttempt.status == IndexingStatus.NOT_STARTED) .values( status=IndexingStatus.CANCELED, error_msg="Canceled by user", time_started=datetime.now(timezone.utc), ) ) if not include_secondary_index: subquery = select(SearchSettings.id).where( SearchSettings.status != IndexModelStatus.FUTURE ) stmt = stmt.where(IndexAttempt.search_settings_id.in_(subquery)) db_session.execute(stmt) def cancel_indexing_attempts_past_model( db_session: Session, ) -> None: """Stops all indexing attempts that are in progress or not started for any embedding model that not present/future""" db_session.execute( update(IndexAttempt) .where( IndexAttempt.status.in_( [IndexingStatus.IN_PROGRESS, IndexingStatus.NOT_STARTED] ), IndexAttempt.search_settings_id == SearchSettings.id, SearchSettings.status == IndexModelStatus.PAST, ) .values(status=IndexingStatus.FAILED) ) def cancel_indexing_attempts_for_search_settings( search_settings_id: int, db_session: Session, ) -> None: """Stops all indexing attempts that are in progress or not started for the specified search settings.""" db_session.execute( update(IndexAttempt) .where( IndexAttempt.status.in_( [IndexingStatus.IN_PROGRESS, IndexingStatus.NOT_STARTED] ), IndexAttempt.search_settings_id == search_settings_id, ) .values(status=IndexingStatus.FAILED) ) def count_unique_cc_pairs_with_successful_index_attempts( search_settings_id: int | None, db_session: Session, ) -> int: """Collect all of the Index Attempts that are successful and for the specified embedding model Then do distinct by connector_id and credential_id which is equivalent to the cc-pair. Finally, do a count to get the total number of unique cc-pairs with successful attempts""" unique_pairs_count = ( db_session.query(IndexAttempt.connector_credential_pair_id) .join(ConnectorCredentialPair) .filter( IndexAttempt.search_settings_id == search_settings_id, IndexAttempt.status == IndexingStatus.SUCCESS, ) .distinct() .count() ) return unique_pairs_count def count_unique_active_cc_pairs_with_successful_index_attempts( search_settings_id: int | None, db_session: Session, ) -> int: """Collect all of the Index Attempts that are successful and for the specified embedding model, but only for non-paused connector-credential pairs. Then do distinct by connector_id and credential_id which is equivalent to the cc-pair. Finally, do a count to get the total number of unique non-paused cc-pairs with successful attempts.""" unique_pairs_count = ( db_session.query(IndexAttempt.connector_credential_pair_id) .join(ConnectorCredentialPair) .filter( IndexAttempt.search_settings_id == search_settings_id, IndexAttempt.status == IndexingStatus.SUCCESS, ConnectorCredentialPair.status != ConnectorCredentialPairStatus.PAUSED, ) .distinct() .count() ) return unique_pairs_count def create_index_attempt_error( index_attempt_id: int | None, connector_credential_pair_id: int, failure: ConnectorFailure, db_session: Session, ) -> int: new_error = IndexAttemptError( index_attempt_id=index_attempt_id, connector_credential_pair_id=connector_credential_pair_id, document_id=( failure.failed_document.document_id if failure.failed_document else None ), document_link=( failure.failed_document.document_link if failure.failed_document else None ), entity_id=(failure.failed_entity.entity_id if failure.failed_entity else None), failed_time_range_start=( failure.failed_entity.missed_time_range[0] if failure.failed_entity and failure.failed_entity.missed_time_range else None ), failed_time_range_end=( failure.failed_entity.missed_time_range[1] if failure.failed_entity and failure.failed_entity.missed_time_range else None ), failure_message=failure.failure_message, is_resolved=False, ) db_session.add(new_error) db_session.commit() return new_error.id def get_index_attempt_errors( index_attempt_id: int, db_session: Session, ) -> list[IndexAttemptError]: stmt = select(IndexAttemptError).where( IndexAttemptError.index_attempt_id == index_attempt_id ) errors = db_session.scalars(stmt) return list(errors.all()) def count_index_attempt_errors_for_cc_pair( cc_pair_id: int, unresolved_only: bool, db_session: Session, ) -> int: stmt = ( select(func.count()) .select_from(IndexAttemptError) .where(IndexAttemptError.connector_credential_pair_id == cc_pair_id) ) if unresolved_only: stmt = stmt.where(IndexAttemptError.is_resolved.is_(False)) result = db_session.scalar(stmt) return 0 if result is None else result def get_index_attempt_errors_for_cc_pair( cc_pair_id: int, unresolved_only: bool, db_session: Session, page: int | None = None, page_size: int | None = None, ) -> list[IndexAttemptError]: stmt = select(IndexAttemptError).where( IndexAttemptError.connector_credential_pair_id == cc_pair_id ) if unresolved_only: stmt = stmt.where(IndexAttemptError.is_resolved.is_(False)) # Order by most recent first stmt = stmt.order_by(desc(IndexAttemptError.time_created)) if page is not None and page_size is not None: stmt = stmt.offset(page * page_size).limit(page_size) return list(db_session.scalars(stmt).all()) # ── Metrics query helpers ────────────────────────────────────────────── class ActiveIndexAttemptMetric(NamedTuple): """Row returned by get_active_index_attempts_for_metrics.""" status: IndexingStatus source: "DocumentSource" cc_pair_id: int cc_pair_name: str | None attempt_count: int def get_active_index_attempts_for_metrics( db_session: Session, ) -> list[ActiveIndexAttemptMetric]: """Return non-terminal index attempts grouped by status, source, and connector. Each row is (status, source, cc_pair_id, cc_pair_name, attempt_count). """ from onyx.db.models import Connector terminal_statuses = [s for s in IndexingStatus if s.is_terminal()] rows = ( db_session.query( IndexAttempt.status, Connector.source, ConnectorCredentialPair.id, ConnectorCredentialPair.name, func.count(), ) .join( ConnectorCredentialPair, IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id, ) .join( Connector, ConnectorCredentialPair.connector_id == Connector.id, ) .filter(IndexAttempt.status.notin_(terminal_statuses)) .group_by( IndexAttempt.status, Connector.source, ConnectorCredentialPair.id, ConnectorCredentialPair.name, ) .all() ) return [ActiveIndexAttemptMetric(*row) for row in rows] def get_failed_attempt_counts_by_cc_pair( db_session: Session, since: datetime | None = None, ) -> dict[int, int]: """Return {cc_pair_id: failed_attempt_count} for all connectors. When ``since`` is provided, only attempts created after that timestamp are counted. Defaults to the last 90 days to avoid unbounded historical aggregation. """ if since is None: since = datetime.now(timezone.utc) - timedelta(days=90) rows = ( db_session.query( IndexAttempt.connector_credential_pair_id, func.count(), ) .filter(IndexAttempt.status == IndexingStatus.FAILED) .filter(IndexAttempt.time_created >= since) .group_by(IndexAttempt.connector_credential_pair_id) .all() ) return {cc_id: count for cc_id, count in rows} def get_docs_indexed_by_cc_pair( db_session: Session, since: datetime | None = None, ) -> dict[int, int]: """Return {cc_pair_id: total_new_docs_indexed} across successful attempts. Only counts attempts with status SUCCESS to avoid inflating counts with partial results from failed attempts. When ``since`` is provided, only attempts created after that timestamp are included. """ if since is None: since = datetime.now(timezone.utc) - timedelta(days=90) query = ( db_session.query( IndexAttempt.connector_credential_pair_id, func.sum(func.coalesce(IndexAttempt.new_docs_indexed, 0)), ) .filter(IndexAttempt.status == IndexingStatus.SUCCESS) .filter(IndexAttempt.time_created >= since) .group_by(IndexAttempt.connector_credential_pair_id) ) rows = query.all() return {cc_id: int(total or 0) for cc_id, total in rows} ================================================ FILE: backend/onyx/db/indexing_coordination.py ================================================ """Database-based indexing coordination to replace Redis fencing.""" from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import Session from onyx.db.engine.time_utils import get_db_current_time from onyx.db.enums import IndexingStatus from onyx.db.index_attempt import count_error_rows_for_index_attempt from onyx.db.index_attempt import create_index_attempt from onyx.db.index_attempt import get_index_attempt from onyx.db.models import IndexAttempt from onyx.utils.logger import setup_logger logger = setup_logger() INDEXING_PROGRESS_TIMEOUT_HOURS = 6 class CoordinationStatus(BaseModel): """Status of an indexing attempt's coordination.""" found: bool total_batches: int | None completed_batches: int total_failures: int total_docs: int total_chunks: int status: IndexingStatus | None = None cancellation_requested: bool = False class IndexingCoordination: """Database-based coordination for indexing tasks, replacing Redis fencing.""" @staticmethod def try_create_index_attempt( db_session: Session, cc_pair_id: int, search_settings_id: int, celery_task_id: str, from_beginning: bool = False, ) -> int | None: """ Try to create a new index attempt for the given CC pair and search settings. Returns the index_attempt_id if successful, None if another attempt is already running. This replaces the Redis fencing mechanism by using database constraints and transactions to prevent duplicate attempts. """ try: # Check for existing active attempts (this is the "fence" check) existing_attempt = db_session.execute( select(IndexAttempt) .where( IndexAttempt.connector_credential_pair_id == cc_pair_id, IndexAttempt.search_settings_id == search_settings_id, IndexAttempt.status.in_( [IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS] ), ) .with_for_update(nowait=True) ).first() if existing_attempt: logger.info( f"Indexing already in progress: " f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id} " f"existing_attempt={existing_attempt[0].id}" ) return None # Create new index attempt (this is setting the "fence") attempt_id = create_index_attempt( connector_credential_pair_id=cc_pair_id, search_settings_id=search_settings_id, from_beginning=from_beginning, db_session=db_session, celery_task_id=celery_task_id, ) logger.info( f"Created Index Attempt: " f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id} " f"attempt_id={attempt_id} " f"celery_task_id={celery_task_id}" ) return attempt_id except SQLAlchemyError as e: logger.info( f"Failed to create index attempt (likely race condition): " f"cc_pair={cc_pair_id} " f"search_settings={search_settings_id} " f"error={str(e)}" ) db_session.rollback() return None @staticmethod def check_cancellation_requested( db_session: Session, index_attempt_id: int, ) -> bool: """ Check if cancellation has been requested for this indexing attempt. This replaces Redis termination signals. """ attempt = get_index_attempt(db_session, index_attempt_id) return attempt.cancellation_requested if attempt else False @staticmethod def request_cancellation( db_session: Session, index_attempt_id: int, ) -> None: """ Request cancellation of an indexing attempt. This replaces Redis termination signals. """ attempt = get_index_attempt(db_session, index_attempt_id) if attempt: attempt.cancellation_requested = True db_session.commit() logger.info(f"Requested cancellation for attempt {index_attempt_id}") @staticmethod def set_total_batches( db_session: Session, index_attempt_id: int, total_batches: int, ) -> None: """ Set the total number of batches for this indexing attempt. Called by docfetching when extraction is complete. """ attempt = get_index_attempt(db_session, index_attempt_id) if attempt: attempt.total_batches = total_batches db_session.commit() logger.info( f"Set total batches: attempt={index_attempt_id} total={total_batches}" ) @staticmethod def update_batch_completion_and_docs( db_session: Session, index_attempt_id: int, total_docs_indexed: int, new_docs_indexed: int, total_chunks: int, ) -> tuple[int, int | None]: """ Update batch completion and document counts atomically. Returns (completed_batches, total_batches). This extends the existing update_docs_indexed pattern. """ try: attempt = db_session.execute( select(IndexAttempt) .where(IndexAttempt.id == index_attempt_id) .with_for_update() # Same pattern as existing update_docs_indexed ).scalar_one() # Existing document count updates attempt.total_docs_indexed = ( attempt.total_docs_indexed or 0 ) + total_docs_indexed attempt.new_docs_indexed = ( attempt.new_docs_indexed or 0 ) + new_docs_indexed # New coordination updates attempt.completed_batches = (attempt.completed_batches or 0) + 1 attempt.total_chunks = (attempt.total_chunks or 0) + total_chunks db_session.commit() logger.info( f"Updated batch completion: " f"attempt={index_attempt_id} " f"completed={attempt.completed_batches} " f"total={attempt.total_batches} " f"docs={total_docs_indexed} " ) return attempt.completed_batches, attempt.total_batches except Exception: db_session.rollback() logger.exception( f"Failed to update batch completion for attempt {index_attempt_id}" ) raise @staticmethod def get_coordination_status( db_session: Session, index_attempt_id: int, ) -> CoordinationStatus: """ Get the current coordination status for an indexing attempt. This replaces reading FileStore state files. """ attempt = get_index_attempt(db_session, index_attempt_id) if not attempt: return CoordinationStatus( found=False, total_batches=None, completed_batches=0, total_failures=0, total_docs=0, total_chunks=0, status=None, cancellation_requested=False, ) return CoordinationStatus( found=True, total_batches=attempt.total_batches, completed_batches=attempt.completed_batches, total_failures=count_error_rows_for_index_attempt( index_attempt_id, db_session ), total_docs=attempt.total_docs_indexed or 0, total_chunks=attempt.total_chunks, status=attempt.status, cancellation_requested=attempt.cancellation_requested, ) @staticmethod def get_orphaned_index_attempt_ids(db_session: Session) -> list[int]: """ Gets a list of potentially orphaned index attempts. These are attempts in non-terminal state that have task IDs but may have died. This replaces the old get_unfenced_index_attempt_ids function. The actual orphan detection requires checking with Celery, which should be done by the caller. """ # Find attempts that are active and have task IDs # The caller needs to check each one with Celery to confirm orphaned status active_attempts = ( db_session.execute( select(IndexAttempt).where( IndexAttempt.status.in_( [IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS] ), IndexAttempt.celery_task_id.isnot(None), ) ) .scalars() .all() ) return [attempt.id for attempt in active_attempts] @staticmethod def update_progress_tracking( db_session: Session, index_attempt_id: int, current_batches_completed: int, timeout_hours: int = INDEXING_PROGRESS_TIMEOUT_HOURS, force_update_progress: bool = False, ) -> bool: """ Update progress tracking for stall detection. Returns True if sufficient progress was made, False if stalled. """ attempt = get_index_attempt(db_session, index_attempt_id) if not attempt: logger.error(f"Index attempt {index_attempt_id} not found in database") return False current_time = get_db_current_time(db_session) # No progress - check if this is the first time tracking # or if the caller wants to simulate guaranteed progress if attempt.last_progress_time is None or force_update_progress: # First time tracking - initialize attempt.last_progress_time = current_time attempt.last_batches_completed_count = current_batches_completed db_session.commit() return True time_elapsed = (current_time - attempt.last_progress_time).total_seconds() # only actually write to db every timeout_hours/2 # this ensure thats at most timeout_hours will pass with no activity if time_elapsed < timeout_hours * 1800: return True # Check if progress has been made if current_batches_completed <= attempt.last_batches_completed_count: # if between timeout_hours/2 and timeout_hours has passed # without an update, we consider the attempt stalled return False # Progress made - update tracking attempt.last_progress_time = current_time attempt.last_batches_completed_count = current_batches_completed db_session.commit() return True ================================================ FILE: backend/onyx/db/input_prompt.py ================================================ from uuid import UUID from fastapi import HTTPException from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import aliased from sqlalchemy.orm import Session from onyx.db.models import InputPrompt from onyx.db.models import InputPrompt__User from onyx.db.models import User from onyx.server.features.input_prompt.models import InputPromptSnapshot from onyx.server.manage.models import UserInfo from onyx.utils.logger import setup_logger logger = setup_logger() def insert_input_prompt( prompt: str, content: str, is_public: bool, user: User | None, db_session: Session, ) -> InputPrompt: user_id = user.id if user else None # Use atomic INSERT ... ON CONFLICT DO NOTHING with RETURNING # to avoid race conditions with the uniqueness check stmt = pg_insert(InputPrompt).values( prompt=prompt, content=content, active=True, is_public=is_public, user_id=user_id, ) # Use the appropriate constraint based on whether this is a user-owned or public prompt if user_id is not None: stmt = stmt.on_conflict_do_nothing(constraint="uq_inputprompt_prompt_user_id") else: # Partial unique indexes cannot be targeted by constraint name; # must use index_elements + index_where stmt = stmt.on_conflict_do_nothing( index_elements=[InputPrompt.prompt], index_where=InputPrompt.user_id.is_(None), ) stmt = stmt.returning(InputPrompt) result = db_session.execute(stmt) input_prompt = result.scalar_one_or_none() if input_prompt is None: raise HTTPException( status_code=409, detail=f"A prompt shortcut with the name '{prompt}' already exists", ) db_session.commit() return input_prompt def update_input_prompt( user: User, input_prompt_id: int, prompt: str, content: str, active: bool, db_session: Session, ) -> InputPrompt: input_prompt = db_session.scalar( select(InputPrompt).where(InputPrompt.id == input_prompt_id) ) if input_prompt is None: raise ValueError(f"No input prompt with id {input_prompt_id}") if not validate_user_prompt_authorization(user, input_prompt): raise HTTPException(status_code=401, detail="You don't own this prompt") input_prompt.prompt = prompt input_prompt.content = content input_prompt.active = active try: db_session.commit() except IntegrityError: db_session.rollback() raise HTTPException( status_code=409, detail=f"A prompt shortcut with the name '{prompt}' already exists", ) return input_prompt def validate_user_prompt_authorization(user: User, input_prompt: InputPrompt) -> bool: prompt = InputPromptSnapshot.from_model(input_prompt=input_prompt) # Public prompts cannot be modified via the user API (only admins via admin endpoints) if prompt.is_public or prompt.user_id is None: return False # Anonymous users cannot modify user-owned prompts if user.is_anonymous: return False # User must own the prompt user_details = UserInfo.from_model(user) return str(user_details.id) == str(prompt.user_id) def remove_public_input_prompt(input_prompt_id: int, db_session: Session) -> None: input_prompt = db_session.scalar( select(InputPrompt).where(InputPrompt.id == input_prompt_id) ) if input_prompt is None: raise ValueError(f"No input prompt with id {input_prompt_id}") if not input_prompt.is_public: raise HTTPException(status_code=400, detail="This prompt is not public") db_session.delete(input_prompt) db_session.commit() def remove_input_prompt( user: User, input_prompt_id: int, db_session: Session, delete_public: bool = False, ) -> None: input_prompt = db_session.scalar( select(InputPrompt).where(InputPrompt.id == input_prompt_id) ) if input_prompt is None: raise ValueError(f"No input prompt with id {input_prompt_id}") if input_prompt.is_public and not delete_public: raise HTTPException( status_code=400, detail="Cannot delete public prompts with this method" ) if not validate_user_prompt_authorization(user, input_prompt): raise HTTPException(status_code=401, detail="You do not own this prompt") db_session.delete(input_prompt) db_session.commit() def fetch_input_prompt_by_id( id: int, user_id: UUID | None, db_session: Session ) -> InputPrompt: query = select(InputPrompt).where(InputPrompt.id == id) if user_id: query = query.where( (InputPrompt.user_id == user_id) | (InputPrompt.user_id is None) ) else: # If no user_id is provided, only fetch prompts without a user_id (aka public) query = query.where(InputPrompt.user_id == None) # noqa result = db_session.scalar(query) if result is None: raise HTTPException(422, "No input prompt found") return result def fetch_public_input_prompts( db_session: Session, ) -> list[InputPrompt]: query = select(InputPrompt).where(InputPrompt.is_public) return list(db_session.scalars(query).all()) def fetch_input_prompts_by_user( db_session: Session, user_id: UUID | None, active: bool | None = None, include_public: bool = False, ) -> list[InputPrompt]: """ Returns all prompts belonging to the user or public prompts, excluding those the user has specifically disabled. """ query = select(InputPrompt) if user_id is not None: # If we have a user, left join to InputPrompt__User to check "disabled" IPU = aliased(InputPrompt__User) query = query.join( IPU, (IPU.input_prompt_id == InputPrompt.id) & (IPU.user_id == user_id), isouter=True, ) # Exclude disabled prompts query = query.where(or_(IPU.disabled.is_(None), IPU.disabled.is_(False))) if include_public: # Return both user-owned and public prompts query = query.where( or_( InputPrompt.user_id == user_id, InputPrompt.is_public, ) ) else: # Return only user-owned prompts query = query.where(InputPrompt.user_id == user_id) else: # user_id is None - anonymous usage if include_public: query = query.where(InputPrompt.is_public) else: # No user and not requesting public prompts - return nothing return [] if active is not None: query = query.where(InputPrompt.active == active) return list(db_session.scalars(query).all()) def disable_input_prompt_for_user( input_prompt_id: int, user_id: UUID, db_session: Session, ) -> None: """ Sets (or creates) a record in InputPrompt__User with disabled=True so that this prompt is hidden for the user. """ ipu = ( db_session.query(InputPrompt__User) .filter_by(input_prompt_id=input_prompt_id, user_id=user_id) .first() ) if ipu is None: # Create a new association row ipu = InputPrompt__User( input_prompt_id=input_prompt_id, user_id=user_id, disabled=True ) db_session.add(ipu) else: # Just update the existing record ipu.disabled = True db_session.commit() ================================================ FILE: backend/onyx/db/kg_config.py ================================================ from onyx.configs.constants import KV_KG_CONFIG_KEY from onyx.key_value_store.factory import get_kv_store from onyx.key_value_store.interface import KvKeyNotFoundError from onyx.kg.models import KGConfigSettings from onyx.server.kg.models import EnableKGConfigRequest from onyx.utils.logger import setup_logger logger = setup_logger() def set_kg_config_settings(kg_config_settings: KGConfigSettings) -> None: kv_store = get_kv_store() kv_store.store(KV_KG_CONFIG_KEY, kg_config_settings.model_dump()) def get_kg_config_settings() -> KGConfigSettings: kv_store = get_kv_store() try: # refresh cache True until beta is over as we may manually update the config in the db stored_config = kv_store.load(KV_KG_CONFIG_KEY, refresh_cache=True) return KGConfigSettings.model_validate(stored_config or {}) except KvKeyNotFoundError: # Default to empty kg config if no config have been set yet logger.debug(f"No kg config found in KV store for key: {KV_KG_CONFIG_KEY}") return KGConfigSettings() except Exception as e: logger.error(f"Error loading kg config from KV store: {str(e)}") return KGConfigSettings() def validate_kg_settings(kg_config_settings: KGConfigSettings) -> None: if not kg_config_settings.KG_ENABLED: raise ValueError("KG is not enabled") if not kg_config_settings.KG_VENDOR: raise ValueError("KG_VENDOR is not set") if not kg_config_settings.KG_VENDOR_DOMAINS: raise ValueError("KG_VENDOR_DOMAINS is not set") def is_kg_config_settings_enabled_valid(kg_config_settings: KGConfigSettings) -> bool: try: validate_kg_settings(kg_config_settings) return True except Exception: return False def enable_kg(enable_req: EnableKGConfigRequest) -> None: kg_config_settings = get_kg_config_settings() kg_config_settings.KG_ENABLED = True kg_config_settings.KG_VENDOR = enable_req.vendor kg_config_settings.KG_VENDOR_DOMAINS = enable_req.vendor_domains kg_config_settings.KG_IGNORE_EMAIL_DOMAINS = enable_req.ignore_domains kg_config_settings.KG_COVERAGE_START = enable_req.coverage_start.strftime( "%Y-%m-%d" ) kg_config_settings.KG_MAX_COVERAGE_DAYS = 10000 # TODO: revisit after public beta validate_kg_settings(kg_config_settings) set_kg_config_settings(kg_config_settings) def disable_kg() -> None: kg_config_settings = get_kg_config_settings() kg_config_settings.KG_ENABLED = False set_kg_config_settings(kg_config_settings) ================================================ FILE: backend/onyx/db/kg_temp_view.py ================================================ # import random # from sqlalchemy import text # from sqlalchemy.ext.declarative import declarative_base # from sqlalchemy.orm import Session # from onyx.agents.agent_search.kb_search.models import KGViewNames # from onyx.configs.app_configs import DB_READONLY_USER # from onyx.configs.kg_configs import KG_TEMP_ALLOWED_DOCS_VIEW_NAME_PREFIX # from onyx.configs.kg_configs import KG_TEMP_KG_ENTITIES_VIEW_NAME_PREFIX # from onyx.configs.kg_configs import KG_TEMP_KG_RELATIONSHIPS_VIEW_NAME_PREFIX # from onyx.db.engine.sql_engine import get_session_with_current_tenant # Base = declarative_base() # def get_user_view_names( # user_email: str, tenant_id: str # ) -> KGViewNames: # user_email_cleaned = ( # user_email.replace("@", "__") # .replace(".", "_") # .replace("+", "_") # ) # random_suffix_str = str( # random.randint(1000000, 9999999) # ) # return KGViewNames( # allowed_docs_view_name=( # f'"{tenant_id}".' # f"{KG_TEMP_ALLOWED_DOCS_VIEW_NAME_PREFIX}_" # f"{user_email_cleaned}_{random_suffix_str}" # ), # kg_relationships_view_name=( # f'"{tenant_id}".' # f"{KG_TEMP_KG_RELATIONSHIPS_VIEW_NAME_PREFIX}_" # f"{user_email_cleaned}_{random_suffix_str}" # ), # kg_entity_view_name=( # f'"{tenant_id}".' # f"{KG_TEMP_KG_ENTITIES_VIEW_NAME_PREFIX}_" # f"{user_email_cleaned}_{random_suffix_str}" # ), # ) # # First, create the view definition # def create_views( # db_session: Session, # tenant_id: str, # user_email: str, # allowed_docs_view_name: str, # kg_relationships_view_name: str, # kg_entity_view_name: str, # ) -> None: # # Create ALLOWED_DOCS view # allowed_docs_view = text( # f""" # CREATE OR REPLACE VIEW {allowed_docs_view_name} AS # WITH kg_used_docs AS ( # SELECT document_id as kg_used_doc_id # FROM "{tenant_id}".kg_entity d # WHERE document_id IS NOT NULL # ), # base_public_docs AS ( # SELECT d.id as allowed_doc_id # FROM "{tenant_id}".document d # INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id # WHERE d.is_public # ), # user_owned_and_public_docs AS ( # SELECT d.id as allowed_doc_id # FROM "{tenant_id}".document_by_connector_credential_pair d # JOIN "{tenant_id}".credential c ON d.credential_id = c.id # JOIN "{tenant_id}".connector_credential_pair ccp ON # d.connector_id = ccp.connector_id AND # d.credential_id = ccp.credential_id # JOIN "{tenant_id}".user u ON c.user_id = u.id # INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id # WHERE ccp.status != 'DELETING' # AND ccp.access_type != 'SYNC' # AND (u.email = :user_email or ccp.access_type::text = 'PUBLIC') # ), # user_group_accessible_docs AS ( # SELECT d.id as allowed_doc_id # FROM "{tenant_id}".document_by_connector_credential_pair d # JOIN "{tenant_id}".connector_credential_pair ccp ON # d.connector_id = ccp.connector_id AND # d.credential_id = ccp.credential_id # JOIN "{tenant_id}".user_group__connector_credential_pair ugccp ON # ccp.id = ugccp.cc_pair_id # JOIN "{tenant_id}".user__user_group uug ON # uug.user_group_id = ugccp.user_group_id # JOIN "{tenant_id}".user u ON uug.user_id = u.id # INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id # WHERE kud.kg_used_doc_id IS NOT NULL # AND ccp.status != 'DELETING' # AND ccp.access_type != 'SYNC' # AND u.email = :user_email # ), # external_user_docs AS ( # SELECT d.id as allowed_doc_id # FROM "{tenant_id}".document d # INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id # WHERE kud.kg_used_doc_id IS NOT NULL # AND :user_email = ANY(external_user_emails) # ), # external_group_docs AS ( # SELECT d.id as allowed_doc_id # FROM "{tenant_id}".document d # INNER JOIN kg_used_docs kud ON kud.kg_used_doc_id = d.id # JOIN "{tenant_id}".user__external_user_group_id ueg ON ueg.external_user_group_id = ANY(d.external_user_group_ids) # JOIN "{tenant_id}".user u ON ueg.user_id = u.id # WHERE kud.kg_used_doc_id IS NOT NULL # AND u.email = :user_email # ) # SELECT DISTINCT allowed_doc_id FROM ( # SELECT allowed_doc_id FROM base_public_docs # UNION # SELECT allowed_doc_id FROM user_owned_and_public_docs # UNION # SELECT allowed_doc_id FROM user_group_accessible_docs # UNION # SELECT allowed_doc_id FROM external_user_docs # UNION # SELECT allowed_doc_id FROM external_group_docs # ) combined_docs # """ # ).bindparams(user_email=user_email) # # Create the main view that uses ALLOWED_DOCS for Relationships # kg_relationships_view = text( # f""" # CREATE OR REPLACE VIEW {kg_relationships_view_name} AS # SELECT kgr.id_name as relationship, # kgr.source_node as source_entity, # kgr.target_node as target_entity, # kgr.source_node_type as source_entity_type, # kgr.target_node_type as target_entity_type, # kgr.type as relationship_description, # kgr.relationship_type_id_name as relationship_type, # kgr.source_document as source_document, # d.doc_updated_at as source_date, # se.attributes as source_entity_attributes, # te.attributes as target_entity_attributes # FROM "{tenant_id}".kg_relationship kgr # INNER JOIN {allowed_docs_view_name} AD on AD.allowed_doc_id = kgr.source_document # JOIN "{tenant_id}".document d on d.id = kgr.source_document # JOIN "{tenant_id}".kg_entity se on se.id_name = kgr.source_node # JOIN "{tenant_id}".kg_entity te on te.id_name = kgr.target_node # """ # ) # # Create the main view that uses ALLOWED_DOCS for Entities # kg_entity_view = text( # f""" # CREATE OR REPLACE VIEW {kg_entity_view_name} AS # SELECT kge.id_name as entity, # kge.entity_type_id_name as entity_type, # kge.attributes as entity_attributes, # kge.document_id as source_document, # d.doc_updated_at as source_date # FROM "{tenant_id}".kg_entity kge # INNER JOIN {allowed_docs_view_name} AD on AD.allowed_doc_id = kge.document_id # JOIN "{tenant_id}".document d on d.id = kge.document_id # """ # ) # # Execute the views using the session # db_session.execute(allowed_docs_view) # db_session.execute(kg_relationships_view) # db_session.execute(kg_entity_view) # # Grant permissions on view to readonly user # db_session.execute( # text(f"GRANT SELECT ON {kg_relationships_view_name} TO {DB_READONLY_USER}") # ) # db_session.execute( # text(f"GRANT SELECT ON {kg_entity_view_name} TO {DB_READONLY_USER}") # ) # db_session.commit() # return None # def drop_views( # allowed_docs_view_name: str | None = None, # kg_relationships_view_name: str | None = None, # kg_entity_view_name: str | None = None, # ) -> None: # """ # Drops the temporary views created by create_views. # Args: # db_session: SQLAlchemy session # allowed_docs_view_name: Name of the allowed_docs view # kg_relationships_view_name: Name of the allowed kg_relationships view # kg_entity_view_name: Name of the allowed kg_entity view # """ # with get_session_with_current_tenant() as db_drop_session: # if kg_relationships_view_name: # revoke_kg_relationships = text( # f"REVOKE SELECT ON {kg_relationships_view_name} FROM {DB_READONLY_USER}" # ) # db_drop_session.execute(revoke_kg_relationships) # drop_kg_relationships = text( # f"DROP VIEW IF EXISTS {kg_relationships_view_name}" # ) # db_drop_session.execute(drop_kg_relationships) # if kg_entity_view_name: # revoke_kg_entities = text( # f"REVOKE SELECT ON {kg_entity_view_name} FROM {DB_READONLY_USER}" # ) # db_drop_session.execute(revoke_kg_entities) # drop_kg_entities = text(f"DROP VIEW IF EXISTS {kg_entity_view_name}") # db_drop_session.execute(drop_kg_entities) # if allowed_docs_view_name: # drop_allowed_docs = text(f"DROP VIEW IF EXISTS {allowed_docs_view_name}") # db_drop_session.execute(drop_allowed_docs) # db_drop_session.commit() # return None ================================================ FILE: backend/onyx/db/llm.py ================================================ from sqlalchemy import delete from sqlalchemy import select from sqlalchemy import update from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from onyx.db.enums import LLMModelFlowType from onyx.db.models import CloudEmbeddingProvider as CloudEmbeddingProviderModel from onyx.db.models import DocumentSet from onyx.db.models import ImageGenerationConfig from onyx.db.models import LLMModelFlow from onyx.db.models import LLMProvider as LLMProviderModel from onyx.db.models import LLMProvider__Persona from onyx.db.models import LLMProvider__UserGroup from onyx.db.models import ModelConfiguration from onyx.db.models import Persona from onyx.db.models import SearchSettings from onyx.db.models import Tool as ToolModel from onyx.db.models import User from onyx.db.models import User__UserGroup from onyx.llm.utils import model_supports_image_input from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations from onyx.server.manage.embedding.models import CloudEmbeddingProvider from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest from onyx.server.manage.llm.models import LLMProviderUpsertRequest from onyx.server.manage.llm.models import LLMProviderView from onyx.server.manage.llm.models import SyncModelEntry from onyx.utils.logger import setup_logger from shared_configs.enums import EmbeddingProvider logger = setup_logger() def update_group_llm_provider_relationships__no_commit( llm_provider_id: int, group_ids: list[int] | None, db_session: Session, ) -> None: # Delete existing relationships db_session.query(LLMProvider__UserGroup).filter( LLMProvider__UserGroup.llm_provider_id == llm_provider_id ).delete(synchronize_session="fetch") # Add new relationships from given group_ids if group_ids: new_relationships = [ LLMProvider__UserGroup( llm_provider_id=llm_provider_id, user_group_id=group_id, ) for group_id in group_ids ] db_session.add_all(new_relationships) def update_llm_provider_persona_relationships__no_commit( db_session: Session, llm_provider_id: int, persona_ids: list[int] | None, ) -> None: """Replace the persona restrictions for a provider within an open transaction.""" db_session.execute( delete(LLMProvider__Persona).where( LLMProvider__Persona.llm_provider_id == llm_provider_id ) ) if persona_ids: db_session.add_all( LLMProvider__Persona( llm_provider_id=llm_provider_id, persona_id=persona_id, ) for persona_id in persona_ids ) def fetch_user_group_ids(db_session: Session, user: User) -> set[int]: """Fetch the set of user group IDs for a given user. Args: db_session: Database session user: User to fetch groups for Returns: Set of user group IDs. Empty set for anonymous users. """ if user.is_anonymous: return set() return set( db_session.scalars( select(User__UserGroup.user_group_id).where( User__UserGroup.user_id == user.id ) ).all() ) def can_user_access_llm_provider( provider: LLMProviderModel, user_group_ids: set[int], persona: Persona | None, is_admin: bool = False, ) -> bool: """Check if a user may use an LLM provider. Args: provider: The LLM provider to check access for user_group_ids: Set of user group IDs the user belongs to persona: The persona being used (if any) is_admin: If True, bypass user group restrictions but still respect persona restrictions Access logic: - is_public controls USER access (group bypass): when True, all users can access regardless of group membership. When False, user must be in a whitelisted group (or be admin). - Persona restrictions are ALWAYS enforced when set, regardless of is_public. This allows admins to make a provider available to all users while still restricting which personas (assistants) can use it. Decision matrix: 1. is_public=True, no personas set → everyone has access 2. is_public=True, personas set → all users, but only whitelisted personas 3. is_public=False, groups+personas set → must satisfy BOTH (admins bypass groups) 4. is_public=False, only groups set → must be in group (admins bypass) 5. is_public=False, only personas set → must use whitelisted persona 6. is_public=False, neither set → admin-only (locked) """ provider_group_ids = {g.id for g in (provider.groups or [])} provider_persona_ids = {p.id for p in (provider.personas or [])} has_groups = bool(provider_group_ids) has_personas = bool(provider_persona_ids) # Persona restrictions are always enforced when set, regardless of is_public if has_personas and not (persona and persona.id in provider_persona_ids): return False if provider.is_public: return True if has_groups: return is_admin or bool(user_group_ids & provider_group_ids) # No groups: either persona-whitelisted (already passed) or admin-only if locked return has_personas or is_admin def validate_persona_ids_exist( db_session: Session, persona_ids: list[int] ) -> tuple[set[int], list[int]]: """Validate that persona IDs exist in the database. Returns: Tuple of (fetched_persona_ids, missing_personas) """ fetched_persona_ids = set( db_session.scalars(select(Persona.id).where(Persona.id.in_(persona_ids))).all() ) missing_personas = sorted(set(persona_ids) - fetched_persona_ids) return fetched_persona_ids, missing_personas def get_personas_using_provider( db_session: Session, provider_name: str ) -> list[Persona]: """Get all non-deleted personas that use a specific LLM provider.""" return list( db_session.scalars( select(Persona).where( Persona.llm_model_provider_override == provider_name, Persona.deleted == False, # noqa: E712 ) ).all() ) def fetch_persona_with_groups(db_session: Session, persona_id: int) -> Persona | None: """Fetch a persona with its groups eagerly loaded.""" return db_session.scalar( select(Persona) .options(selectinload(Persona.groups)) .where(Persona.id == persona_id, Persona.deleted == False) # noqa: E712 ) def upsert_cloud_embedding_provider( db_session: Session, provider: CloudEmbeddingProviderCreationRequest ) -> CloudEmbeddingProvider: existing_provider = ( db_session.query(CloudEmbeddingProviderModel) .filter_by(provider_type=provider.provider_type) .first() ) if existing_provider: for key, value in provider.model_dump().items(): setattr(existing_provider, key, value) else: new_provider = CloudEmbeddingProviderModel(**provider.model_dump()) db_session.add(new_provider) existing_provider = new_provider db_session.commit() db_session.refresh(existing_provider) return CloudEmbeddingProvider.from_request(existing_provider) def upsert_llm_provider( llm_provider_upsert_request: LLMProviderUpsertRequest, db_session: Session, ) -> LLMProviderView: existing_llm_provider: LLMProviderModel | None = None if llm_provider_upsert_request.id: existing_llm_provider = fetch_existing_llm_provider_by_id( id=llm_provider_upsert_request.id, db_session=db_session ) if not existing_llm_provider: raise ValueError( f"LLM provider with id {llm_provider_upsert_request.id} not found" ) if existing_llm_provider.name != llm_provider_upsert_request.name: raise ValueError( f"LLM provider with id {llm_provider_upsert_request.id} name change not allowed" ) else: existing_llm_provider = fetch_existing_llm_provider( name=llm_provider_upsert_request.name, db_session=db_session ) if existing_llm_provider: raise ValueError( f"LLM provider with name '{llm_provider_upsert_request.name}' already exists" ) existing_llm_provider = LLMProviderModel(name=llm_provider_upsert_request.name) db_session.add(existing_llm_provider) # Filter out empty strings and None values from custom_config to allow # providers like Bedrock to fall back to IAM roles when credentials are not provided custom_config = llm_provider_upsert_request.custom_config if custom_config: custom_config = { k: v for k, v in custom_config.items() if v is not None and v.strip() != "" } # Set to None if the dict is empty after filtering custom_config = custom_config or None api_base = llm_provider_upsert_request.api_base or None existing_llm_provider.provider = llm_provider_upsert_request.provider # EncryptedString accepts str for writes, returns SensitiveValue for reads existing_llm_provider.api_key = llm_provider_upsert_request.api_key # type: ignore[assignment] existing_llm_provider.api_base = api_base existing_llm_provider.api_version = llm_provider_upsert_request.api_version existing_llm_provider.custom_config = custom_config existing_llm_provider.is_public = llm_provider_upsert_request.is_public existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name if not existing_llm_provider.id: # If its not already in the db, we need to generate an ID by flushing db_session.flush() # Build a lookup of existing model configurations by name (single iteration) existing_by_name = { mc.name: mc for mc in existing_llm_provider.model_configurations } models_to_exist = { mc.name for mc in llm_provider_upsert_request.model_configurations } # Build a lookup of requested visibility by model name requested_visibility = { mc.name: mc.is_visible for mc in llm_provider_upsert_request.model_configurations } # Delete removed models removed_ids = [ mc.id for name, mc in existing_by_name.items() if name not in models_to_exist ] default_model = fetch_default_llm_model(db_session) # Prevent removing and hiding the default model if default_model: for name, mc in existing_by_name.items(): if mc.id == default_model.id: if default_model.id in removed_ids: raise ValueError( f"Cannot remove the default model '{name}'. Please change the default model before removing." ) if not requested_visibility.get(name, True): raise ValueError( f"Cannot hide the default model '{name}'. Please change the default model before hiding." ) break if removed_ids: db_session.query(ModelConfiguration).filter( ModelConfiguration.id.in_(removed_ids) ).delete(synchronize_session="fetch") db_session.flush() # Import here to avoid circular imports from onyx.llm.utils import get_max_input_tokens for model_config in llm_provider_upsert_request.model_configurations: max_input_tokens = model_config.max_input_tokens if max_input_tokens is None: max_input_tokens = get_max_input_tokens( model_name=model_config.name, model_provider=llm_provider_upsert_request.provider, ) supported_flows = [LLMModelFlowType.CHAT] if model_config.supports_image_input: supported_flows.append(LLMModelFlowType.VISION) existing = existing_by_name.get(model_config.name) if existing: update_model_configuration__no_commit( db_session=db_session, model_configuration_id=existing.id, supported_flows=supported_flows, is_visible=model_config.is_visible, max_input_tokens=max_input_tokens, display_name=model_config.display_name, ) else: insert_new_model_configuration__no_commit( db_session=db_session, llm_provider_id=existing_llm_provider.id, model_name=model_config.name, supported_flows=supported_flows, is_visible=model_config.is_visible, max_input_tokens=max_input_tokens, display_name=model_config.display_name, ) # Make sure the relationship table stays up to date update_group_llm_provider_relationships__no_commit( llm_provider_id=existing_llm_provider.id, group_ids=llm_provider_upsert_request.groups, db_session=db_session, ) update_llm_provider_persona_relationships__no_commit( db_session=db_session, llm_provider_id=existing_llm_provider.id, persona_ids=llm_provider_upsert_request.personas, ) db_session.flush() db_session.refresh(existing_llm_provider) try: db_session.commit() except Exception as e: db_session.rollback() raise ValueError(f"Failed to save LLM provider: {str(e)}") from e full_llm_provider = LLMProviderView.from_model(existing_llm_provider) return full_llm_provider def sync_model_configurations( db_session: Session, provider_name: str, models: list[SyncModelEntry], ) -> int: """Sync model configurations for a dynamic provider (OpenRouter, Bedrock, Ollama, etc.). This inserts NEW models from the source API without overwriting existing ones. User preferences (is_visible, max_input_tokens) are preserved for existing models. Args: db_session: Database session provider_name: Name of the LLM provider models: List of SyncModelEntry objects describing the fetched models Returns: Number of new models added """ provider = fetch_existing_llm_provider(name=provider_name, db_session=db_session) if not provider: raise ValueError(f"LLM Provider '{provider_name}' not found") # Get existing model names to count new additions existing_names = {mc.name for mc in provider.model_configurations} new_count = 0 for model in models: if model.name not in existing_names: # Insert new model with is_visible=False (user must explicitly enable) supported_flows = [LLMModelFlowType.CHAT] if model.supports_image_input: supported_flows.append(LLMModelFlowType.VISION) insert_new_model_configuration__no_commit( db_session=db_session, llm_provider_id=provider.id, model_name=model.name, supported_flows=supported_flows, is_visible=False, max_input_tokens=model.max_input_tokens, display_name=model.display_name, ) new_count += 1 if new_count > 0: db_session.commit() return new_count def fetch_existing_embedding_providers( db_session: Session, ) -> list[CloudEmbeddingProviderModel]: return list(db_session.scalars(select(CloudEmbeddingProviderModel)).all()) def fetch_existing_doc_sets( db_session: Session, doc_ids: list[int] ) -> list[DocumentSet]: return list( db_session.scalars(select(DocumentSet).where(DocumentSet.id.in_(doc_ids))).all() ) def fetch_existing_tools(db_session: Session, tool_ids: list[int]) -> list[ToolModel]: return list( db_session.scalars(select(ToolModel).where(ToolModel.id.in_(tool_ids))).all() ) def fetch_existing_models( db_session: Session, flow_types: list[LLMModelFlowType], ) -> list[ModelConfiguration]: models = ( select(ModelConfiguration) .join(LLMModelFlow) .where(LLMModelFlow.llm_model_flow_type.in_(flow_types)) .options( selectinload(ModelConfiguration.llm_provider), selectinload(ModelConfiguration.llm_model_flows), ) ) return list(db_session.scalars(models).all()) def fetch_existing_llm_providers( db_session: Session, flow_type_filter: list[LLMModelFlowType], only_public: bool = False, exclude_image_generation_providers: bool = True, ) -> list[LLMProviderModel]: """Fetch all LLM providers with optional filtering. Args: db_session: Database session flow_type_filter: List of flow types to filter by, empty list for no filter only_public: If True, only return public providers exclude_image_generation_providers: If True, exclude providers that are used for image generation configs """ stmt = select(LLMProviderModel) if flow_type_filter: providers_with_flows = ( select(ModelConfiguration.llm_provider_id) .join(LLMModelFlow) .where(LLMModelFlow.llm_model_flow_type.in_(flow_type_filter)) .distinct() ) stmt = stmt.where(LLMProviderModel.id.in_(providers_with_flows)) if exclude_image_generation_providers: image_gen_provider_ids = select(ModelConfiguration.llm_provider_id).join( ImageGenerationConfig ) stmt = stmt.where(~LLMProviderModel.id.in_(image_gen_provider_ids)) stmt = stmt.options( selectinload(LLMProviderModel.model_configurations), selectinload(LLMProviderModel.groups), selectinload(LLMProviderModel.personas), ) providers = list(db_session.scalars(stmt).all()) if only_public: return [provider for provider in providers if provider.is_public] return providers def fetch_existing_llm_provider( name: str, db_session: Session ) -> LLMProviderModel | None: provider_model = db_session.scalar( select(LLMProviderModel) .where(LLMProviderModel.name == name) .options( selectinload(LLMProviderModel.model_configurations), selectinload(LLMProviderModel.groups), selectinload(LLMProviderModel.personas), ) ) return provider_model def fetch_existing_llm_provider_by_id( id: int, db_session: Session ) -> LLMProviderModel | None: provider_model = db_session.scalar( select(LLMProviderModel) .where(LLMProviderModel.id == id) .options( selectinload(LLMProviderModel.model_configurations), selectinload(LLMProviderModel.groups), selectinload(LLMProviderModel.personas), ) ) return provider_model def fetch_embedding_provider( db_session: Session, provider_type: EmbeddingProvider ) -> CloudEmbeddingProviderModel | None: return db_session.scalar( select(CloudEmbeddingProviderModel).where( CloudEmbeddingProviderModel.provider_type == provider_type ) ) def fetch_default_llm_model(db_session: Session) -> ModelConfiguration | None: return fetch_default_model(db_session, LLMModelFlowType.CHAT) def fetch_default_vision_model(db_session: Session) -> ModelConfiguration | None: return fetch_default_model(db_session, LLMModelFlowType.VISION) def fetch_default_contextual_rag_model( db_session: Session, ) -> ModelConfiguration | None: return fetch_default_model(db_session, LLMModelFlowType.CONTEXTUAL_RAG) def fetch_default_model( db_session: Session, flow_type: LLMModelFlowType, ) -> ModelConfiguration | None: model_config = db_session.scalar( select(ModelConfiguration) .options(selectinload(ModelConfiguration.llm_provider)) .join(LLMModelFlow) .where( LLMModelFlow.llm_model_flow_type == flow_type, LLMModelFlow.is_default == True, # noqa: E712 ) ) return model_config def fetch_llm_provider_view( db_session: Session, provider_name: str ) -> LLMProviderView | None: provider_model = fetch_existing_llm_provider( name=provider_name, db_session=db_session ) if not provider_model: return None return LLMProviderView.from_model(provider_model) def remove_embedding_provider( db_session: Session, provider_type: EmbeddingProvider ) -> None: db_session.execute( delete(SearchSettings).where(SearchSettings.provider_type == provider_type) ) # Delete the embedding provider db_session.execute( delete(CloudEmbeddingProviderModel).where( CloudEmbeddingProviderModel.provider_type == provider_type ) ) db_session.commit() def remove_llm_provider(db_session: Session, provider_id: int) -> None: provider = db_session.get(LLMProviderModel, provider_id) if not provider: raise ValueError("LLM Provider not found") # Clear the provider override from any personas using it # This causes them to fall back to the default provider personas_using_provider = get_personas_using_provider(db_session, provider.name) for persona in personas_using_provider: persona.llm_model_provider_override = None db_session.execute( delete(LLMProvider__UserGroup).where( LLMProvider__UserGroup.llm_provider_id == provider_id ) ) # Remove LLMProvider db_session.execute( delete(LLMProviderModel).where(LLMProviderModel.id == provider_id) ) db_session.commit() def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> None: """Remove LLM provider.""" provider = db_session.get(LLMProviderModel, provider_id) if not provider: raise ValueError("LLM Provider not found") # Clear the provider override from any personas using it # This causes them to fall back to the default provider personas_using_provider = get_personas_using_provider(db_session, provider.name) for persona in personas_using_provider: persona.llm_model_provider_override = None db_session.execute( delete(LLMProvider__UserGroup).where( LLMProvider__UserGroup.llm_provider_id == provider_id ) ) # Remove LLMProvider db_session.execute( delete(LLMProviderModel).where(LLMProviderModel.id == provider_id) ) db_session.flush() def update_default_provider( provider_id: int, model_name: str, db_session: Session ) -> None: _update_default_model( db_session, provider_id, model_name, LLMModelFlowType.CHAT, ) def update_default_vision_provider( provider_id: int, vision_model: str, db_session: Session ) -> None: provider = db_session.scalar( select(LLMProviderModel).where( LLMProviderModel.id == provider_id, ) ) if provider is None: raise ValueError(f"LLM Provider with id={provider_id} does not exist") if not model_supports_image_input(vision_model, provider.provider): raise ValueError( f"Model '{vision_model}' for provider '{provider.provider} does not support image input" ) _update_default_model( db_session=db_session, provider_id=provider_id, model=vision_model, flow_type=LLMModelFlowType.VISION, ) def update_no_default_contextual_rag_provider( db_session: Session, ) -> None: db_session.execute( update(LLMModelFlow) .where( LLMModelFlow.llm_model_flow_type == LLMModelFlowType.CONTEXTUAL_RAG, LLMModelFlow.is_default == True, # noqa: E712 ) .values(is_default=False) ) db_session.commit() def update_default_contextual_model( db_session: Session, enable_contextual_rag: bool, contextual_rag_llm_provider: str | None, contextual_rag_llm_name: str | None, ) -> None: """Sets or clears the default contextual RAG model. Should be called whenever the PRESENT search settings change (e.g. inline update or FUTURE → PRESENT swap). """ if ( not enable_contextual_rag or not contextual_rag_llm_name or not contextual_rag_llm_provider ): update_no_default_contextual_rag_provider(db_session=db_session) return provider = fetch_existing_llm_provider( name=contextual_rag_llm_provider, db_session=db_session ) if not provider: raise ValueError(f"Provider '{contextual_rag_llm_provider}' not found") model_config = next( ( mc for mc in provider.model_configurations if mc.name == contextual_rag_llm_name ), None, ) if not model_config: raise ValueError( f"Model '{contextual_rag_llm_name}' not found for provider '{contextual_rag_llm_provider}'" ) add_model_to_flow( db_session=db_session, model_configuration_id=model_config.id, flow_type=LLMModelFlowType.CONTEXTUAL_RAG, ) _update_default_model( db_session=db_session, provider_id=provider.id, model=contextual_rag_llm_name, flow_type=LLMModelFlowType.CONTEXTUAL_RAG, ) return def fetch_auto_mode_providers(db_session: Session) -> list[LLMProviderModel]: """Fetch all LLM providers that are in Auto mode.""" query = ( select(LLMProviderModel) .where(LLMProviderModel.is_auto_mode.is_(True)) .options(selectinload(LLMProviderModel.model_configurations)) ) return list(db_session.scalars(query).all()) def sync_auto_mode_models( db_session: Session, provider: LLMProviderModel, llm_recommendations: LLMRecommendations, ) -> int: """Sync models from GitHub config to a provider in Auto mode. In Auto mode, the model list and default are controlled by GitHub config. The schema has: - default_model: The default model config (always visible) - additional_visible_models: List of additional visible models Admin only provides API credentials. Args: db_session: Database session provider: LLM provider in Auto mode github_config: Configuration from GitHub Returns: The number of changes made. """ changes = 0 # Build the list of all visible models from the config # All models in the config are visible (default + additional_visible_models) recommended_visible_models = llm_recommendations.get_visible_models( provider.provider ) recommended_visible_model_names = [ model.name for model in recommended_visible_models ] # Get existing models existing_models: dict[str, ModelConfiguration] = { mc.name: mc for mc in db_session.scalars( select(ModelConfiguration).where( ModelConfiguration.llm_provider_id == provider.id ) ).all() } # Mark models that are no longer in GitHub config as not visible for model_name, model in existing_models.items(): if model_name not in recommended_visible_model_names: if model.is_visible: model.is_visible = False changes += 1 # Add or update models from GitHub config for model_config in recommended_visible_models: if model_config.name in existing_models: # Update existing model existing = existing_models[model_config.name] # Check each field for changes updated = False if existing.display_name != model_config.display_name: existing.display_name = model_config.display_name updated = True # All models in the config are visible if not existing.is_visible: existing.is_visible = True updated = True if updated: changes += 1 else: # Add new model - all models from GitHub config are visible insert_new_model_configuration__no_commit( db_session=db_session, llm_provider_id=provider.id, model_name=model_config.name, supported_flows=[LLMModelFlowType.CHAT], is_visible=True, max_input_tokens=None, display_name=model_config.display_name, ) changes += 1 # Update the default if this provider currently holds the global CHAT default. # We flush (but don't commit) so that _update_default_model can see the new # model rows, then commit everything atomically to avoid a window where the # old default is invisible but still pointed-to. db_session.flush() recommended_default = llm_recommendations.get_default_model(provider.provider) if recommended_default: current_default = fetch_default_llm_model(db_session) if ( current_default and current_default.llm_provider_id == provider.id and current_default.name != recommended_default.name ): _update_default_model__no_commit( db_session=db_session, provider_id=provider.id, model=recommended_default.name, flow_type=LLMModelFlowType.CHAT, ) changes += 1 db_session.commit() return changes def create_new_flow_mapping__no_commit( db_session: Session, model_configuration_id: int, flow_type: LLMModelFlowType, ) -> LLMModelFlow: result = db_session.execute( insert(LLMModelFlow) .values( model_configuration_id=model_configuration_id, llm_model_flow_type=flow_type, is_default=False, ) .on_conflict_do_nothing() .returning(LLMModelFlow) ) flow = result.scalar() if not flow: # Row already exists — fetch it flow = db_session.scalar( select(LLMModelFlow).where( LLMModelFlow.model_configuration_id == model_configuration_id, LLMModelFlow.llm_model_flow_type == flow_type, ) ) if not flow: raise ValueError( f"Failed to create or find flow mapping for model_configuration_id={model_configuration_id} and flow_type={flow_type}" ) return flow def insert_new_model_configuration__no_commit( db_session: Session, llm_provider_id: int, model_name: str, supported_flows: list[LLMModelFlowType], is_visible: bool, max_input_tokens: int | None, display_name: str | None, ) -> int | None: result = db_session.execute( insert(ModelConfiguration) .values( llm_provider_id=llm_provider_id, name=model_name, is_visible=is_visible, max_input_tokens=max_input_tokens, display_name=display_name, supports_image_input=LLMModelFlowType.VISION in supported_flows, ) .on_conflict_do_nothing() .returning(ModelConfiguration.id) ) model_config_id = result.scalar() if not model_config_id: return None for flow_type in supported_flows: create_new_flow_mapping__no_commit( db_session=db_session, model_configuration_id=model_config_id, flow_type=flow_type, ) return model_config_id def update_model_configuration__no_commit( db_session: Session, model_configuration_id: int, supported_flows: list[LLMModelFlowType], is_visible: bool, max_input_tokens: int | None, display_name: str | None, ) -> None: result = db_session.execute( update(ModelConfiguration) .values( is_visible=is_visible, max_input_tokens=max_input_tokens, display_name=display_name, supports_image_input=LLMModelFlowType.VISION in supported_flows, ) .where(ModelConfiguration.id == model_configuration_id) .returning(ModelConfiguration) ) model_configuration = result.scalar() if not model_configuration: raise ValueError( f"Failed to update model configuration with id={model_configuration_id}" ) new_flows = { flow_type for flow_type in supported_flows if flow_type not in model_configuration.llm_model_flow_types } removed_flows = { flow_type for flow_type in model_configuration.llm_model_flow_types if flow_type not in supported_flows } for flow_type in new_flows: create_new_flow_mapping__no_commit( db_session=db_session, model_configuration_id=model_configuration_id, flow_type=flow_type, ) for flow_type in removed_flows: db_session.execute( delete(LLMModelFlow).where( LLMModelFlow.model_configuration_id == model_configuration_id, LLMModelFlow.llm_model_flow_type == flow_type, ) ) db_session.flush() def _update_default_model__no_commit( db_session: Session, provider_id: int, model: str, flow_type: LLMModelFlowType, ) -> None: result = db_session.execute( select(ModelConfiguration, LLMModelFlow) .join( LLMModelFlow, LLMModelFlow.model_configuration_id == ModelConfiguration.id ) .where( ModelConfiguration.llm_provider_id == provider_id, ModelConfiguration.name == model, LLMModelFlow.llm_model_flow_type == flow_type, ) ).first() if not result: raise ValueError( f"Model '{model}' is not a valid model for provider_id={provider_id}" ) model_config, new_default = result # Clear existing default and set in an atomic operation db_session.execute( update(LLMModelFlow) .where( LLMModelFlow.llm_model_flow_type == flow_type, LLMModelFlow.is_default == True, # noqa: E712 ) .values(is_default=False) ) new_default.is_default = True model_config.is_visible = True def _update_default_model( db_session: Session, provider_id: int, model: str, flow_type: LLMModelFlowType, ) -> None: _update_default_model__no_commit(db_session, provider_id, model, flow_type) db_session.commit() def add_model_to_flow( db_session: Session, model_configuration_id: int, flow_type: LLMModelFlowType, ) -> None: # Function does nothing on conflict create_new_flow_mapping__no_commit( db_session=db_session, model_configuration_id=model_configuration_id, flow_type=flow_type, ) db_session.commit() ================================================ FILE: backend/onyx/db/mcp.py ================================================ import datetime from typing import cast from uuid import UUID from sqlalchemy import and_ from sqlalchemy import delete from sqlalchemy import select from sqlalchemy.orm import Session from sqlalchemy.orm.attributes import flag_modified from onyx.db.enums import MCPAuthenticationPerformer from onyx.db.enums import MCPServerStatus from onyx.db.enums import MCPTransport from onyx.db.models import MCPAuthenticationType from onyx.db.models import MCPConnectionConfig from onyx.db.models import MCPServer from onyx.db.models import Persona from onyx.db.models import Tool from onyx.db.models import User from onyx.server.features.mcp.models import MCPConnectionData from onyx.utils.logger import setup_logger from onyx.utils.sensitive import SensitiveValue logger = setup_logger() # MCPServer operations def get_all_mcp_servers(db_session: Session) -> list[MCPServer]: """Get all MCP servers""" return list( db_session.scalars(select(MCPServer).order_by(MCPServer.created_at)).all() ) def get_mcp_server_by_id(server_id: int, db_session: Session) -> MCPServer: """Get MCP server by ID""" server = db_session.scalar(select(MCPServer).where(MCPServer.id == server_id)) if not server: raise ValueError("MCP server by specified id does not exist") return server def get_mcp_servers_by_owner(owner_email: str, db_session: Session) -> list[MCPServer]: """Get all MCP servers owned by a specific user""" return list( db_session.scalars( select(MCPServer).where(MCPServer.owner == owner_email) ).all() ) def get_mcp_servers_for_persona( persona_id: int, db_session: Session, user: User, # noqa: ARG001 ) -> list[MCPServer]: """Get all MCP servers associated with a persona via its tools""" # Get the persona and its tools persona = db_session.query(Persona).filter(Persona.id == persona_id).first() if not persona: return [] # Collect unique MCP server IDs from the persona's tools mcp_server_ids = set() for tool in persona.tools: if tool.mcp_server_id: mcp_server_ids.add(tool.mcp_server_id) if not mcp_server_ids: return [] # Fetch the MCP servers mcp_servers = ( db_session.query(MCPServer).filter(MCPServer.id.in_(mcp_server_ids)).all() ) return list(mcp_servers) def get_mcp_servers_accessible_to_user( user_id: UUID, db_session: Session ) -> list[MCPServer]: """Get all MCP servers accessible to a user (directly or through groups)""" user = db_session.scalar(select(User).where(User.id == user_id)) # type: ignore if not user: return [] user = cast(User, user) # Get servers accessible directly to user user_servers = list(user.accessible_mcp_servers) # TODO: Add group-based access once relationships are fully implemented # For now, just return direct user access return user_servers def create_mcp_server__no_commit( owner_email: str, name: str, description: str | None, server_url: str, auth_type: MCPAuthenticationType | None, transport: MCPTransport | None, auth_performer: MCPAuthenticationPerformer | None, db_session: Session, admin_connection_config_id: int | None = None, ) -> MCPServer: """Create a new MCP server""" new_server = MCPServer( owner=owner_email, name=name, description=description, server_url=server_url, transport=transport, auth_type=auth_type, auth_performer=auth_performer, admin_connection_config_id=admin_connection_config_id, ) db_session.add(new_server) db_session.flush() # Get the ID without committing return new_server def update_mcp_server__no_commit( server_id: int, db_session: Session, name: str | None = None, description: str | None = None, server_url: str | None = None, auth_type: MCPAuthenticationType | None = None, admin_connection_config_id: int | None = None, auth_performer: MCPAuthenticationPerformer | None = None, transport: MCPTransport | None = None, status: MCPServerStatus | None = None, last_refreshed_at: datetime.datetime | None = None, ) -> MCPServer: """Update an existing MCP server""" server = get_mcp_server_by_id(server_id, db_session) if name is not None: server.name = name if description is not None: server.description = description if server_url is not None: server.server_url = server_url if auth_type is not None: server.auth_type = auth_type if admin_connection_config_id is not None: server.admin_connection_config_id = admin_connection_config_id if auth_performer is not None: server.auth_performer = auth_performer if transport is not None: server.transport = transport if status is not None: server.status = status if last_refreshed_at is not None: server.last_refreshed_at = last_refreshed_at db_session.flush() # Don't commit yet, let caller decide when to commit return server def delete_mcp_server(server_id: int, db_session: Session) -> None: """Delete an MCP server and all associated tools (via CASCADE)""" server = get_mcp_server_by_id(server_id, db_session) # Count tools that will be deleted tools_count = db_session.query(Tool).filter(Tool.mcp_server_id == server_id).count() logger.info(f"Deleting MCP server {server_id} with {tools_count} associated tools") db_session.delete(server) db_session.commit() logger.info(f"Successfully deleted MCP server {server_id} and its tools") def get_all_mcp_tools_for_server(server_id: int, db_session: Session) -> list[Tool]: """Get all MCP tools for a server""" return list( db_session.scalars(select(Tool).where(Tool.mcp_server_id == server_id)).all() ) def add_user_to_mcp_server(server_id: int, user_id: UUID, db_session: Session) -> None: """Grant a user access to an MCP server""" server = get_mcp_server_by_id(server_id, db_session) user = db_session.scalar(select(User).where(User.id == user_id)) # type: ignore if not user: raise ValueError("User not found") if user not in server.users: server.users.append(user) db_session.commit() def remove_user_from_mcp_server( server_id: int, user_id: UUID, db_session: Session ) -> None: """Remove a user's access to an MCP server""" server = get_mcp_server_by_id(server_id, db_session) user = db_session.scalar(select(User).where(User.id == user_id)) # type: ignore if not user: raise ValueError("User not found") if user in server.users: server.users.remove(user) db_session.commit() # MCPConnectionConfig operations def extract_connection_data( config: MCPConnectionConfig | None, apply_mask: bool = False ) -> MCPConnectionData: """Extract MCPConnectionData from a connection config, with proper typing. This helper encapsulates the cast from the JSON column's dict[str, Any] to the typed MCPConnectionData structure. """ if config is None or config.config is None: return MCPConnectionData(headers={}) if isinstance(config.config, SensitiveValue): return cast(MCPConnectionData, config.config.get_value(apply_mask=apply_mask)) return cast(MCPConnectionData, config.config) def get_connection_config_by_id( config_id: int, db_session: Session ) -> MCPConnectionConfig: """Get connection config by ID""" config = db_session.scalar( select(MCPConnectionConfig).where(MCPConnectionConfig.id == config_id) ) if not config: raise ValueError("Connection config by specified id does not exist") return config def get_user_connection_config( server_id: int, user_email: str, db_session: Session ) -> MCPConnectionConfig | None: """Get a user's connection config for a specific MCP server""" return db_session.scalar( select(MCPConnectionConfig).where( and_( MCPConnectionConfig.mcp_server_id == server_id, MCPConnectionConfig.user_email == user_email, ) ) ) def get_user_connection_configs_for_server( server_id: int, db_session: Session ) -> list[MCPConnectionConfig]: """Get all user connection configs for a specific MCP server""" return list( db_session.scalars( select(MCPConnectionConfig).where( MCPConnectionConfig.mcp_server_id == server_id ) ).all() ) def create_connection_config( config_data: MCPConnectionData, db_session: Session, mcp_server_id: int | None = None, user_email: str = "", ) -> MCPConnectionConfig: """Create a new connection config""" new_config = MCPConnectionConfig( mcp_server_id=mcp_server_id, user_email=user_email, config=config_data, ) db_session.add(new_config) db_session.flush() # Don't commit yet, let caller decide when to commit return new_config def update_connection_config( config_id: int, db_session: Session, config_data: MCPConnectionData | None = None, ) -> MCPConnectionConfig: """Update an existing connection config""" config = get_connection_config_by_id(config_id, db_session) if config_data is not None: config.config = config_data # type: ignore[assignment] # Force SQLAlchemy to detect the change by marking the field as modified flag_modified(config, "config") db_session.commit() return config def upsert_user_connection_config( server_id: int, user_email: str, config_data: MCPConnectionData, db_session: Session, ) -> MCPConnectionConfig: """Create or update a user's connection config for an MCP server""" existing_config = get_user_connection_config(server_id, user_email, db_session) if existing_config: existing_config.config = config_data # type: ignore[assignment] db_session.flush() # Don't commit yet, let caller decide when to commit return existing_config else: return create_connection_config( config_data=config_data, mcp_server_id=server_id, user_email=user_email, db_session=db_session, ) # TODO: do this in one db call def get_server_auth_template( server_id: int, db_session: Session ) -> MCPConnectionConfig | None: """Get the authentication template for a server (from the admin connection config)""" server = get_mcp_server_by_id(server_id, db_session) if not server.admin_connection_config_id: return None if server.auth_performer == MCPAuthenticationPerformer.ADMIN: return None # admin server implies no template return server.admin_connection_config def delete_connection_config(config_id: int, db_session: Session) -> None: """Delete a connection config""" config = get_connection_config_by_id(config_id, db_session) db_session.delete(config) db_session.flush() # Don't commit yet, let caller decide when to commit def delete_user_connection_configs_for_server( server_id: int, user_email: str, db_session: Session ) -> None: """Delete all connection configs for a user on a specific server""" configs = db_session.scalars( select(MCPConnectionConfig).where( and_( MCPConnectionConfig.mcp_server_id == server_id, MCPConnectionConfig.user_email == user_email, ) ) ).all() for config in configs: db_session.delete(config) db_session.commit() def delete_all_user_connection_configs_for_server_no_commit( server_id: int, db_session: Session ) -> None: """Delete all user connection configs for a specific MCP server""" db_session.execute( delete(MCPConnectionConfig).where( MCPConnectionConfig.mcp_server_id == server_id ) ) db_session.flush() # Don't commit yet, let caller decide when to commit ================================================ FILE: backend/onyx/db/memory.py ================================================ from uuid import UUID from pydantic import BaseModel from pydantic import ConfigDict from sqlalchemy import select from sqlalchemy.orm import Session from onyx.db.models import Memory from onyx.db.models import User MAX_MEMORIES_PER_USER = 10 class UserInfo(BaseModel): name: str | None = None role: str | None = None email: str | None = None def to_dict(self) -> dict: return { "name": self.name, "role": self.role, "email": self.email, } class UserMemoryContext(BaseModel): model_config = ConfigDict(frozen=True) user_id: UUID | None = None user_info: UserInfo user_preferences: str | None = None memories: tuple[str, ...] = () def without_memories(self) -> "UserMemoryContext": """Return a copy with memories cleared but user info/preferences intact.""" return UserMemoryContext( user_id=self.user_id, user_info=self.user_info, user_preferences=self.user_preferences, memories=(), ) def as_formatted_list(self) -> list[str]: """Returns combined list of user info, preferences, and memories.""" result = [] if self.user_info.name: result.append(f"User's name: {self.user_info.name}") if self.user_info.role: result.append(f"User's role: {self.user_info.role}") if self.user_info.email: result.append(f"User's email: {self.user_info.email}") if self.user_preferences: result.append(f"User preferences: {self.user_preferences}") result.extend(self.memories) return result def get_memories(user: User, db_session: Session) -> UserMemoryContext: user_info = UserInfo( name=user.personal_name, role=user.personal_role, email=user.email, ) user_preferences = None if user.user_preferences: user_preferences = user.user_preferences memory_rows = db_session.scalars( select(Memory).where(Memory.user_id == user.id).order_by(Memory.id.asc()) ).all() memories = tuple(memory.memory_text for memory in memory_rows if memory.memory_text) return UserMemoryContext( user_id=user.id, user_info=user_info, user_preferences=user_preferences, memories=memories, ) def add_memory( user_id: UUID, memory_text: str, db_session: Session, ) -> Memory: """Insert a new Memory row for the given user. If the user already has MAX_MEMORIES_PER_USER memories, the oldest one (lowest id) is deleted before inserting the new one. """ existing = db_session.scalars( select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc()) ).all() if len(existing) >= MAX_MEMORIES_PER_USER: db_session.delete(existing[0]) memory = Memory( user_id=user_id, memory_text=memory_text, ) db_session.add(memory) db_session.commit() return memory def update_memory_at_index( user_id: UUID, index: int, new_text: str, db_session: Session, ) -> Memory | None: """Update the memory at the given 0-based index (ordered by id ASC, matching get_memories()). Returns the updated Memory row, or None if the index is out of range. """ memory_rows = db_session.scalars( select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.asc()) ).all() if index < 0 or index >= len(memory_rows): return None memory = memory_rows[index] memory.memory_text = new_text db_session.commit() return memory ================================================ FILE: backend/onyx/db/models.py ================================================ import datetime import json from typing import Any from typing import Literal from typing import NotRequired from uuid import uuid4 from pydantic import BaseModel from sqlalchemy.orm import validates from typing_extensions import TypedDict # noreorder from uuid import UUID from pydantic import ValidationError from sqlalchemy.dialects.postgresql import JSONB as PGJSONB from sqlalchemy.dialects.postgresql import UUID as PGUUID from fastapi_users_db_sqlalchemy import SQLAlchemyBaseOAuthAccountTableUUID from fastapi_users_db_sqlalchemy import SQLAlchemyBaseUserTableUUID from fastapi_users_db_sqlalchemy.access_token import SQLAlchemyBaseAccessTokenTableUUID from fastapi_users_db_sqlalchemy.generics import TIMESTAMPAware from sqlalchemy import Boolean from sqlalchemy import DateTime from sqlalchemy import desc from sqlalchemy import Enum from sqlalchemy import Float from sqlalchemy import ForeignKey from sqlalchemy import ForeignKeyConstraint from sqlalchemy import func from sqlalchemy import Index from sqlalchemy import Integer from sqlalchemy import BigInteger from sqlalchemy import Sequence from sqlalchemy import String from sqlalchemy import Text from sqlalchemy import text from sqlalchemy import UniqueConstraint from sqlalchemy.dialects import postgresql from sqlalchemy import event from sqlalchemy.engine.interfaces import Dialect from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import Mapped from sqlalchemy.orm import Mapper from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from sqlalchemy.types import LargeBinary from sqlalchemy.types import TypeDecorator from sqlalchemy import PrimaryKeyConstraint from onyx.db.enums import AccountType from onyx.auth.schemas import UserRole from onyx.configs.constants import ( ANONYMOUS_USER_UUID, DEFAULT_BOOST, FederatedConnectorSource, MilestoneRecordType, ) from onyx.configs.constants import DocumentSource from onyx.configs.constants import FileOrigin from onyx.configs.constants import MessageType from onyx.db.enums import ( AccessType, ArtifactType, BuildSessionStatus, EmbeddingPrecision, HierarchyNodeType, HookFailStrategy, HookPoint, IndexingMode, OpenSearchDocumentMigrationStatus, OpenSearchTenantMigrationStatus, ProcessingMode, SandboxStatus, SyncType, SyncStatus, MCPAuthenticationType, UserFileStatus, MCPAuthenticationPerformer, MCPTransport, MCPServerStatus, Permission, GrantSource, LLMModelFlowType, ThemePreference, DefaultAppMode, SwitchoverType, SharingScope, ) from onyx.configs.constants import NotificationType from onyx.configs.constants import SearchFeedbackType from onyx.configs.constants import TokenRateLimitScope from onyx.connectors.models import InputType from onyx.db.enums import ChatSessionSharedStatus from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import IndexingStatus from onyx.db.enums import IndexModelStatus from onyx.db.enums import PermissionSyncStatus from onyx.db.enums import TaskStatus from onyx.db.pydantic_type import PydanticListType, PydanticType from onyx.kg.models import KGEntityTypeAttributes from onyx.utils.logger import setup_logger from onyx.utils.special_types import JSON_ro from onyx.file_store.models import FileDescriptor from onyx.llm.override_models import LLMOverride from onyx.llm.override_models import PromptOverride from onyx.kg.models import KGStage from onyx.tools.tool_implementations.web_search.models import WebContentProviderConfig from onyx.utils.encryption import decrypt_bytes_to_string from onyx.utils.encryption import encrypt_string_to_bytes from onyx.utils.sensitive import SensitiveValue from onyx.utils.headers import HeaderItemDict from shared_configs.enums import EmbeddingProvider # TODO: After anonymous user migration has been deployed, make user_id columns NOT NULL # and update Mapped[User | None] relationships to Mapped[User] where needed. logger = setup_logger() PROMPT_LENGTH = 5_000_000 class Base(DeclarativeBase): __abstract__ = True class _EncryptedBase(TypeDecorator): """Base for encrypted column types that wrap values in SensitiveValue.""" impl = LargeBinary cache_ok = True _is_json: bool = False def wrap_raw(self, value: Any) -> SensitiveValue: """Encrypt a raw value and wrap it in SensitiveValue. Called by the attribute set event so the Python-side type is always SensitiveValue, regardless of whether the value was loaded from the DB or assigned in application code. """ if self._is_json: if not isinstance(value, dict): raise TypeError( f"EncryptedJson column expected dict, got {type(value).__name__}" ) raw_str = json.dumps(value) else: if not isinstance(value, str): raise TypeError( f"EncryptedString column expected str, got {type(value).__name__}" ) raw_str = value return SensitiveValue( encrypted_bytes=encrypt_string_to_bytes(raw_str), decrypt_fn=decrypt_bytes_to_string, is_json=self._is_json, ) def compare_values(self, x: Any, y: Any) -> bool: if x is None or y is None: return x == y if isinstance(x, SensitiveValue): x = x.get_value(apply_mask=False) if isinstance(y, SensitiveValue): y = y.get_value(apply_mask=False) return x == y class EncryptedString(_EncryptedBase): # Must redeclare cache_ok in this child class since we explicitly redeclare _is_json cache_ok = True _is_json: bool = False def process_bind_param( self, value: str | SensitiveValue[str] | None, dialect: Dialect, # noqa: ARG002 ) -> bytes | None: if value is not None: # Handle both raw strings and SensitiveValue wrappers if isinstance(value, SensitiveValue): # Get raw value for storage value = value.get_value(apply_mask=False) return encrypt_string_to_bytes(value) return value def process_result_value( self, value: bytes | None, dialect: Dialect, # noqa: ARG002 ) -> SensitiveValue[str] | None: if value is not None: return SensitiveValue( encrypted_bytes=value, decrypt_fn=decrypt_bytes_to_string, is_json=False, ) return None class EncryptedJson(_EncryptedBase): cache_ok = True _is_json: bool = True def process_bind_param( self, value: dict[str, Any] | SensitiveValue[dict[str, Any]] | None, dialect: Dialect, # noqa: ARG002 ) -> bytes | None: if value is not None: if isinstance(value, SensitiveValue): value = value.get_value(apply_mask=False) json_str = json.dumps(value) return encrypt_string_to_bytes(json_str) return value def process_result_value( self, value: bytes | None, dialect: Dialect, # noqa: ARG002 ) -> SensitiveValue[dict[str, Any]] | None: if value is not None: return SensitiveValue( encrypted_bytes=value, decrypt_fn=decrypt_bytes_to_string, is_json=True, ) return None _REGISTERED_ATTRS: set[str] = set() @event.listens_for(Mapper, "mapper_configured") def _register_sensitive_value_set_events( mapper: Mapper, class_: type, ) -> None: """Auto-wrap raw values in SensitiveValue when assigned to encrypted columns.""" for prop in mapper.column_attrs: for col in prop.columns: if isinstance(col.type, _EncryptedBase): col_type = col.type attr = getattr(class_, prop.key) # Guard against double-registration (e.g. if mapper is # re-configured in test setups) attr_key = f"{class_.__qualname__}.{prop.key}" if attr_key in _REGISTERED_ATTRS: continue _REGISTERED_ATTRS.add(attr_key) @event.listens_for(attr, "set", retval=True) def _wrap_value( target: Any, # noqa: ARG001 value: Any, oldvalue: Any, # noqa: ARG001 initiator: Any, # noqa: ARG001 _col_type: _EncryptedBase = col_type, ) -> Any: if value is not None and not isinstance(value, SensitiveValue): return _col_type.wrap_raw(value) return value class NullFilteredString(TypeDecorator): impl = String # This type's behavior is fully deterministic and doesn't depend on any external factors. cache_ok = True def process_bind_param( self, value: str | None, dialect: Dialect, # noqa: ARG002 ) -> str | None: if value is not None and "\x00" in value: logger.warning(f"NUL characters found in value: {value}") return value.replace("\x00", "") return value def process_result_value( self, value: str | None, dialect: Dialect, # noqa: ARG002 ) -> str | None: return value """ Auth/Authz (users, permissions, access) Tables """ class OAuthAccount(SQLAlchemyBaseOAuthAccountTableUUID, Base): # even an almost empty token from keycloak will not fit the default 1024 bytes access_token: Mapped[str] = mapped_column(Text, nullable=False) # type: ignore refresh_token: Mapped[str] = mapped_column(Text, nullable=False) # type: ignore class User(SQLAlchemyBaseUserTableUUID, Base): oauth_accounts: Mapped[list[OAuthAccount]] = relationship( "OAuthAccount", lazy="joined", cascade="all, delete-orphan" ) role: Mapped[UserRole] = mapped_column( Enum(UserRole, native_enum=False, default=UserRole.BASIC) ) account_type: Mapped[AccountType] = mapped_column( Enum(AccountType, native_enum=False), nullable=False, default=AccountType.STANDARD, server_default="STANDARD", ) """ Preferences probably should be in a separate table at some point, but for now putting here for simpicity """ temperature_override_enabled: Mapped[bool | None] = mapped_column( Boolean, default=None ) auto_scroll: Mapped[bool | None] = mapped_column(Boolean, default=None) shortcut_enabled: Mapped[bool] = mapped_column(Boolean, default=False) theme_preference: Mapped[ThemePreference | None] = mapped_column( Enum(ThemePreference, native_enum=False), nullable=True, default=None, ) chat_background: Mapped[str | None] = mapped_column(String, nullable=True) default_app_mode: Mapped[DefaultAppMode] = mapped_column( Enum(DefaultAppMode, native_enum=False), nullable=False, default=DefaultAppMode.CHAT, ) # personalization fields are exposed via the chat user settings "Personalization" tab personal_name: Mapped[str | None] = mapped_column(String, nullable=True) personal_role: Mapped[str | None] = mapped_column(String, nullable=True) use_memories: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) enable_memory_tool: Mapped[bool] = mapped_column( Boolean, nullable=False, default=True ) user_preferences: Mapped[str | None] = mapped_column(Text, nullable=True) chosen_assistants: Mapped[list[int] | None] = mapped_column( postgresql.JSONB(), nullable=True, default=None ) visible_assistants: Mapped[list[int]] = mapped_column( postgresql.JSONB(), nullable=False, default=[] ) hidden_assistants: Mapped[list[int]] = mapped_column( postgresql.JSONB(), nullable=False, default=[] ) pinned_assistants: Mapped[list[int] | None] = mapped_column( postgresql.JSONB(), nullable=True, default=None ) effective_permissions: Mapped[list[str]] = mapped_column( postgresql.JSONB(), nullable=False, default=list, server_default=text("'[]'::jsonb"), ) oidc_expiry: Mapped[datetime.datetime] = mapped_column( TIMESTAMPAware(timezone=True), nullable=True ) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) updated_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False, ) default_model: Mapped[str] = mapped_column(Text, nullable=True) # organized in typical structured fashion # formatted as `displayName__provider__modelName` # Voice preferences voice_auto_send: Mapped[bool] = mapped_column(Boolean, default=False) voice_auto_playback: Mapped[bool] = mapped_column(Boolean, default=False) voice_playback_speed: Mapped[float] = mapped_column(Float, default=1.0) # relationships credentials: Mapped[list["Credential"]] = relationship( "Credential", back_populates="user" ) chat_sessions: Mapped[list["ChatSession"]] = relationship( "ChatSession", back_populates="user" ) input_prompts: Mapped[list["InputPrompt"]] = relationship( "InputPrompt", back_populates="user" ) # Personas owned by this user personas: Mapped[list["Persona"]] = relationship("Persona", back_populates="user") # Custom tools created by this user custom_tools: Mapped[list["Tool"]] = relationship("Tool", back_populates="user") # Notifications for the UI notifications: Mapped[list["Notification"]] = relationship( "Notification", back_populates="user" ) cc_pairs: Mapped[list["ConnectorCredentialPair"]] = relationship( "ConnectorCredentialPair", back_populates="creator", primaryjoin="User.id == foreign(ConnectorCredentialPair.creator_id)", ) projects: Mapped[list["UserProject"]] = relationship( "UserProject", back_populates="user" ) files: Mapped[list["UserFile"]] = relationship("UserFile", back_populates="user") # MCP servers accessible to this user accessible_mcp_servers: Mapped[list["MCPServer"]] = relationship( "MCPServer", secondary="mcp_server__user", back_populates="users" ) memories: Mapped[list["Memory"]] = relationship( "Memory", back_populates="user", cascade="all, delete-orphan", order_by="desc(Memory.id)", ) oauth_user_tokens: Mapped[list["OAuthUserToken"]] = relationship( "OAuthUserToken", back_populates="user", cascade="all, delete-orphan", ) @validates("email") def validate_email(self, key: str, value: str) -> str: # noqa: ARG002 return value.lower() if value else value @property def password_configured(self) -> bool: """ Returns True if the user has at least one OAuth (or OIDC) account. """ return not bool(self.oauth_accounts) @property def is_anonymous(self) -> bool: """Returns True if this is the anonymous user.""" return str(self.id) == ANONYMOUS_USER_UUID class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base): pass class Memory(Base): __tablename__ = "memory" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) user_id: Mapped[UUID] = mapped_column( PGUUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False ) memory_text: Mapped[str] = mapped_column(Text, nullable=False) conversation_id: Mapped[UUID | None] = mapped_column( PGUUID(as_uuid=True), nullable=True ) message_id: Mapped[int | None] = mapped_column(Integer, nullable=True) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) updated_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False, ) user: Mapped["User"] = relationship("User", back_populates="memories") class ApiKey(Base): __tablename__ = "api_key" id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str | None] = mapped_column(String, nullable=True) hashed_api_key: Mapped[str] = mapped_column(String, unique=True) api_key_display: Mapped[str] = mapped_column(String, unique=True) # the ID of the "user" who represents the access credentials for the API key user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), nullable=False) # the ID of the user who owns the key owner_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) # Add this relationship to access the User object via user_id user: Mapped["User"] = relationship("User", foreign_keys=[user_id]) class PersonalAccessToken(Base): __tablename__ = "personal_access_token" id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(String, nullable=False) # User-provided label hashed_token: Mapped[str] = mapped_column( String(64), unique=True, nullable=False ) # SHA256 = 64 hex chars token_display: Mapped[str] = mapped_column(String, nullable=False) user_id: Mapped[UUID] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), nullable=False ) expires_at: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True, index=True ) # NULL = no expiration. Revocation sets this to NOW() for immediate expiry. # Audit fields created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) last_used_at: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) is_revoked: Mapped[bool] = mapped_column( Boolean, server_default=text("false"), nullable=False ) # True if user explicitly revoked (vs naturally expired) user: Mapped["User"] = relationship("User", foreign_keys=[user_id]) # Indexes for performance __table_args__ = ( Index( "ix_pat_user_created", user_id, created_at.desc() ), # Fast user token listing ) class Notification(Base): __tablename__ = "notification" id: Mapped[int] = mapped_column(primary_key=True) notif_type: Mapped[NotificationType] = mapped_column( Enum(NotificationType, native_enum=False) ) user_id: Mapped[UUID | None] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), nullable=True ) dismissed: Mapped[bool] = mapped_column(Boolean, default=False) last_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True)) first_shown: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True)) title: Mapped[str] = mapped_column(String) description: Mapped[str | None] = mapped_column(String, nullable=True) user: Mapped[User] = relationship("User", back_populates="notifications") additional_data: Mapped[dict | None] = mapped_column( postgresql.JSONB(), nullable=True ) # Unique constraint ix_notification_user_type_data on (user_id, notif_type, additional_data) # ensures notification deduplication for batch inserts. Defined in migration 8405ca81cc83. __table_args__ = ( Index( "ix_notification_user_sort", "user_id", "dismissed", desc("first_shown"), ), ) """ Association Tables NOTE: must be at the top since they are referenced by other tables """ class Persona__DocumentSet(Base): __tablename__ = "persona__document_set" persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True) document_set_id: Mapped[int] = mapped_column( ForeignKey("document_set.id"), primary_key=True ) class Persona__User(Base): __tablename__ = "persona__user" persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True) user_id: Mapped[UUID | None] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), primary_key=True, nullable=True ) class DocumentSet__User(Base): __tablename__ = "document_set__user" document_set_id: Mapped[int] = mapped_column( ForeignKey("document_set.id"), primary_key=True ) user_id: Mapped[UUID | None] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), primary_key=True, nullable=True ) class DocumentSet__ConnectorCredentialPair(Base): __tablename__ = "document_set__connector_credential_pair" document_set_id: Mapped[int] = mapped_column( ForeignKey("document_set.id"), primary_key=True ) connector_credential_pair_id: Mapped[int] = mapped_column( ForeignKey("connector_credential_pair.id"), primary_key=True ) # if `True`, then is part of the current state of the document set # if `False`, then is a part of the prior state of the document set # rows with `is_current=False` should be deleted when the document # set is updated and should not exist for a given document set if # `DocumentSet.is_up_to_date == True` is_current: Mapped[bool] = mapped_column( Boolean, nullable=False, default=True, primary_key=True, ) document_set: Mapped["DocumentSet"] = relationship("DocumentSet") class ChatMessage__SearchDoc(Base): __tablename__ = "chat_message__search_doc" chat_message_id: Mapped[int] = mapped_column( ForeignKey("chat_message.id", ondelete="CASCADE"), primary_key=True ) search_doc_id: Mapped[int] = mapped_column( ForeignKey("search_doc.id", ondelete="CASCADE"), primary_key=True ) class ToolCall__SearchDoc(Base): __tablename__ = "tool_call__search_doc" tool_call_id: Mapped[int] = mapped_column( ForeignKey("tool_call.id", ondelete="CASCADE"), primary_key=True ) search_doc_id: Mapped[int] = mapped_column( ForeignKey("search_doc.id", ondelete="CASCADE"), primary_key=True ) class Document__Tag(Base): __tablename__ = "document__tag" document_id: Mapped[str] = mapped_column( ForeignKey("document.id"), primary_key=True ) tag_id: Mapped[int] = mapped_column( ForeignKey("tag.id"), primary_key=True, index=True ) class Persona__Tool(Base): """An entry in this table represents a tool that is **available** to a persona. It does NOT necessarily mean that the tool is actually usable to the persona. For example, a persona may have the image generation tool attached to it, even though the image generation tool is not set up / enabled. In this case, the tool should not show up in the UI for the persona + it should not be usable by the persona in chat. """ __tablename__ = "persona__tool" persona_id: Mapped[int] = mapped_column( ForeignKey("persona.id", ondelete="CASCADE"), primary_key=True ) tool_id: Mapped[int] = mapped_column( ForeignKey("tool.id", ondelete="CASCADE"), primary_key=True ) class StandardAnswer__StandardAnswerCategory(Base): __tablename__ = "standard_answer__standard_answer_category" standard_answer_id: Mapped[int] = mapped_column( ForeignKey("standard_answer.id"), primary_key=True ) standard_answer_category_id: Mapped[int] = mapped_column( ForeignKey("standard_answer_category.id"), primary_key=True ) class SlackChannelConfig__StandardAnswerCategory(Base): __tablename__ = "slack_channel_config__standard_answer_category" slack_channel_config_id: Mapped[int] = mapped_column( ForeignKey("slack_channel_config.id"), primary_key=True ) standard_answer_category_id: Mapped[int] = mapped_column( ForeignKey("standard_answer_category.id"), primary_key=True ) class ChatMessage__StandardAnswer(Base): __tablename__ = "chat_message__standard_answer" chat_message_id: Mapped[int] = mapped_column( ForeignKey("chat_message.id", ondelete="CASCADE"), primary_key=True ) standard_answer_id: Mapped[int] = mapped_column( ForeignKey("standard_answer.id"), primary_key=True ) """ Documents/Indexing Tables """ class ConnectorCredentialPair(Base): """Connectors and Credentials can have a many-to-many relationship I.e. A Confluence Connector may have multiple admin users who can run it with their own credentials I.e. An admin user may use the same credential to index multiple Confluence Spaces """ __tablename__ = "connector_credential_pair" # NOTE: this `id` column has to use `Sequence` instead of `autoincrement=True` # due to some SQLAlchemy quirks + this not being a primary key column id: Mapped[int] = mapped_column( Integer, Sequence("connector_credential_pair_id_seq"), unique=True, nullable=False, ) name: Mapped[str] = mapped_column(String, nullable=False) status: Mapped[ConnectorCredentialPairStatus] = mapped_column( Enum(ConnectorCredentialPairStatus, native_enum=False), nullable=False ) # this is separate from the `status` above, since a connector can be `INITIAL_INDEXING`, `ACTIVE`, # or `PAUSED` and still be in a repeated error state. in_repeated_error_state: Mapped[bool] = mapped_column(Boolean, default=False) connector_id: Mapped[int] = mapped_column( ForeignKey("connector.id"), primary_key=True ) deletion_failure_message: Mapped[str | None] = mapped_column(String, nullable=True) credential_id: Mapped[int] = mapped_column( ForeignKey("credential.id"), primary_key=True ) # controls whether the documents indexed by this CC pair are visible to all # or if they are only visible to those with that are given explicit access # (e.g. via owning the credential or being a part of a group that is given access) access_type: Mapped[AccessType] = mapped_column( Enum(AccessType, native_enum=False), nullable=False ) # special info needed for the auto-sync feature. The exact structure depends on the # source type (defined in the connector's `source` field) # E.g. for google_drive perm sync: # {"customer_id": "123567", "company_domain": "@onyx.app"} auto_sync_options: Mapped[dict[str, Any] | None] = mapped_column( postgresql.JSONB(), nullable=True ) last_time_perm_sync: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) last_time_external_group_sync: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) # Time finished, not used for calculating backend jobs which uses time started (created) last_successful_index_time: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), default=None ) # last successful prune last_pruned: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True, index=True ) # last successful hierarchy fetch last_time_hierarchy_fetch: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) total_docs_indexed: Mapped[int] = mapped_column(Integer, default=0) indexing_trigger: Mapped[IndexingMode | None] = mapped_column( Enum(IndexingMode, native_enum=False), nullable=True ) # Determines how documents are processed after fetching: # REGULAR: Full pipeline (chunk → embed → Vespa) # FILE_SYSTEM: Write to file system only (for CLI agent sandbox) processing_mode: Mapped[ProcessingMode] = mapped_column( Enum(ProcessingMode, native_enum=False), nullable=False, default=ProcessingMode.REGULAR, server_default="REGULAR", ) connector: Mapped["Connector"] = relationship( "Connector", back_populates="credentials" ) credential: Mapped["Credential"] = relationship( "Credential", back_populates="connectors" ) document_sets: Mapped[list["DocumentSet"]] = relationship( "DocumentSet", secondary=DocumentSet__ConnectorCredentialPair.__table__, primaryjoin=( (DocumentSet__ConnectorCredentialPair.connector_credential_pair_id == id) & (DocumentSet__ConnectorCredentialPair.is_current.is_(True)) ), back_populates="connector_credential_pairs", overlaps="document_set", ) index_attempts: Mapped[list["IndexAttempt"]] = relationship( "IndexAttempt", back_populates="connector_credential_pair" ) # the user id of the user that created this cc pair creator_id: Mapped[UUID | None] = mapped_column(nullable=True) creator: Mapped["User"] = relationship( "User", back_populates="cc_pairs", primaryjoin="foreign(ConnectorCredentialPair.creator_id) == remote(User.id)", ) background_errors: Mapped[list["BackgroundError"]] = relationship( "BackgroundError", back_populates="cc_pair", cascade="all, delete-orphan" ) class HierarchyNode(Base): """ Represents a structural node in a connected source's hierarchy. Examples: folders, drives, spaces, projects, channels. Stores hierarchy structure WITH permission information, using the same permission model as Documents (external_user_emails, external_user_group_ids, is_public). This enables user-scoped hierarchy browsing in the UI. Some hierarchy nodes (e.g., Confluence pages) can also be documents. In these cases, `document_id` will be set. """ __tablename__ = "hierarchy_node" # Primary key - Integer for simplicity id: Mapped[int] = mapped_column(Integer, primary_key=True) # Raw identifier from the source system # e.g., "1h7uWUR2BYZjtMfEXFt43tauj-Gp36DTPtwnsNuA665I" for Google Drive # For SOURCE nodes, this is the source name (e.g., "google_drive") raw_node_id: Mapped[str] = mapped_column(String, nullable=False) # Human-readable name for display # e.g., "Engineering", "Q4 Planning", "Google Drive" display_name: Mapped[str] = mapped_column(String, nullable=False) # Link to view this node in the source system link: Mapped[str | None] = mapped_column(NullFilteredString, nullable=True) # Source type (google_drive, confluence, etc.) source: Mapped[DocumentSource] = mapped_column( Enum(DocumentSource, native_enum=False), nullable=False ) # What kind of structural node this is node_type: Mapped[HierarchyNodeType] = mapped_column( Enum(HierarchyNodeType, native_enum=False), nullable=False ) # ============= PERMISSION FIELDS (same pattern as Document) ============= # Email addresses of external users with access to this node in the source system external_user_emails: Mapped[list[str] | None] = mapped_column( postgresql.ARRAY(String), nullable=True ) # External group IDs with access (prefixed by source type) external_user_group_ids: Mapped[list[str] | None] = mapped_column( postgresql.ARRAY(String), nullable=True ) # Whether this node is publicly accessible (org-wide or world-public) # SOURCE nodes are always public. Other nodes get this from source permissions. is_public: Mapped[bool] = mapped_column(Boolean, default=False) # ========================================================================== # Foreign keys # For hierarchy nodes that are also documents (e.g., Confluence pages) # SET NULL when document is deleted - node can exist without its document document_id: Mapped[str | None] = mapped_column( ForeignKey("document.id", ondelete="SET NULL"), nullable=True ) # Self-referential FK for tree structure # SET NULL when parent is deleted - orphan children for cleanup via pruning parent_id: Mapped[int | None] = mapped_column( ForeignKey("hierarchy_node.id", ondelete="SET NULL"), nullable=True, index=True ) # Relationships document: Mapped["Document | None"] = relationship( "Document", back_populates="hierarchy_node", foreign_keys=[document_id] ) parent: Mapped["HierarchyNode | None"] = relationship( "HierarchyNode", remote_side=[id], back_populates="children" ) children: Mapped[list["HierarchyNode"]] = relationship( "HierarchyNode", back_populates="parent", passive_deletes=True ) child_documents: Mapped[list["Document"]] = relationship( "Document", back_populates="parent_hierarchy_node", foreign_keys="Document.parent_hierarchy_node_id", passive_deletes=True, ) # Personas that have this hierarchy node attached for scoped search personas: Mapped[list["Persona"]] = relationship( "Persona", secondary="persona__hierarchy_node", back_populates="hierarchy_nodes", viewonly=True, ) __table_args__ = ( # Unique constraint: same raw_node_id + source should not exist twice UniqueConstraint( "raw_node_id", "source", name="uq_hierarchy_node_raw_id_source" ), Index("ix_hierarchy_node_source_type", source, node_type), ) class Document(Base): __tablename__ = "document" # NOTE: if more sensitive data is added here for display, make sure to add user/group permission # this should correspond to the ID of the document # (as is passed around in Onyx) id: Mapped[str] = mapped_column(NullFilteredString, primary_key=True) from_ingestion_api: Mapped[bool] = mapped_column( Boolean, default=False, nullable=True ) # 0 for neutral, positive for mostly endorse, negative for mostly reject boost: Mapped[int] = mapped_column(Integer, default=DEFAULT_BOOST) hidden: Mapped[bool] = mapped_column(Boolean, default=False) semantic_id: Mapped[str] = mapped_column(NullFilteredString) # First Section's link link: Mapped[str | None] = mapped_column(NullFilteredString, nullable=True) # The updated time is also used as a measure of the last successful state of the doc # pulled from the source (to help skip reindexing already updated docs in case of # connector retries) # TODO: rename this column because it conflates the time of the source doc # with the local last modified time of the doc and any associated metadata # it should just be the server timestamp of the source doc doc_updated_at: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) # Number of chunks in the document (in Vespa) # Only null for documents indexed prior to this change chunk_count: Mapped[int | None] = mapped_column(Integer, nullable=True) # last time any vespa relevant row metadata or the doc changed. # does not include last_synced last_modified: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=False, index=True, default=func.now() ) # last successful sync to vespa last_synced: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True, index=True ) # The following are not attached to User because the account/email may not be known # within Onyx # Something like the document creator primary_owners: Mapped[list[str] | None] = mapped_column( postgresql.ARRAY(String), nullable=True ) secondary_owners: Mapped[list[str] | None] = mapped_column( postgresql.ARRAY(String), nullable=True ) # Permission sync columns # Email addresses are saved at the document level for externally synced permissions # This is becuase the normal flow of assigning permissions is through the cc_pair # doesn't apply here external_user_emails: Mapped[list[str] | None] = mapped_column( postgresql.ARRAY(String), nullable=True ) # These group ids have been prefixed by the source type external_user_group_ids: Mapped[list[str] | None] = mapped_column( postgresql.ARRAY(String), nullable=True ) is_public: Mapped[bool] = mapped_column(Boolean, default=False) # Reference to parent hierarchy node (the folder/space containing this doc) # If None, document's hierarchy position is unknown or connector doesn't support hierarchy # SET NULL when hierarchy node is deleted - document should not be blocked by node deletion parent_hierarchy_node_id: Mapped[int | None] = mapped_column( ForeignKey("hierarchy_node.id", ondelete="SET NULL"), nullable=True, index=True ) # tables for the knowledge graph data kg_stage: Mapped[KGStage] = mapped_column( Enum(KGStage, native_enum=False), comment="Status of knowledge graph extraction for this document", index=True, ) kg_processing_time: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) retrieval_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship( "DocumentRetrievalFeedback", back_populates="document" ) doc_metadata: Mapped[dict[str, Any] | None] = mapped_column( postgresql.JSONB(), nullable=True, default=None ) tags = relationship( "Tag", secondary=Document__Tag.__table__, back_populates="documents", ) # Relationship to parent hierarchy node (the folder/space containing this doc) parent_hierarchy_node: Mapped["HierarchyNode | None"] = relationship( "HierarchyNode", back_populates="child_documents", foreign_keys=[parent_hierarchy_node_id], ) # For documents that ARE hierarchy nodes (e.g., Confluence pages with children) hierarchy_node: Mapped["HierarchyNode | None"] = relationship( "HierarchyNode", back_populates="document", foreign_keys="HierarchyNode.document_id", passive_deletes=True, ) # Personas that have this document directly attached for scoped search attached_personas: Mapped[list["Persona"]] = relationship( "Persona", secondary="persona__document", back_populates="attached_documents", viewonly=True, ) __table_args__ = ( Index( "ix_document_sync_status", last_modified, last_synced, ), ) class OpenSearchDocumentMigrationRecord(Base): """Tracks the migration status of documents from Vespa to OpenSearch. This table can be dropped when the migration is complete for all Onyx instances. """ __tablename__ = "opensearch_document_migration_record" document_id: Mapped[str] = mapped_column( String, ForeignKey("document.id", ondelete="CASCADE"), primary_key=True, nullable=False, index=True, ) status: Mapped[OpenSearchDocumentMigrationStatus] = mapped_column( Enum(OpenSearchDocumentMigrationStatus, native_enum=False), default=OpenSearchDocumentMigrationStatus.PENDING, nullable=False, index=True, ) error_message: Mapped[str | None] = mapped_column(Text, nullable=True) attempts_count: Mapped[int] = mapped_column( Integer, default=0, nullable=False, index=True ) last_attempt_at: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False, index=True, ) document: Mapped["Document"] = relationship("Document") class OpenSearchTenantMigrationRecord(Base): """Tracks the state of the OpenSearch migration for a tenant. Should only contain one row. This table can be dropped when the migration is complete for all Onyx instances. """ __tablename__ = "opensearch_tenant_migration_record" __table_args__ = ( # Singleton pattern - unique index on constant ensures only one row. Index("idx_opensearch_tenant_migration_singleton", text("(true)"), unique=True), ) id: Mapped[int] = mapped_column(primary_key=True, nullable=False) document_migration_record_table_population_status: Mapped[ OpenSearchTenantMigrationStatus ] = mapped_column( Enum(OpenSearchTenantMigrationStatus, native_enum=False), default=OpenSearchTenantMigrationStatus.PENDING, nullable=False, ) num_times_observed_no_additional_docs_to_populate_migration_table: Mapped[int] = ( mapped_column(Integer, default=0, nullable=False) ) overall_document_migration_status: Mapped[OpenSearchTenantMigrationStatus] = ( mapped_column( Enum(OpenSearchTenantMigrationStatus, native_enum=False), default=OpenSearchTenantMigrationStatus.PENDING, nullable=False, ) ) num_times_observed_no_additional_docs_to_migrate: Mapped[int] = mapped_column( Integer, default=0, nullable=False, ) last_updated_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False, ) # Opaque continuation token from Vespa's Visit API. # NULL means "not started". # Otherwise contains a serialized mapping between slice ID and continuation # token for that slice. vespa_visit_continuation_token: Mapped[str | None] = mapped_column( Text, nullable=True ) total_chunks_migrated: Mapped[int] = mapped_column( Integer, default=0, nullable=False ) total_chunks_errored: Mapped[int] = mapped_column( Integer, default=0, nullable=False ) total_chunks_in_vespa: Mapped[int] = mapped_column( Integer, default=0, nullable=False ) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False, ) migration_completed_at: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) enable_opensearch_retrieval: Mapped[bool] = mapped_column( Boolean, nullable=False, default=False ) approx_chunk_count_in_vespa: Mapped[int | None] = mapped_column( Integer, nullable=True ) class KGEntityType(Base): __tablename__ = "kg_entity_type" # Primary identifier id_name: Mapped[str] = mapped_column( String, primary_key=True, nullable=False, index=True ) description: Mapped[str | None] = mapped_column(NullFilteredString, nullable=True) grounding: Mapped[str] = mapped_column( NullFilteredString, nullable=False, index=False ) attributes: Mapped[dict | None] = mapped_column( postgresql.JSONB, nullable=True, default=dict, server_default="{}", comment="Filtering based on document attribute", ) @property def parsed_attributes(self) -> KGEntityTypeAttributes: if self.attributes is None: return KGEntityTypeAttributes() try: return KGEntityTypeAttributes(**self.attributes) except ValidationError: return KGEntityTypeAttributes() occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1) active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) deep_extraction: Mapped[bool] = mapped_column( Boolean, nullable=False, default=False ) # Tracking fields time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), ) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) grounded_source_name: Mapped[str | None] = mapped_column( NullFilteredString, nullable=True, index=False ) entity_values: Mapped[list[str]] = mapped_column( postgresql.ARRAY(String), nullable=True, default=None ) clustering: Mapped[dict] = mapped_column( postgresql.JSONB, nullable=False, default=dict, server_default="{}", comment="Clustering information for this entity type", ) class KGRelationshipType(Base): __tablename__ = "kg_relationship_type" # Primary identifier id_name: Mapped[str] = mapped_column( NullFilteredString, primary_key=True, nullable=False, index=True, ) name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True) source_entity_type_id_name: Mapped[str] = mapped_column( NullFilteredString, ForeignKey("kg_entity_type.id_name"), nullable=False, index=True, ) target_entity_type_id_name: Mapped[str] = mapped_column( NullFilteredString, ForeignKey("kg_entity_type.id_name"), nullable=False, index=True, ) definition: Mapped[bool] = mapped_column( Boolean, nullable=False, default=False, comment="Whether this relationship type represents a definition", ) clustering: Mapped[dict] = mapped_column( postgresql.JSONB, nullable=False, default=dict, server_default="{}", comment="Clustering information for this relationship type", ) type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True) active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1) # Tracking fields time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), ) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) # Relationships to EntityType source_type: Mapped["KGEntityType"] = relationship( "KGEntityType", foreign_keys=[source_entity_type_id_name], backref="source_relationship_type", ) target_type: Mapped["KGEntityType"] = relationship( "KGEntityType", foreign_keys=[target_entity_type_id_name], backref="target_relationship_type", ) class KGRelationshipTypeExtractionStaging(Base): __tablename__ = "kg_relationship_type_extraction_staging" # Primary identifier id_name: Mapped[str] = mapped_column( NullFilteredString, primary_key=True, nullable=False, index=True, ) name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True) source_entity_type_id_name: Mapped[str] = mapped_column( NullFilteredString, ForeignKey("kg_entity_type.id_name"), nullable=False, index=True, ) target_entity_type_id_name: Mapped[str] = mapped_column( NullFilteredString, ForeignKey("kg_entity_type.id_name"), nullable=False, index=True, ) definition: Mapped[bool] = mapped_column( Boolean, nullable=False, default=False, comment="Whether this relationship type represents a definition", ) clustering: Mapped[dict] = mapped_column( postgresql.JSONB, nullable=False, default=dict, server_default="{}", comment="Clustering information for this relationship type", ) type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True) active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1) transferred: Mapped[bool] = mapped_column( Boolean, nullable=False, default=False, ) # Tracking fields time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) # Relationships to EntityType source_type: Mapped["KGEntityType"] = relationship( "KGEntityType", foreign_keys=[source_entity_type_id_name], backref="source_relationship_type_staging", ) target_type: Mapped["KGEntityType"] = relationship( "KGEntityType", foreign_keys=[target_entity_type_id_name], backref="target_relationship_type_staging", ) class KGEntity(Base): __tablename__ = "kg_entity" # Primary identifier id_name: Mapped[str] = mapped_column( NullFilteredString, primary_key=True, index=True ) # Basic entity information name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True) entity_key: Mapped[str] = mapped_column( NullFilteredString, nullable=True, index=True ) parent_key: Mapped[str | None] = mapped_column( NullFilteredString, nullable=True, index=True ) name_trigrams: Mapped[list[str]] = mapped_column( postgresql.ARRAY(String(3)), nullable=True, ) attributes: Mapped[dict] = mapped_column( postgresql.JSONB, nullable=False, default=dict, server_default="{}", comment="Attributes for this entity", ) document_id: Mapped[str | None] = mapped_column( NullFilteredString, nullable=True, index=True ) alternative_names: Mapped[list[str]] = mapped_column( postgresql.ARRAY(String), nullable=False, default=list ) # Reference to KGEntityType entity_type_id_name: Mapped[str] = mapped_column( NullFilteredString, ForeignKey("kg_entity_type.id_name"), nullable=False, index=True, ) # Relationship to KGEntityType entity_type: Mapped["KGEntityType"] = relationship("KGEntityType", backref="entity") description: Mapped[str | None] = mapped_column(String, nullable=True) keywords: Mapped[list[str]] = mapped_column( postgresql.ARRAY(String), nullable=False, default=list ) occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1) # Access control acl: Mapped[list[str]] = mapped_column( postgresql.ARRAY(String), nullable=False, default=list ) # Boosts - using JSON for flexibility boosts: Mapped[dict] = mapped_column(postgresql.JSONB, nullable=False, default=dict) event_time: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True, comment="Time of the event being processed", ) # Tracking fields time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), ) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) __table_args__ = ( # Fixed column names in indexes Index("ix_entity_type_acl", entity_type_id_name, acl), Index("ix_entity_name_search", name, entity_type_id_name), ) class KGEntityExtractionStaging(Base): __tablename__ = "kg_entity_extraction_staging" # Primary identifier id_name: Mapped[str] = mapped_column( NullFilteredString, primary_key=True, nullable=False, index=True, ) # Basic entity information name: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True) attributes: Mapped[dict] = mapped_column( postgresql.JSONB, nullable=False, default=dict, server_default="{}", comment="Attributes for this entity", ) document_id: Mapped[str | None] = mapped_column( NullFilteredString, nullable=True, index=True ) alternative_names: Mapped[list[str]] = mapped_column( postgresql.ARRAY(String), nullable=False, default=list ) # Reference to KGEntityType entity_type_id_name: Mapped[str] = mapped_column( NullFilteredString, ForeignKey("kg_entity_type.id_name"), nullable=False, index=True, ) # Relationship to KGEntityType entity_type: Mapped["KGEntityType"] = relationship( "KGEntityType", backref="entity_staging" ) description: Mapped[str | None] = mapped_column(String, nullable=True) keywords: Mapped[list[str]] = mapped_column( postgresql.ARRAY(String), nullable=False, default=list ) occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1) # Access control acl: Mapped[list[str]] = mapped_column( postgresql.ARRAY(String), nullable=False, default=list ) # Boosts - using JSON for flexibility boosts: Mapped[dict] = mapped_column(postgresql.JSONB, nullable=False, default=dict) transferred_id_name: Mapped[str | None] = mapped_column( NullFilteredString, nullable=True, ) # Parent Child Information entity_key: Mapped[str] = mapped_column( NullFilteredString, nullable=True, index=True ) parent_key: Mapped[str | None] = mapped_column( NullFilteredString, nullable=True, index=True ) event_time: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True, comment="Time of the event being processed", ) # Tracking fields time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) __table_args__ = ( # Fixed column names in indexes Index("ix_entity_type_acl", entity_type_id_name, acl), Index("ix_entity_name_search", name, entity_type_id_name), ) class KGRelationship(Base): __tablename__ = "kg_relationship" # Primary identifier - now part of composite key id_name: Mapped[str] = mapped_column( NullFilteredString, nullable=False, index=True, ) source_document: Mapped[str | None] = mapped_column( NullFilteredString, ForeignKey("document.id"), nullable=True, index=True ) # Source and target nodes (foreign keys to Entity table) source_node: Mapped[str] = mapped_column( NullFilteredString, ForeignKey("kg_entity.id_name"), nullable=False, index=True ) target_node: Mapped[str] = mapped_column( NullFilteredString, ForeignKey("kg_entity.id_name"), nullable=False, index=True ) source_node_type: Mapped[str] = mapped_column( NullFilteredString, ForeignKey("kg_entity_type.id_name"), nullable=False, index=True, ) target_node_type: Mapped[str] = mapped_column( NullFilteredString, ForeignKey("kg_entity_type.id_name"), nullable=False, index=True, ) # Relationship type type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True) # Add new relationship type reference relationship_type_id_name: Mapped[str] = mapped_column( NullFilteredString, ForeignKey("kg_relationship_type.id_name"), nullable=False, index=True, ) # Add the SQLAlchemy relationship property relationship_type: Mapped["KGRelationshipType"] = relationship( "KGRelationshipType", backref="relationship" ) occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1) # Tracking fields time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), ) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) # Relationships to Entity table source: Mapped["KGEntity"] = relationship("KGEntity", foreign_keys=[source_node]) target: Mapped["KGEntity"] = relationship("KGEntity", foreign_keys=[target_node]) document: Mapped["Document"] = relationship( "Document", foreign_keys=[source_document] ) __table_args__ = ( # Composite primary key PrimaryKeyConstraint("id_name", "source_document"), # Index for querying relationships by type Index("ix_kg_relationship_type", type), # Composite index for source/target queries Index("ix_kg_relationship_nodes", source_node, target_node), # Ensure unique relationships between nodes of a specific type UniqueConstraint( "source_node", "target_node", "type", name="uq_kg_relationship_source_target_type", ), ) class KGRelationshipExtractionStaging(Base): __tablename__ = "kg_relationship_extraction_staging" # Primary identifier - now part of composite key id_name: Mapped[str] = mapped_column( NullFilteredString, nullable=False, index=True, ) source_document: Mapped[str | None] = mapped_column( NullFilteredString, ForeignKey("document.id"), nullable=True, index=True ) # Source and target nodes (foreign keys to Entity table) source_node: Mapped[str] = mapped_column( NullFilteredString, ForeignKey("kg_entity_extraction_staging.id_name"), nullable=False, index=True, ) target_node: Mapped[str] = mapped_column( NullFilteredString, ForeignKey("kg_entity_extraction_staging.id_name"), nullable=False, index=True, ) source_node_type: Mapped[str] = mapped_column( NullFilteredString, ForeignKey("kg_entity_type.id_name"), nullable=False, index=True, ) target_node_type: Mapped[str] = mapped_column( NullFilteredString, ForeignKey("kg_entity_type.id_name"), nullable=False, index=True, ) # Relationship type type: Mapped[str] = mapped_column(NullFilteredString, nullable=False, index=True) # Add new relationship type reference relationship_type_id_name: Mapped[str] = mapped_column( NullFilteredString, ForeignKey("kg_relationship_type_extraction_staging.id_name"), nullable=False, index=True, ) # Add the SQLAlchemy relationship property relationship_type: Mapped["KGRelationshipTypeExtractionStaging"] = relationship( "KGRelationshipTypeExtractionStaging", backref="relationship_staging" ) occurrences: Mapped[int] = mapped_column(Integer, nullable=False, default=1) transferred: Mapped[bool] = mapped_column( Boolean, nullable=False, default=False, ) # Tracking fields time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) # Relationships to Entity table source: Mapped["KGEntityExtractionStaging"] = relationship( "KGEntityExtractionStaging", foreign_keys=[source_node] ) target: Mapped["KGEntityExtractionStaging"] = relationship( "KGEntityExtractionStaging", foreign_keys=[target_node] ) document: Mapped["Document"] = relationship( "Document", foreign_keys=[source_document] ) __table_args__ = ( # Composite primary key PrimaryKeyConstraint("id_name", "source_document"), # Index for querying relationships by type Index("ix_kg_relationship_type", type), # Composite index for source/target queries Index("ix_kg_relationship_nodes", source_node, target_node), # Ensure unique relationships between nodes of a specific type UniqueConstraint( "source_node", "target_node", "type", name="uq_kg_relationship_source_target_type", ), ) class KGTerm(Base): __tablename__ = "kg_term" # Make id_term the primary key id_term: Mapped[str] = mapped_column( NullFilteredString, primary_key=True, nullable=False, index=True ) # List of entity types this term applies to entity_types: Mapped[list[str]] = mapped_column( postgresql.ARRAY(String), nullable=False, default=list ) # Tracking fields time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), ) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) __table_args__ = ( # Index for searching terms with specific entity types Index("ix_search_term_entities", entity_types), # Index for term lookups Index("ix_search_term_term", id_term), ) class ChunkStats(Base): __tablename__ = "chunk_stats" # NOTE: if more sensitive data is added here for display, make sure to add user/group permission # this should correspond to the ID of the document # (as is passed around in Onyx)x id: Mapped[str] = mapped_column( NullFilteredString, primary_key=True, default=lambda context: ( f"{context.get_current_parameters()['document_id']}__{context.get_current_parameters()['chunk_in_doc_id']}" ), index=True, ) # Reference to parent document document_id: Mapped[str] = mapped_column( NullFilteredString, ForeignKey("document.id"), nullable=False, index=True ) chunk_in_doc_id: Mapped[int] = mapped_column( Integer, nullable=False, ) information_content_boost: Mapped[float | None] = mapped_column( Float, nullable=True ) last_modified: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=False, index=True, default=func.now() ) last_synced: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True, index=True ) __table_args__ = ( Index( "ix_chunk_sync_status", last_modified, last_synced, ), UniqueConstraint( "document_id", "chunk_in_doc_id", name="uq_chunk_stats_doc_chunk" ), ) class Tag(Base): __tablename__ = "tag" id: Mapped[int] = mapped_column(primary_key=True) tag_key: Mapped[str] = mapped_column(String) tag_value: Mapped[str] = mapped_column(String) source: Mapped[DocumentSource] = mapped_column( Enum(DocumentSource, native_enum=False) ) is_list: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) documents = relationship( "Document", secondary=Document__Tag.__table__, back_populates="tags", ) __table_args__ = ( UniqueConstraint( "tag_key", "tag_value", "source", "is_list", name="_tag_key_value_source_list_uc", ), ) class Connector(Base): __tablename__ = "connector" id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(String) source: Mapped[DocumentSource] = mapped_column( Enum(DocumentSource, native_enum=False) ) input_type = mapped_column(Enum(InputType, native_enum=False)) connector_specific_config: Mapped[dict[str, Any]] = mapped_column( postgresql.JSONB() ) indexing_start: Mapped[datetime.datetime | None] = mapped_column( DateTime, nullable=True ) kg_processing_enabled: Mapped[bool] = mapped_column( Boolean, nullable=False, default=False, comment="Whether this connector should extract knowledge graph entities", ) kg_coverage_days: Mapped[int | None] = mapped_column(Integer, nullable=True) refresh_freq: Mapped[int | None] = mapped_column(Integer, nullable=True) prune_freq: Mapped[int | None] = mapped_column(Integer, nullable=True) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) credentials: Mapped[list["ConnectorCredentialPair"]] = relationship( "ConnectorCredentialPair", back_populates="connector", cascade="all, delete-orphan", ) documents_by_connector: Mapped[list["DocumentByConnectorCredentialPair"]] = ( relationship( "DocumentByConnectorCredentialPair", back_populates="connector", passive_deletes=True, ) ) # synchronize this validation logic with RefreshFrequencySchema etc on front end # until we have a centralized validation schema # TODO(rkuo): experiment with SQLAlchemy validators rather than manual checks # https://docs.sqlalchemy.org/en/20/orm/mapped_attributes.html def validate_refresh_freq(self) -> None: if self.refresh_freq is not None: if self.refresh_freq < 60: raise ValueError( "refresh_freq must be greater than or equal to 1 minute." ) def validate_prune_freq(self) -> None: if self.prune_freq is not None: if self.prune_freq < 300: raise ValueError( "prune_freq must be greater than or equal to 5 minutes." ) class Credential(Base): __tablename__ = "credential" name: Mapped[str] = mapped_column(String, nullable=True) source: Mapped[DocumentSource] = mapped_column( Enum(DocumentSource, native_enum=False) ) id: Mapped[int] = mapped_column(primary_key=True) credential_json: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column( EncryptedJson() ) user_id: Mapped[UUID | None] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), nullable=True ) # if `true`, then all Admins will have access to the credential admin_public: Mapped[bool] = mapped_column(Boolean, default=True) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) curator_public: Mapped[bool] = mapped_column(Boolean, default=False) connectors: Mapped[list["ConnectorCredentialPair"]] = relationship( "ConnectorCredentialPair", back_populates="credential", cascade="all, delete-orphan", ) documents_by_credential: Mapped[list["DocumentByConnectorCredentialPair"]] = ( relationship( "DocumentByConnectorCredentialPair", back_populates="credential", passive_deletes=True, ) ) user: Mapped[User | None] = relationship("User", back_populates="credentials") class FederatedConnector(Base): __tablename__ = "federated_connector" id: Mapped[int] = mapped_column(primary_key=True) source: Mapped[FederatedConnectorSource] = mapped_column( Enum(FederatedConnectorSource, native_enum=False) ) credentials: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column( EncryptedJson(), nullable=False ) config: Mapped[dict[str, Any]] = mapped_column( postgresql.JSONB(), default=dict, nullable=False, server_default="{}" ) oauth_tokens: Mapped[list["FederatedConnectorOAuthToken"]] = relationship( "FederatedConnectorOAuthToken", back_populates="federated_connector", cascade="all, delete-orphan", ) document_sets: Mapped[list["FederatedConnector__DocumentSet"]] = relationship( "FederatedConnector__DocumentSet", back_populates="federated_connector", cascade="all, delete-orphan", ) class FederatedConnectorOAuthToken(Base): __tablename__ = "federated_connector_oauth_token" id: Mapped[int] = mapped_column(primary_key=True) federated_connector_id: Mapped[int] = mapped_column( ForeignKey("federated_connector.id", ondelete="CASCADE"), nullable=False ) user_id: Mapped[UUID] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), nullable=False ) token: Mapped[SensitiveValue[str] | None] = mapped_column( EncryptedString(), nullable=False ) expires_at: Mapped[datetime.datetime | None] = mapped_column( DateTime, nullable=True ) federated_connector: Mapped["FederatedConnector"] = relationship( "FederatedConnector", back_populates="oauth_tokens" ) user: Mapped["User"] = relationship("User") class FederatedConnector__DocumentSet(Base): __tablename__ = "federated_connector__document_set" id: Mapped[int] = mapped_column(primary_key=True) federated_connector_id: Mapped[int] = mapped_column( ForeignKey("federated_connector.id", ondelete="CASCADE"), nullable=False ) document_set_id: Mapped[int] = mapped_column( ForeignKey("document_set.id", ondelete="CASCADE"), nullable=False ) # unique per source type. Validated before insertion. entities: Mapped[dict[str, Any]] = mapped_column(postgresql.JSONB(), nullable=False) federated_connector: Mapped["FederatedConnector"] = relationship( "FederatedConnector", back_populates="document_sets" ) document_set: Mapped["DocumentSet"] = relationship( "DocumentSet", back_populates="federated_connectors" ) __table_args__ = ( UniqueConstraint( "federated_connector_id", "document_set_id", name="uq_federated_connector_document_set", ), ) class SearchSettings(Base): __tablename__ = "search_settings" id: Mapped[int] = mapped_column(primary_key=True) model_name: Mapped[str] = mapped_column(String) model_dim: Mapped[int] = mapped_column(Integer) normalize: Mapped[bool] = mapped_column(Boolean) query_prefix: Mapped[str | None] = mapped_column(String, nullable=True) passage_prefix: Mapped[str | None] = mapped_column(String, nullable=True) status: Mapped[IndexModelStatus] = mapped_column( Enum(IndexModelStatus, native_enum=False) ) index_name: Mapped[str] = mapped_column(String) provider_type: Mapped[EmbeddingProvider | None] = mapped_column( ForeignKey("embedding_provider.provider_type"), nullable=True ) # Type of switchover to perform when switching embedding models # REINDEX: waits for all connectors to complete # ACTIVE_ONLY: waits for only non-paused connectors to complete # INSTANT: swaps immediately without waiting switchover_type: Mapped[SwitchoverType] = mapped_column( Enum(SwitchoverType, native_enum=False), default=SwitchoverType.REINDEX ) # allows for quantization -> less memory usage for a small performance hit embedding_precision: Mapped[EmbeddingPrecision] = mapped_column( Enum(EmbeddingPrecision, native_enum=False) ) # can be used to reduce dimensionality of vectors and save memory with # a small performance hit. More details in the `Reducing embedding dimensions` # section here: # https://platform.openai.com/docs/guides/embeddings#embedding-models # If not specified, will just use the model_dim without any reduction. # NOTE: this is only currently available for OpenAI models reduced_dimension: Mapped[int | None] = mapped_column(Integer, nullable=True) # Mini and Large Chunks (large chunk also checks for model max context) multipass_indexing: Mapped[bool] = mapped_column(Boolean, default=True) # Contextual RAG enable_contextual_rag: Mapped[bool] = mapped_column(Boolean, default=False) # Contextual RAG LLM contextual_rag_llm_name: Mapped[str | None] = mapped_column(String, nullable=True) contextual_rag_llm_provider: Mapped[str | None] = mapped_column( String, nullable=True ) multilingual_expansion: Mapped[list[str]] = mapped_column( postgresql.ARRAY(String), default=[] ) cloud_provider: Mapped["CloudEmbeddingProvider"] = relationship( "CloudEmbeddingProvider", back_populates="search_settings", foreign_keys=[provider_type], ) index_attempts: Mapped[list["IndexAttempt"]] = relationship( "IndexAttempt", back_populates="search_settings" ) __table_args__ = ( Index( "ix_embedding_model_present_unique", "status", unique=True, postgresql_where=(status == IndexModelStatus.PRESENT), ), Index( "ix_embedding_model_future_unique", "status", unique=True, postgresql_where=(status == IndexModelStatus.FUTURE), ), ) def __repr__(self) -> str: return f"" @property def api_version(self) -> str | None: return ( self.cloud_provider.api_version if self.cloud_provider is not None else None ) @property def deployment_name(self) -> str | None: return ( self.cloud_provider.deployment_name if self.cloud_provider is not None else None ) @property def api_url(self) -> str | None: return self.cloud_provider.api_url if self.cloud_provider is not None else None @property def api_key(self) -> str | None: if self.cloud_provider is None or self.cloud_provider.api_key is None: return None return self.cloud_provider.api_key.get_value(apply_mask=False) @property def large_chunks_enabled(self) -> bool: """ Given multipass usage and an embedder, decides whether large chunks are allowed based on model/provider constraints. """ # Only local models that support a larger context are from Nomic # Cohere does not support larger contexts (they recommend not going above ~512 tokens) return SearchSettings.can_use_large_chunks( self.multipass_indexing, self.model_name, self.provider_type ) @property def final_embedding_dim(self) -> int: return self.reduced_dimension or self.model_dim @staticmethod def can_use_large_chunks( multipass: bool, model_name: str, provider_type: EmbeddingProvider | None ) -> bool: """ Given multipass usage and an embedder, decides whether large chunks are allowed based on model/provider constraints. """ # Only local models that support a larger context are from Nomic # Cohere does not support larger contexts (they recommend not going above ~512 tokens) return ( multipass and model_name.startswith("nomic-ai") and provider_type != EmbeddingProvider.COHERE ) class IndexAttempt(Base): """ Represents an attempt to index a group of 0 or more documents from a source. For example, a single pull from Google Drive, a single event from slack event API, or a single website crawl. """ __tablename__ = "index_attempt" id: Mapped[int] = mapped_column(primary_key=True) connector_credential_pair_id: Mapped[int] = mapped_column( ForeignKey("connector_credential_pair.id"), nullable=False, ) # Some index attempts that run from beginning will still have this as False # This is only for attempts that are explicitly marked as from the start via # the run once API from_beginning: Mapped[bool] = mapped_column(Boolean) status: Mapped[IndexingStatus] = mapped_column( Enum(IndexingStatus, native_enum=False, index=True) ) # The two below may be slightly out of sync if user switches Embedding Model new_docs_indexed: Mapped[int | None] = mapped_column(Integer, default=0) total_docs_indexed: Mapped[int | None] = mapped_column(Integer, default=0) docs_removed_from_index: Mapped[int | None] = mapped_column(Integer, default=0) # only filled if status = "failed" error_msg: Mapped[str | None] = mapped_column(Text, default=None) # only filled if status = "failed" AND an unhandled exception caused the failure full_exception_trace: Mapped[str | None] = mapped_column(Text, default=None) # Nullable because in the past, we didn't allow swapping out embedding models live search_settings_id: Mapped[int] = mapped_column( ForeignKey("search_settings.id", ondelete="SET NULL"), nullable=True, ) # for polling connectors, the start and end time of the poll window # will be set when the index attempt starts poll_range_start: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True, default=None ) poll_range_end: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True, default=None ) # Points to the last checkpoint that was saved for this run. The pointer here # can be taken to the FileStore to grab the actual checkpoint value checkpoint_pointer: Mapped[str | None] = mapped_column(String, nullable=True) # Database-based coordination fields (replacing Redis fencing) celery_task_id: Mapped[str | None] = mapped_column(String, nullable=True) cancellation_requested: Mapped[bool] = mapped_column(Boolean, default=False) # Batch coordination fields # Once this is set, docfetching has completed total_batches: Mapped[int | None] = mapped_column(Integer, nullable=True) # batches that are fully indexed (i.e. have completed docfetching and docprocessing) completed_batches: Mapped[int] = mapped_column(Integer, default=0) # TODO: unused, remove this column total_failures_batch_level: Mapped[int] = mapped_column(Integer, default=0) total_chunks: Mapped[int] = mapped_column(Integer, default=0) # Progress tracking for stall detection last_progress_time: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) last_batches_completed_count: Mapped[int] = mapped_column(Integer, default=0) # Heartbeat tracking for worker liveness detection heartbeat_counter: Mapped[int] = mapped_column(Integer, default=0) last_heartbeat_value: Mapped[int] = mapped_column(Integer, default=0) last_heartbeat_time: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), index=True, ) # when the actual indexing run began # NOTE: will use the api_server clock rather than DB server clock time_started: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), default=None ) time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), ) connector_credential_pair: Mapped[ConnectorCredentialPair] = relationship( "ConnectorCredentialPair", back_populates="index_attempts" ) search_settings: Mapped[SearchSettings | None] = relationship( "SearchSettings", back_populates="index_attempts" ) error_rows = relationship( "IndexAttemptError", back_populates="index_attempt", cascade="all, delete-orphan", ) __table_args__ = ( Index( "ix_index_attempt_latest_for_connector_credential_pair", "connector_credential_pair_id", "time_created", ), Index( "ix_index_attempt_ccpair_search_settings_time_updated", "connector_credential_pair_id", "search_settings_id", desc("time_updated"), unique=False, ), Index( "ix_index_attempt_cc_pair_settings_poll", "connector_credential_pair_id", "search_settings_id", "status", desc("time_updated"), ), # NEW: Index for coordination queries Index( "ix_index_attempt_active_coordination", "connector_credential_pair_id", "search_settings_id", "status", ), ) def __repr__(self) -> str: return ( f"" f"time_created={self.time_created!r}, " f"time_updated={self.time_updated!r}, " ) def is_finished(self) -> bool: return self.status.is_terminal() def is_coordination_complete(self) -> bool: """Check if all batches have been processed""" return ( self.total_batches is not None and self.completed_batches >= self.total_batches ) class HierarchyFetchAttempt(Base): """Tracks attempts to fetch hierarchy nodes from a source""" __tablename__ = "hierarchy_fetch_attempt" id: Mapped[UUID] = mapped_column( PGUUID(as_uuid=True), primary_key=True, default=uuid4 ) connector_credential_pair_id: Mapped[int] = mapped_column( ForeignKey("connector_credential_pair.id", ondelete="CASCADE"), nullable=False, ) status: Mapped[IndexingStatus] = mapped_column( Enum(IndexingStatus, native_enum=False), nullable=False, index=True ) # Statistics nodes_fetched: Mapped[int | None] = mapped_column(Integer, default=0) nodes_updated: Mapped[int | None] = mapped_column(Integer, default=0) # Error information (only filled if status = "failed") error_msg: Mapped[str | None] = mapped_column(Text, default=None) full_exception_trace: Mapped[str | None] = mapped_column(Text, default=None) # Timestamps time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), index=True, ) time_started: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), default=None ) time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), ) # Relationships connector_credential_pair: Mapped["ConnectorCredentialPair"] = relationship( "ConnectorCredentialPair" ) __table_args__ = ( Index( "ix_hierarchy_fetch_attempt_cc_pair", connector_credential_pair_id, ), ) class IndexAttemptError(Base): __tablename__ = "index_attempt_errors" id: Mapped[int] = mapped_column(primary_key=True) index_attempt_id: Mapped[int] = mapped_column( ForeignKey("index_attempt.id"), nullable=False, ) connector_credential_pair_id: Mapped[int] = mapped_column( ForeignKey("connector_credential_pair.id"), nullable=False, ) document_id: Mapped[str | None] = mapped_column(String, nullable=True) document_link: Mapped[str | None] = mapped_column(String, nullable=True) entity_id: Mapped[str | None] = mapped_column(String, nullable=True) failed_time_range_start: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) failed_time_range_end: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) failure_message: Mapped[str] = mapped_column(Text) is_resolved: Mapped[bool] = mapped_column(Boolean, default=False) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), ) # This is the reverse side of the relationship index_attempt = relationship("IndexAttempt", back_populates="error_rows") class SyncRecord(Base): """ Represents the status of a "sync" operation (e.g. document set, user group, deletion). A "sync" operation is an operation which needs to update a set of documents within Vespa, usually to match the state of Postgres. """ __tablename__ = "sync_record" id: Mapped[int] = mapped_column(Integer, primary_key=True) # document set id, user group id, or deletion id entity_id: Mapped[int] = mapped_column(Integer) sync_type: Mapped[SyncType] = mapped_column(Enum(SyncType, native_enum=False)) sync_status: Mapped[SyncStatus] = mapped_column(Enum(SyncStatus, native_enum=False)) num_docs_synced: Mapped[int] = mapped_column(Integer, default=0) sync_start_time: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True)) sync_end_time: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) __table_args__ = ( Index( "ix_sync_record_entity_id_sync_type_sync_start_time", "entity_id", "sync_type", "sync_start_time", ), Index( "ix_sync_record_entity_id_sync_type_sync_status", "entity_id", "sync_type", "sync_status", ), ) class HierarchyNodeByConnectorCredentialPair(Base): """Tracks which cc_pairs reference each hierarchy node. During pruning, stale entries are removed for the current cc_pair. Hierarchy nodes with zero remaining entries are then deleted. """ __tablename__ = "hierarchy_node_by_connector_credential_pair" hierarchy_node_id: Mapped[int] = mapped_column( ForeignKey("hierarchy_node.id", ondelete="CASCADE"), primary_key=True ) connector_id: Mapped[int] = mapped_column(primary_key=True) credential_id: Mapped[int] = mapped_column(primary_key=True) __table_args__ = ( ForeignKeyConstraint( ["connector_id", "credential_id"], [ "connector_credential_pair.connector_id", "connector_credential_pair.credential_id", ], ondelete="CASCADE", ), Index( "ix_hierarchy_node_cc_pair_connector_credential", "connector_id", "credential_id", ), ) class DocumentByConnectorCredentialPair(Base): """Represents an indexing of a document by a specific connector / credential pair""" __tablename__ = "document_by_connector_credential_pair" id: Mapped[str] = mapped_column(ForeignKey("document.id"), primary_key=True) # TODO: transition this to use the ConnectorCredentialPair id directly connector_id: Mapped[int] = mapped_column( ForeignKey("connector.id", ondelete="CASCADE"), primary_key=True ) credential_id: Mapped[int] = mapped_column( ForeignKey("credential.id", ondelete="CASCADE"), primary_key=True ) # used to better keep track of document counts at a connector level # e.g. if a document is added as part of permission syncing, it should # not be counted as part of the connector's document count until # the actual indexing is complete has_been_indexed: Mapped[bool] = mapped_column(Boolean) connector: Mapped[Connector] = relationship( "Connector", back_populates="documents_by_connector", passive_deletes=True ) credential: Mapped[Credential] = relationship( "Credential", back_populates="documents_by_credential", passive_deletes=True ) __table_args__ = ( Index( "idx_document_cc_pair_connector_credential", "connector_id", "credential_id", unique=False, ), # Index to optimize get_document_counts_for_cc_pairs query pattern Index( "idx_document_cc_pair_counts", "connector_id", "credential_id", "has_been_indexed", unique=False, ), ) """ Messages Tables """ class ChatSession(Base): __tablename__ = "chat_session" id: Mapped[UUID] = mapped_column( PGUUID(as_uuid=True), primary_key=True, default=uuid4 ) user_id: Mapped[UUID | None] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), nullable=True ) persona_id: Mapped[int | None] = mapped_column( ForeignKey("persona.id"), nullable=True ) description: Mapped[str | None] = mapped_column(Text, nullable=True) # This chat created by OnyxBot onyxbot_flow: Mapped[bool] = mapped_column(Boolean, default=False) # Only ever set to True if system is set to not hard-delete chats deleted: Mapped[bool] = mapped_column(Boolean, default=False) # controls whether or not this conversation is viewable by others shared_status: Mapped[ChatSessionSharedStatus] = mapped_column( Enum(ChatSessionSharedStatus, native_enum=False), default=ChatSessionSharedStatus.PRIVATE, ) current_alternate_model: Mapped[str | None] = mapped_column(String, default=None) slack_thread_id: Mapped[str | None] = mapped_column( String, nullable=True, default=None ) project_id: Mapped[int | None] = mapped_column( ForeignKey("user_project.id"), nullable=True ) project: Mapped["UserProject"] = relationship( "UserProject", back_populates="chat_sessions", foreign_keys=[project_id] ) # the latest "overrides" specified by the user. These take precedence over # the attached persona. However, overrides specified directly in the # `send-message` call will take precedence over these. # NOTE: currently only used by the chat seeding flow, will be used in the # future once we allow users to override default values via the Chat UI # itself llm_override: Mapped[LLMOverride | None] = mapped_column( PydanticType(LLMOverride), nullable=True ) # The latest temperature override specified by the user temperature_override: Mapped[float | None] = mapped_column(Float, nullable=True) prompt_override: Mapped[PromptOverride | None] = mapped_column( PydanticType(PromptOverride), nullable=True ) time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), ) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) user: Mapped[User] = relationship("User", back_populates="chat_sessions") messages: Mapped[list["ChatMessage"]] = relationship( "ChatMessage", back_populates="chat_session", cascade="all, delete-orphan", foreign_keys="ChatMessage.chat_session_id", ) persona: Mapped["Persona"] = relationship("Persona") class ChatMessage(Base): """Note, the first message in a chain has no contents, it's a workaround to allow edits on the first message of a session, an empty root node basically Since every user message is followed by a LLM response, chat messages generally come in pairs. Keeping them as separate messages however for future Agentification extensions Fields will be largely duplicated in the pair. """ __tablename__ = "chat_message" id: Mapped[int] = mapped_column(primary_key=True) # Where is this message located chat_session_id: Mapped[UUID] = mapped_column( PGUUID(as_uuid=True), ForeignKey("chat_session.id") ) # Parent message pointer for the tree structure, nullable because the first message is # an empty root node to allow edits on the first message of a session. parent_message_id: Mapped[int | None] = mapped_column( ForeignKey("chat_message.id"), nullable=True ) # This only maps to the latest because only that message chain is needed. # It can be updated as needed to trace other branches. latest_child_message_id: Mapped[int | None] = mapped_column( ForeignKey("chat_message.id"), nullable=True ) # Only set on summary messages - the ID of the last message included in this summary # Used for chat history compression last_summarized_message_id: Mapped[int | None] = mapped_column( ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True, ) # For multi-model turns: the user message points to which assistant response # was selected as the preferred one to continue the conversation with. preferred_response_id: Mapped[int | None] = mapped_column( ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True ) # The display name of the model that generated this assistant message model_display_name: Mapped[str | None] = mapped_column(String, nullable=True) # What does this message contain reasoning_tokens: Mapped[str | None] = mapped_column(Text, nullable=True) message: Mapped[str] = mapped_column(Text) token_count: Mapped[int] = mapped_column(Integer) message_type: Mapped[MessageType] = mapped_column( Enum(MessageType, native_enum=False) ) # Files attached to the message, when parsed into history, it becomes a separate message files: Mapped[list[FileDescriptor] | None] = mapped_column( postgresql.JSONB(), nullable=True ) # Maps the citation numbers to a SearchDoc id citations: Mapped[dict[int, int] | None] = mapped_column( postgresql.JSONB(), nullable=True ) # Metadata error: Mapped[str | None] = mapped_column(Text, nullable=True) time_sent: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) # True if this assistant message is a clarification question (deep research flow) is_clarification: Mapped[bool] = mapped_column(Boolean, default=False) # Duration in seconds for processing this message (assistant messages only) processing_duration_seconds: Mapped[float | None] = mapped_column( Float, nullable=True ) # Relationships chat_session: Mapped[ChatSession] = relationship( "ChatSession", back_populates="messages", foreign_keys=[chat_session_id], ) chat_message_feedbacks: Mapped[list["ChatMessageFeedback"]] = relationship( "ChatMessageFeedback", back_populates="chat_message", ) document_feedbacks: Mapped[list["DocumentRetrievalFeedback"]] = relationship( "DocumentRetrievalFeedback", back_populates="chat_message", ) # Even though search docs come from tool calls, the answer has a final set of saved search docs that we will show search_docs: Mapped[list["SearchDoc"]] = relationship( "SearchDoc", secondary=ChatMessage__SearchDoc.__table__, back_populates="chat_messages", cascade="all, delete-orphan", single_parent=True, ) parent_message: Mapped["ChatMessage | None"] = relationship( "ChatMessage", foreign_keys=[parent_message_id], remote_side="ChatMessage.id", ) latest_child_message: Mapped["ChatMessage | None"] = relationship( "ChatMessage", foreign_keys=[latest_child_message_id], remote_side="ChatMessage.id", ) preferred_response: Mapped["ChatMessage | None"] = relationship( "ChatMessage", foreign_keys=[preferred_response_id], remote_side="ChatMessage.id", ) # Chat messages only need to know their immediate tool call children # If there are nested tool calls, they are stored in the tool_call_children relationship. tool_calls: Mapped[list["ToolCall"] | None] = relationship( "ToolCall", back_populates="chat_message", ) standard_answers: Mapped[list["StandardAnswer"]] = relationship( "StandardAnswer", secondary=ChatMessage__StandardAnswer.__table__, back_populates="chat_messages", ) class ToolCall(Base): """Represents a Tool Call and Tool Response""" __tablename__ = "tool_call" id: Mapped[int] = mapped_column(primary_key=True) chat_session_id: Mapped[UUID] = mapped_column( PGUUID(as_uuid=True), ForeignKey("chat_session.id", ondelete="CASCADE") ) # If this is not None, it's a top level tool call from the user message # If this is None, it's a lower level call from another tool/agent parent_chat_message_id: Mapped[int | None] = mapped_column( ForeignKey("chat_message.id", ondelete="CASCADE"), nullable=True ) # If this is not None, this tool call is a child of another tool call parent_tool_call_id: Mapped[int | None] = mapped_column( ForeignKey("tool_call.id", ondelete="CASCADE"), nullable=True ) # The tools with the same turn number (and parent) were called in parallel # Ones with different turn numbers (and same parent) were called sequentially turn_number: Mapped[int] = mapped_column(Integer) # Index order of tool calls from the LLM for parallel tool calls tab_index: Mapped[int] = mapped_column(Integer, default=0) # Not a FK because we want to be able to delete the tool without deleting # this entry tool_id: Mapped[int] = mapped_column(Integer()) # This is needed because LLMs expect the tool call and the response to have matching IDs # This is better than just regenerating one randomly tool_call_id: Mapped[str] = mapped_column(String()) # Preceeding reasoning tokens for this tool call, not included in the history reasoning_tokens: Mapped[str | None] = mapped_column(Text, nullable=True) # For "Agents" like the Research Agent for Deep Research - # the argument and final report are stored as the argument and response. tool_call_arguments: Mapped[dict[str, JSON_ro]] = mapped_column(postgresql.JSONB()) tool_call_response: Mapped[str] = mapped_column(Text) # This just counts the number of tokens in the arg because it's all that's kept for the history # Only the top level tools (the ones with a parent_chat_message_id) have token counts that are counted # towards the session total. tool_call_tokens: Mapped[int] = mapped_column(Integer()) # For image generation tool - stores GeneratedImage objects for replay generated_images: Mapped[list[dict] | None] = mapped_column( postgresql.JSONB(), nullable=True ) # Relationships chat_session: Mapped[ChatSession] = relationship("ChatSession") chat_message: Mapped["ChatMessage | None"] = relationship( "ChatMessage", foreign_keys=[parent_chat_message_id], back_populates="tool_calls", ) parent_tool_call: Mapped["ToolCall | None"] = relationship( "ToolCall", foreign_keys=[parent_tool_call_id], remote_side="ToolCall.id", ) tool_call_children: Mapped[list["ToolCall"]] = relationship( "ToolCall", foreign_keys=[parent_tool_call_id], back_populates="parent_tool_call", ) # Other tools may need to save other things, might need to figure out a more generic way to store # rich tool returns search_docs: Mapped[list["SearchDoc"]] = relationship( "SearchDoc", secondary=ToolCall__SearchDoc.__table__, back_populates="tool_calls", cascade="all, delete-orphan", single_parent=True, ) class SearchDoc(Base): """Different from Document table. This one stores the state of a document from a retrieval. This allows chat sessions to be replayed with the searched docs Notably, this does not include the contents of the Document/Chunk, during inference if a stored SearchDoc is selected, an inference must be remade to retrieve the contents """ __tablename__ = "search_doc" id: Mapped[int] = mapped_column(primary_key=True) document_id: Mapped[str] = mapped_column(String) chunk_ind: Mapped[int] = mapped_column(Integer) semantic_id: Mapped[str] = mapped_column(String) link: Mapped[str | None] = mapped_column(String, nullable=True) blurb: Mapped[str] = mapped_column(String) boost: Mapped[int] = mapped_column(Integer) source_type: Mapped[DocumentSource] = mapped_column( Enum(DocumentSource, native_enum=False) ) hidden: Mapped[bool] = mapped_column(Boolean) doc_metadata: Mapped[dict[str, str | list[str]]] = mapped_column(postgresql.JSONB()) score: Mapped[float] = mapped_column(Float) match_highlights: Mapped[list[str]] = mapped_column(postgresql.ARRAY(String)) # This is for the document, not this row in the table updated_at: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) primary_owners: Mapped[list[str] | None] = mapped_column( postgresql.ARRAY(String), nullable=True ) secondary_owners: Mapped[list[str] | None] = mapped_column( postgresql.ARRAY(String), nullable=True ) is_internet: Mapped[bool] = mapped_column(Boolean, default=False, nullable=True) is_relevant: Mapped[bool | None] = mapped_column(Boolean, nullable=True) relevance_explanation: Mapped[str | None] = mapped_column(String, nullable=True) chat_messages: Mapped[list["ChatMessage"]] = relationship( "ChatMessage", secondary=ChatMessage__SearchDoc.__table__, back_populates="search_docs", ) tool_calls: Mapped[list["ToolCall"]] = relationship( "ToolCall", secondary=ToolCall__SearchDoc.__table__, back_populates="search_docs", ) class SearchQuery(Base): # This table contains search queries for the Search UI. There are no followups and less is stored because the reply # functionality is simply to rerun the search query again as things may have changed and this is more common for search. __tablename__ = "search_query" id: Mapped[UUID] = mapped_column( PGUUID(as_uuid=True), primary_key=True, default=uuid4 ) user_id: Mapped[UUID] = mapped_column( PGUUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE") ) query: Mapped[str] = mapped_column(String) query_expansions: Mapped[list[str] | None] = mapped_column( postgresql.ARRAY(String), nullable=True ) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) """ Feedback, Logging, Metrics Tables """ class DocumentRetrievalFeedback(Base): __tablename__ = "document_retrieval_feedback" id: Mapped[int] = mapped_column(primary_key=True) chat_message_id: Mapped[int | None] = mapped_column( ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True ) document_id: Mapped[str] = mapped_column(ForeignKey("document.id")) # How high up this document is in the results, 1 for first document_rank: Mapped[int] = mapped_column(Integer) clicked: Mapped[bool] = mapped_column(Boolean, default=False) feedback: Mapped[SearchFeedbackType | None] = mapped_column( Enum(SearchFeedbackType, native_enum=False), nullable=True ) chat_message: Mapped[ChatMessage] = relationship( "ChatMessage", back_populates="document_feedbacks", foreign_keys=[chat_message_id], ) document: Mapped[Document] = relationship( "Document", back_populates="retrieval_feedbacks" ) class ChatMessageFeedback(Base): __tablename__ = "chat_feedback" id: Mapped[int] = mapped_column(Integer, primary_key=True) chat_message_id: Mapped[int | None] = mapped_column( ForeignKey("chat_message.id", ondelete="SET NULL"), nullable=True ) is_positive: Mapped[bool | None] = mapped_column(Boolean, nullable=True) required_followup: Mapped[bool | None] = mapped_column(Boolean, nullable=True) feedback_text: Mapped[str | None] = mapped_column(Text, nullable=True) predefined_feedback: Mapped[str | None] = mapped_column(String, nullable=True) chat_message: Mapped[ChatMessage] = relationship( "ChatMessage", back_populates="chat_message_feedbacks", foreign_keys=[chat_message_id], ) class LLMProvider(Base): __tablename__ = "llm_provider" id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(String, unique=True) provider: Mapped[str] = mapped_column(String) api_key: Mapped[SensitiveValue[str] | None] = mapped_column( EncryptedString(), nullable=True ) api_base: Mapped[str | None] = mapped_column(String, nullable=True) api_version: Mapped[str | None] = mapped_column(String, nullable=True) # custom configs that should be passed to the LLM provider at inference time # (e.g. `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, etc. for bedrock) custom_config: Mapped[dict[str, str] | None] = mapped_column( postgresql.JSONB(), nullable=True ) # Deprecated: use LLMModelFlow with CHAT flow type instead default_model_name: Mapped[str | None] = mapped_column(String, nullable=True) deployment_name: Mapped[str | None] = mapped_column(String, nullable=True) # Deprecated: use LLMModelFlow.is_default with CHAT flow type instead is_default_provider: Mapped[bool | None] = mapped_column(Boolean, nullable=True) # Deprecated: use LLMModelFlow.is_default with VISION flow type instead is_default_vision_provider: Mapped[bool | None] = mapped_column(Boolean) # Deprecated: use LLMModelFlow with VISION flow type instead default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True) # EE only is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) # Auto mode: models, visibility, and defaults are managed by GitHub config is_auto_mode: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) groups: Mapped[list["UserGroup"]] = relationship( "UserGroup", secondary="llm_provider__user_group", viewonly=True, ) personas: Mapped[list["Persona"]] = relationship( "Persona", secondary="llm_provider__persona", back_populates="allowed_by_llm_providers", viewonly=True, ) model_configurations: Mapped[list["ModelConfiguration"]] = relationship( "ModelConfiguration", back_populates="llm_provider", foreign_keys="ModelConfiguration.llm_provider_id", ) class ModelConfiguration(Base): __tablename__ = "model_configuration" id: Mapped[int] = mapped_column(Integer, primary_key=True) llm_provider_id: Mapped[int] = mapped_column( ForeignKey("llm_provider.id", ondelete="CASCADE"), nullable=False, ) name: Mapped[str] = mapped_column(String, nullable=False) # Represents whether or not a given model will be usable by the end user or not. # This field is primarily used for "Well Known LLM Providers", since for them, # we have a pre-defined list of LLM models that we allow them to choose from. # For example, for OpenAI, we allow the end-user to choose multiple models from # `["gpt-4", "gpt-4o", etc.]`. Once they make their selections, we set each # selected model to `is_visible = True`. # # For "Custom LLM Providers", we don't provide a comprehensive list of models # for the end-user to choose from; *they provide it themselves*. Therefore, # for Custom LLM Providers, `is_visible` will always be True. is_visible: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) # Max input tokens can be null when: # - The end-user configures models through a "Well Known LLM Provider". # - The end-user is configuring a model and chooses not to set a max-input-tokens limit. max_input_tokens: Mapped[int | None] = mapped_column(Integer, nullable=True) # Deprecated: use LLMModelFlow with VISION flow type instead supports_image_input: Mapped[bool | None] = mapped_column(Boolean, nullable=True) # Human-readable display name for the model. # For dynamic providers (OpenRouter, Bedrock, Ollama), this comes from the source API. # For static providers (OpenAI, Anthropic), this may be null and will fall back to LiteLLM. display_name: Mapped[str | None] = mapped_column(String, nullable=True) llm_provider: Mapped["LLMProvider"] = relationship( "LLMProvider", back_populates="model_configurations", ) llm_model_flows: Mapped[list["LLMModelFlow"]] = relationship( "LLMModelFlow", back_populates="model_configuration", cascade="all, delete-orphan", passive_deletes=True, ) @property def llm_model_flow_types(self) -> list[LLMModelFlowType]: return [flow.llm_model_flow_type for flow in self.llm_model_flows] class LLMModelFlow(Base): __tablename__ = "llm_model_flow" id: Mapped[int] = mapped_column(Integer, primary_key=True) llm_model_flow_type: Mapped[LLMModelFlowType] = mapped_column( Enum(LLMModelFlowType, native_enum=False), nullable=False ) model_configuration_id: Mapped[int] = mapped_column( ForeignKey("model_configuration.id", ondelete="CASCADE"), nullable=False, ) is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) model_configuration: Mapped["ModelConfiguration"] = relationship( "ModelConfiguration", back_populates="llm_model_flows", ) __table_args__ = ( UniqueConstraint( "llm_model_flow_type", "model_configuration_id", name="uq_model_config_per_llm_model_flow_type", ), Index( "ix_one_default_per_llm_model_flow", "llm_model_flow_type", unique=True, postgresql_where=(is_default == True), # noqa: E712 ), ) class ImageGenerationConfig(Base): __tablename__ = "image_generation_config" image_provider_id: Mapped[str] = mapped_column(String, primary_key=True) model_configuration_id: Mapped[int] = mapped_column( ForeignKey("model_configuration.id", ondelete="CASCADE"), nullable=False, ) is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) model_configuration: Mapped["ModelConfiguration"] = relationship( "ModelConfiguration" ) __table_args__ = ( Index("ix_image_generation_config_is_default", "is_default"), Index( "ix_image_generation_config_model_configuration_id", "model_configuration_id", ), ) class VoiceProvider(Base): """Configuration for voice services (STT and TTS).""" __tablename__ = "voice_provider" id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(String, unique=True) provider_type: Mapped[str] = mapped_column( String ) # "openai", "azure", "elevenlabs" api_key: Mapped[SensitiveValue[str] | None] = mapped_column( EncryptedString(), nullable=True ) api_base: Mapped[str | None] = mapped_column(String, nullable=True) custom_config: Mapped[dict[str, Any] | None] = mapped_column( postgresql.JSONB(), nullable=True ) # Model/voice configuration stt_model: Mapped[str | None] = mapped_column( String, nullable=True ) # e.g., "whisper-1" tts_model: Mapped[str | None] = mapped_column( String, nullable=True ) # e.g., "tts-1", "tts-1-hd" default_voice: Mapped[str | None] = mapped_column( String, nullable=True ) # e.g., "alloy", "echo" # STT and TTS can use different providers - only one provider per type is_default_stt: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) is_default_tts: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) # Enforce only one default STT provider and one default TTS provider at DB level __table_args__ = ( Index( "ix_voice_provider_one_default_stt", "is_default_stt", unique=True, postgresql_where=(is_default_stt == True), # noqa: E712 ), Index( "ix_voice_provider_one_default_tts", "is_default_tts", unique=True, postgresql_where=(is_default_tts == True), # noqa: E712 ), ) class CloudEmbeddingProvider(Base): __tablename__ = "embedding_provider" provider_type: Mapped[EmbeddingProvider] = mapped_column( Enum(EmbeddingProvider), primary_key=True ) api_url: Mapped[str | None] = mapped_column(String, nullable=True) api_key: Mapped[SensitiveValue[str] | None] = mapped_column(EncryptedString()) api_version: Mapped[str | None] = mapped_column(String, nullable=True) deployment_name: Mapped[str | None] = mapped_column(String, nullable=True) search_settings: Mapped[list["SearchSettings"]] = relationship( "SearchSettings", back_populates="cloud_provider", ) def __repr__(self) -> str: return f"" class InternetSearchProvider(Base): __tablename__ = "internet_search_provider" id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(String, unique=True, nullable=False) provider_type: Mapped[str] = mapped_column(String, nullable=False) api_key: Mapped[SensitiveValue[str] | None] = mapped_column( EncryptedString(), nullable=True ) config: Mapped[dict[str, str] | None] = mapped_column( postgresql.JSONB(), nullable=True ) is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) def __repr__(self) -> str: return f"" class InternetContentProvider(Base): __tablename__ = "internet_content_provider" id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(String, unique=True, nullable=False) provider_type: Mapped[str] = mapped_column(String, nullable=False) api_key: Mapped[SensitiveValue[str] | None] = mapped_column( EncryptedString(), nullable=True ) config: Mapped[WebContentProviderConfig | None] = mapped_column( PydanticType(WebContentProviderConfig), nullable=True ) is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) time_updated: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) def __repr__(self) -> str: return f"" class DocumentSet(Base): __tablename__ = "document_set" id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(String, unique=True) description: Mapped[str | None] = mapped_column(String) user_id: Mapped[UUID | None] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), nullable=True ) # Whether changes to the document set have been propagated is_up_to_date: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) # If `False`, then the document set is not visible to users who are not explicitly # given access to it either via the `users` or `groups` relationships is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) # Last time a user updated this document set time_last_modified_by_user: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) connector_credential_pairs: Mapped[list[ConnectorCredentialPair]] = relationship( "ConnectorCredentialPair", secondary=DocumentSet__ConnectorCredentialPair.__table__, primaryjoin=( (DocumentSet__ConnectorCredentialPair.document_set_id == id) & (DocumentSet__ConnectorCredentialPair.is_current.is_(True)) ), secondaryjoin=( DocumentSet__ConnectorCredentialPair.connector_credential_pair_id == ConnectorCredentialPair.id ), back_populates="document_sets", overlaps="document_set", ) personas: Mapped[list["Persona"]] = relationship( "Persona", secondary=Persona__DocumentSet.__table__, back_populates="document_sets", ) # Other users with access users: Mapped[list[User]] = relationship( "User", secondary=DocumentSet__User.__table__, viewonly=True, ) # EE only groups: Mapped[list["UserGroup"]] = relationship( "UserGroup", secondary="document_set__user_group", viewonly=True, ) federated_connectors: Mapped[list["FederatedConnector__DocumentSet"]] = ( relationship( "FederatedConnector__DocumentSet", back_populates="document_set", cascade="all, delete-orphan", ) ) class Tool(Base): __tablename__ = "tool" id: Mapped[int] = mapped_column(Integer, primary_key=True) # The name of the tool that the LLM will see name: Mapped[str] = mapped_column(String, nullable=False) description: Mapped[str] = mapped_column(Text, nullable=True) # ID of the tool in the codebase, only applies for in-code tools. # tools defined via the UI will have this as None in_code_tool_id: Mapped[str | None] = mapped_column(String, nullable=True) display_name: Mapped[str] = mapped_column(String, nullable=True) # OpenAPI scheme for the tool. Only applies to tools defined via the UI. openapi_schema: Mapped[dict[str, Any] | None] = mapped_column( postgresql.JSONB(), nullable=True ) # MCP tool input schema. Only applies to MCP tools. mcp_input_schema: Mapped[dict[str, Any] | None] = mapped_column( postgresql.JSONB(), nullable=True ) custom_headers: Mapped[list[HeaderItemDict] | None] = mapped_column( postgresql.JSONB(), nullable=True ) # user who created / owns the tool. Will be None for built-in tools. user_id: Mapped[UUID | None] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), nullable=True ) # whether to pass through the user's OAuth token as Authorization header passthrough_auth: Mapped[bool] = mapped_column(Boolean, default=False) # MCP server this tool is associated with (null for non-MCP tools) mcp_server_id: Mapped[int | None] = mapped_column( Integer, ForeignKey("mcp_server.id", ondelete="CASCADE"), nullable=True ) # OAuth configuration for this tool (null for tools without OAuth) oauth_config_id: Mapped[int | None] = mapped_column( Integer, ForeignKey("oauth_config.id", ondelete="SET NULL"), nullable=True ) enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) user: Mapped[User | None] = relationship("User", back_populates="custom_tools") oauth_config: Mapped["OAuthConfig | None"] = relationship( "OAuthConfig", back_populates="tools" ) # Relationship to Persona through the association table personas: Mapped[list["Persona"]] = relationship( "Persona", secondary=Persona__Tool.__table__, back_populates="tools", ) # MCP server relationship mcp_server: Mapped["MCPServer | None"] = relationship( "MCPServer", back_populates="current_actions" ) class OAuthConfig(Base): """OAuth provider configuration that can be shared across multiple tools""" __tablename__ = "oauth_config" id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(String, unique=True, nullable=False) # OAuth provider endpoints authorization_url: Mapped[str] = mapped_column(Text, nullable=False) token_url: Mapped[str] = mapped_column(Text, nullable=False) # Client credentials (encrypted) client_id: Mapped[SensitiveValue[str] | None] = mapped_column( EncryptedString(), nullable=False ) client_secret: Mapped[SensitiveValue[str] | None] = mapped_column( EncryptedString(), nullable=False ) # Optional configurations scopes: Mapped[list[str] | None] = mapped_column(postgresql.JSONB(), nullable=True) additional_params: Mapped[dict[str, Any] | None] = mapped_column( postgresql.JSONB(), nullable=True ) # Metadata created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) updated_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False, ) # Relationships tools: Mapped[list["Tool"]] = relationship("Tool", back_populates="oauth_config") user_tokens: Mapped[list["OAuthUserToken"]] = relationship( "OAuthUserToken", back_populates="oauth_config", cascade="all, delete-orphan" ) class OAuthUserToken(Base): """Per-user OAuth tokens for a specific OAuth configuration""" __tablename__ = "oauth_user_token" id: Mapped[int] = mapped_column(Integer, primary_key=True) oauth_config_id: Mapped[int] = mapped_column( ForeignKey("oauth_config.id", ondelete="CASCADE"), nullable=False ) user_id: Mapped[UUID] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), nullable=False ) # Token data (encrypted) # Structure: { # "access_token": "...", # "refresh_token": "...", # Optional # "token_type": "Bearer", # "expires_at": 1234567890, # Unix timestamp, optional # "scope": "repo user" # Optional # } token_data: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column( EncryptedJson(), nullable=False ) # Metadata created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) updated_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False, ) # Relationships oauth_config: Mapped["OAuthConfig"] = relationship( "OAuthConfig", back_populates="user_tokens" ) user: Mapped["User"] = relationship("User") # Unique constraint: One token per user per OAuth config __table_args__ = ( UniqueConstraint("oauth_config_id", "user_id", name="uq_oauth_user_token"), ) class StarterMessage(BaseModel): """Starter message for a persona.""" name: str message: str class Persona__PersonaLabel(Base): __tablename__ = "persona__persona_label" persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True) persona_label_id: Mapped[int] = mapped_column( ForeignKey("persona_label.id", ondelete="CASCADE"), primary_key=True ) class Persona(Base): __tablename__ = "persona" id: Mapped[int] = mapped_column(primary_key=True) user_id: Mapped[UUID | None] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), nullable=True ) name: Mapped[str] = mapped_column(String) description: Mapped[str] = mapped_column(String) # Allows the persona to specify a specific default LLM model # NOTE: only is applied on the actual response generation - is not used for things like # auto-detected time filters, relevance filters, etc. llm_model_provider_override: Mapped[str | None] = mapped_column( String, nullable=True ) llm_model_version_override: Mapped[str | None] = mapped_column( String, nullable=True ) default_model_configuration_id: Mapped[int | None] = mapped_column( Integer, ForeignKey("model_configuration.id", ondelete="SET NULL"), nullable=True, ) starter_messages: Mapped[list[StarterMessage] | None] = mapped_column( PydanticListType(StarterMessage), nullable=True ) search_start_date: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), default=None ) # Built-in personas are configured via backend during deployment # Treated specially (cannot be user edited etc.) builtin_persona: Mapped[bool] = mapped_column(Boolean, default=False) # Featured personas are highlighted in the UI is_featured: Mapped[bool] = mapped_column(Boolean, default=False) # controls whether the persona is listed in user-facing agent lists is_listed: Mapped[bool] = mapped_column(Boolean, default=True) # controls the ordering of personas in the UI # higher priority personas are displayed first, ties are resolved by the ID, # where lower value IDs (e.g. created earlier) are displayed first display_priority: Mapped[int | None] = mapped_column( Integer, nullable=True, default=None ) deleted: Mapped[bool] = mapped_column(Boolean, default=False) # Custom Agent Prompt system_prompt: Mapped[str | None] = mapped_column( String(length=PROMPT_LENGTH), nullable=True ) replace_base_system_prompt: Mapped[bool] = mapped_column(Boolean, default=False) task_prompt: Mapped[str | None] = mapped_column( String(length=PROMPT_LENGTH), nullable=True ) datetime_aware: Mapped[bool] = mapped_column(Boolean, default=True) uploaded_image_id: Mapped[str | None] = mapped_column(String, nullable=True) icon_name: Mapped[str | None] = mapped_column(String, nullable=True) # These are only defaults, users can select from all if desired document_sets: Mapped[list[DocumentSet]] = relationship( "DocumentSet", secondary=Persona__DocumentSet.__table__, back_populates="personas", ) tools: Mapped[list[Tool]] = relationship( "Tool", secondary=Persona__Tool.__table__, back_populates="personas", ) # Owner user: Mapped[User | None] = relationship("User", back_populates="personas") # Other users with access users: Mapped[list[User]] = relationship( "User", secondary=Persona__User.__table__, viewonly=True, ) # EE only is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) groups: Mapped[list["UserGroup"]] = relationship( "UserGroup", secondary="persona__user_group", viewonly=True, ) allowed_by_llm_providers: Mapped[list["LLMProvider"]] = relationship( "LLMProvider", secondary="llm_provider__persona", back_populates="personas", viewonly=True, ) # Relationship to UserFile user_files: Mapped[list["UserFile"]] = relationship( "UserFile", secondary="persona__user_file", back_populates="assistants", ) labels: Mapped[list["PersonaLabel"]] = relationship( "PersonaLabel", secondary=Persona__PersonaLabel.__table__, back_populates="personas", ) # Hierarchy nodes attached to this persona for scoped search hierarchy_nodes: Mapped[list["HierarchyNode"]] = relationship( "HierarchyNode", secondary="persona__hierarchy_node", back_populates="personas", ) # Individual documents attached to this persona for scoped search attached_documents: Mapped[list["Document"]] = relationship( "Document", secondary="persona__document", back_populates="attached_personas", ) # Default personas loaded via yaml cannot have the same name __table_args__ = ( Index( "_builtin_persona_name_idx", "name", unique=True, postgresql_where=(builtin_persona == True), # noqa: E712 ), ) class Persona__UserFile(Base): __tablename__ = "persona__user_file" persona_id: Mapped[int] = mapped_column( ForeignKey("persona.id", ondelete="CASCADE"), primary_key=True ) user_file_id: Mapped[UUID] = mapped_column( ForeignKey("user_file.id", ondelete="CASCADE"), primary_key=True ) class Persona__HierarchyNode(Base): """Association table linking personas to hierarchy nodes. This allows assistants to be configured with specific hierarchy nodes (folders, spaces, channels, etc.) for scoped search/retrieval. """ __tablename__ = "persona__hierarchy_node" persona_id: Mapped[int] = mapped_column( ForeignKey("persona.id", ondelete="CASCADE"), primary_key=True ) hierarchy_node_id: Mapped[int] = mapped_column( ForeignKey("hierarchy_node.id", ondelete="CASCADE"), primary_key=True ) class Persona__Document(Base): """Association table linking personas to individual documents. This allows assistants to be configured with specific documents for scoped search/retrieval. Complements hierarchy_nodes which allow attaching folders/spaces. """ __tablename__ = "persona__document" persona_id: Mapped[int] = mapped_column( ForeignKey("persona.id", ondelete="CASCADE"), primary_key=True ) document_id: Mapped[str] = mapped_column( ForeignKey("document.id", ondelete="CASCADE"), primary_key=True ) class PersonaLabel(Base): __tablename__ = "persona_label" id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(String, unique=True) personas: Mapped[list["Persona"]] = relationship( "Persona", secondary=Persona__PersonaLabel.__table__, back_populates="labels", ) class Assistant__UserSpecificConfig(Base): __tablename__ = "assistant__user_specific_config" assistant_id: Mapped[int] = mapped_column( ForeignKey("persona.id", ondelete="CASCADE"), primary_key=True ) user_id: Mapped[UUID] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), primary_key=True ) disabled_tool_ids: Mapped[list[int]] = mapped_column( postgresql.ARRAY(Integer), nullable=False ) AllowedAnswerFilters = ( Literal["well_answered_postfilter"] | Literal["questionmark_prefilter"] ) class ChannelConfig(TypedDict): """NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column in Postgres""" channel_name: str | None # None for default channel config respond_tag_only: NotRequired[bool] # defaults to False respond_to_bots: NotRequired[bool] # defaults to False is_ephemeral: NotRequired[bool] # defaults to False respond_member_group_list: NotRequired[list[str]] answer_filters: NotRequired[list[AllowedAnswerFilters]] # If None then no follow up # If empty list, follow up with no tags follow_up_tags: NotRequired[list[str]] show_continue_in_web_ui: NotRequired[bool] # defaults to False disabled: NotRequired[bool] # defaults to False class SlackChannelConfig(Base): __tablename__ = "slack_channel_config" id: Mapped[int] = mapped_column(primary_key=True) slack_bot_id: Mapped[int] = mapped_column( ForeignKey("slack_bot.id"), nullable=False ) persona_id: Mapped[int | None] = mapped_column( ForeignKey("persona.id"), nullable=True ) channel_config: Mapped[ChannelConfig] = mapped_column( postgresql.JSONB(), nullable=False ) enable_auto_filters: Mapped[bool] = mapped_column( Boolean, nullable=False, default=False ) is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) persona: Mapped[Persona | None] = relationship("Persona") slack_bot: Mapped["SlackBot"] = relationship( "SlackBot", back_populates="slack_channel_configs", ) standard_answer_categories: Mapped[list["StandardAnswerCategory"]] = relationship( "StandardAnswerCategory", secondary=SlackChannelConfig__StandardAnswerCategory.__table__, back_populates="slack_channel_configs", ) __table_args__ = ( UniqueConstraint( "slack_bot_id", "is_default", name="uq_slack_channel_config_slack_bot_id_default", ), Index( "ix_slack_channel_config_slack_bot_id_default", "slack_bot_id", "is_default", unique=True, postgresql_where=(is_default is True), ), ) class SlackBot(Base): __tablename__ = "slack_bot" id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(String) enabled: Mapped[bool] = mapped_column(Boolean, default=True) bot_token: Mapped[SensitiveValue[str] | None] = mapped_column( EncryptedString(), unique=True ) app_token: Mapped[SensitiveValue[str] | None] = mapped_column( EncryptedString(), unique=True ) user_token: Mapped[SensitiveValue[str] | None] = mapped_column( EncryptedString(), nullable=True ) slack_channel_configs: Mapped[list[SlackChannelConfig]] = relationship( "SlackChannelConfig", back_populates="slack_bot", cascade="all, delete-orphan", ) class DiscordBotConfig(Base): """Global Discord bot configuration (one per tenant). Stores the bot token when not provided via DISCORD_BOT_TOKEN env var. Uses a fixed ID with check constraint to enforce only one row per tenant. """ __tablename__ = "discord_bot_config" id: Mapped[str] = mapped_column( String, primary_key=True, server_default=text("'SINGLETON'") ) bot_token: Mapped[SensitiveValue[str] | None] = mapped_column( EncryptedString(), nullable=False ) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) class DiscordGuildConfig(Base): """Configuration for a Discord guild (server) connected to this tenant. registration_key is a one-time key used to link a Discord server to this tenant. Format: discord_. guild_id is NULL until the Discord admin runs !register with the key. """ __tablename__ = "discord_guild_config" id: Mapped[int] = mapped_column(primary_key=True) # Discord snowflake - NULL until registered via command in Discord guild_id: Mapped[int | None] = mapped_column(BigInteger, nullable=True, unique=True) guild_name: Mapped[str | None] = mapped_column(String(256), nullable=True) # One-time registration key: discord_. registration_key: Mapped[str] = mapped_column(String, unique=True, nullable=False) registered_at: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) # Configuration default_persona_id: Mapped[int | None] = mapped_column( ForeignKey("persona.id", ondelete="SET NULL"), nullable=True ) enabled: Mapped[bool] = mapped_column( Boolean, server_default=text("true"), nullable=False ) # Relationships default_persona: Mapped["Persona | None"] = relationship( "Persona", foreign_keys=[default_persona_id] ) channels: Mapped[list["DiscordChannelConfig"]] = relationship( back_populates="guild_config", cascade="all, delete-orphan" ) class DiscordChannelConfig(Base): """Per-channel configuration for Discord bot behavior. Used to whitelist specific channels and configure per-channel behavior. """ __tablename__ = "discord_channel_config" id: Mapped[int] = mapped_column(primary_key=True) guild_config_id: Mapped[int] = mapped_column( ForeignKey("discord_guild_config.id", ondelete="CASCADE"), nullable=False ) # Discord snowflake channel_id: Mapped[int] = mapped_column(BigInteger, nullable=False) channel_name: Mapped[str] = mapped_column(String(), nullable=False) # Channel type from Discord (text, forum) channel_type: Mapped[str] = mapped_column( String(20), server_default=text("'text'"), nullable=False ) # True if @everyone cannot view the channel is_private: Mapped[bool] = mapped_column( Boolean, server_default=text("false"), nullable=False ) # If true, bot only responds to messages in threads # Otherwise, will reply in channel thread_only_mode: Mapped[bool] = mapped_column( Boolean, server_default=text("false"), nullable=False ) # If true (default), bot only responds when @mentioned # If false, bot responds to ALL messages in this channel require_bot_invocation: Mapped[bool] = mapped_column( Boolean, server_default=text("true"), nullable=False ) # Override the guild's default persona for this channel persona_override_id: Mapped[int | None] = mapped_column( ForeignKey("persona.id", ondelete="SET NULL"), nullable=True ) enabled: Mapped[bool] = mapped_column( Boolean, server_default=text("false"), nullable=False ) # Relationships guild_config: Mapped["DiscordGuildConfig"] = relationship(back_populates="channels") persona_override: Mapped["Persona | None"] = relationship() # Constraints __table_args__ = ( UniqueConstraint( "guild_config_id", "channel_id", name="uq_discord_channel_guild_channel" ), ) class Milestone(Base): # This table is used to track significant events for a deployment towards finding value # The table is currently not used for features but it may be used in the future to inform # users about the product features and encourage usage/exploration. __tablename__ = "milestone" id: Mapped[UUID] = mapped_column( PGUUID(as_uuid=True), primary_key=True, default=uuid4 ) user_id: Mapped[UUID | None] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), nullable=True ) event_type: Mapped[MilestoneRecordType] = mapped_column(String) # Need to track counts and specific ids of certain events to know if the Milestone has been reached event_tracker: Mapped[dict | None] = mapped_column( postgresql.JSONB(), nullable=True ) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) user: Mapped[User | None] = relationship("User") __table_args__ = (UniqueConstraint("event_type", name="uq_milestone_event_type"),) class TaskQueueState(Base): # Currently refers to Celery Tasks __tablename__ = "task_queue_jobs" id: Mapped[int] = mapped_column(primary_key=True) # Celery task id. currently only for readability/diagnostics task_id: Mapped[str] = mapped_column(String) # For any job type, this would be the same task_name: Mapped[str] = mapped_column(String) # Note that if the task dies, this won't necessarily be marked FAILED correctly status: Mapped[TaskStatus] = mapped_column(Enum(TaskStatus, native_enum=False)) start_time: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True) ) register_time: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) class KVStore(Base): __tablename__ = "key_value_store" key: Mapped[str] = mapped_column(String, primary_key=True) value: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True) encrypted_value: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column( EncryptedJson(), nullable=True ) class FileRecord(Base): __tablename__ = "file_record" # Internal file ID, must be unique across all files. file_id: Mapped[str] = mapped_column(String, primary_key=True) display_name: Mapped[str] = mapped_column(String, nullable=True) file_origin: Mapped[FileOrigin] = mapped_column(Enum(FileOrigin, native_enum=False)) file_type: Mapped[str] = mapped_column(String, default="text/plain") file_metadata: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True) # External storage support (S3, MinIO, Azure Blob, etc.) bucket_name: Mapped[str] = mapped_column(String) object_key: Mapped[str] = mapped_column(String) # Timestamps for external storage created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) updated_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) class FileContent(Base): """Stores file content in PostgreSQL using Large Objects. Used when FILE_STORE_BACKEND=postgres to avoid needing S3/MinIO.""" __tablename__ = "file_content" file_id: Mapped[str] = mapped_column( String, ForeignKey("file_record.file_id", ondelete="CASCADE"), primary_key=True, ) # PostgreSQL Large Object OID referencing pg_largeobject lobj_oid: Mapped[int] = mapped_column(BigInteger, nullable=False) file_size: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) """ ************************************************************************ Enterprise Edition Models ************************************************************************ These models are only used in Enterprise Edition only features in Onyx. They are kept here to simplify the codebase and avoid having different assumptions on the shape of data being passed around between the MIT and EE versions of Onyx. In the MIT version of Onyx, assume these tables are always empty. """ class SamlAccount(Base): __tablename__ = "saml" id: Mapped[int] = mapped_column(primary_key=True) user_id: Mapped[int] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), unique=True ) encrypted_cookie: Mapped[str] = mapped_column(Text, unique=True) expires_at: Mapped[datetime.datetime] = mapped_column(DateTime(timezone=True)) updated_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) user: Mapped[User] = relationship("User") class User__UserGroup(Base): __tablename__ = "user__user_group" __table_args__ = (Index("ix_user__user_group_user_id", "user_id"),) is_curator: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) user_group_id: Mapped[int] = mapped_column( ForeignKey("user_group.id"), primary_key=True ) user_id: Mapped[UUID | None] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), primary_key=True, nullable=True ) class PermissionGrant(Base): __tablename__ = "permission_grant" __table_args__ = ( UniqueConstraint( "group_id", "permission", name="uq_permission_grant_group_permission" ), ) id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) group_id: Mapped[int] = mapped_column( ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False ) permission: Mapped[Permission] = mapped_column( Enum( Permission, native_enum=False, values_callable=lambda x: [e.value for e in x], ), nullable=False, ) grant_source: Mapped[GrantSource] = mapped_column( Enum(GrantSource, native_enum=False), nullable=False ) granted_by: Mapped[UUID | None] = mapped_column( ForeignKey("user.id", ondelete="SET NULL"), nullable=True ) granted_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) is_deleted: Mapped[bool] = mapped_column( Boolean, nullable=False, default=False, server_default=text("false") ) group: Mapped["UserGroup"] = relationship( "UserGroup", back_populates="permission_grants" ) @validates("permission") def _validate_permission(self, _key: str, value: Permission) -> Permission: if value in Permission.IMPLIED: raise ValueError( f"{value!r} is an implied permission and cannot be granted directly" ) return value class UserGroup__ConnectorCredentialPair(Base): __tablename__ = "user_group__connector_credential_pair" user_group_id: Mapped[int] = mapped_column( ForeignKey("user_group.id"), primary_key=True ) cc_pair_id: Mapped[int] = mapped_column( ForeignKey("connector_credential_pair.id"), primary_key=True ) # if `True`, then is part of the current state of the UserGroup # if `False`, then is a part of the prior state of the UserGroup # rows with `is_current=False` should be deleted when the UserGroup # is updated and should not exist for a given UserGroup if # `UserGroup.is_up_to_date == True` is_current: Mapped[bool] = mapped_column( Boolean, default=True, primary_key=True, ) cc_pair: Mapped[ConnectorCredentialPair] = relationship( "ConnectorCredentialPair", ) class Persona__UserGroup(Base): __tablename__ = "persona__user_group" persona_id: Mapped[int] = mapped_column(ForeignKey("persona.id"), primary_key=True) user_group_id: Mapped[int] = mapped_column( ForeignKey("user_group.id"), primary_key=True ) class LLMProvider__Persona(Base): """Association table restricting LLM providers to specific personas. If no such rows exist for a given LLM provider, then it is accessible by all personas. """ __tablename__ = "llm_provider__persona" llm_provider_id: Mapped[int] = mapped_column( ForeignKey("llm_provider.id", ondelete="CASCADE"), primary_key=True ) persona_id: Mapped[int] = mapped_column( ForeignKey("persona.id", ondelete="CASCADE"), primary_key=True ) class LLMProvider__UserGroup(Base): __tablename__ = "llm_provider__user_group" llm_provider_id: Mapped[int] = mapped_column( ForeignKey("llm_provider.id"), primary_key=True ) user_group_id: Mapped[int] = mapped_column( ForeignKey("user_group.id"), primary_key=True ) class DocumentSet__UserGroup(Base): __tablename__ = "document_set__user_group" document_set_id: Mapped[int] = mapped_column( ForeignKey("document_set.id"), primary_key=True ) user_group_id: Mapped[int] = mapped_column( ForeignKey("user_group.id"), primary_key=True ) class Credential__UserGroup(Base): __tablename__ = "credential__user_group" credential_id: Mapped[int] = mapped_column( ForeignKey("credential.id"), primary_key=True ) user_group_id: Mapped[int] = mapped_column( ForeignKey("user_group.id"), primary_key=True ) class UserGroup(Base): __tablename__ = "user_group" id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(String, unique=True) # whether or not changes to the UserGroup have been propagated to Vespa is_up_to_date: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) # tell the sync job to clean up the group is_up_for_deletion: Mapped[bool] = mapped_column( Boolean, nullable=False, default=False ) # whether this is a default group (e.g. "Basic", "Admins") that cannot be deleted is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) # Last time a user updated this user group time_last_modified_by_user: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) users: Mapped[list[User]] = relationship( "User", secondary=User__UserGroup.__table__, ) user_group_relationships: Mapped[list[User__UserGroup]] = relationship( "User__UserGroup", viewonly=True, ) cc_pairs: Mapped[list[ConnectorCredentialPair]] = relationship( "ConnectorCredentialPair", secondary=UserGroup__ConnectorCredentialPair.__table__, viewonly=True, ) cc_pair_relationships: Mapped[list[UserGroup__ConnectorCredentialPair]] = ( relationship( "UserGroup__ConnectorCredentialPair", viewonly=True, ) ) personas: Mapped[list[Persona]] = relationship( "Persona", secondary=Persona__UserGroup.__table__, viewonly=True, ) document_sets: Mapped[list[DocumentSet]] = relationship( "DocumentSet", secondary=DocumentSet__UserGroup.__table__, viewonly=True, ) credentials: Mapped[list[Credential]] = relationship( "Credential", secondary=Credential__UserGroup.__table__, ) # MCP servers accessible to this user group accessible_mcp_servers: Mapped[list["MCPServer"]] = relationship( "MCPServer", secondary="mcp_server__user_group", back_populates="user_groups" ) permission_grants: Mapped[list["PermissionGrant"]] = relationship( "PermissionGrant", back_populates="group", cascade="all, delete-orphan" ) """Tables related to Token Rate Limiting NOTE: `TokenRateLimit` is partially an MIT feature (global rate limit) """ class TokenRateLimit(Base): __tablename__ = "token_rate_limit" id: Mapped[int] = mapped_column(primary_key=True) enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) token_budget: Mapped[int] = mapped_column(Integer, nullable=False) period_hours: Mapped[int] = mapped_column(Integer, nullable=False) scope: Mapped[TokenRateLimitScope] = mapped_column( Enum(TokenRateLimitScope, native_enum=False) ) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) class TokenRateLimit__UserGroup(Base): __tablename__ = "token_rate_limit__user_group" rate_limit_id: Mapped[int] = mapped_column( ForeignKey("token_rate_limit.id"), primary_key=True ) user_group_id: Mapped[int] = mapped_column( ForeignKey("user_group.id"), primary_key=True ) class StandardAnswerCategory(Base): __tablename__ = "standard_answer_category" id: Mapped[int] = mapped_column(primary_key=True) name: Mapped[str] = mapped_column(String, unique=True) standard_answers: Mapped[list["StandardAnswer"]] = relationship( "StandardAnswer", secondary=StandardAnswer__StandardAnswerCategory.__table__, back_populates="categories", ) slack_channel_configs: Mapped[list["SlackChannelConfig"]] = relationship( "SlackChannelConfig", secondary=SlackChannelConfig__StandardAnswerCategory.__table__, back_populates="standard_answer_categories", ) class StandardAnswer(Base): __tablename__ = "standard_answer" id: Mapped[int] = mapped_column(primary_key=True) keyword: Mapped[str] = mapped_column(String) answer: Mapped[str] = mapped_column(String) active: Mapped[bool] = mapped_column(Boolean) match_regex: Mapped[bool] = mapped_column(Boolean) match_any_keywords: Mapped[bool] = mapped_column(Boolean) __table_args__ = ( Index( "unique_keyword_active", keyword, active, unique=True, postgresql_where=(active == True), # noqa: E712 ), ) categories: Mapped[list[StandardAnswerCategory]] = relationship( "StandardAnswerCategory", secondary=StandardAnswer__StandardAnswerCategory.__table__, back_populates="standard_answers", ) chat_messages: Mapped[list[ChatMessage]] = relationship( "ChatMessage", secondary=ChatMessage__StandardAnswer.__table__, back_populates="standard_answers", ) class BackgroundError(Base): """Important background errors. Serves to: 1. Ensure that important logs are kept around and not lost on rotation/container restarts 2. A trail for high-signal events so that the debugger doesn't need to remember/know every possible relevant log line. """ __tablename__ = "background_error" id: Mapped[int] = mapped_column(primary_key=True) message: Mapped[str] = mapped_column(String) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) # option to link the error to a specific CC Pair cc_pair_id: Mapped[int | None] = mapped_column( ForeignKey("connector_credential_pair.id", ondelete="CASCADE"), nullable=True ) cc_pair: Mapped["ConnectorCredentialPair | None"] = relationship( "ConnectorCredentialPair", back_populates="background_errors" ) """Tables related to Permission Sync""" class User__ExternalUserGroupId(Base): """Maps user info both internal and external to the name of the external group This maps the user to all of their external groups so that the external group name can be attached to the ACL list matching during query time. User level permissions can be handled by directly adding the Onyx user to the doc ACL list""" __tablename__ = "user__external_user_group_id" user_id: Mapped[UUID] = mapped_column(ForeignKey("user.id"), primary_key=True) # These group ids have been prefixed by the source type external_user_group_id: Mapped[str] = mapped_column(String, primary_key=True) cc_pair_id: Mapped[int] = mapped_column( ForeignKey("connector_credential_pair.id"), primary_key=True ) # Signifies whether or not the group should be cleaned up at the end of a # group sync run. stale: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) __table_args__ = ( Index( "ix_user_external_group_cc_pair_stale", "cc_pair_id", "stale", ), Index( "ix_user_external_group_stale", "stale", ), ) class PublicExternalUserGroup(Base): """Stores all public external user "groups". For example, things like Google Drive folders that are marked as `Anyone with the link` or `Anyone in the domain` """ __tablename__ = "public_external_user_group" external_user_group_id: Mapped[str] = mapped_column(String, primary_key=True) cc_pair_id: Mapped[int] = mapped_column( ForeignKey("connector_credential_pair.id", ondelete="CASCADE"), primary_key=True ) # Signifies whether or not the group should be cleaned up at the end of a # group sync run. stale: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) __table_args__ = ( Index( "ix_public_external_group_cc_pair_stale", "cc_pair_id", "stale", ), Index( "ix_public_external_group_stale", "stale", ), ) class UsageReport(Base): """This stores metadata about usage reports generated by admin including user who generated them as well as the period they cover. The actual zip file of the report is stored as a lo using the FileRecord """ __tablename__ = "usage_reports" id: Mapped[int] = mapped_column(primary_key=True) report_name: Mapped[str] = mapped_column(ForeignKey("file_record.file_id")) # if None, report was auto-generated requestor_user_id: Mapped[UUID | None] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), nullable=True ) time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) period_from: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True) ) period_to: Mapped[datetime.datetime | None] = mapped_column(DateTime(timezone=True)) requestor = relationship("User") file = relationship("FileRecord") class InputPrompt(Base): __tablename__ = "inputprompt" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) prompt: Mapped[str] = mapped_column(String) content: Mapped[str] = mapped_column(String) active: Mapped[bool] = mapped_column(Boolean) user: Mapped[User | None] = relationship("User", back_populates="input_prompts") is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) user_id: Mapped[UUID | None] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), nullable=True ) __table_args__ = ( # Unique constraint on (prompt, user_id) for user-owned prompts UniqueConstraint("prompt", "user_id", name="uq_inputprompt_prompt_user_id"), # Partial unique index for public prompts (user_id IS NULL) Index( "uq_inputprompt_prompt_public", "prompt", unique=True, postgresql_where=text("user_id IS NULL"), ), ) class InputPrompt__User(Base): __tablename__ = "inputprompt__user" input_prompt_id: Mapped[int] = mapped_column( ForeignKey("inputprompt.id"), primary_key=True ) user_id: Mapped[UUID | None] = mapped_column( ForeignKey("user.id"), primary_key=True ) disabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) class Project__UserFile(Base): __tablename__ = "project__user_file" project_id: Mapped[int] = mapped_column( ForeignKey("user_project.id"), primary_key=True ) user_file_id: Mapped[UUID] = mapped_column( ForeignKey("user_file.id"), primary_key=True ) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) __table_args__ = ( Index( "ix_project__user_file_project_id_created_at", project_id, created_at.desc(), ), ) class UserProject(Base): __tablename__ = "user_project" id: Mapped[int] = mapped_column(primary_key=True, autoincrement=True) user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=False) name: Mapped[str] = mapped_column(nullable=False) description: Mapped[str] = mapped_column(nullable=True) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) user: Mapped["User"] = relationship(back_populates="projects") user_files: Mapped[list["UserFile"]] = relationship( "UserFile", secondary=Project__UserFile.__table__, back_populates="projects", ) chat_sessions: Mapped[list["ChatSession"]] = relationship( "ChatSession", back_populates="project", lazy="selectin" ) instructions: Mapped[str] = mapped_column(String) class UserDocument(str, Enum): CHAT = "chat" RECENT = "recent" FILE = "file" class UserFile(Base): __tablename__ = "user_file" id: Mapped[UUID] = mapped_column(PGUUID(as_uuid=True), primary_key=True) user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=False) assistants: Mapped[list["Persona"]] = relationship( "Persona", secondary=Persona__UserFile.__table__, back_populates="user_files", ) file_id: Mapped[str] = mapped_column(nullable=False) name: Mapped[str] = mapped_column(nullable=False) created_at: Mapped[datetime.datetime] = mapped_column( default=datetime.datetime.utcnow ) user: Mapped["User"] = relationship(back_populates="files") token_count: Mapped[int | None] = mapped_column(Integer, nullable=True) file_type: Mapped[str] = mapped_column(String, nullable=False) status: Mapped[UserFileStatus] = mapped_column( Enum(UserFileStatus, native_enum=False, name="userfilestatus"), nullable=False, default=UserFileStatus.PROCESSING, ) needs_project_sync: Mapped[bool] = mapped_column( Boolean, nullable=False, default=False ) needs_persona_sync: Mapped[bool] = mapped_column( Boolean, nullable=False, default=False ) last_project_sync_at: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) chunk_count: Mapped[int | None] = mapped_column(Integer, nullable=True) last_accessed_at: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) link_url: Mapped[str | None] = mapped_column(String, nullable=True) content_type: Mapped[str | None] = mapped_column(String, nullable=True) projects: Mapped[list["UserProject"]] = relationship( "UserProject", secondary=Project__UserFile.__table__, back_populates="user_files", lazy="selectin", ) """ Multi-tenancy related tables """ class PublicBase(DeclarativeBase): __abstract__ = True # Strictly keeps track of the tenant that a given user will authenticate to. class UserTenantMapping(Base): __tablename__ = "user_tenant_mapping" __table_args__ = ({"schema": "public"},) email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True) tenant_id: Mapped[str] = mapped_column(String, nullable=False, primary_key=True) active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) @validates("email") def validate_email(self, key: str, value: str) -> str: # noqa: ARG002 return value.lower() if value else value class AvailableTenant(Base): __tablename__ = "available_tenant" """ These entries will only exist ephemerally and are meant to be picked up by new users on registration. """ tenant_id: Mapped[str] = mapped_column(String, primary_key=True, nullable=False) alembic_version: Mapped[str] = mapped_column(String, nullable=False) date_created: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False) # This is a mapping from tenant IDs to anonymous user paths class TenantAnonymousUserPath(Base): __tablename__ = "tenant_anonymous_user_path" tenant_id: Mapped[str] = mapped_column(String, primary_key=True, nullable=False) anonymous_user_path: Mapped[str] = mapped_column( String, nullable=False, unique=True ) class MCPServer(Base): """Model for storing MCP server configurations""" __tablename__ = "mcp_server" id: Mapped[int] = mapped_column(Integer, primary_key=True) # Owner email of user who configured this server owner: Mapped[str] = mapped_column(String, nullable=False) name: Mapped[str] = mapped_column(String, nullable=False) description: Mapped[str | None] = mapped_column(String, nullable=True) server_url: Mapped[str] = mapped_column(String, nullable=False) # Transport type for connecting to the MCP server transport: Mapped[MCPTransport | None] = mapped_column( Enum(MCPTransport, native_enum=False), nullable=True ) # Auth type: "none", "api_token", or "oauth" auth_type: Mapped[MCPAuthenticationType | None] = mapped_column( Enum(MCPAuthenticationType, native_enum=False), nullable=True ) # Who performs authentication for this server (ADMIN or PER_USER) auth_performer: Mapped[MCPAuthenticationPerformer | None] = mapped_column( Enum(MCPAuthenticationPerformer, native_enum=False), nullable=True ) # Status tracking for configuration flow status: Mapped[MCPServerStatus] = mapped_column( Enum(MCPServerStatus, native_enum=False), nullable=False, server_default="CREATED", ) # Admin connection config - used for the config page # and (when applicable) admin-managed auth # and (when applicable) per-user auth admin_connection_config_id: Mapped[int | None] = mapped_column( Integer, ForeignKey("mcp_connection_config.id", ondelete="SET NULL"), nullable=True, ) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) updated_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) last_refreshed_at: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) # Relationships admin_connection_config: Mapped["MCPConnectionConfig | None"] = relationship( "MCPConnectionConfig", foreign_keys=[admin_connection_config_id], back_populates="admin_servers", ) user_connection_configs: Mapped[list["MCPConnectionConfig"]] = relationship( "MCPConnectionConfig", foreign_keys="MCPConnectionConfig.mcp_server_id", back_populates="mcp_server", passive_deletes=True, ) current_actions: Mapped[list["Tool"]] = relationship( "Tool", back_populates="mcp_server", cascade="all, delete-orphan" ) # Many-to-many relationships for access control users: Mapped[list["User"]] = relationship( "User", secondary="mcp_server__user", back_populates="accessible_mcp_servers" ) user_groups: Mapped[list["UserGroup"]] = relationship( "UserGroup", secondary="mcp_server__user_group", back_populates="accessible_mcp_servers", ) class MCPServer__User(Base): __tablename__ = "mcp_server__user" mcp_server_id: Mapped[int] = mapped_column( ForeignKey("mcp_server.id", ondelete="CASCADE"), primary_key=True ) user_id: Mapped[UUID] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), primary_key=True ) class MCPServer__UserGroup(Base): __tablename__ = "mcp_server__user_group" mcp_server_id: Mapped[int] = mapped_column( ForeignKey("mcp_server.id"), primary_key=True ) user_group_id: Mapped[int] = mapped_column( ForeignKey("user_group.id"), primary_key=True ) class MCPConnectionConfig(Base): """Model for storing MCP connection configurations (credentials, auth data)""" __tablename__ = "mcp_connection_config" id: Mapped[int] = mapped_column(Integer, primary_key=True) # Server this config is for (nullable for template configs) mcp_server_id: Mapped[int | None] = mapped_column( Integer, ForeignKey("mcp_server.id", ondelete="CASCADE"), nullable=True ) # User email this config is for (empty for admin configs and templates) user_email: Mapped[str] = mapped_column(String, nullable=False, default="") # Config data stored as JSON # Format: { # "refresh_token": "", # OAuth only # "access_token": "", # OAuth only # "headers": {"key": "value", "key2": "value2"}, # "header_substitutions": {"": ""}, # stored header template substitutions # "request_body": ["path/in/body:value", "path2/in2/body2:value2"] # TBD # "client_id": "", # For dynamically registered OAuth clients # "client_secret": "", # For confidential clients # "registration_access_token": "", # For managing registration # "registration_client_uri": "", # For managing registration # } config: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column( EncryptedJson(), nullable=False, default=dict ) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) updated_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) # Relationships mcp_server: Mapped["MCPServer | None"] = relationship( "MCPServer", foreign_keys=[mcp_server_id], back_populates="user_connection_configs", ) admin_servers: Mapped[list["MCPServer"]] = relationship( "MCPServer", foreign_keys="MCPServer.admin_connection_config_id", back_populates="admin_connection_config", ) __table_args__ = ( Index("ix_mcp_connection_config_user_email", "user_email"), Index("ix_mcp_connection_config_server_user", "mcp_server_id", "user_email"), ) """ Permission Sync Tables """ class DocPermissionSyncAttempt(Base): """ Represents an attempt to sync document permissions for a connector credential pair. Similar to IndexAttempt but specifically for document permission syncing operations. """ __tablename__ = "doc_permission_sync_attempt" id: Mapped[int] = mapped_column(primary_key=True) connector_credential_pair_id: Mapped[int] = mapped_column( ForeignKey("connector_credential_pair.id"), nullable=False, ) # Status of the sync attempt status: Mapped[PermissionSyncStatus] = mapped_column( Enum(PermissionSyncStatus, native_enum=False, index=True) ) # Counts for tracking progress total_docs_synced: Mapped[int | None] = mapped_column(Integer, default=0) docs_with_permission_errors: Mapped[int | None] = mapped_column(Integer, default=0) # Error message if sync fails error_message: Mapped[str | None] = mapped_column(Text, default=None) # Timestamps time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), index=True, ) time_started: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), default=None ) time_finished: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), default=None ) # Relationships connector_credential_pair: Mapped[ConnectorCredentialPair] = relationship( "ConnectorCredentialPair" ) __table_args__ = ( Index( "ix_permission_sync_attempt_latest_for_cc_pair", "connector_credential_pair_id", "time_created", ), Index( "ix_permission_sync_attempt_status_time", "status", desc("time_finished"), ), ) def __repr__(self) -> str: return f"" def is_finished(self) -> bool: return self.status.is_terminal() class ExternalGroupPermissionSyncAttempt(Base): """ Represents an attempt to sync external group memberships for users. This tracks the syncing of user-to-external-group mappings across connectors. """ __tablename__ = "external_group_permission_sync_attempt" id: Mapped[int] = mapped_column(primary_key=True) # Can be tied to a specific connector or be a global group sync connector_credential_pair_id: Mapped[int | None] = mapped_column( ForeignKey("connector_credential_pair.id"), nullable=True, # Nullable for global group syncs across all connectors ) # Status of the group sync attempt status: Mapped[PermissionSyncStatus] = mapped_column( Enum(PermissionSyncStatus, native_enum=False, index=True) ) # Counts for tracking progress total_users_processed: Mapped[int | None] = mapped_column(Integer, default=0) total_groups_processed: Mapped[int | None] = mapped_column(Integer, default=0) total_group_memberships_synced: Mapped[int | None] = mapped_column( Integer, default=0 ) # Error message if sync fails error_message: Mapped[str | None] = mapped_column(Text, default=None) # Timestamps time_created: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), index=True, ) time_started: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), default=None ) time_finished: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), default=None ) # Relationships connector_credential_pair: Mapped[ConnectorCredentialPair | None] = relationship( "ConnectorCredentialPair" ) __table_args__ = ( Index( "ix_group_sync_attempt_cc_pair_time", "connector_credential_pair_id", "time_created", ), Index( "ix_group_sync_attempt_status_time", "status", desc("time_finished"), ), ) def __repr__(self) -> str: return f"" def is_finished(self) -> bool: return self.status.is_terminal() class License(Base): """Stores the signed license blob (singleton pattern - only one row).""" __tablename__ = "license" __table_args__ = ( # Singleton pattern - unique index on constant ensures only one row Index("idx_license_singleton", text("(true)"), unique=True), ) id: Mapped[int] = mapped_column(primary_key=True) license_data: Mapped[str] = mapped_column(Text, nullable=False) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now() ) updated_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) class TenantUsage(Base): """ Tracks per-tenant usage statistics within a time window for cloud usage limits. Each row represents usage for a specific tenant during a specific time window. A new row is created when the window rolls over (typically weekly). """ __tablename__ = "tenant_usage" id: Mapped[int] = mapped_column(primary_key=True) # The start of the usage tracking window (e.g., start of the week in UTC) window_start: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), nullable=False, index=True ) # Cumulative LLM usage cost in cents for the window llm_cost_cents: Mapped[float] = mapped_column(Float, nullable=False, default=0.0) # Number of chunks indexed during the window chunks_indexed: Mapped[int] = mapped_column(Integer, nullable=False, default=0) # Number of API calls using API keys or Personal Access Tokens api_calls: Mapped[int] = mapped_column(Integer, nullable=False, default=0) # Number of non-streaming API calls (more expensive operations) non_streaming_api_calls: Mapped[int] = mapped_column( Integer, nullable=False, default=0 ) # Last updated timestamp for tracking freshness updated_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now() ) __table_args__ = ( # Ensure only one row per window start (tenant_id is in the schema name) UniqueConstraint("window_start", name="uq_tenant_usage_window"), ) """Tables related to Build Mode (CLI Agent Platform)""" class BuildSession(Base): """Stores metadata about CLI agent build sessions.""" __tablename__ = "build_session" id: Mapped[UUID] = mapped_column( PGUUID(as_uuid=True), primary_key=True, default=uuid4 ) user_id: Mapped[UUID | None] = mapped_column( PGUUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=True ) name: Mapped[str | None] = mapped_column(String, nullable=True) status: Mapped[BuildSessionStatus] = mapped_column( Enum(BuildSessionStatus, native_enum=False, name="buildsessionstatus"), nullable=False, default=BuildSessionStatus.ACTIVE, ) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) last_activity_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False, ) nextjs_port: Mapped[int | None] = mapped_column(Integer, nullable=True) demo_data_enabled: Mapped[bool] = mapped_column( Boolean, nullable=False, server_default=text("true") ) sharing_scope: Mapped[SharingScope] = mapped_column( String, nullable=False, default=SharingScope.PRIVATE, server_default="private", ) # Relationships user: Mapped[User | None] = relationship("User", foreign_keys=[user_id]) artifacts: Mapped[list["Artifact"]] = relationship( "Artifact", back_populates="session", cascade="all, delete-orphan" ) messages: Mapped[list["BuildMessage"]] = relationship( "BuildMessage", back_populates="session", cascade="all, delete-orphan" ) snapshots: Mapped[list["Snapshot"]] = relationship( "Snapshot", back_populates="session", cascade="all, delete-orphan" ) __table_args__ = ( Index("ix_build_session_user_created", "user_id", desc("created_at")), Index("ix_build_session_status", "status"), ) class Sandbox(Base): """Stores sandbox container metadata for users (one sandbox per user).""" __tablename__ = "sandbox" id: Mapped[UUID] = mapped_column( PGUUID(as_uuid=True), primary_key=True, default=uuid4 ) user_id: Mapped[UUID] = mapped_column( PGUUID(as_uuid=True), ForeignKey("user.id", ondelete="CASCADE"), nullable=False, unique=True, ) container_id: Mapped[str | None] = mapped_column(String, nullable=True) status: Mapped[SandboxStatus] = mapped_column( Enum(SandboxStatus, native_enum=False, name="sandboxstatus"), nullable=False, default=SandboxStatus.PROVISIONING, ) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) last_heartbeat: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) # Relationships user: Mapped[User] = relationship("User") __table_args__ = ( Index("ix_sandbox_status", "status"), Index("ix_sandbox_container_id", "container_id"), ) class Artifact(Base): """Stores metadata about artifacts generated by CLI agents.""" __tablename__ = "artifact" id: Mapped[UUID] = mapped_column( PGUUID(as_uuid=True), primary_key=True, default=uuid4 ) session_id: Mapped[UUID] = mapped_column( PGUUID(as_uuid=True), ForeignKey("build_session.id", ondelete="CASCADE"), nullable=False, ) type: Mapped[ArtifactType] = mapped_column( Enum(ArtifactType, native_enum=False, name="artifacttype"), nullable=False ) # path of artifact in sandbox relative to outputs/ path: Mapped[str] = mapped_column(String, nullable=False) name: Mapped[str] = mapped_column(String, nullable=False) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) updated_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False, ) # Relationships session: Mapped[BuildSession] = relationship( "BuildSession", back_populates="artifacts" ) __table_args__ = ( Index("ix_artifact_session_created", "session_id", desc("created_at")), Index("ix_artifact_type", "type"), ) class Snapshot(Base): """Stores metadata about session output snapshots.""" __tablename__ = "snapshot" id: Mapped[UUID] = mapped_column( PGUUID(as_uuid=True), primary_key=True, default=uuid4 ) session_id: Mapped[UUID] = mapped_column( PGUUID(as_uuid=True), ForeignKey("build_session.id", ondelete="CASCADE"), nullable=False, ) storage_path: Mapped[str] = mapped_column(String, nullable=False) size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) # Relationships session: Mapped[BuildSession] = relationship( "BuildSession", back_populates="snapshots" ) __table_args__ = ( Index("ix_snapshot_session_created", "session_id", desc("created_at")), ) class BuildMessage(Base): """Stores messages exchanged in build sessions. All message data is stored in message_metadata as JSON (the raw ACP packet). The turn_index groups all assistant responses under the user prompt they respond to. Packet types stored in message_metadata: - user_message: {type: "user_message", content: {...}} - agent_message: {type: "agent_message", content: {...}} (accumulated from chunks) - agent_thought: {type: "agent_thought", content: {...}} (accumulated from chunks) - tool_call_progress: {type: "tool_call_progress", status: "completed", ...} (only completed) - agent_plan_update: {type: "agent_plan_update", entries: [...]} (upserted, latest only) """ __tablename__ = "build_message" id: Mapped[UUID] = mapped_column( PGUUID(as_uuid=True), primary_key=True, default=uuid4 ) session_id: Mapped[UUID] = mapped_column( PGUUID(as_uuid=True), ForeignKey("build_session.id", ondelete="CASCADE"), nullable=False, ) turn_index: Mapped[int] = mapped_column(Integer, nullable=False) type: Mapped[MessageType] = mapped_column( Enum(MessageType, native_enum=False, name="messagetype"), nullable=False ) message_metadata: Mapped[dict[str, Any]] = mapped_column(PGJSONB, nullable=False) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) # Relationships session: Mapped[BuildSession] = relationship( "BuildSession", back_populates="messages" ) __table_args__ = ( Index( "ix_build_message_session_turn", "session_id", "turn_index", "created_at" ), ) """ SCIM 2.0 Provisioning Models (Enterprise Edition only) Used for automated user/group provisioning from identity providers (Okta, Azure AD). """ class ScimToken(Base): """Bearer tokens for IdP SCIM authentication.""" __tablename__ = "scim_token" id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(String, nullable=False) hashed_token: Mapped[str] = mapped_column( String(64), unique=True, nullable=False ) # SHA256 = 64 hex chars token_display: Mapped[str] = mapped_column( String, nullable=False ) # Last 4 chars for UI identification created_by_id: Mapped[UUID] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), nullable=False ) is_active: Mapped[bool] = mapped_column( Boolean, server_default=text("true"), nullable=False ) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) last_used_at: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) created_by: Mapped[User] = relationship("User", foreign_keys=[created_by_id]) class ScimUserMapping(Base): """Maps SCIM externalId from the IdP to an Onyx User.""" __tablename__ = "scim_user_mapping" id: Mapped[int] = mapped_column(Integer, primary_key=True) external_id: Mapped[str | None] = mapped_column( String, unique=True, index=True, nullable=True ) user_id: Mapped[UUID] = mapped_column( ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False ) scim_username: Mapped[str | None] = mapped_column(String, nullable=True) department: Mapped[str | None] = mapped_column(String, nullable=True) manager: Mapped[str | None] = mapped_column(String, nullable=True) given_name: Mapped[str | None] = mapped_column(String, nullable=True) family_name: Mapped[str | None] = mapped_column(String, nullable=True) scim_emails_json: Mapped[str | None] = mapped_column(Text, nullable=True) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) updated_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False, ) user: Mapped[User] = relationship("User", foreign_keys=[user_id]) class ScimGroupMapping(Base): """Maps SCIM externalId from the IdP to an Onyx UserGroup.""" __tablename__ = "scim_group_mapping" id: Mapped[int] = mapped_column(Integer, primary_key=True) external_id: Mapped[str] = mapped_column(String, unique=True, index=True) user_group_id: Mapped[int] = mapped_column( ForeignKey("user_group.id", ondelete="CASCADE"), unique=True, nullable=False ) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) updated_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False, ) user_group: Mapped[UserGroup] = relationship( "UserGroup", foreign_keys=[user_group_id] ) class CodeInterpreterServer(Base): """Details about the code interpreter server""" __tablename__ = "code_interpreter_server" id: Mapped[int] = mapped_column(Integer, primary_key=True) server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True) class CacheStore(Base): """Key-value cache table used by ``PostgresCacheBackend``. Replaces Redis for simple KV caching, locks, and list operations when ``CACHE_BACKEND=postgres`` (NO_VECTOR_DB deployments). Intentionally separate from ``KVStore``: - Stores raw bytes (LargeBinary) vs JSONB, matching Redis semantics. - Has ``expires_at`` for TTL; rows are periodically garbage-collected. - Holds ephemeral data (tokens, stop signals, lock state) not persistent application config, so cleanup can be aggressive. """ __tablename__ = "cache_store" key: Mapped[str] = mapped_column(String, primary_key=True) value: Mapped[bytes | None] = mapped_column(LargeBinary, nullable=True) expires_at: Mapped[datetime.datetime | None] = mapped_column( DateTime(timezone=True), nullable=True ) class Hook(Base): """Pairs a HookPoint with a customer-provided API endpoint. At most one non-deleted Hook per HookPoint is allowed, enforced by a partial unique index on (hook_point) where deleted=false. """ __tablename__ = "hook" id: Mapped[int] = mapped_column(Integer, primary_key=True) name: Mapped[str] = mapped_column(String, nullable=False) hook_point: Mapped[HookPoint] = mapped_column( Enum(HookPoint, native_enum=False), nullable=False ) endpoint_url: Mapped[str | None] = mapped_column(Text, nullable=True) api_key: Mapped[SensitiveValue[str] | None] = mapped_column( EncryptedString(), nullable=True ) is_reachable: Mapped[bool | None] = mapped_column( Boolean, nullable=True, default=None ) # null = never validated, true = last check passed, false = last check failed fail_strategy: Mapped[HookFailStrategy] = mapped_column( Enum(HookFailStrategy, native_enum=False), nullable=False, default=HookFailStrategy.HARD, ) timeout_seconds: Mapped[float] = mapped_column(Float, nullable=False, default=30.0) is_active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) deleted: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) creator_id: Mapped[UUID | None] = mapped_column( PGUUID(as_uuid=True), ForeignKey("user.id", ondelete="SET NULL"), nullable=True, ) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False ) updated_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), onupdate=func.now(), nullable=False, ) creator: Mapped["User | None"] = relationship("User", foreign_keys=[creator_id]) execution_logs: Mapped[list["HookExecutionLog"]] = relationship( "HookExecutionLog", back_populates="hook", cascade="all, delete-orphan" ) __table_args__ = ( Index( "ix_hook_one_non_deleted_per_point", "hook_point", unique=True, postgresql_where=(deleted == False), # noqa: E712 ), ) class HookExecutionLog(Base): """Records hook executions for health monitoring and debugging. Currently only failures are logged; the is_success column exists so success logging can be added later without a schema change. Retention: rows older than 30 days are deleted by a nightly Celery task. """ __tablename__ = "hook_execution_log" id: Mapped[int] = mapped_column(Integer, primary_key=True) hook_id: Mapped[int] = mapped_column( Integer, ForeignKey("hook.id", ondelete="CASCADE"), nullable=False, index=True, ) is_success: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) error_message: Mapped[str | None] = mapped_column(Text, nullable=True) status_code: Mapped[int | None] = mapped_column(Integer, nullable=True) duration_ms: Mapped[int | None] = mapped_column(Integer, nullable=True) created_at: Mapped[datetime.datetime] = mapped_column( DateTime(timezone=True), server_default=func.now(), nullable=False, index=True ) hook: Mapped["Hook"] = relationship("Hook", back_populates="execution_logs") ================================================ FILE: backend/onyx/db/notification.py ================================================ from datetime import datetime from datetime import timezone from uuid import UUID from sqlalchemy import cast from sqlalchemy import select from sqlalchemy.dialects import postgresql from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session from sqlalchemy.sql import func from onyx.auth.schemas import UserRole from onyx.configs.constants import NotificationType from onyx.db.models import Notification from onyx.db.models import User def create_notification( user_id: UUID | None, notif_type: NotificationType, db_session: Session, title: str, description: str | None = None, additional_data: dict | None = None, autocommit: bool = True, ) -> Notification: # Previously, we only matched the first identical, undismissed notification # Now, we assume some uniqueness to notifications # If we previously issued a notification that was dismissed, we no longer issue a new one # Normalize additional_data to match the unique index behavior # The index uses COALESCE(additional_data, '{}'::jsonb) # We need to match this logic in our query additional_data_normalized = additional_data if additional_data is not None else {} existing_notification = ( db_session.query(Notification) .filter_by(user_id=user_id, notif_type=notif_type) .filter( func.coalesce(Notification.additional_data, cast({}, postgresql.JSONB)) == additional_data_normalized ) .first() ) if existing_notification: # Update the last_shown timestamp if the notification is not dismissed if not existing_notification.dismissed: existing_notification.last_shown = func.now() if autocommit: db_session.commit() return existing_notification # Create a new notification if none exists notification = Notification( user_id=user_id, notif_type=notif_type, title=title, description=description, dismissed=False, last_shown=func.now(), first_shown=func.now(), additional_data=additional_data, ) db_session.add(notification) if autocommit: db_session.commit() return notification def get_notification_by_id( notification_id: int, user: User, db_session: Session ) -> Notification: user_id = user.id notif = db_session.get(Notification, notification_id) if not notif: raise ValueError(f"No notification found with id {notification_id}") if notif.user_id != user_id and not ( notif.user_id is None and user is not None and user.role == UserRole.ADMIN ): raise PermissionError( f"User {user_id} is not authorized to access notification {notification_id}" ) return notif def get_notifications( user: User | None, db_session: Session, notif_type: NotificationType | None = None, include_dismissed: bool = True, ) -> list[Notification]: query = select(Notification).where( Notification.user_id == user.id if user else Notification.user_id.is_(None) ) if not include_dismissed: query = query.where(Notification.dismissed.is_(False)) if notif_type: query = query.where(Notification.notif_type == notif_type) # Sort: undismissed first, then by date (newest first) query = query.order_by( Notification.dismissed.asc(), Notification.first_shown.desc(), ) return list(db_session.execute(query).scalars().all()) def dismiss_all_notifications( notif_type: NotificationType, db_session: Session, ) -> None: db_session.query(Notification).filter(Notification.notif_type == notif_type).update( {"dismissed": True} ) db_session.commit() def dismiss_notification(notification: Notification, db_session: Session) -> None: notification.dismissed = True db_session.commit() def batch_dismiss_notifications( notifications: list[Notification], db_session: Session, ) -> None: for notification in notifications: notification.dismissed = True db_session.commit() def batch_create_notifications( user_ids: list[UUID], notif_type: NotificationType, db_session: Session, title: str, description: str | None = None, additional_data: dict | None = None, ) -> int: """ Create notifications for multiple users in a single batch operation. Uses ON CONFLICT DO NOTHING for atomic idempotent inserts - if a user already has a notification with the same (user_id, notif_type, additional_data), the insert is silently skipped. Returns the number of notifications created. Relies on unique index on (user_id, notif_type, COALESCE(additional_data, '{}')) """ if not user_ids: return 0 now = datetime.now(timezone.utc) # Use empty dict instead of None to match COALESCE behavior in the unique index additional_data_normalized = additional_data if additional_data is not None else {} values = [ { "user_id": uid, "notif_type": notif_type.value, "title": title, "description": description, "dismissed": False, "last_shown": now, "first_shown": now, "additional_data": additional_data_normalized, } for uid in user_ids ] stmt = insert(Notification).values(values).on_conflict_do_nothing() result = db_session.execute(stmt) db_session.commit() # rowcount returns number of rows inserted (excludes conflicts) # CursorResult has rowcount but session.execute type hints are too broad return result.rowcount if result.rowcount >= 0 else 0 # type: ignore[attr-defined] def update_notification_last_shown( notification: Notification, db_session: Session ) -> None: notification.last_shown = func.now() db_session.commit() ================================================ FILE: backend/onyx/db/oauth_config.py ================================================ from typing import Any from uuid import UUID from sqlalchemy import select from sqlalchemy.orm import Session from onyx.db.models import OAuthConfig from onyx.db.models import OAuthUserToken from onyx.db.models import Tool from onyx.utils.logger import setup_logger logger = setup_logger() # OAuth Config CRUD operations def create_oauth_config( name: str, authorization_url: str, token_url: str, client_id: str, client_secret: str, scopes: list[str] | None, additional_params: dict[str, str] | None, db_session: Session, ) -> OAuthConfig: """Create a new OAuth configuration""" oauth_config = OAuthConfig( name=name, authorization_url=authorization_url, token_url=token_url, client_id=client_id, client_secret=client_secret, scopes=scopes, additional_params=additional_params, ) db_session.add(oauth_config) db_session.commit() return oauth_config def get_oauth_config(oauth_config_id: int, db_session: Session) -> OAuthConfig | None: """Get OAuth configuration by ID""" return db_session.scalar( select(OAuthConfig).where(OAuthConfig.id == oauth_config_id) ) def get_oauth_configs(db_session: Session) -> list[OAuthConfig]: """Get all OAuth configurations""" return list(db_session.scalars(select(OAuthConfig)).all()) def update_oauth_config( oauth_config_id: int, db_session: Session, name: str | None = None, authorization_url: str | None = None, token_url: str | None = None, client_id: str | None = None, client_secret: str | None = None, scopes: list[str] | None = None, additional_params: dict[str, Any] | None = None, clear_client_id: bool = False, clear_client_secret: bool = False, ) -> OAuthConfig: """ Update OAuth configuration. NOTE: If client_id or client_secret are None, existing values are preserved. To clear these values, set clear_client_id or clear_client_secret to True. This allows partial updates without re-entering secrets. """ oauth_config = db_session.scalar( select(OAuthConfig).where(OAuthConfig.id == oauth_config_id) ) if oauth_config is None: raise ValueError(f"OAuth config with id {oauth_config_id} does not exist") # Update only provided fields if name is not None: oauth_config.name = name if authorization_url is not None: oauth_config.authorization_url = authorization_url if token_url is not None: oauth_config.token_url = token_url if clear_client_id: oauth_config.client_id = "" # type: ignore[assignment] elif client_id is not None: oauth_config.client_id = client_id # type: ignore[assignment] if clear_client_secret: oauth_config.client_secret = "" # type: ignore[assignment] elif client_secret is not None: oauth_config.client_secret = client_secret # type: ignore[assignment] if scopes is not None: oauth_config.scopes = scopes if additional_params is not None: oauth_config.additional_params = additional_params db_session.commit() return oauth_config def delete_oauth_config(oauth_config_id: int, db_session: Session) -> None: """ Delete OAuth configuration. Sets oauth_config_id to NULL for associated tools due to SET NULL foreign key. Cascades delete to user tokens. """ oauth_config = db_session.scalar( select(OAuthConfig).where(OAuthConfig.id == oauth_config_id) ) if oauth_config is None: raise ValueError(f"OAuth config with id {oauth_config_id} does not exist") db_session.delete(oauth_config) db_session.commit() # User Token operations def get_user_oauth_token( oauth_config_id: int, user_id: UUID, db_session: Session ) -> OAuthUserToken | None: """Get user's OAuth token for a specific configuration""" return db_session.scalar( select(OAuthUserToken).where( OAuthUserToken.oauth_config_id == oauth_config_id, OAuthUserToken.user_id == user_id, ) ) def get_all_user_oauth_tokens( user_id: UUID, db_session: Session ) -> list[OAuthUserToken]: """ Get all user OAuth tokens. """ stmt = select(OAuthUserToken).where(OAuthUserToken.user_id == user_id) return list(db_session.scalars(stmt).all()) def upsert_user_oauth_token( oauth_config_id: int, user_id: UUID, token_data: dict, db_session: Session ) -> OAuthUserToken: """Insert or update user's OAuth token for a specific configuration""" existing_token = get_user_oauth_token(oauth_config_id, user_id, db_session) if existing_token: # Update existing token existing_token.token_data = token_data # type: ignore[assignment] db_session.commit() return existing_token else: # Create new token new_token = OAuthUserToken( oauth_config_id=oauth_config_id, user_id=user_id, token_data=token_data, ) db_session.add(new_token) db_session.commit() return new_token def delete_user_oauth_token( oauth_config_id: int, user_id: UUID, db_session: Session ) -> None: """Delete user's OAuth token for a specific configuration""" user_token = get_user_oauth_token(oauth_config_id, user_id, db_session) if user_token is None: raise ValueError( f"OAuth token for user {user_id} and config {oauth_config_id} does not exist" ) db_session.delete(user_token) db_session.commit() # Helper operations def get_tools_by_oauth_config(oauth_config_id: int, db_session: Session) -> list[Tool]: """Get all tools that use a specific OAuth configuration""" return list( db_session.scalars( select(Tool).where(Tool.oauth_config_id == oauth_config_id) ).all() ) ================================================ FILE: backend/onyx/db/opensearch_migration.py ================================================ """Database operations for OpenSearch migration tracking. This module provides functions to track the progress of migrating documents from Vespa to OpenSearch. """ import json from datetime import datetime from datetime import timezone from sqlalchemy import select from sqlalchemy import text from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session from onyx.background.celery.tasks.opensearch_migration.constants import ( GET_VESPA_CHUNKS_SLICE_COUNT, ) from onyx.background.celery.tasks.opensearch_migration.constants import ( TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE, ) from onyx.configs.app_configs import ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX from onyx.db.enums import OpenSearchDocumentMigrationStatus from onyx.db.models import Document from onyx.db.models import OpenSearchDocumentMigrationRecord from onyx.db.models import OpenSearchTenantMigrationRecord from onyx.document_index.vespa.shared_utils.utils import ( replace_invalid_doc_id_characters, ) from onyx.utils.logger import setup_logger logger = setup_logger() def get_paginated_document_batch( db_session: Session, limit: int, prev_ending_document_id: str | None = None, ) -> list[str]: """Gets a paginated batch of document IDs from the Document table. We need some deterministic ordering to ensure that we don't miss any documents when paginating. This function uses the document ID. It is possible a document is inserted above a spot this function has already passed. In that event we assume that the document will be indexed into OpenSearch anyway and we don't need to migrate. TODO(andrei): Consider ordering on last_modified in addition to ID to better match get_opensearch_migration_records_needing_migration. Args: db_session: SQLAlchemy session. limit: Number of document IDs to fetch. prev_ending_document_id: Document ID to start after (for pagination). If None, returns the first batch of documents. If not None, this should be the last ordered ID which was fetched in a previous batch. Defaults to None. Returns: List of document IDs. """ stmt = select(Document.id).order_by(Document.id.asc()).limit(limit) if prev_ending_document_id is not None: stmt = stmt.where(Document.id > prev_ending_document_id) return list(db_session.scalars(stmt).all()) def get_last_opensearch_migration_document_id( db_session: Session, ) -> str | None: """ Gets the last document ID in the OpenSearchDocumentMigrationRecord table. Returns None if no records are found. """ stmt = ( select(OpenSearchDocumentMigrationRecord.document_id) .order_by(OpenSearchDocumentMigrationRecord.document_id.desc()) .limit(1) ) return db_session.scalars(stmt).first() def create_opensearch_migration_records_with_commit( db_session: Session, document_ids: list[str], ) -> None: """Creates new OpenSearchDocumentMigrationRecord records. Silently skips any document IDs that already have records. """ if not document_ids: return values = [ { "document_id": document_id, "status": OpenSearchDocumentMigrationStatus.PENDING, } for document_id in document_ids ] stmt = insert(OpenSearchDocumentMigrationRecord).values(values) stmt = stmt.on_conflict_do_nothing(index_elements=["document_id"]) db_session.execute(stmt) db_session.commit() def get_opensearch_migration_records_needing_migration( db_session: Session, limit: int, ) -> list[OpenSearchDocumentMigrationRecord]: """Gets records of documents that need to be migrated. Properties: - First tries documents with status PENDING. - Of these, orders documents with the oldest last_modified to prioritize documents that were modified a long time ago, as they are presumed to be stable. This column is modified in many flows so is not a guarantee of the document having been indexed. - Then if there's room in the result, tries documents with status FAILED. - Of these, first orders documents on the least attempts_count so as to have a backoff for recently-failed docs. Then orders on last_modified as before. """ result: list[OpenSearchDocumentMigrationRecord] = [] # Step 1: Fetch as many PENDING status records as possible ordered by # last_modified (oldest first). last_modified lives on Document, so we join. stmt_pending = ( select(OpenSearchDocumentMigrationRecord) .join(Document, OpenSearchDocumentMigrationRecord.document_id == Document.id) .where( OpenSearchDocumentMigrationRecord.status == OpenSearchDocumentMigrationStatus.PENDING ) .order_by(Document.last_modified.asc()) .limit(limit) ) result.extend(list(db_session.scalars(stmt_pending).all())) remaining = limit - len(result) # Step 2: If more are needed, fetch records with status FAILED, ordered by # attempts_count (lowest first), then last_modified (oldest first). if remaining > 0: stmt_failed = ( select(OpenSearchDocumentMigrationRecord) .join( Document, OpenSearchDocumentMigrationRecord.document_id == Document.id, ) .where( OpenSearchDocumentMigrationRecord.status == OpenSearchDocumentMigrationStatus.FAILED ) .order_by( OpenSearchDocumentMigrationRecord.attempts_count.asc(), Document.last_modified.asc(), ) .limit(remaining) ) result.extend(list(db_session.scalars(stmt_failed).all())) return result def get_total_opensearch_migration_record_count( db_session: Session, ) -> int: """Gets the total number of OpenSearch migration records. Used to check whether every document has been tracked for migration. """ return db_session.query(OpenSearchDocumentMigrationRecord).count() def get_total_document_count(db_session: Session) -> int: """Gets the total number of documents. Used to check whether every document has been tracked for migration. """ return db_session.query(Document).count() def try_insert_opensearch_tenant_migration_record_with_commit( db_session: Session, ) -> None: """Tries to insert the singleton row on OpenSearchTenantMigrationRecord. Does nothing if the row already exists. """ stmt = insert(OpenSearchTenantMigrationRecord).on_conflict_do_nothing( index_elements=[text("(true)")] ) db_session.execute(stmt) db_session.commit() def increment_num_times_observed_no_additional_docs_to_migrate_with_commit( db_session: Session, ) -> None: """Increments the number of times observed no additional docs to migrate. Requires the OpenSearchTenantMigrationRecord to exist. Used to track when to stop the migration task. """ record = db_session.query(OpenSearchTenantMigrationRecord).first() if record is None: raise RuntimeError("OpenSearchTenantMigrationRecord not found.") record.num_times_observed_no_additional_docs_to_migrate += 1 db_session.commit() def increment_num_times_observed_no_additional_docs_to_populate_migration_table_with_commit( db_session: Session, ) -> None: """ Increments the number of times observed no additional docs to populate the migration table. Requires the OpenSearchTenantMigrationRecord to exist. Used to track when to stop the migration check task. """ record = db_session.query(OpenSearchTenantMigrationRecord).first() if record is None: raise RuntimeError("OpenSearchTenantMigrationRecord not found.") record.num_times_observed_no_additional_docs_to_populate_migration_table += 1 db_session.commit() def should_document_migration_be_permanently_failed( opensearch_document_migration_record: OpenSearchDocumentMigrationRecord, ) -> bool: return ( opensearch_document_migration_record.status == OpenSearchDocumentMigrationStatus.PERMANENTLY_FAILED or ( opensearch_document_migration_record.status == OpenSearchDocumentMigrationStatus.FAILED and opensearch_document_migration_record.attempts_count >= TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE ) ) def get_vespa_visit_state( db_session: Session, ) -> tuple[dict[int, str | None], int]: """Gets the current Vespa migration state from the tenant migration record. Requires the OpenSearchTenantMigrationRecord to exist. Returns: Tuple of (continuation_token_map, total_chunks_migrated). """ record = db_session.query(OpenSearchTenantMigrationRecord).first() if record is None: raise RuntimeError("OpenSearchTenantMigrationRecord not found.") if record.vespa_visit_continuation_token is None: continuation_token_map: dict[int, str | None] = { slice_id: None for slice_id in range(GET_VESPA_CHUNKS_SLICE_COUNT) } else: json_loaded_continuation_token_map = json.loads( record.vespa_visit_continuation_token ) continuation_token_map = { int(key): value for key, value in json_loaded_continuation_token_map.items() } return continuation_token_map, record.total_chunks_migrated def update_vespa_visit_progress_with_commit( db_session: Session, continuation_token_map: dict[int, str | None], chunks_processed: int, chunks_errored: int, approx_chunk_count_in_vespa: int | None, ) -> None: """Updates the Vespa migration progress and commits. Requires the OpenSearchTenantMigrationRecord to exist. Args: db_session: SQLAlchemy session. continuation_token_map: The new continuation token map. None entry means the visit is complete for that slice. chunks_processed: Number of chunks processed in this batch (added to the running total). chunks_errored: Number of chunks errored in this batch (added to the running errored total). approx_chunk_count_in_vespa: Approximate number of chunks in Vespa. If None, the existing value is used. """ record = db_session.query(OpenSearchTenantMigrationRecord).first() if record is None: raise RuntimeError("OpenSearchTenantMigrationRecord not found.") record.vespa_visit_continuation_token = json.dumps(continuation_token_map) record.total_chunks_migrated += chunks_processed record.total_chunks_errored += chunks_errored record.approx_chunk_count_in_vespa = ( approx_chunk_count_in_vespa if approx_chunk_count_in_vespa is not None else record.approx_chunk_count_in_vespa ) db_session.commit() def mark_migration_completed_time_if_not_set_with_commit( db_session: Session, ) -> None: """Marks the migration completed time if not set. Requires the OpenSearchTenantMigrationRecord to exist. """ record = db_session.query(OpenSearchTenantMigrationRecord).first() if record is None: raise RuntimeError("OpenSearchTenantMigrationRecord not found.") if record.migration_completed_at is not None: return record.migration_completed_at = datetime.now(timezone.utc) db_session.commit() def is_migration_completed(db_session: Session) -> bool: """Returns True if the migration is completed. Can be run even if the migration record does not exist. """ record = db_session.query(OpenSearchTenantMigrationRecord).first() return record is not None and record.migration_completed_at is not None def build_sanitized_to_original_doc_id_mapping( db_session: Session, ) -> dict[str, str]: """Pre-computes a mapping of sanitized -> original document IDs. Only includes documents whose ID contains single quotes (the only character that gets sanitized by replace_invalid_doc_id_characters). For all other documents, sanitized == original and no mapping entry is needed. Scans over all documents. Checks if the sanitized ID already exists as a genuine separate document in the Document table. If so, raises as there is no way of resolving the conflict in the migration. The user will need to reindex. Args: db_session: SQLAlchemy session. Returns: Dict mapping sanitized_id -> original_id, only for documents where the IDs differ. Empty dict means no documents have single quotes in their IDs. """ # Find all documents with single quotes in their ID. stmt = select(Document.id).where(Document.id.contains("'")) ids_with_quotes = list(db_session.scalars(stmt).all()) result: dict[str, str] = {} for original_id in ids_with_quotes: sanitized_id = replace_invalid_doc_id_characters(original_id) if sanitized_id != original_id: result[sanitized_id] = original_id # See if there are any documents whose ID is a sanitized ID of another # document. If there is even one match, we cannot proceed. stmt = select(Document.id).where(Document.id.in_(result.keys())) ids_with_matches = list(db_session.scalars(stmt).all()) if ids_with_matches: raise RuntimeError( f"Documents with IDs {ids_with_matches} have sanitized IDs that match other documents. " "This is not supported and the user will need to reindex." ) return result def get_opensearch_migration_state( db_session: Session, ) -> tuple[int, datetime | None, datetime | None, int | None]: """Returns the state of the Vespa to OpenSearch migration. If the tenant migration record is not found, returns defaults of 0, None, None, None. Args: db_session: SQLAlchemy session. Returns: Tuple of (total_chunks_migrated, created_at, migration_completed_at, approx_chunk_count_in_vespa). """ record = db_session.query(OpenSearchTenantMigrationRecord).first() if record is None: return 0, None, None, None return ( record.total_chunks_migrated, record.created_at, record.migration_completed_at, record.approx_chunk_count_in_vespa, ) def get_opensearch_retrieval_state( db_session: Session, ) -> bool: """Returns the state of the OpenSearch retrieval. If the tenant migration record is not found, defaults to ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX. """ record = db_session.query(OpenSearchTenantMigrationRecord).first() if record is None: return ENABLE_OPENSEARCH_RETRIEVAL_FOR_ONYX return record.enable_opensearch_retrieval def set_enable_opensearch_retrieval_with_commit( db_session: Session, enable: bool, ) -> None: """Sets the enable_opensearch_retrieval flag on the singleton record. Creates the record if it doesn't exist yet. """ try_insert_opensearch_tenant_migration_record_with_commit(db_session) record = db_session.query(OpenSearchTenantMigrationRecord).first() if record is None: raise RuntimeError("OpenSearchTenantMigrationRecord not found.") record.enable_opensearch_retrieval = enable db_session.commit() ================================================ FILE: backend/onyx/db/pat.py ================================================ """Database operations for Personal Access Tokens.""" import asyncio from datetime import datetime from datetime import timezone from uuid import UUID from sqlalchemy import select from sqlalchemy import update from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session from onyx.auth.pat import build_displayable_pat from onyx.auth.pat import calculate_expiration from onyx.auth.pat import generate_pat from onyx.auth.pat import hash_pat from onyx.db.engine.async_sql_engine import get_async_session_context_manager from onyx.db.models import PersonalAccessToken from onyx.db.models import User from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() async def fetch_user_for_pat( hashed_token: str, async_db_session: AsyncSession ) -> User | None: """Fetch user associated with PAT. Returns None if invalid, expired, or inactive user. NOTE: This is async since it's used during auth (which is necessarily async due to FastAPI Users). NOTE: Expired includes both naturally expired and user-revoked tokens (revocation sets expires_at=NOW()). Uses select(User) as primary entity so that joined-eager relationships (e.g. oauth_accounts) are loaded correctly — matching the pattern in fetch_user_for_api_key. """ now = datetime.now(timezone.utc) user = await async_db_session.scalar( select(User) .join(PersonalAccessToken, PersonalAccessToken.user_id == User.id) .where(PersonalAccessToken.hashed_token == hashed_token) .where(User.is_active) # type: ignore .where( (PersonalAccessToken.expires_at.is_(None)) | (PersonalAccessToken.expires_at > now) ) ) if not user: return None _schedule_pat_last_used_update(hashed_token, now) return user def _schedule_pat_last_used_update(hashed_token: str, now: datetime) -> None: """Fire-and-forget update of last_used_at, throttled to 5-minute granularity.""" async def _update() -> None: try: tenant_id = get_current_tenant_id() async with get_async_session_context_manager(tenant_id) as session: pat = await session.scalar( select(PersonalAccessToken).where( PersonalAccessToken.hashed_token == hashed_token ) ) if not pat: return if ( pat.last_used_at is not None and (now - pat.last_used_at).total_seconds() <= 300 ): return await session.execute( update(PersonalAccessToken) .where(PersonalAccessToken.hashed_token == hashed_token) .values(last_used_at=now) ) await session.commit() except Exception as e: logger.warning(f"Failed to update last_used_at for PAT: {e}") asyncio.create_task(_update()) def create_pat( db_session: Session, user_id: UUID, name: str, expiration_days: int | None, ) -> tuple[PersonalAccessToken, str]: """Create new PAT. Returns (db_record, raw_token). Raises ValueError if user is inactive or not found. """ user = db_session.scalar(select(User).where(User.id == user_id)) # type: ignore if not user or not user.is_active: raise ValueError("Cannot create PAT for inactive or non-existent user") tenant_id = get_current_tenant_id() raw_token = generate_pat(tenant_id) pat = PersonalAccessToken( name=name, hashed_token=hash_pat(raw_token), token_display=build_displayable_pat(raw_token), user_id=user_id, expires_at=calculate_expiration(expiration_days), ) db_session.add(pat) db_session.commit() return pat, raw_token def list_user_pats(db_session: Session, user_id: UUID) -> list[PersonalAccessToken]: """List all active (non-expired) PATs for a user.""" return list( db_session.scalars( select(PersonalAccessToken) .where(PersonalAccessToken.user_id == user_id) .where( (PersonalAccessToken.expires_at.is_(None)) | (PersonalAccessToken.expires_at > datetime.now(timezone.utc)) ) .order_by(PersonalAccessToken.created_at.desc()) ).all() ) def revoke_pat(db_session: Session, pat_id: int, user_id: UUID) -> bool: """Revoke PAT by setting expires_at=NOW() for immediate expiry. Returns True if revoked, False if not found, not owned by user, or already expired. """ now = datetime.now(timezone.utc) pat = db_session.scalar( select(PersonalAccessToken) .where(PersonalAccessToken.id == pat_id) .where(PersonalAccessToken.user_id == user_id) .where( (PersonalAccessToken.expires_at.is_(None)) | (PersonalAccessToken.expires_at > now) ) # Only revoke active (non-expired) tokens ) if not pat: return False # Revoke by setting expires_at to NOW() and marking as revoked for audit trail pat.expires_at = now pat.is_revoked = True db_session.commit() return True ================================================ FILE: backend/onyx/db/permission_sync_attempt.py ================================================ """Permission sync attempt CRUD operations and utilities. This module contains all CRUD operations for both DocPermissionSyncAttempt and ExternalGroupPermissionSyncAttempt models, along with shared utilities. """ from typing import Any from typing import cast from sqlalchemy import delete from sqlalchemy import func from sqlalchemy import select from sqlalchemy.engine.cursor import CursorResult from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session from onyx.db.enums import PermissionSyncStatus from onyx.db.models import ConnectorCredentialPair from onyx.db.models import DocPermissionSyncAttempt from onyx.db.models import ExternalGroupPermissionSyncAttempt from onyx.utils.logger import setup_logger from onyx.utils.telemetry import optional_telemetry from onyx.utils.telemetry import RecordType logger = setup_logger() # ============================================================================= # DOC PERMISSION SYNC ATTEMPT CRUD # ============================================================================= def create_doc_permission_sync_attempt( connector_credential_pair_id: int, db_session: Session, ) -> int: """Create a new doc permission sync attempt. Args: connector_credential_pair_id: The ID of the connector credential pair db_session: The database session Returns: The ID of the created attempt """ attempt = DocPermissionSyncAttempt( connector_credential_pair_id=connector_credential_pair_id, status=PermissionSyncStatus.NOT_STARTED, ) db_session.add(attempt) db_session.commit() return attempt.id def get_doc_permission_sync_attempt( db_session: Session, attempt_id: int, eager_load_connector: bool = False, ) -> DocPermissionSyncAttempt | None: """Get a doc permission sync attempt by ID. Args: db_session: The database session attempt_id: The ID of the attempt eager_load_connector: If True, eagerly loads the connector and cc_pair relationships Returns: The attempt if found, None otherwise """ stmt = select(DocPermissionSyncAttempt).where( DocPermissionSyncAttempt.id == attempt_id ) if eager_load_connector: stmt = stmt.options( joinedload(DocPermissionSyncAttempt.connector_credential_pair).joinedload( ConnectorCredentialPair.connector ) ) return db_session.scalars(stmt).first() def get_latest_doc_permission_sync_attempt_for_cc_pair( db_session: Session, connector_credential_pair_id: int, ) -> DocPermissionSyncAttempt | None: """Get the latest doc permission sync attempt for a connector credential pair.""" return db_session.execute( select(DocPermissionSyncAttempt) .where( DocPermissionSyncAttempt.connector_credential_pair_id == connector_credential_pair_id ) .order_by(DocPermissionSyncAttempt.time_created.desc()) .limit(1) ).scalar_one_or_none() def get_recent_doc_permission_sync_attempts_for_cc_pair( cc_pair_id: int, limit: int, db_session: Session, ) -> list[DocPermissionSyncAttempt]: """Get recent doc permission sync attempts for a cc pair, most recent first.""" return list( db_session.execute( select(DocPermissionSyncAttempt) .where(DocPermissionSyncAttempt.connector_credential_pair_id == cc_pair_id) .order_by(DocPermissionSyncAttempt.time_created.desc()) .limit(limit) ).scalars() ) def mark_doc_permission_sync_attempt_in_progress( attempt_id: int, db_session: Session, ) -> DocPermissionSyncAttempt: """Mark a doc permission sync attempt as IN_PROGRESS. Locks the row during update.""" try: attempt = db_session.execute( select(DocPermissionSyncAttempt) .where(DocPermissionSyncAttempt.id == attempt_id) .with_for_update() ).scalar_one() if attempt.status != PermissionSyncStatus.NOT_STARTED: raise RuntimeError( f"Doc permission sync attempt with ID '{attempt_id}' is not in NOT_STARTED status. " f"Current status is '{attempt.status}'." ) attempt.status = PermissionSyncStatus.IN_PROGRESS attempt.time_started = func.now() # type: ignore db_session.commit() return attempt except Exception: db_session.rollback() logger.exception("mark_doc_permission_sync_attempt_in_progress exceptioned.") raise def mark_doc_permission_sync_attempt_failed( attempt_id: int, db_session: Session, error_message: str, ) -> None: """Mark a doc permission sync attempt as failed.""" try: attempt = db_session.execute( select(DocPermissionSyncAttempt) .where(DocPermissionSyncAttempt.id == attempt_id) .with_for_update() ).scalar_one() if not attempt.time_started: attempt.time_started = func.now() # type: ignore attempt.status = PermissionSyncStatus.FAILED attempt.time_finished = func.now() # type: ignore attempt.error_message = error_message db_session.commit() # Add telemetry for permission sync attempt status change optional_telemetry( record_type=RecordType.PERMISSION_SYNC_COMPLETE, data={ "doc_permission_sync_attempt_id": attempt_id, "status": PermissionSyncStatus.FAILED.value, "cc_pair_id": attempt.connector_credential_pair_id, }, ) except Exception: db_session.rollback() raise def complete_doc_permission_sync_attempt( db_session: Session, attempt_id: int, total_docs_synced: int, docs_with_permission_errors: int, ) -> DocPermissionSyncAttempt: """Complete a doc permission sync attempt by updating progress and setting final status. This combines the progress update and final status marking into a single operation. If there were permission errors, the attempt is marked as COMPLETED_WITH_ERRORS, otherwise it's marked as SUCCESS. Args: db_session: The database session attempt_id: The ID of the attempt total_docs_synced: Total number of documents synced docs_with_permission_errors: Number of documents that had permission errors Returns: The completed attempt """ try: attempt = db_session.execute( select(DocPermissionSyncAttempt) .where(DocPermissionSyncAttempt.id == attempt_id) .with_for_update() ).scalar_one() # Update progress counters attempt.total_docs_synced = (attempt.total_docs_synced or 0) + total_docs_synced attempt.docs_with_permission_errors = ( attempt.docs_with_permission_errors or 0 ) + docs_with_permission_errors # Set final status based on whether there were errors if docs_with_permission_errors > 0: attempt.status = PermissionSyncStatus.COMPLETED_WITH_ERRORS else: attempt.status = PermissionSyncStatus.SUCCESS attempt.time_finished = func.now() # type: ignore db_session.commit() # Add telemetry optional_telemetry( record_type=RecordType.PERMISSION_SYNC_COMPLETE, data={ "doc_permission_sync_attempt_id": attempt_id, "status": attempt.status.value, "cc_pair_id": attempt.connector_credential_pair_id, }, ) return attempt except Exception: db_session.rollback() logger.exception("complete_doc_permission_sync_attempt exceptioned.") raise # ============================================================================= # EXTERNAL GROUP PERMISSION SYNC ATTEMPT CRUD # ============================================================================= def create_external_group_sync_attempt( connector_credential_pair_id: int | None, db_session: Session, ) -> int: """Create a new external group sync attempt. Args: connector_credential_pair_id: The ID of the connector credential pair, or None for global syncs db_session: The database session Returns: The ID of the created attempt """ attempt = ExternalGroupPermissionSyncAttempt( connector_credential_pair_id=connector_credential_pair_id, status=PermissionSyncStatus.NOT_STARTED, ) db_session.add(attempt) db_session.commit() return attempt.id def get_external_group_sync_attempt( db_session: Session, attempt_id: int, eager_load_connector: bool = False, ) -> ExternalGroupPermissionSyncAttempt | None: """Get an external group sync attempt by ID. Args: db_session: The database session attempt_id: The ID of the attempt eager_load_connector: If True, eagerly loads the connector and cc_pair relationships Returns: The attempt if found, None otherwise """ stmt = select(ExternalGroupPermissionSyncAttempt).where( ExternalGroupPermissionSyncAttempt.id == attempt_id ) if eager_load_connector: stmt = stmt.options( joinedload( ExternalGroupPermissionSyncAttempt.connector_credential_pair ).joinedload(ConnectorCredentialPair.connector) ) return db_session.scalars(stmt).first() def get_recent_external_group_sync_attempts_for_cc_pair( cc_pair_id: int | None, limit: int, db_session: Session, ) -> list[ExternalGroupPermissionSyncAttempt]: """Get recent external group sync attempts for a cc pair, most recent first. If cc_pair_id is None, gets global group sync attempts.""" stmt = select(ExternalGroupPermissionSyncAttempt) if cc_pair_id is not None: stmt = stmt.where( ExternalGroupPermissionSyncAttempt.connector_credential_pair_id == cc_pair_id ) else: stmt = stmt.where( ExternalGroupPermissionSyncAttempt.connector_credential_pair_id.is_(None) ) return list( db_session.execute( stmt.order_by(ExternalGroupPermissionSyncAttempt.time_created.desc()).limit( limit ) ).scalars() ) def mark_external_group_sync_attempt_in_progress( attempt_id: int, db_session: Session, ) -> ExternalGroupPermissionSyncAttempt: """Mark an external group sync attempt as IN_PROGRESS. Locks the row during update.""" try: attempt = db_session.execute( select(ExternalGroupPermissionSyncAttempt) .where(ExternalGroupPermissionSyncAttempt.id == attempt_id) .with_for_update() ).scalar_one() if attempt.status != PermissionSyncStatus.NOT_STARTED: raise RuntimeError( f"External group sync attempt with ID '{attempt_id}' is not in NOT_STARTED status. " f"Current status is '{attempt.status}'." ) attempt.status = PermissionSyncStatus.IN_PROGRESS attempt.time_started = func.now() # type: ignore db_session.commit() return attempt except Exception: db_session.rollback() logger.exception("mark_external_group_sync_attempt_in_progress exceptioned.") raise def mark_external_group_sync_attempt_failed( attempt_id: int, db_session: Session, error_message: str, ) -> None: """Mark an external group sync attempt as failed.""" try: attempt = db_session.execute( select(ExternalGroupPermissionSyncAttempt) .where(ExternalGroupPermissionSyncAttempt.id == attempt_id) .with_for_update() ).scalar_one() if not attempt.time_started: attempt.time_started = func.now() # type: ignore attempt.status = PermissionSyncStatus.FAILED attempt.time_finished = func.now() # type: ignore attempt.error_message = error_message db_session.commit() # Add telemetry for permission sync attempt status change optional_telemetry( record_type=RecordType.PERMISSION_SYNC_COMPLETE, data={ "external_group_sync_attempt_id": attempt_id, "status": PermissionSyncStatus.FAILED.value, "cc_pair_id": attempt.connector_credential_pair_id, }, ) except Exception: db_session.rollback() raise def complete_external_group_sync_attempt( db_session: Session, attempt_id: int, total_users_processed: int, total_groups_processed: int, total_group_memberships_synced: int, errors_encountered: int = 0, ) -> ExternalGroupPermissionSyncAttempt: """Complete an external group sync attempt by updating progress and setting final status. This combines the progress update and final status marking into a single operation. If there were errors, the attempt is marked as COMPLETED_WITH_ERRORS, otherwise it's marked as SUCCESS. Args: db_session: The database session attempt_id: The ID of the attempt total_users_processed: Total users processed total_groups_processed: Total groups processed total_group_memberships_synced: Total group memberships synced errors_encountered: Number of errors encountered (determines if COMPLETED_WITH_ERRORS) Returns: The completed attempt """ try: attempt = db_session.execute( select(ExternalGroupPermissionSyncAttempt) .where(ExternalGroupPermissionSyncAttempt.id == attempt_id) .with_for_update() ).scalar_one() # Update progress counters attempt.total_users_processed = ( attempt.total_users_processed or 0 ) + total_users_processed attempt.total_groups_processed = ( attempt.total_groups_processed or 0 ) + total_groups_processed attempt.total_group_memberships_synced = ( attempt.total_group_memberships_synced or 0 ) + total_group_memberships_synced # Set final status based on whether there were errors if errors_encountered > 0: attempt.status = PermissionSyncStatus.COMPLETED_WITH_ERRORS else: attempt.status = PermissionSyncStatus.SUCCESS attempt.time_finished = func.now() # type: ignore db_session.commit() # Add telemetry optional_telemetry( record_type=RecordType.PERMISSION_SYNC_COMPLETE, data={ "external_group_sync_attempt_id": attempt_id, "status": attempt.status.value, "cc_pair_id": attempt.connector_credential_pair_id, }, ) return attempt except Exception: db_session.rollback() logger.exception("complete_external_group_sync_attempt exceptioned.") raise # ============================================================================= # DELETION FUNCTIONS # ============================================================================= def delete_doc_permission_sync_attempts__no_commit( db_session: Session, cc_pair_id: int, ) -> int: """Delete all doc permission sync attempts for a connector credential pair. This does not commit the transaction. It should be used within an existing transaction. Args: db_session: The database session cc_pair_id: The connector credential pair ID Returns: The number of attempts deleted """ stmt = delete(DocPermissionSyncAttempt).where( DocPermissionSyncAttempt.connector_credential_pair_id == cc_pair_id ) result = cast(CursorResult[Any], db_session.execute(stmt)) return result.rowcount or 0 def delete_external_group_permission_sync_attempts__no_commit( db_session: Session, cc_pair_id: int, ) -> int: """Delete all external group permission sync attempts for a connector credential pair. This does not commit the transaction. It should be used within an existing transaction. Args: db_session: The database session cc_pair_id: The connector credential pair ID Returns: The number of attempts deleted """ stmt = delete(ExternalGroupPermissionSyncAttempt).where( ExternalGroupPermissionSyncAttempt.connector_credential_pair_id == cc_pair_id ) result = cast(CursorResult[Any], db_session.execute(stmt)) return result.rowcount or 0 ================================================ FILE: backend/onyx/db/permissions.py ================================================ """ DB operations for recomputing user effective_permissions. These live in onyx/db/ (not onyx/auth/) because they are pure DB operations that query PermissionGrant rows and update the User.effective_permissions JSONB column. Keeping them here avoids circular imports when called from other onyx/db/ modules such as users.py. """ from collections import defaultdict from uuid import UUID from sqlalchemy import select from sqlalchemy import update from sqlalchemy.orm import Session from onyx.db.models import PermissionGrant from onyx.db.models import User from onyx.db.models import User__UserGroup def recompute_user_permissions__no_commit( user_ids: UUID | str | list[UUID] | list[str], db_session: Session ) -> None: """Recompute granted permissions for one or more users. Accepts a single UUID or a list. Uses a single query regardless of how many users are passed, avoiding N+1 issues. Stores only directly granted permissions — implication expansion happens at read time via get_effective_permissions(). Does NOT commit — caller must commit the session. """ if isinstance(user_ids, (UUID, str)): uid_list = [user_ids] else: uid_list = list(user_ids) if not uid_list: return # Single query to fetch ALL permissions for these users across ALL their # groups (a user may belong to multiple groups with different grants). rows = db_session.execute( select(User__UserGroup.user_id, PermissionGrant.permission) .join( PermissionGrant, PermissionGrant.group_id == User__UserGroup.user_group_id, ) .where( User__UserGroup.user_id.in_(uid_list), PermissionGrant.is_deleted.is_(False), ) ).all() # Group permissions by user; users with no grants get an empty set. perms_by_user: dict[UUID | str, set[str]] = defaultdict(set) for uid in uid_list: perms_by_user[uid] # ensure every user has an entry for uid, perm in rows: perms_by_user[uid].add(perm.value) for uid, perms in perms_by_user.items(): db_session.execute( update(User) .where(User.id == uid) # type: ignore[arg-type] .values(effective_permissions=sorted(perms)) ) def recompute_permissions_for_group__no_commit( group_id: int, db_session: Session ) -> None: """Recompute granted permissions for all users in a group. Does NOT commit — caller must commit the session. """ user_ids: list[UUID] = [ uid for uid in db_session.execute( select(User__UserGroup.user_id).where( User__UserGroup.user_group_id == group_id, User__UserGroup.user_id.isnot(None), ) ) .scalars() .all() if uid is not None ] if not user_ids: return recompute_user_permissions__no_commit(user_ids, db_session) ================================================ FILE: backend/onyx/db/persona.py ================================================ from collections.abc import Sequence from datetime import datetime from enum import Enum from uuid import UUID from fastapi import HTTPException from sqlalchemy import exists from sqlalchemy import func from sqlalchemy import not_ from sqlalchemy import or_ from sqlalchemy import Select from sqlalchemy import select from sqlalchemy import update from sqlalchemy.orm import aliased from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from onyx.access.hierarchy_access import get_user_external_group_ids from onyx.auth.schemas import UserRole from onyx.configs.app_configs import CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS from onyx.configs.constants import DEFAULT_PERSONA_ID from onyx.configs.constants import NotificationType from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX from onyx.db.document_access import get_accessible_documents_by_ids from onyx.db.models import ConnectorCredentialPair from onyx.db.models import Document from onyx.db.models import DocumentSet from onyx.db.models import FederatedConnector__DocumentSet from onyx.db.models import HierarchyNode from onyx.db.models import Persona from onyx.db.models import Persona__User from onyx.db.models import Persona__UserGroup from onyx.db.models import PersonaLabel from onyx.db.models import StarterMessage from onyx.db.models import Tool from onyx.db.models import User from onyx.db.models import User__UserGroup from onyx.db.models import UserFile from onyx.db.models import UserGroup from onyx.db.notification import create_notification from onyx.server.features.persona.models import FullPersonaSnapshot from onyx.server.features.persona.models import MinimalPersonaSnapshot from onyx.server.features.persona.models import PersonaSharedNotificationData from onyx.server.features.persona.models import PersonaSnapshot from onyx.server.features.persona.models import PersonaUpsertRequest from onyx.server.features.tool.tool_visibility import should_expose_tool_to_fe from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_versioned_implementation logger = setup_logger() def get_default_behavior_persona( db_session: Session, eager_load_for_tools: bool = False, ) -> Persona | None: stmt = select(Persona).where(Persona.id == DEFAULT_PERSONA_ID) if eager_load_for_tools: stmt = stmt.options( selectinload(Persona.tools), selectinload(Persona.document_sets), selectinload(Persona.attached_documents), selectinload(Persona.hierarchy_nodes), ) return db_session.scalars(stmt).first() class PersonaLoadType(Enum): NONE = "none" MINIMAL = "minimal" FULL = "full" def _add_user_filters( stmt: Select[tuple[Persona]], user: User, get_editable: bool = True ) -> Select[tuple[Persona]]: if user.role == UserRole.ADMIN: return stmt stmt = stmt.distinct() Persona__UG = aliased(Persona__UserGroup) User__UG = aliased(User__UserGroup) """ Here we select cc_pairs by relation: User -> User__UserGroup -> Persona__UserGroup -> Persona """ stmt = ( stmt.outerjoin(Persona__UG) .outerjoin( User__UserGroup, User__UserGroup.user_group_id == Persona__UG.user_group_id, ) .outerjoin( Persona__User, Persona__User.persona_id == Persona.id, ) ) """ Filter Personas by: - if the user is in the user_group that owns the Persona - if the user is not a global_curator, they must also have a curator relationship to the user_group - if editing is being done, we also filter out Personas that are owned by groups that the user isn't a curator for - if we are not editing, we show all Personas in the groups the user is a curator for (as well as public Personas) - if we are not editing, we return all Personas directly connected to the user """ # Anonymous users only see public Personas if user.is_anonymous: where_clause = Persona.is_public == True # noqa: E712 return stmt.where(where_clause) # If curator ownership restriction is enabled, curators can only access their own assistants if CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS and user.role in [ UserRole.CURATOR, UserRole.GLOBAL_CURATOR, ]: where_clause = (Persona.user_id == user.id) | (Persona.user_id.is_(None)) return stmt.where(where_clause) where_clause = User__UserGroup.user_id == user.id if user.role == UserRole.CURATOR and get_editable: where_clause &= User__UserGroup.is_curator == True # noqa: E712 if get_editable: user_groups = select(User__UG.user_group_id).where(User__UG.user_id == user.id) if user.role == UserRole.CURATOR: user_groups = user_groups.where(User__UG.is_curator == True) # noqa: E712 where_clause &= ( ~exists() .where(Persona__UG.persona_id == Persona.id) .where(~Persona__UG.user_group_id.in_(user_groups)) .correlate(Persona) ) else: # Group the public persona conditions public_condition = (Persona.is_public == True) & ( # noqa: E712 Persona.is_listed == True # noqa: E712 ) where_clause |= public_condition where_clause |= Persona__User.user_id == user.id where_clause |= Persona.user_id == user.id return stmt.where(where_clause) def fetch_persona_by_id_for_user( db_session: Session, persona_id: int, user: User, get_editable: bool = True ) -> Persona: stmt = select(Persona).where(Persona.id == persona_id).distinct() stmt = _add_user_filters(stmt=stmt, user=user, get_editable=get_editable) persona = db_session.scalars(stmt).one_or_none() if not persona: raise HTTPException( status_code=403, detail=f"Persona with ID {persona_id} does not exist or user is not authorized to access it", ) return persona def get_best_persona_id_for_user( db_session: Session, user: User, persona_id: int | None = None ) -> int | None: if persona_id is not None: stmt = select(Persona).where(Persona.id == persona_id).distinct() stmt = _add_user_filters( stmt=stmt, user=user, # We don't want to filter by editable here, we just want to see if the # persona is usable by the user get_editable=False, ) persona = db_session.scalars(stmt).one_or_none() if persona: return persona.id # If the persona is not found, or the slack bot is using doc sets instead of personas, # we need to find the best persona for the user # This is the persona with the highest display priority that the user has access to stmt = select(Persona).order_by(Persona.display_priority.desc()).distinct() stmt = _add_user_filters(stmt=stmt, user=user, get_editable=True) persona = db_session.scalars(stmt).one_or_none() return persona.id if persona else None def _get_persona_by_name( persona_name: str, user: User | None, db_session: Session ) -> Persona | None: """Fetch a persona by name with access control. Access rules: - user=None (system operations): can see all personas - Admin users: can see all personas - Non-admin users: can only see their own personas """ stmt = select(Persona).where(Persona.name == persona_name) if user and user.role != UserRole.ADMIN: stmt = stmt.where(Persona.user_id == user.id) result = db_session.execute(stmt).scalar_one_or_none() return result def update_persona_access( persona_id: int, creator_user_id: UUID | None, db_session: Session, is_public: bool | None = None, user_ids: list[UUID] | None = None, group_ids: list[int] | None = None, ) -> None: """Updates the access settings for a persona including public status and user shares. NOTE: Callers are responsible for committing.""" needs_sync = False if is_public is not None: needs_sync = True persona = db_session.query(Persona).filter(Persona.id == persona_id).first() if persona: persona.is_public = is_public # NOTE: For user-ids and group-ids, `None` means "leave unchanged", `[]` means "clear all shares", # and a non-empty list means "replace with these shares". if user_ids is not None: needs_sync = True db_session.query(Persona__User).filter( Persona__User.persona_id == persona_id ).delete(synchronize_session="fetch") for user_uuid in user_ids: db_session.add(Persona__User(persona_id=persona_id, user_id=user_uuid)) if user_uuid != creator_user_id: create_notification( user_id=user_uuid, notif_type=NotificationType.PERSONA_SHARED, title="A new agent was shared with you!", db_session=db_session, additional_data=PersonaSharedNotificationData( persona_id=persona_id, ).model_dump(), ) # MIT doesn't support group-based sharing, so we allow clearing (no-op since # there shouldn't be any) but raise an error if trying to add actual groups. if group_ids is not None: needs_sync = True db_session.query(Persona__UserGroup).filter( Persona__UserGroup.persona_id == persona_id ).delete(synchronize_session="fetch") if group_ids: raise NotImplementedError("Onyx MIT does not support group-based sharing") # When sharing changes, user file ACLs need to be updated in the vector DB if needs_sync: mark_persona_user_files_for_sync(persona_id, db_session) def create_update_persona( persona_id: int | None, create_persona_request: PersonaUpsertRequest, user: User, db_session: Session, ) -> FullPersonaSnapshot: """Higher level function than upsert_persona, although either is valid to use.""" # Permission to actually use these is checked later try: # Featured persona validation if create_persona_request.is_featured: # Curators can edit featured personas, but not make them # TODO this will be reworked soon with RBAC permissions feature if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR: pass elif user.role != UserRole.ADMIN: raise ValueError("Only admins can make a featured persona") # Convert incoming string UUIDs to UUID objects for DB operations converted_user_file_ids = None if create_persona_request.user_file_ids is not None: try: converted_user_file_ids = [ UUID(str_id) for str_id in create_persona_request.user_file_ids ] except Exception: raise ValueError("Invalid user_file_ids; must be UUID strings") persona = upsert_persona( persona_id=persona_id, user=user, db_session=db_session, description=create_persona_request.description, name=create_persona_request.name, document_set_ids=create_persona_request.document_set_ids, tool_ids=create_persona_request.tool_ids, is_public=create_persona_request.is_public, llm_model_provider_override=create_persona_request.llm_model_provider_override, llm_model_version_override=create_persona_request.llm_model_version_override, starter_messages=create_persona_request.starter_messages, system_prompt=create_persona_request.system_prompt, task_prompt=create_persona_request.task_prompt, datetime_aware=create_persona_request.datetime_aware, replace_base_system_prompt=create_persona_request.replace_base_system_prompt, uploaded_image_id=create_persona_request.uploaded_image_id, icon_name=create_persona_request.icon_name, display_priority=create_persona_request.display_priority, remove_image=create_persona_request.remove_image, search_start_date=create_persona_request.search_start_date, label_ids=create_persona_request.label_ids, is_featured=create_persona_request.is_featured, user_file_ids=converted_user_file_ids, commit=False, hierarchy_node_ids=create_persona_request.hierarchy_node_ids, document_ids=create_persona_request.document_ids, ) versioned_update_persona_access = fetch_versioned_implementation( "onyx.db.persona", "update_persona_access" ) versioned_update_persona_access( persona_id=persona.id, creator_user_id=user.id, db_session=db_session, user_ids=create_persona_request.users, group_ids=create_persona_request.groups, ) db_session.commit() except ValueError as e: logger.exception("Failed to create persona") raise HTTPException(status_code=400, detail=str(e)) return FullPersonaSnapshot.from_model(persona) def update_persona_shared( persona_id: int, user_ids: list[UUID] | None, user: User, db_session: Session, group_ids: list[int] | None = None, is_public: bool | None = None, label_ids: list[int] | None = None, ) -> None: """Simplified version of `create_update_persona` which only touches the accessibility rather than any of the logic (e.g. prompt, connected data sources, etc.).""" persona = fetch_persona_by_id_for_user( db_session=db_session, persona_id=persona_id, user=user, get_editable=True ) if user and user.role != UserRole.ADMIN and persona.user_id != user.id: raise PermissionError("You don't have permission to modify this persona") versioned_update_persona_access = fetch_versioned_implementation( "onyx.db.persona", "update_persona_access" ) versioned_update_persona_access( persona_id=persona_id, creator_user_id=user.id, db_session=db_session, is_public=is_public, user_ids=user_ids, group_ids=group_ids, ) if label_ids is not None: labels = ( db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all() ) if len(labels) != len(label_ids): raise ValueError("Some label IDs were not found in the database") persona.labels.clear() persona.labels = labels db_session.commit() def update_persona_public_status( persona_id: int, is_public: bool, db_session: Session, user: User, ) -> None: persona = fetch_persona_by_id_for_user( db_session=db_session, persona_id=persona_id, user=user, get_editable=True ) if user.role != UserRole.ADMIN and persona.user_id != user.id: raise ValueError("You don't have permission to modify this persona") persona.is_public = is_public db_session.commit() def _build_persona_filters( stmt: Select[tuple[Persona]], include_default: bool, include_slack_bot_personas: bool, include_deleted: bool, ) -> Select[tuple[Persona]]: """Filters which Personas are included in the query. Args: stmt: The base query to filter. include_default: If True, includes builtin/default personas. include_slack_bot_personas: If True, includes Slack bot personas. include_deleted: If True, includes deleted personas. Returns: The modified query with the filters applied. """ if not include_default: stmt = stmt.where(Persona.builtin_persona.is_(False)) if not include_slack_bot_personas: stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX))) if not include_deleted: stmt = stmt.where(Persona.deleted.is_(False)) return stmt def get_minimal_persona_snapshots_for_user( user: User, db_session: Session, get_editable: bool = True, include_default: bool = True, include_slack_bot_personas: bool = False, include_deleted: bool = False, ) -> list[MinimalPersonaSnapshot]: stmt = select(Persona) stmt = _add_user_filters(stmt, user, get_editable) stmt = _build_persona_filters( stmt, include_default, include_slack_bot_personas, include_deleted ) stmt = stmt.options( selectinload(Persona.tools), selectinload(Persona.labels), selectinload(Persona.document_sets).options( selectinload(DocumentSet.connector_credential_pairs).selectinload( ConnectorCredentialPair.connector ), selectinload(DocumentSet.users), selectinload(DocumentSet.groups), selectinload(DocumentSet.federated_connectors).selectinload( FederatedConnector__DocumentSet.federated_connector ), ), selectinload(Persona.hierarchy_nodes), selectinload(Persona.attached_documents).selectinload( Document.parent_hierarchy_node ), selectinload(Persona.user), ) results = db_session.scalars(stmt).all() return [MinimalPersonaSnapshot.from_model(persona) for persona in results] def get_persona_snapshots_for_user( user: User, db_session: Session, get_editable: bool = True, include_default: bool = True, include_slack_bot_personas: bool = False, include_deleted: bool = False, ) -> list[PersonaSnapshot]: stmt = select(Persona) stmt = _add_user_filters(stmt, user, get_editable) stmt = _build_persona_filters( stmt, include_default, include_slack_bot_personas, include_deleted ) stmt = stmt.options( selectinload(Persona.tools), selectinload(Persona.hierarchy_nodes), selectinload(Persona.attached_documents).selectinload( Document.parent_hierarchy_node ), selectinload(Persona.labels), selectinload(Persona.document_sets).options( selectinload(DocumentSet.connector_credential_pairs).selectinload( ConnectorCredentialPair.connector ), selectinload(DocumentSet.users), selectinload(DocumentSet.groups), selectinload(DocumentSet.federated_connectors).selectinload( FederatedConnector__DocumentSet.federated_connector ), ), selectinload(Persona.user), selectinload(Persona.user_files), selectinload(Persona.users), selectinload(Persona.groups), ) results = db_session.scalars(stmt).all() return [PersonaSnapshot.from_model(persona) for persona in results] def get_persona_count_for_user( user: User, db_session: Session, get_editable: bool = True, include_default: bool = True, include_slack_bot_personas: bool = False, include_deleted: bool = False, ) -> int: """Counts the total number of personas accessible to the user. Args: user: The user to filter personas for. If None and auth is disabled, assumes the user is an admin. Otherwise, if None shows only public personas. db_session: Database session for executing queries. get_editable: If True, only returns personas the user can edit. include_default: If True, includes builtin/default personas. include_slack_bot_personas: If True, includes Slack bot personas. include_deleted: If True, includes deleted personas. Returns: Total count of personas matching the filters and user permissions. """ stmt = _build_persona_base_query( user=user, get_editable=get_editable, include_default=include_default, include_slack_bot_personas=include_slack_bot_personas, include_deleted=include_deleted, ) # Convert to count query. count_stmt = stmt.with_only_columns(func.count(func.distinct(Persona.id))).order_by( None ) return db_session.scalar(count_stmt) or 0 def get_minimal_persona_snapshots_paginated( user: User, db_session: Session, page_num: int, page_size: int, get_editable: bool = True, include_default: bool = True, include_slack_bot_personas: bool = False, include_deleted: bool = False, ) -> list[MinimalPersonaSnapshot]: """Gets a single page of minimal persona snapshots with ordering. Personas are ordered by display_priority (ASC, nulls last) then by ID (ASC distance from 0). Args: user: The user to filter personas for. If None and auth is disabled, assumes the user is an admin. Otherwise, if None shows only public personas. db_session: Database session for executing queries. page_num: Zero-indexed page number (e.g., 0 for the first page). page_size: Number of items per page. get_editable: If True, only returns personas the user can edit. include_default: If True, includes builtin/default personas. include_slack_bot_personas: If True, includes Slack bot personas. include_deleted: If True, includes deleted personas. Returns: List of MinimalPersonaSnapshot objects for the requested page, ordered by display_priority (nulls last) then ID. """ stmt = _get_paginated_persona_query( user, page_num, page_size, get_editable, include_default, include_slack_bot_personas, include_deleted, ) # Do eager loading of columns we know MinimalPersonaSnapshot.from_model will # need. stmt = stmt.options( selectinload(Persona.tools), selectinload(Persona.hierarchy_nodes), selectinload(Persona.attached_documents).selectinload( Document.parent_hierarchy_node ), selectinload(Persona.labels), selectinload(Persona.document_sets).options( selectinload(DocumentSet.connector_credential_pairs).selectinload( ConnectorCredentialPair.connector ), selectinload(DocumentSet.users), selectinload(DocumentSet.groups), selectinload(DocumentSet.federated_connectors).selectinload( FederatedConnector__DocumentSet.federated_connector ), ), selectinload(Persona.user), ) results = db_session.scalars(stmt).all() return [MinimalPersonaSnapshot.from_model(persona) for persona in results] def get_persona_snapshots_paginated( user: User, db_session: Session, page_num: int, page_size: int, get_editable: bool = True, include_default: bool = True, include_slack_bot_personas: bool = False, include_deleted: bool = False, ) -> list[PersonaSnapshot]: """Gets a single page of persona snapshots (admin view) with ordering. Personas are ordered by display_priority (ASC, nulls last) then by ID (ASC distance from 0). This function returns PersonaSnapshot objects which contain more detailed information than MinimalPersonaSnapshot, used for admin views. Args: user: The user to filter personas for. If None and auth is disabled, assumes the user is an admin. Otherwise, if None shows only public personas. db_session: Database session for executing queries. page_num: Zero-indexed page number (e.g., 0 for the first page). page_size: Number of items per page. get_editable: If True, only returns personas the user can edit. include_default: If True, includes builtin/default personas. include_slack_bot_personas: If True, includes Slack bot personas. include_deleted: If True, includes deleted personas. Returns: List of PersonaSnapshot objects for the requested page, ordered by display_priority (nulls last) then ID. """ stmt = _get_paginated_persona_query( user, page_num, page_size, get_editable, include_default, include_slack_bot_personas, include_deleted, ) # Do eager loading of columns we know PersonaSnapshot.from_model will need. stmt = stmt.options( selectinload(Persona.tools), selectinload(Persona.hierarchy_nodes), selectinload(Persona.attached_documents).selectinload( Document.parent_hierarchy_node ), selectinload(Persona.labels), selectinload(Persona.document_sets).options( selectinload(DocumentSet.connector_credential_pairs).selectinload( ConnectorCredentialPair.connector ), selectinload(DocumentSet.users), selectinload(DocumentSet.groups), selectinload(DocumentSet.federated_connectors).selectinload( FederatedConnector__DocumentSet.federated_connector ), ), selectinload(Persona.user), selectinload(Persona.user_files), selectinload(Persona.users), selectinload(Persona.groups), ) results = db_session.scalars(stmt).all() return [PersonaSnapshot.from_model(persona) for persona in results] def _get_paginated_persona_query( user: User, page_num: int, page_size: int, get_editable: bool = True, include_default: bool = True, include_slack_bot_personas: bool = False, include_deleted: bool = False, ) -> Select[tuple[Persona]]: """Builds a paginated query on personas ordered on display_priority and id. Personas are ordered by display_priority (ASC, nulls last) then by ID (ASC distance from 0) to match the frontend personaComparator() logic. Args: user: The user to filter personas for. If None and auth is disabled, assumes the user is an admin. Otherwise, if None shows only public personas. page_num: Zero-indexed page number (e.g., 0 for the first page). page_size: Number of items per page. get_editable: If True, only returns personas the user can edit. include_default: If True, includes builtin/default personas. include_slack_bot_personas: If True, includes Slack bot personas. include_deleted: If True, includes deleted personas. Returns: SQLAlchemy Select statement with all filters, ordering, and pagination applied. """ stmt = _build_persona_base_query( user=user, get_editable=get_editable, include_default=include_default, include_slack_bot_personas=include_slack_bot_personas, include_deleted=include_deleted, ) # Add the abs(id) expression to the SELECT list (required for DISTINCT + # ORDER BY). stmt = stmt.add_columns(func.abs(Persona.id).label("abs_id")) # Apply ordering. stmt = stmt.order_by( Persona.display_priority.asc().nullslast(), func.abs(Persona.id).asc(), ) # Apply pagination. stmt = stmt.offset(page_num * page_size).limit(page_size) return stmt def _build_persona_base_query( user: User, get_editable: bool = True, include_default: bool = True, include_slack_bot_personas: bool = False, include_deleted: bool = False, ) -> Select[tuple[Persona]]: """Builds a base persona query with all user and persona filters applied. This helper constructs a filtered query that can then be customized for counting, pagination, or full retrieval. Args: user: The user to filter personas for. If None and auth is disabled, assumes the user is an admin. Otherwise, if None shows only public personas. get_editable: If True, only returns personas the user can edit. include_default: If True, includes builtin/default personas. include_slack_bot_personas: If True, includes Slack bot personas. include_deleted: If True, includes deleted personas. Returns: SQLAlchemy Select statement with all filters applied. """ stmt = select(Persona) stmt = _add_user_filters(stmt, user, get_editable) stmt = _build_persona_filters( stmt, include_default, include_slack_bot_personas, include_deleted ) return stmt def get_raw_personas_for_user( user: User, db_session: Session, get_editable: bool = True, include_default: bool = True, include_slack_bot_personas: bool = False, include_deleted: bool = False, ) -> Sequence[Persona]: stmt = _build_persona_base_query( user, get_editable, include_default, include_slack_bot_personas, include_deleted ) return db_session.scalars(stmt).all() def get_personas(db_session: Session) -> Sequence[Persona]: """WARNING: Unsafe, can fetch personas from all users.""" stmt = select(Persona).distinct() stmt = stmt.where(not_(Persona.name.startswith(SLACK_BOT_PERSONA_PREFIX))) stmt = stmt.where(Persona.deleted.is_(False)) return db_session.execute(stmt).unique().scalars().all() def mark_persona_as_deleted( persona_id: int, user: User, db_session: Session, ) -> None: persona = get_persona_by_id(persona_id=persona_id, user=user, db_session=db_session) persona.deleted = True affected_file_ids = [uf.id for uf in persona.user_files] if affected_file_ids: _mark_files_need_persona_sync(db_session, affected_file_ids) db_session.commit() def mark_persona_as_not_deleted( persona_id: int, user: User, db_session: Session, ) -> None: persona = get_persona_by_id( persona_id=persona_id, user=user, db_session=db_session, include_deleted=True ) if not persona.deleted: raise ValueError(f"Persona with ID {persona_id} is not deleted.") persona.deleted = False affected_file_ids = [uf.id for uf in persona.user_files] if affected_file_ids: _mark_files_need_persona_sync(db_session, affected_file_ids) db_session.commit() def mark_delete_persona_by_name( persona_name: str, db_session: Session, is_default: bool = True ) -> None: stmt = ( update(Persona) .where(Persona.name == persona_name, Persona.builtin_persona == is_default) .values(deleted=True) ) db_session.execute(stmt) db_session.commit() def update_personas_display_priority( display_priority_map: dict[int, int], db_session: Session, user: User, commit_db_txn: bool = False, ) -> None: """Updates the display priorities of the specified Personas. Args: display_priority_map: A map of persona IDs to intended display priorities. db_session: Database session for executing queries. user: The user to filter personas for. If None and auth is disabled, assumes the user is an admin. Otherwise, if None shows only public personas. commit_db_txn: If True, commits the database transaction after updating the display priorities. Defaults to False. Raises: ValueError: The caller tried to update a persona for which the user does not have access. """ # No-op to save a query if it is not necessary. if len(display_priority_map) == 0: return personas = get_raw_personas_for_user( user, db_session, get_editable=False, include_default=True, include_slack_bot_personas=True, include_deleted=True, ) available_personas_map: dict[int, Persona] = { persona.id: persona for persona in personas } for persona_id, priority in display_priority_map.items(): if persona_id not in available_personas_map: raise ValueError( f"Invalid persona ID provided: Persona with ID {persona_id} was not found for this user." ) available_personas_map[persona_id].display_priority = priority if commit_db_txn: db_session.commit() def mark_persona_user_files_for_sync( persona_id: int, db_session: Session, ) -> None: """When persona sharing changes, mark all of its user files for sync so that their ACLs get updated in the vector DB.""" persona = ( db_session.query(Persona) .options(selectinload(Persona.user_files)) .filter(Persona.id == persona_id) .first() ) if not persona: return file_ids = [uf.id for uf in persona.user_files] _mark_files_need_persona_sync(db_session, file_ids) def _mark_files_need_persona_sync( db_session: Session, user_file_ids: list[UUID], ) -> None: """Flag the given UserFile rows so the background sync task picks them up and updates their persona metadata in the vector DB.""" if not user_file_ids: return db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).update( {UserFile.needs_persona_sync: True}, synchronize_session=False, ) def upsert_persona( user: User | None, name: str, description: str, llm_model_provider_override: str | None, llm_model_version_override: str | None, starter_messages: list[StarterMessage] | None, # Embedded prompt fields system_prompt: str | None, task_prompt: str | None, datetime_aware: bool | None, is_public: bool, db_session: Session, document_set_ids: list[int] | None = None, tool_ids: list[int] | None = None, persona_id: int | None = None, commit: bool = True, uploaded_image_id: str | None = None, icon_name: str | None = None, display_priority: int | None = None, is_listed: bool = True, remove_image: bool | None = None, search_start_date: datetime | None = None, builtin_persona: bool = False, is_featured: bool | None = None, label_ids: list[int] | None = None, user_file_ids: list[UUID] | None = None, hierarchy_node_ids: list[int] | None = None, document_ids: list[str] | None = None, replace_base_system_prompt: bool = False, ) -> Persona: """ NOTE: This operation cannot update persona configuration options that are core to the persona, such as its display priority and whether or not the assistant is a built-in / default assistant """ if persona_id is not None: existing_persona = db_session.query(Persona).filter_by(id=persona_id).first() else: existing_persona = _get_persona_by_name( persona_name=name, user=user, db_session=db_session ) # Check for duplicate names when creating new personas # Deleted personas are allowed to be overwritten if existing_persona and not existing_persona.deleted: raise ValueError( f"Assistant with name '{name}' already exists. Please rename your assistant." ) if existing_persona and user: # this checks if the user has permission to edit the persona # will raise an Exception if the user does not have permission # Skip check if user is None (system/admin operation) existing_persona = fetch_persona_by_id_for_user( db_session=db_session, persona_id=existing_persona.id, user=user, get_editable=True, ) # Fetch and attach tools by IDs tools = None if tool_ids is not None: tools = db_session.query(Tool).filter(Tool.id.in_(tool_ids)).all() if not tools and tool_ids: raise ValueError("Tools not found") # Fetch and attach document_sets by IDs document_sets = None if document_set_ids is not None: document_sets = ( db_session.query(DocumentSet) .filter(DocumentSet.id.in_(document_set_ids)) .all() ) if not document_sets and document_set_ids: raise ValueError("document_sets not found") # Fetch and attach user_files by IDs user_files = None if user_file_ids is not None: user_files = ( db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).all() ) if not user_files and user_file_ids: raise ValueError("user_files not found") labels = None if label_ids is not None: labels = ( db_session.query(PersonaLabel).filter(PersonaLabel.id.in_(label_ids)).all() ) if len(labels) != len(label_ids): raise ValueError("Some label IDs were not found in the database") # Fetch and attach hierarchy_nodes by IDs hierarchy_nodes = None if hierarchy_node_ids: hierarchy_nodes = ( db_session.query(HierarchyNode) .filter(HierarchyNode.id.in_(hierarchy_node_ids)) .all() ) if not hierarchy_nodes and hierarchy_node_ids: raise ValueError("hierarchy_nodes not found") # Fetch and attach documents by IDs, filtering for access permissions attached_documents = None if document_ids is not None: user_email = user.email if user else None external_group_ids = ( get_user_external_group_ids(db_session, user) if user else [] ) attached_documents = get_accessible_documents_by_ids( db_session=db_session, document_ids=document_ids, user_email=user_email, external_group_ids=external_group_ids, ) if not attached_documents and document_ids: raise ValueError("documents not found or not accessible") # ensure all specified tools are valid if tools: validate_persona_tools(tools, db_session) if existing_persona: # Built-in personas can only be updated through YAML configuration. # This ensures that core system personas are not modified unintentionally. if existing_persona.builtin_persona and not builtin_persona: raise ValueError("Cannot update builtin persona with non-builtin.") # The following update excludes `default`, `built-in`, and display priority. # Display priority is handled separately in the `display-priority` endpoint. # `default` and `built-in` properties can only be set when creating a persona. existing_persona.name = name existing_persona.description = description existing_persona.llm_model_provider_override = llm_model_provider_override existing_persona.llm_model_version_override = llm_model_version_override existing_persona.starter_messages = starter_messages existing_persona.deleted = False # Un-delete if previously deleted existing_persona.is_public = is_public if remove_image or uploaded_image_id: existing_persona.uploaded_image_id = uploaded_image_id existing_persona.icon_name = icon_name existing_persona.is_listed = is_listed existing_persona.search_start_date = search_start_date if label_ids is not None: existing_persona.labels.clear() existing_persona.labels = labels or [] existing_persona.is_featured = ( is_featured if is_featured is not None else existing_persona.is_featured ) # Update embedded prompt fields if provided if system_prompt is not None: existing_persona.system_prompt = system_prompt if task_prompt is not None: existing_persona.task_prompt = task_prompt if datetime_aware is not None: existing_persona.datetime_aware = datetime_aware existing_persona.replace_base_system_prompt = replace_base_system_prompt # Do not delete any associations manually added unless # a new updated list is provided if document_sets is not None: existing_persona.document_sets.clear() existing_persona.document_sets = document_sets or [] # Note: prompts are now embedded in personas - no separate prompts relationship if tools is not None: existing_persona.tools = tools or [] if user_file_ids is not None: old_file_ids = {uf.id for uf in existing_persona.user_files} new_file_ids = {uf.id for uf in (user_files or [])} affected_file_ids = old_file_ids | new_file_ids existing_persona.user_files.clear() existing_persona.user_files = user_files or [] if affected_file_ids: _mark_files_need_persona_sync(db_session, list(affected_file_ids)) if hierarchy_node_ids is not None: existing_persona.hierarchy_nodes.clear() existing_persona.hierarchy_nodes = hierarchy_nodes or [] if document_ids is not None: existing_persona.attached_documents.clear() existing_persona.attached_documents = attached_documents or [] # We should only update display priority if it is not already set if existing_persona.display_priority is None: existing_persona.display_priority = display_priority persona = existing_persona else: # Create new persona - prompt configuration will be set separately if needed new_persona = Persona( id=persona_id, user_id=user.id if user else None, is_public=is_public, name=name, description=description, builtin_persona=builtin_persona, system_prompt=system_prompt or "", task_prompt=task_prompt or "", datetime_aware=(datetime_aware if datetime_aware is not None else True), replace_base_system_prompt=replace_base_system_prompt, document_sets=document_sets or [], llm_model_provider_override=llm_model_provider_override, llm_model_version_override=llm_model_version_override, starter_messages=starter_messages, tools=tools or [], uploaded_image_id=uploaded_image_id, icon_name=icon_name, display_priority=display_priority, is_listed=is_listed, search_start_date=search_start_date, is_featured=(is_featured if is_featured is not None else False), user_files=user_files or [], labels=labels or [], hierarchy_nodes=hierarchy_nodes or [], attached_documents=attached_documents or [], ) db_session.add(new_persona) if user_files: _mark_files_need_persona_sync(db_session, [uf.id for uf in user_files]) persona = new_persona if commit: db_session.commit() else: # flush the session so that the persona has an ID db_session.flush() return persona def delete_old_default_personas( db_session: Session, ) -> None: """Note, this locks out the Summarize and Paraphrase personas for now Need a more graceful fix later or those need to never have IDs. This function is idempotent, so it can be run multiple times without issue. """ OLD_SUFFIX = "_old" stmt = ( update(Persona) .where( Persona.builtin_persona, Persona.id > 0, or_( Persona.deleted.is_(False), not_(Persona.name.endswith(OLD_SUFFIX)), ), ) .values(deleted=True, name=func.concat(Persona.name, OLD_SUFFIX)) ) db_session.execute(stmt) db_session.commit() def update_persona_featured( persona_id: int, is_featured: bool, db_session: Session, user: User, ) -> None: persona = fetch_persona_by_id_for_user( db_session=db_session, persona_id=persona_id, user=user, get_editable=True ) persona.is_featured = is_featured db_session.commit() def update_persona_visibility( persona_id: int, is_listed: bool, db_session: Session, user: User, ) -> None: persona = fetch_persona_by_id_for_user( db_session=db_session, persona_id=persona_id, user=user, get_editable=True ) persona.is_listed = is_listed db_session.commit() def validate_persona_tools(tools: list[Tool], db_session: Session) -> None: # local import to avoid circular import. DB layer should not depend on tools layer. from onyx.tools.built_in_tools import get_built_in_tool_by_id for tool in tools: if tool.in_code_tool_id is not None: tool_cls = get_built_in_tool_by_id(tool.in_code_tool_id) if not tool_cls.is_available(db_session): raise ValueError(f"Tool {tool.in_code_tool_id} is not available") # TODO: since this gets called with every chat message, could it be more efficient to pregenerate # a direct mapping indicating whether a user has access to a specific persona? def get_persona_by_id( persona_id: int, user: User | None, db_session: Session, include_deleted: bool = False, is_for_edit: bool = True, # NOTE: assume true for safety ) -> Persona: persona_stmt = ( select(Persona) .distinct() .outerjoin(Persona.groups) .outerjoin(Persona.users) .outerjoin(UserGroup.user_group_relationships) .where(Persona.id == persona_id) ) if not include_deleted: persona_stmt = persona_stmt.where(Persona.deleted.is_(False)) if not user or user.role == UserRole.ADMIN: result = db_session.execute(persona_stmt) persona = result.scalar_one_or_none() if persona is None: raise ValueError(f"Persona with ID {persona_id} does not exist") return persona # or check if user owns persona or_conditions = Persona.user_id == user.id # allow access if persona user id is None or_conditions |= Persona.user_id == None # noqa: E711 if not is_for_edit: # if the user is in a group related to the persona or_conditions |= User__UserGroup.user_id == user.id # if the user is in the .users of the persona or_conditions |= User.id == user.id or_conditions |= Persona.is_public == True # noqa: E712 elif user.role == UserRole.GLOBAL_CURATOR: # global curators can edit personas for the groups they are in or_conditions |= User__UserGroup.user_id == user.id elif user.role == UserRole.CURATOR: # curators can edit personas for the groups they are curators of or_conditions |= (User__UserGroup.user_id == user.id) & ( User__UserGroup.is_curator == True # noqa: E712 ) persona_stmt = persona_stmt.where(or_conditions) result = db_session.execute(persona_stmt) persona = result.scalar_one_or_none() if persona is None: raise ValueError( f"Persona with ID {persona_id} does not exist or does not belong to user" ) return persona def get_personas_by_ids( persona_ids: list[int], db_session: Session ) -> Sequence[Persona]: """WARNING: Unsafe, can fetch personas from all users.""" if not persona_ids: return [] personas = db_session.scalars( select(Persona).where(Persona.id.in_(persona_ids)) ).all() return personas def delete_persona_by_name( persona_name: str, db_session: Session, is_default: bool = True ) -> None: stmt = ( update(Persona) .where(Persona.name == persona_name, Persona.builtin_persona == is_default) .values(deleted=True) ) db_session.execute(stmt) db_session.commit() def get_assistant_labels(db_session: Session) -> list[PersonaLabel]: return db_session.query(PersonaLabel).all() def create_assistant_label(db_session: Session, name: str) -> PersonaLabel: label = PersonaLabel(name=name) db_session.add(label) db_session.commit() return label def update_persona_label( label_id: int, label_name: str, db_session: Session, ) -> None: persona_label = ( db_session.query(PersonaLabel).filter(PersonaLabel.id == label_id).one_or_none() ) if persona_label is None: raise ValueError(f"Persona label with ID {label_id} does not exist") persona_label.name = label_name db_session.commit() def delete_persona_label(label_id: int, db_session: Session) -> None: db_session.query(PersonaLabel).filter(PersonaLabel.id == label_id).delete() db_session.commit() def persona_has_search_tool(persona_id: int, db_session: Session) -> bool: persona = ( db_session.query(Persona) .options(selectinload(Persona.tools)) .filter(Persona.id == persona_id) .one_or_none() ) if persona is None: raise ValueError(f"Persona with ID {persona_id} does not exist") return any(tool.in_code_tool_id == "run_search" for tool in persona.tools) def get_default_assistant(db_session: Session) -> Persona | None: """Fetch the default assistant (persona with builtin_persona=True).""" return ( db_session.query(Persona) .options(selectinload(Persona.tools)) .filter(Persona.builtin_persona.is_(True)) # NOTE: need to add this since we had prior builtin personas # that have since been deleted .filter(Persona.deleted.is_(False)) .one_or_none() ) def update_default_assistant_configuration( db_session: Session, tool_ids: list[int] | None = None, system_prompt: str | None = None, update_system_prompt: bool = False, ) -> Persona: """Update only tools and system_prompt for the default assistant. Args: db_session: Database session tool_ids: List of tool IDs to enable (if None, tools are not updated) system_prompt: New system prompt value (None means use default) update_system_prompt: If True, update the system_prompt field (allows setting to None) Returns: Updated Persona object Raises: ValueError: If default assistant not found or invalid tool IDs provided """ # Get the default assistant persona = get_default_assistant(db_session) if not persona: raise ValueError("Default assistant not found") # Update system prompt if explicitly requested if update_system_prompt: persona.system_prompt = system_prompt # Update tools if provided if tool_ids is not None: # Clear existing tool associations persona.tools = [] # Add new tool associations for tool_id in tool_ids: tool = db_session.query(Tool).filter(Tool.id == tool_id).one_or_none() if not tool: raise ValueError(f"Tool with ID {tool_id} not found") if not should_expose_tool_to_fe(tool): raise ValueError(f"Tool with ID {tool_id} cannot be assigned") if not tool.enabled: raise ValueError( f"Enable tool {tool.display_name or tool.name} before assigning it" ) persona.tools.append(tool) db_session.commit() return persona def user_can_access_persona( db_session: Session, persona_id: int, user: User, get_editable: bool = False ) -> bool: """Check if a user has access to a specific persona. Args: db_session: Database session persona_id: ID of the persona to check user: User to check access for get_editable: If True, check for edit access; if False, check for view access Returns: True if user can access the persona, False otherwise """ stmt = select(Persona).where(Persona.id == persona_id, Persona.deleted.is_(False)) stmt = _add_user_filters(stmt, user, get_editable=get_editable) return db_session.scalar(stmt) is not None ================================================ FILE: backend/onyx/db/projects.py ================================================ import datetime import uuid from typing import List from uuid import UUID from fastapi import HTTPException from fastapi import UploadFile from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field from sqlalchemy import func from sqlalchemy.orm import Session from starlette.background import BackgroundTasks from onyx.configs.app_configs import DISABLE_VECTOR_DB from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES from onyx.configs.constants import FileOrigin from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.db.enums import UserFileStatus from onyx.db.models import Project__UserFile from onyx.db.models import User from onyx.db.models import UserFile from onyx.db.models import UserProject from onyx.server.documents.connector import upload_files from onyx.server.features.projects.projects_file_utils import categorize_uploaded_files from onyx.server.features.projects.projects_file_utils import RejectedFile from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() class CategorizedFilesResult(BaseModel): user_files: list[UserFile] rejected_files: list[RejectedFile] id_to_temp_id: dict[str, str] # Filenames that should be stored but not indexed. skip_indexing_filenames: set[str] = Field(default_factory=set) # Allow SQLAlchemy ORM models inside this result container model_config = ConfigDict(arbitrary_types_allowed=True) @property def indexable_files(self) -> list[UserFile]: return [ uf for uf in self.user_files if (uf.name or "") not in self.skip_indexing_filenames ] def build_hashed_file_key(file: UploadFile) -> str: name_prefix = (file.filename or "")[:50] return f"{file.size}|{name_prefix}" def create_user_files( files: List[UploadFile], project_id: int | None, user: User, db_session: Session, link_url: str | None = None, temp_id_map: dict[str, str] | None = None, ) -> CategorizedFilesResult: # Categorize the files categorized_files = categorize_uploaded_files(files, db_session) # NOTE: At the moment, zip metadata is not used for user files. # Should revisit to decide whether this should be a feature. upload_response = upload_files(categorized_files.acceptable, FileOrigin.USER_FILE) user_files = [] rejected_files = categorized_files.rejected id_to_temp_id: dict[str, str] = {} # Pair returned storage paths with the same set of acceptable files we uploaded for file_path, file in zip( upload_response.file_paths, categorized_files.acceptable ): new_id = uuid.uuid4() new_temp_id = ( temp_id_map.get(build_hashed_file_key(file)) if temp_id_map else None ) if new_temp_id is not None: id_to_temp_id[str(new_id)] = new_temp_id should_skip = (file.filename or "") in categorized_files.skip_indexing new_file = UserFile( id=new_id, user_id=user.id, file_id=file_path, name=file.filename, token_count=categorized_files.acceptable_file_to_token_count[ file.filename or "" ], link_url=link_url, content_type=file.content_type, file_type=file.content_type, status=UserFileStatus.SKIPPED if should_skip else UserFileStatus.PROCESSING, last_accessed_at=datetime.datetime.now(datetime.timezone.utc), ) # Persist the UserFile first to satisfy FK constraints for association table db_session.add(new_file) db_session.flush() if project_id: project_to_user_file = Project__UserFile( project_id=project_id, user_file_id=new_file.id, ) db_session.add(project_to_user_file) user_files.append(new_file) db_session.commit() return CategorizedFilesResult( user_files=user_files, rejected_files=rejected_files, id_to_temp_id=id_to_temp_id, skip_indexing_filenames=categorized_files.skip_indexing, ) def upload_files_to_user_files_with_indexing( files: List[UploadFile], project_id: int | None, user: User, temp_id_map: dict[str, str] | None, db_session: Session, background_tasks: BackgroundTasks | None = None, ) -> CategorizedFilesResult: if project_id is not None and user is not None: if not check_project_ownership(project_id, user.id, db_session): raise HTTPException(status_code=404, detail="Project not found") categorized_files_result = create_user_files( files, project_id, user, db_session, temp_id_map=temp_id_map, ) user_files = categorized_files_result.user_files rejected_files = categorized_files_result.rejected_files id_to_temp_id = categorized_files_result.id_to_temp_id indexable_files = categorized_files_result.indexable_files # Trigger per-file processing immediately for the current tenant tenant_id = get_current_tenant_id() for rejected_file in rejected_files: logger.warning( f"File {rejected_file.filename} rejected for {rejected_file.reason}" ) if DISABLE_VECTOR_DB and background_tasks is not None: from onyx.background.task_utils import drain_processing_loop background_tasks.add_task(drain_processing_loop, tenant_id) for user_file in indexable_files: logger.info(f"Queued in-process processing for user_file_id={user_file.id}") else: from onyx.background.celery.versioned_apps.client import app as client_app for user_file in indexable_files: task = client_app.send_task( OnyxCeleryTask.PROCESS_SINGLE_USER_FILE, kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id}, queue=OnyxCeleryQueues.USER_FILE_PROCESSING, priority=OnyxCeleryPriority.HIGH, expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES, ) logger.info( f"Triggered indexing for user_file_id={user_file.id} with task_id={task.id}" ) return CategorizedFilesResult( user_files=user_files, rejected_files=rejected_files, id_to_temp_id=id_to_temp_id, skip_indexing_filenames=categorized_files_result.skip_indexing_filenames, ) def check_project_ownership( project_id: int, user_id: UUID | None, db_session: Session ) -> bool: # In no-auth mode, all projects are accessible if user_id is None: # Verify project exists return ( db_session.query(UserProject).filter(UserProject.id == project_id).first() is not None ) return ( db_session.query(UserProject) .filter(UserProject.id == project_id, UserProject.user_id == user_id) .first() is not None ) def get_user_files_from_project( project_id: int, user_id: UUID | None, db_session: Session ) -> list[UserFile]: # First check if the user owns the project if not check_project_ownership(project_id, user_id, db_session): return [] return ( db_session.query(UserFile) .join(Project__UserFile) .filter(Project__UserFile.project_id == project_id) .all() ) def get_project_instructions(db_session: Session, project_id: int | None) -> str | None: """Return the project's instruction text from the project, else None. Safe helper that swallows DB errors and returns None on any failure. """ if not project_id: return None try: project = ( db_session.query(UserProject) .filter(UserProject.id == project_id) .one_or_none() ) if not project or not project.instructions: return None instructions = project.instructions.strip() return instructions or None except Exception: return None def get_project_token_count( project_id: int | None, user_id: UUID | None, db_session: Session, ) -> int: """Return sum of token_count for all user files in the given project. If project_id is None, returns 0. """ if project_id is None: return 0 total_tokens = ( db_session.query(func.coalesce(func.sum(UserFile.token_count), 0)) .filter( UserFile.user_id == user_id, UserFile.projects.any(id=project_id), ) .scalar() or 0 ) return int(total_tokens) ================================================ FILE: backend/onyx/db/pydantic_type.py ================================================ import json from typing import Any from typing import Optional from typing import Type from pydantic import BaseModel from sqlalchemy.dialects.postgresql import JSONB from sqlalchemy.types import TypeDecorator class PydanticType(TypeDecorator): impl = JSONB def __init__( self, pydantic_model: Type[BaseModel], *args: Any, **kwargs: Any ) -> None: super().__init__(*args, **kwargs) self.pydantic_model = pydantic_model def process_bind_param( self, value: Optional[BaseModel], dialect: Any, # noqa: ARG002 ) -> Optional[dict]: if value is not None: return json.loads(value.json()) return None def process_result_value( self, value: Optional[dict], dialect: Any, # noqa: ARG002 ) -> Optional[BaseModel]: if value is not None: return self.pydantic_model.parse_obj(value) return None class PydanticListType(TypeDecorator): impl = JSONB def __init__( self, pydantic_model: Type[BaseModel], *args: Any, **kwargs: Any ) -> None: super().__init__(*args, **kwargs) self.pydantic_model = pydantic_model def process_bind_param( self, value: Optional[list[BaseModel]], dialect: Any, # noqa: ARG002 ) -> Optional[list[dict]]: if value is not None: return [json.loads(item.model_dump_json()) for item in value] return None def process_result_value( self, value: Optional[list[dict]], dialect: Any, # noqa: ARG002 ) -> Optional[list[BaseModel]]: if value is not None: return [self.pydantic_model.model_validate(item) for item in value] return None ================================================ FILE: backend/onyx/db/relationships.py ================================================ from typing import List from sqlalchemy import or_ from sqlalchemy.dialects import postgresql from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.orm import Session import onyx.db.document as dbdocument from onyx.db.models import KGEntity from onyx.db.models import KGEntityExtractionStaging from onyx.db.models import KGRelationship from onyx.db.models import KGRelationshipExtractionStaging from onyx.db.models import KGRelationshipType from onyx.db.models import KGRelationshipTypeExtractionStaging from onyx.db.models import KGStage from onyx.kg.utils.formatting_utils import extract_relationship_type_id from onyx.kg.utils.formatting_utils import format_relationship_id from onyx.kg.utils.formatting_utils import get_entity_type from onyx.kg.utils.formatting_utils import make_relationship_id from onyx.kg.utils.formatting_utils import make_relationship_type_id from onyx.kg.utils.formatting_utils import split_relationship_id from onyx.utils.logger import setup_logger logger = setup_logger() def upsert_staging_relationship( db_session: Session, relationship_id_name: str, source_document_id: str | None, occurrences: int = 1, ) -> KGRelationshipExtractionStaging: """ Add or update a new staging relationship to the database. Args: db_session: SQLAlchemy database session relationship_id_name: The ID name of the relationship in format "source__relationship__target" source_document_id: ID of the source document occurrences: Number of times this relationship has been found Returns: The created or updated KGRelationshipExtractionStaging object Raises: sqlalchemy.exc.IntegrityError: If there's an error with the database operation """ # Generate a unique ID for the relationship relationship_id_name = format_relationship_id(relationship_id_name) ( source_entity_id_name, relationship_string, target_entity_id_name, ) = split_relationship_id(relationship_id_name) source_entity_type = get_entity_type(source_entity_id_name) target_entity_type = get_entity_type(target_entity_id_name) relationship_type = extract_relationship_type_id(relationship_id_name) # Insert the new relationship stmt = ( postgresql.insert(KGRelationshipExtractionStaging) .values( { "id_name": relationship_id_name, "source_node": source_entity_id_name, "target_node": target_entity_id_name, "source_node_type": source_entity_type, "target_node_type": target_entity_type, "type": relationship_string.lower(), "relationship_type_id_name": relationship_type, "source_document": source_document_id, "occurrences": occurrences, } ) .on_conflict_do_update( index_elements=["id_name", "source_document"], set_=dict( occurrences=KGRelationshipExtractionStaging.occurrences + occurrences, ), ) .returning(KGRelationshipExtractionStaging) ) result = db_session.execute(stmt).scalar() if result is None: raise RuntimeError( f"Failed to create or increment staging relationship with id_name: {relationship_id_name}" ) # Update the document's kg_stage if source_document is provided if source_document_id is not None: dbdocument.update_document_kg_info( db_session, document_id=source_document_id, kg_stage=KGStage.EXTRACTED, ) db_session.flush() # Flush to get any DB errors early return result def upsert_relationship( db_session: Session, relationship_id_name: str, source_document_id: str | None, occurrences: int = 1, ) -> KGRelationship: """ Upsert a new relationship directly to the database. Args: db_session: SQLAlchemy database session relationship_id_name: The ID name of the relationship in format "source__relationship__target" source_document_id: ID of the source document occurrences: Number of times this relationship has been found Returns: The created or updated KGRelationship object Raises: sqlalchemy.exc.IntegrityError: If there's an error with the database operation """ # Generate a unique ID for the relationship relationship_id_name = format_relationship_id(relationship_id_name) ( source_entity_id_name, relationship_string, target_entity_id_name, ) = split_relationship_id(relationship_id_name) source_entity_type = get_entity_type(source_entity_id_name) target_entity_type = get_entity_type(target_entity_id_name) relationship_type = extract_relationship_type_id(relationship_id_name) # Insert the new relationship stmt = ( postgresql.insert(KGRelationship) .values( { "id_name": relationship_id_name, "source_node": source_entity_id_name, "target_node": target_entity_id_name, "source_node_type": source_entity_type, "target_node_type": target_entity_type, "type": relationship_string.lower(), "relationship_type_id_name": relationship_type, "source_document": source_document_id, "occurrences": occurrences, } ) .on_conflict_do_update( index_elements=["id_name", "source_document"], set_=dict( occurrences=KGRelationship.occurrences + occurrences, ), ) .returning(KGRelationship) ) new_relationship = db_session.execute(stmt).scalar() if new_relationship is None: raise RuntimeError( f"Failed to upsert relationship with id_name: {relationship_id_name}" ) db_session.flush() return new_relationship def transfer_relationship( db_session: Session, relationship: KGRelationshipExtractionStaging, entity_translations: dict[str, str], ) -> KGRelationship: """ Transfer a relationship from the staging table to the normalized table. """ # Translate the source and target nodes source_node = entity_translations[relationship.source_node] target_node = entity_translations[relationship.target_node] relationship_id_name = make_relationship_id( source_node, relationship.type, target_node ) # Create the transferred relationship stmt = ( pg_insert(KGRelationship) .values( id_name=relationship_id_name, source_node=source_node, target_node=target_node, source_node_type=relationship.source_node_type, target_node_type=relationship.target_node_type, type=relationship.type, relationship_type_id_name=relationship.relationship_type_id_name, source_document=relationship.source_document, occurrences=relationship.occurrences, ) .on_conflict_do_update( index_elements=["id_name", "source_document"], set_=dict( occurrences=KGRelationship.occurrences + relationship.occurrences, ), ) .returning(KGRelationship) ) new_relationship = db_session.execute(stmt).scalar() if new_relationship is None: raise RuntimeError( f"Failed to transfer relationship with id_name: {relationship.id_name}" ) # Update transferred db_session.query(KGRelationshipExtractionStaging).filter( KGRelationshipExtractionStaging.id_name == relationship.id_name, KGRelationshipExtractionStaging.source_document == relationship.source_document, ).update({"transferred": True}) db_session.flush() return new_relationship def upsert_staging_relationship_type( db_session: Session, source_entity_type: str, relationship_type: str, target_entity_type: str, definition: bool = False, extraction_count: int = 1, ) -> KGRelationshipTypeExtractionStaging: """ Add a new relationship type to the database. Args: db_session: SQLAlchemy session source_entity_type: Type of the source entity relationship_type: Type of relationship target_entity_type: Type of the target entity definition: Whether this relationship type represents a definition (default False) Returns: The created KGRelationshipTypeExtractionStaging object """ id_name = make_relationship_type_id( source_entity_type, relationship_type, target_entity_type ) # Create new relationship type stmt = ( postgresql.insert(KGRelationshipTypeExtractionStaging) .values( { "id_name": id_name, "name": relationship_type, "source_entity_type_id_name": source_entity_type.upper(), "target_entity_type_id_name": target_entity_type.upper(), "definition": definition, "occurrences": extraction_count, "type": relationship_type, # Using the relationship_type as the type "active": True, # Setting as active by default } ) .on_conflict_do_update( index_elements=["id_name"], set_=dict( occurrences=KGRelationshipTypeExtractionStaging.occurrences + extraction_count, ), ) .returning(KGRelationshipTypeExtractionStaging) ) result = db_session.execute(stmt).scalar() if result is None: raise RuntimeError( f"Failed to create or increment staging relationship type with id_name: {id_name}" ) db_session.flush() # Flush to get any DB errors early return result def upsert_relationship_type( db_session: Session, source_entity_type: str, relationship_type: str, target_entity_type: str, definition: bool = False, extraction_count: int = 1, ) -> KGRelationshipType: """ Upsert a new relationship type directly to the database. Args: db_session: SQLAlchemy session source_entity_type: Type of the source entity relationship_type: Type of relationship target_entity_type: Type of the target entity definition: Whether this relationship type represents a definition (default False) Returns: The created KGRelationshipType object """ id_name = make_relationship_type_id( source_entity_type, relationship_type, target_entity_type ) # Create new relationship type stmt = ( postgresql.insert(KGRelationshipType) .values( { "id_name": id_name, "name": relationship_type, "source_entity_type_id_name": source_entity_type.upper(), "target_entity_type_id_name": target_entity_type.upper(), "definition": definition, "occurrences": extraction_count, "type": relationship_type, # Using the relationship_type as the type "active": True, # Setting as active by default } ) .on_conflict_do_update( index_elements=["id_name"], set_=dict( occurrences=KGRelationshipType.occurrences + extraction_count, ), ) .returning(KGRelationshipType) ) new_relationship_type = db_session.execute(stmt).scalar() if new_relationship_type is None: raise RuntimeError( f"Failed to upsert relationship type with id_name: {id_name}" ) db_session.flush() return new_relationship_type def transfer_relationship_type( db_session: Session, relationship_type: KGRelationshipTypeExtractionStaging, ) -> KGRelationshipType: """ Transfer a relationship type from the staging table to the normalized table. """ stmt = ( pg_insert(KGRelationshipType) .values( id_name=relationship_type.id_name, name=relationship_type.name, source_entity_type_id_name=relationship_type.source_entity_type_id_name, target_entity_type_id_name=relationship_type.target_entity_type_id_name, definition=relationship_type.definition, occurrences=relationship_type.occurrences, type=relationship_type.type, active=relationship_type.active, ) .on_conflict_do_update( index_elements=["id_name"], set_=dict( occurrences=KGRelationshipType.occurrences + relationship_type.occurrences, ), ) .returning(KGRelationshipType) ) new_relationship_type = db_session.execute(stmt).scalar() if new_relationship_type is None: raise RuntimeError( f"Failed to transfer relationship type with id_name: {relationship_type.id_name}" ) # Update transferred db_session.query(KGRelationshipTypeExtractionStaging).filter( KGRelationshipTypeExtractionStaging.id_name == relationship_type.id_name ).update({"transferred": True}) db_session.flush() return new_relationship_type def delete_relationships_by_id_names( db_session: Session, id_names: list[str], kg_stage: KGStage ) -> int: """ Delete relationships from the database based on a list of id_names. Args: db_session: SQLAlchemy database session id_names: List of relationship id_names to delete Returns: Number of relationships deleted Raises: sqlalchemy.exc.SQLAlchemyError: If there's an error during deletion """ deleted_count = 0 if kg_stage == KGStage.EXTRACTED: deleted_count = ( db_session.query(KGRelationshipExtractionStaging) .filter(KGRelationshipExtractionStaging.id_name.in_(id_names)) .delete(synchronize_session=False) ) elif kg_stage == KGStage.NORMALIZED: deleted_count = ( db_session.query(KGRelationship) .filter(KGRelationship.id_name.in_(id_names)) .delete(synchronize_session=False) ) db_session.flush() # Flush to ensure deletion is processed return deleted_count def delete_relationship_types_by_id_names( db_session: Session, id_names: list[str], kg_stage: KGStage ) -> int: """ Delete relationship types from the database based on a list of id_names. Args: db_session: SQLAlchemy database session id_names: List of relationship type id_names to delete Returns: Number of relationship types deleted Raises: sqlalchemy.exc.SQLAlchemyError: If there's an error during deletion """ deleted_count = 0 if kg_stage == KGStage.EXTRACTED: deleted_count = ( db_session.query(KGRelationshipTypeExtractionStaging) .filter(KGRelationshipTypeExtractionStaging.id_name.in_(id_names)) .delete(synchronize_session=False) ) elif kg_stage == KGStage.NORMALIZED: deleted_count = ( db_session.query(KGRelationshipType) .filter(KGRelationshipType.id_name.in_(id_names)) .delete(synchronize_session=False) ) db_session.flush() # Flush to ensure deletion is processed return deleted_count def get_relationships_for_entity_type_pairs( db_session: Session, entity_type_pairs: list[tuple[str, str]] ) -> list["KGRelationshipType"]: """ Get relationship types from the database based on a list of entity type pairs. Args: db_session: SQLAlchemy database session entity_type_pairs: List of tuples where each tuple contains (source_entity_type, target_entity_type) Returns: List of KGRelationshipType objects where source and target types match the provided pairs """ conditions = [ ( (KGRelationshipType.source_entity_type_id_name == source_type) & (KGRelationshipType.target_entity_type_id_name == target_type) ) for source_type, target_type in entity_type_pairs ] return db_session.query(KGRelationshipType).filter(or_(*conditions)).all() def get_allowed_relationship_type_pairs( db_session: Session, entities: list[str] ) -> list[str]: """ Get the allowed relationship pairs for the given entities. Args: db_session: SQLAlchemy database session entities: List of entity type ID names to filter by Returns: List of id_names from KGRelationshipType where source or target entity types are in the provided entities list. We also filter out for now the catch-all relationship types 'VENDOR____' """ entity_types = list({get_entity_type(entity) for entity in entities}) return [ row[0] for row in ( db_session.query(KGRelationshipType.id_name) .filter( or_( KGRelationshipType.source_entity_type_id_name.in_(entity_types), KGRelationshipType.target_entity_type_id_name.in_(entity_types), ) ) .filter(~KGRelationshipType.source_entity_type_id_name.like("VENDOR::%")) .distinct() .all() ) ] def get_relationships_of_entity(db_session: Session, entity_id: str) -> List[str]: """Get all relationship ID names where the given entity is either the source or target node. Args: db_session: SQLAlchemy session entity_id: ID of the entity to find relationships for Returns: List of relationship ID names where the entity is either source or target """ return [ row[0] for row in ( db_session.query(KGRelationship.id_name) .filter( or_( KGRelationship.source_node == entity_id, KGRelationship.target_node == entity_id, ) ) .all() ) ] def get_relationship_types_of_entity_types( db_session: Session, entity_types_id: str ) -> List[str]: """Get all relationship ID names where the given entity is either the source or target node. Args: db_session: SQLAlchemy session entity_types_id: ID of the entity to find relationships for Returns: List of relationship ID names where the entity is either source or target """ if entity_types_id.endswith(":*"): entity_types_id = entity_types_id[:-2] return [ row[0] for row in ( db_session.query(KGRelationshipType.id_name) .filter( or_( KGRelationshipType.source_entity_type_id_name == entity_types_id, KGRelationshipType.target_entity_type_id_name == entity_types_id, ) ) .all() ) ] def delete_document_references_from_kg(db_session: Session, document_id: str) -> None: # Delete relationships from normalized stage db_session.query(KGRelationship).filter( KGRelationship.source_document == document_id ).delete(synchronize_session=False) # Delete relationships from extraction staging db_session.query(KGRelationshipExtractionStaging).filter( KGRelationshipExtractionStaging.source_document == document_id ).delete(synchronize_session=False) # Delete entities from normalized stage db_session.query(KGEntity).filter(KGEntity.document_id == document_id).delete( synchronize_session=False ) # Delete entities from extraction staging db_session.query(KGEntityExtractionStaging).filter( KGEntityExtractionStaging.document_id == document_id ).delete(synchronize_session=False) db_session.flush() def delete_from_kg_relationships_extraction_staging__no_commit( db_session: Session, document_ids: list[str] ) -> None: """Delete relationships from the extraction staging table.""" db_session.query(KGRelationshipExtractionStaging).filter( KGRelationshipExtractionStaging.source_document.in_(document_ids) ).delete(synchronize_session=False) def delete_from_kg_relationships__no_commit( db_session: Session, document_ids: list[str] ) -> None: """Delete relationships from the normalized table.""" db_session.query(KGRelationship).filter( KGRelationship.source_document.in_(document_ids) ).delete(synchronize_session=False) ================================================ FILE: backend/onyx/db/release_notes.py ================================================ """Database functions for release notes functionality.""" from urllib.parse import urlencode from sqlalchemy import select from sqlalchemy.orm import Session from onyx.configs.app_configs import INSTANCE_TYPE from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN from onyx.configs.constants import NotificationType from onyx.configs.constants import ONYX_UTM_SOURCE from onyx.db.enums import AccountType from onyx.db.models import User from onyx.db.notification import batch_create_notifications from onyx.server.features.release_notes.constants import DOCS_CHANGELOG_BASE_URL from onyx.server.features.release_notes.models import ReleaseNoteEntry from onyx.utils.logger import setup_logger logger = setup_logger() def create_release_notifications_for_versions( db_session: Session, release_note_entries: list[ReleaseNoteEntry], ) -> int: """ Create release notes notifications for each release note entry. Uses batch_create_notifications for efficient bulk insertion. If a user already has a notification for a specific version (dismissed or not), no new one is created (handled by unique constraint on additional_data). Note: Entries should already be filtered by app_version before calling this function. The filtering happens in _parse_mdx_to_release_note_entries(). Args: db_session: Database session release_note_entries: List of release note entries to notify about (pre-filtered) Returns: Total number of notifications created across all versions. """ if not release_note_entries: logger.debug("No release note entries to notify about") return 0 # Get active users and exclude API key users user_ids = list( db_session.scalars( select(User.id).where( # type: ignore User.is_active == True, # noqa: E712 User.account_type.notin_([AccountType.BOT, AccountType.EXT_PERM_USER]), User.email.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN).is_(False), # type: ignore[attr-defined] ) ).all() ) total_created = 0 for entry in release_note_entries: # Convert version to anchor format for external docs links # v2.7.0 -> v2-7-0 version_anchor = entry.version.replace(".", "-") # Build UTM parameters for tracking utm_params = { "utm_source": ONYX_UTM_SOURCE, "utm_medium": "notification", "utm_campaign": INSTANCE_TYPE, "utm_content": f"release_notes-{entry.version}", } link = f"{DOCS_CHANGELOG_BASE_URL}#{version_anchor}?{urlencode(utm_params)}" additional_data: dict[str, str] = { "version": entry.version, "link": link, } created_count = batch_create_notifications( user_ids, NotificationType.RELEASE_NOTES, db_session, title=entry.title, description=f"Check out what's new in {entry.version}", additional_data=additional_data, ) total_created += created_count logger.debug( f"Created {created_count} release notes notifications (version {entry.version}, {len(user_ids)} eligible users)" ) return total_created ================================================ FILE: backend/onyx/db/rotate_encryption_key.py ================================================ """Rotate encryption key for all encrypted columns. Dynamically discovers all columns using EncryptedString / EncryptedJson, decrypts each value with the old key, and re-encrypts with the current ENCRYPTION_KEY_SECRET. The operation is idempotent: rows already encrypted with the current key are skipped. Commits are made in batches so a crash mid-rotation can be safely resumed by re-running. """ import json from typing import Any from sqlalchemy import LargeBinary from sqlalchemy import select from sqlalchemy import update from sqlalchemy.orm import Session from onyx.configs.app_configs import ENCRYPTION_KEY_SECRET from onyx.db.models import Base from onyx.db.models import EncryptedJson from onyx.db.models import EncryptedString from onyx.utils.encryption import decrypt_bytes_to_string from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import global_version logger = setup_logger() _BATCH_SIZE = 500 def _can_decrypt_with_current_key(data: bytes) -> bool: """Check if data is already encrypted with the current key. Passes the key explicitly so the fallback-to-raw-decode path in _decrypt_bytes is NOT triggered — a clean success/failure signal. """ try: decrypt_bytes_to_string(data, key=ENCRYPTION_KEY_SECRET) return True except Exception: return False def _discover_encrypted_columns() -> list[tuple[type, str, list[str], bool]]: """Walk all ORM models and find columns using EncryptedString/EncryptedJson. Returns list of (ModelClass, column_attr_name, [pk_attr_names], is_json). """ results: list[tuple[type, str, list[str], bool]] = [] for mapper in Base.registry.mappers: model_cls = mapper.class_ pk_names = [col.key for col in mapper.primary_key] for prop in mapper.column_attrs: for col in prop.columns: if isinstance(col.type, EncryptedJson): results.append((model_cls, prop.key, pk_names, True)) elif isinstance(col.type, EncryptedString): results.append((model_cls, prop.key, pk_names, False)) return results def rotate_encryption_key( db_session: Session, old_key: str | None, dry_run: bool = False, ) -> dict[str, int]: """Decrypt all encrypted columns with old_key and re-encrypt with the current key. Args: db_session: Active database session. old_key: The previous encryption key. Pass None or "" if values were not previously encrypted with a key. dry_run: If True, count rows that need rotation without modifying data. Returns: Dict of "table.column" -> number of rows re-encrypted (or would be). Commits every _BATCH_SIZE rows so that locks are held briefly and progress is preserved on crash. Already-rotated rows are detected and skipped, making the operation safe to re-run. """ if not global_version.is_ee_version(): raise RuntimeError("EE mode is not enabled — rotation requires EE encryption.") if not ENCRYPTION_KEY_SECRET: raise RuntimeError( "ENCRYPTION_KEY_SECRET is not set — cannot rotate. Set the target encryption key in the environment before running." ) encrypted_columns = _discover_encrypted_columns() totals: dict[str, int] = {} for model_cls, col_name, pk_names, is_json in encrypted_columns: table_name: str = model_cls.__tablename__ # type: ignore[attr-defined] col_attr = getattr(model_cls, col_name) pk_attrs = [getattr(model_cls, pk) for pk in pk_names] # Read raw bytes directly, bypassing the TypeDecorator raw_col = col_attr.property.columns[0] stmt = select(*pk_attrs, raw_col.cast(LargeBinary)).where(col_attr.is_not(None)) rows = db_session.execute(stmt).all() reencrypted = 0 batch_pending = 0 for row in rows: raw_bytes: bytes | None = row[-1] if raw_bytes is None: continue if _can_decrypt_with_current_key(raw_bytes): continue try: if not old_key: decrypted_str = raw_bytes.decode("utf-8") else: decrypted_str = decrypt_bytes_to_string(raw_bytes, key=old_key) # For EncryptedJson, parse back to dict so the TypeDecorator # can json.dumps() it cleanly (avoids double-encoding). value: Any = json.loads(decrypted_str) if is_json else decrypted_str except (ValueError, UnicodeDecodeError) as e: pk_vals = [row[i] for i in range(len(pk_names))] logger.warning( f"Could not decrypt/parse {table_name}.{col_name} row {pk_vals} — skipping: {e}" ) continue if not dry_run: pk_filters = [pk_attr == row[i] for i, pk_attr in enumerate(pk_attrs)] update_stmt = ( update(model_cls).where(*pk_filters).values({col_name: value}) ) db_session.execute(update_stmt) batch_pending += 1 if batch_pending >= _BATCH_SIZE: db_session.commit() batch_pending = 0 reencrypted += 1 # Flush remaining rows in this column if batch_pending > 0: db_session.commit() if reencrypted > 0: totals[f"{table_name}.{col_name}"] = reencrypted logger.info( f"{'[DRY RUN] Would re-encrypt' if dry_run else 'Re-encrypted'} {reencrypted} value(s) in {table_name}.{col_name}" ) return totals ================================================ FILE: backend/onyx/db/saml.py ================================================ import datetime from typing import cast from uuid import UUID from sqlalchemy import and_ from sqlalchemy import func from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS from onyx.db.models import SamlAccount def upsert_saml_account( user_id: UUID, cookie: str, db_session: Session, expiration_offset: int = SESSION_EXPIRE_TIME_SECONDS, ) -> datetime.datetime: expires_at = func.now() + datetime.timedelta(seconds=expiration_offset) existing_saml_acc = ( db_session.query(SamlAccount) .filter(SamlAccount.user_id == user_id) .one_or_none() ) if existing_saml_acc: existing_saml_acc.encrypted_cookie = cookie existing_saml_acc.expires_at = cast(datetime.datetime, expires_at) existing_saml_acc.updated_at = func.now() saml_acc = existing_saml_acc else: saml_acc = SamlAccount( user_id=user_id, encrypted_cookie=cookie, expires_at=expires_at, ) db_session.add(saml_acc) db_session.commit() return saml_acc.expires_at async def get_saml_account( cookie: str, async_db_session: AsyncSession ) -> SamlAccount | None: """NOTE: this is async, since it's used during auth (which is necessarily async due to FastAPI Users)""" stmt = ( select(SamlAccount) .options(selectinload(SamlAccount.user)) # Use selectinload for collections .where( and_( SamlAccount.encrypted_cookie == cookie, SamlAccount.expires_at > func.now(), ) ) ) result = await async_db_session.execute(stmt) return result.scalars().unique().one_or_none() async def expire_saml_account( saml_account: SamlAccount, async_db_session: AsyncSession ) -> None: saml_account.expires_at = func.now() await async_db_session.commit() ================================================ FILE: backend/onyx/db/search_settings.py ================================================ from sqlalchemy import and_ from sqlalchemy import delete from sqlalchemy import select from sqlalchemy.orm import Session from onyx.configs.model_configs import DEFAULT_DOCUMENT_ENCODER_MODEL from onyx.configs.model_configs import DOCUMENT_ENCODER_MODEL from onyx.context.search.models import SavedSearchSettings from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.llm import fetch_embedding_provider from onyx.db.models import CloudEmbeddingProvider from onyx.db.models import IndexAttempt from onyx.db.models import IndexModelStatus from onyx.db.models import SearchSettings from onyx.server.manage.embedding.models import ( CloudEmbeddingProvider as ServerCloudEmbeddingProvider, ) from onyx.utils.logger import setup_logger from shared_configs.configs import PRESERVED_SEARCH_FIELDS from shared_configs.enums import EmbeddingProvider logger = setup_logger() class ActiveSearchSettings: primary: SearchSettings secondary: SearchSettings | None def __init__( self, primary: SearchSettings, secondary: SearchSettings | None ) -> None: self.primary = primary self.secondary = secondary def create_search_settings( search_settings: SavedSearchSettings, db_session: Session, status: IndexModelStatus = IndexModelStatus.FUTURE, ) -> SearchSettings: embedding_model = SearchSettings( model_name=search_settings.model_name, model_dim=search_settings.model_dim, normalize=search_settings.normalize, query_prefix=search_settings.query_prefix, passage_prefix=search_settings.passage_prefix, status=status, index_name=search_settings.index_name, provider_type=search_settings.provider_type, multipass_indexing=search_settings.multipass_indexing, embedding_precision=search_settings.embedding_precision, reduced_dimension=search_settings.reduced_dimension, enable_contextual_rag=search_settings.enable_contextual_rag, contextual_rag_llm_name=search_settings.contextual_rag_llm_name, contextual_rag_llm_provider=search_settings.contextual_rag_llm_provider, switchover_type=search_settings.switchover_type, ) db_session.add(embedding_model) db_session.commit() return embedding_model def get_embedding_provider_from_provider_type( db_session: Session, provider_type: EmbeddingProvider ) -> CloudEmbeddingProvider | None: query = select(CloudEmbeddingProvider).where( CloudEmbeddingProvider.provider_type == provider_type ) provider = db_session.execute(query).scalars().first() return provider if provider else None def get_current_db_embedding_provider( db_session: Session, ) -> ServerCloudEmbeddingProvider | None: search_settings = get_current_search_settings(db_session=db_session) if search_settings.provider_type is None: return None embedding_provider = fetch_embedding_provider( db_session=db_session, provider_type=search_settings.provider_type, ) if embedding_provider is None: raise RuntimeError("No embedding provider exists for this model.") current_embedding_provider = ServerCloudEmbeddingProvider.from_request( cloud_provider_model=embedding_provider ) return current_embedding_provider def delete_search_settings(db_session: Session, search_settings_id: int) -> None: current_settings = get_current_search_settings(db_session) if current_settings.id == search_settings_id: raise ValueError("Cannot delete currently active search settings") # First, delete associated index attempts index_attempts_query = delete(IndexAttempt).where( IndexAttempt.search_settings_id == search_settings_id ) db_session.execute(index_attempts_query) # Then, delete the search settings search_settings_query = delete(SearchSettings).where( and_( SearchSettings.id == search_settings_id, SearchSettings.status != IndexModelStatus.PRESENT, ) ) db_session.execute(search_settings_query) db_session.commit() def get_current_search_settings(db_session: Session) -> SearchSettings: query = ( select(SearchSettings) .where(SearchSettings.status == IndexModelStatus.PRESENT) .order_by(SearchSettings.id.desc()) ) result = db_session.execute(query) latest_settings = result.scalars().first() if not latest_settings: raise RuntimeError("No search settings specified; DB is not in a valid state.") return latest_settings def get_secondary_search_settings(db_session: Session) -> SearchSettings | None: query = ( select(SearchSettings) .where(SearchSettings.status == IndexModelStatus.FUTURE) .order_by(SearchSettings.id.desc()) ) result = db_session.execute(query) latest_settings = result.scalars().first() return latest_settings def get_active_search_settings(db_session: Session) -> ActiveSearchSettings: """Returns active search settings. Secondary search settings may be None.""" # Get the primary and secondary search settings primary_search_settings = get_current_search_settings(db_session) secondary_search_settings = get_secondary_search_settings(db_session) return ActiveSearchSettings( primary=primary_search_settings, secondary=secondary_search_settings ) def get_active_search_settings_list(db_session: Session) -> list[SearchSettings]: """Returns active search settings as a list. Primary settings are the first element, and if secondary search settings exist, they will be the second element.""" search_settings_list: list[SearchSettings] = [] active_search_settings = get_active_search_settings(db_session) search_settings_list.append(active_search_settings.primary) if active_search_settings.secondary: search_settings_list.append(active_search_settings.secondary) return search_settings_list def get_all_search_settings(db_session: Session) -> list[SearchSettings]: query = select(SearchSettings).order_by(SearchSettings.id.desc()) result = db_session.execute(query) all_settings = result.scalars().all() return list(all_settings) def get_multilingual_expansion(db_session: Session | None = None) -> list[str]: if db_session is None: with get_session_with_current_tenant() as db_session: search_settings = get_current_search_settings(db_session) else: search_settings = get_current_search_settings(db_session) if not search_settings: return [] return search_settings.multilingual_expansion def update_search_settings( current_settings: SearchSettings, updated_settings: SavedSearchSettings, preserved_fields: list[str], ) -> None: for field, value in updated_settings.dict().items(): if field not in preserved_fields: setattr(current_settings, field, value) def update_current_search_settings( db_session: Session, search_settings: SavedSearchSettings, preserved_fields: list[str] = PRESERVED_SEARCH_FIELDS, ) -> None: current_settings = get_current_search_settings(db_session) if not current_settings: logger.warning("No current search settings found to update") return update_search_settings(current_settings, search_settings, preserved_fields) db_session.commit() logger.info("Current search settings updated successfully") def update_secondary_search_settings( db_session: Session, search_settings: SavedSearchSettings, preserved_fields: list[str] = PRESERVED_SEARCH_FIELDS, ) -> None: secondary_settings = get_secondary_search_settings(db_session) if not secondary_settings: logger.warning("No secondary search settings found to update") return preserved_fields = PRESERVED_SEARCH_FIELDS update_search_settings(secondary_settings, search_settings, preserved_fields) db_session.commit() logger.info("Secondary search settings updated successfully") def update_search_settings_status( search_settings: SearchSettings, new_status: IndexModelStatus, db_session: Session ) -> None: search_settings.status = new_status db_session.commit() def user_has_overridden_embedding_model() -> bool: return DOCUMENT_ENCODER_MODEL != DEFAULT_DOCUMENT_ENCODER_MODEL ================================================ FILE: backend/onyx/db/seeding/chat_history_seeding.py ================================================ import random from datetime import datetime from datetime import timedelta from logging import getLogger from uuid import UUID from onyx.configs.constants import MessageType from onyx.db.chat import create_chat_session from onyx.db.chat import create_new_chat_message from onyx.db.chat import get_or_create_root_message from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import ChatSession logger = getLogger(__name__) def seed_chat_history( num_sessions: int, num_messages: int, days: int, user_id: UUID | None = None, persona_id: int | None = None, ) -> None: """Utility function to seed chat history for testing. num_sessions: the number of sessions to seed num_messages: the number of messages to seed per sessions days: the number of days looking backwards from the current time over which to randomize the times. user_id: optional user to associate with sessions persona_id: optional persona/assistant to associate with sessions """ with get_session_with_current_tenant() as db_session: logger.info(f"Seeding {num_sessions} sessions.") for y in range(0, num_sessions): create_chat_session(db_session, f"pytest_session_{y}", user_id, persona_id) # randomize all session times logger.info(f"Seeding {num_messages} messages per session.") rows = db_session.query(ChatSession).all() for x in range(0, len(rows)): if x % 1024 == 0: logger.info(f"Seeded messages for {x} sessions so far.") row = rows[x] row.time_created = datetime.utcnow() - timedelta( days=random.randint(0, days) ) row.time_updated = row.time_created + timedelta( minutes=random.randint(0, 10) ) root_message = get_or_create_root_message(row.id, db_session) current_message_type = MessageType.USER parent_message = root_message for x in range(0, num_messages): if current_message_type == MessageType.USER: msg = f"pytest_message_user_{x}" else: msg = f"pytest_message_assistant_{x}" chat_message = create_new_chat_message( chat_session_id=row.id, parent_message=parent_message, message=msg, token_count=0, message_type=current_message_type, commit=False, db_session=db_session, ) chat_message.time_sent = row.time_created + timedelta( minutes=random.randint(0, 10) ) db_session.commit() current_message_type = ( MessageType.ASSISTANT if current_message_type == MessageType.USER else MessageType.USER ) parent_message = chat_message db_session.commit() logger.info(f"Seeded messages for {len(rows)} sessions. Finished.") ================================================ FILE: backend/onyx/db/slack_bot.py ================================================ from collections.abc import Sequence from sqlalchemy import select from sqlalchemy.orm import Session from onyx.db.models import SlackBot def insert_slack_bot( db_session: Session, name: str, enabled: bool, bot_token: str, app_token: str, user_token: str | None = None, ) -> SlackBot: slack_bot = SlackBot( name=name, enabled=enabled, bot_token=bot_token, app_token=app_token, user_token=user_token, ) db_session.add(slack_bot) db_session.commit() return slack_bot def update_slack_bot( db_session: Session, slack_bot_id: int, name: str, enabled: bool, bot_token: str, app_token: str, user_token: str | None = None, ) -> SlackBot: slack_bot = db_session.scalar(select(SlackBot).where(SlackBot.id == slack_bot_id)) if slack_bot is None: raise ValueError(f"Unable to find Slack Bot with ID {slack_bot_id}") # update the app slack_bot.name = name slack_bot.enabled = enabled slack_bot.bot_token = bot_token # type: ignore[assignment] slack_bot.app_token = app_token # type: ignore[assignment] slack_bot.user_token = user_token # type: ignore[assignment] db_session.commit() return slack_bot def fetch_slack_bot( db_session: Session, slack_bot_id: int, ) -> SlackBot: slack_bot = db_session.scalar(select(SlackBot).where(SlackBot.id == slack_bot_id)) if slack_bot is None: raise ValueError(f"Unable to find Slack Bot with ID {slack_bot_id}") return slack_bot def remove_slack_bot( db_session: Session, slack_bot_id: int, ) -> None: slack_bot = fetch_slack_bot( db_session=db_session, slack_bot_id=slack_bot_id, ) db_session.delete(slack_bot) db_session.commit() def fetch_slack_bots(db_session: Session) -> Sequence[SlackBot]: return db_session.scalars(select(SlackBot)).all() ================================================ FILE: backend/onyx/db/slack_channel_config.py ================================================ from collections.abc import Sequence from typing import Any from sqlalchemy import select from sqlalchemy.orm import joinedload from sqlalchemy.orm import Session from onyx.db.constants import DEFAULT_PERSONA_SLACK_CHANNEL_NAME from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX from onyx.db.models import ChannelConfig from onyx.db.models import Persona from onyx.db.models import Persona__DocumentSet from onyx.db.models import SlackChannelConfig from onyx.db.models import User from onyx.db.persona import mark_persona_as_deleted from onyx.db.persona import upsert_persona from onyx.db.tools import get_builtin_tool from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.utils.errors import EERequiredError from onyx.utils.variable_functionality import ( fetch_versioned_implementation_with_fallback, ) def _build_persona_name(channel_name: str | None) -> str: return f"{SLACK_BOT_PERSONA_PREFIX}{channel_name if channel_name else DEFAULT_PERSONA_SLACK_CHANNEL_NAME}" def _cleanup_relationships(db_session: Session, persona_id: int) -> None: """NOTE: does not commit changes""" # delete existing persona-document_set relationships existing_relationships = db_session.scalars( select(Persona__DocumentSet).where( Persona__DocumentSet.persona_id == persona_id ) ) for rel in existing_relationships: db_session.delete(rel) def create_slack_channel_persona( db_session: Session, channel_name: str | None, document_set_ids: list[int], existing_persona_id: int | None = None, ) -> Persona: """NOTE: does not commit changes""" search_tool = get_builtin_tool(db_session=db_session, tool_type=SearchTool) # create/update persona associated with the Slack channel persona_name = _build_persona_name(channel_name) persona_id_to_update = existing_persona_id if persona_id_to_update is None: # Reuse any previous Slack persona for this channel (even if the config was # temporarily switched to a different persona) so we don't trip duplicate name # validation inside `upsert_persona`. existing_persona = db_session.scalar( select(Persona).where(Persona.name == persona_name) ) if existing_persona: persona_id_to_update = existing_persona.id persona = upsert_persona( user=None, # Slack channel Personas are not attached to users persona_id=persona_id_to_update, name=persona_name, description="", system_prompt="", task_prompt="", datetime_aware=True, tool_ids=[search_tool.id], document_set_ids=document_set_ids, llm_model_provider_override=None, llm_model_version_override=None, starter_messages=None, is_public=True, is_featured=False, db_session=db_session, commit=False, ) return persona def _no_ee_standard_answer_categories( *args: Any, # noqa: ARG001 **kwargs: Any, # noqa: ARG001 ) -> list: return [] def insert_slack_channel_config( db_session: Session, slack_bot_id: int, persona_id: int | None, channel_config: ChannelConfig, standard_answer_category_ids: list[int], enable_auto_filters: bool, is_default: bool = False, ) -> SlackChannelConfig: versioned_fetch_standard_answer_categories_by_ids = ( fetch_versioned_implementation_with_fallback( "onyx.db.standard_answer", "fetch_standard_answer_categories_by_ids", _no_ee_standard_answer_categories, ) ) existing_standard_answer_categories = ( versioned_fetch_standard_answer_categories_by_ids( standard_answer_category_ids=standard_answer_category_ids, db_session=db_session, ) ) if len(existing_standard_answer_categories) != len(standard_answer_category_ids): if len(existing_standard_answer_categories) == 0: raise EERequiredError( "Standard answers are a paid Enterprise Edition feature - enable EE or remove standard answer categories" ) else: raise ValueError( f"Some or all categories with ids {standard_answer_category_ids} do not exist" ) if is_default: existing_default = db_session.scalar( select(SlackChannelConfig).where( SlackChannelConfig.slack_bot_id == slack_bot_id, SlackChannelConfig.is_default is True, # type: ignore ) ) if existing_default: raise ValueError("A default config already exists for this Slack bot.") else: if "channel_name" not in channel_config: raise ValueError("Channel name is required for non-default configs.") slack_channel_config = SlackChannelConfig( slack_bot_id=slack_bot_id, persona_id=persona_id, channel_config=channel_config, standard_answer_categories=existing_standard_answer_categories, enable_auto_filters=enable_auto_filters, is_default=is_default, ) db_session.add(slack_channel_config) db_session.commit() return slack_channel_config def update_slack_channel_config( db_session: Session, slack_channel_config_id: int, persona_id: int | None, channel_config: ChannelConfig, standard_answer_category_ids: list[int], enable_auto_filters: bool, disabled: bool, # noqa: ARG001 ) -> SlackChannelConfig: slack_channel_config = db_session.scalar( select(SlackChannelConfig).where( SlackChannelConfig.id == slack_channel_config_id ) ) if slack_channel_config is None: raise ValueError( f"Unable to find Slack channel config with ID {slack_channel_config_id}" ) versioned_fetch_standard_answer_categories_by_ids = ( fetch_versioned_implementation_with_fallback( "onyx.db.standard_answer", "fetch_standard_answer_categories_by_ids", _no_ee_standard_answer_categories, ) ) existing_standard_answer_categories = ( versioned_fetch_standard_answer_categories_by_ids( standard_answer_category_ids=standard_answer_category_ids, db_session=db_session, ) ) if len(existing_standard_answer_categories) != len(standard_answer_category_ids): raise ValueError( f"Some or all categories with ids {standard_answer_category_ids} do not exist" ) # update the config slack_channel_config.persona_id = persona_id slack_channel_config.channel_config = channel_config slack_channel_config.standard_answer_categories = list( existing_standard_answer_categories ) slack_channel_config.enable_auto_filters = enable_auto_filters db_session.commit() return slack_channel_config def remove_slack_channel_config( db_session: Session, slack_channel_config_id: int, user: User, ) -> None: slack_channel_config = db_session.scalar( select(SlackChannelConfig).where( SlackChannelConfig.id == slack_channel_config_id ) ) if slack_channel_config is None: raise ValueError( f"Unable to find Slack channel config with ID {slack_channel_config_id}" ) existing_persona_id = slack_channel_config.persona_id if existing_persona_id: existing_persona = db_session.scalar( select(Persona).where(Persona.id == existing_persona_id) ) # if the existing persona was one created just for use with this Slack channel, # then clean it up if existing_persona and existing_persona.name.startswith( SLACK_BOT_PERSONA_PREFIX ): _cleanup_relationships( db_session=db_session, persona_id=existing_persona_id ) mark_persona_as_deleted( persona_id=existing_persona_id, user=user, db_session=db_session ) db_session.delete(slack_channel_config) db_session.commit() def fetch_slack_channel_configs( db_session: Session, slack_bot_id: int | None = None ) -> Sequence[SlackChannelConfig]: if not slack_bot_id: return db_session.scalars(select(SlackChannelConfig)).all() return db_session.scalars( select(SlackChannelConfig).where( SlackChannelConfig.slack_bot_id == slack_bot_id ) ).all() def fetch_slack_channel_config( db_session: Session, slack_channel_config_id: int ) -> SlackChannelConfig | None: return db_session.scalar( select(SlackChannelConfig).where( SlackChannelConfig.id == slack_channel_config_id ) ) def fetch_slack_channel_config_for_channel_or_default( db_session: Session, slack_bot_id: int, channel_name: str | None ) -> SlackChannelConfig | None: # attempt to find channel-specific config first if channel_name is not None: sc_config = db_session.scalar( select(SlackChannelConfig) .options(joinedload(SlackChannelConfig.persona)) .where( SlackChannelConfig.slack_bot_id == slack_bot_id, SlackChannelConfig.channel_config["channel_name"].astext == channel_name, ) ) else: sc_config = None if sc_config: return sc_config # if none found, see if there is a default default_sc = db_session.scalar( select(SlackChannelConfig) .options(joinedload(SlackChannelConfig.persona)) .where( SlackChannelConfig.slack_bot_id == slack_bot_id, SlackChannelConfig.is_default == True, # noqa: E712 ) ) return default_sc ================================================ FILE: backend/onyx/db/swap_index.py ================================================ import time from sqlalchemy.orm import Session from onyx.configs.app_configs import DISABLE_VECTOR_DB from onyx.configs.app_configs import VESPA_NUM_ATTEMPTS_ON_STARTUP from onyx.configs.constants import KV_REINDEX_KEY from onyx.db.connector_credential_pair import get_connector_credential_pairs from onyx.db.connector_credential_pair import resync_cc_pair from onyx.db.document import delete_all_documents_for_connector_credential_pair from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import IndexModelStatus from onyx.db.enums import SwitchoverType from onyx.db.index_attempt import cancel_indexing_attempts_for_search_settings from onyx.db.index_attempt import ( count_unique_active_cc_pairs_with_successful_index_attempts, ) from onyx.db.index_attempt import count_unique_cc_pairs_with_successful_index_attempts from onyx.db.llm import update_default_contextual_model from onyx.db.llm import update_no_default_contextual_rag_provider from onyx.db.models import ConnectorCredentialPair from onyx.db.models import SearchSettings from onyx.db.search_settings import get_current_search_settings from onyx.db.search_settings import get_secondary_search_settings from onyx.db.search_settings import update_search_settings_status from onyx.document_index.factory import get_all_document_indices from onyx.key_value_store.factory import get_kv_store from onyx.utils.logger import setup_logger logger = setup_logger() def _perform_index_swap( db_session: Session, new_search_settings: SearchSettings, all_cc_pairs: list[ConnectorCredentialPair], cleanup_documents: bool = False, ) -> SearchSettings | None: """Swap the indices and expire the old one. Returns the old search settings if the swap was successful, otherwise None. """ current_search_settings = get_current_search_settings(db_session) if len(all_cc_pairs) > 0: kv_store = get_kv_store() kv_store.store(KV_REINDEX_KEY, False) # Expire jobs for the now past index/embedding model cancel_indexing_attempts_for_search_settings( search_settings_id=current_search_settings.id, db_session=db_session, ) # Recount aggregates for cc_pair in all_cc_pairs: resync_cc_pair( cc_pair=cc_pair, # sync based on the new search settings search_settings_id=new_search_settings.id, db_session=db_session, ) if cleanup_documents: # clean up all DocumentByConnectorCredentialPair / Document rows, since we're # doing an instant swap and no documents will exist in the new index. for cc_pair in all_cc_pairs: delete_all_documents_for_connector_credential_pair( db_session=db_session, connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id, ) # swap over search settings update_search_settings_status( search_settings=current_search_settings, new_status=IndexModelStatus.PAST, db_session=db_session, ) update_search_settings_status( search_settings=new_search_settings, new_status=IndexModelStatus.PRESENT, db_session=db_session, ) # Update the default contextual model to match the newly promoted settings try: update_default_contextual_model( db_session=db_session, enable_contextual_rag=new_search_settings.enable_contextual_rag, contextual_rag_llm_provider=new_search_settings.contextual_rag_llm_provider, contextual_rag_llm_name=new_search_settings.contextual_rag_llm_name, ) except ValueError as e: logger.error(f"Model not found, defaulting to no contextual model: {e}") update_no_default_contextual_rag_provider( db_session=db_session, ) new_search_settings.enable_contextual_rag = False new_search_settings.contextual_rag_llm_provider = None new_search_settings.contextual_rag_llm_name = None db_session.commit() # This flow is for checking and possibly creating an index so we get all # indices. document_indices = get_all_document_indices(new_search_settings, None, None) WAIT_SECONDS = 5 for document_index in document_indices: success = False for x in range(VESPA_NUM_ATTEMPTS_ON_STARTUP): try: logger.notice( f"Document index {document_index.__class__.__name__} swap (attempt {x + 1}/{VESPA_NUM_ATTEMPTS_ON_STARTUP})..." ) document_index.ensure_indices_exist( primary_embedding_dim=new_search_settings.final_embedding_dim, primary_embedding_precision=new_search_settings.embedding_precision, # just finished swap, no more secondary index secondary_index_embedding_dim=None, secondary_index_embedding_precision=None, ) logger.notice("Document index swap complete.") success = True break except Exception: logger.exception( f"Document index swap for {document_index.__class__.__name__} did not succeed. " f"The document index services may not be ready yet. Retrying in {WAIT_SECONDS} seconds." ) time.sleep(WAIT_SECONDS) if not success: logger.error( f"Document index swap for {document_index.__class__.__name__} did not succeed. " f"Attempt limit reached. ({VESPA_NUM_ATTEMPTS_ON_STARTUP})" ) return None return current_search_settings def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None: """Get count of cc-pairs and count of successful index_attempts for the new model grouped by connector + credential, if it's the same, then assume new index is done building. If so, swap the indices and expire the old one. Returns None if search settings did not change, or the old search settings if they did change. """ if DISABLE_VECTOR_DB: return None # Default CC-pair created for Ingestion API unused here all_cc_pairs = get_connector_credential_pairs(db_session) cc_pair_count = max(len(all_cc_pairs) - 1, 0) new_search_settings = get_secondary_search_settings(db_session) if not new_search_settings: return None # Handle switchover based on switchover_type switchover_type = new_search_settings.switchover_type # INSTANT: Swap immediately without waiting if switchover_type == SwitchoverType.INSTANT: return _perform_index_swap( db_session=db_session, new_search_settings=new_search_settings, all_cc_pairs=all_cc_pairs, # clean up all DocumentByConnectorCredentialPair / Document rows, since we're # doing an instant swap. cleanup_documents=True, ) # REINDEX: Wait for all connectors to complete elif switchover_type == SwitchoverType.REINDEX: unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts( search_settings_id=new_search_settings.id, db_session=db_session ) # Index Attempts are cleaned up as well when the cc-pair is deleted so the logic in this # function is correct. The unique_cc_indexings are specifically for the existing cc-pairs if unique_cc_indexings > cc_pair_count: logger.error("More unique indexings than cc pairs, should not occur") if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings: # Swap indices return _perform_index_swap( db_session=db_session, new_search_settings=new_search_settings, all_cc_pairs=all_cc_pairs, ) return None # ACTIVE_ONLY: Wait for only non-paused connectors to complete elif switchover_type == SwitchoverType.ACTIVE_ONLY: # Count non-paused cc_pairs (excluding the default Ingestion API cc_pair) active_cc_pairs = [ cc_pair for cc_pair in all_cc_pairs if cc_pair.status != ConnectorCredentialPairStatus.PAUSED ] active_cc_pair_count = max(len(active_cc_pairs) - 1, 0) unique_active_cc_indexings = ( count_unique_active_cc_pairs_with_successful_index_attempts( search_settings_id=new_search_settings.id, db_session=db_session ) ) if unique_active_cc_indexings > active_cc_pair_count: logger.error( "More unique active indexings than active cc pairs, should not occur" ) if ( active_cc_pair_count == 0 or active_cc_pair_count == unique_active_cc_indexings ): # Swap indices return _perform_index_swap( db_session=db_session, new_search_settings=new_search_settings, all_cc_pairs=all_cc_pairs, ) return None # Should not reach here, but handle gracefully logger.error(f"Unknown switchover_type: {switchover_type}") return None ================================================ FILE: backend/onyx/db/sync_record.py ================================================ from sqlalchemy import and_ from sqlalchemy import desc from sqlalchemy import func from sqlalchemy import select from sqlalchemy import update from sqlalchemy.orm import Session from onyx.db.enums import SyncStatus from onyx.db.enums import SyncType from onyx.db.models import SyncRecord from onyx.utils.logger import setup_logger logger = setup_logger() def insert_sync_record( db_session: Session, entity_id: int, sync_type: SyncType, ) -> SyncRecord: """Insert a new sync record into the database, cancelling any existing in-progress records. Args: db_session: The database session to use entity_id: The ID of the entity being synced (document set ID, user group ID, etc.) sync_type: The type of sync operation """ # If an existing in-progress sync record exists, mark as cancelled existing_in_progress_sync_record = fetch_latest_sync_record( db_session, entity_id, sync_type, sync_status=SyncStatus.IN_PROGRESS ) if existing_in_progress_sync_record is not None: logger.info( f"Cancelling existing in-progress sync record {existing_in_progress_sync_record.id} " f"for entity_id={entity_id} sync_type={sync_type}" ) mark_sync_records_as_cancelled(db_session, entity_id, sync_type) return _create_sync_record(db_session, entity_id, sync_type) def mark_sync_records_as_cancelled( db_session: Session, entity_id: int | None, sync_type: SyncType, ) -> None: stmt = ( update(SyncRecord) .where( and_( SyncRecord.entity_id == entity_id, SyncRecord.sync_type == sync_type, SyncRecord.sync_status == SyncStatus.IN_PROGRESS, ) ) .values(sync_status=SyncStatus.CANCELED) ) db_session.execute(stmt) db_session.commit() def _create_sync_record( db_session: Session, entity_id: int | None, sync_type: SyncType, ) -> SyncRecord: """Create and insert a new sync record into the database.""" sync_record = SyncRecord( entity_id=entity_id, sync_type=sync_type, sync_status=SyncStatus.IN_PROGRESS, num_docs_synced=0, sync_start_time=func.now(), ) db_session.add(sync_record) db_session.commit() return sync_record def fetch_latest_sync_record( db_session: Session, entity_id: int, sync_type: SyncType, sync_status: SyncStatus | None = None, ) -> SyncRecord | None: """Fetch the most recent sync record for a given entity ID and status. Args: db_session: The database session to use entity_id: The ID of the entity to fetch sync record for sync_type: The type of sync operation """ stmt = ( select(SyncRecord) .where( and_( SyncRecord.entity_id == entity_id, SyncRecord.sync_type == sync_type, ) ) .order_by(desc(SyncRecord.sync_start_time)) .limit(1) ) if sync_status is not None: stmt = stmt.where(SyncRecord.sync_status == sync_status) result = db_session.execute(stmt) return result.scalar_one_or_none() def update_sync_record_status( db_session: Session, entity_id: int, sync_type: SyncType, sync_status: SyncStatus, num_docs_synced: int | None = None, ) -> None: """Update the status of a sync record. Args: db_session: The database session to use entity_id: The ID of the entity being synced sync_type: The type of sync operation sync_status: The new status to set num_docs_synced: Optional number of documents synced to update """ sync_record = fetch_latest_sync_record(db_session, entity_id, sync_type) if sync_record is None: raise ValueError( f"No sync record found for entity_id={entity_id} sync_type={sync_type}" ) sync_record.sync_status = sync_status if num_docs_synced is not None: sync_record.num_docs_synced = num_docs_synced if sync_status.is_terminal(): sync_record.sync_end_time = func.now() # type: ignore db_session.commit() def cleanup_sync_records( db_session: Session, entity_id: int, sync_type: SyncType ) -> None: """Cleanup sync records for a given entity ID and sync type by marking them as failed.""" stmt = ( update(SyncRecord) .where(SyncRecord.entity_id == entity_id) .where(SyncRecord.sync_type == sync_type) .where(SyncRecord.sync_status == SyncStatus.IN_PROGRESS) .values(sync_status=SyncStatus.CANCELED, sync_end_time=func.now()) ) db_session.execute(stmt) db_session.commit() ================================================ FILE: backend/onyx/db/tag.py ================================================ from typing import Any from sqlalchemy import and_ from sqlalchemy import delete from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.orm import Session from onyx.configs.constants import DocumentSource from onyx.db.models import Document from onyx.db.models import Document__Tag from onyx.db.models import Tag from onyx.utils.logger import setup_logger logger = setup_logger() def check_tag_validity(tag_key: str, tag_value: str) -> bool: """If a tag is too long, it should not be used (it will cause an error in Postgres as the unique constraint can only apply to entries that are less than 2704 bytes). Additionally, extremely long tags are not really usable / useful.""" if len(tag_key) + len(tag_value) > 255: logger.error( f"Tag with key '{tag_key}' and value '{tag_value}' is too long, cannot be used" ) return False return True def create_or_add_document_tag( tag_key: str, tag_value: str, source: DocumentSource, document_id: str, db_session: Session, ) -> Tag | None: if not check_tag_validity(tag_key, tag_value): return None document = db_session.get(Document, document_id) if not document: raise ValueError("Invalid Document, cannot attach Tags") # Use upsert to avoid race condition when multiple workers try to create the same tag insert_stmt = pg_insert(Tag).values( tag_key=tag_key, tag_value=tag_value, source=source, is_list=False, ) insert_stmt = insert_stmt.on_conflict_do_nothing( constraint="_tag_key_value_source_list_uc" ) db_session.execute(insert_stmt) # Now fetch the tag (either just inserted or already existed) tag_stmt = select(Tag).where( Tag.tag_key == tag_key, Tag.tag_value == tag_value, Tag.source == source, Tag.is_list.is_(False), ) tag = db_session.execute(tag_stmt).scalar_one() if tag not in document.tags: document.tags.append(tag) db_session.commit() return tag def create_or_add_document_tag_list( tag_key: str, tag_values: list[str], source: DocumentSource, document_id: str, db_session: Session, ) -> list[Tag]: valid_tag_values = [ tag_value for tag_value in tag_values if check_tag_validity(tag_key, tag_value) ] if not valid_tag_values: return [] document = db_session.get(Document, document_id) if not document: raise ValueError("Invalid Document, cannot attach Tags") # Use upsert to avoid race condition when multiple workers try to create the same tags for tag_value in valid_tag_values: insert_stmt = pg_insert(Tag).values( tag_key=tag_key, tag_value=tag_value, source=source, is_list=True, ) insert_stmt = insert_stmt.on_conflict_do_nothing( constraint="_tag_key_value_source_list_uc" ) db_session.execute(insert_stmt) # Now fetch all tags (either just inserted or already existed) all_tags_stmt = select(Tag).where( Tag.tag_key == tag_key, Tag.tag_value.in_(valid_tag_values), Tag.source == source, Tag.is_list.is_(True), ) all_tags = list(db_session.execute(all_tags_stmt).scalars().all()) for tag in all_tags: if tag not in document.tags: document.tags.append(tag) db_session.commit() return all_tags def upsert_document_tags( document_id: str, source: DocumentSource, metadata: dict[str, str | list[str]], db_session: Session, ) -> list[Tag]: document = db_session.get(Document, document_id) if not document: raise ValueError("Invalid Document, cannot attach Tags") old_tag_ids: set[int] = {tag.id for tag in document.tags} new_tags: list[Tag] = [] new_tag_ids: set[int] = set() for k, v in metadata.items(): if isinstance(v, list): new_tags.extend( create_or_add_document_tag_list(k, v, source, document_id, db_session) ) new_tag_ids.update({tag.id for tag in new_tags}) continue new_tag = create_or_add_document_tag(k, v, source, document_id, db_session) if new_tag: new_tag_ids.add(new_tag.id) new_tags.append(new_tag) delete_tags = old_tag_ids - new_tag_ids if delete_tags: delete_stmt = delete(Document__Tag).where( Document__Tag.document_id == document_id, Document__Tag.tag_id.in_(delete_tags), ) db_session.execute(delete_stmt) db_session.commit() return new_tags def find_tags( tag_key_prefix: str | None, tag_value_prefix: str | None, sources: list[DocumentSource] | None, limit: int | None, db_session: Session, # if set, both tag_key_prefix and tag_value_prefix must be a match require_both_to_match: bool = False, ) -> list[Tag]: query = select(Tag) if tag_key_prefix or tag_value_prefix: conditions = [] if tag_key_prefix: conditions.append(Tag.tag_key.ilike(f"{tag_key_prefix}%")) if tag_value_prefix: conditions.append(Tag.tag_value.ilike(f"{tag_value_prefix}%")) final_prefix_condition = ( and_(*conditions) if require_both_to_match else or_(*conditions) ) query = query.where(final_prefix_condition) if sources: query = query.where(Tag.source.in_(sources)) if limit: query = query.limit(limit) result = db_session.execute(query) tags = result.scalars().all() return list(tags) def get_structured_tags_for_document( document_id: str, db_session: Session ) -> dict[str, str | list[str]]: """Essentially returns the document metadata from postgres.""" document = db_session.get(Document, document_id) if not document: raise ValueError("Invalid Document, cannot find tags") document_metadata: dict[str, Any] = {} for tag in document.tags: if tag.is_list: document_metadata.setdefault(tag.tag_key, []) # should always be a list (if tag.is_list is always True for this key), but just in case if not isinstance(document_metadata[tag.tag_key], list): logger.warning( "Inconsistent is_list for document %s, tag_key %s", document_id, tag.tag_key, ) document_metadata[tag.tag_key] = [document_metadata[tag.tag_key]] document_metadata[tag.tag_key].append(tag.tag_value) continue # set value (ignore duplicate keys, though there should be none) document_metadata.setdefault(tag.tag_key, tag.tag_value) # should always be a value, but just in case (treat it as a list in this case) if isinstance(document_metadata[tag.tag_key], list): logger.warning( "Inconsistent is_list for document %s, tag_key %s", document_id, tag.tag_key, ) document_metadata[tag.tag_key] = [document_metadata[tag.tag_key]] return document_metadata def delete_document_tags_for_documents__no_commit( document_ids: list[str], db_session: Session ) -> None: stmt = delete(Document__Tag).where(Document__Tag.document_id.in_(document_ids)) db_session.execute(stmt) def delete_orphan_tags__no_commit(db_session: Session) -> None: orphan_tags_query = select(Tag.id).where( ~db_session.query(Document__Tag.tag_id) .filter(Document__Tag.tag_id == Tag.id) .exists() ) orphan_tags = db_session.execute(orphan_tags_query).scalars().all() if orphan_tags: delete_orphan_tags_stmt = delete(Tag).where(Tag.id.in_(orphan_tags)) db_session.execute(delete_orphan_tags_stmt) ================================================ FILE: backend/onyx/db/tasks.py ================================================ from datetime import datetime from sqlalchemy import desc from sqlalchemy import func from sqlalchemy import select from sqlalchemy.orm import Session from sqlalchemy.sql import delete from onyx.configs.app_configs import JOB_TIMEOUT from onyx.db.engine.time_utils import get_db_current_time from onyx.db.models import TaskQueueState from onyx.db.models import TaskStatus def get_latest_task( task_name: str, db_session: Session, ) -> TaskQueueState | None: stmt = ( select(TaskQueueState) .where(TaskQueueState.task_name == task_name) .order_by(desc(TaskQueueState.id)) .limit(1) ) result = db_session.execute(stmt) latest_task = result.scalars().first() return latest_task def get_latest_task_by_type( task_name: str, db_session: Session, ) -> TaskQueueState | None: stmt = ( select(TaskQueueState) .where(TaskQueueState.task_name.like(f"%{task_name}%")) .order_by(desc(TaskQueueState.id)) .limit(1) ) result = db_session.execute(stmt) latest_task = result.scalars().first() return latest_task def register_task( task_name: str, db_session: Session, task_id: str = "", status: TaskStatus = TaskStatus.PENDING, start_time: datetime | None = None, ) -> TaskQueueState: new_task = TaskQueueState( task_id=task_id, task_name=task_name, status=status, start_time=start_time, ) db_session.add(new_task) db_session.commit() return new_task def get_task_with_id( db_session: Session, task_id: str, ) -> TaskQueueState | None: return db_session.scalar( select(TaskQueueState).where(TaskQueueState.task_id == task_id) ) def delete_task_with_id( db_session: Session, task_id: str, ) -> None: db_session.execute(delete(TaskQueueState).where(TaskQueueState.task_id == task_id)) db_session.commit() def get_all_tasks_with_prefix( db_session: Session, task_name_prefix: str ) -> list[TaskQueueState]: return list( db_session.scalars( select(TaskQueueState).where( TaskQueueState.task_name.like(f"{task_name_prefix}_%") ) ) ) def mark_task_as_started_with_id( db_session: Session, task_id: str, ) -> None: task = get_task_with_id(db_session=db_session, task_id=task_id) if not task: raise RuntimeError(f"A task with the task-id {task_id=} does not exist") task.status = TaskStatus.STARTED db_session.commit() def mark_task_as_finished_with_id( db_session: Session, task_id: str, success: bool = True, ) -> None: task = get_task_with_id(db_session=db_session, task_id=task_id) if not task: raise RuntimeError(f"A task with the task-id {task_id=} does not exist") task.status = TaskStatus.SUCCESS if success else TaskStatus.FAILURE db_session.commit() def mark_task_start( task_name: str, db_session: Session, ) -> None: task = get_latest_task(task_name, db_session) if not task: raise ValueError(f"No task found with name {task_name}") task.start_time = func.now() # type: ignore db_session.commit() def mark_task_finished( task_name: str, db_session: Session, success: bool = True, ) -> None: latest_task = get_latest_task(task_name, db_session) if latest_task is None: raise ValueError(f"tasks for {task_name} do not exist") latest_task.status = TaskStatus.SUCCESS if success else TaskStatus.FAILURE db_session.commit() def check_task_is_live_and_not_timed_out( task: TaskQueueState, db_session: Session, timeout: int = JOB_TIMEOUT, ) -> bool: # We only care for live tasks to not create new periodic tasks if task.status in [TaskStatus.SUCCESS, TaskStatus.FAILURE]: return False current_db_time = get_db_current_time(db_session=db_session) last_update_time = task.register_time if task.start_time: last_update_time = max(task.register_time, task.start_time) time_elapsed = current_db_time - last_update_time return time_elapsed.total_seconds() < timeout ================================================ FILE: backend/onyx/db/token_limit.py ================================================ from collections.abc import Sequence from sqlalchemy import select from sqlalchemy.orm import Session from onyx.configs.constants import TokenRateLimitScope from onyx.db.models import TokenRateLimit from onyx.db.models import TokenRateLimit__UserGroup from onyx.server.token_rate_limits.models import TokenRateLimitArgs def fetch_all_user_token_rate_limits( db_session: Session, enabled_only: bool = False, ordered: bool = True, ) -> Sequence[TokenRateLimit]: query = select(TokenRateLimit).where( TokenRateLimit.scope == TokenRateLimitScope.USER ) if enabled_only: query = query.where(TokenRateLimit.enabled.is_(True)) if ordered: query = query.order_by(TokenRateLimit.created_at.desc()) return db_session.scalars(query).all() def fetch_all_global_token_rate_limits( db_session: Session, enabled_only: bool = False, ordered: bool = True, ) -> Sequence[TokenRateLimit]: query = select(TokenRateLimit).where( TokenRateLimit.scope == TokenRateLimitScope.GLOBAL ) if enabled_only: query = query.where(TokenRateLimit.enabled.is_(True)) if ordered: query = query.order_by(TokenRateLimit.created_at.desc()) token_rate_limits = db_session.scalars(query).all() return token_rate_limits def insert_user_token_rate_limit( db_session: Session, token_rate_limit_settings: TokenRateLimitArgs, ) -> TokenRateLimit: token_limit = TokenRateLimit( enabled=token_rate_limit_settings.enabled, token_budget=token_rate_limit_settings.token_budget, period_hours=token_rate_limit_settings.period_hours, scope=TokenRateLimitScope.USER, ) db_session.add(token_limit) db_session.commit() return token_limit def insert_global_token_rate_limit( db_session: Session, token_rate_limit_settings: TokenRateLimitArgs, ) -> TokenRateLimit: token_limit = TokenRateLimit( enabled=token_rate_limit_settings.enabled, token_budget=token_rate_limit_settings.token_budget, period_hours=token_rate_limit_settings.period_hours, scope=TokenRateLimitScope.GLOBAL, ) db_session.add(token_limit) db_session.commit() return token_limit def update_token_rate_limit( db_session: Session, token_rate_limit_id: int, token_rate_limit_settings: TokenRateLimitArgs, ) -> TokenRateLimit: token_limit = db_session.get(TokenRateLimit, token_rate_limit_id) if token_limit is None: raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found") token_limit.enabled = token_rate_limit_settings.enabled token_limit.token_budget = token_rate_limit_settings.token_budget token_limit.period_hours = token_rate_limit_settings.period_hours db_session.commit() return token_limit def delete_token_rate_limit( db_session: Session, token_rate_limit_id: int, ) -> None: token_limit = db_session.get(TokenRateLimit, token_rate_limit_id) if token_limit is None: raise ValueError(f"TokenRateLimit with id '{token_rate_limit_id}' not found") db_session.query(TokenRateLimit__UserGroup).filter( TokenRateLimit__UserGroup.rate_limit_id == token_rate_limit_id ).delete() db_session.delete(token_limit) db_session.commit() ================================================ FILE: backend/onyx/db/tools.py ================================================ from typing import Any from typing import cast from typing import Type from typing import TYPE_CHECKING from uuid import UUID from sqlalchemy import func from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy.orm import Session from onyx.db.constants import UNSET from onyx.db.constants import UnsetType from onyx.db.enums import MCPServerStatus from onyx.db.models import MCPServer from onyx.db.models import OAuthConfig from onyx.db.models import Tool from onyx.db.models import ToolCall from onyx.server.features.tool.models import Header from onyx.tools.built_in_tools import BUILT_IN_TOOL_TYPES from onyx.utils.headers import HeaderItemDict from onyx.utils.logger import setup_logger from onyx.utils.postgres_sanitization import sanitize_json_like from onyx.utils.postgres_sanitization import sanitize_string if TYPE_CHECKING: pass logger = setup_logger() def get_tools( db_session: Session, *, only_enabled: bool = False, only_connected_mcp: bool = False, only_openapi: bool = False, ) -> list[Tool]: query = select(Tool) if only_connected_mcp: # Keep tools that either: # 1. Don't have an MCP server (mcp_server_id IS NULL) - Non-MCP tools # 2. Have an MCP server that is connected - Connected MCP tools query = query.outerjoin(MCPServer, Tool.mcp_server_id == MCPServer.id).where( or_( Tool.mcp_server_id.is_(None), # Non-MCP tools (built-in, custom) MCPServer.status == MCPServerStatus.CONNECTED, # MCP tools connected ) ) if only_enabled: query = query.where(Tool.enabled.is_(True)) if only_openapi: query = query.where( Tool.openapi_schema.is_not(None), # To avoid showing rows that have JSON literal `null` stored in the column to the user. # tools from mcp servers will not have an openapi schema but it has `null`, so we need to exclude them. func.jsonb_typeof(Tool.openapi_schema) == "object", # Exclude built-in tools that happen to have an openapi_schema Tool.in_code_tool_id.is_(None), ) return list(db_session.scalars(query).all()) def get_tools_by_mcp_server_id( mcp_server_id: int, db_session: Session, *, only_enabled: bool = False, order_by_id: bool = False, ) -> list[Tool]: query = select(Tool).where(Tool.mcp_server_id == mcp_server_id) if only_enabled: query = query.where(Tool.enabled.is_(True)) if order_by_id: query = query.order_by(Tool.id) return list(db_session.scalars(query).all()) def get_tools_by_ids(tool_ids: list[int], db_session: Session) -> list[Tool]: if not tool_ids: return [] stmt = select(Tool).where(Tool.id.in_(tool_ids)) return list(db_session.scalars(stmt).all()) def get_tool_by_id(tool_id: int, db_session: Session) -> Tool: tool = db_session.scalar(select(Tool).where(Tool.id == tool_id)) if not tool: raise ValueError("Tool by specified id does not exist") return tool def get_tool_by_name(tool_name: str, db_session: Session) -> Tool: tool = db_session.scalar(select(Tool).where(Tool.name == tool_name)) if not tool: raise ValueError("Tool by specified name does not exist") return tool def create_tool__no_commit( name: str, description: str | None, openapi_schema: dict[str, Any] | None, custom_headers: list[Header] | None, user_id: UUID | None, db_session: Session, passthrough_auth: bool, *, mcp_server_id: int | None = None, oauth_config_id: int | None = None, enabled: bool = True, ) -> Tool: new_tool = Tool( name=name, description=description, in_code_tool_id=None, openapi_schema=openapi_schema, custom_headers=( [header.model_dump() for header in custom_headers] if custom_headers else [] ), user_id=user_id, passthrough_auth=passthrough_auth, mcp_server_id=mcp_server_id, oauth_config_id=oauth_config_id, enabled=enabled, ) db_session.add(new_tool) db_session.flush() # Don't commit yet, let caller decide when to commit return new_tool def update_tool( tool_id: int, name: str | None, description: str | None, openapi_schema: dict[str, Any] | None, custom_headers: list[Header] | None, user_id: UUID | None, db_session: Session, passthrough_auth: bool | None, oauth_config_id: int | None | UnsetType = UNSET, ) -> Tool: tool = get_tool_by_id(tool_id, db_session) if tool is None: raise ValueError(f"Tool with ID {tool_id} does not exist") if name is not None: tool.name = name if description is not None: tool.description = description if openapi_schema is not None: tool.openapi_schema = openapi_schema if user_id is not None: tool.user_id = user_id if custom_headers is not None: tool.custom_headers = [ cast(HeaderItemDict, header.model_dump()) for header in custom_headers ] if passthrough_auth is not None: tool.passthrough_auth = passthrough_auth old_oauth_config_id = tool.oauth_config_id if not isinstance(oauth_config_id, UnsetType): tool.oauth_config_id = oauth_config_id db_session.flush() # Clean up orphaned OAuthConfig if the oauth_config_id was changed if ( old_oauth_config_id is not None and not isinstance(oauth_config_id, UnsetType) and old_oauth_config_id != oauth_config_id ): other_tools = db_session.scalars( select(Tool).where(Tool.oauth_config_id == old_oauth_config_id) ).all() if not other_tools: oauth_config = db_session.get(OAuthConfig, old_oauth_config_id) if oauth_config: db_session.delete(oauth_config) db_session.commit() return tool def delete_tool__no_commit(tool_id: int, db_session: Session) -> None: tool = get_tool_by_id(tool_id, db_session) if tool is None: raise ValueError(f"Tool with ID {tool_id} does not exist") oauth_config_id = tool.oauth_config_id db_session.delete(tool) db_session.flush() # Clean up orphaned OAuthConfig if no other tools reference it if oauth_config_id is not None: other_tools = db_session.scalars( select(Tool).where(Tool.oauth_config_id == oauth_config_id) ).all() if not other_tools: oauth_config = db_session.get(OAuthConfig, oauth_config_id) if oauth_config: db_session.delete(oauth_config) db_session.flush() def get_builtin_tool( db_session: Session, tool_type: Type[BUILT_IN_TOOL_TYPES], ) -> Tool: """ Retrieves a built-in tool from the database based on the tool type. """ # local import to avoid circular import. DB layer should not depend on tools layer. from onyx.tools.built_in_tools import BUILT_IN_TOOL_MAP tool_id = next( ( in_code_tool_id for in_code_tool_id, tool_cls in BUILT_IN_TOOL_MAP.items() if tool_cls.__name__ == tool_type.__name__ ), None, ) if not tool_id: raise RuntimeError( f"Tool type {tool_type.__name__} not found in the BUILT_IN_TOOLS list." ) db_tool = db_session.execute( select(Tool).where(Tool.in_code_tool_id == tool_id) ).scalar_one_or_none() if not db_tool: raise RuntimeError(f"Tool type {tool_type.__name__} not found in the database.") return db_tool def create_tool_call_no_commit( chat_session_id: UUID, parent_chat_message_id: int | None, turn_number: int, tool_id: int, tool_call_id: str, tool_call_arguments: dict[str, Any], tool_call_response: Any, tool_call_tokens: int, db_session: Session, *, parent_tool_call_id: int | None = None, reasoning_tokens: str | None = None, generated_images: list[dict] | None = None, tab_index: int = 0, add_only: bool = True, ) -> ToolCall: """ Create a ToolCall entry in the database. Args: chat_session_id: The chat session ID parent_chat_message_id: The parent chat message ID turn_number: The turn number for this tool call tool_id: The tool ID tool_call_id: The tool call ID (string identifier from LLM) tool_call_arguments: The tool call arguments tool_call_response: The tool call response tool_call_tokens: The number of tokens in the tool call arguments db_session: The database session parent_tool_call_id: Optional parent tool call ID (for nested tool calls) reasoning_tokens: Optional reasoning tokens generated_images: Optional list of generated image metadata for replay tab_index: Index order of tool calls from the LLM for parallel tool calls commit: If True, commit the transaction; if False, flush only Returns: The created ToolCall object """ tool_call = ToolCall( chat_session_id=chat_session_id, parent_chat_message_id=parent_chat_message_id, parent_tool_call_id=parent_tool_call_id, turn_number=turn_number, tab_index=tab_index, tool_id=tool_id, tool_call_id=tool_call_id, reasoning_tokens=( sanitize_string(reasoning_tokens) if reasoning_tokens else reasoning_tokens ), tool_call_arguments=sanitize_json_like(tool_call_arguments), tool_call_response=sanitize_json_like(tool_call_response), tool_call_tokens=tool_call_tokens, generated_images=sanitize_json_like(generated_images), ) db_session.add(tool_call) if not add_only: db_session.add(tool_call) else: db_session.flush() return tool_call ================================================ FILE: backend/onyx/db/usage.py ================================================ """Database interactions for tenant usage tracking (cloud usage limits).""" from datetime import datetime from datetime import timezone from enum import Enum from pydantic import BaseModel from sqlalchemy import select from sqlalchemy.dialects.postgresql import insert as pg_insert from sqlalchemy.orm import Session from onyx.db.models import TenantUsage from onyx.utils.logger import setup_logger from shared_configs.configs import USAGE_LIMIT_WINDOW_SECONDS logger = setup_logger() class UsageType(str, Enum): """Types of usage that can be tracked and limited.""" LLM_COST = "llm_cost_cents" CHUNKS_INDEXED = "chunks_indexed" API_CALLS = "api_calls" NON_STREAMING_API_CALLS = "non_streaming_api_calls" class TenantUsageStats(BaseModel): """Current usage statistics for a tenant.""" window_start: datetime llm_cost_cents: float chunks_indexed: int api_calls: int non_streaming_api_calls: int class UsageLimitExceededError(Exception): """Raised when a tenant exceeds their usage limit.""" def __init__(self, usage_type: UsageType, current: float, limit: float): self.usage_type = usage_type self.current = current self.limit = limit super().__init__( f"Usage limit exceeded for {usage_type.value}: current usage {current}, limit {limit}" ) def get_current_window_start() -> datetime: """ Calculate the start of the current usage window. Uses fixed windows aligned to Monday 00:00 UTC for predictability. The window duration is configured via USAGE_LIMIT_WINDOW_SECONDS. """ now = datetime.now(timezone.utc) # For weekly windows (default), align to Monday 00:00 UTC if USAGE_LIMIT_WINDOW_SECONDS == 604800: # 1 week # Get the start of the current week (Monday) days_since_monday = now.weekday() window_start = now.replace( hour=0, minute=0, second=0, microsecond=0 ) - __import__("datetime").timedelta(days=days_since_monday) return window_start # For other window sizes, use epoch-aligned windows epoch = datetime(1970, 1, 1, tzinfo=timezone.utc) seconds_since_epoch = int((now - epoch).total_seconds()) window_number = seconds_since_epoch // USAGE_LIMIT_WINDOW_SECONDS window_start_seconds = window_number * USAGE_LIMIT_WINDOW_SECONDS return epoch + __import__("datetime").timedelta(seconds=window_start_seconds) def get_or_create_tenant_usage( db_session: Session, window_start: datetime | None = None, ) -> TenantUsage: """ Get or create the usage record for the current window. Uses INSERT ... ON CONFLICT DO UPDATE to atomically create or get the record, avoiding TOCTOU race conditions where two concurrent requests could both attempt to insert a new record. """ if window_start is None: window_start = get_current_window_start() # Atomic upsert: insert if not exists, or update a field to itself if exists # This ensures we always get back a valid row without race conditions stmt = ( pg_insert(TenantUsage) .values( window_start=window_start, llm_cost_cents=0.0, chunks_indexed=0, api_calls=0, non_streaming_api_calls=0, ) .on_conflict_do_update( index_elements=["window_start"], # No-op update: just set a field to its current value # This ensures the row is returned even on conflict set_={"llm_cost_cents": TenantUsage.llm_cost_cents}, ) .returning(TenantUsage) ) result = db_session.execute(stmt).scalar_one() db_session.flush() return result def get_tenant_usage_stats( db_session: Session, window_start: datetime | None = None, ) -> TenantUsageStats: """Get the current usage statistics for the tenant (read-only, no lock).""" if window_start is None: window_start = get_current_window_start() usage = db_session.execute( select(TenantUsage).where(TenantUsage.window_start == window_start) ).scalar_one_or_none() if usage is None: # No usage recorded yet for this window return TenantUsageStats( window_start=window_start, llm_cost_cents=0.0, chunks_indexed=0, api_calls=0, non_streaming_api_calls=0, ) return TenantUsageStats( window_start=usage.window_start, llm_cost_cents=usage.llm_cost_cents, chunks_indexed=usage.chunks_indexed, api_calls=usage.api_calls, non_streaming_api_calls=usage.non_streaming_api_calls, ) def increment_usage( db_session: Session, usage_type: UsageType, amount: float | int, ) -> None: """ Atomically increment a usage counter. Uses row-level locking to prevent race conditions. The caller should handle the transaction commit. """ usage = get_or_create_tenant_usage(db_session) if usage_type == UsageType.LLM_COST: usage.llm_cost_cents += float(amount) elif usage_type == UsageType.CHUNKS_INDEXED: usage.chunks_indexed += int(amount) elif usage_type == UsageType.API_CALLS: usage.api_calls += int(amount) elif usage_type == UsageType.NON_STREAMING_API_CALLS: usage.non_streaming_api_calls += int(amount) db_session.flush() def check_usage_limit( db_session: Session, usage_type: UsageType, limit: float | int, pending_amount: float | int = 0, ) -> None: """ Check if the current usage plus pending amount would exceed the limit. Args: db_session: Database session usage_type: Type of usage to check limit: The maximum allowed usage pending_amount: Amount about to be used (to check before committing) Raises: UsageLimitExceededError: If usage would exceed the limit """ stats = get_tenant_usage_stats(db_session) current_value: float if usage_type == UsageType.LLM_COST: current_value = stats.llm_cost_cents elif usage_type == UsageType.CHUNKS_INDEXED: current_value = float(stats.chunks_indexed) elif usage_type == UsageType.API_CALLS: current_value = float(stats.api_calls) elif usage_type == UsageType.NON_STREAMING_API_CALLS: current_value = float(stats.non_streaming_api_calls) else: current_value = 0.0 if current_value + pending_amount > limit: raise UsageLimitExceededError( usage_type=usage_type, current=current_value + pending_amount, limit=float(limit), ) ================================================ FILE: backend/onyx/db/user_file.py ================================================ import datetime from uuid import UUID from sqlalchemy import func from sqlalchemy import select from sqlalchemy.orm import joinedload from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from onyx.db.models import Persona from onyx.db.models import Project__UserFile from onyx.db.models import UserFile def fetch_chunk_counts_for_user_files( user_file_ids: list[str], db_session: Session, ) -> list[tuple[str, int]]: """ Return a list of (user_file_id, chunk_count) tuples. If a user_file_id is not found in the database, it will be returned with a chunk_count of 0. """ stmt = select(UserFile.id, UserFile.chunk_count).where( UserFile.id.in_(user_file_ids) ) results = db_session.execute(stmt).all() # Create a dictionary of user_file_id to chunk_count chunk_counts = {str(row.id): row.chunk_count or 0 for row in results} # Return a list of tuples, preserving `None` for documents not found or with # an unknown chunk count. Callers should handle the `None` case and fall # back to an existence check against the vector DB if necessary. return [ (user_file_id, chunk_counts.get(user_file_id, 0)) for user_file_id in user_file_ids ] def calculate_user_files_token_count(file_ids: list[UUID], db_session: Session) -> int: """Calculate total token count for specified files""" total_tokens = 0 # Get tokens from individual files if file_ids: file_tokens = ( db_session.query(func.sum(UserFile.token_count)) .filter(UserFile.id.in_(file_ids)) .scalar() or 0 ) total_tokens += file_tokens return total_tokens def fetch_user_project_ids_for_user_files( user_file_ids: list[str], db_session: Session, ) -> dict[str, list[int]]: """Fetch user project ids for specified user files""" user_file_uuid_ids = [UUID(user_file_id) for user_file_id in user_file_ids] stmt = select(Project__UserFile.user_file_id, Project__UserFile.project_id).where( Project__UserFile.user_file_id.in_(user_file_uuid_ids) ) rows = db_session.execute(stmt).all() user_file_id_to_project_ids: dict[str, list[int]] = { user_file_id: [] for user_file_id in user_file_ids } for user_file_id, project_id in rows: user_file_id_to_project_ids[str(user_file_id)].append(project_id) return user_file_id_to_project_ids def fetch_persona_ids_for_user_files( user_file_ids: list[str], db_session: Session, ) -> dict[str, list[int]]: """Fetch persona (assistant) ids for specified user files.""" stmt = ( select(UserFile) .where(UserFile.id.in_(user_file_ids)) .options(selectinload(UserFile.assistants)) ) results = db_session.execute(stmt).scalars().all() return { str(user_file.id): [persona.id for persona in user_file.assistants] for user_file in results } def update_last_accessed_at_for_user_files( user_file_ids: list[UUID], db_session: Session, ) -> None: """Update `last_accessed_at` to now (UTC) for the given user files.""" if not user_file_ids: return now = datetime.datetime.now(datetime.timezone.utc) ( db_session.query(UserFile) .filter(UserFile.id.in_(user_file_ids)) .update({UserFile.last_accessed_at: now}, synchronize_session=False) ) db_session.commit() def get_file_id_by_user_file_id(user_file_id: str, db_session: Session) -> str | None: user_file = db_session.query(UserFile).filter(UserFile.id == user_file_id).first() if user_file: return user_file.file_id return None def get_file_ids_by_user_file_ids( user_file_ids: list[UUID], db_session: Session ) -> list[str]: user_files = db_session.query(UserFile).filter(UserFile.id.in_(user_file_ids)).all() return [user_file.file_id for user_file in user_files] def fetch_user_files_with_access_relationships( user_file_ids: list[str], db_session: Session, eager_load_groups: bool = False, ) -> list[UserFile]: """Fetch user files with the owner and assistant relationships eagerly loaded (needed for computing access control). When eager_load_groups is True, Persona.groups is also loaded so that callers can extract user-group names without a second DB round-trip.""" persona_sub_options = [ selectinload(Persona.users), selectinload(Persona.user), ] if eager_load_groups: persona_sub_options.append(selectinload(Persona.groups)) return ( db_session.query(UserFile) .options( joinedload(UserFile.user), selectinload(UserFile.assistants).options(*persona_sub_options), ) .filter(UserFile.id.in_(user_file_ids)) .all() ) ================================================ FILE: backend/onyx/db/user_preferences.py ================================================ from collections.abc import Sequence from uuid import UUID from sqlalchemy import Column from sqlalchemy import delete from sqlalchemy import desc from sqlalchemy import select from sqlalchemy import update from sqlalchemy.orm import Session from onyx.auth.schemas import UserRole from onyx.db.enums import AccountType from onyx.db.enums import DefaultAppMode from onyx.db.enums import ThemePreference from onyx.db.models import AccessToken from onyx.db.models import Assistant__UserSpecificConfig from onyx.db.models import Memory from onyx.db.models import User from onyx.db.models import User__UserGroup from onyx.db.models import UserGroup from onyx.db.permissions import recompute_user_permissions__no_commit from onyx.db.users import assign_user_to_default_groups__no_commit from onyx.server.manage.models import MemoryItem from onyx.server.manage.models import UserSpecificAssistantPreference from onyx.utils.logger import setup_logger logger = setup_logger() _ROLE_TO_ACCOUNT_TYPE: dict[UserRole, AccountType] = { UserRole.SLACK_USER: AccountType.BOT, UserRole.EXT_PERM_USER: AccountType.EXT_PERM_USER, } def update_user_role( user: User, new_role: UserRole, db_session: Session, ) -> None: """Update a user's role in the database. Dual-writes account_type to keep it in sync with role and reconciles default-group membership (Admin / Basic).""" old_role = user.role user.role = new_role # Note: setting account_type to BOT or EXT_PERM_USER causes # assign_user_to_default_groups__no_commit to early-return, which is # intentional — these account types should not be in default groups. if new_role in _ROLE_TO_ACCOUNT_TYPE: user.account_type = _ROLE_TO_ACCOUNT_TYPE[new_role] elif user.account_type in (AccountType.BOT, AccountType.EXT_PERM_USER): # Upgrading from a non-web-login account type to a web role user.account_type = AccountType.STANDARD # Reconcile default-group membership when the role changes. if old_role != new_role: # Remove from all default groups first. db_session.execute( delete(User__UserGroup).where( User__UserGroup.user_id == user.id, User__UserGroup.user_group_id.in_( select(UserGroup.id).where(UserGroup.is_default.is_(True)) ), ) ) # Re-assign to the correct default group (skip for LIMITED). if new_role != UserRole.LIMITED: assign_user_to_default_groups__no_commit( db_session, user, is_admin=(new_role == UserRole.ADMIN), ) recompute_user_permissions__no_commit(user.id, db_session) db_session.commit() def deactivate_user( user: User, db_session: Session, ) -> None: """Deactivate a user by setting is_active to False.""" user.is_active = False db_session.add(user) db_session.commit() def activate_user( user: User, db_session: Session, ) -> None: """Activate a user by setting is_active to True. Also reconciles default-group membership — the user may have been created while inactive or deactivated before the backfill migration. """ user.is_active = True if user.role != UserRole.LIMITED: assign_user_to_default_groups__no_commit( db_session, user, is_admin=(user.role == UserRole.ADMIN) ) db_session.add(user) db_session.commit() def get_latest_access_token_for_user( user_id: UUID, db_session: Session, ) -> AccessToken | None: """Get the most recent access token for a user.""" try: result = db_session.execute( select(AccessToken) .where(AccessToken.user_id == user_id) # type: ignore .order_by(desc(Column("created_at"))) .limit(1) ) return result.scalar_one_or_none() except Exception as e: logger.error(f"Error fetching AccessToken: {e}") return None def update_user_temperature_override_enabled( user_id: UUID, temperature_override_enabled: bool, db_session: Session, ) -> None: """Update user's temperature override enabled setting.""" db_session.execute( update(User) .where(User.id == user_id) # type: ignore .values(temperature_override_enabled=temperature_override_enabled) ) db_session.commit() def update_user_shortcut_enabled( user_id: UUID, shortcut_enabled: bool, db_session: Session, ) -> None: """Update user's shortcut enabled setting.""" db_session.execute( update(User) .where(User.id == user_id) # type: ignore .values(shortcut_enabled=shortcut_enabled) ) db_session.commit() def update_user_auto_scroll( user_id: UUID, auto_scroll: bool | None, db_session: Session, ) -> None: """Update user's auto scroll setting.""" db_session.execute( update(User) .where(User.id == user_id) # type: ignore .values(auto_scroll=auto_scroll) ) db_session.commit() def update_user_default_model( user_id: UUID, default_model: str | None, db_session: Session, ) -> None: """Update user's default model setting.""" db_session.execute( update(User) .where(User.id == user_id) # type: ignore .values(default_model=default_model) ) db_session.commit() def update_user_theme_preference( user_id: UUID, theme_preference: ThemePreference, db_session: Session, ) -> None: """Update user's theme preference setting.""" db_session.execute( update(User) .where(User.id == user_id) # type: ignore .values(theme_preference=theme_preference) ) db_session.commit() def update_user_chat_background( user_id: UUID, chat_background: str | None, db_session: Session, ) -> None: """Update user's chat background setting.""" db_session.execute( update(User) .where(User.id == user_id) # type: ignore .values(chat_background=chat_background) ) db_session.commit() def update_user_default_app_mode( user_id: UUID, default_app_mode: DefaultAppMode, db_session: Session, ) -> None: """Update user's default app mode setting.""" db_session.execute( update(User) .where(User.id == user_id) # type: ignore .values(default_app_mode=default_app_mode) ) db_session.commit() def update_user_personalization( user_id: UUID, *, personal_name: str | None, personal_role: str | None, use_memories: bool, enable_memory_tool: bool, memories: list[MemoryItem], user_preferences: str | None, db_session: Session, ) -> None: db_session.execute( update(User) .where(User.id == user_id) # type: ignore .values( personal_name=personal_name, personal_role=personal_role, use_memories=use_memories, enable_memory_tool=enable_memory_tool, user_preferences=user_preferences, ) ) # ID-based upsert: use real DB IDs from the frontend to match memories. incoming_ids = {m.id for m in memories if m.id is not None} # Delete existing rows not in the incoming set (scoped to user_id) existing_memories = list( db_session.scalars(select(Memory).where(Memory.user_id == user_id)).all() ) existing_ids = {mem.id for mem in existing_memories} ids_to_delete = existing_ids - incoming_ids if ids_to_delete: db_session.execute( delete(Memory).where( Memory.id.in_(ids_to_delete), Memory.user_id == user_id, ) ) # Update existing rows whose IDs match existing_by_id = {mem.id: mem for mem in existing_memories} for item in memories: if item.id is not None and item.id in existing_by_id: existing_by_id[item.id].memory_text = item.content # Create new rows for items without an ID new_items = [m for m in memories if m.id is None] if new_items: db_session.add_all( [Memory(user_id=user_id, memory_text=item.content) for item in new_items] ) db_session.commit() def get_memories_for_user( user_id: UUID, db_session: Session, ) -> Sequence[Memory]: return db_session.scalars( select(Memory).where(Memory.user_id == user_id).order_by(Memory.id.desc()) ).all() def update_user_pinned_assistants( user_id: UUID, pinned_assistants: list[int], db_session: Session, ) -> None: """Update user's pinned assistants list.""" db_session.execute( update(User) .where(User.id == user_id) # type: ignore .values(pinned_assistants=pinned_assistants) ) db_session.commit() def update_user_assistant_visibility( user_id: UUID, hidden_assistants: list[int] | None, visible_assistants: list[int] | None, chosen_assistants: list[int] | None, db_session: Session, ) -> None: """Update user's assistant visibility settings.""" db_session.execute( update(User) .where(User.id == user_id) # type: ignore .values( hidden_assistants=hidden_assistants, visible_assistants=visible_assistants, chosen_assistants=chosen_assistants, ) ) db_session.commit() def get_all_user_assistant_specific_configs( user_id: UUID, db_session: Session, ) -> Sequence[Assistant__UserSpecificConfig]: """Get the full user assistant specific config for a specific assistant and user.""" return db_session.scalars( select(Assistant__UserSpecificConfig).where( Assistant__UserSpecificConfig.user_id == user_id ) ).all() def update_assistant_preferences( assistant_id: int, user_id: UUID, new_assistant_preference: UserSpecificAssistantPreference, db_session: Session, ) -> None: """Update the disabled tools for a specific assistant for a specific user.""" # First check if a config already exists result = db_session.execute( select(Assistant__UserSpecificConfig) .where(Assistant__UserSpecificConfig.assistant_id == assistant_id) .where(Assistant__UserSpecificConfig.user_id == user_id) ) config = result.scalar_one_or_none() if config: # Update existing config config.disabled_tool_ids = new_assistant_preference.disabled_tool_ids else: # Create new config config = Assistant__UserSpecificConfig( assistant_id=assistant_id, user_id=user_id, disabled_tool_ids=new_assistant_preference.disabled_tool_ids, ) db_session.add(config) db_session.commit() ================================================ FILE: backend/onyx/db/users.py ================================================ from collections.abc import Sequence from typing import Any from uuid import UUID from fastapi import HTTPException from fastapi_users.password import PasswordHelper from sqlalchemy import case from sqlalchemy import func from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from sqlalchemy.sql import expression from sqlalchemy.sql.elements import ColumnElement from sqlalchemy.sql.elements import KeyedColumnElement from sqlalchemy.sql.expression import or_ from onyx.auth.invited_users import remove_user_from_invited_users from onyx.auth.schemas import UserRole from onyx.configs.constants import ANONYMOUS_USER_EMAIL from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN from onyx.configs.constants import NO_AUTH_PLACEHOLDER_USER_EMAIL from onyx.db.enums import AccountType from onyx.db.models import DocumentSet from onyx.db.models import DocumentSet__User from onyx.db.models import Persona from onyx.db.models import Persona__User from onyx.db.models import SamlAccount from onyx.db.models import User from onyx.db.models import User__UserGroup from onyx.db.models import UserGroup from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop logger = setup_logger() def validate_user_role_update( requested_role: UserRole, current_role: UserRole, current_account_type: AccountType, explicit_override: bool = False, ) -> None: """ Validate that a user role update is valid. Assumed only admins can hit this endpoint. raise if: - requested role is a curator - requested role is a slack user - requested role is an external permissioned user - requested role is a limited user - current account type is BOT (slack user) - current account type is EXT_PERM_USER - current role is a limited user """ if current_account_type == AccountType.BOT: raise HTTPException( status_code=400, detail="To change a Slack User's role, they must first login to Onyx via the web app.", ) if current_account_type == AccountType.EXT_PERM_USER: raise HTTPException( status_code=400, detail="To change an External Permissioned User's role, they must first login to Onyx via the web app.", ) if current_role == UserRole.LIMITED: raise HTTPException( status_code=400, detail="To change a Limited User's role, they must first login to Onyx via the web app.", ) if explicit_override: return if requested_role == UserRole.CURATOR: # This shouldn't happen, but just in case raise HTTPException( status_code=400, detail="Curator role must be set via the User Group Menu", ) if requested_role == UserRole.LIMITED: # This shouldn't happen, but just in case raise HTTPException( status_code=400, detail=( "A user cannot be set to a Limited User role. " "This role is automatically assigned to users through certain endpoints in the API." ), ) if requested_role == UserRole.SLACK_USER: # This shouldn't happen, but just in case raise HTTPException( status_code=400, detail=( "A user cannot be set to a Slack User role. " "This role is automatically assigned to users who only use Onyx via Slack." ), ) if requested_role == UserRole.EXT_PERM_USER: # This shouldn't happen, but just in case raise HTTPException( status_code=400, detail=( "A user cannot be set to an External Permissioned User role. " "This role is automatically assigned to users who have been " "pulled in to the system via an external permissions system." ), ) def get_all_users( db_session: Session, email_filter_string: str | None = None, include_external: bool = False, ) -> Sequence[User]: """List all users. No pagination as of now, as the # of users is assumed to be relatively small (<< 1 million)""" stmt = select(User) # Exclude system users (anonymous user, no-auth placeholder) stmt = stmt.where(User.email != ANONYMOUS_USER_EMAIL) # type: ignore stmt = stmt.where(User.email != NO_AUTH_PLACEHOLDER_USER_EMAIL) # type: ignore if not include_external: stmt = stmt.where(User.role != UserRole.EXT_PERM_USER) if email_filter_string is not None: stmt = stmt.where(User.email.ilike(f"%{email_filter_string}%")) # type: ignore return db_session.scalars(stmt).unique().all() def _get_accepted_user_where_clause( email_filter_string: str | None = None, roles_filter: list[UserRole] = [], include_external: bool = False, is_active_filter: bool | None = None, ) -> list[ColumnElement[bool]]: """ Generates a SQLAlchemy where clause for filtering users based on the provided parameters. This is used to build the filters for the function that retrieves the users for the users table in the admin panel. Parameters: - email_filter_string: A substring to filter user emails. Only users whose emails contain this substring will be included. - is_active_filter: When True, only active users will be included. When False, only inactive users will be included. - roles_filter: A list of user roles to filter by. Only users with roles in this list will be included. - include_external: If False, external permissioned users will be excluded. Returns: - list: A list of conditions to be used in a SQLAlchemy query to filter users. """ # Access table columns directly via __table__.c to get proper SQLAlchemy column types # This ensures type checking works correctly for SQL operations like ilike, endswith, and is_ email_col: KeyedColumnElement[Any] = User.__table__.c.email is_active_col: KeyedColumnElement[Any] = User.__table__.c.is_active where_clause: list[ColumnElement[bool]] = [ expression.not_(email_col.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN)), # Exclude system users (anonymous user, no-auth placeholder) email_col != ANONYMOUS_USER_EMAIL, email_col != NO_AUTH_PLACEHOLDER_USER_EMAIL, ] if not include_external: where_clause.append(User.role != UserRole.EXT_PERM_USER) if email_filter_string is not None: personal_name_col: KeyedColumnElement[Any] = User.__table__.c.personal_name where_clause.append( or_( email_col.ilike(f"%{email_filter_string}%"), personal_name_col.ilike(f"%{email_filter_string}%"), ) ) if roles_filter: where_clause.append(User.role.in_(roles_filter)) if is_active_filter is not None: where_clause.append(is_active_col.is_(is_active_filter)) return where_clause def get_all_accepted_users( db_session: Session, include_external: bool = False, ) -> Sequence[User]: """Returns all accepted users without pagination. Uses the same filtering as the paginated endpoint but without search, role, or active filters.""" stmt = select(User) where_clause = _get_accepted_user_where_clause( include_external=include_external, ) stmt = stmt.where(*where_clause).order_by(User.email) return db_session.scalars(stmt).unique().all() def get_page_of_filtered_users( db_session: Session, page_size: int, page_num: int, email_filter_string: str | None = None, is_active_filter: bool | None = None, roles_filter: list[UserRole] = [], include_external: bool = False, ) -> Sequence[User]: users_stmt = select(User) where_clause = _get_accepted_user_where_clause( email_filter_string=email_filter_string, roles_filter=roles_filter, include_external=include_external, is_active_filter=is_active_filter, ) # Apply pagination users_stmt = users_stmt.offset((page_num) * page_size).limit(page_size) # Apply filtering users_stmt = users_stmt.where(*where_clause) return db_session.scalars(users_stmt).unique().all() def get_total_filtered_users_count( db_session: Session, email_filter_string: str | None = None, is_active_filter: bool | None = None, roles_filter: list[UserRole] = [], include_external: bool = False, ) -> int: where_clause = _get_accepted_user_where_clause( email_filter_string=email_filter_string, roles_filter=roles_filter, include_external=include_external, is_active_filter=is_active_filter, ) total_count_stmt = select(func.count()).select_from(User) # Apply filtering total_count_stmt = total_count_stmt.where(*where_clause) return db_session.scalar(total_count_stmt) or 0 def get_user_counts_by_role_and_status( db_session: Session, ) -> dict[str, dict[str, int]]: """Returns user counts grouped by role and by active/inactive status. Excludes API key users, anonymous users, and no-auth placeholder users. Uses a single query with conditional aggregation. """ base_where = _get_accepted_user_where_clause() role_col = User.__table__.c.role is_active_col = User.__table__.c.is_active stmt = ( select( role_col, func.count().label("total"), func.sum(case((is_active_col.is_(True), 1), else_=0)).label("active"), func.sum(case((is_active_col.is_(False), 1), else_=0)).label("inactive"), ) .where(*base_where) .group_by(role_col) ) role_counts: dict[str, int] = {} status_counts: dict[str, int] = {"active": 0, "inactive": 0} for role_val, total, active, inactive in db_session.execute(stmt).all(): key = role_val.value if hasattr(role_val, "value") else str(role_val) role_counts[key] = total status_counts["active"] += active or 0 status_counts["inactive"] += inactive or 0 return {"role_counts": role_counts, "status_counts": status_counts} def get_user_by_email(email: str, db_session: Session) -> User | None: user = ( db_session.query(User) .filter(func.lower(User.email) == func.lower(email)) .first() ) return user def fetch_user_by_id(db_session: Session, user_id: UUID) -> User | None: return db_session.query(User).filter(User.id == user_id).first() # type: ignore def _generate_slack_user(email: str) -> User: fastapi_users_pw_helper = PasswordHelper() password = fastapi_users_pw_helper.generate() hashed_pass = fastapi_users_pw_helper.hash(password) return User( email=email, hashed_password=hashed_pass, role=UserRole.SLACK_USER, account_type=AccountType.BOT, ) def add_slack_user_if_not_exists(db_session: Session, email: str) -> User: email = email.lower() user = get_user_by_email(email, db_session) if user is not None: # If the user is an external permissioned user, we update it to a slack user if user.account_type == AccountType.EXT_PERM_USER: user.role = UserRole.SLACK_USER user.account_type = AccountType.BOT db_session.commit() return user user = _generate_slack_user(email=email) db_session.add(user) db_session.commit() return user def _get_users_by_emails( db_session: Session, lower_emails: list[str] ) -> tuple[list[User], list[str]]: """given a list of lowercase emails, returns a list[User] of Users whose emails match and a list[str] the missing emails that had no User""" stmt = select(User).filter(func.lower(User.email).in_(lower_emails)) found_users = list(db_session.scalars(stmt).unique().all()) # Convert to list # Extract found emails and convert to lowercase to avoid case sensitivity issues found_users_emails = [user.email.lower() for user in found_users] # Separate emails for users that were not found missing_user_emails = [ email for email in lower_emails if email not in found_users_emails ] return found_users, missing_user_emails def _generate_ext_permissioned_user(email: str) -> User: fastapi_users_pw_helper = PasswordHelper() password = fastapi_users_pw_helper.generate() hashed_pass = fastapi_users_pw_helper.hash(password) return User( email=email, hashed_password=hashed_pass, role=UserRole.EXT_PERM_USER, account_type=AccountType.EXT_PERM_USER, ) def batch_add_ext_perm_user_if_not_exists( db_session: Session, emails: list[str], continue_on_error: bool = False ) -> list[User]: lower_emails = [email.lower() for email in emails] found_users, missing_lower_emails = _get_users_by_emails(db_session, lower_emails) # Use savepoints (begin_nested) so that a failed insert only rolls back # that single user, not the entire transaction. A plain rollback() would # discard all previously flushed users in the same transaction. # We also avoid add_all() because SQLAlchemy 2.0's insertmanyvalues # batch path hits a UUID sentinel mismatch with server_default columns. for email in missing_lower_emails: user = _generate_ext_permissioned_user(email=email) savepoint = db_session.begin_nested() try: db_session.add(user) savepoint.commit() except IntegrityError: savepoint.rollback() if not continue_on_error: raise db_session.commit() # Fetch all users again to ensure we have the most up-to-date list all_users, _ = _get_users_by_emails(db_session, lower_emails) return all_users def assign_user_to_default_groups__no_commit( db_session: Session, user: User, is_admin: bool = False, ) -> None: """Assign a newly created user to the appropriate default group. Does NOT commit — callers must commit the session themselves so that group assignment can be part of the same transaction as user creation. Args: is_admin: If True, assign to Admin default group; otherwise Basic. Callers determine this from their own context (e.g. user_count, admin email list, explicit choice). Defaults to False (Basic). """ if user.account_type in ( AccountType.BOT, AccountType.EXT_PERM_USER, AccountType.ANONYMOUS, ): return target_group_name = "Admin" if is_admin else "Basic" default_group = ( db_session.query(UserGroup) .filter( UserGroup.name == target_group_name, UserGroup.is_default.is_(True), ) .first() ) if default_group is None: raise RuntimeError( f"Default group '{target_group_name}' not found. " f"Cannot assign user {user.email} to a group. " f"Ensure the seed_default_groups migration has run." ) # Check if the user is already in the group existing = ( db_session.query(User__UserGroup) .filter( User__UserGroup.user_id == user.id, User__UserGroup.user_group_id == default_group.id, ) .first() ) if existing is not None: return savepoint = db_session.begin_nested() try: db_session.add( User__UserGroup( user_id=user.id, user_group_id=default_group.id, ) ) db_session.flush() except IntegrityError: # Race condition: another transaction inserted this membership # between our SELECT and INSERT. The savepoint isolates the failure # so the outer transaction (user creation) stays intact. savepoint.rollback() return from onyx.db.permissions import recompute_user_permissions__no_commit recompute_user_permissions__no_commit(user.id, db_session) logger.info(f"Assigned user {user.email} to default group '{default_group.name}'") def delete_user_from_db( user_to_delete: User, db_session: Session, ) -> None: for oauth_account in user_to_delete.oauth_accounts: db_session.delete(oauth_account) fetch_ee_implementation_or_noop( "onyx.db.external_perm", "delete_user__ext_group_for_user__no_commit", )( db_session=db_session, user_id=user_to_delete.id, ) db_session.query(SamlAccount).filter( SamlAccount.user_id == user_to_delete.id ).delete() # Null out ownership on document sets and personas so they're # preserved for other users instead of being cascade-deleted db_session.query(DocumentSet).filter( DocumentSet.user_id == user_to_delete.id ).update({DocumentSet.user_id: None}) db_session.query(Persona).filter(Persona.user_id == user_to_delete.id).update( {Persona.user_id: None} ) db_session.query(DocumentSet__User).filter( DocumentSet__User.user_id == user_to_delete.id ).delete() db_session.query(Persona__User).filter( Persona__User.user_id == user_to_delete.id ).delete() db_session.query(User__UserGroup).filter( User__UserGroup.user_id == user_to_delete.id ).delete() db_session.delete(user_to_delete) db_session.commit() # NOTE: edge case may exist with race conditions # with this `invited user` scheme generally. remove_user_from_invited_users(user_to_delete.email) def batch_get_user_groups( db_session: Session, user_ids: list[UUID], include_default: bool = False, ) -> dict[UUID, list[tuple[int, str]]]: """Fetch group memberships for a batch of users in a single query. Returns a mapping of user_id -> list of (group_id, group_name) tuples.""" if not user_ids: return {} stmt = ( select( User__UserGroup.user_id, UserGroup.id, UserGroup.name, ) .join(UserGroup, UserGroup.id == User__UserGroup.user_group_id) .where(User__UserGroup.user_id.in_(user_ids)) ) if not include_default: stmt = stmt.where(UserGroup.is_default == False) # noqa: E712 rows = db_session.execute(stmt).all() result: dict[UUID, list[tuple[int, str]]] = {uid: [] for uid in user_ids} for user_id, group_id, group_name in rows: result[user_id].append((group_id, group_name)) return result ================================================ FILE: backend/onyx/db/utils.py ================================================ from enum import Enum from typing import Any from psycopg2 import errorcodes from psycopg2 import OperationalError from pydantic import BaseModel from sqlalchemy import inspect from onyx.db.models import Base def model_to_dict(model: Base) -> dict[str, Any]: return {c.key: getattr(model, c.key) for c in inspect(model).mapper.column_attrs} # type: ignore RETRYABLE_PG_CODES = { errorcodes.SERIALIZATION_FAILURE, # '40001' errorcodes.DEADLOCK_DETECTED, # '40P01' errorcodes.CONNECTION_EXCEPTION, # '08000' errorcodes.CONNECTION_DOES_NOT_EXIST, # '08003' errorcodes.CONNECTION_FAILURE, # '08006' errorcodes.TRANSACTION_ROLLBACK, # '40000' } def is_retryable_sqlalchemy_error(exc: BaseException) -> bool: """Helper function for use with tenacity's retry_if_exception as the callback""" if isinstance(exc, OperationalError): pgcode = getattr(getattr(exc, "orig", None), "pgcode", None) return pgcode in RETRYABLE_PG_CODES return False class DocumentRow(BaseModel): id: str doc_metadata: dict[str, Any] external_user_group_ids: list[str] class SortOrder(str, Enum): ASC = "asc" DESC = "desc" class DiscordChannelView(BaseModel): channel_id: int channel_name: str channel_type: str = "text" # text, forum is_private: bool = False # True if @everyone cannot view the channel ================================================ FILE: backend/onyx/db/voice.py ================================================ from typing import Any from uuid import UUID from sqlalchemy import select from sqlalchemy import update from sqlalchemy.orm import Session from onyx.db.models import User from onyx.db.models import VoiceProvider from onyx.error_handling.error_codes import OnyxErrorCode from onyx.error_handling.exceptions import OnyxError MIN_VOICE_PLAYBACK_SPEED = 0.5 MAX_VOICE_PLAYBACK_SPEED = 2.0 def fetch_voice_providers(db_session: Session) -> list[VoiceProvider]: """Fetch all voice providers.""" return list( db_session.scalars(select(VoiceProvider).order_by(VoiceProvider.name)).all() ) def fetch_voice_provider_by_id( db_session: Session, provider_id: int ) -> VoiceProvider | None: """Fetch a voice provider by ID.""" return db_session.scalar( select(VoiceProvider).where(VoiceProvider.id == provider_id) ) def fetch_default_stt_provider(db_session: Session) -> VoiceProvider | None: """Fetch the default STT provider.""" return db_session.scalar( select(VoiceProvider).where(VoiceProvider.is_default_stt.is_(True)) ) def fetch_default_tts_provider(db_session: Session) -> VoiceProvider | None: """Fetch the default TTS provider.""" return db_session.scalar( select(VoiceProvider).where(VoiceProvider.is_default_tts.is_(True)) ) def fetch_voice_provider_by_type( db_session: Session, provider_type: str ) -> VoiceProvider | None: """Fetch a voice provider by type.""" return db_session.scalar( select(VoiceProvider).where(VoiceProvider.provider_type == provider_type) ) def upsert_voice_provider( *, db_session: Session, provider_id: int | None, name: str, provider_type: str, api_key: str | None, api_key_changed: bool, api_base: str | None = None, custom_config: dict[str, Any] | None = None, stt_model: str | None = None, tts_model: str | None = None, default_voice: str | None = None, activate_stt: bool = False, activate_tts: bool = False, ) -> VoiceProvider: """Create or update a voice provider.""" provider: VoiceProvider | None = None if provider_id is not None: provider = fetch_voice_provider_by_id(db_session, provider_id) if provider is None: raise OnyxError( OnyxErrorCode.NOT_FOUND, f"No voice provider with id {provider_id} exists.", ) else: provider = VoiceProvider() db_session.add(provider) # Apply updates provider.name = name provider.provider_type = provider_type provider.api_base = api_base provider.custom_config = custom_config provider.stt_model = stt_model provider.tts_model = tts_model provider.default_voice = default_voice # Only update API key if explicitly changed or if provider has no key if api_key_changed or provider.api_key is None: provider.api_key = api_key # type: ignore[assignment] db_session.flush() if activate_stt: set_default_stt_provider(db_session=db_session, provider_id=provider.id) if activate_tts: set_default_tts_provider(db_session=db_session, provider_id=provider.id) db_session.refresh(provider) return provider def delete_voice_provider(db_session: Session, provider_id: int) -> None: """Delete a voice provider by ID.""" provider = fetch_voice_provider_by_id(db_session, provider_id) if provider: db_session.delete(provider) db_session.flush() def set_default_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider: """Set a voice provider as the default STT provider.""" provider = fetch_voice_provider_by_id(db_session, provider_id) if provider is None: raise OnyxError( OnyxErrorCode.NOT_FOUND, f"No voice provider with id {provider_id} exists.", ) # Deactivate all other STT providers db_session.execute( update(VoiceProvider) .where( VoiceProvider.is_default_stt.is_(True), VoiceProvider.id != provider_id, ) .values(is_default_stt=False) ) # Activate this provider provider.is_default_stt = True db_session.flush() db_session.refresh(provider) return provider def set_default_tts_provider( *, db_session: Session, provider_id: int, tts_model: str | None = None ) -> VoiceProvider: """Set a voice provider as the default TTS provider.""" provider = fetch_voice_provider_by_id(db_session, provider_id) if provider is None: raise OnyxError( OnyxErrorCode.NOT_FOUND, f"No voice provider with id {provider_id} exists.", ) # Deactivate all other TTS providers db_session.execute( update(VoiceProvider) .where( VoiceProvider.is_default_tts.is_(True), VoiceProvider.id != provider_id, ) .values(is_default_tts=False) ) # Activate this provider provider.is_default_tts = True # Update the TTS model if specified if tts_model is not None: provider.tts_model = tts_model db_session.flush() db_session.refresh(provider) return provider def deactivate_stt_provider(*, db_session: Session, provider_id: int) -> VoiceProvider: """Remove the default STT status from a voice provider.""" provider = fetch_voice_provider_by_id(db_session, provider_id) if provider is None: raise OnyxError( OnyxErrorCode.NOT_FOUND, f"No voice provider with id {provider_id} exists.", ) provider.is_default_stt = False db_session.flush() db_session.refresh(provider) return provider def deactivate_tts_provider(*, db_session: Session, provider_id: int) -> VoiceProvider: """Remove the default TTS status from a voice provider.""" provider = fetch_voice_provider_by_id(db_session, provider_id) if provider is None: raise OnyxError( OnyxErrorCode.NOT_FOUND, f"No voice provider with id {provider_id} exists.", ) provider.is_default_tts = False db_session.flush() db_session.refresh(provider) return provider # User voice preferences def update_user_voice_settings( db_session: Session, user_id: UUID, auto_send: bool | None = None, auto_playback: bool | None = None, playback_speed: float | None = None, ) -> None: """Update user's voice settings. For all fields, None means "don't update this field". """ values: dict[str, bool | float] = {} if auto_send is not None: values["voice_auto_send"] = auto_send if auto_playback is not None: values["voice_auto_playback"] = auto_playback if playback_speed is not None: values["voice_playback_speed"] = max( MIN_VOICE_PLAYBACK_SPEED, min(MAX_VOICE_PLAYBACK_SPEED, playback_speed) ) if values: db_session.execute(update(User).where(User.id == user_id).values(**values)) # type: ignore[arg-type] db_session.flush() ================================================ FILE: backend/onyx/db/web_search.py ================================================ from __future__ import annotations from sqlalchemy import select from sqlalchemy import update from sqlalchemy.orm import Session from onyx.db.models import InternetContentProvider from onyx.db.models import InternetSearchProvider from onyx.tools.tool_implementations.web_search.models import WebContentProviderConfig from shared_configs.enums import WebContentProviderType from shared_configs.enums import WebSearchProviderType def fetch_web_search_providers(db_session: Session) -> list[InternetSearchProvider]: stmt = select(InternetSearchProvider).order_by(InternetSearchProvider.id.asc()) return list(db_session.scalars(stmt).all()) def fetch_web_content_providers(db_session: Session) -> list[InternetContentProvider]: stmt = select(InternetContentProvider).order_by(InternetContentProvider.id.asc()) return list(db_session.scalars(stmt).all()) def fetch_active_web_search_provider( db_session: Session, ) -> InternetSearchProvider | None: stmt = select(InternetSearchProvider).where( InternetSearchProvider.is_active.is_(True) ) return db_session.scalars(stmt).first() def fetch_web_search_provider_by_id( provider_id: int, db_session: Session ) -> InternetSearchProvider | None: return db_session.get(InternetSearchProvider, provider_id) def fetch_web_search_provider_by_name( name: str, db_session: Session ) -> InternetSearchProvider | None: stmt = select(InternetSearchProvider).where(InternetSearchProvider.name.ilike(name)) return db_session.scalars(stmt).first() def fetch_web_search_provider_by_type( provider_type: WebSearchProviderType, db_session: Session ) -> InternetSearchProvider | None: stmt = select(InternetSearchProvider).where( InternetSearchProvider.provider_type == provider_type.value ) return db_session.scalars(stmt).first() def _ensure_unique_search_name( name: str, provider_id: int | None, db_session: Session ) -> None: existing = fetch_web_search_provider_by_name(name=name, db_session=db_session) if existing and existing.id != provider_id: raise ValueError(f"A web search provider named '{name}' already exists.") def _apply_search_provider_updates( provider: InternetSearchProvider, *, name: str, provider_type: WebSearchProviderType, api_key: str | None, api_key_changed: bool, config: dict[str, str] | None, ) -> None: provider.name = name provider.provider_type = provider_type.value provider.config = config if api_key_changed or provider.api_key is None: # EncryptedString accepts str for writes, returns SensitiveValue for reads provider.api_key = api_key # type: ignore[assignment] def upsert_web_search_provider( *, provider_id: int | None, name: str, provider_type: WebSearchProviderType, api_key: str | None, api_key_changed: bool, config: dict[str, str] | None, activate: bool, db_session: Session, ) -> InternetSearchProvider: _ensure_unique_search_name( name=name, provider_id=provider_id, db_session=db_session ) provider: InternetSearchProvider | None = None if provider_id is not None: provider = fetch_web_search_provider_by_id(provider_id, db_session) if provider is None: raise ValueError(f"No web search provider with id {provider_id} exists.") else: provider = InternetSearchProvider() db_session.add(provider) _apply_search_provider_updates( provider, name=name, provider_type=provider_type, api_key=api_key, api_key_changed=api_key_changed, config=config, ) db_session.flush() if activate: set_active_web_search_provider(provider_id=provider.id, db_session=db_session) db_session.refresh(provider) return provider def set_active_web_search_provider( *, provider_id: int | None, db_session: Session ) -> InternetSearchProvider: if provider_id is None: raise ValueError("Cannot activate a provider without an id.") provider = fetch_web_search_provider_by_id(provider_id, db_session) if provider is None: raise ValueError(f"No web search provider with id {provider_id} exists.") db_session.execute( update(InternetSearchProvider) .where( InternetSearchProvider.is_active.is_(True), InternetSearchProvider.id != provider_id, ) .values(is_active=False) ) provider.is_active = True db_session.flush() db_session.refresh(provider) return provider def deactivate_web_search_provider( *, provider_id: int | None, db_session: Session ) -> InternetSearchProvider: if provider_id is None: raise ValueError("Cannot deactivate a provider without an id.") provider = fetch_web_search_provider_by_id(provider_id, db_session) if provider is None: raise ValueError(f"No web search provider with id {provider_id} exists.") provider.is_active = False db_session.flush() db_session.refresh(provider) return provider def delete_web_search_provider(provider_id: int, db_session: Session) -> None: provider = fetch_web_search_provider_by_id(provider_id, db_session) if provider is None: raise ValueError(f"No web search provider with id {provider_id} exists.") db_session.delete(provider) db_session.flush() db_session.commit() # Content provider helpers def fetch_active_web_content_provider( db_session: Session, ) -> InternetContentProvider | None: stmt = select(InternetContentProvider).where( InternetContentProvider.is_active.is_(True) ) return db_session.scalars(stmt).first() def fetch_web_content_provider_by_id( provider_id: int, db_session: Session ) -> InternetContentProvider | None: return db_session.get(InternetContentProvider, provider_id) def fetch_web_content_provider_by_name( name: str, db_session: Session ) -> InternetContentProvider | None: stmt = select(InternetContentProvider).where( InternetContentProvider.name.ilike(name) ) return db_session.scalars(stmt).first() def fetch_web_content_provider_by_type( provider_type: WebContentProviderType, db_session: Session ) -> InternetContentProvider | None: stmt = select(InternetContentProvider).where( InternetContentProvider.provider_type == provider_type.value ) return db_session.scalars(stmt).first() def _ensure_unique_content_name( name: str, provider_id: int | None, db_session: Session ) -> None: existing = fetch_web_content_provider_by_name(name=name, db_session=db_session) if existing and existing.id != provider_id: raise ValueError(f"A web content provider named '{name}' already exists.") def _apply_content_provider_updates( provider: InternetContentProvider, *, name: str, provider_type: WebContentProviderType, api_key: str | None, api_key_changed: bool, config: WebContentProviderConfig | None, ) -> None: provider.name = name provider.provider_type = provider_type.value provider.config = config if api_key_changed or provider.api_key is None: # EncryptedString accepts str for writes, returns SensitiveValue for reads provider.api_key = api_key # type: ignore[assignment] def upsert_web_content_provider( *, provider_id: int | None, name: str, provider_type: WebContentProviderType, api_key: str | None, api_key_changed: bool, config: WebContentProviderConfig | None, activate: bool, db_session: Session, ) -> InternetContentProvider: _ensure_unique_content_name( name=name, provider_id=provider_id, db_session=db_session ) provider: InternetContentProvider | None = None if provider_id is not None: provider = fetch_web_content_provider_by_id(provider_id, db_session) if provider is None: raise ValueError(f"No web content provider with id {provider_id} exists.") else: provider = InternetContentProvider() db_session.add(provider) _apply_content_provider_updates( provider, name=name, provider_type=provider_type, api_key=api_key, api_key_changed=api_key_changed, config=config, ) db_session.flush() if activate: set_active_web_content_provider(provider_id=provider.id, db_session=db_session) db_session.refresh(provider) return provider def set_active_web_content_provider( *, provider_id: int | None, db_session: Session ) -> InternetContentProvider: if provider_id is None: raise ValueError("Cannot activate a provider without an id.") provider = fetch_web_content_provider_by_id(provider_id, db_session) if provider is None: raise ValueError(f"No web content provider with id {provider_id} exists.") db_session.execute( update(InternetContentProvider) .where( InternetContentProvider.is_active.is_(True), InternetContentProvider.id != provider_id, ) .values(is_active=False) ) provider.is_active = True db_session.flush() db_session.refresh(provider) return provider def deactivate_web_content_provider( *, provider_id: int | None, db_session: Session ) -> InternetContentProvider: if provider_id is None: raise ValueError("Cannot deactivate a provider without an id.") provider = fetch_web_content_provider_by_id(provider_id, db_session) if provider is None: raise ValueError(f"No web content provider with id {provider_id} exists.") provider.is_active = False db_session.flush() db_session.refresh(provider) return provider def delete_web_content_provider(provider_id: int, db_session: Session) -> None: provider = fetch_web_content_provider_by_id(provider_id, db_session) if provider is None: raise ValueError(f"No web content provider with id {provider_id} exists.") db_session.delete(provider) db_session.flush() db_session.commit() ================================================ FILE: backend/onyx/deep_research/__init__.py ================================================ ================================================ FILE: backend/onyx/deep_research/dr_loop.py ================================================ # TODO: Notes for potential extensions and future improvements: # 1. Allow tools that aren't search specific tools # 2. Use user provided custom prompts # 3. Save the plan for replay import time from collections.abc import Callable from typing import cast from sqlalchemy.orm import Session from onyx.chat.chat_state import ChatStateContainer from onyx.chat.citation_processor import CitationMapping from onyx.chat.citation_processor import DynamicCitationProcessor from onyx.chat.emitter import Emitter from onyx.chat.llm_loop import construct_message_history from onyx.chat.llm_step import run_llm_step from onyx.chat.llm_step import run_llm_step_pkt_generator from onyx.chat.models import ChatMessageSimple from onyx.chat.models import FileToolMetadata from onyx.chat.models import LlmStepResult from onyx.chat.models import ToolCallSimple from onyx.configs.chat_configs import SKIP_DEEP_RESEARCH_CLARIFICATION from onyx.configs.constants import MessageType from onyx.db.tools import get_tool_by_name from onyx.deep_research.dr_mock_tools import get_clarification_tool_definitions from onyx.deep_research.dr_mock_tools import get_orchestrator_tools from onyx.deep_research.dr_mock_tools import RESEARCH_AGENT_TOOL_NAME from onyx.deep_research.dr_mock_tools import THINK_TOOL_RESPONSE_MESSAGE from onyx.deep_research.dr_mock_tools import THINK_TOOL_RESPONSE_TOKEN_COUNT from onyx.deep_research.utils import check_special_tool_calls from onyx.deep_research.utils import create_think_tool_token_processor from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLMUserIdentity from onyx.llm.models import ToolChoiceOptions from onyx.llm.utils import model_is_reasoning_model from onyx.prompts.deep_research.orchestration_layer import CLARIFICATION_PROMPT from onyx.prompts.deep_research.orchestration_layer import FINAL_REPORT_PROMPT from onyx.prompts.deep_research.orchestration_layer import FIRST_CYCLE_REMINDER from onyx.prompts.deep_research.orchestration_layer import FIRST_CYCLE_REMINDER_TOKENS from onyx.prompts.deep_research.orchestration_layer import ( INTERNAL_SEARCH_CLARIFICATION_GUIDANCE, ) from onyx.prompts.deep_research.orchestration_layer import ( INTERNAL_SEARCH_RESEARCH_TASK_GUIDANCE, ) from onyx.prompts.deep_research.orchestration_layer import ORCHESTRATOR_PROMPT from onyx.prompts.deep_research.orchestration_layer import ORCHESTRATOR_PROMPT_REASONING from onyx.prompts.deep_research.orchestration_layer import RESEARCH_PLAN_PROMPT from onyx.prompts.deep_research.orchestration_layer import RESEARCH_PLAN_REMINDER from onyx.prompts.deep_research.orchestration_layer import USER_FINAL_REPORT_QUERY from onyx.prompts.prompt_utils import get_current_llm_day_time from onyx.server.query_and_chat.placement import Placement from onyx.server.query_and_chat.streaming_models import AgentResponseDelta from onyx.server.query_and_chat.streaming_models import AgentResponseStart from onyx.server.query_and_chat.streaming_models import DeepResearchPlanDelta from onyx.server.query_and_chat.streaming_models import DeepResearchPlanStart from onyx.server.query_and_chat.streaming_models import OverallStop from onyx.server.query_and_chat.streaming_models import Packet from onyx.server.query_and_chat.streaming_models import SectionEnd from onyx.server.query_and_chat.streaming_models import TopLevelBranching from onyx.tools.fake_tools.research_agent import run_research_agent_calls from onyx.tools.interface import Tool from onyx.tools.models import ToolCallInfo from onyx.tools.models import ToolCallKickoff from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool from onyx.tools.tool_implementations.search.search_tool import SearchTool from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool from onyx.tracing.framework.create import function_span from onyx.tracing.framework.create import trace from onyx.utils.logger import setup_logger from onyx.utils.timing import log_function_time from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() MAX_USER_MESSAGES_FOR_CONTEXT = 5 MAX_FINAL_REPORT_TOKENS = 20000 # 30 minute timeout before forcing final report generation # NOTE: The overall execution may be much longer still because it could run a research cycle at minute 29 # and that runs for another nearly 30 minutes. DEEP_RESEARCH_FORCE_REPORT_SECONDS = 30 * 60 # Might be something like (this gives a lot of leeway for change but typically the models don't do this): # 0. Research topics 1-3 # 1. Think # 2. Research topics 4-5 # 3. Think # 4. Research topics 6 + something new or different from the plan # 5. Think # 6. Research, possibly something new or different from the plan # 7. Think # 8. Generate report MAX_ORCHESTRATOR_CYCLES = 8 # Similar but without the 4 thinking tool calls MAX_ORCHESTRATOR_CYCLES_REASONING = 4 def generate_final_report( history: list[ChatMessageSimple], research_plan: str, llm: LLM, token_counter: Callable[[str], int], state_container: ChatStateContainer, emitter: Emitter, turn_index: int, citation_mapping: CitationMapping, user_identity: LLMUserIdentity | None, saved_reasoning: str | None = None, pre_answer_processing_time: float | None = None, all_injected_file_metadata: dict[str, FileToolMetadata] | None = None, ) -> bool: """Generate the final research report. Returns: bool: True if reasoning occurred during report generation (turn_index was incremented), False otherwise. """ with function_span("generate_report") as span: span.span_data.input = f"history_length={len(history)}, turn_index={turn_index}" final_report_prompt = FINAL_REPORT_PROMPT.format( current_datetime=get_current_llm_day_time(full_sentence=False), ) system_prompt = ChatMessageSimple( message=final_report_prompt, token_count=token_counter(final_report_prompt), message_type=MessageType.SYSTEM, ) final_reminder = USER_FINAL_REPORT_QUERY.format(research_plan=research_plan) reminder_message = ChatMessageSimple( message=final_reminder, token_count=token_counter(final_reminder), message_type=MessageType.USER_REMINDER, ) final_report_history = construct_message_history( system_prompt=system_prompt, custom_agent_prompt=None, simple_chat_history=history, reminder_message=reminder_message, context_files=None, available_tokens=llm.config.max_input_tokens, all_injected_file_metadata=all_injected_file_metadata, ) citation_processor = DynamicCitationProcessor() citation_processor.update_citation_mapping(citation_mapping) # Only passing in the cited documents as the whole list would be too long final_documents = list(citation_processor.citation_to_doc.values()) llm_step_result, has_reasoned = run_llm_step( emitter=emitter, history=final_report_history, tool_definitions=[], tool_choice=ToolChoiceOptions.NONE, llm=llm, placement=Placement(turn_index=turn_index), citation_processor=citation_processor, state_container=state_container, final_documents=final_documents, user_identity=user_identity, max_tokens=MAX_FINAL_REPORT_TOKENS, is_deep_research=True, pre_answer_processing_time=pre_answer_processing_time, timeout_override=300, # 5 minute read timeout for long report generation ) # Save citation mapping to state_container so citations are persisted state_container.set_citation_mapping(citation_processor.citation_to_doc) final_report = llm_step_result.answer if final_report is None: raise ValueError("LLM failed to generate the final deep research report") if saved_reasoning: # The reasoning we want to save with the message is more about calling this # generate report and why it's done. Also some models don't have reasoning # but we'd still want to capture the reasoning from the think_tool of theprevious turn. state_container.set_reasoning_tokens(saved_reasoning) span.span_data.output = final_report if final_report else None return has_reasoned @log_function_time(print_only=True) def run_deep_research_llm_loop( emitter: Emitter, state_container: ChatStateContainer, simple_chat_history: list[ChatMessageSimple], tools: list[Tool], custom_agent_prompt: str | None, # noqa: ARG001 llm: LLM, token_counter: Callable[[str], int], db_session: Session, skip_clarification: bool = False, user_identity: LLMUserIdentity | None = None, chat_session_id: str | None = None, all_injected_file_metadata: dict[str, FileToolMetadata] | None = None, ) -> None: with trace( "run_deep_research_llm_loop", group_id=chat_session_id, metadata={ "tenant_id": get_current_tenant_id(), "chat_session_id": chat_session_id, }, ): # Here for lazy load LiteLLM from onyx.llm.litellm_singleton.config import initialize_litellm # An approximate limit. In extreme cases it may still fail but this should allow deep research # to work in most cases. if llm.config.max_input_tokens < 50000: raise RuntimeError( "Cannot run Deep Research with an LLM that has less than 50,000 max input tokens" ) initialize_litellm() # Track processing start time for tool duration calculation processing_start_time = time.monotonic() available_tokens = llm.config.max_input_tokens llm_step_result: LlmStepResult | None = None # Filter tools to only allow web search, internal search, and open URL allowed_tool_names = {SearchTool.NAME, WebSearchTool.NAME, OpenURLTool.NAME} allowed_tools = [tool for tool in tools if tool.name in allowed_tool_names] include_internal_search_tunings = SearchTool.NAME in allowed_tool_names orchestrator_start_turn_index = 1 ######################################################### # CLARIFICATION STEP (optional) ######################################################### internal_search_clarification_guidance = ( INTERNAL_SEARCH_CLARIFICATION_GUIDANCE if include_internal_search_tunings else "" ) if not SKIP_DEEP_RESEARCH_CLARIFICATION and not skip_clarification: with function_span("clarification_step") as span: clarification_prompt = CLARIFICATION_PROMPT.format( current_datetime=get_current_llm_day_time(full_sentence=False), internal_search_clarification_guidance=internal_search_clarification_guidance, ) system_prompt = ChatMessageSimple( message=clarification_prompt, token_count=300, # Skips the exact token count but has enough leeway message_type=MessageType.SYSTEM, ) truncated_message_history = construct_message_history( system_prompt=system_prompt, custom_agent_prompt=None, simple_chat_history=simple_chat_history, reminder_message=None, context_files=None, available_tokens=available_tokens, last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT, all_injected_file_metadata=all_injected_file_metadata, ) # Calculate tool processing duration for clarification step # (used if the LLM emits a clarification question instead of calling tools) clarification_tool_duration = time.monotonic() - processing_start_time llm_step_result, _ = run_llm_step( emitter=emitter, history=truncated_message_history, tool_definitions=get_clarification_tool_definitions(), tool_choice=ToolChoiceOptions.AUTO, llm=llm, placement=Placement(turn_index=0), # No citations in this step, it should just pass through all # tokens directly so initialized as an empty citation processor citation_processor=None, state_container=state_container, final_documents=None, user_identity=user_identity, is_deep_research=True, pre_answer_processing_time=clarification_tool_duration, ) if not llm_step_result.tool_calls: # Mark this turn as a clarification question state_container.set_is_clarification(True) span.span_data.output = "clarification_required" emitter.emit( Packet( placement=Placement(turn_index=0), obj=OverallStop(type="stop"), ) ) # If a clarification is asked, we need to end this turn and wait on user input return ######################################################### # RESEARCH PLAN STEP ######################################################### with function_span("research_plan_step") as span: system_prompt = ChatMessageSimple( message=RESEARCH_PLAN_PROMPT.format( current_datetime=get_current_llm_day_time(full_sentence=False) ), token_count=300, message_type=MessageType.SYSTEM, ) # Note this is fine to use a USER message type here as it can just be interpretered as a # user's message directly to the LLM. reminder_message = ChatMessageSimple( message=RESEARCH_PLAN_REMINDER, token_count=100, message_type=MessageType.USER, ) truncated_message_history = construct_message_history( system_prompt=system_prompt, custom_agent_prompt=None, simple_chat_history=simple_chat_history + [reminder_message], reminder_message=None, context_files=None, available_tokens=available_tokens, last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT + 1, all_injected_file_metadata=all_injected_file_metadata, ) research_plan_generator = run_llm_step_pkt_generator( history=truncated_message_history, tool_definitions=[], tool_choice=ToolChoiceOptions.NONE, llm=llm, placement=Placement(turn_index=0), citation_processor=None, state_container=state_container, final_documents=None, user_identity=user_identity, is_deep_research=True, ) while True: try: packet = next(research_plan_generator) # Translate AgentResponseStart/Delta packets to DeepResearchPlanStart/Delta # The LLM response from this prompt is the research plan if isinstance(packet.obj, AgentResponseStart): emitter.emit( Packet( placement=packet.placement, obj=DeepResearchPlanStart(), ) ) elif isinstance(packet.obj, AgentResponseDelta): emitter.emit( Packet( placement=packet.placement, obj=DeepResearchPlanDelta(content=packet.obj.content), ) ) else: # Pass through other packet types (e.g., ReasoningStart, ReasoningDelta, etc.) emitter.emit(packet) except StopIteration as e: llm_step_result, reasoned = e.value emitter.emit( Packet( # Marks the last turn end which should be the plan generation placement=Placement( turn_index=1 if reasoned else 0, ), obj=SectionEnd(), ) ) if reasoned: orchestrator_start_turn_index += 1 break llm_step_result = cast(LlmStepResult, llm_step_result) research_plan = llm_step_result.answer if research_plan is None: raise RuntimeError("Deep Research failed to generate a research plan") span.span_data.output = research_plan if research_plan else None ######################################################### # RESEARCH EXECUTION STEP ######################################################### with function_span("research_execution_step") as span: is_reasoning_model = model_is_reasoning_model( llm.config.model_name, llm.config.model_provider ) max_orchestrator_cycles = ( MAX_ORCHESTRATOR_CYCLES if not is_reasoning_model else MAX_ORCHESTRATOR_CYCLES_REASONING ) orchestrator_prompt_template = ( ORCHESTRATOR_PROMPT if not is_reasoning_model else ORCHESTRATOR_PROMPT_REASONING ) internal_search_research_task_guidance = ( INTERNAL_SEARCH_RESEARCH_TASK_GUIDANCE if include_internal_search_tunings else "" ) token_count_prompt = orchestrator_prompt_template.format( current_datetime=get_current_llm_day_time(full_sentence=False), current_cycle_count=1, max_cycles=max_orchestrator_cycles, research_plan=research_plan, internal_search_research_task_guidance=internal_search_research_task_guidance, ) orchestration_tokens = token_counter(token_count_prompt) reasoning_cycles = 0 most_recent_reasoning: str | None = None citation_mapping: CitationMapping = {} final_turn_index: int = ( orchestrator_start_turn_index # Track the final turn_index for stop packet ) for cycle in range(max_orchestrator_cycles): # Check if we've exceeded the time limit or reached the last cycle # - if so, skip LLM and generate final report elapsed_seconds = time.monotonic() - processing_start_time timed_out = elapsed_seconds > DEEP_RESEARCH_FORCE_REPORT_SECONDS is_last_cycle = cycle == max_orchestrator_cycles - 1 if timed_out or is_last_cycle: if timed_out: logger.info( f"Deep research exceeded {DEEP_RESEARCH_FORCE_REPORT_SECONDS}s " f"(elapsed: {elapsed_seconds:.1f}s), forcing final report generation" ) report_turn_index = ( orchestrator_start_turn_index + cycle + reasoning_cycles ) report_reasoned = generate_final_report( history=simple_chat_history, research_plan=research_plan, llm=llm, token_counter=token_counter, state_container=state_container, emitter=emitter, turn_index=report_turn_index, citation_mapping=citation_mapping, user_identity=user_identity, pre_answer_processing_time=elapsed_seconds, all_injected_file_metadata=all_injected_file_metadata, ) final_turn_index = report_turn_index + (1 if report_reasoned else 0) break if cycle == 1: first_cycle_reminder_message = ChatMessageSimple( message=FIRST_CYCLE_REMINDER, token_count=FIRST_CYCLE_REMINDER_TOKENS, message_type=MessageType.USER_REMINDER, ) else: first_cycle_reminder_message = None research_agent_calls: list[ToolCallKickoff] = [] orchestrator_prompt = orchestrator_prompt_template.format( current_datetime=get_current_llm_day_time(full_sentence=False), current_cycle_count=cycle, max_cycles=max_orchestrator_cycles, research_plan=research_plan, internal_search_research_task_guidance=internal_search_research_task_guidance, ) system_prompt = ChatMessageSimple( message=orchestrator_prompt, token_count=orchestration_tokens, message_type=MessageType.SYSTEM, ) truncated_message_history = construct_message_history( system_prompt=system_prompt, custom_agent_prompt=None, simple_chat_history=simple_chat_history, reminder_message=first_cycle_reminder_message, context_files=None, available_tokens=available_tokens, last_n_user_messages=MAX_USER_MESSAGES_FOR_CONTEXT, all_injected_file_metadata=all_injected_file_metadata, ) # Use think tool processor for non-reasoning models to convert # think_tool calls to reasoning content custom_processor = ( create_think_tool_token_processor() if not is_reasoning_model else None ) llm_step_result, has_reasoned = run_llm_step( emitter=emitter, history=truncated_message_history, tool_definitions=get_orchestrator_tools( include_think_tool=not is_reasoning_model ), tool_choice=ToolChoiceOptions.REQUIRED, llm=llm, placement=Placement( turn_index=orchestrator_start_turn_index + cycle + reasoning_cycles ), # No citations in this step, it should just pass through all # tokens directly so initialized as an empty citation processor citation_processor=DynamicCitationProcessor(), state_container=state_container, final_documents=None, user_identity=user_identity, custom_token_processor=custom_processor, is_deep_research=True, # Even for the reasoning tool, this should be plenty # The generation here should never be very long as it's just the tool calls. # This prevents timeouts where the model gets into an endless loop of null or bad tokens. max_tokens=1024, ) if has_reasoned: reasoning_cycles += 1 tool_calls = llm_step_result.tool_calls or [] if not tool_calls and cycle == 0: raise RuntimeError( "Deep Research failed to generate any research tasks for the agents." ) if not tool_calls: # Basically hope that this is an infrequent occurence and hopefully multiple research # cycles have already ran logger.warning("No tool calls found, this should not happen.") report_turn_index = ( orchestrator_start_turn_index + cycle + reasoning_cycles ) report_reasoned = generate_final_report( history=simple_chat_history, research_plan=research_plan, llm=llm, token_counter=token_counter, state_container=state_container, emitter=emitter, turn_index=report_turn_index, citation_mapping=citation_mapping, user_identity=user_identity, pre_answer_processing_time=time.monotonic() - processing_start_time, all_injected_file_metadata=all_injected_file_metadata, ) final_turn_index = report_turn_index + (1 if report_reasoned else 0) break special_tool_calls = check_special_tool_calls(tool_calls=tool_calls) if special_tool_calls.generate_report_tool_call: report_turn_index = ( special_tool_calls.generate_report_tool_call.placement.turn_index ) report_reasoned = generate_final_report( history=simple_chat_history, research_plan=research_plan, llm=llm, token_counter=token_counter, state_container=state_container, emitter=emitter, turn_index=report_turn_index, citation_mapping=citation_mapping, user_identity=user_identity, saved_reasoning=most_recent_reasoning, pre_answer_processing_time=time.monotonic() - processing_start_time, all_injected_file_metadata=all_injected_file_metadata, ) final_turn_index = report_turn_index + (1 if report_reasoned else 0) break elif special_tool_calls.think_tool_call: think_tool_call = special_tool_calls.think_tool_call # Only process the THINK_TOOL and skip all other tool calls # This will not actually get saved to the db as a tool call but we'll attach it to the tool(s) called after # it as if it were just a reasoning model doing it. In the chat history, because it happens in 2 steps, # we will show it as a separate message. # NOTE: This does not need to increment the reasoning cycles because the custom token processor causes # the LLM step to handle this with function_span("think_tool") as span: span.span_data.input = str(think_tool_call.tool_args) most_recent_reasoning = state_container.reasoning_tokens tool_call_message = think_tool_call.to_msg_str() tool_call_token_count = token_counter(tool_call_message) # Create ASSISTANT message with tool_calls (OpenAI parallel format) think_tool_simple = ToolCallSimple( tool_call_id=think_tool_call.tool_call_id, tool_name=think_tool_call.tool_name, tool_arguments=think_tool_call.tool_args, token_count=tool_call_token_count, ) think_assistant_msg = ChatMessageSimple( message="", token_count=tool_call_token_count, message_type=MessageType.ASSISTANT, tool_calls=[think_tool_simple], image_files=None, ) simple_chat_history.append(think_assistant_msg) think_tool_response_msg = ChatMessageSimple( message=THINK_TOOL_RESPONSE_MESSAGE, token_count=THINK_TOOL_RESPONSE_TOKEN_COUNT, message_type=MessageType.TOOL_CALL_RESPONSE, tool_call_id=think_tool_call.tool_call_id, image_files=None, ) simple_chat_history.append(think_tool_response_msg) span.span_data.output = THINK_TOOL_RESPONSE_MESSAGE continue else: for tool_call in tool_calls: if tool_call.tool_name != RESEARCH_AGENT_TOOL_NAME: logger.warning( f"Unexpected tool call: {tool_call.tool_name}" ) continue research_agent_calls.append(tool_call) if not research_agent_calls: logger.warning( "No research agent tool calls found, this should not happen." ) report_turn_index = ( orchestrator_start_turn_index + cycle + reasoning_cycles ) report_reasoned = generate_final_report( history=simple_chat_history, research_plan=research_plan, llm=llm, token_counter=token_counter, state_container=state_container, emitter=emitter, turn_index=report_turn_index, citation_mapping=citation_mapping, user_identity=user_identity, pre_answer_processing_time=time.monotonic() - processing_start_time, all_injected_file_metadata=all_injected_file_metadata, ) final_turn_index = report_turn_index + ( 1 if report_reasoned else 0 ) break if len(research_agent_calls) > 1: emitter.emit( Packet( placement=Placement( turn_index=research_agent_calls[ 0 ].placement.turn_index ), obj=TopLevelBranching( num_parallel_branches=len(research_agent_calls) ), ) ) research_results = run_research_agent_calls( # The tool calls here contain the placement information research_agent_calls=research_agent_calls, parent_tool_call_ids=[ tool_call.tool_call_id for tool_call in tool_calls ], tools=allowed_tools, emitter=emitter, state_container=state_container, llm=llm, is_reasoning_model=is_reasoning_model, token_counter=token_counter, citation_mapping=citation_mapping, user_identity=user_identity, ) citation_mapping = research_results.citation_mapping # Build ONE ASSISTANT message with all tool calls (OpenAI parallel format) tool_calls_simple: list[ToolCallSimple] = [] for current_tool_call in research_agent_calls: tool_call_message = current_tool_call.to_msg_str() tool_call_token_count = token_counter(tool_call_message) tool_calls_simple.append( ToolCallSimple( tool_call_id=current_tool_call.tool_call_id, tool_name=current_tool_call.tool_name, tool_arguments=current_tool_call.tool_args, token_count=tool_call_token_count, ) ) total_tool_call_tokens = sum( tc.token_count for tc in tool_calls_simple ) assistant_with_tools = ChatMessageSimple( message="", token_count=total_tool_call_tokens, message_type=MessageType.ASSISTANT, tool_calls=tool_calls_simple, image_files=None, ) simple_chat_history.append(assistant_with_tools) # Now add TOOL_CALL_RESPONSE messages and tool call info for each result for tab_index, report in enumerate( research_results.intermediate_reports ): if report is None: # The LLM will not see that this research was even attempted, it may try # something similar again but this is not bad. logger.error( f"Research agent call at tab_index {tab_index} failed, skipping" ) continue current_tool_call = research_agent_calls[tab_index] tool_call_info = ToolCallInfo( parent_tool_call_id=None, turn_index=orchestrator_start_turn_index + cycle + reasoning_cycles, tab_index=tab_index, tool_name=current_tool_call.tool_name, tool_call_id=current_tool_call.tool_call_id, tool_id=get_tool_by_name( tool_name=RESEARCH_AGENT_TOOL_NAME, db_session=db_session, ).id, reasoning_tokens=llm_step_result.reasoning or most_recent_reasoning, tool_call_arguments=current_tool_call.tool_args, tool_call_response=report, search_docs=None, # Intermediate docs are not saved/shown generated_images=None, ) state_container.add_tool_call(tool_call_info) tool_call_response_msg = ChatMessageSimple( message=report, token_count=token_counter(report), message_type=MessageType.TOOL_CALL_RESPONSE, tool_call_id=current_tool_call.tool_call_id, image_files=None, ) simple_chat_history.append(tool_call_response_msg) # If it reached this point, it did not call reasoning, so here we wipe it to not save it to multiple turns most_recent_reasoning = None emitter.emit( Packet( placement=Placement(turn_index=final_turn_index), obj=OverallStop(type="stop"), ) ) ================================================ FILE: backend/onyx/deep_research/dr_mock_tools.py ================================================ GENERATE_PLAN_TOOL_NAME = "generate_plan" RESEARCH_AGENT_IN_CODE_ID = "ResearchAgent" RESEARCH_AGENT_TOOL_NAME = "research_agent" RESEARCH_AGENT_TASK_KEY = "task" GENERATE_REPORT_TOOL_NAME = "generate_report" THINK_TOOL_NAME = "think_tool" # ruff: noqa: E501, W605 start GENERATE_PLAN_TOOL_DESCRIPTION = { "type": "function", "function": { "name": GENERATE_PLAN_TOOL_NAME, "description": "No clarification needed, generate a research plan for the user's query.", "parameters": { "type": "object", "properties": {}, "required": [], }, }, } RESEARCH_AGENT_TOOL_DESCRIPTION = { "type": "function", "function": { "name": RESEARCH_AGENT_TOOL_NAME, "description": "Conduct research on a specific topic.", "parameters": { "type": "object", "properties": { RESEARCH_AGENT_TASK_KEY: { "type": "string", "description": "The research task to investigate, should be 1-2 descriptive sentences outlining the direction of investigation.", } }, "required": [RESEARCH_AGENT_TASK_KEY], }, }, } GENERATE_REPORT_TOOL_DESCRIPTION = { "type": "function", "function": { "name": GENERATE_REPORT_TOOL_NAME, "description": "Generate the final research report from all of the findings. Should be called when all aspects of the user's query have been researched, or maximum cycles are reached.", "parameters": { "type": "object", "properties": {}, "required": [], }, }, } THINK_TOOL_DESCRIPTION = { "type": "function", "function": { "name": THINK_TOOL_NAME, "description": "Use this for reasoning between research_agent calls and before calling generate_report. Think deeply about key results, identify knowledge gaps, and plan next steps.", "parameters": { "type": "object", "properties": { "reasoning": { "type": "string", "description": "Your chain of thought reasoning, use paragraph format, no lists.", } }, "required": ["reasoning"], }, }, } RESEARCH_AGENT_THINK_TOOL_DESCRIPTION = { "type": "function", "function": { "name": "think_tool", "description": "Use this for reasoning between research steps. Think deeply about key results, identify knowledge gaps, and plan next steps.", "parameters": { "type": "object", "properties": { "reasoning": { "type": "string", "description": "Your chain of thought reasoning, can be as long as a lengthy paragraph.", } }, "required": ["reasoning"], }, }, } RESEARCH_AGENT_GENERATE_REPORT_TOOL_DESCRIPTION = { "type": "function", "function": { "name": "generate_report", "description": "Generate the final research report from all findings. Should be called when research is complete.", "parameters": { "type": "object", "properties": {}, "required": [], }, }, } THINK_TOOL_RESPONSE_MESSAGE = "Acknowledged, please continue." THINK_TOOL_RESPONSE_TOKEN_COUNT = 10 def get_clarification_tool_definitions() -> list[dict]: return [GENERATE_PLAN_TOOL_DESCRIPTION] def get_orchestrator_tools(include_think_tool: bool) -> list[dict]: tools = [ RESEARCH_AGENT_TOOL_DESCRIPTION, GENERATE_REPORT_TOOL_DESCRIPTION, ] if include_think_tool: tools.append(THINK_TOOL_DESCRIPTION) return tools def get_research_agent_additional_tool_definitions( include_think_tool: bool, ) -> list[dict]: tools = [GENERATE_REPORT_TOOL_DESCRIPTION] if include_think_tool: tools.append(RESEARCH_AGENT_THINK_TOOL_DESCRIPTION) return tools # ruff: noqa: E501, W605 end ================================================ FILE: backend/onyx/deep_research/models.py ================================================ from pydantic import BaseModel from onyx.chat.citation_processor import CitationMapping from onyx.tools.models import ToolCallKickoff class SpecialToolCalls(BaseModel): think_tool_call: ToolCallKickoff | None = None generate_report_tool_call: ToolCallKickoff | None = None class ResearchAgentCallResult(BaseModel): intermediate_report: str citation_mapping: CitationMapping class CombinedResearchAgentCallResult(BaseModel): # The None is needed here to keep the mappings consistent # we later skip the failed research results but we need to know # which ones failed intermediate_reports: list[str | None] citation_mapping: CitationMapping ================================================ FILE: backend/onyx/deep_research/utils.py ================================================ from collections.abc import Callable from typing import Any from pydantic import BaseModel from onyx.deep_research.dr_mock_tools import GENERATE_REPORT_TOOL_NAME from onyx.deep_research.dr_mock_tools import THINK_TOOL_NAME from onyx.deep_research.models import SpecialToolCalls from onyx.llm.model_response import ChatCompletionDeltaToolCall from onyx.llm.model_response import Delta from onyx.llm.model_response import FunctionCall from onyx.tools.models import ToolCallKickoff # JSON prefixes to detect in think_tool arguments # The schema is: {"reasoning": "...content..."} JSON_PREFIX_WITH_SPACE = '{"reasoning": "' JSON_PREFIX_NO_SPACE = '{"reasoning":"' class ThinkToolProcessorState(BaseModel): """State for tracking think tool processing across streaming deltas.""" think_tool_found: bool = False think_tool_index: int | None = None think_tool_id: str | None = None full_arguments: str = "" # Full accumulated arguments for final tool call accumulated_args: str = "" # Working buffer for JSON parsing json_prefix_stripped: bool = False # Buffer holds content that might be the JSON suffix "} # We hold back 2 chars to avoid emitting the closing "} buffer: str = "" def _unescape_json_string(s: str) -> str: """ Unescape JSON string escape sequences. JSON strings use backslash escapes like \\n for newlines, \\t for tabs, etc. When we extract content from JSON by string manipulation (without json.loads), we need to manually decode these escape sequences. Note: We use a placeholder approach to handle escaped backslashes correctly. For example, "\\\\n" (escaped backslash + n) should become "\\n" (literal backslash + n), not a newline character. """ # First, protect escaped backslashes with a placeholder placeholder = "\x00ESCAPED_BACKSLASH\x00" result = s.replace("\\\\", placeholder) # Now unescape common JSON escape sequences result = result.replace("\\n", "\n") result = result.replace("\\r", "\r") result = result.replace("\\t", "\t") result = result.replace('\\"', '"') # Finally, restore escaped backslashes as single backslashes result = result.replace(placeholder, "\\") return result def _extract_reasoning_chunk(state: ThinkToolProcessorState) -> str | None: """ Extract reasoning content from accumulated arguments, stripping JSON wrapper. Returns the next chunk of reasoning to emit, or None if nothing to emit yet. """ # If we haven't found the JSON prefix yet, look for it if not state.json_prefix_stripped: # Try both prefix variants for prefix in [JSON_PREFIX_WITH_SPACE, JSON_PREFIX_NO_SPACE]: prefix_pos = state.accumulated_args.find(prefix) if prefix_pos != -1: # Found prefix - extract content after it content_start = prefix_pos + len(prefix) state.buffer = state.accumulated_args[content_start:] state.accumulated_args = "" state.json_prefix_stripped = True break if not state.json_prefix_stripped: # Haven't seen full prefix yet, keep accumulating return None else: # Already stripped prefix, add new content to buffer state.buffer += state.accumulated_args state.accumulated_args = "" # Hold back enough chars to avoid splitting escape sequences AND the JSON suffix "} # We need at least 2 for the suffix, but we also need to ensure escape sequences # like \n, \t, \\, \" don't get split. The longest escape is \\ (2 chars). # So we hold back 3 chars to be safe: if the last char is \, we don't want to # emit it without knowing what follows. holdback = 3 if len(state.buffer) <= holdback: return None # Check if there's a trailing backslash that could be part of an escape sequence # If so, hold back one more character to avoid splitting the escape to_emit = state.buffer[:-holdback] remaining = state.buffer[-holdback:] # If to_emit ends with a backslash, it might be the start of an escape sequence # Move it to the remaining buffer to process with the next chunk # If to_emit ends with a backslash, it might be the start of an escape sequence # Move it to the remaining buffer to process with the next chunk if to_emit and to_emit[-1] == "\\": remaining = to_emit[-1] + remaining to_emit = to_emit[:-1] state.buffer = remaining # Unescape JSON escape sequences (e.g., \\n -> \n) if to_emit: to_emit = _unescape_json_string(to_emit) return to_emit if to_emit else None def create_think_tool_token_processor() -> ( Callable[[Delta | None, Any], tuple[Delta | None, Any]] ): """ Create a custom token processor that converts think_tool calls to reasoning content. When the think_tool is detected: - Tool call arguments are converted to reasoning_content (JSON wrapper stripped) - All other deltas (content, other tool calls) are dropped This allows non-reasoning models to emit chain-of-thought via the think_tool, which gets displayed as reasoning tokens in the UI. Returns: A function compatible with run_llm_step_pkt_generator's custom_token_processor parameter. The function takes (Delta, state) and returns (modified Delta | None, new state). """ def process_token(delta: Delta | None, state: Any) -> tuple[Delta | None, Any]: if state is None: state = ThinkToolProcessorState() # Handle flush signal (delta=None) - emit the complete tool call if delta is None: if state.think_tool_found and state.think_tool_id: # Return the complete think tool call complete_tool_call = ChatCompletionDeltaToolCall( id=state.think_tool_id, index=state.think_tool_index or 0, type="function", function=FunctionCall( name=THINK_TOOL_NAME, arguments=state.full_arguments, ), ) return Delta(tool_calls=[complete_tool_call]), state return None, state # Check for think tool in tool_calls if delta.tool_calls: for tool_call in delta.tool_calls: # Detect think tool by name if tool_call.function and tool_call.function.name == THINK_TOOL_NAME: state.think_tool_found = True state.think_tool_index = tool_call.index # Capture tool call id when available if ( state.think_tool_found and tool_call.index == state.think_tool_index and tool_call.id ): state.think_tool_id = tool_call.id # Accumulate arguments for the think tool if ( state.think_tool_found and tool_call.index == state.think_tool_index and tool_call.function and tool_call.function.arguments ): # Track full arguments for final tool call state.full_arguments += tool_call.function.arguments # Also accumulate for JSON parsing state.accumulated_args += tool_call.function.arguments # Try to extract reasoning content reasoning_chunk = _extract_reasoning_chunk(state) if reasoning_chunk: # Return delta with reasoning_content to trigger reasoning streaming return Delta(reasoning_content=reasoning_chunk), state # If think tool found, drop all other content if state.think_tool_found: return None, state # No think tool detected, pass through original delta return delta, state return process_token def check_special_tool_calls(tool_calls: list[ToolCallKickoff]) -> SpecialToolCalls: think_tool_call: ToolCallKickoff | None = None generate_report_tool_call: ToolCallKickoff | None = None for tool_call in tool_calls: if tool_call.tool_name == THINK_TOOL_NAME: think_tool_call = tool_call elif tool_call.tool_name == GENERATE_REPORT_TOOL_NAME: generate_report_tool_call = tool_call return SpecialToolCalls( think_tool_call=think_tool_call, generate_report_tool_call=generate_report_tool_call, ) ================================================ FILE: backend/onyx/document_index/FILTER_SEMANTICS.md ================================================ # Vector DB Filter Semantics How `IndexFilters` fields combine into the final query filter. Applies to both Vespa and OpenSearch. ## Filter categories | Category | Fields | Join logic | |---|---|---| | **Visibility** | `hidden` | Always applied (unless `include_hidden`) | | **Tenant** | `tenant_id` | AND (multi-tenant only) | | **ACL** | `access_control_list` | OR within, AND with rest | | **Narrowing** | `source_type`, `tags`, `time_cutoff` | Each OR within, AND with rest | | **Knowledge scope** | `document_set`, `attached_document_ids`, `hierarchy_node_ids`, `persona_id_filter` | OR within group, AND with rest | | **Additive scope** | `project_id_filter` | OR'd into knowledge scope **only when** a knowledge scope filter already exists | ## How filters combine All categories are AND'd together. Within the knowledge scope category, individual filters are OR'd. ``` NOT hidden AND tenant = T -- if multi-tenant AND (acl contains A1 OR acl contains A2) AND (source_type = S1 OR ...) -- if set AND (tag = T1 OR ...) -- if set AND -- see below AND time >= cutoff -- if set ``` ## Knowledge scope rules The knowledge scope filter controls **what knowledge an assistant can access**. ### Primary vs additive triggers - **`persona_id_filter`** is a **primary** trigger. A persona with user files IS explicit knowledge, so `persona_id_filter` alone can start a knowledge scope. Note: this is NOT the raw ID of the persona being used — it is only set when the persona's user files overflowed the LLM context window. - **`project_id_filter`** is **additive**. It widens an existing scope to include project files but never restricts on its own — a chat inside a project should still search team knowledge when no other knowledge is attached. ### No explicit knowledge attached When `document_set`, `attached_document_ids`, `hierarchy_node_ids`, and `persona_id_filter` are all empty/None: - **No knowledge scope filter is applied.** The assistant can see everything (subject to ACL). - `project_id_filter` is ignored — it never restricts on its own. ### One explicit knowledge type ``` -- Only document sets AND (document_sets contains "Engineering" OR document_sets contains "Legal") -- Only persona user files (overflowed context) AND (personas contains 42) ``` ### Multiple explicit knowledge types (OR'd) ``` -- Document sets + persona user files AND ( document_sets contains "Engineering" OR personas contains 42 ) ``` ### Explicit knowledge + overflowing project files When an explicit knowledge restriction is in effect **and** `project_id_filter` is set (project files overflowed the LLM context window), `project_id_filter` widens the filter: ``` -- Document sets + project files overflowed AND ( document_sets contains "Engineering" OR user_project contains 7 ) -- Persona user files + project files (won't happen in practice; -- custom personas ignore project files per the precedence rule) AND ( personas contains 42 OR user_project contains 7 ) ``` ### Only project_id_filter (no explicit knowledge) No knowledge scope filter. The assistant searches everything. ``` -- Just ACL, no restriction NOT hidden AND (acl contains ...) ``` ## Field reference | Filter field | Vespa field | Vespa type | Purpose | |---|---|---|---| | `document_set` | `document_sets` | `weightedset` | Connector doc sets attached to assistant | | `attached_document_ids` | `document_id` | `string` | Documents explicitly attached (OpenSearch only) | | `hierarchy_node_ids` | `ancestor_hierarchy_node_ids` | `array` | Folder/space nodes (OpenSearch only) | | `persona_id_filter` | `personas` | `array` | Persona tag for overflowing user files (**primary** trigger) | | `project_id_filter` | `user_project` | `array` | Project tag for overflowing project files (**additive** only) | | `access_control_list` | `access_control_list` | `weightedset` | ACL entries for the requesting user | | `source_type` | `source_type` | `string` | Connector source type (e.g. `web`, `jira`) | | `tags` | `metadata_list` | `array` | Document metadata tags | | `time_cutoff` | `doc_updated_at` | `long` | Minimum document update timestamp | | `tenant_id` | `tenant_id` | `string` | Tenant isolation (multi-tenant) | ================================================ FILE: backend/onyx/document_index/__init__.py ================================================ ================================================ FILE: backend/onyx/document_index/chunk_content_enrichment.py ================================================ from onyx.configs.app_configs import BLURB_SIZE from onyx.configs.constants import RETURN_SEPARATOR from onyx.context.search.models import InferenceChunk from onyx.context.search.models import InferenceChunkUncleaned from onyx.indexing.models import DocAwareChunk from onyx.indexing.models import DocMetadataAwareIndexChunk def generate_enriched_content_for_chunk_text(chunk: DocMetadataAwareIndexChunk) -> str: return f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_keyword}" def generate_enriched_content_for_chunk_embedding(chunk: DocAwareChunk) -> str: return f"{chunk.title_prefix}{chunk.doc_summary}{chunk.content}{chunk.chunk_context}{chunk.metadata_suffix_semantic}" def cleanup_content_for_chunks( chunks: list[InferenceChunkUncleaned], ) -> list[InferenceChunk]: """ Removes indexing-time content additions from chunks. Inverse of generate_enriched_content_for_chunk. During indexing, chunks are augmented with additional text to improve search quality: - Title prepended to content (for better keyword/semantic matching) - Metadata suffix appended to content - Contextual RAG: doc_summary (beginning) and chunk_context (end) This function strips these additions before returning chunks to users, restoring the original document content. Cleaning is applied in sequence: 1. Title removal: - Full match: Strips exact title from beginning - Partial match: If content starts with title[:BLURB_SIZE], splits on RETURN_SEPARATOR to remove title section 2. Metadata suffix removal: - Strips metadata_suffix from end, plus trailing RETURN_SEPARATOR 3. Contextual RAG removal: - Strips doc_summary from beginning (if present) - Strips chunk_context from end (if present) TODO(andrei): This entire function is not that fantastic, clean it up during QA before rolling out OpenSearch. Args: chunks: Chunks as retrieved from the document index with indexing augmentations intact. Returns: Clean InferenceChunk objects with augmentations removed, containing only the original document content that should be shown to users. """ def _remove_title(chunk: InferenceChunkUncleaned) -> str: # TODO(andrei): This was ported over from # backend/onyx/document_index/vespa/vespa_document_index.py but I don't # think this logic is correct. In Vespa at least we set the title field # from the output of get_title_for_document_index, which is not # necessarily the same data that is prepended to the content; that comes # from title_prefix. # This was added in # https://github.com/onyx-dot-app/onyx/commit/e90c66c1b61c5b7da949652d703f7c906863e6e4#diff-2a2a29d5929de75cdaea77867a397934d9f8b785ce40a861c0d704033e3663ab, # see postprocessing.py. At that time the content enrichment logic was # also added in that commit, see # https://github.com/onyx-dot-app/onyx/commit/e90c66c1b61c5b7da949652d703f7c906863e6e4#diff-d807718aa263a15c1d991a4ab063c360c8419eaad210b4ba70e1e9f47d2aa6d2R77 # chunker.py. if not chunk.title or not chunk.content: return chunk.content if chunk.content.startswith(chunk.title): return chunk.content[len(chunk.title) :].lstrip() # BLURB SIZE is by token instead of char but each token is at least 1 char # If this prefix matches the content, it's assumed the title was prepended if chunk.content.startswith(chunk.title[:BLURB_SIZE]): return ( chunk.content.split(RETURN_SEPARATOR, 1)[-1] if RETURN_SEPARATOR in chunk.content else chunk.content ) return chunk.content def _remove_metadata_suffix(chunk: InferenceChunkUncleaned) -> str: if not chunk.metadata_suffix: return chunk.content return chunk.content.removesuffix(chunk.metadata_suffix).rstrip( RETURN_SEPARATOR ) def _remove_contextual_rag(chunk: InferenceChunkUncleaned) -> str: # remove document summary if chunk.doc_summary and chunk.content.startswith(chunk.doc_summary): chunk.content = chunk.content[len(chunk.doc_summary) :].lstrip() # remove chunk context if chunk.chunk_context and chunk.content.endswith(chunk.chunk_context): chunk.content = chunk.content[ : len(chunk.content) - len(chunk.chunk_context) ].rstrip() return chunk.content for chunk in chunks: chunk.content = _remove_title(chunk) chunk.content = _remove_metadata_suffix(chunk) chunk.content = _remove_contextual_rag(chunk) return [chunk.to_inference_chunk() for chunk in chunks] ================================================ FILE: backend/onyx/document_index/disabled.py ================================================ """A DocumentIndex implementation that raises on every operation. Used as a safety net when DISABLE_VECTOR_DB is True. Any code path that accidentally reaches the vector DB layer will fail loudly instead of timing out against a nonexistent Vespa/OpenSearch instance. """ from collections.abc import Iterable from typing import Any from onyx.context.search.models import IndexFilters from onyx.context.search.models import InferenceChunk from onyx.context.search.models import QueryExpansionType from onyx.db.enums import EmbeddingPrecision from onyx.document_index.interfaces import DocumentIndex from onyx.document_index.interfaces import DocumentInsertionRecord from onyx.document_index.interfaces import IndexBatchParams from onyx.document_index.interfaces import VespaChunkRequest from onyx.document_index.interfaces import VespaDocumentFields from onyx.document_index.interfaces import VespaDocumentUserFields from onyx.indexing.models import DocMetadataAwareIndexChunk from shared_configs.model_server_models import Embedding VECTOR_DB_DISABLED_ERROR = "Vector DB is disabled (DISABLE_VECTOR_DB=true). This operation requires a vector database." class DisabledDocumentIndex(DocumentIndex): """A DocumentIndex where every method raises RuntimeError. Returned by the factory when DISABLE_VECTOR_DB is True so that any accidental vector-DB call surfaces immediately. """ def __init__( self, index_name: str = "disabled", secondary_index_name: str | None = None, *args: Any, # noqa: ARG002 **kwargs: Any, # noqa: ARG002 ) -> None: self.index_name = index_name self.secondary_index_name = secondary_index_name # ------------------------------------------------------------------ # Verifiable # ------------------------------------------------------------------ def ensure_indices_exist( self, primary_embedding_dim: int, # noqa: ARG002 primary_embedding_precision: EmbeddingPrecision, # noqa: ARG002 secondary_index_embedding_dim: int | None, # noqa: ARG002 secondary_index_embedding_precision: EmbeddingPrecision | None, # noqa: ARG002 ) -> None: # No-op: there are no indices to create when the vector DB is disabled. pass @staticmethod def register_multitenant_indices( indices: list[str], # noqa: ARG002, ARG004 embedding_dims: list[int], # noqa: ARG002, ARG004 embedding_precisions: list[EmbeddingPrecision], # noqa: ARG002, ARG004 ) -> None: raise RuntimeError(VECTOR_DB_DISABLED_ERROR) # ------------------------------------------------------------------ # Indexable # ------------------------------------------------------------------ def index( self, chunks: Iterable[DocMetadataAwareIndexChunk], # noqa: ARG002 index_batch_params: IndexBatchParams, # noqa: ARG002 ) -> set[DocumentInsertionRecord]: raise RuntimeError(VECTOR_DB_DISABLED_ERROR) # ------------------------------------------------------------------ # Deletable # ------------------------------------------------------------------ def delete_single( self, doc_id: str, # noqa: ARG002 *, tenant_id: str, # noqa: ARG002 chunk_count: int | None, # noqa: ARG002 ) -> int: raise RuntimeError(VECTOR_DB_DISABLED_ERROR) # ------------------------------------------------------------------ # Updatable # ------------------------------------------------------------------ def update_single( self, doc_id: str, # noqa: ARG002 *, tenant_id: str, # noqa: ARG002 chunk_count: int | None, # noqa: ARG002 fields: VespaDocumentFields | None, # noqa: ARG002 user_fields: VespaDocumentUserFields | None, # noqa: ARG002 ) -> None: raise RuntimeError(VECTOR_DB_DISABLED_ERROR) # ------------------------------------------------------------------ # IdRetrievalCapable # ------------------------------------------------------------------ def id_based_retrieval( self, chunk_requests: list[VespaChunkRequest], # noqa: ARG002 filters: IndexFilters, # noqa: ARG002 batch_retrieval: bool = False, # noqa: ARG002 ) -> list[InferenceChunk]: raise RuntimeError(VECTOR_DB_DISABLED_ERROR) # ------------------------------------------------------------------ # HybridCapable # ------------------------------------------------------------------ def hybrid_retrieval( self, query: str, # noqa: ARG002 query_embedding: Embedding, # noqa: ARG002 final_keywords: list[str] | None, # noqa: ARG002 filters: IndexFilters, # noqa: ARG002 hybrid_alpha: float, # noqa: ARG002 time_decay_multiplier: float, # noqa: ARG002 num_to_retrieve: int, # noqa: ARG002 ranking_profile_type: QueryExpansionType, # noqa: ARG002 title_content_ratio: float | None = None, # noqa: ARG002 ) -> list[InferenceChunk]: raise RuntimeError(VECTOR_DB_DISABLED_ERROR) # ------------------------------------------------------------------ # AdminCapable # ------------------------------------------------------------------ def admin_retrieval( self, query: str, # noqa: ARG002 query_embedding: Embedding, # noqa: ARG002 filters: IndexFilters, # noqa: ARG002 num_to_retrieve: int = 10, # noqa: ARG002 ) -> list[InferenceChunk]: raise RuntimeError(VECTOR_DB_DISABLED_ERROR) # ------------------------------------------------------------------ # RandomCapable # ------------------------------------------------------------------ def random_retrieval( self, filters: IndexFilters, # noqa: ARG002 num_to_retrieve: int = 10, # noqa: ARG002 ) -> list[InferenceChunk]: raise RuntimeError(VECTOR_DB_DISABLED_ERROR) ================================================ FILE: backend/onyx/document_index/document_index_utils.py ================================================ import math import uuid from uuid import UUID from sqlalchemy.orm import Session from onyx.configs.app_configs import ENABLE_MULTIPASS_INDEXING from onyx.db.models import SearchSettings from onyx.db.search_settings import get_current_search_settings from onyx.db.search_settings import get_secondary_search_settings from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo from onyx.indexing.models import DocMetadataAwareIndexChunk from onyx.indexing.models import MultipassConfig from shared_configs.configs import MULTI_TENANT DEFAULT_BATCH_SIZE = 30 DEFAULT_INDEX_NAME = "danswer_chunk" def should_use_multipass(search_settings: SearchSettings | None) -> bool: """ Determines whether multipass should be used based on the search settings or the default config if settings are unavailable. """ if search_settings is not None: return search_settings.multipass_indexing return ENABLE_MULTIPASS_INDEXING def get_multipass_config(search_settings: SearchSettings) -> MultipassConfig: """ Determines whether to enable multipass and large chunks by examining the current search settings and the embedder configuration. """ multipass = should_use_multipass(search_settings) enable_large_chunks = SearchSettings.can_use_large_chunks( multipass, search_settings.model_name, search_settings.provider_type ) return MultipassConfig( multipass_indexing=multipass, enable_large_chunks=enable_large_chunks ) def get_both_index_properties( db_session: Session, ) -> tuple[str, str | None, bool, bool | None]: search_settings = get_current_search_settings(db_session) config_1 = get_multipass_config(search_settings) search_settings_new = get_secondary_search_settings(db_session) if not search_settings_new: return search_settings.index_name, None, config_1.enable_large_chunks, None config_2 = get_multipass_config(search_settings) return ( search_settings.index_name, search_settings_new.index_name, config_1.enable_large_chunks, config_2.enable_large_chunks, ) def translate_boost_count_to_multiplier(boost: int) -> float: """Mapping boost integer values to a multiplier according to a sigmoid curve Piecewise such that at many downvotes, its 0.5x the score and with many upvotes it is 2x the score. This should be in line with the Vespa calculation.""" # 3 in the equation below stretches it out to hit asymptotes slower if boost < 0: # 0.5 + sigmoid -> range of 0.5 to 1 return 0.5 + (1 / (1 + math.exp(-1 * boost / 3))) # 2 x sigmoid -> range of 1 to 2 return 2 / (1 + math.exp(-1 * boost / 3)) # Assembles a list of Vespa chunk IDs for a document # given the required context. This can be used to directly query # Vespa's Document API. def get_document_chunk_ids( enriched_document_info_list: list[EnrichedDocumentIndexingInfo], tenant_id: str, large_chunks_enabled: bool, ) -> list[UUID]: doc_chunk_ids = [] for enriched_document_info in enriched_document_info_list: for chunk_index in range( enriched_document_info.chunk_start_index, enriched_document_info.chunk_end_index, ): if not enriched_document_info.old_version: doc_chunk_ids.append( get_uuid_from_chunk_info( document_id=enriched_document_info.doc_id, chunk_id=chunk_index, tenant_id=tenant_id, ) ) else: doc_chunk_ids.append( get_uuid_from_chunk_info_old( document_id=enriched_document_info.doc_id, chunk_id=chunk_index, ) ) if large_chunks_enabled and chunk_index % 4 == 0: large_chunk_id = int(chunk_index / 4) large_chunk_reference_ids = [ large_chunk_id + i for i in range(4) if large_chunk_id + i < enriched_document_info.chunk_end_index ] if enriched_document_info.old_version: doc_chunk_ids.append( get_uuid_from_chunk_info_old( document_id=enriched_document_info.doc_id, chunk_id=large_chunk_id, large_chunk_reference_ids=large_chunk_reference_ids, ) ) else: doc_chunk_ids.append( get_uuid_from_chunk_info( document_id=enriched_document_info.doc_id, chunk_id=large_chunk_id, tenant_id=tenant_id, large_chunk_id=large_chunk_id, ) ) return doc_chunk_ids def get_uuid_from_chunk_info( *, document_id: str, chunk_id: int, tenant_id: str, large_chunk_id: int | None = None, ) -> UUID: """NOTE: be VERY carefuly about changing this function. If changed without a migration, this can cause deletion/update/insertion to function incorrectly.""" doc_str = document_id # Web parsing URL duplicate catching if doc_str and doc_str[-1] == "/": doc_str = doc_str[:-1] chunk_index = ( "large_" + str(large_chunk_id) if large_chunk_id is not None else str(chunk_id) ) unique_identifier_string = "_".join([doc_str, chunk_index]) if MULTI_TENANT: unique_identifier_string += "_" + tenant_id uuid_value = uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string) return uuid_value def get_uuid_from_chunk_info_old( *, document_id: str, chunk_id: int, large_chunk_reference_ids: list[int] = [] ) -> UUID: doc_str = document_id # Web parsing URL duplicate catching if doc_str and doc_str[-1] == "/": doc_str = doc_str[:-1] unique_identifier_string = "_".join([doc_str, str(chunk_id), "0"]) if large_chunk_reference_ids: unique_identifier_string += "_large" + "_".join( [ str(referenced_chunk_id) for referenced_chunk_id in large_chunk_reference_ids ] ) return uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string) def get_uuid_from_chunk(chunk: DocMetadataAwareIndexChunk) -> uuid.UUID: return get_uuid_from_chunk_info( document_id=chunk.source_document.id, chunk_id=chunk.chunk_id, tenant_id=chunk.tenant_id, large_chunk_id=chunk.large_chunk_id, ) def get_uuid_from_chunk_old( chunk: DocMetadataAwareIndexChunk, large_chunk_reference_ids: list[int] = [] ) -> UUID: return get_uuid_from_chunk_info_old( document_id=chunk.source_document.id, chunk_id=chunk.chunk_id, large_chunk_reference_ids=large_chunk_reference_ids, ) ================================================ FILE: backend/onyx/document_index/factory.py ================================================ import httpx from sqlalchemy.orm import Session from onyx.configs.app_configs import DISABLE_VECTOR_DB from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX from onyx.db.models import SearchSettings from onyx.db.opensearch_migration import get_opensearch_retrieval_state from onyx.document_index.disabled import DisabledDocumentIndex from onyx.document_index.interfaces import DocumentIndex from onyx.document_index.opensearch.opensearch_document_index import ( OpenSearchOldDocumentIndex, ) from onyx.document_index.vespa.index import VespaIndex from onyx.indexing.models import IndexingSetting from shared_configs.configs import MULTI_TENANT def get_default_document_index( search_settings: SearchSettings, secondary_search_settings: SearchSettings | None, db_session: Session, httpx_client: httpx.Client | None = None, ) -> DocumentIndex: """Gets the default document index from env vars. To be used for retrieval only. Indexing should be done through both indices until Vespa is deprecated. Primary index is the index that is used for querying/updating etc. Secondary index is for when both the currently used index and the upcoming index both need to be updated. Updates are applied to both indices. WARNING: In that case, get_all_document_indices should be used. """ if DISABLE_VECTOR_DB: return DisabledDocumentIndex( index_name=search_settings.index_name, secondary_index_name=( secondary_search_settings.index_name if secondary_search_settings else None ), ) secondary_index_name: str | None = None secondary_large_chunks_enabled: bool | None = None if secondary_search_settings: secondary_index_name = secondary_search_settings.index_name secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled opensearch_retrieval_enabled = get_opensearch_retrieval_state(db_session) if opensearch_retrieval_enabled: indexing_setting = IndexingSetting.from_db_model(search_settings) secondary_indexing_setting = ( IndexingSetting.from_db_model(secondary_search_settings) if secondary_search_settings else None ) return OpenSearchOldDocumentIndex( index_name=search_settings.index_name, embedding_dim=indexing_setting.final_embedding_dim, embedding_precision=indexing_setting.embedding_precision, secondary_index_name=secondary_index_name, secondary_embedding_dim=( secondary_indexing_setting.final_embedding_dim if secondary_indexing_setting else None ), secondary_embedding_precision=( secondary_indexing_setting.embedding_precision if secondary_indexing_setting else None ), large_chunks_enabled=search_settings.large_chunks_enabled, secondary_large_chunks_enabled=secondary_large_chunks_enabled, multitenant=MULTI_TENANT, httpx_client=httpx_client, ) else: return VespaIndex( index_name=search_settings.index_name, secondary_index_name=secondary_index_name, large_chunks_enabled=search_settings.large_chunks_enabled, secondary_large_chunks_enabled=secondary_large_chunks_enabled, multitenant=MULTI_TENANT, httpx_client=httpx_client, ) def get_all_document_indices( search_settings: SearchSettings, secondary_search_settings: SearchSettings | None, httpx_client: httpx.Client | None = None, ) -> list[DocumentIndex]: """Gets all document indices. NOTE: Will only return an OpenSearch index interface if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX is True. This is so we don't break flows where we know it won't be enabled. Used for indexing only. Until Vespa is deprecated we will index into both document indices. Retrieval is done through only one index however. Large chunks are not currently supported so we hardcode appropriate values. NOTE: Make sure the Vespa index object is returned first. In the rare event that there is some conflict between indexing and the migration task, it is assumed that the state of Vespa is more up-to-date than the state of OpenSearch. """ if DISABLE_VECTOR_DB: return [ DisabledDocumentIndex( index_name=search_settings.index_name, secondary_index_name=( secondary_search_settings.index_name if secondary_search_settings else None ), ) ] vespa_document_index = VespaIndex( index_name=search_settings.index_name, secondary_index_name=( secondary_search_settings.index_name if secondary_search_settings else None ), large_chunks_enabled=search_settings.large_chunks_enabled, secondary_large_chunks_enabled=( secondary_search_settings.large_chunks_enabled if secondary_search_settings else None ), multitenant=MULTI_TENANT, httpx_client=httpx_client, ) opensearch_document_index: OpenSearchOldDocumentIndex | None = None if ENABLE_OPENSEARCH_INDEXING_FOR_ONYX: indexing_setting = IndexingSetting.from_db_model(search_settings) secondary_indexing_setting = ( IndexingSetting.from_db_model(secondary_search_settings) if secondary_search_settings else None ) opensearch_document_index = OpenSearchOldDocumentIndex( index_name=search_settings.index_name, embedding_dim=indexing_setting.final_embedding_dim, embedding_precision=indexing_setting.embedding_precision, secondary_index_name=( secondary_search_settings.index_name if secondary_search_settings else None ), secondary_embedding_dim=( secondary_indexing_setting.final_embedding_dim if secondary_indexing_setting else None ), secondary_embedding_precision=( secondary_indexing_setting.embedding_precision if secondary_indexing_setting else None ), large_chunks_enabled=search_settings.large_chunks_enabled, secondary_large_chunks_enabled=( secondary_search_settings.large_chunks_enabled if secondary_search_settings else None ), multitenant=MULTI_TENANT, httpx_client=httpx_client, ) result: list[DocumentIndex] = [vespa_document_index] if opensearch_document_index: result.append(opensearch_document_index) return result ================================================ FILE: backend/onyx/document_index/interfaces.py ================================================ import abc from collections.abc import Iterable from dataclasses import dataclass from datetime import datetime from typing import Any from onyx.access.models import DocumentAccess from onyx.access.models import ExternalAccess from onyx.configs.chat_configs import NUM_RETURNED_HITS from onyx.configs.chat_configs import TITLE_CONTENT_RATIO from onyx.context.search.models import IndexFilters from onyx.context.search.models import InferenceChunk from onyx.context.search.models import QueryExpansionType from onyx.db.enums import EmbeddingPrecision from onyx.indexing.models import DocMetadataAwareIndexChunk from shared_configs.model_server_models import Embedding @dataclass(frozen=True) class DocumentInsertionRecord: document_id: str already_existed: bool @dataclass(frozen=True) class VespaChunkRequest: document_id: str min_chunk_ind: int | None = None max_chunk_ind: int | None = None @property def is_capped(self) -> bool: # If the max chunk index is not None, then the chunk request is capped # If the min chunk index is None, we can assume the min is 0 return self.max_chunk_ind is not None @property def range(self) -> int | None: if self.max_chunk_ind is not None: return (self.max_chunk_ind - (self.min_chunk_ind or 0)) + 1 return None @dataclass class IndexBatchParams: """ Information necessary for efficiently indexing a batch of documents """ doc_id_to_previous_chunk_cnt: dict[str, int] doc_id_to_new_chunk_cnt: dict[str, int] tenant_id: str large_chunks_enabled: bool @dataclass class MinimalDocumentIndexingInfo: """ Minimal information necessary for indexing a document """ doc_id: str chunk_start_index: int @dataclass class EnrichedDocumentIndexingInfo(MinimalDocumentIndexingInfo): """ Enriched information necessary for indexing a document, including version and chunk range. """ old_version: bool chunk_end_index: int @dataclass class DocumentMetadata: """ Document information that needs to be inserted into Postgres on first time encountering this document during indexing across any of the connectors. """ connector_id: int credential_id: int document_id: str semantic_identifier: str first_link: str doc_updated_at: datetime | None = None # Emails, not necessarily attached to users # Users may not be in Onyx primary_owners: list[str] | None = None secondary_owners: list[str] | None = None from_ingestion_api: bool = False external_access: ExternalAccess | None = None doc_metadata: dict[str, Any] | None = None # The resolved database ID of the parent hierarchy node (folder/container) parent_hierarchy_node_id: int | None = None @dataclass class VespaDocumentFields: """ Specifies fields in Vespa for a document. Fields set to None will be ignored. Perhaps we should name this in an implementation agnostic fashion, but it's more understandable like this for now. """ # all other fields except these 4 will always be left alone by the update request access: DocumentAccess | None = None document_sets: set[str] | None = None boost: float | None = None hidden: bool | None = None aggregated_chunk_boost_factor: float | None = None @dataclass class VespaDocumentUserFields: """ Fields that are specific to the user who is indexing the document. """ user_projects: list[int] | None = None personas: list[int] | None = None @dataclass class UpdateRequest: """ For all document_ids, update the allowed_users and the boost to the new values Does not update any of the None fields """ minimal_document_indexing_info: list[MinimalDocumentIndexingInfo] # all other fields except these 4 will always be left alone by the update request access: DocumentAccess | None = None document_sets: set[str] | None = None boost: float | None = None hidden: bool | None = None class Verifiable(abc.ABC): """ Class must implement document index schema verification. For example, verify that all of the necessary attributes for indexing, querying, filtering, and fields to return from search are all valid in the schema. Parameters: - index_name: The name of the primary index currently used for querying - secondary_index_name: The name of the secondary index being built in the background, if it currently exists. Some functions on the document index act on both the primary and secondary index, some act on just one. """ @abc.abstractmethod def __init__( self, index_name: str, secondary_index_name: str | None, *args: Any, **kwargs: Any, ) -> None: super().__init__(*args, **kwargs) self.index_name = index_name self.secondary_index_name = secondary_index_name @abc.abstractmethod def ensure_indices_exist( self, primary_embedding_dim: int, primary_embedding_precision: EmbeddingPrecision, secondary_index_embedding_dim: int | None, secondary_index_embedding_precision: EmbeddingPrecision | None, ) -> None: """ Verify that the document index exists and is consistent with the expectations in the code. Parameters: - primary_embedding_dim: Vector dimensionality for the vector similarity part of the search - primary_embedding_precision: Precision of the vector similarity part of the search - secondary_index_embedding_dim: Vector dimensionality of the secondary index being built behind the scenes. The secondary index should only be built when switching embedding models therefore this dim should be different from the primary index. - secondary_index_embedding_precision: Precision of the vector similarity part of the secondary index """ raise NotImplementedError @staticmethod @abc.abstractmethod def register_multitenant_indices( indices: list[str], embedding_dims: list[int], embedding_precisions: list[EmbeddingPrecision], ) -> None: """ Register multitenant indices with the document index. """ raise NotImplementedError class Indexable(abc.ABC): """ Class must implement the ability to index document chunks """ @abc.abstractmethod def index( self, chunks: Iterable[DocMetadataAwareIndexChunk], index_batch_params: IndexBatchParams, ) -> set[DocumentInsertionRecord]: """ Takes a list of document chunks and indexes them in the document index NOTE: When a document is reindexed/updated here, it must clear all of the existing document chunks before reindexing. This is because the document may have gotten shorter since the last run. Therefore, upserting the first 0 through n chunks may leave some old chunks that have not been written over. NOTE: The chunks of a document are never separated into separate index() calls. So there is no worry of receiving the first 0 through n chunks in one index call and the next n through m chunks of a docu in the next index call. NOTE: Due to some asymmetry between the primary and secondary indexing logic, this function only needs to index chunks into the PRIMARY index. Do not update the secondary index here, it is done automatically outside of this code. Parameters: - chunks: Document chunks with all of the information needed for indexing to the document index. - tenant_id: The tenant id of the user whose chunks are being indexed - large_chunks_enabled: Whether large chunks are enabled Returns: List of document ids which map to unique documents and are used for deduping chunks when updating, as well as if the document is newly indexed or already existed and just updated """ raise NotImplementedError class Deletable(abc.ABC): """ Class must implement the ability to delete document by a given unique document id. """ @abc.abstractmethod def delete_single( self, doc_id: str, *, tenant_id: str, chunk_count: int | None, ) -> int: """ Given a single document id, hard delete it from the document index Parameters: - doc_id: document id as specified by the connector """ raise NotImplementedError class Updatable(abc.ABC): """ Class must implement the ability to update certain attributes of a document without needing to update all of the fields. Specifically, needs to be able to update: - Access Control List - Document-set membership - Boost value (learning from feedback mechanism) - Whether the document is hidden or not, hidden documents are not returned from search """ @abc.abstractmethod def update_single( self, doc_id: str, *, tenant_id: str, chunk_count: int | None, fields: VespaDocumentFields | None, user_fields: VespaDocumentUserFields | None, ) -> None: """ Updates all chunks for a document with the specified fields. None values mean that the field does not need an update. The rationale for a single update function is that it allows retries and parallelism to happen at a higher / more strategic level, is simpler to read, and allows us to individually handle error conditions per document. Parameters: - fields: the fields to update in the document. Any field set to None will not be changed. Return: None """ raise NotImplementedError class IdRetrievalCapable(abc.ABC): """ Class must implement the ability to retrieve either: - all of the chunks of a document IN ORDER given a document id. - a specific chunk given a document id and a chunk index (0 based) """ @abc.abstractmethod def id_based_retrieval( self, chunk_requests: list[VespaChunkRequest], filters: IndexFilters, batch_retrieval: bool = False, ) -> list[InferenceChunk]: """ Fetch chunk(s) based on document id NOTE: This is used to reconstruct a full document or an extended (multi-chunk) section of a document. Downstream currently assumes that the chunking does not introduce overlaps between the chunks. If there are overlaps for the chunks, then the reconstructed document or extended section will have duplicate segments. Parameters: - chunk_requests: requests containing the document id and the chunk range to retrieve - filters: Filters to apply to retrieval - batch_retrieval: If True, perform a batch retrieval Returns: list of chunks for the document id or the specific chunk by the specified chunk index and document id """ raise NotImplementedError class HybridCapable(abc.ABC): """ Class must implement hybrid (keyword + vector) search functionality """ @abc.abstractmethod def hybrid_retrieval( self, query: str, query_embedding: Embedding, final_keywords: list[str] | None, filters: IndexFilters, hybrid_alpha: float, time_decay_multiplier: float, num_to_retrieve: int, ranking_profile_type: QueryExpansionType, title_content_ratio: float | None = TITLE_CONTENT_RATIO, ) -> list[InferenceChunk]: """ Run hybrid search and return a list of inference chunks. NOTE: the query passed in here is the unprocessed plain text query. Preprocessing is expected to be handled by this function as it may depend on the index implementation. Things like query expansion, synonym injection, stop word removal, lemmatization, etc. are done here. Parameters: - query: unmodified user query. This is needed for getting the matching highlighted keywords - query_embedding: vector representation of the query, must be of the correct dimensionality for the primary index - final_keywords: Final keywords to be used from the query, defaults to query if not set - filters: standard filter object - hybrid_alpha: weighting between the keyword and vector search results. It is important that the two scores are normalized to the same range so that a meaningful comparison can be made. 1 for 100% weighting on vector score, 0 for 100% weighting on keyword score. - time_decay_multiplier: how much to decay the document scores as they age. Some queries based on the persona settings, will have this be a 2x or 3x of the default - num_to_retrieve: number of highest matching chunks to return Returns: best matching chunks based on weighted sum of keyword and vector/semantic search scores """ raise NotImplementedError class AdminCapable(abc.ABC): """ Class must implement a search for the admin "Explorer" page. The assumption here is that the admin is not "searching" for knowledge but has some document already in mind. They are either looking to positively boost it because they know it's a good reference document, looking to negatively boost it as a way of "deprecating", or hiding the document. Assuming the admin knows the document name, this search has high emphasis on the title match. Suggested implementation: Keyword only, BM25 search with 5x weighting on the title field compared to the contents """ @abc.abstractmethod def admin_retrieval( self, query: str, query_embedding: Embedding, filters: IndexFilters, num_to_retrieve: int = NUM_RETURNED_HITS, ) -> list[InferenceChunk]: """ Run the special search for the admin document explorer page Parameters: - query: unmodified user query. Though in this flow probably unmodified is best - filters: standard filter object - num_to_retrieve: number of highest matching chunks to return Returns: list of best matching chunks for the explorer page query """ raise NotImplementedError class RandomCapable(abc.ABC): """Class must implement random document retrieval capability""" @abc.abstractmethod def random_retrieval( self, filters: IndexFilters, num_to_retrieve: int = 10, ) -> list[InferenceChunk]: """Retrieve random chunks matching the filters""" raise NotImplementedError class BaseIndex( Verifiable, Indexable, Updatable, Deletable, AdminCapable, IdRetrievalCapable, RandomCapable, abc.ABC, ): """ All basic document index functionalities excluding the actual querying approach. As a summary, document indices need to be able to - Verify the schema definition is valid - Index new documents - Update specific attributes of existing documents - Delete documents - Provide a search for the admin document explorer page - Retrieve documents based on document id """ class DocumentIndex(HybridCapable, BaseIndex, abc.ABC): """ A valid document index that can plug into all Onyx flows must implement all of these functionalities, though "technically" it does not need to be keyword or vector capable as currently all default search flows use Hybrid Search. """ ================================================ FILE: backend/onyx/document_index/interfaces_new.py ================================================ import abc from collections.abc import Iterable from typing import Self from pydantic import BaseModel from pydantic import model_validator from onyx.access.models import DocumentAccess from onyx.configs.constants import PUBLIC_DOC_PAT from onyx.context.search.enums import QueryType from onyx.context.search.models import IndexFilters from onyx.context.search.models import InferenceChunk from onyx.db.enums import EmbeddingPrecision from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE from onyx.indexing.models import DocMetadataAwareIndexChunk from shared_configs.model_server_models import Embedding # NOTE: "Document" in the naming convention is used to refer to the entire # document as represented in Onyx. What is actually stored in the index is the # document chunks. By the terminology of most search engines / vector databases, # the individual objects stored are called documents, but in this case it refers # to a chunk. __all__ = [ # Main interfaces - these are what you should inherit from "DocumentIndex", # Data models - used in method signatures "DocumentInsertionRecord", "DocumentSectionRequest", "IndexingMetadata", "MetadataUpdateRequest", # Capability mixins - for custom compositions or type checking "SchemaVerifiable", "Indexable", "Deletable", "Updatable", "IdRetrievalCapable", "HybridCapable", "RandomCapable", ] class TenantState(BaseModel): """ Captures the tenant-related state for an instance of DocumentIndex. NOTE: Tenant ID must be set in multitenant mode. """ model_config = {"frozen": True} tenant_id: str multitenant: bool def __str__(self) -> str: return ( f"TenantState(tenant_id={self.tenant_id}, multitenant={self.multitenant})" ) @model_validator(mode="after") def check_tenant_id_is_set_in_multitenant_mode(self) -> Self: if self.multitenant and not self.tenant_id: raise ValueError("Bug: Tenant ID must be set in multitenant mode.") return self class DocumentInsertionRecord(BaseModel): """ Result of indexing a document. """ model_config = {"frozen": True} document_id: str already_existed: bool class DocumentSectionRequest(BaseModel): """Request for a document section or whole document. If no min_chunk_ind is provided it should start at the beginning of the document. If no max_chunk_ind is provided it should go to the end of the document. """ model_config = {"frozen": True} document_id: str min_chunk_ind: int | None = None max_chunk_ind: int | None = None # A given document can have multiple chunking strategies. max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE @model_validator(mode="after") def check_chunk_index_range_is_valid(self) -> Self: if ( self.min_chunk_ind is not None and self.max_chunk_ind is not None and self.min_chunk_ind > self.max_chunk_ind ): raise ValueError( "Bug: Min chunk index must be less than or equal to max chunk index." ) return self class IndexingMetadata(BaseModel): """ Information about chunk counts for efficient cleaning / updating of document chunks. A common pattern to ensure that no chunks are left over is to delete all of the chunks for a document and then re-index the document. This information allows us to only delete the extra "tail" chunks when the document has gotten shorter. """ class ChunkCounts(BaseModel): model_config = {"frozen": True} old_chunk_cnt: int new_chunk_cnt: int model_config = {"frozen": True} doc_id_to_chunk_cnt_diff: dict[str, ChunkCounts] class MetadataUpdateRequest(BaseModel): """ Updates to the documents that can happen without there being an update to the contents of the document. """ model_config = {"frozen": True} document_ids: list[str] # Passed in to help with potential optimizations of the implementation. The # keys should be redundant with document_ids. # NOTE: Generally the chunk count should always be known, however for # documents still using the legacy chunk ID system it may not be. Any chunk # count value < 0 should represent an unknown chunk count. doc_id_to_chunk_cnt: dict[str, int] # For the ones that are None, there is no update required to that field. access: DocumentAccess | None = None document_sets: set[str] | None = None boost: float | None = None hidden: bool | None = None secondary_index_updated: bool | None = None project_ids: set[int] | None = None persona_ids: set[int] | None = None class IndexRetrievalFilters(BaseModel): """ Filters for retrieving chunks from the index. Used to filter on permissions and other Onyx-specific metadata rather than chunk content. Should be passed in for every retrieval method. TODO(andrei): Currently unused, use this when making retrieval methods more strict. """ model_config = {"frozen": True} # frozenset gets around the issue of python's mutable defaults. # WARNING: Falls back to only public docs as default for security. If # callers want no access filtering they must explicitly supply an empty set. # Doing so should be done sparingly. access_control_list: frozenset[str] = frozenset({PUBLIC_DOC_PAT}) class SchemaVerifiable(abc.ABC): """ Class must implement document index schema verification. For example, verify that all of the necessary attributes for indexing, querying, filtering, and fields to return from search are all valid in the schema. """ @abc.abstractmethod def verify_and_create_index_if_necessary( self, embedding_dim: int, embedding_precision: EmbeddingPrecision, ) -> None: """ Verifies that the document index exists and is consistent with the expectations in the code. For certain search engines, the schema needs to be created before indexing can happen. This call should create the schema if it does not exist. Args: embedding_dim: Vector dimensionality for the vector similarity part of the search. embedding_precision: Precision of the values of the vectors for the similarity part of the search. """ raise NotImplementedError class Indexable(abc.ABC): """ Class must implement the ability to index document chunks. """ @abc.abstractmethod def index( self, chunks: Iterable[DocMetadataAwareIndexChunk], indexing_metadata: IndexingMetadata, ) -> list[DocumentInsertionRecord]: """Indexes an iterable of document chunks into the document index. This is often a batch operation including chunks from multiple documents. NOTE: When a document is reindexed/updated here and has gotten shorter, it is important to delete the extra chunks at the end to ensure there are no stale chunks in the index. The implementation should do this. NOTE: The chunks of a document are never separated into separate index() calls. So there is no worry of receiving the first 0 through n chunks in one index call and the next n through m chunks of a document in the next index call. Args: chunks: Document chunks with all of the information needed for indexing to the document index. indexing_metadata: Information about chunk counts for efficient cleaning / updating. Returns: List of document IDs which map to unique documents as well as if the document is newly indexed or had already existed and was just updated. """ raise NotImplementedError class Deletable(abc.ABC): """ Class must implement the ability to delete a document by a given unique document ID. """ @abc.abstractmethod def delete( self, # TODO(andrei): Fine for now but this can probably be a batch operation # that takes in a list of IDs. document_id: str, chunk_count: int | None = None, # TODO(andrei): Shouldn't this also have some acl filtering at minimum? ) -> int: """ Hard deletes all of the chunks for the corresponding document in the document index. TODO(andrei): Not a pressing issue now but think about what we want the contract of this method to be in the event the specified document ID does not exist. Args: document_id: The unique identifier for the document as represented in Onyx, not necessarily in the document index. chunk_count: The number of chunks in the document. May be useful for improving the efficiency of the delete operation. Defaults to None. Returns: The number of chunks deleted. """ raise NotImplementedError class Updatable(abc.ABC): """ Class must implement the ability to update certain attributes of a document without needing to update all of the fields. Specifically, needs to be able to update: - Access Control List - Document-set membership - Boost value (learning from feedback mechanism) - Whether the document is hidden or not; hidden documents are not returned from search - Which Projects the document is a part of """ @abc.abstractmethod def update( self, update_requests: list[MetadataUpdateRequest], ) -> None: """Updates some set of chunks. The document and fields to update are specified in the update requests. Each update request in the list applies its changes to a list of document IDs. None values mean that the field does not need an update. Args: update_requests: A list of update requests, each containing a list of document IDs and the fields to update. The field updates apply to all of the specified documents in each update request. """ raise NotImplementedError class IdRetrievalCapable(abc.ABC): """ Class must implement the ability to retrieve either: - All of the chunks of a document IN ORDER given a document ID. - A specific section (continuous set of chunks) for some document. """ @abc.abstractmethod def id_based_retrieval( self, chunk_requests: list[DocumentSectionRequest], # TODO(andrei): Make this more strict w.r.t. acl, temporary for now. filters: IndexFilters, # TODO(andrei): This is temporary, we will not expose this in the long # run. batch_retrieval: bool = False, # TODO(andrei): Add a param for whether to retrieve hidden docs. ) -> list[InferenceChunk]: """Fetches chunk(s) based on document ID. NOTE: This is used to reconstruct a full document or an extended (multi-chunk) section of a document. Downstream currently assumes that the chunking does not introduce overlaps between the chunks. If there are overlaps for the chunks, then the reconstructed document or extended section will have duplicate segments. Args: chunk_requests: Requests containing the document ID and the chunk range to retrieve. Returns: List of sections from the documents specified. """ raise NotImplementedError class HybridCapable(abc.ABC): """ Class must implement hybrid (keyword + vector) search functionality. """ @abc.abstractmethod def hybrid_retrieval( self, query: str, query_embedding: Embedding, # TODO(andrei): This param is not great design, get rid of it. final_keywords: list[str] | None, query_type: QueryType, # TODO(andrei): Make this more strict w.r.t. acl, temporary for now. filters: IndexFilters, num_to_retrieve: int, ) -> list[InferenceChunk]: """Runs hybrid search and returns a list of inference chunks. Args: query: Unmodified user query. This may be needed for getting the matching highlighted keywords or for logging purposes. query_embedding: Vector representation of the query. Must be of the correct dimensionality for the primary index. final_keywords: Final keywords to be used from the query; defaults to query if not set. query_type: Semantic or keyword type query; may use different scoring logic for each. filters: Filters for things like permissions, source type, time, etc. num_to_retrieve: Number of highest matching chunks to return. Returns: Score-ranked (highest first) list of highest matching chunks. """ raise NotImplementedError @abc.abstractmethod def keyword_retrieval( self, query: str, filters: IndexFilters, num_to_retrieve: int, ) -> list[InferenceChunk]: """Runs keyword-only search and returns a list of inference chunks. Args: query: User query. filters: Filters for things like permissions, source type, time, etc. num_to_retrieve: Number of highest matching chunks to return. Returns: Score-ranked (highest first) list of highest matching chunks. """ raise NotImplementedError @abc.abstractmethod def semantic_retrieval( self, query_embedding: Embedding, filters: IndexFilters, num_to_retrieve: int, ) -> list[InferenceChunk]: """Runs semantic-only search and returns a list of inference chunks. Args: query_embedding: Vector representation of the query. Must be of the correct dimensionality for the primary index. filters: Filters for things like permissions, source type, time, etc. num_to_retrieve: Number of highest matching chunks to return. Returns: Score-ranked (highest first) list of highest matching chunks. """ raise NotImplementedError class RandomCapable(abc.ABC): """ Class must implement random document retrieval. """ @abc.abstractmethod def random_retrieval( self, # TODO(andrei): Make this more strict w.r.t. acl, temporary for now. filters: IndexFilters, num_to_retrieve: int = 10, dirty: bool | None = None, ) -> list[InferenceChunk]: """Retrieves random chunks matching the filters. Args: filters: Filters for things like permissions, source type, time, etc. num_to_retrieve: Number of chunks to retrieve. Defaults to 10. dirty: If set, retrieve chunks whose "dirty" flag matches this argument. If None, there is no restriction on retrieved chunks with respect to that flag. A chunk is considered dirty if there is a secondary index but the chunk's state has not been ported over to it yet. Defaults to None. Returns: List of chunks matching the filters. """ raise NotImplementedError class DocumentIndex( SchemaVerifiable, Indexable, Updatable, Deletable, HybridCapable, IdRetrievalCapable, RandomCapable, abc.ABC, ): """ A valid document index that can plug into all Onyx flows must implement all of these functionalities. As a high-level summary, document indices need to be able to: - Verify the schema definition is valid - Index new documents - Update specific attributes of existing documents - Delete documents - Run hybrid search - Retrieve document or sections of documents based on document id - Retrieve sets of random documents """ ================================================ FILE: backend/onyx/document_index/opensearch/README.md ================================================ # Opensearch Idiosyncrasies ## How it works at a high level Opensearch has 2 phases, a `Search` phase and a `Fetch` phase. The `Search` phase works by getting the document scores on each shard separately, then typically a fetch phase grabs all of the relevant fields/data for returning to the user. There is also an intermediate phase (seemingly built specifically to handle hybrid search queries) which can run in between as a processor. References: https://docs.opensearch.org/latest/search-plugins/search-pipelines/search-processors/ https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/ https://docs.opensearch.org/latest/query-dsl/compound/hybrid/ ## How Hybrid queries work Hybrid queries are basically parallel queries that each run through their own `Search` phase and do not interact in any way. They also run across all the shards. It is not entirely clear what happens if a combination pipeline is not specified for them, perhaps the scores are just summed. When the normalization processor is applied to keyword/vector hybrid searches, documents that show up due to keyword match may not also have showed up in the vector search and vice versa. In these situations, it just receives a 0 score for the missing query component. Opensearch does not run another phase to recapture those missing values. The impact of this is that after normalizing, the missing scores are 0 but this is a higher score than if it actually received a non-zero score. This may not be immediately obvious so an explanation is included here. If it got a non-zero score instead, it must be lower than all of the other scores of the list (otherwise it would have shown up). Therefore it would impact the normalization and push the other scores higher so that it's not only the lowest score still, but now it's a differentiated lowest score. This is not strictly the case in a multi-node setup but the high level concept approximately holds. So basically the 0 score is a form of "minimum value clipping". ## On time decay and boosting Embedding models do not have a uniform distribution from 0 to 1. The values typically cluster strongly around 0.6 to 0.8 but also varies between models and even the query. It is not a safe assumption to pre-normalize the scores so we also cannot apply any additive or multiplicative boost to it. i.e. if results of a doc cluster around 0.6 to 0.8 and I give a 50% penalty to the score, it doesn't bring a result from the top of the range to 50th percentile, it brings it under the 0.6 and is now the worst match. Same logic applies to additive boosting. So these boosts can only be applied after normalization. Unfortunately with Opensearch, the normalization processor runs last and only applies to the results of the completely independent `Search` phase queries. So if a time based boost (a separate query which filters on recently updated documents) is added, it would not be able to introduce any new documents to the set (since the new documents would have no keyword/vector score or already be present) since the 0 scores on keyword and vector would make the docs which only came because of time filter very low scoring. This can however make some of the lower scored documents from the union of all the `Search` phase documents to show up higher and potentially not get dropped before being fetched and returned to the user. But there are other issues of including these: - There is no way to sort by this field, only a filter, so there's no way to guarantee the best docs even irrespective of the contents. If there are lots of updates, this may miss. - There is not a good way to normalize this field, the best is to clip it on the bottom. - This would require using min-max norm but z-score norm is better for the other functions due to things like it being less sensitive to outliers, better handles distribution drifts (min-max assumes stable meaningful ranges), better for comparing "unusual-ness" across distributions. So while it is possible to apply time based boosting at the normalization stage (or specifically to the keyword score), we have decided it is better to not apply it during the OpenSearch query. Because of these limitations, Onyx in code applies further refinements, boostings, etc. based on OpenSearch providing an initial filtering. The impact of time decay and boost should not be so big that we would need orders of magnitude more results back from OpenSearch. ## Other concepts to be aware of Within the `Search` phase, there are optional steps like Rescore but these are not useful for the combination/normalization work that is relevant for the hybrid search. Since the Rescore happens prior to normalization, it's not able to provide any meaningful operations to the query for our usage. Because the Title is included in the Contents for both embedding and keyword searches, the Title scores are very low relative to the actual full contents scoring. It is seen as a boost rather than a core scoring component. Time decay works similarly. ================================================ FILE: backend/onyx/document_index/opensearch/client.py ================================================ import json import logging import time from contextlib import AbstractContextManager from contextlib import nullcontext from typing import Any from typing import Generic from typing import TypeVar from opensearchpy import OpenSearch from opensearchpy import TransportError from opensearchpy.helpers import bulk from pydantic import BaseModel from onyx.configs.app_configs import DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S from onyx.configs.app_configs import OPENSEARCH_ADMIN_PASSWORD from onyx.configs.app_configs import OPENSEARCH_ADMIN_USERNAME from onyx.configs.app_configs import OPENSEARCH_HOST from onyx.configs.app_configs import OPENSEARCH_REST_API_PORT from onyx.document_index.interfaces_new import TenantState from onyx.document_index.opensearch.constants import OpenSearchSearchType from onyx.document_index.opensearch.schema import DocumentChunk from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id from onyx.document_index.opensearch.search import DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW from onyx.server.metrics.opensearch_search import observe_opensearch_search from onyx.server.metrics.opensearch_search import track_opensearch_search_in_progress from onyx.utils.logger import setup_logger from onyx.utils.timing import log_function_time CLIENT_THRESHOLD_TO_LOG_SLOW_SEARCH_MS = 2000 logger = setup_logger(__name__) # Set the logging level to WARNING to ignore INFO and DEBUG logs from # opensearch. By default it emits INFO-level logs for every request. # The opensearch-py library uses "opensearch" as the logger name for HTTP # requests (see opensearchpy/connection/base.py) opensearch_logger = logging.getLogger("opensearch") opensearch_logger.setLevel(logging.WARNING) SchemaDocumentModel = TypeVar("SchemaDocumentModel") class SearchHit(BaseModel, Generic[SchemaDocumentModel]): """Represents a hit from OpenSearch in response to a query. Templated on the specific document model as defined by a schema. """ model_config = {"frozen": True} # The document chunk source retrieved from OpenSearch. document_chunk: SchemaDocumentModel # The match score for the document chunk as calculated by OpenSearch. Only # relevant for "fuzzy searches"; this will be None for direct queries where # score is not relevant like direct retrieval on ID. score: float | None = None # Maps schema property name to a list of highlighted snippets with match # terms wrapped in tags (e.g. "something keyword other thing"). match_highlights: dict[str, list[str]] = {} # Score explanation from OpenSearch when "explain": true is set in the # query. Contains detailed breakdown of how the score was calculated. explanation: dict[str, Any] | None = None class IndexInfo(BaseModel): """ Represents information about an OpenSearch index. """ model_config = {"frozen": True} name: str health: str status: str num_primary_shards: str num_replica_shards: str docs_count: str docs_deleted: str created_at: str total_size: str primary_shards_size: str def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]: """Recursively replaces vectors in the body with their length. TODO(andrei): Do better. Args: body: The body to replace the vectors. Returns: A copy of body with vectors replaced with their length. """ new_body: dict[str, Any] = {} for k, v in body.items(): if k == "vector": new_body[k] = len(v) elif isinstance(v, dict): new_body[k] = get_new_body_without_vectors(v) elif isinstance(v, list) and len(v) > 0 and isinstance(v[0], dict): new_body[k] = [get_new_body_without_vectors(item) for item in v] else: new_body[k] = v return new_body class OpenSearchClient(AbstractContextManager): """Client for interacting with OpenSearch for cluster-level operations. Args: host: The host of the OpenSearch cluster. port: The port of the OpenSearch cluster. auth: The authentication credentials for the OpenSearch cluster. A tuple of (username, password). use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to True. verify_certs: Whether to verify the SSL certificates for the OpenSearch cluster. Defaults to False. ssl_show_warn: Whether to show warnings for SSL certificates. Defaults to False. timeout: The timeout for the OpenSearch cluster. Defaults to DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S. """ def __init__( self, host: str = OPENSEARCH_HOST, port: int = OPENSEARCH_REST_API_PORT, auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD), use_ssl: bool = True, verify_certs: bool = False, ssl_show_warn: bool = False, timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S, ): logger.debug( f"Creating OpenSearch client with host {host}, port {port} and timeout {timeout} seconds." ) self._client = OpenSearch( hosts=[{"host": host, "port": port}], http_auth=auth, use_ssl=use_ssl, verify_certs=verify_certs, ssl_show_warn=ssl_show_warn, # NOTE: This timeout applies to all requests the client makes, # including bulk indexing. When exceeded, the client will raise a # ConnectionTimeout and return no useful results. The OpenSearch # server will log that the client cancelled the request. To get # partial results from OpenSearch, pass in a timeout parameter to # your request body that is less than this value. timeout=timeout, ) def __exit__(self, *_: Any) -> None: self.close() def __del__(self) -> None: try: self.close() except Exception: pass @log_function_time(print_only=True, debug_only=True, include_args=True) def create_search_pipeline( self, pipeline_id: str, pipeline_body: dict[str, Any], ) -> None: """Creates a search pipeline. See the OpenSearch documentation for more information on the search pipeline body. https://docs.opensearch.org/latest/search-plugins/search-pipelines/index/ Args: pipeline_id: The ID of the search pipeline to create. pipeline_body: The body of the search pipeline to create. Raises: Exception: There was an error creating the search pipeline. """ response = self._client.search_pipeline.put(id=pipeline_id, body=pipeline_body) if not response.get("acknowledged", False): raise RuntimeError(f"Failed to create search pipeline {pipeline_id}.") @log_function_time(print_only=True, debug_only=True, include_args=True) def delete_search_pipeline(self, pipeline_id: str) -> None: """Deletes a search pipeline. Args: pipeline_id: The ID of the search pipeline to delete. Raises: Exception: There was an error deleting the search pipeline. """ response = self._client.search_pipeline.delete(id=pipeline_id) if not response.get("acknowledged", False): raise RuntimeError(f"Failed to delete search pipeline {pipeline_id}.") @log_function_time(print_only=True, debug_only=True, include_args=True) def put_cluster_settings(self, settings: dict[str, Any]) -> bool: """Puts cluster settings. Args: settings: The settings to put. Raises: Exception: There was an error putting the cluster settings. Returns: True if the settings were put successfully, False otherwise. """ response = self._client.cluster.put_settings(body=settings) if response.get("acknowledged", False): logger.info("Successfully put cluster settings.") return True else: logger.error(f"Failed to put cluster settings: {response}.") return False @log_function_time(print_only=True, debug_only=True) def list_indices_with_info(self) -> list[IndexInfo]: """ Lists the indices in the OpenSearch cluster with information about each index. Returns: A list of IndexInfo objects for each index. """ response = self._client.cat.indices(format="json") indices: list[IndexInfo] = [] for raw_index_info in response: indices.append( IndexInfo( name=raw_index_info.get("index", ""), health=raw_index_info.get("health", ""), status=raw_index_info.get("status", ""), num_primary_shards=raw_index_info.get("pri", ""), num_replica_shards=raw_index_info.get("rep", ""), docs_count=raw_index_info.get("docs.count", ""), docs_deleted=raw_index_info.get("docs.deleted", ""), created_at=raw_index_info.get("creation.date.string", ""), total_size=raw_index_info.get("store.size", ""), primary_shards_size=raw_index_info.get("pri.store.size", ""), ) ) return indices @log_function_time(print_only=True, debug_only=True) def ping(self) -> bool: """Pings the OpenSearch cluster. Returns: True if OpenSearch could be reached, False if it could not. """ return self._client.ping() def close(self) -> None: """Closes the client. Raises: Exception: There was an error closing the client. """ self._client.close() class OpenSearchIndexClient(OpenSearchClient): """Client for interacting with OpenSearch for index-level operations. OpenSearch's Python module has pretty bad typing support so this client attempts to protect the rest of the codebase from this. As a consequence, most methods here return the minimum data needed for the rest of Onyx, and tend to rely on Exceptions to handle errors. TODO(andrei): This class currently assumes the structure of the database schema when it returns a DocumentChunk. Make the class, or at least the search method, templated on the structure the caller can expect. Args: index_name: The name of the index to interact with. host: The host of the OpenSearch cluster. port: The port of the OpenSearch cluster. auth: The authentication credentials for the OpenSearch cluster. A tuple of (username, password). use_ssl: Whether to use SSL for the OpenSearch cluster. Defaults to True. verify_certs: Whether to verify the SSL certificates for the OpenSearch cluster. Defaults to False. ssl_show_warn: Whether to show warnings for SSL certificates. Defaults to False. timeout: The timeout for the OpenSearch cluster. Defaults to DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S. """ def __init__( self, index_name: str, host: str = OPENSEARCH_HOST, port: int = OPENSEARCH_REST_API_PORT, auth: tuple[str, str] = (OPENSEARCH_ADMIN_USERNAME, OPENSEARCH_ADMIN_PASSWORD), use_ssl: bool = True, verify_certs: bool = False, ssl_show_warn: bool = False, timeout: int = DEFAULT_OPENSEARCH_CLIENT_TIMEOUT_S, emit_metrics: bool = True, ): super().__init__( host=host, port=port, auth=auth, use_ssl=use_ssl, verify_certs=verify_certs, ssl_show_warn=ssl_show_warn, timeout=timeout, ) self._index_name = index_name self._emit_metrics = emit_metrics logger.debug( f"OpenSearch client created successfully for index {self._index_name}." ) @log_function_time(print_only=True, debug_only=True, include_args=True) def create_index(self, mappings: dict[str, Any], settings: dict[str, Any]) -> None: """Creates the index. See the OpenSearch documentation for more information on mappings and settings. Args: mappings: The mappings for the index to create. settings: The settings for the index to create. Raises: Exception: There was an error creating the index. """ body: dict[str, Any] = { "mappings": mappings, "settings": settings, } logger.debug(f"Creating index {self._index_name} with body {body}.") response = self._client.indices.create(index=self._index_name, body=body) if not response.get("acknowledged", False): raise RuntimeError(f"Failed to create index {self._index_name}.") response_index = response.get("index", "") if response_index != self._index_name: raise RuntimeError( f"OpenSearch responded with index name {response_index} when creating index {self._index_name}." ) logger.debug(f"Index {self._index_name} created successfully.") @log_function_time(print_only=True, debug_only=True) def delete_index(self) -> bool: """Deletes the index. Raises: Exception: There was an error deleting the index. Returns: True if the index was deleted, False if it did not exist. """ if not self._client.indices.exists(index=self._index_name): logger.warning( f"Tried to delete index {self._index_name} but it does not exist." ) return False logger.debug(f"Deleting index {self._index_name}.") response = self._client.indices.delete(index=self._index_name) if not response.get("acknowledged", False): raise RuntimeError(f"Failed to delete index {self._index_name}.") return True @log_function_time(print_only=True, debug_only=True) def index_exists(self) -> bool: """Checks if the index exists. Raises: Exception: There was an error checking if the index exists. Returns: True if the index exists, False if it does not. """ return self._client.indices.exists(index=self._index_name) @log_function_time(print_only=True, debug_only=True, include_args=True) def put_mapping(self, mappings: dict[str, Any]) -> None: """Updates the index mapping in an idempotent manner. - Existing fields with the same definition: No-op (succeeds silently). - New fields: Added to the index. - Existing fields with different types: Raises exception (requires reindex). See the OpenSearch documentation for more information: https://docs.opensearch.org/latest/api-reference/index-apis/put-mapping/ Args: mappings: The complete mapping definition to apply. This will be merged with existing mappings in the index. Raises: Exception: There was an error updating the mappings, such as attempting to change the type of an existing field. """ logger.debug( f"Putting mappings for index {self._index_name} with mappings {mappings}." ) response = self._client.indices.put_mapping( index=self._index_name, body=mappings ) if not response.get("acknowledged", False): raise RuntimeError( f"Failed to put the mapping update for index {self._index_name}." ) logger.debug(f"Successfully put mappings for index {self._index_name}.") @log_function_time(print_only=True, debug_only=True, include_args=True) def validate_index(self, expected_mappings: dict[str, Any]) -> bool: """Validates the index. Short-circuit returns False on the first mismatch. Logs the mismatch. See the OpenSearch documentation for more information on the index mappings. https://docs.opensearch.org/latest/mappings/ Args: mappings: The expected mappings of the index to validate. Raises: Exception: There was an error validating the index. Returns: True if the index is valid, False if it is not based on the mappings supplied. """ # OpenSearch's documentation makes no mention of what happens when you # invoke client.indices.get on an index that does not exist, so we check # for existence explicitly just to be sure. exists_response = self.index_exists() if not exists_response: logger.warning( f"Tried to validate index {self._index_name} but it does not exist." ) return False logger.debug( f"Validating index {self._index_name} with expected mappings {expected_mappings}." ) get_result = self._client.indices.get(index=self._index_name) index_info: dict[str, Any] = get_result.get(self._index_name, {}) if not index_info: raise ValueError( f"Bug: OpenSearch did not return any index info for index {self._index_name}, " "even though it confirmed that the index exists." ) index_mapping_properties: dict[str, Any] = index_info.get("mappings", {}).get( "properties", {} ) expected_mapping_properties: dict[str, Any] = expected_mappings.get( "properties", {} ) assert ( expected_mapping_properties ), "Bug: No properties were found in the provided expected mappings." for property in expected_mapping_properties: if property not in index_mapping_properties: logger.warning( f'The field "{property}" was not found in the index {self._index_name}.' ) return False expected_property_type = expected_mapping_properties[property].get( "type", "" ) assert ( expected_property_type ), f'Bug: The field "{property}" in the supplied expected schema mappings has no type.' index_property_type = index_mapping_properties[property].get("type", "") if expected_property_type != index_property_type: logger.warning( f'The field "{property}" in the index {self._index_name} has type {index_property_type} ' f"but the expected type is {expected_property_type}." ) return False logger.debug(f"Index {self._index_name} validated successfully.") return True @log_function_time(print_only=True, debug_only=True, include_args=True) def update_settings(self, settings: dict[str, Any]) -> None: """Updates the settings of the index. See the OpenSearch documentation for more information on the index settings. https://docs.opensearch.org/latest/install-and-configure/configuring-opensearch/index-settings/ Args: settings: The settings to update the index with. Raises: Exception: There was an error updating the settings of the index. """ # TODO(andrei): Implement this. raise NotImplementedError @log_function_time( print_only=True, debug_only=True, include_args_subset={ "document": str, "tenant_state": str, "update_if_exists": str, }, ) def index_document( self, document: DocumentChunk, tenant_state: TenantState, update_if_exists: bool = False, ) -> None: """Indexes a document. Args: document: The document to index. In Onyx this is a chunk of a document, OpenSearch simply refers to this as a document as well. tenant_state: The tenant state of the caller. update_if_exists: Whether to update the document if it already exists. If False, will raise an exception if the document already exists. Defaults to False. Raises: Exception: There was an error indexing the document. This includes the case where a document with the same ID already exists if update_if_exists is False. """ logger.debug( f"Trying to index document ID {document.document_id} for tenant {tenant_state.tenant_id}. " f"update_if_exists={update_if_exists}." ) document_chunk_id: str = get_opensearch_doc_chunk_id( tenant_state=tenant_state, document_id=document.document_id, chunk_index=document.chunk_index, max_chunk_size=document.max_chunk_size, ) body: dict[str, Any] = document.model_dump(exclude_none=True) # client.create will raise if a doc with the same ID exists. # client.index does not do this. if update_if_exists: result = self._client.index( index=self._index_name, id=document_chunk_id, body=body ) else: result = self._client.create( index=self._index_name, id=document_chunk_id, body=body ) result_id = result.get("_id", "") # Sanity check. if result_id != document_chunk_id: raise RuntimeError( f'Upon trying to index a document, OpenSearch responded with ID "{result_id}" ' f'instead of "{document_chunk_id}" which is the ID it was given.' ) result_string: str = result.get("result", "") match result_string: # Sanity check. case "created": pass case "updated": if not update_if_exists: raise RuntimeError( f'The OpenSearch client returned result "updated" for indexing document chunk "{document_chunk_id}". ' "This indicates that a document chunk with that ID already exists, which is not expected." ) case _: raise RuntimeError( f'Unknown OpenSearch indexing result: "{result_string}".' ) logger.debug(f"Successfully indexed {document_chunk_id}.") @log_function_time( print_only=True, debug_only=True, include_args_subset={ "documents": len, "tenant_state": str, "update_if_exists": str, }, ) def bulk_index_documents( self, documents: list[DocumentChunk], tenant_state: TenantState, update_if_exists: bool = False, ) -> None: """Bulk indexes documents. Raises if there are any errors during the bulk index. It should be assumed that no documents in the batch were indexed successfully if there is an error. Retries on 429 too many requests. Args: documents: The documents to index. In Onyx this is a chunk of a document, OpenSearch simply refers to this as a document as well. tenant_state: The tenant state of the caller. update_if_exists: Whether to update the document if it already exists. If False, will raise an exception if the document already exists. Defaults to False. Raises: Exception: There was an error during the bulk index. This includes the case where a document with the same ID already exists if update_if_exists is False. """ if not documents: return logger.debug( f"Bulk indexing {len(documents)} documents for tenant {tenant_state.tenant_id}. update_if_exists={update_if_exists}." ) data = [] for document in documents: document_chunk_id: str = get_opensearch_doc_chunk_id( tenant_state=tenant_state, document_id=document.document_id, chunk_index=document.chunk_index, max_chunk_size=document.max_chunk_size, ) body: dict[str, Any] = document.model_dump(exclude_none=True) data_for_document: dict[str, Any] = { "_index": self._index_name, "_id": document_chunk_id, "_op_type": "index" if update_if_exists else "create", "_source": body, } data.append(data_for_document) # max_retries is the number of times to retry a request if we get a 429. success, errors = bulk(self._client, data, max_retries=3) if errors: raise RuntimeError( f"Failed to bulk index documents for index {self._index_name}. Errors: {errors}" ) if success != len(documents): raise RuntimeError( f"OpenSearch reported no errors during bulk index but the number of successful operations " f"({success}) does not match the number of documents ({len(documents)})." ) logger.debug(f"Successfully bulk indexed {len(documents)} documents.") @log_function_time(print_only=True, debug_only=True, include_args=True) def delete_document(self, document_chunk_id: str) -> bool: """Deletes a document. Args: document_chunk_id: The OpenSearch ID of the document chunk to delete. Raises: Exception: There was an error deleting the document. Returns: True if the document was deleted, False if it was not found. """ try: logger.debug( f"Trying to delete document chunk {document_chunk_id} from index {self._index_name}." ) result = self._client.delete(index=self._index_name, id=document_chunk_id) except TransportError as e: if e.status_code == 404: logger.debug( f"Document chunk {document_chunk_id} not found in index {self._index_name}." ) return False else: raise e result_string: str = result.get("result", "") match result_string: case "deleted": logger.debug( f"Successfully deleted document chunk {document_chunk_id} from index {self._index_name}." ) return True case "not_found": logger.debug( f"Document chunk {document_chunk_id} not found in index {self._index_name}." ) return False case _: raise RuntimeError( f'Unknown OpenSearch deletion result: "{result_string}".' ) @log_function_time(print_only=True, debug_only=True) def delete_by_query(self, query_body: dict[str, Any]) -> int: """Deletes documents by a query. Args: query_body: The body of the query to delete documents by. Raises: Exception: There was an error deleting the documents. Returns: The number of documents deleted. """ logger.debug( f"Trying to delete documents by query for index {self._index_name}." ) result = self._client.delete_by_query(index=self._index_name, body=query_body) if result.get("timed_out", False): raise RuntimeError( f"Delete by query timed out for index {self._index_name}." ) if len(result.get("failures", [])) > 0: raise RuntimeError( f"Failed to delete some or all of the documents for index {self._index_name}." ) num_deleted = result.get("deleted", 0) num_processed = result.get("total", 0) if num_deleted != num_processed: raise RuntimeError( f"Failed to delete some or all of the documents for index {self._index_name}. " f"{num_deleted} documents were deleted out of {num_processed} documents that were processed." ) logger.debug( f"Successfully deleted {num_deleted} documents by query for index {self._index_name}." ) return num_deleted @log_function_time( print_only=True, debug_only=True, include_args_subset={ "document_chunk_id": str, "properties_to_update": lambda x: x.keys(), }, ) def update_document( self, document_chunk_id: str, properties_to_update: dict[str, Any] ) -> None: """Updates an OpenSearch document chunk's properties. Args: document_chunk_id: The OpenSearch ID of the document chunk to update. properties_to_update: The properties of the document to update. Each property should exist in the schema. Raises: Exception: There was an error updating the document. """ logger.debug( f"Trying to update document chunk {document_chunk_id} for index {self._index_name}." ) update_body: dict[str, Any] = {"doc": properties_to_update} result = self._client.update( index=self._index_name, id=document_chunk_id, body=update_body, _source=False, ) result_id = result.get("_id", "") # Sanity check. if result_id != document_chunk_id: raise RuntimeError( f'Upon trying to update a document, OpenSearch responded with ID "{result_id}" ' f'instead of "{document_chunk_id}" which is the ID it was given.' ) result_string: str = result.get("result", "") match result_string: # Sanity check. case "updated": logger.debug( f"Successfully updated document chunk {document_chunk_id} for index {self._index_name}." ) return case "noop": logger.warning( f'OpenSearch reported a no-op when trying to update document with ID "{document_chunk_id}".' ) return case _: raise RuntimeError( f'The OpenSearch client returned result "{result_string}" for updating document chunk "{document_chunk_id}". ' "This is unexpected." ) @log_function_time(print_only=True, debug_only=True, include_args=True) def get_document(self, document_chunk_id: str) -> DocumentChunk: """Gets an OpenSearch document chunk. Will raise an exception if the document chunk is not found. Args: document_chunk_id: The OpenSearch ID of the document chunk to get. Raises: Exception: There was an error getting the document. This includes the case where the document is not found. Returns: The document chunk. """ logger.debug( f"Trying to get document chunk {document_chunk_id} from index {self._index_name}." ) result = self._client.get(index=self._index_name, id=document_chunk_id) found_result: bool = result.get("found", False) if not found_result: raise RuntimeError( f'Document chunk with ID "{document_chunk_id}" was not found.' ) document_chunk_source: dict[str, Any] | None = result.get("_source") if not document_chunk_source: raise RuntimeError( f'Document chunk with ID "{document_chunk_id}" has no data.' ) logger.debug( f"Successfully got document chunk {document_chunk_id} from index {self._index_name}." ) return DocumentChunk.model_validate(document_chunk_source) @log_function_time(print_only=True, debug_only=True) def search( self, body: dict[str, Any], search_pipeline_id: str | None, search_type: OpenSearchSearchType = OpenSearchSearchType.UNKNOWN, ) -> list[SearchHit[DocumentChunkWithoutVectors]]: """Searches the index. NOTE: Does not return vector fields. In order to take advantage of performance benefits, the search body should exclude the schema's vector fields. TODO(andrei): Ideally we could check that every field in the body is present in the index, to avoid a class of runtime bugs that could easily be caught during development. Or change the function signature to accept a predefined pydantic model of allowed fields. Args: body: The body of the search request. See the OpenSearch documentation for more information on search request bodies. search_pipeline_id: The ID of the search pipeline to use. If None, the default search pipeline will be used. search_type: Label for Prometheus metrics. Does not affect search behavior. Raises: Exception: There was an error searching the index. Returns: List of search hits that match the search request. """ logger.debug( f"Trying to search index {self._index_name} with search pipeline {search_pipeline_id}." ) result: dict[str, Any] params = {"phase_took": "true"} ctx = self._get_emit_metrics_context_manager(search_type) t0 = time.perf_counter() with ctx: if search_pipeline_id: result = self._client.search( index=self._index_name, search_pipeline=search_pipeline_id, body=body, params=params, ) else: result = self._client.search( index=self._index_name, body=body, params=params ) client_duration_s = time.perf_counter() - t0 hits, time_took, timed_out, phase_took, profile = ( self._get_hits_and_profile_from_search_result(result) ) if self._emit_metrics: observe_opensearch_search(search_type, client_duration_s, time_took) self._log_search_result_perf( time_took=time_took, timed_out=timed_out, phase_took=phase_took, profile=profile, body=body, search_pipeline_id=search_pipeline_id, raise_on_timeout=True, ) search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = [] for hit in hits: document_chunk_source: dict[str, Any] | None = hit.get("_source") if not document_chunk_source: raise RuntimeError( f'Document chunk with ID "{hit.get("_id", "")}" has no data.' ) document_chunk_score = hit.get("_score", None) match_highlights: dict[str, list[str]] = hit.get("highlight", {}) explanation: dict[str, Any] | None = hit.get("_explanation", None) search_hit = SearchHit[DocumentChunkWithoutVectors]( document_chunk=DocumentChunkWithoutVectors.model_validate( document_chunk_source ), score=document_chunk_score, match_highlights=match_highlights, explanation=explanation, ) search_hits.append(search_hit) logger.debug( f"Successfully searched index {self._index_name} and got {len(search_hits)} hits." ) return search_hits @log_function_time(print_only=True, debug_only=True) def search_for_document_ids( self, body: dict[str, Any], search_type: OpenSearchSearchType = OpenSearchSearchType.UNKNOWN, ) -> list[str]: """Searches the index and returns only document chunk IDs. In order to take advantage of the performance benefits of only returning IDs, the body should have a key, value pair of "_source": False. Otherwise, OpenSearch will return the entire document body and this method's performance will be the same as the search method's. TODO(andrei): Ideally we could check that every field in the body is present in the index, to avoid a class of runtime bugs that could easily be caught during development. Args: body: The body of the search request. See the OpenSearch documentation for more information on search request bodies. TODO(andrei): Make this a more deep interface; callers shouldn't need to know to set _source: False for example. search_type: Label for Prometheus metrics. Does not affect search behavior. Raises: Exception: There was an error searching the index. Returns: List of document chunk IDs that match the search request. """ logger.debug( f"Trying to search for document chunk IDs in index {self._index_name}." ) if "_source" not in body or body["_source"] is not False: logger.warning( "The body of the search request for document chunk IDs is missing the key, value pair of " '"_source": False. This query will therefore be inefficient.' ) params = {"phase_took": "true"} ctx = self._get_emit_metrics_context_manager(search_type) t0 = time.perf_counter() with ctx: result: dict[str, Any] = self._client.search( index=self._index_name, body=body, params=params ) client_duration_s = time.perf_counter() - t0 hits, time_took, timed_out, phase_took, profile = ( self._get_hits_and_profile_from_search_result(result) ) if self._emit_metrics: observe_opensearch_search(search_type, client_duration_s, time_took) self._log_search_result_perf( time_took=time_took, timed_out=timed_out, phase_took=phase_took, profile=profile, body=body, raise_on_timeout=True, ) # TODO(andrei): Implement scroll/point in time for results so that we # can return arbitrarily-many IDs. if len(hits) == DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW: logger.warning( "The search request for document chunk IDs returned the maximum number of results. " "It is extremely likely that there are more hits in OpenSearch than the returned results." ) # Extract only the _id field from each hit. document_chunk_ids: list[str] = [] for hit in hits: document_chunk_id = hit.get("_id") if not document_chunk_id: raise RuntimeError( "Received a hit from OpenSearch but the _id field is missing." ) document_chunk_ids.append(document_chunk_id) logger.debug( f"Successfully searched for document chunk IDs in index {self._index_name} and got {len(document_chunk_ids)} hits." ) return document_chunk_ids @log_function_time(print_only=True, debug_only=True) def refresh_index(self) -> None: """Refreshes the index to make recent changes searchable. In OpenSearch, documents are not immediately searchable after indexing. This method forces a refresh to make them available for search. Raises: Exception: There was an error refreshing the index. """ self._client.indices.refresh(index=self._index_name) def _get_hits_and_profile_from_search_result( self, result: dict[str, Any] ) -> tuple[list[Any], int | None, bool | None, dict[str, Any], dict[str, Any]]: """Extracts the hits and profiling information from a search result. Args: result: The search result to extract the hits from. Raises: Exception: There was an error extracting the hits from the search result. Returns: A tuple containing the hits from the search result, the time taken to execute the search in milliseconds, whether the search timed out, the time taken to execute each phase of the search, and the profile. """ time_took: int | None = result.get("took") timed_out: bool | None = result.get("timed_out") phase_took: dict[str, Any] = result.get("phase_took", {}) profile: dict[str, Any] = result.get("profile", {}) hits_first_layer: dict[str, Any] = result.get("hits", {}) if not hits_first_layer: raise RuntimeError( f"Hits field missing from response when trying to search index {self._index_name}." ) hits_second_layer: list[Any] = hits_first_layer.get("hits", []) return hits_second_layer, time_took, timed_out, phase_took, profile def _log_search_result_perf( self, time_took: int | None, timed_out: bool | None, phase_took: dict[str, Any], profile: dict[str, Any], body: dict[str, Any], search_pipeline_id: str | None = None, raise_on_timeout: bool = False, ) -> None: """Logs the performance of a search result. Args: time_took: The time taken to execute the search in milliseconds. timed_out: Whether the search timed out. phase_took: The time taken to execute each phase of the search. profile: The profile for the search. body: The body of the search request for logging. search_pipeline_id: The ID of the search pipeline used for the search, if any, for logging. Defaults to None. raise_on_timeout: Whether to raise an exception if the search timed out. Note that the result may still contain useful partial results. Defaults to False. Raises: Exception: If raise_on_timeout is True and the search timed out. """ if time_took and time_took > CLIENT_THRESHOLD_TO_LOG_SLOW_SEARCH_MS: logger.warning( f"OpenSearch client warning: Search for index {self._index_name} took {time_took} milliseconds.\n" f"Body: {get_new_body_without_vectors(body)}\n" f"Search pipeline ID: {search_pipeline_id}\n" f"Phase took: {phase_took}\n" f"Profile: {json.dumps(profile, indent=2)}\n" ) if timed_out: error_str = f"OpenSearch client error: Search timed out for index {self._index_name}." logger.error(error_str) if raise_on_timeout: raise RuntimeError(error_str) def _get_emit_metrics_context_manager( self, search_type: OpenSearchSearchType ) -> AbstractContextManager[None]: """ Returns a context manager that tracks in-flight OpenSearch searches via a Gauge if emit_metrics is True, otherwise returns a null context manager. """ return ( track_opensearch_search_in_progress(search_type) if self._emit_metrics else nullcontext() ) def wait_for_opensearch_with_timeout( wait_interval_s: int = 5, wait_limit_s: int = 60, client: OpenSearchClient | None = None, ) -> bool: """Waits for OpenSearch to become ready subject to a timeout. Will create a new dummy client if no client is provided. Will close this client at the end of the function. Will not close the client if it was supplied. Args: wait_interval_s: The interval in seconds to wait between checks. Defaults to 5. wait_limit_s: The total timeout in seconds to wait for OpenSearch to become ready. Defaults to 60. client: The OpenSearch client to use for pinging. If None, a new dummy client will be created. Defaults to None. Returns: True if OpenSearch is ready, False otherwise. """ with nullcontext(client) if client else OpenSearchClient() as client: time_start = time.monotonic() while True: if client.ping(): logger.info("[OpenSearch] Readiness probe succeeded. Continuing...") return True time_elapsed = time.monotonic() - time_start if time_elapsed > wait_limit_s: logger.info( f"[OpenSearch] Readiness probe did not succeed within the timeout ({wait_limit_s} seconds)." ) return False logger.info( f"[OpenSearch] Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={wait_limit_s:.1f}" ) time.sleep(wait_interval_s) ================================================ FILE: backend/onyx/document_index/opensearch/cluster_settings.py ================================================ from typing import Any OPENSEARCH_CLUSTER_SETTINGS: dict[str, Any] = { "persistent": { # By default, when you index a document to a non-existent index, # OpenSearch will automatically create the index. This behavior is # undesirable so this function exposes the ability to disable it. # See # https://docs.opensearch.org/latest/install-and-configure/configuring-opensearch/index/#updating-cluster-settings-using-the-api "action.auto_create_index": False, # Thresholds for OpenSearch to log slow queries at the server level. "cluster.search.request.slowlog.level": "INFO", "cluster.search.request.slowlog.threshold.warn": "5s", "cluster.search.request.slowlog.threshold.info": "2s", "cluster.search.request.slowlog.threshold.debug": "1s", "cluster.search.request.slowlog.threshold.trace": "500ms", } } ================================================ FILE: backend/onyx/document_index/opensearch/constants.py ================================================ # Default value for the maximum number of tokens a chunk can hold, if none is # specified when creating an index. import os from enum import Enum DEFAULT_MAX_CHUNK_SIZE = 512 # By default OpenSearch will only return a maximum of this many results in a # given search. This value is configurable in the index settings. DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW = 10_000 # For documents which do not have a value for LAST_UPDATED_FIELD_NAME, we assume # that the document was last updated this many days ago for the purpose of time # cutoff filtering during retrieval. ASSUMED_DOCUMENT_AGE_DAYS = 90 # Size of the dynamic list used to consider elements during kNN graph creation. # Higher values improve search quality but increase indexing time. Values # typically range between 100 - 512. EF_CONSTRUCTION = 256 # Number of bi-directional links per element. Higher values improve search # quality but increase memory footprint. Values typically range between 12 - 48. M = 32 # Set relatively high for better accuracy. # When performing hybrid search, we need to consider more candidates than the # number of results to be returned. This is because the scoring is hybrid and # the results are reordered due to the hybrid scoring. Higher = more candidates # for hybrid fusion = better retrieval accuracy, but results in more computation # per query. Imagine a simple case with a single keyword query and a single # vector query and we want 10 final docs. If we only fetch 10 candidates from # each of keyword and vector, they would have to have perfect overlap to get a # good hybrid ranking for the 10 results. If we fetch 1000 candidates from each, # we have a much higher chance of all 10 of the final desired docs showing up # and getting scored. In worse situations, the final 10 docs don't even show up # as the final 10 (worse than just a miss at the reranking step). # Defaults to 500 for now. Initially this defaulted to 750 but we were seeing # poor search performance; bumped from 100 to 500 to improve recall. DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES = int( os.environ.get("DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES", 500) ) # Number of vectors to examine to decide the top k neighbors for the HNSW # method. # NOTE: "When creating a search query, you must specify k. If you provide both k # and ef_search, then the larger value is passed to the engine. If ef_search is # larger than k, you can provide the size parameter to limit the final number of # results to k." from # https://docs.opensearch.org/latest/query-dsl/specialized/k-nn/index/#ef_search EF_SEARCH = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES class OpenSearchSearchType(str, Enum): """Search type label used for Prometheus metrics.""" HYBRID = "hybrid" KEYWORD = "keyword" SEMANTIC = "semantic" RANDOM = "random" DOC_ID_RETRIEVAL = "doc_id_retrieval" UNKNOWN = "unknown" class HybridSearchSubqueryConfiguration(Enum): TITLE_VECTOR_CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD = 1 # Current default. CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD = 2 # Will raise and block application start if HYBRID_SEARCH_SUBQUERY_CONFIGURATION # is set but not a valid value. If not set, defaults to # CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD. HYBRID_SEARCH_SUBQUERY_CONFIGURATION: HybridSearchSubqueryConfiguration = ( HybridSearchSubqueryConfiguration( int(os.environ["HYBRID_SEARCH_SUBQUERY_CONFIGURATION"]) ) if os.environ.get("HYBRID_SEARCH_SUBQUERY_CONFIGURATION", None) is not None else HybridSearchSubqueryConfiguration.CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD ) class HybridSearchNormalizationPipeline(Enum): # Current default. MIN_MAX = 1 # NOTE: Using z-score normalization is better for hybrid search from a # theoretical standpoint. Empirically on a small dataset of up to 10K docs, # it's not very different. Likely more impactful at scale. # https://opensearch.org/blog/introducing-the-z-score-normalization-technique-for-hybrid-search/ ZSCORE = 2 # Will raise and block application start if HYBRID_SEARCH_NORMALIZATION_PIPELINE # is set but not a valid value. If not set, defaults to MIN_MAX. HYBRID_SEARCH_NORMALIZATION_PIPELINE: HybridSearchNormalizationPipeline = ( HybridSearchNormalizationPipeline( int(os.environ["HYBRID_SEARCH_NORMALIZATION_PIPELINE"]) ) if os.environ.get("HYBRID_SEARCH_NORMALIZATION_PIPELINE", None) is not None else HybridSearchNormalizationPipeline.MIN_MAX ) ================================================ FILE: backend/onyx/document_index/opensearch/opensearch_document_index.py ================================================ import json from collections.abc import Iterable from typing import Any import httpx from opensearchpy import NotFoundError from onyx.access.models import DocumentAccess from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH from onyx.configs.app_configs import VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT from onyx.configs.chat_configs import NUM_RETURNED_HITS from onyx.configs.chat_configs import TITLE_CONTENT_RATIO from onyx.configs.constants import PUBLIC_DOC_PAT from onyx.connectors.cross_connector_utils.miscellaneous_utils import ( get_experts_stores_representations, ) from onyx.connectors.models import convert_metadata_list_of_strings_to_dict from onyx.context.search.enums import QueryType from onyx.context.search.models import IndexFilters from onyx.context.search.models import InferenceChunk from onyx.context.search.models import InferenceChunkUncleaned from onyx.context.search.models import QueryExpansionType from onyx.db.enums import EmbeddingPrecision from onyx.db.models import DocumentSource from onyx.document_index.chunk_content_enrichment import cleanup_content_for_chunks from onyx.document_index.chunk_content_enrichment import ( generate_enriched_content_for_chunk_text, ) from onyx.document_index.interfaces import DocumentIndex as OldDocumentIndex from onyx.document_index.interfaces import ( DocumentInsertionRecord as OldDocumentInsertionRecord, ) from onyx.document_index.interfaces import IndexBatchParams from onyx.document_index.interfaces import VespaChunkRequest from onyx.document_index.interfaces import VespaDocumentFields from onyx.document_index.interfaces import VespaDocumentUserFields from onyx.document_index.interfaces_new import DocumentIndex from onyx.document_index.interfaces_new import DocumentInsertionRecord from onyx.document_index.interfaces_new import DocumentSectionRequest from onyx.document_index.interfaces_new import IndexingMetadata from onyx.document_index.interfaces_new import MetadataUpdateRequest from onyx.document_index.interfaces_new import TenantState from onyx.document_index.opensearch.client import OpenSearchClient from onyx.document_index.opensearch.client import OpenSearchIndexClient from onyx.document_index.opensearch.client import SearchHit from onyx.document_index.opensearch.cluster_settings import OPENSEARCH_CLUSTER_SETTINGS from onyx.document_index.opensearch.constants import OpenSearchSearchType from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME from onyx.document_index.opensearch.schema import DocumentChunk from onyx.document_index.opensearch.schema import DocumentChunkWithoutVectors from onyx.document_index.opensearch.schema import DocumentSchema from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id from onyx.document_index.opensearch.schema import GLOBAL_BOOST_FIELD_NAME from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME from onyx.document_index.opensearch.search import DocumentQuery from onyx.document_index.opensearch.search import ( get_min_max_normalization_pipeline_name_and_config, ) from onyx.document_index.opensearch.search import ( get_normalization_pipeline_name_and_config, ) from onyx.document_index.opensearch.search import ( get_zscore_normalization_pipeline_name_and_config, ) from onyx.indexing.models import DocMetadataAwareIndexChunk from onyx.indexing.models import Document from onyx.utils.logger import setup_logger from onyx.utils.text_processing import remove_invalid_unicode_chars from shared_configs.configs import MULTI_TENANT from shared_configs.contextvars import get_current_tenant_id from shared_configs.model_server_models import Embedding logger = setup_logger(__name__) class ChunkCountNotFoundError(ValueError): """Raised when a document has no chunk count.""" def generate_opensearch_filtered_access_control_list( access: DocumentAccess, ) -> list[str]: """Generates an access control list with PUBLIC_DOC_PAT removed. In the OpenSearch schema this is represented by PUBLIC_FIELD_NAME. """ access_control_list = access.to_acl() access_control_list.discard(PUBLIC_DOC_PAT) return list(access_control_list) def set_cluster_state(client: OpenSearchClient) -> None: if not client.put_cluster_settings(settings=OPENSEARCH_CLUSTER_SETTINGS): logger.error( "Failed to put cluster settings. If the settings have never been set before, " "this may cause unexpected index creation when indexing documents into an " "index that does not exist, or may cause expected logs to not appear. If this " "is not the first time running Onyx against this instance of OpenSearch, these " "settings have likely already been set. Not taking any further action..." ) min_max_normalization_pipeline_name, min_max_normalization_pipeline_config = ( get_min_max_normalization_pipeline_name_and_config() ) zscore_normalization_pipeline_name, zscore_normalization_pipeline_config = ( get_zscore_normalization_pipeline_name_and_config() ) client.create_search_pipeline( pipeline_id=min_max_normalization_pipeline_name, pipeline_body=min_max_normalization_pipeline_config, ) client.create_search_pipeline( pipeline_id=zscore_normalization_pipeline_name, pipeline_body=zscore_normalization_pipeline_config, ) def _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned( chunk: DocumentChunkWithoutVectors, score: float | None, highlights: dict[str, list[str]], ) -> InferenceChunkUncleaned: """ Generates an inference chunk from an OpenSearch document chunk, its score, and its match highlights. Args: chunk: The document chunk returned by OpenSearch. score: The document chunk match score as calculated by OpenSearch. Only relevant for searches like hybrid search. It is acceptable for this value to be None for results from other queries like ID-based retrieval as a match score makes no sense in those contexts. highlights: Maps schema property name to a list of highlighted snippets with match terms wrapped in tags (e.g. "something keyword other thing"). Returns: An Onyx inference chunk representation. """ return InferenceChunkUncleaned( chunk_id=chunk.chunk_index, blurb=chunk.blurb, # Includes extra content prepended/appended during indexing. content=chunk.content, # When we read a string and turn it into a dict the keys will be # strings, but in this case they need to be ints. source_links=( {int(k): v for k, v in json.loads(chunk.source_links).items()} if chunk.source_links else None ), image_file_id=chunk.image_file_id, # Deprecated. Fill in some reasonable default. section_continuation=False, document_id=chunk.document_id, source_type=DocumentSource(chunk.source_type), semantic_identifier=chunk.semantic_identifier, title=chunk.title, boost=chunk.global_boost, score=score, hidden=chunk.hidden, metadata=( convert_metadata_list_of_strings_to_dict(chunk.metadata_list) if chunk.metadata_list else {} ), # Extract highlighted snippets from the content field, if available. In # the future we may want to match on other fields too, currently we only # use the content field. match_highlights=highlights.get(CONTENT_FIELD_NAME, []), # TODO(andrei) Consider storing a chunk content index instead of a full # string when working on chunk content augmentation. doc_summary=chunk.doc_summary, # TODO(andrei) Same thing as above. chunk_context=chunk.chunk_context, updated_at=chunk.last_updated, primary_owners=chunk.primary_owners, secondary_owners=chunk.secondary_owners, # TODO(andrei) Same thing as chunk_context above. metadata_suffix=chunk.metadata_suffix, ) def _convert_onyx_chunk_to_opensearch_document( chunk: DocMetadataAwareIndexChunk, ) -> DocumentChunk: filtered_blurb = remove_invalid_unicode_chars(chunk.blurb) _title = chunk.source_document.get_title_for_document_index() filtered_title = remove_invalid_unicode_chars(_title) if _title else None filtered_content = remove_invalid_unicode_chars( generate_enriched_content_for_chunk_text(chunk) ) filtered_semantic_identifier = remove_invalid_unicode_chars( chunk.source_document.semantic_identifier ) filtered_metadata_suffix = remove_invalid_unicode_chars( chunk.metadata_suffix_keyword ) _metadata_list = chunk.source_document.get_metadata_str_attributes() filtered_metadata_list = ( [remove_invalid_unicode_chars(metadata) for metadata in _metadata_list] if _metadata_list else None ) return DocumentChunk( document_id=chunk.source_document.id, chunk_index=chunk.chunk_id, # Use get_title_for_document_index to match the logic used when creating # the title_embedding in the embedder. This method falls back to # semantic_identifier when title is None (but not empty string). title=filtered_title, title_vector=chunk.title_embedding, content=filtered_content, content_vector=chunk.embeddings.full_embedding, source_type=chunk.source_document.source.value, metadata_list=filtered_metadata_list, metadata_suffix=filtered_metadata_suffix, last_updated=chunk.source_document.doc_updated_at, public=chunk.access.is_public, access_control_list=generate_opensearch_filtered_access_control_list( chunk.access ), global_boost=chunk.boost, semantic_identifier=filtered_semantic_identifier, image_file_id=chunk.image_file_id, # Small optimization, if this list is empty we can supply None to # OpenSearch and it will not store any data at all for this field, which # is different from supplying an empty list. source_links=json.dumps(chunk.source_links) if chunk.source_links else None, blurb=filtered_blurb, doc_summary=chunk.doc_summary, chunk_context=chunk.chunk_context, # Small optimization, if this list is empty we can supply None to # OpenSearch and it will not store any data at all for this field, which # is different from supplying an empty list. document_sets=list(chunk.document_sets) if chunk.document_sets else None, # Small optimization, if this list is empty we can supply None to # OpenSearch and it will not store any data at all for this field, which # is different from supplying an empty list. user_projects=chunk.user_project or None, personas=chunk.personas or None, primary_owners=get_experts_stores_representations( chunk.source_document.primary_owners ), secondary_owners=get_experts_stores_representations( chunk.source_document.secondary_owners ), # TODO(andrei): Consider not even getting this from # DocMetadataAwareIndexChunk and instead using OpenSearchDocumentIndex's # instance variable. One source of truth -> less chance of a very bad # bug in prod. tenant_id=TenantState(tenant_id=chunk.tenant_id, multitenant=MULTI_TENANT), # Store ancestor hierarchy node IDs for hierarchy-based filtering. ancestor_hierarchy_node_ids=chunk.ancestor_hierarchy_node_ids or None, ) class OpenSearchOldDocumentIndex(OldDocumentIndex): """ Wrapper for OpenSearch to adapt the new DocumentIndex interface with invocations to the old DocumentIndex interface in the hotpath. The analogous class for Vespa is VespaIndex which calls to VespaDocumentIndex. TODO(andrei): This is very dumb and purely temporary until there are no more references to the old interface in the hotpath. """ def __init__( self, index_name: str, embedding_dim: int, embedding_precision: EmbeddingPrecision, secondary_index_name: str | None, secondary_embedding_dim: int | None, secondary_embedding_precision: EmbeddingPrecision | None, # NOTE: We do not support large chunks right now. large_chunks_enabled: bool, # noqa: ARG002 secondary_large_chunks_enabled: bool | None, # noqa: ARG002 multitenant: bool = False, httpx_client: httpx.Client | None = None, # noqa: ARG002 ) -> None: super().__init__( index_name=index_name, secondary_index_name=secondary_index_name, ) if multitenant != MULTI_TENANT: raise ValueError( "Bug: Multitenant mismatch when initializing an OpenSearchDocumentIndex. " f"Expected {MULTI_TENANT}, got {multitenant}." ) tenant_id = get_current_tenant_id() tenant_state = TenantState(tenant_id=tenant_id, multitenant=multitenant) self._real_index = OpenSearchDocumentIndex( tenant_state=tenant_state, index_name=index_name, embedding_dim=embedding_dim, embedding_precision=embedding_precision, ) self._secondary_real_index: OpenSearchDocumentIndex | None = None if self.secondary_index_name: if secondary_embedding_dim is None or secondary_embedding_precision is None: raise ValueError( "Bug: Secondary index embedding dimension and precision are not set." ) self._secondary_real_index = OpenSearchDocumentIndex( tenant_state=tenant_state, index_name=self.secondary_index_name, embedding_dim=secondary_embedding_dim, embedding_precision=secondary_embedding_precision, ) @staticmethod def register_multitenant_indices( indices: list[str], embedding_dims: list[int], embedding_precisions: list[EmbeddingPrecision], ) -> None: raise NotImplementedError( "Bug: Multitenant index registration is not supported for OpenSearch." ) def ensure_indices_exist( self, primary_embedding_dim: int, primary_embedding_precision: EmbeddingPrecision, secondary_index_embedding_dim: int | None, secondary_index_embedding_precision: EmbeddingPrecision | None, ) -> None: self._real_index.verify_and_create_index_if_necessary( primary_embedding_dim, primary_embedding_precision ) if self.secondary_index_name: if ( secondary_index_embedding_dim is None or secondary_index_embedding_precision is None ): raise ValueError( "Bug: Secondary index embedding dimension and precision are not set." ) assert ( self._secondary_real_index is not None ), "Bug: Secondary index is not initialized." self._secondary_real_index.verify_and_create_index_if_necessary( secondary_index_embedding_dim, secondary_index_embedding_precision ) def index( self, chunks: Iterable[DocMetadataAwareIndexChunk], index_batch_params: IndexBatchParams, ) -> set[OldDocumentInsertionRecord]: """ NOTE: Do NOT consider the secondary index here. A separate indexing pipeline will be responsible for indexing to the secondary index. This design is not ideal and we should reconsider this when revamping index swapping. """ # Convert IndexBatchParams to IndexingMetadata. chunk_counts: dict[str, IndexingMetadata.ChunkCounts] = {} for doc_id in index_batch_params.doc_id_to_new_chunk_cnt: old_count = index_batch_params.doc_id_to_previous_chunk_cnt[doc_id] new_count = index_batch_params.doc_id_to_new_chunk_cnt[doc_id] chunk_counts[doc_id] = IndexingMetadata.ChunkCounts( old_chunk_cnt=old_count, new_chunk_cnt=new_count, ) indexing_metadata = IndexingMetadata(doc_id_to_chunk_cnt_diff=chunk_counts) results = self._real_index.index(chunks, indexing_metadata) # Convert list[DocumentInsertionRecord] to # set[OldDocumentInsertionRecord]. return { OldDocumentInsertionRecord( document_id=record.document_id, already_existed=record.already_existed, ) for record in results } def delete_single( self, doc_id: str, *, tenant_id: str, # noqa: ARG002 chunk_count: int | None, ) -> int: """ NOTE: Remember to handle the secondary index here. There is no separate pipeline for deleting chunks in the secondary index. This design is not ideal and we should reconsider this when revamping index swapping. """ total_chunks_deleted = self._real_index.delete(doc_id, chunk_count) if self.secondary_index_name: assert ( self._secondary_real_index is not None ), "Bug: Secondary index is not initialized." total_chunks_deleted += self._secondary_real_index.delete( doc_id, chunk_count ) return total_chunks_deleted def update_single( self, doc_id: str, *, tenant_id: str, # noqa: ARG002 chunk_count: int | None, fields: VespaDocumentFields | None, user_fields: VespaDocumentUserFields | None, ) -> None: """ NOTE: Remember to handle the secondary index here. There is no separate pipeline for updating chunks in the secondary index. This design is not ideal and we should reconsider this when revamping index swapping. """ if fields is None and user_fields is None: logger.warning( f"Tried to update document {doc_id} with no updated fields or user fields." ) return # Convert VespaDocumentFields to MetadataUpdateRequest. update_request = MetadataUpdateRequest( document_ids=[doc_id], doc_id_to_chunk_cnt={ doc_id: chunk_count if chunk_count is not None else -1 }, access=fields.access if fields else None, document_sets=fields.document_sets if fields else None, boost=fields.boost if fields else None, hidden=fields.hidden if fields else None, project_ids=( set(user_fields.user_projects) # NOTE: Empty user_projects is semantically different from None # user_projects. if user_fields and user_fields.user_projects is not None else None ), persona_ids=( set(user_fields.personas) # NOTE: Empty personas is semantically different from None # personas. if user_fields and user_fields.personas is not None else None ), ) try: self._real_index.update([update_request]) if self.secondary_index_name: assert ( self._secondary_real_index is not None ), "Bug: Secondary index is not initialized." self._secondary_real_index.update([update_request]) except NotFoundError: logger.exception( f"Tried to update document {doc_id} but at least one of its chunks was not found in OpenSearch. " "This is likely due to it not having been indexed yet. Skipping update for now..." ) return except ChunkCountNotFoundError: logger.exception( f"Tried to update document {doc_id} but its chunk count is not known. We tolerate this for now " "but this will not be an acceptable state once OpenSearch is the primary document index and the " "indexing/updating race condition is fixed." ) return def id_based_retrieval( self, chunk_requests: list[VespaChunkRequest], filters: IndexFilters, batch_retrieval: bool = False, get_large_chunks: bool = False, # noqa: ARG002 ) -> list[InferenceChunk]: section_requests = [ DocumentSectionRequest( document_id=req.document_id, min_chunk_ind=req.min_chunk_ind, max_chunk_ind=req.max_chunk_ind, ) for req in chunk_requests ] return self._real_index.id_based_retrieval( section_requests, filters, batch_retrieval ) def hybrid_retrieval( self, query: str, query_embedding: Embedding, final_keywords: list[str] | None, filters: IndexFilters, hybrid_alpha: float, time_decay_multiplier: float, # noqa: ARG002 num_to_retrieve: int, ranking_profile_type: QueryExpansionType = QueryExpansionType.SEMANTIC, # noqa: ARG002 title_content_ratio: float | None = TITLE_CONTENT_RATIO, # noqa: ARG002 ) -> list[InferenceChunk]: # Determine query type based on hybrid_alpha. if hybrid_alpha >= 0.8: query_type = QueryType.SEMANTIC elif hybrid_alpha <= 0.2: query_type = QueryType.KEYWORD else: query_type = QueryType.SEMANTIC # Default to semantic for hybrid. return self._real_index.hybrid_retrieval( query=query, query_embedding=query_embedding, final_keywords=final_keywords, query_type=query_type, filters=filters, num_to_retrieve=num_to_retrieve, ) def admin_retrieval( self, query: str, query_embedding: Embedding, filters: IndexFilters, num_to_retrieve: int = NUM_RETURNED_HITS, ) -> list[InferenceChunk]: return self._real_index.hybrid_retrieval( query=query, query_embedding=query_embedding, final_keywords=None, query_type=QueryType.KEYWORD, filters=filters, num_to_retrieve=num_to_retrieve, ) def random_retrieval( self, filters: IndexFilters, num_to_retrieve: int = 10, ) -> list[InferenceChunk]: return self._real_index.random_retrieval( filters=filters, num_to_retrieve=num_to_retrieve, dirty=None, ) class OpenSearchDocumentIndex(DocumentIndex): """OpenSearch-specific implementation of the DocumentIndex interface. This class provides document indexing, retrieval, and management operations for an OpenSearch search engine instance. It handles the complete lifecycle of document chunks within a specific OpenSearch index/schema. Each kind of embedding used should correspond to a different instance of this class, and therefore a different index in OpenSearch. If in a multitenant environment and VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT, will verify and create the index if necessary on initialization. This is because there is no logic which runs on cluster restart which scans through all search settings over all tenants and creates the relevant indices. Args: tenant_state: The tenant state of the caller. index_name: The name of the index to interact with. embedding_dim: The dimensionality of the embeddings used for the index. embedding_precision: The precision of the embeddings used for the index. """ def __init__( self, tenant_state: TenantState, index_name: str, embedding_dim: int, embedding_precision: EmbeddingPrecision, ) -> None: self._index_name: str = index_name self._tenant_state: TenantState = tenant_state self._client = OpenSearchIndexClient(index_name=self._index_name) if self._tenant_state.multitenant and VERIFY_CREATE_OPENSEARCH_INDEX_ON_INIT_MT: self.verify_and_create_index_if_necessary( embedding_dim=embedding_dim, embedding_precision=embedding_precision ) def verify_and_create_index_if_necessary( self, embedding_dim: int, embedding_precision: EmbeddingPrecision, # noqa: ARG002 ) -> None: """Verifies and creates the index if necessary. Also puts the desired cluster settings if not in a multitenant environment. Also puts the desired search pipeline state if not in a multitenant environment, creating the pipelines if they do not exist and updating them otherwise. In a multitenant environment, the above steps happen explicitly on setup. Args: embedding_dim: Vector dimensionality for the vector similarity part of the search. embedding_precision: Precision of the values of the vectors for the similarity part of the search. Raises: Exception: There was an error verifying or creating the index or search pipelines. """ logger.debug( f"[OpenSearchDocumentIndex] Verifying and creating index {self._index_name} if " f"necessary, with embedding dimension {embedding_dim}." ) if not self._tenant_state.multitenant: set_cluster_state(self._client) expected_mappings = DocumentSchema.get_document_schema( embedding_dim, self._tenant_state.multitenant ) if not self._client.index_exists(): index_settings = DocumentSchema.get_index_settings_based_on_environment() self._client.create_index( mappings=expected_mappings, settings=index_settings, ) else: # Ensure schema is up to date by applying the current mappings. try: self._client.put_mapping(expected_mappings) except Exception as e: logger.error( f"Failed to update mappings for index {self._index_name}. This likely means a " f"field type was changed which requires reindexing. Error: {e}" ) raise def index( self, chunks: Iterable[DocMetadataAwareIndexChunk], indexing_metadata: IndexingMetadata, ) -> list[DocumentInsertionRecord]: """Indexes an iterable of document chunks into the document index. Groups chunks by document ID and for each document, deletes existing chunks and indexes the new chunks in bulk. NOTE: It is assumed that chunks for a given document are not spread out over multiple index() calls. Args: chunks: Document chunks with all of the information needed for indexing to the document index. indexing_metadata: Information about chunk counts for efficient cleaning / updating. Raises: Exception: Failed to index some or all of the chunks for the specified documents. Returns: List of document IDs which map to unique documents as well as if the document is newly indexed or had already existed and was just updated. """ total_chunks = sum( cc.new_chunk_cnt for cc in indexing_metadata.doc_id_to_chunk_cnt_diff.values() ) logger.debug( f"[OpenSearchDocumentIndex] Indexing {total_chunks} chunks from {len(indexing_metadata.doc_id_to_chunk_cnt_diff)} " f"documents for index {self._index_name}." ) document_indexing_results: list[DocumentInsertionRecord] = [] deleted_doc_ids: set[str] = set() # Buffer chunks per document as they arrive from the iterable. # When the document ID changes flush the buffered chunks. current_doc_id: str | None = None current_chunks: list[DocMetadataAwareIndexChunk] = [] def _flush_chunks(doc_chunks: list[DocMetadataAwareIndexChunk]) -> None: assert len(doc_chunks) > 0, "doc_chunks is empty" # Create a batch of OpenSearch-formatted chunks for bulk insertion. # Since we are doing this in batches, an error occurring midway # can result in a state where chunks are deleted and not all the # new chunks have been indexed. chunk_batch: list[DocumentChunk] = [ _convert_onyx_chunk_to_opensearch_document(chunk) for chunk in doc_chunks ] onyx_document: Document = doc_chunks[0].source_document # First delete the doc's chunks from the index. This is so that # there are no dangling chunks in the index, in the event that the # new document's content contains fewer chunks than the previous # content. # TODO(andrei): This can possibly be made more efficient by checking # if the chunk count has actually decreased. This assumes that # overlapping chunks are perfectly overwritten. If we can't # guarantee that then we need the code as-is. if onyx_document.id not in deleted_doc_ids: num_chunks_deleted = self.delete( onyx_document.id, onyx_document.chunk_count ) deleted_doc_ids.add(onyx_document.id) # If we see that chunks were deleted we assume the doc already # existed. We record the result before bulk_index_documents # runs. If indexing raises, this entire result list is discarded # by the caller's retry logic, so early recording is safe. document_indexing_results.append( DocumentInsertionRecord( document_id=onyx_document.id, already_existed=num_chunks_deleted > 0, ) ) # Now index. This will raise if a chunk of the same ID exists, which # we do not expect because we should have deleted all chunks. self._client.bulk_index_documents( documents=chunk_batch, tenant_state=self._tenant_state, ) for chunk in chunks: doc_id = chunk.source_document.id if doc_id != current_doc_id: if current_chunks: _flush_chunks(current_chunks) current_doc_id = doc_id current_chunks = [chunk] elif len(current_chunks) >= MAX_CHUNKS_PER_DOC_BATCH: _flush_chunks(current_chunks) current_chunks = [chunk] else: current_chunks.append(chunk) if current_chunks: _flush_chunks(current_chunks) return document_indexing_results def delete( self, document_id: str, chunk_count: int | None = None, # noqa: ARG002 ) -> int: """Deletes all chunks for a given document. Does nothing if the specified document ID does not exist. TODO(andrei): Consider implementing this method to delete on document chunk IDs vs querying for matching document chunks. Unclear if this is any better though. Args: document_id: The unique identifier for the document as represented in Onyx, not necessarily in the document index. chunk_count: The number of chunks in OpenSearch for the document. Defaults to None. Raises: Exception: Failed to delete some or all of the chunks for the document. Returns: The number of chunks successfully deleted. """ logger.debug( f"[OpenSearchDocumentIndex] Deleting document {document_id} from index {self._index_name}." ) query_body = DocumentQuery.delete_from_document_id_query( document_id=document_id, tenant_state=self._tenant_state, ) return self._client.delete_by_query(query_body) def update( self, update_requests: list[MetadataUpdateRequest], ) -> None: """Updates some set of chunks. NOTE: Will raise if one of the specified document chunks do not exist. This may be due to a concurrent ongoing indexing operation. In that event callers are expected to retry after a bit once the state of the document index is updated. NOTE: Requires document chunk count be known; will raise if it is not. This may be caused by the same situation outlined above. NOTE: Will no-op if an update request has no fields to update. TODO(andrei): Consider exploring a batch API for OpenSearch for this operation. Args: update_requests: A list of update requests, each containing a list of document IDs and the fields to update. The field updates apply to all of the specified documents in each update request. Raises: Exception: Failed to update some or all of the chunks for the specified documents. """ logger.debug( f"[OpenSearchDocumentIndex] Updating {len(update_requests)} chunks for index {self._index_name}." ) for update_request in update_requests: properties_to_update: dict[str, Any] = dict() # TODO(andrei): Nit but consider if we can use DocumentChunk # here so we don't have to think about passing in the # appropriate types into this dict. if update_request.access is not None: properties_to_update[ACCESS_CONTROL_LIST_FIELD_NAME] = ( generate_opensearch_filtered_access_control_list( update_request.access ) ) if update_request.document_sets is not None: properties_to_update[DOCUMENT_SETS_FIELD_NAME] = list( update_request.document_sets ) if update_request.boost is not None: properties_to_update[GLOBAL_BOOST_FIELD_NAME] = int( update_request.boost ) if update_request.hidden is not None: properties_to_update[HIDDEN_FIELD_NAME] = update_request.hidden if update_request.project_ids is not None: properties_to_update[USER_PROJECTS_FIELD_NAME] = list( update_request.project_ids ) if update_request.persona_ids is not None: properties_to_update[PERSONAS_FIELD_NAME] = list( update_request.persona_ids ) if not properties_to_update: if len(update_request.document_ids) > 1: update_string = f"{len(update_request.document_ids)} documents" else: update_string = f"document {update_request.document_ids[0]}" logger.warning( f"[OpenSearchDocumentIndex] Tried to update {update_string} " "with no specified update fields. This will be a no-op." ) continue for doc_id in update_request.document_ids: doc_chunk_count = update_request.doc_id_to_chunk_cnt.get(doc_id, -1) if doc_chunk_count < 0: # This means the chunk count is not known. This is due to a # race condition between doc indexing and updating steps # which run concurrently when a doc is indexed. The indexing # step should update chunk count shortly. This could also # have been due to an older version of the indexing pipeline # which did not compute chunk count, but that codepath has # since been deprecated and should no longer be the case # here. # TODO(andrei): Fix the aforementioned race condition. raise ChunkCountNotFoundError( f"Tried to update document {doc_id} but its chunk count is not known. " "Older versions of the application used to permit this but is not a " "supported state for a document when using OpenSearch. The document was " "likely just added to the indexing pipeline and the chunk count will be " "updated shortly." ) if doc_chunk_count == 0: raise ValueError( f"Bug: Tried to update document {doc_id} but its chunk count was 0." ) for chunk_index in range(doc_chunk_count): document_chunk_id = get_opensearch_doc_chunk_id( tenant_state=self._tenant_state, document_id=doc_id, chunk_index=chunk_index, ) self._client.update_document( document_chunk_id=document_chunk_id, properties_to_update=properties_to_update, ) def id_based_retrieval( self, chunk_requests: list[DocumentSectionRequest], filters: IndexFilters, # TODO(andrei): Remove this from the new interface at some point; we # should not be exposing this. batch_retrieval: bool = False, # noqa: ARG002 # TODO(andrei): Add a param for whether to retrieve hidden docs. ) -> list[InferenceChunk]: """ TODO(andrei): Consider implementing this method to retrieve on document chunk IDs vs querying for matching document chunks. """ logger.debug( f"[OpenSearchDocumentIndex] Retrieving {len(chunk_requests)} chunks for index {self._index_name}." ) results: list[InferenceChunk] = [] for chunk_request in chunk_requests: search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = [] query_body = DocumentQuery.get_from_document_id_query( document_id=chunk_request.document_id, tenant_state=self._tenant_state, # NOTE: Index filters includes metadata tags which were filtered # for invalid unicode at indexing time. In theory it would be # ideal to do filtering here as well, in practice we never did # that in the Vespa codepath and have not seen issues in # production, so we deliberately conform to the existing logic # in order to not unknowningly introduce a possible bug. index_filters=filters, include_hidden=False, max_chunk_size=chunk_request.max_chunk_size, min_chunk_index=chunk_request.min_chunk_ind, max_chunk_index=chunk_request.max_chunk_ind, ) search_hits = self._client.search( body=query_body, search_pipeline_id=None, search_type=OpenSearchSearchType.DOC_ID_RETRIEVAL, ) inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [ _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned( search_hit.document_chunk, None, {} ) for search_hit in search_hits ] inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks( inference_chunks_uncleaned ) results.extend(inference_chunks) return results def hybrid_retrieval( self, query: str, query_embedding: Embedding, # TODO(andrei): This param is not great design, get rid of it. final_keywords: list[str] | None, query_type: QueryType, # noqa: ARG002 filters: IndexFilters, num_to_retrieve: int, ) -> list[InferenceChunk]: # TODO(andrei): There is some duplicated logic in this function with # others in this file. logger.debug( f"[OpenSearchDocumentIndex] Hybrid retrieving {num_to_retrieve} chunks for index {self._index_name}." ) # TODO(andrei): This could be better, the caller should just make this # decision when passing in the query param. See the above comment in the # function signature. final_query = " ".join(final_keywords) if final_keywords else query query_body = DocumentQuery.get_hybrid_search_query( query_text=final_query, query_vector=query_embedding, num_hits=num_to_retrieve, tenant_state=self._tenant_state, # NOTE: Index filters includes metadata tags which were filtered # for invalid unicode at indexing time. In theory it would be # ideal to do filtering here as well, in practice we never did # that in the Vespa codepath and have not seen issues in # production, so we deliberately conform to the existing logic # in order to not unknowningly introduce a possible bug. index_filters=filters, include_hidden=False, ) normalization_pipeline_name, _ = get_normalization_pipeline_name_and_config() search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search( body=query_body, search_pipeline_id=normalization_pipeline_name, search_type=OpenSearchSearchType.HYBRID, ) # Good place for a breakpoint to inspect the search hits if you have # "explain" enabled. inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [ _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned( search_hit.document_chunk, search_hit.score, search_hit.match_highlights ) for search_hit in search_hits ] inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks( inference_chunks_uncleaned ) return inference_chunks def keyword_retrieval( self, query: str, filters: IndexFilters, num_to_retrieve: int, ) -> list[InferenceChunk]: # TODO(andrei): There is some duplicated logic in this function with # others in this file. logger.debug( f"[OpenSearchDocumentIndex] Keyword retrieving {num_to_retrieve} chunks for index {self._index_name}." ) query_body = DocumentQuery.get_keyword_search_query( query_text=query, num_hits=num_to_retrieve, tenant_state=self._tenant_state, # NOTE: Index filters includes metadata tags which were filtered # for invalid unicode at indexing time. In theory it would be # ideal to do filtering here as well, in practice we never did # that in the Vespa codepath and have not seen issues in # production, so we deliberately conform to the existing logic # in order to not unknowningly introduce a possible bug. index_filters=filters, include_hidden=False, ) search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search( body=query_body, search_pipeline_id=None, search_type=OpenSearchSearchType.KEYWORD, ) inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [ _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned( search_hit.document_chunk, search_hit.score, search_hit.match_highlights ) for search_hit in search_hits ] inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks( inference_chunks_uncleaned ) return inference_chunks def semantic_retrieval( self, query_embedding: Embedding, filters: IndexFilters, num_to_retrieve: int, ) -> list[InferenceChunk]: # TODO(andrei): There is some duplicated logic in this function with # others in this file. logger.debug( f"[OpenSearchDocumentIndex] Semantic retrieving {num_to_retrieve} chunks for index {self._index_name}." ) query_body = DocumentQuery.get_semantic_search_query( query_embedding=query_embedding, num_hits=num_to_retrieve, tenant_state=self._tenant_state, # NOTE: Index filters includes metadata tags which were filtered # for invalid unicode at indexing time. In theory it would be # ideal to do filtering here as well, in practice we never did # that in the Vespa codepath and have not seen issues in # production, so we deliberately conform to the existing logic # in order to not unknowningly introduce a possible bug. index_filters=filters, include_hidden=False, ) search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search( body=query_body, search_pipeline_id=None, search_type=OpenSearchSearchType.SEMANTIC, ) inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [ _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned( search_hit.document_chunk, search_hit.score, search_hit.match_highlights ) for search_hit in search_hits ] inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks( inference_chunks_uncleaned ) return inference_chunks def random_retrieval( self, filters: IndexFilters, num_to_retrieve: int = 10, dirty: bool | None = None, # noqa: ARG002 ) -> list[InferenceChunk]: logger.debug( f"[OpenSearchDocumentIndex] Randomly retrieving {num_to_retrieve} chunks for index {self._index_name}." ) query_body = DocumentQuery.get_random_search_query( tenant_state=self._tenant_state, index_filters=filters, num_to_retrieve=num_to_retrieve, ) search_hits: list[SearchHit[DocumentChunkWithoutVectors]] = self._client.search( body=query_body, search_pipeline_id=None, search_type=OpenSearchSearchType.RANDOM, ) inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [ _convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned( search_hit.document_chunk, search_hit.score, search_hit.match_highlights ) for search_hit in search_hits ] inference_chunks: list[InferenceChunk] = cleanup_content_for_chunks( inference_chunks_uncleaned ) return inference_chunks def index_raw_chunks(self, chunks: list[DocumentChunk]) -> None: """Indexes raw document chunks into OpenSearch. Used in the Vespa migration task. Can be deleted after migrations are complete. """ logger.debug( f"[OpenSearchDocumentIndex] Indexing {len(chunks)} raw chunks for index {self._index_name}." ) # Do not raise if the document already exists, just update. This is # because the document may already have been indexed during the # OpenSearch transition period. self._client.bulk_index_documents( documents=chunks, tenant_state=self._tenant_state, update_if_exists=True ) ================================================ FILE: backend/onyx/document_index/opensearch/schema.py ================================================ import hashlib from datetime import datetime from datetime import timezone from typing import Any from typing import Self from pydantic import BaseModel from pydantic import Field from pydantic import field_serializer from pydantic import field_validator from pydantic import model_serializer from pydantic import model_validator from pydantic import SerializerFunctionWrapHandler from onyx.configs.app_configs import OPENSEARCH_INDEX_NUM_REPLICAS from onyx.configs.app_configs import OPENSEARCH_INDEX_NUM_SHARDS from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER from onyx.configs.app_configs import USING_AWS_MANAGED_OPENSEARCH from onyx.document_index.interfaces_new import TenantState from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE from onyx.document_index.opensearch.constants import EF_CONSTRUCTION from onyx.document_index.opensearch.constants import EF_SEARCH from onyx.document_index.opensearch.constants import M from onyx.document_index.opensearch.string_filtering import DocumentIDTooLongError from onyx.document_index.opensearch.string_filtering import ( filter_and_validate_document_id, ) from onyx.document_index.opensearch.string_filtering import ( MAX_DOCUMENT_ID_ENCODED_LENGTH, ) from onyx.utils.tenant import get_tenant_id_short_string from shared_configs.configs import MULTI_TENANT from shared_configs.contextvars import get_current_tenant_id TITLE_FIELD_NAME = "title" TITLE_VECTOR_FIELD_NAME = "title_vector" CONTENT_FIELD_NAME = "content" CONTENT_VECTOR_FIELD_NAME = "content_vector" SOURCE_TYPE_FIELD_NAME = "source_type" METADATA_LIST_FIELD_NAME = "metadata_list" LAST_UPDATED_FIELD_NAME = "last_updated" PUBLIC_FIELD_NAME = "public" ACCESS_CONTROL_LIST_FIELD_NAME = "access_control_list" HIDDEN_FIELD_NAME = "hidden" GLOBAL_BOOST_FIELD_NAME = "global_boost" SEMANTIC_IDENTIFIER_FIELD_NAME = "semantic_identifier" IMAGE_FILE_ID_FIELD_NAME = "image_file_id" SOURCE_LINKS_FIELD_NAME = "source_links" DOCUMENT_SETS_FIELD_NAME = "document_sets" USER_PROJECTS_FIELD_NAME = "user_projects" PERSONAS_FIELD_NAME = "personas" DOCUMENT_ID_FIELD_NAME = "document_id" CHUNK_INDEX_FIELD_NAME = "chunk_index" MAX_CHUNK_SIZE_FIELD_NAME = "max_chunk_size" TENANT_ID_FIELD_NAME = "tenant_id" BLURB_FIELD_NAME = "blurb" DOC_SUMMARY_FIELD_NAME = "doc_summary" CHUNK_CONTEXT_FIELD_NAME = "chunk_context" METADATA_SUFFIX_FIELD_NAME = "metadata_suffix" PRIMARY_OWNERS_FIELD_NAME = "primary_owners" SECONDARY_OWNERS_FIELD_NAME = "secondary_owners" # Hierarchy filtering - list of ancestor hierarchy node IDs ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME = "ancestor_hierarchy_node_ids" # Faiss was also tried but it didn't have any benefits # NMSLIB is deprecated, not recommended OPENSEARCH_KNN_ENGINE = "lucene" def get_opensearch_doc_chunk_id( tenant_state: TenantState, document_id: str, chunk_index: int, max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE, ) -> str: """ Returns a unique identifier for the chunk. This will be the string used to identify the chunk in OpenSearch. Any direct chunk queries should use this function. If the document ID is too long, a hash of the ID is used instead. """ opensearch_doc_chunk_id_suffix: str = f"__{max_chunk_size}__{chunk_index}" encoded_suffix_length: int = len(opensearch_doc_chunk_id_suffix.encode("utf-8")) max_encoded_permissible_doc_id_length: int = ( MAX_DOCUMENT_ID_ENCODED_LENGTH - encoded_suffix_length ) opensearch_doc_chunk_id_tenant_prefix: str = "" if tenant_state.multitenant: short_tenant_id: str = get_tenant_id_short_string(tenant_state.tenant_id) # Use tenant ID because in multitenant mode each tenant has its own # Documents table, so there is a very small chance that doc IDs are not # actually unique across all tenants. opensearch_doc_chunk_id_tenant_prefix = f"{short_tenant_id}__" encoded_prefix_length: int = len( opensearch_doc_chunk_id_tenant_prefix.encode("utf-8") ) max_encoded_permissible_doc_id_length -= encoded_prefix_length try: sanitized_document_id: str = filter_and_validate_document_id( document_id, max_encoded_length=max_encoded_permissible_doc_id_length ) except DocumentIDTooLongError: # If the document ID is too long, use a hash instead. # We use blake2b because it is faster and equally secure as SHA256, and # accepts digest_size which controls the number of bytes returned in the # hash. # digest_size is the size of the returned hash in bytes. Since we're # decoding the hash bytes as a hex string, the digest_size should be # half the max target size of the hash string. # Subtract 1 because filter_and_validate_document_id compares on >= on # max_encoded_length. # 64 is the max digest_size blake2b returns. digest_size: int = min((max_encoded_permissible_doc_id_length - 1) // 2, 64) sanitized_document_id = hashlib.blake2b( document_id.encode("utf-8"), digest_size=digest_size ).hexdigest() opensearch_doc_chunk_id: str = ( f"{opensearch_doc_chunk_id_tenant_prefix}{sanitized_document_id}{opensearch_doc_chunk_id_suffix}" ) # Do one more validation to ensure we haven't exceeded the max length. opensearch_doc_chunk_id = filter_and_validate_document_id(opensearch_doc_chunk_id) return opensearch_doc_chunk_id def set_or_convert_timezone_to_utc(value: datetime) -> datetime: if value.tzinfo is None: # astimezone will raise if value does not have a timezone set. value = value.replace(tzinfo=timezone.utc) else: # Does appropriate time conversion if value was set in a different # timezone. value = value.astimezone(timezone.utc) return value class DocumentChunkWithoutVectors(BaseModel): """ Represents a chunk of a document in the OpenSearch index without vectors. The names of these fields are based on the OpenSearch schema. Changes to the schema require changes here. See get_document_schema. WARNING: Relies on MULTI_TENANT which is global state. Also uses get_current_tenant_id. Generally relying on global state is bad, in this case we accept it because of the importance of validating tenant logic. """ model_config = {"frozen": True} document_id: str chunk_index: int # The maximum number of tokens this chunk's content can hold. Previously # there was a concept of large chunks, this is a generic concept of that. We # can choose to have any size of chunks in the index and they should be # distinct from one another. max_chunk_size: int = DEFAULT_MAX_CHUNK_SIZE # Either both should be None or both should be non-None. title: str | None = None content: str source_type: str # A list of key-value pairs separated by INDEX_SEPARATOR. See # convert_metadata_dict_to_list_of_strings. metadata_list: list[str] | None = None # If it exists, time zone should always be UTC. last_updated: datetime | None = None public: bool access_control_list: list[str] # Defaults to False, currently gets written during update not index. hidden: bool = False global_boost: int semantic_identifier: str image_file_id: str | None = None # Contains a string representation of a dict which maps offset into the raw # chunk text to the link corresponding to that point. source_links: str | None = None blurb: str # doc_summary, chunk_context, and metadata_suffix are all stored simply to # reverse the augmentations to content. Ideally these would just be start # and stop indices into the content string. For legacy reasons they are not # right now. doc_summary: str chunk_context: str metadata_suffix: str | None = None document_sets: list[str] | None = None user_projects: list[int] | None = None personas: list[int] | None = None primary_owners: list[str] | None = None secondary_owners: list[str] | None = None # List of ancestor hierarchy node IDs for hierarchy-based filtering. # None means no hierarchy info (document will be excluded from # hierarchy-filtered searches). ancestor_hierarchy_node_ids: list[int] | None = None tenant_id: TenantState = Field( default_factory=lambda: TenantState( tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT ) ) def __str__(self) -> str: return ( f"DocumentChunk(document_id={self.document_id}, chunk_index={self.chunk_index}, " f"content length={len(self.content)}, tenant_id={self.tenant_id.tenant_id})." ) @model_serializer(mode="wrap") def serialize_model( self, handler: SerializerFunctionWrapHandler ) -> dict[str, object]: """Invokes pydantic's serialization logic, then excludes Nones. We do this because .model_dump(exclude_none=True) does not work after @field_serializer logic, so for some field serializers which return None and which we would like to exclude from the final dump, they would be included without this. Args: handler: Callable from pydantic which takes the instance of the model as an argument and performs standard serialization. Returns: The return of handler but with None items excluded. """ serialized: dict[str, object] = handler(self) serialized_exclude_none = {k: v for k, v in serialized.items() if v is not None} return serialized_exclude_none @field_serializer("last_updated", mode="wrap") def serialize_datetime_fields_to_epoch_seconds( self, value: datetime | None, handler: SerializerFunctionWrapHandler, # noqa: ARG002 ) -> int | None: """ Serializes datetime fields to seconds since the Unix epoch. If there is no datetime, returns None. """ if value is None: return None value = set_or_convert_timezone_to_utc(value) return int(value.timestamp()) @field_validator("last_updated", mode="before") @classmethod def parse_epoch_seconds_to_datetime(cls, value: Any) -> datetime | None: """Parses seconds since the Unix epoch to a datetime object. If the input is None, returns None. The datetime returned will be in UTC. """ if value is None: return None if isinstance(value, datetime): value = set_or_convert_timezone_to_utc(value) return value if not isinstance(value, int): raise ValueError( f"Bug: Expected an int for the last_updated property from OpenSearch, got {type(value)} instead." ) return datetime.fromtimestamp(value, tz=timezone.utc) @field_serializer("tenant_id", mode="wrap") def serialize_tenant_state( self, value: TenantState, handler: SerializerFunctionWrapHandler, # noqa: ARG002 ) -> str | None: """ Serializes tenant_state to the tenant str if multitenant, or None if not. The idea is that in single tenant mode, the schema does not have a tenant_id field, so we don't want to supply it in our serialized DocumentChunk. This assumes the final serialized model excludes None fields, which serialize_model should enforce. """ if not value.multitenant: return None else: return value.tenant_id @field_validator("tenant_id", mode="before") @classmethod def parse_tenant_id(cls, value: Any) -> TenantState: """ Generates a TenantState from OpenSearch's tenant_id if it exists, or generates a default state if it does not (implies we are in single tenant mode). """ if value is None: if MULTI_TENANT: raise ValueError( "Bug: No tenant_id was supplied but multi-tenant mode is enabled." ) return TenantState( tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT ) elif isinstance(value, TenantState): if MULTI_TENANT != value.multitenant: raise ValueError( f"Bug: An existing TenantState object was supplied to the DocumentChunk model " f"but its multi-tenant mode ({value.multitenant}) does not match the program's " "current global tenancy state." ) return value elif not isinstance(value, str): raise ValueError( f"Bug: Expected a str for the tenant_id property from OpenSearch, got {type(value)} instead." ) else: if not MULTI_TENANT: raise ValueError( "Bug: Got a non-null str for the tenant_id property from OpenSearch but " "multi-tenant mode is not enabled. This is unexpected because in single-tenant " "mode we don't expect to see a tenant_id." ) return TenantState(tenant_id=value, multitenant=MULTI_TENANT) class DocumentChunk(DocumentChunkWithoutVectors): """Represents a chunk of a document in the OpenSearch index. The names of these fields are based on the OpenSearch schema. Changes to the schema require changes here. See get_document_schema. """ model_config = {"frozen": True} title_vector: list[float] | None = None content_vector: list[float] def __str__(self) -> str: return ( f"DocumentChunk(document_id={self.document_id}, chunk_index={self.chunk_index}, " f"content length={len(self.content)}, content vector length={len(self.content_vector)}, " f"tenant_id={self.tenant_id.tenant_id})" ) @model_validator(mode="after") def check_title_and_title_vector_are_consistent(self) -> Self: # title and title_vector should both either be None or not. if self.title is not None and self.title_vector is None: raise ValueError("Bug: Title vector must not be None if title is not None.") if self.title_vector is not None and self.title is None: raise ValueError("Bug: Title must not be None if title vector is not None.") return self class DocumentSchema: """ Represents the schema and indexing strategies of the OpenSearch index. TODO(andrei): Implement multi-phase indexing strategies. """ @staticmethod def get_document_schema(vector_dimension: int, multitenant: bool) -> dict[str, Any]: """Returns the document schema for the OpenSearch index. WARNING: Changes / additions to field names here require changes to the DocumentChunk class above. Notes: - By default all fields have indexing enabled. - By default almost all fields except text fields have doc_values enabled, enabling operations like sorting and aggregations. - By default all fields are nullable. - "type": "keyword" fields are stored as-is, used for exact matches, filtering, etc. - "type": "text" fields are OpenSearch-processed strings, used for full-text searches. - "store": True fields are stored and can be returned on their own, independent of the parent document. - "index": True fields can be queried on. - "doc_values": True fields can be sorted and aggregated efficiently. Not supported for "text" type fields. - "store": True fields are stored separately from the source document and can thus be returned from a query separately from _source. Generally this is not necessary. Args: vector_dimension: The dimension of vector embeddings. Must be a positive integer. multitenant: Whether the index is multitenant. Returns: A dictionary representing the document schema, to be supplied to the OpenSearch client. The structure of this dictionary is determined by OpenSearch documentation. """ schema: dict[str, Any] = { # By default OpenSearch allows dynamically adding new properties # based on indexed documents. This is awful and we disable it here. # An exception will be raised if you try to index a new doc which # contains unexpected fields. "dynamic": "strict", "properties": { TITLE_FIELD_NAME: { "type": "text", # Language analyzer (e.g. english) stems at index and search # time for variant matching. Configure via # OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing # after a change. "analyzer": OPENSEARCH_TEXT_ANALYZER, "fields": { # Subfield accessed as title.keyword. Not indexed for # values longer than 256 chars. # TODO(andrei): Ask Yuhong do we want this? "keyword": {"type": "keyword", "ignore_above": 256} }, # This makes highlighting text during queries more efficient # at the cost of disk space. See # https://docs.opensearch.org/latest/search-plugins/searching-data/highlight/#methods-of-obtaining-offsets "index_options": "offsets", }, CONTENT_FIELD_NAME: { "type": "text", "store": True, "analyzer": OPENSEARCH_TEXT_ANALYZER, "index_options": "offsets", }, TITLE_VECTOR_FIELD_NAME: { "type": "knn_vector", "dimension": vector_dimension, "method": { "name": "hnsw", "space_type": "cosinesimil", "engine": OPENSEARCH_KNN_ENGINE, "parameters": {"ef_construction": EF_CONSTRUCTION, "m": M}, }, }, # TODO(andrei): This is a tensor in Vespa. Also look at feature # parity for these other method fields. CONTENT_VECTOR_FIELD_NAME: { "type": "knn_vector", "dimension": vector_dimension, "method": { "name": "hnsw", "space_type": "cosinesimil", "engine": OPENSEARCH_KNN_ENGINE, "parameters": {"ef_construction": EF_CONSTRUCTION, "m": M}, }, }, SOURCE_TYPE_FIELD_NAME: {"type": "keyword"}, METADATA_LIST_FIELD_NAME: {"type": "keyword"}, LAST_UPDATED_FIELD_NAME: { "type": "date", "format": "epoch_second", # For some reason date defaults to False, even though it # would make sense to sort by date. "doc_values": True, }, # Access control fields. # Whether the doc is public. Could have fallen under access # control list but is such a broad and critical filter that it # is its own field. If true, ACCESS_CONTROL_LIST_FIELD_NAME # should have no effect on queries. PUBLIC_FIELD_NAME: {"type": "boolean"}, # Access control list for the doc, excluding public access, # which is covered above. # If a user's access set contains at least one entry from this # set, the user should be able to retrieve this document. This # only applies if public is set to false; public non-hidden # documents are always visible to anyone in a given tenancy # regardless of this field. ACCESS_CONTROL_LIST_FIELD_NAME: {"type": "keyword"}, # Whether the doc is hidden from search results. # Should clobber all other access search filters, namely # PUBLIC_FIELD_NAME and ACCESS_CONTROL_LIST_FIELD_NAME; up to # search implementations to guarantee this. HIDDEN_FIELD_NAME: {"type": "boolean"}, GLOBAL_BOOST_FIELD_NAME: {"type": "integer"}, # This field is only used for displaying a useful name for the # doc in the UI and is not used for searching. Disabling these # features to increase perf. This field is therefore essentially # just metadata. SEMANTIC_IDENTIFIER_FIELD_NAME: { "type": "keyword", "index": False, "doc_values": False, # Generally False by default; just making sure. "store": False, }, # Same as above; used to display an image along with the doc. IMAGE_FILE_ID_FIELD_NAME: { "type": "keyword", "index": False, "doc_values": False, # Generally False by default; just making sure. "store": False, }, # Same as above; used to link to the source doc. SOURCE_LINKS_FIELD_NAME: { "type": "keyword", "index": False, "doc_values": False, # Generally False by default; just making sure. "store": False, }, # Same as above; used to quickly summarize the doc in the UI. BLURB_FIELD_NAME: { "type": "keyword", "index": False, "doc_values": False, # Generally False by default; just making sure. "store": False, }, # Same as above. # TODO(andrei): If we want to search on this this needs to be # changed. DOC_SUMMARY_FIELD_NAME: { "type": "keyword", "index": False, "doc_values": False, # Generally False by default; just making sure. "store": False, }, # Same as above. # TODO(andrei): If we want to search on this this needs to be # changed. CHUNK_CONTEXT_FIELD_NAME: { "type": "keyword", "index": False, "doc_values": False, # Generally False by default; just making sure. "store": False, }, # Same as above. METADATA_SUFFIX_FIELD_NAME: { "type": "keyword", "index": False, "doc_values": False, "store": False, }, # Product-specific fields. DOCUMENT_SETS_FIELD_NAME: {"type": "keyword"}, USER_PROJECTS_FIELD_NAME: {"type": "integer"}, PERSONAS_FIELD_NAME: {"type": "integer"}, PRIMARY_OWNERS_FIELD_NAME: {"type": "keyword"}, SECONDARY_OWNERS_FIELD_NAME: {"type": "keyword"}, # OpenSearch metadata fields. DOCUMENT_ID_FIELD_NAME: {"type": "keyword"}, CHUNK_INDEX_FIELD_NAME: {"type": "integer"}, # The maximum number of tokens this chunk's content can hold. MAX_CHUNK_SIZE_FIELD_NAME: {"type": "integer"}, # Hierarchy filtering - list of ancestor hierarchy node IDs. # Used for scoped search within folder/space hierarchies. # OpenSearch's terms query with value_type: "bitmap" can # efficiently check if any value in this array matches a # query bitmap. ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: {"type": "integer"}, }, } if multitenant: schema["properties"][TENANT_ID_FIELD_NAME] = {"type": "keyword"} return schema @staticmethod def get_index_settings_based_on_environment() -> dict[str, Any]: """ Returns the index settings based on the environment. """ if USING_AWS_MANAGED_OPENSEARCH: # NOTE: The number of data copies, including the primary (not a # replica) copy, must be divisible by the number of AZs. if MULTI_TENANT: number_of_shards = 324 number_of_replicas = 2 else: number_of_shards = 3 number_of_replicas = 2 else: number_of_shards = 1 number_of_replicas = 1 if OPENSEARCH_INDEX_NUM_SHARDS is not None: number_of_shards = OPENSEARCH_INDEX_NUM_SHARDS if OPENSEARCH_INDEX_NUM_REPLICAS is not None: number_of_replicas = OPENSEARCH_INDEX_NUM_REPLICAS return { "index": { "number_of_shards": number_of_shards, "number_of_replicas": number_of_replicas, # Required for vector search. "knn": True, "knn.algo_param.ef_search": EF_SEARCH, } } ================================================ FILE: backend/onyx/document_index/opensearch/search.py ================================================ import random from datetime import datetime from datetime import timedelta from datetime import timezone from typing import Any from typing import TypeAlias from typing import TypeVar from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED from onyx.configs.app_configs import OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED from onyx.configs.app_configs import OPENSEARCH_PROFILING_DISABLED from onyx.configs.constants import DocumentSource from onyx.configs.constants import INDEX_SEPARATOR from onyx.context.search.models import IndexFilters from onyx.context.search.models import Tag from onyx.document_index.interfaces_new import TenantState from onyx.document_index.opensearch.constants import ASSUMED_DOCUMENT_AGE_DAYS from onyx.document_index.opensearch.constants import ( DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES, ) from onyx.document_index.opensearch.constants import ( DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW, ) from onyx.document_index.opensearch.constants import ( HYBRID_SEARCH_NORMALIZATION_PIPELINE, ) from onyx.document_index.opensearch.constants import ( HYBRID_SEARCH_SUBQUERY_CONFIGURATION, ) from onyx.document_index.opensearch.constants import HybridSearchNormalizationPipeline from onyx.document_index.opensearch.constants import HybridSearchSubqueryConfiguration from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME from onyx.document_index.opensearch.schema import ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME from onyx.document_index.opensearch.schema import CHUNK_INDEX_FIELD_NAME from onyx.document_index.opensearch.schema import CONTENT_FIELD_NAME from onyx.document_index.opensearch.schema import CONTENT_VECTOR_FIELD_NAME from onyx.document_index.opensearch.schema import DOCUMENT_ID_FIELD_NAME from onyx.document_index.opensearch.schema import DOCUMENT_SETS_FIELD_NAME from onyx.document_index.opensearch.schema import HIDDEN_FIELD_NAME from onyx.document_index.opensearch.schema import LAST_UPDATED_FIELD_NAME from onyx.document_index.opensearch.schema import MAX_CHUNK_SIZE_FIELD_NAME from onyx.document_index.opensearch.schema import METADATA_LIST_FIELD_NAME from onyx.document_index.opensearch.schema import PERSONAS_FIELD_NAME from onyx.document_index.opensearch.schema import PUBLIC_FIELD_NAME from onyx.document_index.opensearch.schema import set_or_convert_timezone_to_utc from onyx.document_index.opensearch.schema import SOURCE_TYPE_FIELD_NAME from onyx.document_index.opensearch.schema import TENANT_ID_FIELD_NAME from onyx.document_index.opensearch.schema import TITLE_FIELD_NAME from onyx.document_index.opensearch.schema import TITLE_VECTOR_FIELD_NAME from onyx.document_index.opensearch.schema import USER_PROJECTS_FIELD_NAME # See https://docs.opensearch.org/latest/query-dsl/term/terms/. MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY = 65_536 _T = TypeVar("_T") TermsQuery: TypeAlias = dict[str, dict[str, list[_T]]] TermQuery: TypeAlias = dict[str, dict[str, dict[str, _T]]] # TODO(andrei): Turn all magic dictionaries to pydantic models. # Normalization pipelines combine document scores from multiple query clauses. # The number and ordering of weights should match the query clauses. The values # of the weights should sum to 1. def _get_hybrid_search_normalization_weights() -> list[float]: if ( HYBRID_SEARCH_SUBQUERY_CONFIGURATION is HybridSearchSubqueryConfiguration.TITLE_VECTOR_CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD ): # Since the titles are included in the contents, the embedding matches # are heavily downweighted as they act as a boost rather than an # independent scoring component. search_title_vector_weight = 0.1 search_content_vector_weight = 0.45 # Single keyword weight for both title and content (merged from former # title keyword + content keyword). search_keyword_weight = 0.45 # NOTE: It is critical that the order of these weights matches the order # of the sub-queries in the hybrid search. hybrid_search_normalization_weights = [ search_title_vector_weight, search_content_vector_weight, search_keyword_weight, ] elif ( HYBRID_SEARCH_SUBQUERY_CONFIGURATION is HybridSearchSubqueryConfiguration.CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD ): search_content_vector_weight = 0.5 # Single keyword weight for both title and content (merged from former # title keyword + content keyword). search_keyword_weight = 0.5 # NOTE: It is critical that the order of these weights matches the order # of the sub-queries in the hybrid search. hybrid_search_normalization_weights = [ search_content_vector_weight, search_keyword_weight, ] else: raise ValueError( f"Bug: Unhandled hybrid search subquery configuration: {HYBRID_SEARCH_SUBQUERY_CONFIGURATION}." ) assert ( sum(hybrid_search_normalization_weights) == 1.0 ), "Bug: Hybrid search normalization weights do not sum to 1.0." return hybrid_search_normalization_weights def get_min_max_normalization_pipeline_name_and_config() -> tuple[str, dict[str, Any]]: min_max_normalization_pipeline_name = "normalization_pipeline_min_max" min_max_normalization_pipeline_config: dict[str, Any] = { "description": "Normalization for keyword and vector scores using min-max", "phase_results_processors": [ { # https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/ "normalization-processor": { "normalization": {"technique": "min_max"}, "combination": { "technique": "arithmetic_mean", "parameters": { "weights": _get_hybrid_search_normalization_weights() }, }, } } ], } return min_max_normalization_pipeline_name, min_max_normalization_pipeline_config def get_zscore_normalization_pipeline_name_and_config() -> tuple[str, dict[str, Any]]: zscore_normalization_pipeline_name = "normalization_pipeline_zscore" zscore_normalization_pipeline_config: dict[str, Any] = { "description": "Normalization for keyword and vector scores using z-score", "phase_results_processors": [ { # https://docs.opensearch.org/latest/search-plugins/search-pipelines/normalization-processor/ "normalization-processor": { "normalization": {"technique": "z_score"}, "combination": { "technique": "arithmetic_mean", "parameters": { "weights": _get_hybrid_search_normalization_weights() }, }, } } ], } return zscore_normalization_pipeline_name, zscore_normalization_pipeline_config def get_normalization_pipeline_name_and_config() -> tuple[str, dict[str, Any]]: if ( HYBRID_SEARCH_NORMALIZATION_PIPELINE is HybridSearchNormalizationPipeline.MIN_MAX ): return get_min_max_normalization_pipeline_name_and_config() elif ( HYBRID_SEARCH_NORMALIZATION_PIPELINE is HybridSearchNormalizationPipeline.ZSCORE ): return get_zscore_normalization_pipeline_name_and_config() else: raise ValueError( f"Bug: Unhandled hybrid search normalization pipeline: {HYBRID_SEARCH_NORMALIZATION_PIPELINE}." ) class DocumentQuery: """ TODO(andrei): Implement multi-phase search strategies. TODO(andrei): Implement document boost. TODO(andrei): Implement document age. """ @staticmethod def get_from_document_id_query( document_id: str, tenant_state: TenantState, index_filters: IndexFilters, include_hidden: bool, max_chunk_size: int, min_chunk_index: int | None, max_chunk_index: int | None, get_full_document: bool = True, ) -> dict[str, Any]: """ Returns a final search query which gets chunks from a given document ID. This query can be directly supplied to the OpenSearch client. TODO(andrei): Currently capped at 10k results. Implement scroll/point in time for results so that we can return arbitrarily-many IDs. Args: document_id: Onyx document ID. Notably not an OpenSearch document ID, which points to what Onyx would refer to as a chunk. tenant_state: Tenant state containing the tenant ID. index_filters: Filters for the document retrieval query. include_hidden: Whether to include hidden documents. max_chunk_size: Document chunks are categorized by the maximum number of tokens they can hold. This parameter specifies the maximum size category of document chunks to retrieve. min_chunk_index: The minimum chunk index to retrieve, inclusive. If None, no minimum chunk index will be applied. max_chunk_index: The maximum chunk index to retrieve, inclusive. If None, no maximum chunk index will be applied. get_full_document: Whether to get the full document body. If False, OpenSearch will only return the matching document chunk IDs plus metadata; the source data will be omitted from the response. Use this for performance optimization if OpenSearch IDs are sufficient. Defaults to True. Returns: A dictionary representing the final ID search query. """ filter_clauses = DocumentQuery._get_search_filters( tenant_state=tenant_state, include_hidden=include_hidden, access_control_list=index_filters.access_control_list, source_types=index_filters.source_type or [], tags=index_filters.tags or [], document_sets=index_filters.document_set or [], project_id_filter=index_filters.project_id_filter, persona_id_filter=index_filters.persona_id_filter, time_cutoff=index_filters.time_cutoff, min_chunk_index=min_chunk_index, max_chunk_index=max_chunk_index, max_chunk_size=max_chunk_size, document_id=document_id, attached_document_ids=index_filters.attached_document_ids, hierarchy_node_ids=index_filters.hierarchy_node_ids, ) final_get_ids_query: dict[str, Any] = { "query": {"bool": {"filter": filter_clauses}}, # We include this to make sure OpenSearch does not revert to # returning some number of results less than the index max allowed # return size. "size": DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW, # By default exclude retrieving the vector fields in order to save # on retrieval cost as we don't need them upstream. "_source": { "excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME] }, "timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s", } if not get_full_document: # If we explicitly do not want the underlying document, we will only # retrieve IDs. final_get_ids_query["_source"] = False if not OPENSEARCH_PROFILING_DISABLED: final_get_ids_query["profile"] = True return final_get_ids_query @staticmethod def delete_from_document_id_query( document_id: str, tenant_state: TenantState, ) -> dict[str, Any]: """ Returns a final search query which deletes chunks from a given document ID. This query can be directly supplied to the OpenSearch client. Intended to be supplied to the OpenSearch client's delete_by_query method. TODO(andrei): There is no limit to the number of document chunks that can be deleted by this query. This could get expensive. Consider implementing batching. Args: document_id: Onyx document ID. Notably not an OpenSearch document ID, which points to what Onyx would refer to as a chunk. tenant_state: Tenant state containing the tenant ID. Returns: A dictionary representing the final delete query. """ filter_clauses = DocumentQuery._get_search_filters( tenant_state=tenant_state, # Delete hidden docs too. include_hidden=True, access_control_list=None, source_types=[], tags=[], document_sets=[], project_id_filter=None, persona_id_filter=None, time_cutoff=None, min_chunk_index=None, max_chunk_index=None, max_chunk_size=None, document_id=document_id, ) final_delete_query: dict[str, Any] = { "query": {"bool": {"filter": filter_clauses}}, "timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s", } if not OPENSEARCH_PROFILING_DISABLED: final_delete_query["profile"] = True return final_delete_query @staticmethod def get_hybrid_search_query( query_text: str, query_vector: list[float], num_hits: int, tenant_state: TenantState, index_filters: IndexFilters, include_hidden: bool, ) -> dict[str, Any]: """Returns a final hybrid search query. NOTE: This query can be directly supplied to the OpenSearch client, but it MUST be supplied in addition to a search pipeline. The results from hybrid search are not meaningful without that step. TODO(andrei): There is some duplicated logic in this function with others in this file. Args: query_text: The text to query for. query_vector: The vector embedding of the text to query for. num_hits: The final number of hits to return. tenant_state: Tenant state containing the tenant ID. index_filters: Filters for the hybrid search query. include_hidden: Whether to include hidden documents. Returns: A dictionary representing the final hybrid search query. """ # WARNING: Profiling does not work with hybrid search; do not add it at # this level. See https://github.com/opensearch-project/neural-search/issues/1255 if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW: raise ValueError( f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed " f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})." ) # TODO(andrei, yuhong): We can tune this more dynamically based on # num_hits. max_results_per_subquery = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries( query_text, query_vector, vector_candidates=max_results_per_subquery ) hybrid_search_filters = DocumentQuery._get_search_filters( tenant_state=tenant_state, include_hidden=include_hidden, # TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to # now. This should not cause any issues but it can introduce # redundant filters in queries that may affect performance. access_control_list=index_filters.access_control_list, source_types=index_filters.source_type or [], tags=index_filters.tags or [], document_sets=index_filters.document_set or [], project_id_filter=index_filters.project_id_filter, persona_id_filter=index_filters.persona_id_filter, time_cutoff=index_filters.time_cutoff, min_chunk_index=None, max_chunk_index=None, attached_document_ids=index_filters.attached_document_ids, hierarchy_node_ids=index_filters.hierarchy_node_ids, ) # See https://docs.opensearch.org/latest/query-dsl/compound/hybrid/ hybrid_search_query: dict[str, Any] = { "hybrid": { "queries": hybrid_search_subqueries, # Max results per subquery per shard before aggregation. Ensures # keyword and vector subqueries contribute equally to the # candidate pool for hybrid fusion. # Sources: # https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/ # https://opensearch.org/blog/navigating-pagination-in-hybrid-queries-with-the-pagination_depth-parameter/ "pagination_depth": max_results_per_subquery, # Applied to all the sub-queries independently (this avoids # subqueries having a lot of results thrown out during # aggregation). # Sources: # https://docs.opensearch.org/latest/query-dsl/compound/hybrid/ # https://opensearch.org/blog/introducing-common-filter-support-for-hybrid-search-queries # Does AND for each filter in the list. "filter": {"bool": {"filter": hybrid_search_filters}}, } } final_hybrid_search_body: dict[str, Any] = { "query": hybrid_search_query, "size": num_hits, "timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s", # Exclude retrieving the vector fields in order to save on # retrieval cost as we don't need them upstream. "_source": { "excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME] }, } if not OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED: final_hybrid_search_body["highlight"] = ( DocumentQuery._get_match_highlights_configuration() ) # Explain is for scoring breakdowns. Setting this significantly # increases query latency. if OPENSEARCH_EXPLAIN_ENABLED: final_hybrid_search_body["explain"] = True return final_hybrid_search_body @staticmethod def get_keyword_search_query( query_text: str, num_hits: int, tenant_state: TenantState, index_filters: IndexFilters, include_hidden: bool, ) -> dict[str, Any]: """Returns a final keyword search query. This query can be directly supplied to the OpenSearch client. TODO(andrei): There is some duplicated logic in this function with others in this file. Args: query_text: The text to query for. num_hits: The final number of hits to return. tenant_state: Tenant state containing the tenant ID. index_filters: Filters for the keyword search query. include_hidden: Whether to include hidden documents. Returns: A dictionary representing the final keyword search query. """ if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW: raise ValueError( f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed " f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})." ) keyword_search_filters = DocumentQuery._get_search_filters( tenant_state=tenant_state, include_hidden=include_hidden, # TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to # now. This should not cause any issues but it can introduce # redundant filters in queries that may affect performance. access_control_list=index_filters.access_control_list, source_types=index_filters.source_type or [], tags=index_filters.tags or [], document_sets=index_filters.document_set or [], project_id_filter=index_filters.project_id_filter, persona_id_filter=index_filters.persona_id_filter, time_cutoff=index_filters.time_cutoff, min_chunk_index=None, max_chunk_index=None, attached_document_ids=index_filters.attached_document_ids, hierarchy_node_ids=index_filters.hierarchy_node_ids, ) keyword_search_query = ( DocumentQuery._get_title_content_combined_keyword_search_query( query_text, search_filters=keyword_search_filters ) ) final_keyword_search_query: dict[str, Any] = { "query": keyword_search_query, "size": num_hits, "timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s", # Exclude retrieving the vector fields in order to save on # retrieval cost as we don't need them upstream. "_source": { "excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME] }, } if not OPENSEARCH_MATCH_HIGHLIGHTS_DISABLED: final_keyword_search_query["highlight"] = ( DocumentQuery._get_match_highlights_configuration() ) if not OPENSEARCH_PROFILING_DISABLED: final_keyword_search_query["profile"] = True # Explain is for scoring breakdowns. Setting this significantly # increases query latency. if OPENSEARCH_EXPLAIN_ENABLED: final_keyword_search_query["explain"] = True return final_keyword_search_query @staticmethod def get_semantic_search_query( query_embedding: list[float], num_hits: int, tenant_state: TenantState, index_filters: IndexFilters, include_hidden: bool, ) -> dict[str, Any]: """Returns a final semantic search query. This query can be directly supplied to the OpenSearch client. TODO(andrei): There is some duplicated logic in this function with others in this file. Args: query_embedding: The vector embedding of the text to query for. num_hits: The final number of hits to return. tenant_state: Tenant state containing the tenant ID. index_filters: Filters for the semantic search query. include_hidden: Whether to include hidden documents. Returns: A dictionary representing the final semantic search query. """ if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW: raise ValueError( f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed " f"result window ({DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW})." ) semantic_search_filters = DocumentQuery._get_search_filters( tenant_state=tenant_state, include_hidden=include_hidden, # TODO(andrei): We've done no filtering for PUBLIC_DOC_PAT up to # now. This should not cause any issues but it can introduce # redundant filters in queries that may affect performance. access_control_list=index_filters.access_control_list, source_types=index_filters.source_type or [], tags=index_filters.tags or [], document_sets=index_filters.document_set or [], project_id_filter=index_filters.project_id_filter, persona_id_filter=index_filters.persona_id_filter, time_cutoff=index_filters.time_cutoff, min_chunk_index=None, max_chunk_index=None, attached_document_ids=index_filters.attached_document_ids, hierarchy_node_ids=index_filters.hierarchy_node_ids, ) semantic_search_query = ( DocumentQuery._get_content_vector_similarity_search_query( query_embedding, vector_candidates=num_hits, search_filters=semantic_search_filters, ) ) final_semantic_search_query: dict[str, Any] = { "query": semantic_search_query, "size": num_hits, "timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s", # Exclude retrieving the vector fields in order to save on # retrieval cost as we don't need them upstream. "_source": { "excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME] }, } if not OPENSEARCH_PROFILING_DISABLED: final_semantic_search_query["profile"] = True # Explain is for scoring breakdowns. Setting this significantly # increases query latency. if OPENSEARCH_EXPLAIN_ENABLED: final_semantic_search_query["explain"] = True return final_semantic_search_query @staticmethod def get_random_search_query( tenant_state: TenantState, index_filters: IndexFilters, num_to_retrieve: int, ) -> dict[str, Any]: """Returns a final search query that gets document chunks randomly. Args: tenant_state: Tenant state containing the tenant ID. index_filters: Filters for the random search query. num_to_retrieve: Number of document chunks to retrieve. Returns: A dictionary representing the final random search query. """ search_filters = DocumentQuery._get_search_filters( tenant_state=tenant_state, include_hidden=False, access_control_list=index_filters.access_control_list, source_types=index_filters.source_type or [], tags=index_filters.tags or [], document_sets=index_filters.document_set or [], project_id_filter=index_filters.project_id_filter, persona_id_filter=index_filters.persona_id_filter, time_cutoff=index_filters.time_cutoff, min_chunk_index=None, max_chunk_index=None, attached_document_ids=index_filters.attached_document_ids, hierarchy_node_ids=index_filters.hierarchy_node_ids, ) final_random_search_query = { "query": { "function_score": { "query": {"bool": {"filter": search_filters}}, # See # https://docs.opensearch.org/latest/query-dsl/compound/function-score/#the-random-score-function "random_score": { # We'll use a different seed per invocation. "seed": random.randint(0, 1_000_000), # Some field which has a unique value per document # chunk. "field": "_seq_no", }, # Replaces whatever score was computed in the query. "boost_mode": "replace", } }, "size": num_to_retrieve, "timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s", # Exclude retrieving the vector fields in order to save on # retrieval cost as we don't need them upstream. "_source": { "excludes": [TITLE_VECTOR_FIELD_NAME, CONTENT_VECTOR_FIELD_NAME] }, } if not OPENSEARCH_PROFILING_DISABLED: final_random_search_query["profile"] = True return final_random_search_query @staticmethod def _get_hybrid_search_subqueries( query_text: str, query_vector: list[float], # The default number of neighbors to consider for knn vector similarity # search. This is higher than the number of results because the scoring # is hybrid. For a detailed breakdown, see where the default value is # set. vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES, ) -> list[dict[str, Any]]: """Returns subqueries for hybrid search. Each of these subqueries are the "hybrid" component of this search. We search on various things and combine results. The return of this function is not sufficient to be directly supplied to the OpenSearch client. See get_hybrid_search_query. Normalization is not performed here. The weights of each of these subqueries should be configured in a search pipeline. The exact subqueries executed depend on the HYBRID_SEARCH_SUBQUERY_CONFIGURATION setting. NOTE: For OpenSearch, 5 is the maximum number of query clauses allowed in a single hybrid query. Source: https://docs.opensearch.org/latest/query-dsl/compound/hybrid/ NOTE: Each query is independent during the search phase; there is no backfilling of scores for missing query components. What this means is that if a document was a good vector match but did not show up for keyword, it gets a score of 0 for the keyword component of the hybrid scoring. This is not as bad as just disregarding a score though as there is normalization applied after. So really it is "increasing" the missing score compared to if it was included and the range was renormalized. This does however mean that between docs that have high scores for say the vector field, the keyword scores between them are completely ignored unless they also showed up in the keyword query as a reasonably high match. TLDR, this is a bit of unique funky behavior but it seems ok. NOTE: Options considered and rejected: - minimum_should_match: Since it's hybrid search and users often provide semantic queries, there is often a lot of terms, and very low number of meaningful keywords (and a low ratio of keywords). - fuzziness AUTO: Typo tolerance (0/1/2 edit distance by term length). It's mostly for typos as the analyzer ("english" by default) already does some stemming and tokenization. In testing datasets, this makes recall slightly worse. It also is less performant so not really any reason to do it. Args: query_text: The text of the query to search for. query_vector: The vector embedding of the query to search for. num_candidates: The number of candidates to consider for vector similarity search. """ # Build sub-queries for hybrid search. Order must match normalization # pipeline weights. if ( HYBRID_SEARCH_SUBQUERY_CONFIGURATION is HybridSearchSubqueryConfiguration.TITLE_VECTOR_CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD ): return [ DocumentQuery._get_title_vector_similarity_search_query( query_vector, vector_candidates ), DocumentQuery._get_content_vector_similarity_search_query( query_vector, vector_candidates ), DocumentQuery._get_title_content_combined_keyword_search_query( query_text ), ] elif ( HYBRID_SEARCH_SUBQUERY_CONFIGURATION is HybridSearchSubqueryConfiguration.CONTENT_VECTOR_TITLE_CONTENT_COMBINED_KEYWORD ): return [ DocumentQuery._get_content_vector_similarity_search_query( query_vector, vector_candidates ), DocumentQuery._get_title_content_combined_keyword_search_query( query_text ), ] else: raise ValueError( f"Bug: Unhandled hybrid search subquery configuration: {HYBRID_SEARCH_SUBQUERY_CONFIGURATION}" ) @staticmethod def _get_title_vector_similarity_search_query( query_vector: list[float], vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES, ) -> dict[str, Any]: return { "knn": { TITLE_VECTOR_FIELD_NAME: { "vector": query_vector, "k": vector_candidates, } } } @staticmethod def _get_content_vector_similarity_search_query( query_vector: list[float], vector_candidates: int = DEFAULT_NUM_HYBRID_SUBQUERY_CANDIDATES, search_filters: list[dict[str, Any]] | None = None, ) -> dict[str, Any]: query = { "knn": { CONTENT_VECTOR_FIELD_NAME: { "vector": query_vector, "k": vector_candidates, } } } if search_filters is not None: query["knn"][CONTENT_VECTOR_FIELD_NAME]["filter"] = { "bool": {"filter": search_filters} } return query @staticmethod def _get_title_content_combined_keyword_search_query( query_text: str, search_filters: list[dict[str, Any]] | None = None, ) -> dict[str, Any]: query = { "bool": { "should": [ { "match": { TITLE_FIELD_NAME: { "query": query_text, "operator": "or", # The title fields are strongly discounted as # they are included in the content. This just # acts as a minor boost. "boost": 0.1, } } }, { "match_phrase": { TITLE_FIELD_NAME: { "query": query_text, "slop": 1, "boost": 0.2, } } }, { # Analyzes the query and returns results which match any # of the query's terms. More matches result in higher # scores. "match": { CONTENT_FIELD_NAME: { "query": query_text, "operator": "or", "boost": 1.0, } } }, { # Matches an exact phrase in a specified order. "match_phrase": { CONTENT_FIELD_NAME: { "query": query_text, # The number of words permitted between words of # a query phrase and still result in a match. "slop": 1, "boost": 1.5, } } }, ], # Ensures at least one match subquery from the query is present # in the document. This defaults to 1, unless a filter or must # clause is supplied, in which case it defaults to 0. "minimum_should_match": 1, } } if search_filters is not None: query["bool"]["filter"] = search_filters return query @staticmethod def _get_search_filters( tenant_state: TenantState, include_hidden: bool, access_control_list: list[str] | None, source_types: list[DocumentSource], tags: list[Tag], document_sets: list[str], project_id_filter: int | None, persona_id_filter: int | None, time_cutoff: datetime | None, min_chunk_index: int | None, max_chunk_index: int | None, max_chunk_size: int | None = None, document_id: str | None = None, # Assistant knowledge filters attached_document_ids: list[str] | None = None, hierarchy_node_ids: list[int] | None = None, ) -> list[dict[str, Any]]: """Returns filters to be passed into the "filter" key of a search query. The "filter" key applies a logical AND operator to its elements, so every subfilter must evaluate to true in order for the document to be retrieved. This function returns a list of such subfilters. See https://docs.opensearch.org/latest/query-dsl/compound/bool/. TODO(ENG-3874): The terms queries returned by this function can be made more performant for large cardinality sets by sorting the values by their UTF-8 byte order. TODO(ENG-3875): This function can take even better advantage of filter caching by grouping "static" filters together into one sub-clause. Args: tenant_state: Tenant state containing the tenant ID. include_hidden: Whether to include hidden documents. access_control_list: Access control list for the documents to retrieve. If None, there is no restriction on the documents that can be retrieved. If not None, only public documents can be retrieved, or non-public documents where at least one acl provided here is present in the document's acl list. source_types: If supplied, only documents of one of these source types will be retrieved. tags: If supplied, only documents with an entry in their metadata list corresponding to a tag will be retrieved. document_sets: If supplied, only documents with at least one document set ID from this list will be retrieved. project_id_filter: If not None, only documents with this project ID in user projects will be retrieved. Additive — only applied when a knowledge scope already exists. persona_id_filter: If not None, only documents whose personas array contains this persona ID will be retrieved. Primary — creates a knowledge scope on its own. time_cutoff: Time cutoff for the documents to retrieve. If not None, Documents which were last updated before this date will not be returned. For documents which do not have a value for their last updated time, we assume some default age of ASSUMED_DOCUMENT_AGE_DAYS for when the document was last updated. min_chunk_index: The minimum chunk index to retrieve, inclusive. If None, no minimum chunk index will be applied. max_chunk_index: The maximum chunk index to retrieve, inclusive. If None, no maximum chunk index will be applied. max_chunk_size: The type of chunk to retrieve, specified by the maximum number of tokens it can hold. If None, no filter will be applied for this. Defaults to None. NOTE: See DocumentChunk.max_chunk_size. document_id: The document ID to retrieve. If None, no filter will be applied for this. Defaults to None. attached_document_ids: Document IDs explicitly attached to the assistant. If provided along with hierarchy_node_ids, documents matching EITHER criteria will be retrieved (OR logic). hierarchy_node_ids: Hierarchy node IDs (folders/spaces) attached to the assistant. Matches chunks where ancestor_hierarchy_node_ids contains any of these values. Raises: ValueError: document_id and attached_document_ids were supplied together. This is not allowed because they operate on the same schema field, and it does not semantically make sense to use them together. ValueError: Too many of one of the collection arguments was supplied. Returns: A list of filters to be passed into the "filter" key of a search query. """ def _get_acl_visibility_filter( access_control_list: list[str], ) -> dict[str, dict[str, list[TermQuery[bool] | TermsQuery[str]] | int]]: """Returns a filter for the access control list. Since this returns an isolated bool should clause, it can be cached in OpenSearch independently of other clauses in _get_search_filters. Args: access_control_list: The access control list to restrict documents to. Raises: ValueError: The number of access control list entries is greater than MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY. Returns: A filter for the access control list. """ # Logical OR operator on its elements. acl_visibility_filter: dict[str, dict[str, Any]] = { "bool": { "should": [{"term": {PUBLIC_FIELD_NAME: {"value": True}}}], "minimum_should_match": 1, } } if access_control_list: if len(access_control_list) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY: raise ValueError( f"Too many access control list entries: {len(access_control_list)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}." ) # Use terms instead of a list of term within a should clause # because Lucene will optimize the filtering for large sets of # terms. Small sets of terms are not expected to perform any # differently than individual term clauses. acl_subclause: TermsQuery[str] = { "terms": {ACCESS_CONTROL_LIST_FIELD_NAME: list(access_control_list)} } acl_visibility_filter["bool"]["should"].append(acl_subclause) return acl_visibility_filter def _get_source_type_filter( source_types: list[DocumentSource], ) -> TermsQuery[str]: """Returns a filter for the source types. Since this returns an isolated terms clause, it can be cached in OpenSearch independently of other clauses in _get_search_filters. Args: source_types: The source types to restrict documents to. Raises: ValueError: The number of source types is greater than MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY. ValueError: An empty list was supplied. Returns: A filter for the source types. """ if not source_types: raise ValueError( "source_types cannot be empty if trying to create a source type filter." ) if len(source_types) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY: raise ValueError( f"Too many source types: {len(source_types)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}." ) # Use terms instead of a list of term within a should clause because # Lucene will optimize the filtering for large sets of terms. Small # sets of terms are not expected to perform any differently than # individual term clauses. return { "terms": { SOURCE_TYPE_FIELD_NAME: [ source_type.value for source_type in source_types ] } } def _get_tag_filter(tags: list[Tag]) -> TermsQuery[str]: """Returns a filter for the tags. Since this returns an isolated terms clause, it can be cached in OpenSearch independently of other clauses in _get_search_filters. Args: tags: The tags to restrict documents to. Raises: ValueError: The number of tags is greater than MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY. ValueError: An empty list was supplied. Returns: A filter for the tags. """ if not tags: raise ValueError( "tags cannot be empty if trying to create a tag filter." ) if len(tags) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY: raise ValueError( f"Too many tags: {len(tags)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}." ) # Kind of an abstraction leak, see # convert_metadata_dict_to_list_of_strings for why metadata list # entries are expected to look this way. tag_str_list = [ f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in tags ] # Use terms instead of a list of term within a should clause because # Lucene will optimize the filtering for large sets of terms. Small # sets of terms are not expected to perform any differently than # individual term clauses. return {"terms": {METADATA_LIST_FIELD_NAME: tag_str_list}} def _get_document_set_filter(document_sets: list[str]) -> TermsQuery[str]: """Returns a filter for the document sets. Since this returns an isolated terms clause, it can be cached in OpenSearch independently of other clauses in _get_search_filters. Args: document_sets: The document sets to restrict documents to. Raises: ValueError: The number of document sets is greater than MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY. ValueError: An empty list was supplied. Returns: A filter for the document sets. """ if not document_sets: raise ValueError( "document_sets cannot be empty if trying to create a document set filter." ) if len(document_sets) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY: raise ValueError( f"Too many document sets: {len(document_sets)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}." ) # Use terms instead of a list of term within a should clause because # Lucene will optimize the filtering for large sets of terms. Small # sets of terms are not expected to perform any differently than # individual term clauses. return {"terms": {DOCUMENT_SETS_FIELD_NAME: list(document_sets)}} def _get_user_project_filter(project_id: int) -> TermQuery[int]: return {"term": {USER_PROJECTS_FIELD_NAME: {"value": project_id}}} def _get_persona_filter(persona_id: int) -> TermQuery[int]: return {"term": {PERSONAS_FIELD_NAME: {"value": persona_id}}} def _get_time_cutoff_filter(time_cutoff: datetime) -> dict[str, Any]: # Convert to UTC if not already so the cutoff is comparable to the # document data. time_cutoff = set_or_convert_timezone_to_utc(time_cutoff) # Logical OR operator on its elements. time_cutoff_filter: dict[str, Any] = { "bool": {"should": [], "minimum_should_match": 1} } time_cutoff_filter["bool"]["should"].append( { "range": { LAST_UPDATED_FIELD_NAME: {"gte": int(time_cutoff.timestamp())} } } ) if time_cutoff < datetime.now(timezone.utc) - timedelta( days=ASSUMED_DOCUMENT_AGE_DAYS ): # Since the time cutoff is older than ASSUMED_DOCUMENT_AGE_DAYS # ago, we include documents which have no # LAST_UPDATED_FIELD_NAME value. time_cutoff_filter["bool"]["should"].append( { "bool": { "must_not": {"exists": {"field": LAST_UPDATED_FIELD_NAME}} } } ) return time_cutoff_filter def _get_chunk_index_filter( min_chunk_index: int | None, max_chunk_index: int | None ) -> dict[str, Any]: range_clause: dict[str, Any] = {"range": {CHUNK_INDEX_FIELD_NAME: {}}} if min_chunk_index is not None: range_clause["range"][CHUNK_INDEX_FIELD_NAME]["gte"] = min_chunk_index if max_chunk_index is not None: range_clause["range"][CHUNK_INDEX_FIELD_NAME]["lte"] = max_chunk_index return range_clause def _get_attached_document_id_filter( doc_ids: list[str], ) -> TermsQuery[str]: """ Returns a filter for documents explicitly attached to an assistant. Since this returns an isolated terms clause, it can be cached in OpenSearch independently of other clauses in _get_search_filters. Args: doc_ids: The document IDs to restrict documents to. Raises: ValueError: The number of document IDs is greater than MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY. ValueError: An empty list was supplied. Returns: A filter for the document IDs. """ if not doc_ids: raise ValueError( "doc_ids cannot be empty if trying to create a document ID filter." ) if len(doc_ids) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY: raise ValueError( f"Too many document IDs: {len(doc_ids)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}." ) # Use terms instead of a list of term within a should clause because # Lucene will optimize the filtering for large sets of terms. Small # sets of terms are not expected to perform any differently than # individual term clauses. return {"terms": {DOCUMENT_ID_FIELD_NAME: list(doc_ids)}} def _get_hierarchy_node_filter( node_ids: list[int], ) -> TermsQuery[int]: """ Returns a filter for chunks whose ancestors include any of the given hierarchy nodes. Since this returns an isolated terms clause, it can be cached in OpenSearch independently of other clauses in _get_search_filters. Args: node_ids: The hierarchy node IDs to restrict documents to. Raises: ValueError: The number of hierarchy node IDs is greater than MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY. ValueError: An empty list was supplied. Returns: A filter for the hierarchy node IDs. """ if not node_ids: raise ValueError( "node_ids cannot be empty if trying to create a hierarchy node ID filter." ) if len(node_ids) > MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY: raise ValueError( f"Too many hierarchy node IDs: {len(node_ids)}. Max allowed: {MAX_NUM_TERMS_ALLOWED_IN_TERMS_QUERY}." ) # Use terms instead of a list of term within a should clause because # Lucene will optimize the filtering for large sets of terms. Small # sets of terms are not expected to perform any differently than # individual term clauses. return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: list(node_ids)}} if document_id is not None and attached_document_ids is not None: raise ValueError( "document_id and attached_document_ids cannot be used together." ) filter_clauses: list[dict[str, Any]] = [] if not include_hidden: filter_clauses.append({"term": {HIDDEN_FIELD_NAME: {"value": False}}}) if access_control_list is not None: # If an access control list is provided, the caller can only # retrieve public documents, and non-public documents where at least # one acl provided here is present in the document's acl list. If # there is explicitly no list provided, we make no restrictions on # the documents that can be retrieved. filter_clauses.append(_get_acl_visibility_filter(access_control_list)) if source_types: # If at least one source type is provided, the caller will only # retrieve documents whose source type is present in this input # list. filter_clauses.append(_get_source_type_filter(source_types)) if tags: # If at least one tag is provided, the caller will only retrieve # documents where at least one tag provided here is present in the # document's metadata list. filter_clauses.append(_get_tag_filter(tags)) # Knowledge scope: explicit knowledge attachments restrict what an # assistant can see. When none are set the assistant searches # everything. # # persona_id_filter is a primary trigger — a persona with user files IS # explicit knowledge, so it can start a knowledge scope on its own. # # project_id_filter is additive — it widens the scope to also cover # overflowing project files but never restricts on its own (a chat # inside a project should still search team knowledge). has_knowledge_scope = ( attached_document_ids or hierarchy_node_ids or document_sets or persona_id_filter is not None ) if has_knowledge_scope: # Since this returns an isolated bool should clause, it can be # cached in OpenSearch independently of other clauses in # _get_search_filters. knowledge_filter: dict[str, Any] = { "bool": {"should": [], "minimum_should_match": 1} } if attached_document_ids: knowledge_filter["bool"]["should"].append( _get_attached_document_id_filter(attached_document_ids) ) if hierarchy_node_ids: knowledge_filter["bool"]["should"].append( _get_hierarchy_node_filter(hierarchy_node_ids) ) if document_sets: knowledge_filter["bool"]["should"].append( _get_document_set_filter(document_sets) ) if persona_id_filter is not None: knowledge_filter["bool"]["should"].append( _get_persona_filter(persona_id_filter) ) if project_id_filter is not None: knowledge_filter["bool"]["should"].append( _get_user_project_filter(project_id_filter) ) filter_clauses.append(knowledge_filter) if time_cutoff is not None: # If a time cutoff is provided, the caller will only retrieve # documents where the document was last updated at or after the time # cutoff. For documents which do not have a value for # LAST_UPDATED_FIELD_NAME, we assume some default age for the # purposes of time cutoff. filter_clauses.append(_get_time_cutoff_filter(time_cutoff)) if min_chunk_index is not None or max_chunk_index is not None: filter_clauses.append( _get_chunk_index_filter(min_chunk_index, max_chunk_index) ) if document_id is not None: filter_clauses.append( {"term": {DOCUMENT_ID_FIELD_NAME: {"value": document_id}}} ) if max_chunk_size is not None: filter_clauses.append( {"term": {MAX_CHUNK_SIZE_FIELD_NAME: {"value": max_chunk_size}}} ) if tenant_state.multitenant: filter_clauses.append( {"term": {TENANT_ID_FIELD_NAME: {"value": tenant_state.tenant_id}}} ) return filter_clauses @staticmethod def _get_match_highlights_configuration() -> dict[str, Any]: """ Gets configuration for returning match highlights for a hit. """ match_highlights_configuration: dict[str, Any] = { "fields": { CONTENT_FIELD_NAME: { # See https://docs.opensearch.org/latest/search-plugins/searching-data/highlight/#highlighter-types "type": "unified", # The length in chars of a match snippet. Somewhat # arbitrarily-chosen. The Vespa codepath limited total # highlights length to 400 chars. fragment_size * # number_of_fragments = 400 should be good enough. "fragment_size": 100, # The number of snippets to return per field per document # hit. "number_of_fragments": 4, # These tags wrap matched keywords and they match what Vespa # used to return. Use them to minimize changes to our code. "pre_tags": [""], "post_tags": [""], } } } return match_highlights_configuration ================================================ FILE: backend/onyx/document_index/opensearch/string_filtering.py ================================================ import re MAX_DOCUMENT_ID_ENCODED_LENGTH: int = 512 class DocumentIDTooLongError(ValueError): """Raised when a document ID is too long for OpenSearch after filtering.""" def filter_and_validate_document_id( document_id: str, max_encoded_length: int = MAX_DOCUMENT_ID_ENCODED_LENGTH ) -> str: """ Filters and validates a document ID such that it can be used as an ID in OpenSearch. OpenSearch imposes the following restrictions on IDs: - Must not be an empty string. - Must not exceed 512 bytes. - Must not contain any control characters (newline, etc.). - Must not contain URL-unsafe characters (#, ?, /, %, &, etc.). For extra resilience, this function simply removes all characters that are not alphanumeric or one of _.-~. Any query on document ID should use this function. Args: document_id: The document ID to filter and validate. max_encoded_length: The maximum length of the document ID after filtering in bytes. Compared with >= for extra resilience, so encoded values of this length will fail. Raises: DocumentIDTooLongError: If the document ID is too long after filtering. ValueError: If the document ID is empty after filtering. Returns: str: The filtered document ID. """ filtered_document_id = re.sub(r"[^A-Za-z0-9_.\-~]", "", document_id) if not filtered_document_id: raise ValueError(f"Document ID {document_id} is empty after filtering.") if len(filtered_document_id.encode("utf-8")) >= max_encoded_length: raise DocumentIDTooLongError( f"Document ID {document_id} is too long after filtering." ) return filtered_document_id ================================================ FILE: backend/onyx/document_index/vespa/__init__.py ================================================ ================================================ FILE: backend/onyx/document_index/vespa/app_config/schemas/danswer_chunk.sd.jinja ================================================ schema {{ schema_name }} { # source, type, target triplets for kg_relationships struct kg_relationship { field source type string {} field rel_type type string {} field target type string {} } document {{ schema_name }} { {% if multi_tenant %} field tenant_id type string { indexing: summary | attribute rank: filter attribute: fast-search } {% endif %} # Not to be confused with the UUID generated for this chunk which is called documentid by default field document_id type string { indexing: summary | attribute rank: filter attribute: fast-search } field chunk_id type int { indexing: summary | attribute } # Displayed in the UI as the main identifier for the doc field semantic_identifier type string { indexing: summary | attribute } # Must have an additional field for whether to skip title embeddings # This information cannot be extracted from either the title field nor title embedding field skip_title type bool { indexing: attribute } # May not always match the `semantic_identifier` e.g. for Slack docs the # `semantic_identifier` will be the channel name, but the `title` will be empty field title type string { indexing: summary | index | attribute index: enable-bm25 } field content type string { indexing: summary | index index: enable-bm25 } # duplication of `content` is far from ideal, but is needed for # non-gram based highlighting for now. If the capability to re-use a # single field to do both is added, `content_summary` should be removed field content_summary type string { indexing: summary | index summary: dynamic } # Title embedding (x1) field title_embedding type tensor<{{ embedding_precision }}>(x[{{ dim }}]) { indexing: attribute | index attribute { distance-metric: angular } } # Content embeddings (chunk + optional mini chunks embeddings) # "t" and "x" are arbitrary names, not special keywords field embeddings type tensor<{{ embedding_precision }}>(t{},x[{{ dim }}]) { indexing: attribute | index attribute { distance-metric: angular } } # Starting section of the doc, currently unused as it has been replaced by match highlighting field blurb type string { indexing: summary | attribute } field image_file_name type string { indexing: summary | attribute } # https://docs.vespa.ai/en/attributes.html potential enum store for speed, but probably not worth it field source_type type string { indexing: summary | attribute rank: filter attribute: fast-search } # Can also index links https://docs.vespa.ai/en/reference/schema-reference.html#attribute # URL type matching field source_links type string { indexing: summary | attribute } field section_continuation type bool { indexing: summary | attribute } # Technically this one should be int, but can't change without causing breaks to existing index field boost type float { indexing: summary | attribute } field hidden type bool { indexing: summary | attribute rank: filter } # Field to indicate whether a short chunk is a low content chunk field aggregated_chunk_boost_factor type float { indexing: attribute } # Separate array fields for knowledge graph data field kg_entities type array { indexing: summary | attribute attribute: fast-search } field kg_relationships type array { indexing: summary struct-field source { indexing: attribute attribute: fast-search } struct-field rel_type { indexing: attribute attribute: fast-search } struct-field target { indexing: attribute attribute: fast-search } } field kg_terms type array { indexing: summary | attribute attribute: fast-search } # Needs to have a separate Attribute list for efficient filtering field metadata_list type array { indexing: summary | attribute rank:filter attribute: fast-search } # If chunk is a large chunk, this will contain the ids of the smaller chunks field large_chunk_reference_ids type array { indexing: summary | attribute } field metadata type string { indexing: summary | attribute } field chunk_context type string { indexing: summary | attribute } field doc_summary type string { indexing: summary | attribute } field metadata_suffix type string { indexing: summary | attribute } field doc_updated_at type int { indexing: summary | attribute } field primary_owners type array { indexing: summary | attribute } field secondary_owners type array { indexing: summary | attribute } field access_control_list type weightedset { indexing: summary | attribute rank: filter attribute: fast-search } field document_sets type weightedset { indexing: summary | attribute rank: filter attribute: fast-search } field user_file type int { indexing: summary | attribute rank: filter attribute: fast-search } field user_folder type int { indexing: summary | attribute rank: filter attribute: fast-search } field user_project type array { indexing: summary | attribute rank: filter attribute: fast-search } field personas type array { indexing: summary | attribute rank: filter attribute: fast-search } } # If using different tokenization settings, the fieldset has to be removed, and the field must # be specified in the yql like: # + 'or ({grammar: "weakAnd", defaultIndex:"title"}userInput(@query)) ' # + 'or ({grammar: "weakAnd", defaultIndex:"content"}userInput(@query)) ' # Note: for BM-25, the ngram size (and whether ngrams are used) changes the range of the scores fieldset default { fields: content, title } rank-profile default_rank { inputs { query(decay_factor) double } function inline document_boost() { # 0.5 to 2x score: piecewise sigmoid function stretched out by factor of 3 # meaning requires 3x the number of feedback votes to have default sigmoid effect expression: if(attribute(boost) < 0, 0.5 + (1 / (1 + exp(-attribute(boost) / 3))), 2 / (1 + exp(-attribute(boost) / 3))) } function inline document_age() { # Time in years (91.3 days ~= 3 Months ~= 1 fiscal quarter if no age found) expression: max(if(isNan(attribute(doc_updated_at)) == 1, 7890000, now() - attribute(doc_updated_at)) / 31536000, 0) } function inline aggregated_chunk_boost() { # Aggregated boost factor, currently only used for information content classification expression: if(isNan(attribute(aggregated_chunk_boost_factor)) == 1, 1.0, attribute(aggregated_chunk_boost_factor)) } # Document score decays from 1 to 0.75 as age of last updated time increases function inline recency_bias() { expression: max(1 / (1 + query(decay_factor) * document_age), 0.75) } match-features: recency_bias } rank-profile hybrid_search_semantic_base_{{ dim }} inherits default, default_rank { inputs { query(query_embedding) tensor(x[{{ dim }}]) } function title_vector_score() { expression { # If no good matching titles, then it should use the context embeddings rather than having some # irrelevant title have a vector score of 1. This way at least it will be the doc with the highest # matching content score getting the full score max(closeness(field, embeddings), closeness(field, title_embedding)) } } # First phase must be vector to allow hits that have no keyword matches first-phase { expression: query(title_content_ratio) * closeness(field, title_embedding) + (1 - query(title_content_ratio)) * closeness(field, embeddings) } # Weighted average between Vector Search and BM-25 global-phase { expression { ( # Weighted Vector Similarity Score ( query(alpha) * ( (query(title_content_ratio) * normalize_linear(title_vector_score)) + ((1 - query(title_content_ratio)) * normalize_linear(closeness(field, embeddings))) ) ) + # Weighted Keyword Similarity Score # Note: for the BM25 Title score, it requires decent stopword removal in the query # This needs to be the case so there aren't irrelevant titles being normalized to a score of 1 ( (1 - query(alpha)) * ( (query(title_content_ratio) * normalize_linear(bm25(title))) + ((1 - query(title_content_ratio)) * normalize_linear(bm25(content))) ) ) ) # Boost based on user feedback * document_boost # Decay factor based on time document was last updated * recency_bias # Boost based on aggregated boost calculation * aggregated_chunk_boost } # Target hits for hybrid retrieval should be at least this value. rerank-count: 1000 } match-features { bm25(title) bm25(content) closeness(field, title_embedding) closeness(field, embeddings) document_boost recency_bias aggregated_chunk_boost closest(embeddings) } } rank-profile hybrid_search_keyword_base_{{ dim }} inherits default, default_rank { inputs { query(query_embedding) tensor(x[{{ dim }}]) } function title_vector_score() { expression { # If no good matching titles, then it should use the context embeddings rather than having some # irrelevant title have a vector score of 1. This way at least it will be the doc with the highest # matching content score getting the full score max(closeness(field, embeddings), closeness(field, title_embedding)) } } # First phase must be vector to allow hits that have no keyword matches first-phase { expression: query(title_content_ratio) * bm25(title) + (1 - query(title_content_ratio)) * bm25(content) } # Weighted average between Vector Search and BM-25 global-phase { expression { ( # Weighted Vector Similarity Score ( query(alpha) * ( (query(title_content_ratio) * normalize_linear(title_vector_score)) + ((1 - query(title_content_ratio)) * normalize_linear(closeness(field, embeddings))) ) ) + # Weighted Keyword Similarity Score # Note: for the BM25 Title score, it requires decent stopword removal in the query # This needs to be the case so there aren't irrelevant titles being normalized to a score of 1 ( (1 - query(alpha)) * ( (query(title_content_ratio) * normalize_linear(bm25(title))) + ((1 - query(title_content_ratio)) * normalize_linear(bm25(content))) ) ) ) # Boost based on user feedback * document_boost # Decay factor based on time document was last updated * recency_bias # Boost based on aggregated boost calculation * aggregated_chunk_boost } # Target hits for hybrid retrieval should be at least this value. rerank-count: 1000 } match-features { bm25(title) bm25(content) closeness(field, title_embedding) closeness(field, embeddings) document_boost recency_bias aggregated_chunk_boost closest(embeddings) } } # Used when searching from the admin UI for a specific doc to hide / boost # Very heavily prioritize title rank-profile admin_search inherits default, default_rank { first-phase { expression: bm25(content) + (5 * bm25(title)) } } rank-profile random_ inherits default { first-phase { expression: random } } } ================================================ FILE: backend/onyx/document_index/vespa/app_config/services.xml.jinja ================================================ 1 {{ document_elements }} 0.85 {{ num_search_threads }} 3 750 350 300 ================================================ FILE: backend/onyx/document_index/vespa/app_config/validation-overrides.xml.jinja ================================================ schema-removal indexing-change field-type-change ================================================ FILE: backend/onyx/document_index/vespa/chunk_retrieval.py ================================================ import json import string import time from collections.abc import Callable from collections.abc import Mapping from datetime import datetime from datetime import timezone from typing import Any from typing import cast import httpx from retry import retry from onyx.background.celery.tasks.opensearch_migration.constants import ( FINISHED_VISITING_SLICE_CONTINUATION_TOKEN, ) from onyx.background.celery.tasks.opensearch_migration.transformer import ( FIELDS_NEEDED_FOR_TRANSFORMATION, ) from onyx.configs.app_configs import LOG_VESPA_TIMING_INFORMATION from onyx.configs.app_configs import VESPA_LANGUAGE_OVERRIDE from onyx.configs.app_configs import VESPA_MIGRATION_REQUEST_TIMEOUT_S from onyx.configs.app_configs import VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT from onyx.context.search.models import IndexFilters from onyx.context.search.models import InferenceChunkUncleaned from onyx.document_index.interfaces import VespaChunkRequest from onyx.document_index.interfaces_new import TenantState from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client from onyx.document_index.vespa.shared_utils.vespa_request_builders import ( build_vespa_filters, ) from onyx.document_index.vespa.shared_utils.vespa_request_builders import ( build_vespa_id_based_retrieval_yql, ) from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST from onyx.document_index.vespa_constants import BLURB from onyx.document_index.vespa_constants import BOOST from onyx.document_index.vespa_constants import CHUNK_CONTEXT from onyx.document_index.vespa_constants import CHUNK_ID from onyx.document_index.vespa_constants import CONTENT from onyx.document_index.vespa_constants import CONTENT_SUMMARY from onyx.document_index.vespa_constants import DOC_SUMMARY from onyx.document_index.vespa_constants import DOC_UPDATED_AT from onyx.document_index.vespa_constants import DOCUMENT_ID from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT from onyx.document_index.vespa_constants import HIDDEN from onyx.document_index.vespa_constants import IMAGE_FILE_NAME from onyx.document_index.vespa_constants import LARGE_CHUNK_REFERENCE_IDS from onyx.document_index.vespa_constants import MAX_ID_SEARCH_QUERY_SIZE from onyx.document_index.vespa_constants import MAX_OR_CONDITIONS from onyx.document_index.vespa_constants import METADATA from onyx.document_index.vespa_constants import METADATA_SUFFIX from onyx.document_index.vespa_constants import PRIMARY_OWNERS from onyx.document_index.vespa_constants import SEARCH_ENDPOINT from onyx.document_index.vespa_constants import SECONDARY_OWNERS from onyx.document_index.vespa_constants import SECTION_CONTINUATION from onyx.document_index.vespa_constants import SEMANTIC_IDENTIFIER from onyx.document_index.vespa_constants import SOURCE_LINKS from onyx.document_index.vespa_constants import SOURCE_TYPE from onyx.document_index.vespa_constants import TENANT_ID from onyx.document_index.vespa_constants import TITLE from onyx.document_index.vespa_constants import YQL_BASE from onyx.utils.logger import setup_logger from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel from shared_configs.configs import MULTI_TENANT logger = setup_logger() def _process_dynamic_summary( dynamic_summary: str, max_summary_length: int = 400 ) -> list[str]: if not dynamic_summary: return [] current_length = 0 processed_summary: list[str] = [] for summary_section in dynamic_summary.split(""): # if we're past the desired max length, break at the last word if current_length + len(summary_section) >= max_summary_length: summary_section = summary_section[: max_summary_length - current_length] summary_section = summary_section.lstrip() # remove any leading whitespace # handle the case where the truncated section is either just a # single (partial) word or if it's empty first_space = summary_section.find(" ") if first_space == -1: # add ``...`` to previous section if processed_summary: processed_summary[-1] += "..." break # handle the valid truncated section case summary_section = summary_section.rsplit(" ", 1)[0] if summary_section[-1] in string.punctuation: summary_section = summary_section[:-1] summary_section += "..." processed_summary.append(summary_section) break processed_summary.append(summary_section) current_length += len(summary_section) return processed_summary def _vespa_hit_to_inference_chunk( hit: dict[str, Any], null_score: bool = False ) -> InferenceChunkUncleaned: fields = cast(dict[str, Any], hit["fields"]) # parse fields that are stored as strings, but are really json / datetime metadata = json.loads(fields[METADATA]) if METADATA in fields else {} updated_at = ( datetime.fromtimestamp(fields[DOC_UPDATED_AT], tz=timezone.utc) if DOC_UPDATED_AT in fields else None ) match_highlights = _process_dynamic_summary( # fallback to regular `content` if the `content_summary` field # isn't present dynamic_summary=hit["fields"].get(CONTENT_SUMMARY, hit["fields"][CONTENT]), ) semantic_identifier = fields.get(SEMANTIC_IDENTIFIER, "") if not semantic_identifier: logger.error( f"Chunk with blurb: {fields.get(BLURB, 'Unknown')[:50]}... has no Semantic Identifier" ) source_links = fields.get(SOURCE_LINKS, {}) source_links_dict_unprocessed = ( json.loads(source_links) if isinstance(source_links, str) else source_links ) source_links_dict = { int(k): v for k, v in cast(dict[str, str], source_links_dict_unprocessed).items() } return InferenceChunkUncleaned( chunk_id=fields[CHUNK_ID], blurb=fields.get(BLURB, ""), # Unused content=fields[CONTENT], # Includes extra title prefix and metadata suffix; # also sometimes context for contextual rag source_links=source_links_dict or {0: ""}, section_continuation=fields[SECTION_CONTINUATION], document_id=fields[DOCUMENT_ID], source_type=fields[SOURCE_TYPE], # still called `image_file_name` in Vespa for backwards compatibility image_file_id=fields.get(IMAGE_FILE_NAME), title=fields.get(TITLE), semantic_identifier=fields[SEMANTIC_IDENTIFIER], boost=fields.get(BOOST, 1), score=None if null_score else hit.get("relevance", 0), hidden=fields.get(HIDDEN, False), primary_owners=fields.get(PRIMARY_OWNERS), secondary_owners=fields.get(SECONDARY_OWNERS), large_chunk_reference_ids=fields.get(LARGE_CHUNK_REFERENCE_IDS, []), metadata=metadata, metadata_suffix=fields.get(METADATA_SUFFIX), doc_summary=fields.get(DOC_SUMMARY, ""), chunk_context=fields.get(CHUNK_CONTEXT, ""), match_highlights=match_highlights, updated_at=updated_at, ) def get_chunks_via_visit_api( chunk_request: VespaChunkRequest, index_name: str, filters: IndexFilters, field_names: list[str] | None = None, get_large_chunks: bool = False, short_tensor_format: bool = False, ) -> list[dict]: # Constructing the URL for the Visit API # NOTE: visit API uses the same URL as the document API, but with different params url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name) # build the list of fields to retrieve field_set_list = ( [f"{field_name}" for field_name in field_names] if field_names else [] ) acl_fieldset_entry = f"{ACCESS_CONTROL_LIST}" if ( field_set_list and filters.access_control_list and acl_fieldset_entry not in field_set_list ): field_set_list.append(acl_fieldset_entry) if MULTI_TENANT: tenant_id_fieldset_entry = f"{TENANT_ID}" if field_set_list and tenant_id_fieldset_entry not in field_set_list: field_set_list.append(tenant_id_fieldset_entry) if field_set_list: field_set = f"{index_name}:" + ",".join(field_set_list) else: field_set = None # build filters selection = f"{index_name}.document_id=='{chunk_request.document_id}'" if chunk_request.is_capped: selection += f" and {index_name}.chunk_id>={chunk_request.min_chunk_ind or 0}" selection += f" and {index_name}.chunk_id<={chunk_request.max_chunk_ind}" if not get_large_chunks: selection += f" and {index_name}.large_chunk_reference_ids == null" # enforcing tenant_id through a == condition if MULTI_TENANT: if filters.tenant_id: selection += f" and {index_name}.tenant_id=='{filters.tenant_id}'" else: raise ValueError("Tenant ID is required for multi-tenant") # Setting up the selection criteria in the query parameters params = { # NOTE: Document Selector Language doesn't allow `contains`, so we can't check # for the ACL in the selection. Instead, we have to check as a postfilter "selection": selection, "continuation": None, "wantedDocumentCount": 1_000, "fieldSet": field_set, } # Vespa can supply tensors in various different formats. This explicitly # asks to retrieve tensor data in "short-value" format. if short_tensor_format: params["format.tensors"] = "short-value" document_chunks: list[dict] = [] while True: try: filtered_params = {k: v for k, v in params.items() if v is not None} with get_vespa_http_client() as http_client: response = http_client.get(url, params=filtered_params) response.raise_for_status() except httpx.HTTPError as e: error_base = "Failed to query Vespa" logger.error( f"{error_base}:\n" f"Request URL: {e.request.url}\n" f"Request Headers: {e.request.headers}\n" f"Request Payload: {params}\n" f"Exception: {str(e)}" ) raise httpx.HTTPError(error_base) from e # Check if the response contains any documents response_data = response.json() if "documents" in response_data: for document in response_data["documents"]: if filters.access_control_list: document_acl = document["fields"].get(ACCESS_CONTROL_LIST) if not document_acl or not any( user_acl_entry in document_acl for user_acl_entry in filters.access_control_list ): continue if MULTI_TENANT: if not filters.tenant_id: raise ValueError("Tenant ID is required for multi-tenant") document_tenant_id = document["fields"].get(TENANT_ID) if document_tenant_id != filters.tenant_id: logger.error( f"Skipping document {document['document_id']} because " f"it does not belong to tenant {filters.tenant_id}. " "This should never happen." ) continue document_chunks.append(document) # Check for continuation token to handle pagination if "continuation" in response_data and response_data["continuation"]: params["continuation"] = response_data["continuation"] else: break # Exit loop if no continuation token return document_chunks def get_all_chunks_paginated( index_name: str, tenant_state: TenantState, continuation_token_map: dict[int, str | None], page_size: int, ) -> tuple[list[dict], dict[int, str | None]]: """Gets all chunks in Vespa matching the filters, paginated. Uses the Visit API with slicing. Each continuation token map entry is for a different slice. The number of entries determines the number of slices. Args: index_name: The name of the Vespa index to visit. tenant_state: The tenant state to filter by. continuation_token_map: Map of slice ID to a token returned by Vespa representing a page offset. None to start from the beginning of the slice. page_size: Best-effort batch size for the visit. Defaults to 1,000. Returns: Tuple of (list of chunk dicts, next continuation token or None). The continuation token is None when the visit is complete. """ def _get_all_chunks_paginated_for_slice( index_name: str, tenant_state: TenantState, slice_id: int, total_slices: int, continuation_token: str | None, page_size: int, ) -> tuple[list[dict], str | None]: if continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN: logger.debug( f"Slice {slice_id} has finished visiting. Returning empty list and {FINISHED_VISITING_SLICE_CONTINUATION_TOKEN}." ) return [], FINISHED_VISITING_SLICE_CONTINUATION_TOKEN url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name) selection: str = f"{index_name}.large_chunk_reference_ids == null" if MULTI_TENANT: selection += f" and {index_name}.tenant_id=='{tenant_state.tenant_id}'" field_set = f"{index_name}:" + ",".join(FIELDS_NEEDED_FOR_TRANSFORMATION) params: dict[str, str | int | None] = { "selection": selection, "fieldSet": field_set, "wantedDocumentCount": page_size, "format.tensors": "short-value", "slices": total_slices, "sliceId": slice_id, # When exceeded, Vespa should return gracefully with partial # results. Even if no hits are returned, Vespa should still return a # new continuation token representing a new spot in the linear # traversal. "timeout": VESPA_MIGRATION_SERVER_SIDE_REQUEST_TIMEOUT, } if continuation_token is not None: params["continuation"] = continuation_token response: httpx.Response | None = None start_time = time.monotonic() try: with get_vespa_http_client( # When exceeded, an exception is raised in our code. No progress # is saved, and the task will retry this spot in the traversal # later. timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S ) as http_client: response = http_client.get(url, params=params) response.raise_for_status() except httpx.HTTPError as e: error_base = ( f"Failed to get chunks from Vespa slice {slice_id} with continuation token " f"{continuation_token} in {time.monotonic() - start_time:.3f} seconds." ) logger.exception( f"Request URL: {e.request.url}\nRequest Headers: {e.request.headers}\nRequest Payload: {params}\n" ) error_message = ( response.json().get("message") if response else "No response" ) logger.error("Error message from response: %s", error_message) raise httpx.HTTPError(error_base) from e response_data = response.json() # NOTE: If we see a falsey value for "continuation" in the response we # assume we are done and return # FINISHED_VISITING_SLICE_CONTINUATION_TOKEN instead. next_continuation_token = ( response_data.get("continuation") or FINISHED_VISITING_SLICE_CONTINUATION_TOKEN ) chunks = [chunk["fields"] for chunk in response_data.get("documents", [])] if next_continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN: logger.debug( f"Slice {slice_id} has finished visiting. Returning {len(chunks)} chunks and {next_continuation_token}." ) return chunks, next_continuation_token total_slices = len(continuation_token_map) if total_slices < 1: raise ValueError("continuation_token_map must have at least one entry.") # We want to guarantee that these invocations are ordered by slice_id, # because we read in the same order below when parsing parallel_results. functions_with_args: list[tuple[Callable, tuple]] = [ ( _get_all_chunks_paginated_for_slice, ( index_name, tenant_state, slice_id, total_slices, continuation_token, page_size, ), ) for slice_id, continuation_token in sorted(continuation_token_map.items()) ] parallel_results = run_functions_tuples_in_parallel( functions_with_args, allow_failures=True ) if len(parallel_results) != total_slices: raise RuntimeError( f"Expected {total_slices} parallel results, but got {len(parallel_results)}." ) chunks: list[dict] = [] next_continuation_token_map: dict[int, str | None] = { key: value for key, value in continuation_token_map.items() } for i, parallel_result in enumerate(parallel_results): if i not in next_continuation_token_map: raise RuntimeError(f"Slice {i} is not in the continuation token map.") if parallel_result is None: logger.error( f"Failed to get chunks for slice {i} of {total_slices}. " "The continuation token for this slice will not be updated." ) continue chunks.extend(parallel_result[0]) next_continuation_token_map[i] = parallel_result[1] return chunks, next_continuation_token_map # TODO(rkuo): candidate for removal if not being used # @retry(tries=10, delay=1, backoff=2) # def get_all_vespa_ids_for_document_id( # document_id: str, # index_name: str, # filters: IndexFilters | None = None, # get_large_chunks: bool = False, # ) -> list[str]: # document_chunks = get_chunks_via_visit_api( # chunk_request=VespaChunkRequest(document_id=document_id), # index_name=index_name, # filters=filters or IndexFilters(access_control_list=None), # field_names=[DOCUMENT_ID], # get_large_chunks=get_large_chunks, # ) # return [chunk["id"].split("::", 1)[-1] for chunk in document_chunks] def parallel_visit_api_retrieval( index_name: str, chunk_requests: list[VespaChunkRequest], filters: IndexFilters, get_large_chunks: bool = False, ) -> list[InferenceChunkUncleaned]: functions_with_args: list[tuple[Callable, tuple]] = [ ( get_chunks_via_visit_api, (chunk_request, index_name, filters, get_large_chunks), ) for chunk_request in chunk_requests ] parallel_results = run_functions_tuples_in_parallel( functions_with_args, allow_failures=True ) # Any failures to retrieve would give a None, drop the Nones and empty lists vespa_chunk_sets = [res for res in parallel_results if res] flattened_vespa_chunks = [] for chunk_set in vespa_chunk_sets: flattened_vespa_chunks.extend(chunk_set) inference_chunks = [ _vespa_hit_to_inference_chunk(chunk, null_score=True) for chunk in flattened_vespa_chunks ] return inference_chunks @retry(tries=3, delay=1, backoff=2) def query_vespa( query_params: Mapping[str, str | int | float], ) -> list[InferenceChunkUncleaned]: if "query" in query_params and not cast(str, query_params["query"]).strip(): raise ValueError("No/empty query received") params = dict( **query_params, **( { "presentation.timing": True, } if LOG_VESPA_TIMING_INFORMATION else {} ), ) if VESPA_LANGUAGE_OVERRIDE: params["language"] = VESPA_LANGUAGE_OVERRIDE try: with get_vespa_http_client() as http_client: response = http_client.post(SEARCH_ENDPOINT, json=params) response.raise_for_status() except httpx.HTTPError as e: response_text = ( e.response.text if isinstance(e, httpx.HTTPStatusError) else None ) status_code = ( e.response.status_code if isinstance(e, httpx.HTTPStatusError) else None ) yql_value = params.get("yql", "") yql_length = len(str(yql_value)) # Log each detail on its own line so log collectors capture them # as separate entries rather than truncating a single multiline msg logger.error( f"Failed to query Vespa | " f"status={status_code} | " f"yql_length={yql_length} | " f"exception={str(e)}" ) if response_text: logger.error(f"Vespa error response: {response_text[:1000]}") logger.error(f"Vespa request URL: {e.request.url}") # Re-raise with diagnostics so callers see what actually went wrong raise httpx.HTTPError( f"Failed to query Vespa (status={status_code}, " f"yql_length={yql_length})" ) from e response_json: dict[str, Any] = response.json() if LOG_VESPA_TIMING_INFORMATION: logger.debug("Vespa timing info: %s", response_json.get("timing")) hits = response_json["root"].get("children", []) if not hits: logger.warning( f"No hits found for YQL Query: {query_params.get('yql', 'No YQL Query')}" ) logger.debug(f"Vespa Response: {response.text}") for hit in hits: if hit["fields"].get(CONTENT) is None: identifier = hit["fields"].get("documentid") or hit["id"] logger.error( f"Vespa Index with Vespa ID {identifier} has no contents. " f"This is invalid because the vector is not meaningful and keywordsearch cannot " f"fetch this document" ) filtered_hits = [hit for hit in hits if hit["fields"].get(CONTENT) is not None] inference_chunks = [_vespa_hit_to_inference_chunk(hit) for hit in filtered_hits] try: num_retrieved_inference_chunks = len(inference_chunks) num_retrieved_document_ids = len( set([chunk.document_id for chunk in inference_chunks]) ) logger.info( f"Retrieved {num_retrieved_inference_chunks} inference chunks for {num_retrieved_document_ids} documents" ) except Exception as e: # Debug logging only, should not fail the retrieval logger.error(f"Error logging retrieval statistics: {e}") # Good Debugging Spot return inference_chunks def _get_chunks_via_batch_search( index_name: str, chunk_requests: list[VespaChunkRequest], filters: IndexFilters, get_large_chunks: bool = False, ) -> list[InferenceChunkUncleaned]: if not chunk_requests: return [] filters_str = build_vespa_filters(filters=filters, include_hidden=True) yql = ( YQL_BASE.format(index_name=index_name) + filters_str + build_vespa_id_based_retrieval_yql(chunk_requests[0]) ) chunk_requests.pop(0) for request in chunk_requests: yql += " or " + build_vespa_id_based_retrieval_yql(request) params: dict[str, str | int | float] = { "yql": yql, "hits": MAX_ID_SEARCH_QUERY_SIZE, } inference_chunks = query_vespa(params) if not get_large_chunks: inference_chunks = [ chunk for chunk in inference_chunks if not chunk.large_chunk_reference_ids ] inference_chunks.sort(key=lambda chunk: chunk.chunk_id) return inference_chunks def batch_search_api_retrieval( index_name: str, chunk_requests: list[VespaChunkRequest], filters: IndexFilters, get_large_chunks: bool = False, ) -> list[InferenceChunkUncleaned]: retrieved_chunks: list[InferenceChunkUncleaned] = [] capped_requests: list[VespaChunkRequest] = [] uncapped_requests: list[VespaChunkRequest] = [] chunk_count = 0 for req_ind, request in enumerate(chunk_requests, start=1): # All requests without a chunk range are uncapped # Uncapped requests are retrieved using the Visit API range = request.range if range is None: uncapped_requests.append(request) continue if ( chunk_count + range > MAX_ID_SEARCH_QUERY_SIZE or req_ind % MAX_OR_CONDITIONS == 0 ): retrieved_chunks.extend( _get_chunks_via_batch_search( index_name=index_name, chunk_requests=capped_requests, filters=filters, get_large_chunks=get_large_chunks, ) ) capped_requests = [] chunk_count = 0 capped_requests.append(request) chunk_count += range if capped_requests: retrieved_chunks.extend( _get_chunks_via_batch_search( index_name=index_name, chunk_requests=capped_requests, filters=filters, get_large_chunks=get_large_chunks, ) ) if uncapped_requests: logger.debug(f"Retrieving {len(uncapped_requests)} uncapped requests") retrieved_chunks.extend( parallel_visit_api_retrieval( index_name, uncapped_requests, filters, get_large_chunks ) ) return retrieved_chunks ================================================ FILE: backend/onyx/document_index/vespa/deletion.py ================================================ import concurrent.futures from uuid import UUID import httpx from retry import retry from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT from onyx.document_index.vespa_constants import NUM_THREADS from onyx.utils.logger import setup_logger logger = setup_logger() CONTENT_SUMMARY = "content_summary" @retry(tries=10, delay=1, backoff=2) def _retryable_http_delete(http_client: httpx.Client, url: str) -> None: res = http_client.delete(url) res.raise_for_status() def _delete_vespa_chunk( doc_chunk_id: UUID, index_name: str, http_client: httpx.Client ) -> None: try: _retryable_http_delete( http_client, f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}", ) except httpx.HTTPStatusError as e: logger.error(f"Failed to delete chunk, details: {e.response.text}") raise def delete_vespa_chunks( doc_chunk_ids: list[UUID], index_name: str, http_client: httpx.Client, executor: concurrent.futures.ThreadPoolExecutor | None = None, ) -> None: """Deletes a list of chunks from a Vespa index in parallel. Args: doc_chunk_ids: List of chunk IDs to delete. index_name: Name of the index to delete from. http_client: HTTP client to use for the request. executor: Executor to use for the request. """ external_executor = True if not executor: external_executor = False executor = concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) try: chunk_deletion_future = { executor.submit( _delete_vespa_chunk, doc_chunk_id, index_name, http_client ): doc_chunk_id for doc_chunk_id in doc_chunk_ids } for future in concurrent.futures.as_completed(chunk_deletion_future): # Will raise exception if the deletion raised an exception future.result() finally: if not external_executor: executor.shutdown(wait=True) ================================================ FILE: backend/onyx/document_index/vespa/index.py ================================================ import concurrent.futures import io import logging import os import re import time import urllib import zipfile from collections.abc import Iterable from dataclasses import dataclass from datetime import datetime from datetime import timedelta from typing import BinaryIO from typing import cast from typing import List import httpx import jinja2 import requests from pydantic import BaseModel from retry import retry from onyx.configs.app_configs import BLURB_SIZE from onyx.configs.chat_configs import NUM_RETURNED_HITS from onyx.configs.chat_configs import TITLE_CONTENT_RATIO from onyx.configs.chat_configs import VESPA_SEARCHER_THREADS from onyx.configs.constants import KV_REINDEX_KEY from onyx.configs.constants import RETURN_SEPARATOR from onyx.context.search.enums import QueryType from onyx.context.search.models import IndexFilters from onyx.context.search.models import InferenceChunk from onyx.context.search.models import InferenceChunkUncleaned from onyx.context.search.models import QueryExpansionType from onyx.db.enums import EmbeddingPrecision from onyx.document_index.document_index_utils import get_uuid_from_chunk_info from onyx.document_index.interfaces import DocumentIndex from onyx.document_index.interfaces import ( DocumentInsertionRecord as OldDocumentInsertionRecord, ) from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo from onyx.document_index.interfaces import IndexBatchParams from onyx.document_index.interfaces import MinimalDocumentIndexingInfo from onyx.document_index.interfaces import VespaChunkRequest from onyx.document_index.interfaces import VespaDocumentFields from onyx.document_index.interfaces import VespaDocumentUserFields from onyx.document_index.interfaces_new import DocumentSectionRequest from onyx.document_index.interfaces_new import IndexingMetadata from onyx.document_index.interfaces_new import MetadataUpdateRequest from onyx.document_index.vespa.chunk_retrieval import query_vespa from onyx.document_index.vespa.indexing_utils import BaseHTTPXClientContext from onyx.document_index.vespa.indexing_utils import check_for_final_chunk_existence from onyx.document_index.vespa.indexing_utils import GlobalHTTPXClientContext from onyx.document_index.vespa.indexing_utils import TemporaryHTTPXClientContext from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client from onyx.document_index.vespa.shared_utils.vespa_request_builders import ( build_vespa_filters, ) from onyx.document_index.vespa.vespa_document_index import TenantState from onyx.document_index.vespa.vespa_document_index import VespaDocumentIndex from onyx.document_index.vespa_constants import BATCH_SIZE from onyx.document_index.vespa_constants import CONTENT_SUMMARY from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT from onyx.document_index.vespa_constants import NUM_THREADS from onyx.document_index.vespa_constants import VESPA_APPLICATION_ENDPOINT from onyx.document_index.vespa_constants import VESPA_TIMEOUT from onyx.document_index.vespa_constants import YQL_BASE from onyx.indexing.models import DocMetadataAwareIndexChunk from onyx.key_value_store.factory import get_shared_kv_store from onyx.kg.utils.formatting_utils import split_relationship_id from onyx.utils.batching import batch_generator from onyx.utils.logger import setup_logger from onyx.utils.timing import log_function_time from shared_configs.configs import MULTI_TENANT from shared_configs.contextvars import get_current_tenant_id from shared_configs.model_server_models import Embedding logger = setup_logger() # Set the logging level to WARNING to ignore INFO and DEBUG logs httpx_logger = logging.getLogger("httpx") httpx_logger.setLevel(logging.WARNING) @dataclass class _VespaUpdateRequest: document_id: str url: str update_request: dict[str, dict] class KGVespaChunkUpdateRequest(BaseModel): document_id: str chunk_id: int url: str update_request: dict[str, dict] class KGUChunkUpdateRequest(BaseModel): """ Update KG fields for a document """ document_id: str chunk_id: int core_entity: str entities: set[str] | None = None relationships: set[str] | None = None terms: set[str] | None = None class KGUDocumentUpdateRequest(BaseModel): """ Update KG fields for a document """ document_id: str entities: set[str] relationships: set[str] terms: set[str] def generate_kg_update_request( kg_update_request: KGUChunkUpdateRequest, ) -> dict[str, dict]: kg_update_dict: dict[str, dict] = {} if kg_update_request.entities is not None: kg_update_dict["kg_entities"] = {"assign": list(kg_update_request.entities)} if kg_update_request.relationships is not None: kg_update_dict["kg_relationships"] = {"assign": []} for relationship in kg_update_request.relationships: source, rel_type, target = split_relationship_id(relationship) kg_update_dict["kg_relationships"]["assign"].append( { "source": source, "rel_type": rel_type, "target": target, } ) return kg_update_dict def in_memory_zip_from_file_bytes(file_contents: dict[str, bytes]) -> BinaryIO: zip_buffer = io.BytesIO() with zipfile.ZipFile(zip_buffer, "w", zipfile.ZIP_DEFLATED) as zipf: for filename, content in file_contents.items(): zipf.writestr(filename, content) zip_buffer.seek(0) return zip_buffer def _create_document_xml_lines(doc_names: list[str | None] | list[str]) -> str: doc_lines = [ f'' for doc_name in doc_names if doc_name ] return "\n".join(doc_lines) def add_ngrams_to_schema(schema_content: str) -> str: # Add the match blocks containing gram and gram-size to title and content fields schema_content = re.sub( r"(field title type string \{[^}]*indexing: summary \| index \| attribute)", r"\1\n match {\n gram\n gram-size: 3\n }", schema_content, ) schema_content = re.sub( r"(field content type string \{[^}]*indexing: summary \| index)", r"\1\n match {\n gram\n gram-size: 3\n }", schema_content, ) return schema_content def cleanup_chunks(chunks: list[InferenceChunkUncleaned]) -> list[InferenceChunk]: def _remove_title(chunk: InferenceChunkUncleaned) -> str: if not chunk.title or not chunk.content: return chunk.content if chunk.content.startswith(chunk.title): return chunk.content[len(chunk.title) :].lstrip() # BLURB SIZE is by token instead of char but each token is at least 1 char # If this prefix matches the content, it's assumed the title was prepended if chunk.content.startswith(chunk.title[:BLURB_SIZE]): return ( chunk.content.split(RETURN_SEPARATOR, 1)[-1] if RETURN_SEPARATOR in chunk.content else chunk.content ) return chunk.content def _remove_metadata_suffix(chunk: InferenceChunkUncleaned) -> str: if not chunk.metadata_suffix: return chunk.content return chunk.content.removesuffix(chunk.metadata_suffix).rstrip( RETURN_SEPARATOR ) def _remove_contextual_rag(chunk: InferenceChunkUncleaned) -> str: # remove document summary if chunk.content.startswith(chunk.doc_summary): chunk.content = chunk.content[len(chunk.doc_summary) :].lstrip() # remove chunk context if chunk.content.endswith(chunk.chunk_context): chunk.content = chunk.content[ : len(chunk.content) - len(chunk.chunk_context) ].rstrip() return chunk.content for chunk in chunks: chunk.content = _remove_title(chunk) chunk.content = _remove_metadata_suffix(chunk) chunk.content = _remove_contextual_rag(chunk) return [chunk.to_inference_chunk() for chunk in chunks] class VespaIndex(DocumentIndex): VESPA_SCHEMA_JINJA_FILENAME = "danswer_chunk.sd.jinja" def __init__( self, index_name: str, secondary_index_name: str | None, large_chunks_enabled: bool, secondary_large_chunks_enabled: bool | None, multitenant: bool = False, httpx_client: httpx.Client | None = None, ) -> None: self.index_name = index_name self.secondary_index_name = secondary_index_name self.large_chunks_enabled = large_chunks_enabled self.secondary_large_chunks_enabled = secondary_large_chunks_enabled self.multitenant = multitenant # Temporary until we refactor the entirety of this class. self.httpx_client = httpx_client self.httpx_client_context: BaseHTTPXClientContext if httpx_client: self.httpx_client_context = GlobalHTTPXClientContext(httpx_client) else: self.httpx_client_context = TemporaryHTTPXClientContext( get_vespa_http_client ) self.index_to_large_chunks_enabled: dict[str, bool] = {} self.index_to_large_chunks_enabled[index_name] = large_chunks_enabled if secondary_index_name and secondary_large_chunks_enabled: self.index_to_large_chunks_enabled[secondary_index_name] = ( secondary_large_chunks_enabled ) def ensure_indices_exist( self, primary_embedding_dim: int, primary_embedding_precision: EmbeddingPrecision, secondary_index_embedding_dim: int | None, secondary_index_embedding_precision: EmbeddingPrecision | None, ) -> None: if MULTI_TENANT: logger.info( "Skipping Vespa index setup for multitenant (would wipe all indices)" ) return None jinja_env = jinja2.Environment() deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate" logger.notice(f"Deploying Vespa application package to {deploy_url}") vespa_schema_path = os.path.join( os.getcwd(), "onyx", "document_index", "vespa", "app_config" ) schema_jinja_file = os.path.join( vespa_schema_path, "schemas", VespaIndex.VESPA_SCHEMA_JINJA_FILENAME ) services_jinja_file = os.path.join(vespa_schema_path, "services.xml.jinja") overrides_jinja_file = os.path.join( vespa_schema_path, "validation-overrides.xml.jinja" ) with open(services_jinja_file, "r") as services_f: schema_names = [self.index_name, self.secondary_index_name] doc_lines = _create_document_xml_lines(schema_names) services_template_str = services_f.read() services_template = jinja_env.from_string(services_template_str) services = services_template.render( document_elements=doc_lines, num_search_threads=str(VESPA_SEARCHER_THREADS), ) kv_store = get_shared_kv_store() needs_reindexing = False try: needs_reindexing = cast(bool, kv_store.load(KV_REINDEX_KEY)) except Exception: logger.debug("Could not load the reindexing flag. Using ngrams") # Vespa requires an override to erase data including the indices we're no longer using # It also has a 30 day cap from current so we set it to 7 dynamically with open(overrides_jinja_file, "r") as overrides_f: overrides_template_str = overrides_f.read() overrides_template = jinja_env.from_string(overrides_template_str) now = datetime.now() date_in_7_days = now + timedelta(days=7) formatted_date = date_in_7_days.strftime("%Y-%m-%d") overrides = overrides_template.render( until_date=formatted_date, ) zip_dict = { "services.xml": services.encode("utf-8"), "validation-overrides.xml": overrides.encode("utf-8"), } with open(schema_jinja_file, "r") as schema_f: template_str = schema_f.read() template = jinja_env.from_string(template_str) schema = template.render( multi_tenant=MULTI_TENANT, schema_name=self.index_name, dim=primary_embedding_dim, embedding_precision=primary_embedding_precision.value, ) schema = add_ngrams_to_schema(schema) if needs_reindexing else schema zip_dict[f"schemas/{schema_names[0]}.sd"] = schema.encode("utf-8") if self.secondary_index_name: if secondary_index_embedding_dim is None: raise ValueError("Secondary index embedding dimension is required") if secondary_index_embedding_precision is None: raise ValueError("Secondary index embedding precision is required") upcoming_schema = template.render( multi_tenant=MULTI_TENANT, schema_name=self.secondary_index_name, dim=secondary_index_embedding_dim, embedding_precision=secondary_index_embedding_precision.value, ) zip_dict[f"schemas/{schema_names[1]}.sd"] = upcoming_schema.encode("utf-8") zip_file = in_memory_zip_from_file_bytes(zip_dict) headers = {"Content-Type": "application/zip"} response = requests.post(deploy_url, headers=headers, data=zip_file) if response.status_code != 200: logger.error( f"Failed to prepare Vespa Onyx Index. Response: {response.text}" ) raise RuntimeError( f"Failed to prepare Vespa Onyx Index. Response: {response.text}" ) @staticmethod def register_multitenant_indices( indices: list[str], embedding_dims: list[int], embedding_precisions: list[EmbeddingPrecision], ) -> None: if not MULTI_TENANT: raise ValueError("Multi-tenant is not enabled") deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate" logger.info(f"Deploying Vespa application package to {deploy_url}") vespa_schema_path = os.path.join( os.getcwd(), "onyx", "document_index", "vespa", "app_config" ) schema_jinja_file = os.path.join( vespa_schema_path, "schemas", VespaIndex.VESPA_SCHEMA_JINJA_FILENAME ) services_jinja_file = os.path.join(vespa_schema_path, "services.xml.jinja") overrides_jinja_file = os.path.join( vespa_schema_path, "validation-overrides.xml.jinja" ) jinja_env = jinja2.Environment() # Generate schema names from index settings with open(services_jinja_file, "r") as services_f: schema_names = [index_name for index_name in indices] doc_lines = _create_document_xml_lines(schema_names) services_template_str = services_f.read() services_template = jinja_env.from_string(services_template_str) services = services_template.render( document_elements=doc_lines, num_search_threads=str(VESPA_SEARCHER_THREADS), ) kv_store = get_shared_kv_store() needs_reindexing = False try: needs_reindexing = cast(bool, kv_store.load(KV_REINDEX_KEY)) except Exception: logger.debug("Could not load the reindexing flag. Using ngrams") # Vespa requires an override to erase data including the indices we're no longer using # It also has a 30 day cap from current so we set it to 7 dynamically with open(overrides_jinja_file, "r") as overrides_f: overrides_template_str = overrides_f.read() overrides_template = jinja_env.from_string(overrides_template_str) now = datetime.now() date_in_7_days = now + timedelta(days=7) formatted_date = date_in_7_days.strftime("%Y-%m-%d") overrides = overrides_template.render( until_date=formatted_date, ) zip_dict = { "services.xml": services.encode("utf-8"), "validation-overrides.xml": overrides.encode("utf-8"), } with open(schema_jinja_file, "r") as schema_f: schema_template_str = schema_f.read() schema_template = jinja_env.from_string(schema_template_str) for i, index_name in enumerate(indices): embedding_dim = embedding_dims[i] embedding_precision = embedding_precisions[i] logger.info( f"Creating index: {index_name} with embedding dimension: {embedding_dim}" ) schema = schema_template.render( multi_tenant=MULTI_TENANT, schema_name=index_name, dim=embedding_dim, embedding_precision=embedding_precision.value, ) schema = add_ngrams_to_schema(schema) if needs_reindexing else schema zip_dict[f"schemas/{index_name}.sd"] = schema.encode("utf-8") zip_file = in_memory_zip_from_file_bytes(zip_dict) headers = {"Content-Type": "application/zip"} response = requests.post(deploy_url, headers=headers, data=zip_file) if response.status_code != 200: raise RuntimeError( f"Failed to prepare Vespa Onyx Indexes. Response: {response.text}" ) def index( self, chunks: Iterable[DocMetadataAwareIndexChunk], index_batch_params: IndexBatchParams, ) -> set[OldDocumentInsertionRecord]: """ NOTE: Do NOT consider the secondary index here. A separate indexing pipeline will be responsible for indexing to the secondary index. This design is not ideal and we should reconsider this when revamping index swapping. """ if len(index_batch_params.doc_id_to_previous_chunk_cnt) != len( index_batch_params.doc_id_to_new_chunk_cnt ): raise ValueError("Bug: Length of doc ID to chunk maps does not match.") doc_id_to_chunk_cnt_diff = { doc_id: IndexingMetadata.ChunkCounts( old_chunk_cnt=index_batch_params.doc_id_to_previous_chunk_cnt[doc_id], new_chunk_cnt=index_batch_params.doc_id_to_new_chunk_cnt[doc_id], ) for doc_id in index_batch_params.doc_id_to_previous_chunk_cnt.keys() } indexing_metadata = IndexingMetadata( doc_id_to_chunk_cnt_diff=doc_id_to_chunk_cnt_diff, ) tenant_state = TenantState( tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT, ) if tenant_state.multitenant != self.multitenant: raise ValueError( f"Bug: Multitenant mismatch. Expected {tenant_state.multitenant}, got {self.multitenant}." ) if ( tenant_state.multitenant and tenant_state.tenant_id != index_batch_params.tenant_id ): raise ValueError( f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {index_batch_params.tenant_id}." ) vespa_document_index = VespaDocumentIndex( index_name=self.index_name, tenant_state=tenant_state, large_chunks_enabled=self.large_chunks_enabled, httpx_client=self.httpx_client, ) # This conversion from list to set only to be converted again to a list # upstream is suboptimal and only temporary until we refactor the # entirety of this class. document_insertion_records = vespa_document_index.index( chunks, indexing_metadata ) return set( [ OldDocumentInsertionRecord( document_id=doc_insertion_record.document_id, already_existed=doc_insertion_record.already_existed, ) for doc_insertion_record in document_insertion_records ] ) @classmethod def _apply_updates_batched( cls, updates: list[_VespaUpdateRequest], httpx_client: httpx.Client, batch_size: int = BATCH_SIZE, ) -> None: """Runs a batch of updates in parallel via the ThreadPoolExecutor.""" def _update_chunk( update: _VespaUpdateRequest, http_client: httpx.Client ) -> httpx.Response: logger.debug( f"Updating with request to {update.url} with body {update.update_request}" ) return http_client.put( update.url, headers={"Content-Type": "application/json"}, json=update.update_request, ) # NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficient for # indexing / updates / deletes since we have to make a large volume of requests. with ( concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor, httpx_client as http_client, ): for update_batch in batch_generator(updates, batch_size): future_to_document_id = { executor.submit( _update_chunk, update, http_client, ): update.document_id for update in update_batch } for future in concurrent.futures.as_completed(future_to_document_id): res = future.result() try: res.raise_for_status() except requests.HTTPError as e: failure_msg = f"Failed to update document: {future_to_document_id[future]}" raise requests.HTTPError(failure_msg) from e @classmethod def _apply_kg_chunk_updates_batched( cls, updates: list[KGVespaChunkUpdateRequest], httpx_client: httpx.Client, batch_size: int = BATCH_SIZE, ) -> None: """Runs a batch of updates in parallel via the ThreadPoolExecutor.""" @retry(tries=3, delay=1, backoff=2, jitter=(0.0, 1.0)) def _kg_update_chunk( update: KGVespaChunkUpdateRequest, http_client: httpx.Client ) -> httpx.Response: return http_client.put( update.url, headers={"Content-Type": "application/json"}, json=update.update_request, ) # NOTE: using `httpx` here since `requests` doesn't support HTTP2. This is beneficient for # indexing / updates / deletes since we have to make a large volume of requests. with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor: for update_batch in batch_generator(updates, batch_size): future_to_document_id = { executor.submit( _kg_update_chunk, update, httpx_client, ): update.document_id for update in update_batch } for future in concurrent.futures.as_completed(future_to_document_id): res = future.result() try: res.raise_for_status() except requests.HTTPError as e: failure_msg = f"Failed to update document {future_to_document_id[future]}\nResponse: {res.text}" raise requests.HTTPError(failure_msg) from e def kg_chunk_updates( self, kg_update_requests: list[KGUChunkUpdateRequest], tenant_id: str ) -> None: processed_updates_requests: list[KGVespaChunkUpdateRequest] = [] update_start = time.monotonic() # Build the _VespaUpdateRequest objects for kg_update_request in kg_update_requests: kg_update_dict: dict[str, dict] = { "fields": generate_kg_update_request(kg_update_request) } if not kg_update_dict["fields"]: logger.error("Update request received but nothing to update") continue doc_chunk_id = get_uuid_from_chunk_info( document_id=kg_update_request.document_id, chunk_id=kg_update_request.chunk_id, tenant_id=tenant_id, large_chunk_id=None, ) processed_updates_requests.append( KGVespaChunkUpdateRequest( document_id=kg_update_request.document_id, chunk_id=kg_update_request.chunk_id, url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=self.index_name)}/{doc_chunk_id}", update_request=kg_update_dict, ) ) with self.httpx_client_context as httpx_client: self._apply_kg_chunk_updates_batched( processed_updates_requests, httpx_client ) logger.debug( "Updated %d vespa documents in %.2f seconds", len(processed_updates_requests), time.monotonic() - update_start, ) def update_single( self, doc_id: str, *, chunk_count: int | None, tenant_id: str, fields: VespaDocumentFields | None, user_fields: VespaDocumentUserFields | None, ) -> None: """Note: if the document id does not exist, the update will be a no-op and the function will complete with no errors or exceptions. Handle other exceptions if you wish to implement retry behavior NOTE: Remember to handle the secondary index here. There is no separate pipeline for updating chunks in the secondary index. This design is not ideal and we should reconsider this when revamping index swapping. """ if fields is None and user_fields is None: logger.warning( f"Tried to update document {doc_id} with no updated fields or user fields." ) return tenant_state = TenantState( tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT, ) if tenant_state.multitenant != self.multitenant: raise ValueError( f"Bug: Multitenant mismatch. Expected {tenant_state.multitenant}, got {self.multitenant}." ) if tenant_state.multitenant and tenant_state.tenant_id != tenant_id: raise ValueError( f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}." ) project_ids: set[int] | None = None # NOTE: Empty user_projects is semantically different from None # user_projects. if user_fields is not None and user_fields.user_projects is not None: project_ids = set(user_fields.user_projects) persona_ids: set[int] | None = None # NOTE: Empty personas is semantically different from None personas. if user_fields is not None and user_fields.personas is not None: persona_ids = set(user_fields.personas) update_request = MetadataUpdateRequest( document_ids=[doc_id], doc_id_to_chunk_cnt={ doc_id: chunk_count if chunk_count is not None else -1 }, # NOTE: -1 represents an unknown chunk count. access=fields.access if fields is not None else None, document_sets=fields.document_sets if fields is not None else None, boost=fields.boost if fields is not None else None, hidden=fields.hidden if fields is not None else None, project_ids=project_ids, persona_ids=persona_ids, ) indices = [self.index_name] if self.secondary_index_name: indices.append(self.secondary_index_name) for index_name in indices: vespa_document_index = VespaDocumentIndex( index_name=index_name, tenant_state=tenant_state, large_chunks_enabled=self.index_to_large_chunks_enabled.get( index_name, False ), httpx_client=self.httpx_client, ) vespa_document_index.update([update_request]) def delete_single( self, doc_id: str, *, tenant_id: str, chunk_count: int | None, ) -> int: """ NOTE: Remember to handle the secondary index here. There is no separate pipeline for deleting chunks in the secondary index. This design is not ideal and we should reconsider this when revamping index swapping. """ tenant_state = TenantState( tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT, ) if tenant_state.multitenant != self.multitenant: raise ValueError( f"Bug: Multitenant mismatch. Expected {tenant_state.multitenant}, got {self.multitenant}." ) if tenant_state.multitenant and tenant_state.tenant_id != tenant_id: raise ValueError( f"Bug: Tenant ID mismatch. Expected {tenant_state.tenant_id}, got {tenant_id}." ) indices = [self.index_name] if self.secondary_index_name: indices.append(self.secondary_index_name) total_chunks_deleted = 0 for index_name in indices: vespa_document_index = VespaDocumentIndex( index_name=index_name, tenant_state=tenant_state, large_chunks_enabled=self.index_to_large_chunks_enabled.get( index_name, False ), httpx_client=self.httpx_client, ) total_chunks_deleted += vespa_document_index.delete( document_id=doc_id, chunk_count=chunk_count ) return total_chunks_deleted def id_based_retrieval( self, chunk_requests: list[VespaChunkRequest], filters: IndexFilters, batch_retrieval: bool = False, get_large_chunks: bool = False, # noqa: ARG002 ) -> list[InferenceChunk]: tenant_state = TenantState( tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT, ) vespa_document_index = VespaDocumentIndex( index_name=self.index_name, tenant_state=tenant_state, large_chunks_enabled=self.large_chunks_enabled, httpx_client=self.httpx_client, ) generic_chunk_requests: list[DocumentSectionRequest] = [] for chunk_request in chunk_requests: generic_chunk_requests.append( DocumentSectionRequest( document_id=chunk_request.document_id, min_chunk_ind=chunk_request.min_chunk_ind, max_chunk_ind=chunk_request.max_chunk_ind, ) ) return vespa_document_index.id_based_retrieval( chunk_requests=generic_chunk_requests, filters=filters, batch_retrieval=batch_retrieval, ) @log_function_time(print_only=True, debug_only=True) def hybrid_retrieval( self, query: str, query_embedding: Embedding, final_keywords: list[str] | None, filters: IndexFilters, hybrid_alpha: float, # noqa: ARG002 time_decay_multiplier: float, # noqa: ARG002 num_to_retrieve: int, ranking_profile_type: QueryExpansionType = QueryExpansionType.SEMANTIC, title_content_ratio: float | None = TITLE_CONTENT_RATIO, # noqa: ARG002 ) -> list[InferenceChunk]: tenant_state = TenantState( tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT, ) vespa_document_index = VespaDocumentIndex( index_name=self.index_name, tenant_state=tenant_state, large_chunks_enabled=self.large_chunks_enabled, httpx_client=self.httpx_client, ) if not ( ranking_profile_type == QueryExpansionType.KEYWORD or ranking_profile_type == QueryExpansionType.SEMANTIC ): raise ValueError( f"Bug: Received invalid ranking profile type: {ranking_profile_type}" ) query_type = ( QueryType.KEYWORD if ranking_profile_type == QueryExpansionType.KEYWORD else QueryType.SEMANTIC ) return vespa_document_index.hybrid_retrieval( query, query_embedding, final_keywords, query_type, filters, num_to_retrieve, ) def admin_retrieval( self, query: str, query_embedding: Embedding, # noqa: ARG002 filters: IndexFilters, num_to_retrieve: int = NUM_RETURNED_HITS, ) -> list[InferenceChunk]: vespa_where_clauses = build_vespa_filters(filters, include_hidden=True) yql = ( YQL_BASE.format(index_name=self.index_name) + vespa_where_clauses + '({grammar: "weakAnd"}userInput(@query) ' # `({defaultIndex: "content_summary"}userInput(@query))` section is # needed for highlighting while the N-gram highlighting is broken / # not working as desired + f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))' ) params: dict[str, str | int] = { "yql": yql, "query": query, "hits": num_to_retrieve, "ranking.profile": "admin_search", "timeout": VESPA_TIMEOUT, } return cleanup_chunks(query_vespa(params)) # Retrieves chunk information for a document: # - Determines the last indexed chunk # - Identifies if the document uses the old or new chunk ID system # This data is crucial for Vespa document updates without relying on the visit API. @classmethod def enrich_basic_chunk_info( cls, index_name: str, http_client: httpx.Client, document_id: str, previous_chunk_count: int | None = None, new_chunk_count: int = 0, ) -> EnrichedDocumentIndexingInfo: last_indexed_chunk = previous_chunk_count # If the document has no `chunk_count` in the database, we know that it # has the old chunk ID system and we must check for the final chunk index is_old_version = False if last_indexed_chunk is None: is_old_version = True minimal_doc_info = MinimalDocumentIndexingInfo( doc_id=document_id, chunk_start_index=new_chunk_count ) last_indexed_chunk = check_for_final_chunk_existence( minimal_doc_info=minimal_doc_info, start_index=new_chunk_count, index_name=index_name, http_client=http_client, ) enriched_doc_info = EnrichedDocumentIndexingInfo( doc_id=document_id, chunk_start_index=new_chunk_count, chunk_end_index=last_indexed_chunk, old_version=is_old_version, ) return enriched_doc_info @classmethod def delete_entries_by_tenant_id( cls, *, tenant_id: str, index_name: str, ) -> int: """ Deletes all entries in the specified index with the given tenant_id. Currently unused, but we anticipate this being useful. The entire flow does not use the httpx connection pool of an instance. Parameters: tenant_id (str): The tenant ID whose documents are to be deleted. index_name (str): The name of the index from which to delete documents. Returns: int: The number of documents deleted. """ logger.info( f"Deleting entries with tenant_id: {tenant_id} from index: {index_name}" ) # Step 1: Retrieve all document IDs with the given tenant_id document_ids = cls._get_all_document_ids_by_tenant_id(tenant_id, index_name) if not document_ids: logger.info( f"No documents found with tenant_id: {tenant_id} in index: {index_name}" ) return 0 # Step 2: Delete documents in batches delete_requests = [ _VespaDeleteRequest(document_id=doc_id, index_name=index_name) for doc_id in document_ids ] cls._apply_deletes_batched(delete_requests) return len(document_ids) @classmethod def _get_all_document_ids_by_tenant_id( cls, tenant_id: str, index_name: str ) -> List[str]: """ Retrieves all document IDs with the specified tenant_id, handling pagination. Internal helper function for delete_entries_by_tenant_id. Parameters: tenant_id (str): The tenant ID to search for. index_name (str): The name of the index to search in. Returns: List[str]: A list of document IDs matching the tenant_id. """ offset = 0 limit = 1000 # Vespa's maximum hits per query document_ids = [] logger.debug( f"Starting document ID retrieval for tenant_id: {tenant_id} in index: {index_name}" ) while True: # Construct the query to fetch document IDs query_params = { "yql": f'select id from sources * where tenant_id contains "{tenant_id}";', "offset": str(offset), "hits": str(limit), "timeout": "10s", "format": "json", "summary": "id", } url = f"{VESPA_APPLICATION_ENDPOINT}/search/" logger.debug( f"Querying for document IDs with tenant_id: {tenant_id}, offset: {offset}" ) with get_vespa_http_client() as http_client: response = http_client.get(url, params=query_params, timeout=None) response.raise_for_status() search_result = response.json() hits = search_result.get("root", {}).get("children", []) if not hits: break for hit in hits: doc_id = hit.get("id") if doc_id: document_ids.append(doc_id) offset += limit # Move to the next page logger.debug( f"Retrieved {len(document_ids)} document IDs for tenant_id: {tenant_id}" ) return document_ids @classmethod def _apply_deletes_batched( cls, delete_requests: List["_VespaDeleteRequest"], batch_size: int = BATCH_SIZE, ) -> None: """ Deletes documents in batches using multiple threads. Internal helper function for delete_entries_by_tenant_id. This is a class method and does not use the httpx pool of the instance. This is OK because we don't use this method often. Parameters: delete_requests (List[_VespaDeleteRequest]): The list of delete requests. batch_size (int): The number of documents to delete in each batch. """ def _delete_document( delete_request: "_VespaDeleteRequest", http_client: httpx.Client ) -> None: logger.debug(f"Deleting document with ID {delete_request.document_id}") response = http_client.delete( delete_request.url, headers={"Content-Type": "application/json"}, timeout=None, ) response.raise_for_status() logger.debug(f"Starting batch deletion for {len(delete_requests)} documents") with concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor: with get_vespa_http_client() as http_client: for batch_start in range(0, len(delete_requests), batch_size): batch = delete_requests[batch_start : batch_start + batch_size] future_to_document_id = { executor.submit( _delete_document, delete_request, http_client, ): delete_request.document_id for delete_request in batch } for future in concurrent.futures.as_completed( future_to_document_id ): doc_id = future_to_document_id[future] try: future.result() logger.debug(f"Successfully deleted document: {doc_id}") except httpx.HTTPError as e: logger.error(f"Failed to delete document {doc_id}: {e}") # Optionally, implement retry logic or error handling here logger.info("Batch deletion completed") def random_retrieval( self, filters: IndexFilters, num_to_retrieve: int = 10, ) -> list[InferenceChunk]: """Retrieve random chunks matching the filters using Vespa's random ranking This method is currently used for random chunk retrieval in the context of assistant starter message creation (passed as sample context for usage by the assistant). """ tenant_state = TenantState( tenant_id=get_current_tenant_id(), multitenant=MULTI_TENANT, ) vespa_document_index = VespaDocumentIndex( index_name=self.index_name, tenant_state=tenant_state, large_chunks_enabled=self.large_chunks_enabled, httpx_client=self.httpx_client, ) return vespa_document_index.random_retrieval( filters=filters, num_to_retrieve=num_to_retrieve, ) class _VespaDeleteRequest: def __init__(self, document_id: str, index_name: str) -> None: self.document_id = document_id # Encode the document ID to ensure it's safe for use in the URL encoded_doc_id = urllib.parse.quote_plus(self.document_id) self.url = f"{VESPA_APPLICATION_ENDPOINT}/document/v1/{index_name}/{index_name}/docid/{encoded_doc_id}" ================================================ FILE: backend/onyx/document_index/vespa/indexing_utils.py ================================================ import concurrent.futures import json import random import time import uuid from abc import ABC from abc import abstractmethod from collections.abc import Callable from datetime import datetime from datetime import timezone from http import HTTPStatus import httpx from retry import retry from onyx.connectors.cross_connector_utils.miscellaneous_utils import ( get_experts_stores_representations, ) from onyx.document_index.chunk_content_enrichment import ( generate_enriched_content_for_chunk_text, ) from onyx.document_index.document_index_utils import get_uuid_from_chunk from onyx.document_index.document_index_utils import get_uuid_from_chunk_info_old from onyx.document_index.interfaces import MinimalDocumentIndexingInfo from onyx.document_index.vespa.shared_utils.utils import ( replace_invalid_doc_id_characters, ) from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST from onyx.document_index.vespa_constants import AGGREGATED_CHUNK_BOOST_FACTOR from onyx.document_index.vespa_constants import BLURB from onyx.document_index.vespa_constants import BOOST from onyx.document_index.vespa_constants import CHUNK_CONTEXT from onyx.document_index.vespa_constants import CHUNK_ID from onyx.document_index.vespa_constants import CONTENT from onyx.document_index.vespa_constants import CONTENT_SUMMARY from onyx.document_index.vespa_constants import DOC_SUMMARY from onyx.document_index.vespa_constants import DOC_UPDATED_AT from onyx.document_index.vespa_constants import DOCUMENT_ID from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT from onyx.document_index.vespa_constants import DOCUMENT_SETS from onyx.document_index.vespa_constants import EMBEDDINGS from onyx.document_index.vespa_constants import FULL_CHUNK_EMBEDDING_KEY from onyx.document_index.vespa_constants import IMAGE_FILE_NAME from onyx.document_index.vespa_constants import LARGE_CHUNK_REFERENCE_IDS from onyx.document_index.vespa_constants import METADATA from onyx.document_index.vespa_constants import METADATA_LIST from onyx.document_index.vespa_constants import METADATA_SUFFIX from onyx.document_index.vespa_constants import NUM_THREADS from onyx.document_index.vespa_constants import PERSONAS from onyx.document_index.vespa_constants import PRIMARY_OWNERS from onyx.document_index.vespa_constants import SECONDARY_OWNERS from onyx.document_index.vespa_constants import SECTION_CONTINUATION from onyx.document_index.vespa_constants import SEMANTIC_IDENTIFIER from onyx.document_index.vespa_constants import SKIP_TITLE_EMBEDDING from onyx.document_index.vespa_constants import SOURCE_LINKS from onyx.document_index.vespa_constants import SOURCE_TYPE from onyx.document_index.vespa_constants import TENANT_ID from onyx.document_index.vespa_constants import TITLE from onyx.document_index.vespa_constants import TITLE_EMBEDDING from onyx.document_index.vespa_constants import USER_PROJECT from onyx.indexing.models import DocMetadataAwareIndexChunk from onyx.utils.logger import setup_logger from onyx.utils.text_processing import remove_invalid_unicode_chars logger = setup_logger() # Retry configuration constants INDEXING_MAX_RETRIES = 5 INDEXING_BASE_DELAY = 1.0 INDEXING_MAX_DELAY = 60.0 @retry(tries=3, delay=1, backoff=2) def _does_doc_chunk_exist( doc_chunk_id: uuid.UUID, index_name: str, http_client: httpx.Client ) -> bool: doc_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}" doc_fetch_response = http_client.get(doc_url) if doc_fetch_response.status_code == 404: return False if doc_fetch_response.status_code != 200: logger.debug(f"Failed to check for document with URL {doc_url}") raise RuntimeError( f"Unexpected fetch document by ID value from Vespa: " f"error={doc_fetch_response.status_code} " f"index={index_name} " f"doc_chunk_id={doc_chunk_id}" ) return True def _vespa_get_updated_at_attribute(t: datetime | None) -> int | None: if not t: return None if t.tzinfo != timezone.utc: raise ValueError("Connectors must provide document update time in UTC") return int(t.timestamp()) def get_existing_documents_from_chunks( chunks: list[DocMetadataAwareIndexChunk], index_name: str, http_client: httpx.Client, executor: concurrent.futures.ThreadPoolExecutor | None = None, ) -> set[str]: external_executor = True if not executor: external_executor = False executor = concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) document_ids: set[str] = set() try: chunk_existence_future = { executor.submit( _does_doc_chunk_exist, get_uuid_from_chunk(chunk), index_name, http_client, ): chunk for chunk in chunks } for future in concurrent.futures.as_completed(chunk_existence_future): chunk = chunk_existence_future[future] chunk_already_existed = future.result() if chunk_already_existed: document_ids.add(chunk.source_document.id) finally: if not external_executor: executor.shutdown(wait=True) return document_ids def _index_vespa_chunk( chunk: DocMetadataAwareIndexChunk, index_name: str, http_client: httpx.Client, multitenant: bool, ) -> None: json_header = { "Content-Type": "application/json", } document = chunk.source_document # No minichunk documents in vespa, minichunk vectors are stored in the chunk itself vespa_chunk_id = str(get_uuid_from_chunk(chunk)) embeddings = chunk.embeddings embeddings_name_vector_map = {FULL_CHUNK_EMBEDDING_KEY: embeddings.full_embedding} if embeddings.mini_chunk_embeddings: for ind, m_c_embed in enumerate(embeddings.mini_chunk_embeddings): embeddings_name_vector_map[f"mini_chunk_{ind}"] = m_c_embed title = document.get_title_for_document_index() metadata_json = document.metadata cleaned_metadata_json: dict[str, str | list[str]] = {} for key, value in metadata_json.items(): cleaned_key = remove_invalid_unicode_chars(key) if isinstance(value, list): cleaned_metadata_json[cleaned_key] = [ remove_invalid_unicode_chars(item) for item in value ] else: cleaned_metadata_json[cleaned_key] = remove_invalid_unicode_chars(value) metadata_list = document.get_metadata_str_attributes() if metadata_list: metadata_list = [ remove_invalid_unicode_chars(metadata) for metadata in metadata_list ] vespa_document_fields = { DOCUMENT_ID: document.id, CHUNK_ID: chunk.chunk_id, BLURB: remove_invalid_unicode_chars(chunk.blurb), TITLE: remove_invalid_unicode_chars(title) if title else None, SKIP_TITLE_EMBEDDING: not title, # For the BM25 index, the keyword suffix is used, the vector is already generated with the more # natural language representation of the metadata section CONTENT: remove_invalid_unicode_chars( generate_enriched_content_for_chunk_text(chunk) ), # This duplication of `content` is needed for keyword highlighting # Note that it's not exactly the same as the actual content # which contains the title prefix and metadata suffix CONTENT_SUMMARY: remove_invalid_unicode_chars(chunk.content), SOURCE_TYPE: str(document.source.value), SOURCE_LINKS: json.dumps(chunk.source_links), SEMANTIC_IDENTIFIER: remove_invalid_unicode_chars(document.semantic_identifier), SECTION_CONTINUATION: chunk.section_continuation, LARGE_CHUNK_REFERENCE_IDS: chunk.large_chunk_reference_ids, METADATA: json.dumps(cleaned_metadata_json), # Save as a list for efficient extraction as an Attribute METADATA_LIST: metadata_list, METADATA_SUFFIX: remove_invalid_unicode_chars(chunk.metadata_suffix_keyword), CHUNK_CONTEXT: chunk.chunk_context, DOC_SUMMARY: chunk.doc_summary, EMBEDDINGS: embeddings_name_vector_map, TITLE_EMBEDDING: chunk.title_embedding, DOC_UPDATED_AT: _vespa_get_updated_at_attribute(document.doc_updated_at), PRIMARY_OWNERS: get_experts_stores_representations(document.primary_owners), SECONDARY_OWNERS: get_experts_stores_representations(document.secondary_owners), # the only `set` vespa has is `weightedset`, so we have to give each # element an arbitrary weight # rkuo: acl, docset and boost metadata are also updated through the metadata sync queue # which only calls VespaIndex.update ACCESS_CONTROL_LIST: {acl_entry: 1 for acl_entry in chunk.access.to_acl()}, DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets}, # still called `image_file_name` in Vespa for backwards compatibility IMAGE_FILE_NAME: chunk.image_file_id, USER_PROJECT: chunk.user_project if chunk.user_project is not None else [], PERSONAS: chunk.personas if chunk.personas is not None else [], BOOST: chunk.boost, AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor, } if multitenant: if chunk.tenant_id: vespa_document_fields[TENANT_ID] = chunk.tenant_id vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{vespa_chunk_id}" logger.debug(f'Indexing to URL "{vespa_url}"') # Retry logic with exponential backoff for rate limiting for attempt in range(INDEXING_MAX_RETRIES): try: res = http_client.post( vespa_url, headers=json_header, json={"fields": vespa_document_fields} ) res.raise_for_status() return # Success, exit the function except httpx.HTTPStatusError as e: # Handle 429 rate limiting specifically if e.response.status_code == HTTPStatus.TOO_MANY_REQUESTS: if attempt < INDEXING_MAX_RETRIES - 1: # Calculate exponential backoff with jitter delay = min( INDEXING_BASE_DELAY * (2**attempt), INDEXING_MAX_DELAY ) * random.uniform(0.5, 1.0) logger.warning( f"Rate limited while indexing document '{document.id}' " f"(attempt {attempt + 1}/{INDEXING_MAX_RETRIES}). " f"Vespa response: '{e.response.text}'. " f"Backing off for {delay:.2f} seconds." ) time.sleep(delay) continue else: raise RuntimeError( f"Failed to index document '{document.id}' after {INDEXING_MAX_RETRIES} attempts due to rate limiting" ) from e elif e.response.status_code == HTTPStatus.INSUFFICIENT_STORAGE: logger.error( f"Failed to index document: '{document.id}'. Got response: '{e.response.text}'" ) logger.error( "NOTE: HTTP Status 507 Insufficient Storage usually means " "you need to allocate more memory or disk space to the " "Vespa/index container." ) raise else: # For other HTTP errors, check if retryable if e.response.status_code in ( HTTPStatus.BAD_REQUEST, HTTPStatus.UNAUTHORIZED, HTTPStatus.FORBIDDEN, HTTPStatus.NOT_FOUND, ): # Non-retryable errors - fail immediately logger.error( f"Non-retryable HTTP {e.response.status_code} error for document '{document.id}'" ) raise # Retry other errors with shorter backoff if attempt < INDEXING_MAX_RETRIES - 1: delay = INDEXING_BASE_DELAY * (1.5**attempt) logger.warning( f"HTTP error {e.response.status_code} while indexing document '{document.id}' " f"(attempt {attempt + 1}/{INDEXING_MAX_RETRIES}). Retrying in {delay:.2f} seconds." ) time.sleep(delay) continue else: logger.exception( f"Failed to index document: '{document.id}'. Got response: '{e.response.text}'" ) raise except Exception as e: # For non-HTTP errors, use simple retry logic if attempt < INDEXING_MAX_RETRIES - 1: delay = INDEXING_BASE_DELAY * (1.5**attempt) logger.warning( f"Error while indexing document '{document.id}' " f"(attempt {attempt + 1}/{INDEXING_MAX_RETRIES}): {str(e)}. " f"Retrying in {delay:.2f} seconds." ) time.sleep(delay) continue else: logger.exception(f"Failed to index document: '{document.id}'") raise def batch_index_vespa_chunks( chunks: list[DocMetadataAwareIndexChunk], index_name: str, http_client: httpx.Client, multitenant: bool, executor: concurrent.futures.ThreadPoolExecutor | None = None, ) -> None: """Indexes a list of chunks in a Vespa index in parallel. Args: chunks: List of chunks to index. index_name: Name of the index to index into. http_client: HTTP client to use for the request. multitenant: Whether the index is multitenant. executor: Executor to use for the request. """ external_executor = True if not executor: external_executor = False executor = concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) try: chunk_index_future = { executor.submit( _index_vespa_chunk, chunk, index_name, http_client, multitenant ): chunk for chunk in chunks } for future in concurrent.futures.as_completed(chunk_index_future): # Will raise exception if any indexing raised an exception future.result() finally: if not external_executor: executor.shutdown(wait=True) def clean_chunk_id_copy( chunk: DocMetadataAwareIndexChunk, ) -> DocMetadataAwareIndexChunk: clean_chunk = chunk.model_copy( update={ "source_document": chunk.source_document.model_copy( update={ "id": replace_invalid_doc_id_characters(chunk.source_document.id) } ) } ) return clean_chunk def check_for_final_chunk_existence( minimal_doc_info: MinimalDocumentIndexingInfo, start_index: int, index_name: str, http_client: httpx.Client, ) -> int: index = start_index while True: doc_chunk_id = get_uuid_from_chunk_info_old( document_id=minimal_doc_info.doc_id, chunk_id=index, large_chunk_reference_ids=[], ) if not _does_doc_chunk_exist(doc_chunk_id, index_name, http_client): return index index += 1 class BaseHTTPXClientContext(ABC): """Abstract base class for an HTTPX client context manager.""" @abstractmethod def __enter__(self) -> httpx.Client: pass @abstractmethod def __exit__(self, exc_type, exc_value, traceback): # type: ignore pass class GlobalHTTPXClientContext(BaseHTTPXClientContext): """Context manager for a global HTTPX client that does not close it.""" def __init__(self, client: httpx.Client): self._client = client def __enter__(self) -> httpx.Client: return self._client # Reuse the global client def __exit__(self, exc_type, exc_value, traceback): # type: ignore pass # Do nothing; don't close the global client class TemporaryHTTPXClientContext(BaseHTTPXClientContext): """Context manager for a temporary HTTPX client that closes it after use.""" def __init__(self, client_factory: Callable[[], httpx.Client]): self._client_factory = client_factory self._client: httpx.Client | None = None # Client will be created in __enter__ def __enter__(self) -> httpx.Client: self._client = self._client_factory() # Create a new client return self._client def __exit__(self, exc_type, exc_value, traceback): # type: ignore if self._client: self._client.close() ================================================ FILE: backend/onyx/document_index/vespa/kg_interactions.py ================================================ from onyx.db.document import get_document_kg_entities_and_relationships from onyx.db.document import get_num_chunks_for_document from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.document_index.vespa.index import KGUChunkUpdateRequest from onyx.document_index.vespa.index import VespaIndex from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT logger = setup_logger() def update_kg_chunks_vespa_info( kg_update_requests: list[KGUChunkUpdateRequest], index_name: str, tenant_id: str, ) -> None: """ """ # Use the existing visit API infrastructure vespa_index = VespaIndex( index_name=index_name, secondary_index_name=None, large_chunks_enabled=False, secondary_large_chunks_enabled=False, multitenant=MULTI_TENANT, httpx_client=None, ) vespa_index.kg_chunk_updates( kg_update_requests=kg_update_requests, tenant_id=tenant_id ) def get_kg_vespa_info_update_requests_for_document( document_id: str, ) -> list[KGUChunkUpdateRequest]: """Get the kg_info update requests for a document.""" # get all entities and relationships tied to the document with get_session_with_current_tenant() as db_session: entities, relationships = get_document_kg_entities_and_relationships( db_session, document_id ) # create the kg vespa info kg_entities = {entity.id_name for entity in entities} kg_relationships = {relationship.id_name for relationship in relationships} # get chunks in the document with get_session_with_current_tenant() as db_session: num_chunks = get_num_chunks_for_document(db_session, document_id) # get vespa update requests return [ KGUChunkUpdateRequest( document_id=document_id, chunk_id=chunk_id, core_entity="unused", entities=kg_entities, relationships=kg_relationships or None, ) for chunk_id in range(num_chunks) ] ================================================ FILE: backend/onyx/document_index/vespa/shared_utils/utils.py ================================================ import time from typing import cast import httpx from onyx.configs.app_configs import MANAGED_VESPA from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH from onyx.configs.app_configs import VESPA_REQUEST_TIMEOUT from onyx.document_index.vespa_constants import VESPA_APP_CONTAINER_URL from onyx.utils.logger import setup_logger logger = setup_logger() # NOTE: This does not seem to be used in reality despite the Vespa Docs pointing to this code # See here for reference: https://docs.vespa.ai/en/documents.html # https://github.com/vespa-engine/vespa/blob/master/vespajlib/src/main/java/com/yahoo/text/Text.java # Define allowed ASCII characters ALLOWED_ASCII_CHARS: list[bool] = [False] * 0x80 ALLOWED_ASCII_CHARS[0x9] = True # tab ALLOWED_ASCII_CHARS[0xA] = True # newline ALLOWED_ASCII_CHARS[0xD] = True # carriage return for i in range(0x20, 0x7F): ALLOWED_ASCII_CHARS[i] = True # printable ASCII chars ALLOWED_ASCII_CHARS[0x7F] = True # del - discouraged, but allowed def is_text_character(codepoint: int) -> bool: """Returns whether the given codepoint is a valid text character.""" if codepoint < 0x80: return ALLOWED_ASCII_CHARS[codepoint] if codepoint < 0xD800: return True if codepoint <= 0xDFFF: return False if codepoint < 0xFDD0: return True if codepoint <= 0xFDEF: return False if codepoint >= 0x10FFFE: return False return (codepoint & 0xFFFF) < 0xFFFE def replace_invalid_doc_id_characters(text: str) -> str: """Replaces invalid document ID characters in text. NOTE: this must be called at the start of every vespa-related operation or else we risk discrepancies -> silent failures on deletion/update/insertion.""" # There may be a more complete set of replacements that need to be made but Vespa docs are unclear # and users only seem to be running into this error with single quotes return text.replace("'", "_") def get_vespa_http_client( no_timeout: bool = False, http2: bool = True, timeout: int | None = None ) -> httpx.Client: """ Configures and returns an HTTP client for communicating with Vespa, including authentication if needed. """ return httpx.Client( cert=( cast(tuple[str, str], (VESPA_CLOUD_CERT_PATH, VESPA_CLOUD_KEY_PATH)) if MANAGED_VESPA else None ), verify=False if not MANAGED_VESPA else True, timeout=None if no_timeout else (timeout or VESPA_REQUEST_TIMEOUT), http2=http2, ) def wait_for_vespa_with_timeout(wait_interval: int = 5, wait_limit: int = 60) -> bool: """Waits for Vespa to become ready subject to a timeout. Returns True if Vespa is ready, False otherwise.""" time_start = time.monotonic() logger.info("Vespa: Readiness probe starting.") while True: url = f"{VESPA_APP_CONTAINER_URL}/state/v1/health" try: client = get_vespa_http_client() response = client.get(url) response.raise_for_status() response_dict = response.json() if response_dict["status"]["code"] == "up": logger.info("Vespa: Readiness probe succeeded. Continuing...") return True except Exception as e: logger.warning( f"Vespa: Readiness probe failed trying to connect to {url}. Exception: {e}" ) time_elapsed = time.monotonic() - time_start if time_elapsed > wait_limit: logger.info( f"Vespa: Readiness probe did not succeed within the timeout ({wait_limit} seconds)." ) return False logger.info( f"Vespa: Readiness probe ongoing. elapsed={time_elapsed:.1f} timeout={wait_limit:.1f}" ) time.sleep(wait_interval) ================================================ FILE: backend/onyx/document_index/vespa/shared_utils/vespa_request_builders.py ================================================ from datetime import datetime from datetime import timedelta from datetime import timezone from onyx.configs.constants import INDEX_SEPARATOR from onyx.context.search.models import IndexFilters from onyx.document_index.interfaces import VespaChunkRequest from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST from onyx.document_index.vespa_constants import CHUNK_ID from onyx.document_index.vespa_constants import DOC_UPDATED_AT from onyx.document_index.vespa_constants import DOCUMENT_ID from onyx.document_index.vespa_constants import DOCUMENT_SETS from onyx.document_index.vespa_constants import HIDDEN from onyx.document_index.vespa_constants import METADATA_LIST from onyx.document_index.vespa_constants import PERSONAS from onyx.document_index.vespa_constants import SOURCE_TYPE from onyx.document_index.vespa_constants import TENANT_ID from onyx.document_index.vespa_constants import USER_PROJECT from onyx.kg.utils.formatting_utils import split_relationship_id from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT logger = setup_logger() def build_tenant_id_filter(tenant_id: str) -> str: return f'({TENANT_ID} contains "{tenant_id}")' def build_vespa_filters( filters: IndexFilters, *, include_hidden: bool = False, remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query ) -> str: def _build_or_filters(key: str, vals: list[str] | None) -> str: """For string-based 'contains' filters, e.g. WSET fields or array fields. Returns a bare clause like '(key contains "v1" or key contains "v2")' or "".""" if not key or not vals: return "" eq_elems = [f'{key} contains "{val}"' for val in vals if val] if not eq_elems: return "" return f"({' or '.join(eq_elems)})" def _build_weighted_set_filter(key: str, vals: list[str] | None) -> str: """Build a Vespa weightedSet filter for large value lists. Uses Vespa's native weightedSet() operator instead of OR-chained 'contains' clauses. This is critical for fields like access_control_list where a single user may have tens of thousands of ACL entries — OR clauses at that scale cause Vespa to reject the query with HTTP 400.""" if not key or not vals: return "" filtered = [val for val in vals if val] if not filtered: return "" items = ", ".join(f'"{val}":1' for val in filtered) return f"weightedSet({key}, {{{items}}})" def _build_int_or_filters(key: str, vals: list[int] | None) -> str: """For an integer field filter. Returns a bare clause or "".""" if vals is None or not vals: return "" eq_elems = [f"{key} = {val}" for val in vals] return f"({' or '.join(eq_elems)})" def _build_kg_filter( kg_entities: list[str] | None, kg_relationships: list[str] | None, kg_terms: list[str] | None, ) -> str: if not kg_entities and not kg_relationships and not kg_terms: return "" combined_filter_parts = [] def _build_kge(entity: str) -> str: GENERAL = "::*" if entity.endswith(GENERAL): return f'({{prefix: true}}"{entity.split(GENERAL, 1)[0]}")' else: return f'"{entity}"' if kg_entities: filter_parts = [] for kg_entity in kg_entities: filter_parts.append(f"(kg_entities contains {_build_kge(kg_entity)})") combined_filter_parts.append(f"({' or '.join(filter_parts)})") # TODO: handle complex nested relationship logic (e.g., A participated, and B or C participated) if kg_relationships: filter_parts = [] for kg_relationship in kg_relationships: source, rel_type, target = split_relationship_id(kg_relationship) filter_parts.append( "(kg_relationships contains sameElement(" f"source contains {_build_kge(source)}," f'rel_type contains "{rel_type}",' f"target contains {_build_kge(target)}))" ) combined_filter_parts.append(f"{' and '.join(filter_parts)}") # TODO: remove kg terms entirely from prompts and codebase return f"({' and '.join(combined_filter_parts)})" def _build_kg_source_filters( kg_sources: list[str] | None, ) -> str: if not kg_sources: return "" source_phrases = [f'{DOCUMENT_ID} contains "{source}"' for source in kg_sources] return f"({' or '.join(source_phrases)})" def _build_kg_chunk_id_zero_only_filter( kg_chunk_id_zero_only: bool, ) -> str: if not kg_chunk_id_zero_only: return "" return "(chunk_id = 0)" def _build_time_filter( cutoff: datetime | None, untimed_doc_cutoff: timedelta = timedelta(days=92), ) -> str: if not cutoff: return "" include_untimed = datetime.now(timezone.utc) - untimed_doc_cutoff > cutoff cutoff_secs = int(cutoff.timestamp()) if include_untimed: return f"!({DOC_UPDATED_AT} < {cutoff_secs})" return f"({DOC_UPDATED_AT} >= {cutoff_secs})" def _build_user_project_filter( project_id: int | None, ) -> str: if project_id is None: return "" try: pid = int(project_id) except Exception: return "" return f'({USER_PROJECT} contains "{pid}")' def _build_persona_filter( persona_id: int | None, ) -> str: if persona_id is None: return "" try: pid = int(persona_id) except Exception: logger.warning(f"Invalid persona ID: {persona_id}") return "" return f'({PERSONAS} contains "{pid}")' def _append(parts: list[str], clause: str) -> None: if clause: parts.append(clause) # Collect all top-level filter clauses, then join with " and " at the end. filter_parts: list[str] = [] if not include_hidden: filter_parts.append(f"!({HIDDEN}=true)") # TODO: add error condition if MULTI_TENANT and no tenant_id filter is set if filters.tenant_id and MULTI_TENANT: filter_parts.append(build_tenant_id_filter(filters.tenant_id)) # ACL filters — use weightedSet for efficient matching against the # access_control_list weightedset field. OR-chaining thousands # of 'contains' clauses causes Vespa to reject the query (HTTP 400) # for users with large numbers of external permission groups. if filters.access_control_list is not None: _append( filter_parts, _build_weighted_set_filter( ACCESS_CONTROL_LIST, filters.access_control_list ), ) # Source type filters source_strs = ( [s.value for s in filters.source_type] if filters.source_type else None ) _append(filter_parts, _build_or_filters(SOURCE_TYPE, source_strs)) # Tag filters tag_attributes = None if filters.tags: tag_attributes = [ f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in filters.tags ] _append(filter_parts, _build_or_filters(METADATA_LIST, tag_attributes)) # Knowledge scope: explicit knowledge attachments restrict what an # assistant can see. When none are set, the assistant can see # everything. # # persona_id_filter is a primary trigger — a persona with user files IS # explicit knowledge, so it can start a knowledge scope on its own. # # project_id_filter is additive — it widens the scope to also cover # overflowing project files but never restricts on its own (a chat # inside a project should still search team knowledge). knowledge_scope_parts: list[str] = [] _append( knowledge_scope_parts, _build_or_filters(DOCUMENT_SETS, filters.document_set) ) _append(knowledge_scope_parts, _build_persona_filter(filters.persona_id_filter)) # project_id_filter only widens an existing scope. if knowledge_scope_parts: _append( knowledge_scope_parts, _build_user_project_filter(filters.project_id_filter), ) if len(knowledge_scope_parts) > 1: filter_parts.append("(" + " or ".join(knowledge_scope_parts) + ")") elif len(knowledge_scope_parts) == 1: filter_parts.append(knowledge_scope_parts[0]) # Time filter _append(filter_parts, _build_time_filter(filters.time_cutoff)) # # Knowledge Graph Filters # _append(filter_parts, _build_kg_filter( # kg_entities=filters.kg_entities, # kg_relationships=filters.kg_relationships, # kg_terms=filters.kg_terms, # )) # _append(filter_parts, _build_kg_source_filters(filters.kg_sources)) # _append(filter_parts, _build_kg_chunk_id_zero_only_filter( # filters.kg_chunk_id_zero_only or False # )) filter_str = " and ".join(filter_parts) if filter_str and not remove_trailing_and: filter_str += " and " return filter_str def build_vespa_id_based_retrieval_yql( chunk_request: VespaChunkRequest, ) -> str: id_based_retrieval_yql_section = ( f'({DOCUMENT_ID} contains "{chunk_request.document_id}"' ) if chunk_request.is_capped: id_based_retrieval_yql_section += ( f" and {CHUNK_ID} >= {chunk_request.min_chunk_ind or 0}" ) id_based_retrieval_yql_section += ( f" and {CHUNK_ID} <= {chunk_request.max_chunk_ind}" ) id_based_retrieval_yql_section += ")" return id_based_retrieval_yql_section ================================================ FILE: backend/onyx/document_index/vespa/vespa_document_index.py ================================================ import concurrent.futures import logging import random from collections.abc import Generator from collections.abc import Iterable from typing import Any from uuid import UUID import httpx from pydantic import BaseModel from retry import retry from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH from onyx.configs.app_configs import RECENCY_BIAS_MULTIPLIER from onyx.configs.app_configs import RERANK_COUNT from onyx.configs.chat_configs import DOC_TIME_DECAY from onyx.configs.chat_configs import HYBRID_ALPHA from onyx.configs.chat_configs import TITLE_CONTENT_RATIO from onyx.context.search.enums import QueryType from onyx.context.search.models import IndexFilters from onyx.context.search.models import InferenceChunk from onyx.db.enums import EmbeddingPrecision from onyx.document_index.chunk_content_enrichment import cleanup_content_for_chunks from onyx.document_index.document_index_utils import get_document_chunk_ids from onyx.document_index.document_index_utils import get_uuid_from_chunk_info from onyx.document_index.interfaces import EnrichedDocumentIndexingInfo from onyx.document_index.interfaces import MinimalDocumentIndexingInfo from onyx.document_index.interfaces import VespaChunkRequest from onyx.document_index.interfaces_new import DocumentIndex from onyx.document_index.interfaces_new import DocumentInsertionRecord from onyx.document_index.interfaces_new import DocumentSectionRequest from onyx.document_index.interfaces_new import IndexingMetadata from onyx.document_index.interfaces_new import MetadataUpdateRequest from onyx.document_index.interfaces_new import TenantState from onyx.document_index.vespa.chunk_retrieval import batch_search_api_retrieval from onyx.document_index.vespa.chunk_retrieval import get_all_chunks_paginated from onyx.document_index.vespa.chunk_retrieval import get_chunks_via_visit_api from onyx.document_index.vespa.chunk_retrieval import ( parallel_visit_api_retrieval, ) from onyx.document_index.vespa.chunk_retrieval import query_vespa from onyx.document_index.vespa.deletion import delete_vespa_chunks from onyx.document_index.vespa.indexing_utils import BaseHTTPXClientContext from onyx.document_index.vespa.indexing_utils import batch_index_vespa_chunks from onyx.document_index.vespa.indexing_utils import check_for_final_chunk_existence from onyx.document_index.vespa.indexing_utils import clean_chunk_id_copy from onyx.document_index.vespa.indexing_utils import GlobalHTTPXClientContext from onyx.document_index.vespa.indexing_utils import TemporaryHTTPXClientContext from onyx.document_index.vespa.shared_utils.utils import get_vespa_http_client from onyx.document_index.vespa.shared_utils.utils import ( replace_invalid_doc_id_characters, ) from onyx.document_index.vespa.shared_utils.vespa_request_builders import ( build_vespa_filters, ) from onyx.document_index.vespa_constants import BATCH_SIZE from onyx.document_index.vespa_constants import CHUNK_ID from onyx.document_index.vespa_constants import CONTENT_SUMMARY from onyx.document_index.vespa_constants import DOCUMENT_ID from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT from onyx.document_index.vespa_constants import NUM_THREADS from onyx.document_index.vespa_constants import SEARCH_ENDPOINT from onyx.document_index.vespa_constants import VESPA_TIMEOUT from onyx.document_index.vespa_constants import YQL_BASE from onyx.indexing.models import DocMetadataAwareIndexChunk from onyx.tools.tool_implementations.search.constants import KEYWORD_QUERY_HYBRID_ALPHA from onyx.utils.batching import batch_generator from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT from shared_configs.model_server_models import Embedding logger = setup_logger(__name__) # Set the logging level to WARNING to ignore INFO and DEBUG logs from httpx. By # default it emits INFO-level logs for every request. httpx_logger = logging.getLogger("httpx") httpx_logger.setLevel(logging.WARNING) def _enrich_basic_chunk_info( index_name: str, http_client: httpx.Client, document_id: str, previous_chunk_count: int | None, new_chunk_count: int, ) -> EnrichedDocumentIndexingInfo: """Determines which chunks need to be deleted during document reindexing. When a document is reindexed, it may have fewer chunks than before. This function identifies the range of old chunks that need to be deleted by comparing the new chunk count with the previous chunk count. Example: If a document previously had 10 chunks (0-9) and now has 7 chunks (0-6), this function identifies that chunks 7-9 need to be deleted. Args: index_name: The Vespa index/schema name. http_client: HTTP client for making requests to Vespa. document_id: The Vespa-sanitized ID of the document being reindexed. previous_chunk_count: The total number of chunks the document had before reindexing. None for documents using the legacy chunk ID system. new_chunk_count: The total number of chunks the document has after reindexing. This becomes the starting index for deletion since chunks are 0-indexed. Returns: EnrichedDocumentIndexingInfo with chunk_start_index set to new_chunk_count (where deletion begins) and chunk_end_index set to previous_chunk_count (where deletion ends). """ # Technically last indexed chunk index +1. last_indexed_chunk = previous_chunk_count # If the document has no `chunk_count` in the database, we know that it # has the old chunk ID system and we must check for the final chunk index. is_old_version = False if last_indexed_chunk is None: is_old_version = True minimal_doc_info = MinimalDocumentIndexingInfo( doc_id=document_id, chunk_start_index=new_chunk_count ) last_indexed_chunk = check_for_final_chunk_existence( minimal_doc_info=minimal_doc_info, start_index=new_chunk_count, index_name=index_name, http_client=http_client, ) assert ( last_indexed_chunk is not None and last_indexed_chunk >= 0 ), f"Bug: Last indexed chunk index is None or less than 0 for document: {document_id}." enriched_doc_info = EnrichedDocumentIndexingInfo( doc_id=document_id, chunk_start_index=new_chunk_count, chunk_end_index=last_indexed_chunk, old_version=is_old_version, ) return enriched_doc_info @retry( tries=3, delay=1, backoff=2, exceptions=httpx.HTTPError, ) def _update_single_chunk( doc_chunk_id: UUID, index_name: str, doc_id: str, http_client: httpx.Client, update_request: MetadataUpdateRequest, ) -> None: """Updates a single document chunk in Vespa. TODO(andrei): Couldn't this be batched? Args: doc_chunk_id: The ID of the chunk to update. index_name: The index the chunk belongs to. doc_id: The ID of the document the chunk belongs to. Used only for logging. http_client: The HTTP client to use to make the request. update_request: Metadata update request object received in the bulk update method containing fields to update. """ class _Boost(BaseModel): model_config = {"frozen": True} assign: float class _DocumentSets(BaseModel): model_config = {"frozen": True} assign: dict[str, int] class _AccessControl(BaseModel): model_config = {"frozen": True} assign: dict[str, int] class _Hidden(BaseModel): model_config = {"frozen": True} assign: bool class _UserProjects(BaseModel): model_config = {"frozen": True} assign: list[int] class _Personas(BaseModel): model_config = {"frozen": True} assign: list[int] class _VespaPutFields(BaseModel): model_config = {"frozen": True} # The names of these fields are based the Vespa schema. Changes to the # schema require changes here. These names were originally found in # backend/onyx/document_index/vespa_constants.py. boost: _Boost | None = None document_sets: _DocumentSets | None = None access_control_list: _AccessControl | None = None hidden: _Hidden | None = None user_project: _UserProjects | None = None personas: _Personas | None = None class _VespaPutRequest(BaseModel): model_config = {"frozen": True} fields: _VespaPutFields boost_update: _Boost | None = ( _Boost(assign=update_request.boost) if update_request.boost is not None else None ) document_sets_update: _DocumentSets | None = ( _DocumentSets( assign={document_set: 1 for document_set in update_request.document_sets} ) if update_request.document_sets is not None else None ) access_update: _AccessControl | None = ( _AccessControl( assign={acl_entry: 1 for acl_entry in update_request.access.to_acl()} ) if update_request.access is not None else None ) hidden_update: _Hidden | None = ( _Hidden(assign=update_request.hidden) if update_request.hidden is not None else None ) user_projects_update: _UserProjects | None = ( _UserProjects(assign=list(update_request.project_ids)) if update_request.project_ids is not None else None ) personas_update: _Personas | None = ( _Personas(assign=list(update_request.persona_ids)) if update_request.persona_ids is not None else None ) vespa_put_fields = _VespaPutFields( boost=boost_update, document_sets=document_sets_update, access_control_list=access_update, hidden=hidden_update, user_project=user_projects_update, personas=personas_update, ) vespa_put_request = _VespaPutRequest( fields=vespa_put_fields, ) vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=index_name)}/{doc_chunk_id}?create=true" try: resp = http_client.put( vespa_url, headers={"Content-Type": "application/json"}, json=vespa_put_request.model_dump( exclude_none=True ), # NOTE: Important to not produce null fields in the json. ) resp.raise_for_status() except httpx.HTTPStatusError as e: logger.error( f"Failed to update doc chunk {doc_chunk_id} (doc_id={doc_id}). " f"Code: {e.response.status_code}. Details: {e.response.text}" ) # Re-raise so the @retry decorator will catch and retry, unless the # status code is < 5xx, in which case wrap the exception in something # other than an HTTPError to skip retries. if e.response.status_code >= 500: raise raise RuntimeError( f"Non-retryable error updating chunk {doc_chunk_id}: {e}" ) from e class VespaDocumentIndex(DocumentIndex): """Vespa-specific implementation of the DocumentIndex interface. This class provides document indexing, retrieval, and management operations for a Vespa search engine instance. It handles the complete lifecycle of document chunks within a specific Vespa index/schema. """ def __init__( self, index_name: str, tenant_state: TenantState, large_chunks_enabled: bool, httpx_client: httpx.Client | None = None, ) -> None: self._index_name = index_name self._tenant_id = tenant_state.tenant_id self._large_chunks_enabled = large_chunks_enabled # NOTE: using `httpx` here since `requests` doesn't support HTTP2. This # is beneficial for indexing / updates / deletes since we have to make a # large volume of requests. self._httpx_client_context: BaseHTTPXClientContext if httpx_client: # Use the provided client. Because this client is presumed global, # it does not close after exiting a context manager. self._httpx_client_context = GlobalHTTPXClientContext(httpx_client) else: # We did not receive a client, so create one what will close after # exiting a context manager. self._httpx_client_context = TemporaryHTTPXClientContext( get_vespa_http_client ) self._multitenant = tenant_state.multitenant def verify_and_create_index_if_necessary( self, embedding_dim: int, embedding_precision: EmbeddingPrecision ) -> None: raise NotImplementedError def index( self, chunks: Iterable[DocMetadataAwareIndexChunk], indexing_metadata: IndexingMetadata, ) -> list[DocumentInsertionRecord]: doc_id_to_chunk_cnt_diff = indexing_metadata.doc_id_to_chunk_cnt_diff doc_id_to_previous_chunk_cnt = { doc_id: chunk_cnt_diff.old_chunk_cnt for doc_id, chunk_cnt_diff in doc_id_to_chunk_cnt_diff.items() } doc_id_to_new_chunk_cnt = { doc_id: chunk_cnt_diff.new_chunk_cnt for doc_id, chunk_cnt_diff in doc_id_to_chunk_cnt_diff.items() } assert ( len(doc_id_to_chunk_cnt_diff) == len(doc_id_to_previous_chunk_cnt) == len(doc_id_to_new_chunk_cnt) ), "Bug: Doc ID to chunk maps have different lengths." # Vespa has restrictions on valid characters, yet document IDs come from # external w.r.t. this class. We need to sanitize them. # # Instead of materializing all cleaned chunks upfront, we stream them # through a generator that cleans IDs and builds the original-ID mapping # incrementally as chunks flow into Vespa. def _clean_and_track( chunks_iter: Iterable[DocMetadataAwareIndexChunk], id_map: dict[str, str], seen_ids: set[str], ) -> Generator[DocMetadataAwareIndexChunk, None, None]: """Cleans chunk IDs and builds the original-ID mapping incrementally as chunks flow through, avoiding a separate materialization pass.""" for chunk in chunks_iter: original_id = chunk.source_document.id cleaned = clean_chunk_id_copy(chunk) cleaned_id = cleaned.source_document.id # Needed so the final DocumentInsertionRecord returned can have # the original document ID. cleaned_chunks might not contain IDs # exactly as callers supplied them. id_map[cleaned_id] = original_id seen_ids.add(cleaned_id) yield cleaned new_document_id_to_original_document_id: dict[str, str] = {} all_cleaned_doc_ids: set[str] = set() existing_docs: set[str] = set() with ( concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor, self._httpx_client_context as http_client, ): # We require the start and end index for each document in order to # know precisely which chunks to delete. This information exists for # documents that have `chunk_count` in the database, but not for # `old_version` documents. enriched_doc_infos: list[EnrichedDocumentIndexingInfo] = [ _enrich_basic_chunk_info( index_name=self._index_name, http_client=http_client, document_id=doc_id, previous_chunk_count=doc_id_to_previous_chunk_cnt[doc_id], new_chunk_count=doc_id_to_new_chunk_cnt[doc_id], ) for doc_id in doc_id_to_chunk_cnt_diff.keys() # TODO(andrei), WARNING: Don't we need to sanitize these doc IDs? ] for enriched_doc_info in enriched_doc_infos: # If the document has previously indexed chunks, we know it # previously existed and this is a reindex. if enriched_doc_info.chunk_end_index: existing_docs.add(enriched_doc_info.doc_id) # Now, for each doc, we know exactly where to start and end our # deletion. So let's generate the chunk IDs for each chunk to # delete. # WARNING: This code seems to use # indexing_metadata.doc_id_to_chunk_cnt_diff as the source of truth # for which chunks to delete. This implies that the onus is on the # caller to ensure doc_id_to_chunk_cnt_diff only contains docs # relevant to the chunks argument to this method. This should not be # the contract of DocumentIndex; and this code is only a refactor # from old code. It would seem we should use all_cleaned_doc_ids as # the source of truth. chunks_to_delete = get_document_chunk_ids( enriched_document_info_list=enriched_doc_infos, tenant_id=self._tenant_id, large_chunks_enabled=self._large_chunks_enabled, ) # Delete old Vespa documents. for doc_chunk_ids_batch in batch_generator(chunks_to_delete, BATCH_SIZE): delete_vespa_chunks( doc_chunk_ids=doc_chunk_ids_batch, index_name=self._index_name, http_client=http_client, executor=executor, ) # Insert new Vespa documents, streaming through the cleaning # pipeline so chunks are never fully materialized. cleaned_chunks = _clean_and_track( chunks, new_document_id_to_original_document_id, all_cleaned_doc_ids, ) for chunk_batch in batch_generator( cleaned_chunks, min(BATCH_SIZE, MAX_CHUNKS_PER_DOC_BATCH) ): batch_index_vespa_chunks( chunks=chunk_batch, index_name=self._index_name, http_client=http_client, multitenant=self._multitenant, executor=executor, ) return [ DocumentInsertionRecord( document_id=new_document_id_to_original_document_id[cleaned_doc_id], already_existed=cleaned_doc_id in existing_docs, ) for cleaned_doc_id in all_cleaned_doc_ids ] def delete(self, document_id: str, chunk_count: int | None = None) -> int: total_chunks_deleted = 0 sanitized_doc_id = replace_invalid_doc_id_characters(document_id) with ( concurrent.futures.ThreadPoolExecutor(max_workers=NUM_THREADS) as executor, self._httpx_client_context as http_client, ): enriched_doc_info = _enrich_basic_chunk_info( index_name=self._index_name, http_client=http_client, document_id=sanitized_doc_id, previous_chunk_count=chunk_count, new_chunk_count=0, ) chunks_to_delete = get_document_chunk_ids( enriched_document_info_list=[enriched_doc_info], tenant_id=self._tenant_id, large_chunks_enabled=self._large_chunks_enabled, ) for doc_chunk_ids_batch in batch_generator(chunks_to_delete, BATCH_SIZE): total_chunks_deleted += len(doc_chunk_ids_batch) delete_vespa_chunks( doc_chunk_ids=doc_chunk_ids_batch, index_name=self._index_name, http_client=http_client, executor=executor, ) return total_chunks_deleted def update( self, update_requests: list[MetadataUpdateRequest], ) -> None: # WARNING: This method can be called by vespa_metadata_sync_task, which # is kicked off by check_for_vespa_sync_task, notably before a document # has finished indexing. In this way, chunk_count below could be unknown # even for chunks not on the "old" chunk ID system; i.e. there could be # a race condition. Passing in None to _enrich_basic_chunk_info should # handle this, but a higher level TODO might be to not run update at all # on connectors that are still indexing, and therefore do not yet have a # chunk count because update_docs_chunk_count__no_commit has not been # run yet. with self._httpx_client_context as httpx_client: # Each invocation of this method can contain multiple update requests. for update_request in update_requests: # Each update request can correspond to multiple documents. for doc_id in update_request.document_ids: # NOTE: -1 represents an unknown chunk count. chunk_count = update_request.doc_id_to_chunk_cnt[doc_id] sanitized_doc_id = replace_invalid_doc_id_characters(doc_id) enriched_doc_info = _enrich_basic_chunk_info( index_name=self._index_name, http_client=httpx_client, document_id=sanitized_doc_id, previous_chunk_count=chunk_count if chunk_count >= 0 else None, new_chunk_count=0, # WARNING: This semantically makes no sense and is misusing this function. ) doc_chunk_ids = get_document_chunk_ids( enriched_document_info_list=[enriched_doc_info], tenant_id=self._tenant_id, large_chunks_enabled=self._large_chunks_enabled, ) for doc_chunk_id in doc_chunk_ids: _update_single_chunk( doc_chunk_id, self._index_name, # NOTE: Used only for logging, raw ID is ok here. doc_id, httpx_client, update_request, ) logger.info( f"Updated {len(doc_chunk_ids)} chunks for document {doc_id}." ) def id_based_retrieval( self, chunk_requests: list[DocumentSectionRequest], filters: IndexFilters, batch_retrieval: bool = False, ) -> list[InferenceChunk]: sanitized_chunk_requests = [ VespaChunkRequest( document_id=replace_invalid_doc_id_characters( chunk_request.document_id ), min_chunk_ind=chunk_request.min_chunk_ind, max_chunk_ind=chunk_request.max_chunk_ind, ) for chunk_request in chunk_requests ] if batch_retrieval: return cleanup_content_for_chunks( batch_search_api_retrieval( index_name=self._index_name, chunk_requests=sanitized_chunk_requests, filters=filters, # No one was passing in this parameter in the legacy # interface, it always defaulted to False. get_large_chunks=False, ) ) return cleanup_content_for_chunks( parallel_visit_api_retrieval( index_name=self._index_name, chunk_requests=sanitized_chunk_requests, filters=filters, # No one was passing in this parameter in the legacy interface, # it always defaulted to False. get_large_chunks=False, ) ) def hybrid_retrieval( self, query: str, query_embedding: Embedding, final_keywords: list[str] | None, query_type: QueryType, filters: IndexFilters, num_to_retrieve: int, ) -> list[InferenceChunk]: vespa_where_clauses = build_vespa_filters(filters) # Avoid over-fetching a very large candidate set for global-phase reranking. # Keep enough headroom for quality while capping cost on larger indices. target_hits = min(max(4 * num_to_retrieve, 100), RERANK_COUNT) yql = ( YQL_BASE.format(index_name=self._index_name) + vespa_where_clauses + f"(({{targetHits: {target_hits}}}nearestNeighbor(embeddings, query_embedding)) " + f"or ({{targetHits: {target_hits}}}nearestNeighbor(title_embedding, query_embedding)) " + 'or ({grammar: "weakAnd"}userInput(@query)) ' + f'or ({{defaultIndex: "{CONTENT_SUMMARY}"}}userInput(@query)))' ) final_query = " ".join(final_keywords) if final_keywords else query ranking_profile = ( f"hybrid_search_{query_type.value}_base_{len(query_embedding)}" ) logger.info(f"Selected ranking profile: {ranking_profile}") logger.debug(f"Query YQL: {yql}") # In this interface we do not pass in hybrid alpha. Tracing the codepath # of the legacy Vespa interface, it so happens that KEYWORD always # corresponds to an alpha of 0.2 (from KEYWORD_QUERY_HYBRID_ALPHA), and # SEMANTIC to 0.5 (from HYBRID_ALPHA). HYBRID_ALPHA_KEYWORD was only # used in dead code so we do not use it here. hybrid_alpha = ( KEYWORD_QUERY_HYBRID_ALPHA if query_type == QueryType.KEYWORD else HYBRID_ALPHA ) params: dict[str, str | int | float] = { "yql": yql, "query": final_query, "input.query(query_embedding)": str(query_embedding), "input.query(decay_factor)": str(DOC_TIME_DECAY * RECENCY_BIAS_MULTIPLIER), "input.query(alpha)": hybrid_alpha, "input.query(title_content_ratio)": TITLE_CONTENT_RATIO, "hits": num_to_retrieve, "ranking.profile": ranking_profile, "timeout": VESPA_TIMEOUT, } return cleanup_content_for_chunks(query_vespa(params)) def keyword_retrieval( self, query: str, filters: IndexFilters, num_to_retrieve: int, ) -> list[InferenceChunk]: raise NotImplementedError def semantic_retrieval( self, query_embedding: Embedding, filters: IndexFilters, num_to_retrieve: int, ) -> list[InferenceChunk]: raise NotImplementedError def random_retrieval( self, filters: IndexFilters, num_to_retrieve: int = 100, dirty: bool | None = None, # noqa: ARG002 ) -> list[InferenceChunk]: vespa_where_clauses = build_vespa_filters(filters, remove_trailing_and=True) yql = YQL_BASE.format(index_name=self._index_name) + vespa_where_clauses random_seed = random.randint(0, 1_000_000) params: dict[str, str | int | float] = { "yql": yql, "hits": num_to_retrieve, "timeout": VESPA_TIMEOUT, "ranking.profile": "random_", "ranking.properties.random.seed": random_seed, } return cleanup_content_for_chunks(query_vespa(params)) def get_raw_document_chunks(self, document_id: str) -> list[dict[str, Any]]: """Gets all raw document chunks for a document as returned by Vespa. Used in the Vespa migration task. Args: document_id: The ID of the document to get chunks for. Returns: List of raw document chunks. """ # Vespa doc IDs are sanitized using replace_invalid_doc_id_characters. sanitized_document_id = replace_invalid_doc_id_characters(document_id) chunk_request = VespaChunkRequest(document_id=sanitized_document_id) raw_chunks = get_chunks_via_visit_api( chunk_request=chunk_request, index_name=self._index_name, filters=IndexFilters(access_control_list=None, tenant_id=self._tenant_id), get_large_chunks=False, short_tensor_format=True, ) # Vespa returns other metadata around the actual document chunk. The raw # chunk we're interested in is in the "fields" field. raw_document_chunks = [chunk["fields"] for chunk in raw_chunks] return raw_document_chunks def get_all_raw_document_chunks_paginated( self, continuation_token_map: dict[int, str | None], page_size: int, ) -> tuple[list[dict[str, Any]], dict[int, str | None]]: """Gets all the chunks in Vespa, paginated. Used in the chunk-level Vespa-to-OpenSearch migration task. Args: continuation_token: Token returned by Vespa representing a page offset. None to start from the beginning. Defaults to None. page_size: Best-effort batch size for the visit. Returns: Tuple of (list of chunk dicts, next continuation token or None). The continuation token is None when the visit is complete. """ raw_chunks, next_continuation_token_map = get_all_chunks_paginated( index_name=self._index_name, tenant_state=TenantState( tenant_id=self._tenant_id, multitenant=MULTI_TENANT ), continuation_token_map=continuation_token_map, page_size=page_size, ) return raw_chunks, next_continuation_token_map def index_raw_chunks(self, chunks: list[dict[str, Any]]) -> None: """Indexes raw document chunks into Vespa. To only be used in tests. Not for production. """ json_header = { "Content-Type": "application/json", } with self._httpx_client_context as http_client: for chunk in chunks: chunk_id = str( get_uuid_from_chunk_info( document_id=chunk[DOCUMENT_ID], chunk_id=chunk[CHUNK_ID], tenant_id=self._tenant_id, ) ) vespa_url = f"{DOCUMENT_ID_ENDPOINT.format(index_name=self._index_name)}/{chunk_id}" response = http_client.post( vespa_url, headers=json_header, json={"fields": chunk}, ) response.raise_for_status() def get_chunk_count(self) -> int: """Returns the exact number of document chunks in Vespa for this tenant. Uses the Vespa Search API with `limit 0` and `ranking.profile=unranked` to get an exact count without fetching any document data. Includes large chunks. There is no way to filter these out using the Search API. """ where_clause = ( f'tenant_id contains "{self._tenant_id}"' if self._multitenant else "true" ) yql = f"select documentid from {self._index_name} where {where_clause} limit 0" params: dict[str, str | int] = { "yql": yql, "ranking.profile": "unranked", "timeout": VESPA_TIMEOUT, } with get_vespa_http_client() as http_client: response = http_client.post(SEARCH_ENDPOINT, json=params) response.raise_for_status() response_data = response.json() return response_data["root"]["fields"]["totalCount"] ================================================ FILE: backend/onyx/document_index/vespa_constants.py ================================================ from onyx.configs.app_configs import VESPA_CLOUD_URL from onyx.configs.app_configs import VESPA_CONFIG_SERVER_HOST from onyx.configs.app_configs import VESPA_HOST from onyx.configs.app_configs import VESPA_PORT from onyx.configs.app_configs import VESPA_TENANT_PORT from onyx.configs.constants import SOURCE_TYPE # config server VESPA_CONFIG_SERVER_URL = ( VESPA_CLOUD_URL or f"http://{VESPA_CONFIG_SERVER_HOST}:{VESPA_TENANT_PORT}" ) VESPA_APPLICATION_ENDPOINT = f"{VESPA_CONFIG_SERVER_URL}/application/v2" # main search application VESPA_APP_CONTAINER_URL = VESPA_CLOUD_URL or f"http://{VESPA_HOST}:{VESPA_PORT}" # danswer_chunk below is defined in vespa/app_configs/schemas/danswer_chunk.sd.jinja DOCUMENT_ID_ENDPOINT = ( f"{VESPA_APP_CONTAINER_URL}/document/v1/default/{{index_name}}/docid" ) # the default document id endpoint is http://localhost:8080/document/v1/default/danswer_chunk/docid SEARCH_ENDPOINT = f"{VESPA_APP_CONTAINER_URL}/search/" # Since Vespa doesn't allow batching of inserts / updates, we use threads to # parallelize the operations. NUM_THREADS = 32 MAX_ID_SEARCH_QUERY_SIZE = 400 # Suspect that adding too many "or" conditions will cause Vespa to timeout and return # an empty list of hits (with no error status and coverage: 0 and degraded) MAX_OR_CONDITIONS = 10 # up from 500ms for now, since we've seen quite a few timeouts # in the long term, we are looking to improve the performance of Vespa # so that we can bring this back to default VESPA_TIMEOUT = "10s" # The size of the batch to use for batched operations like inserts / updates. # The batch will likely be sent to a threadpool of size NUM_THREADS. BATCH_SIZE = 128 TENANT_ID = "tenant_id" DOCUMENT_ID = "document_id" CHUNK_ID = "chunk_id" BLURB = "blurb" CONTENT = "content" SOURCE_LINKS = "source_links" SEMANTIC_IDENTIFIER = "semantic_identifier" TITLE = "title" SKIP_TITLE_EMBEDDING = "skip_title" SECTION_CONTINUATION = "section_continuation" EMBEDDINGS = "embeddings" TITLE_EMBEDDING = "title_embedding" ACCESS_CONTROL_LIST = "access_control_list" DOCUMENT_SETS = "document_sets" USER_FILE = "user_file" USER_FOLDER = "user_folder" USER_PROJECT = "user_project" PERSONAS = "personas" LARGE_CHUNK_REFERENCE_IDS = "large_chunk_reference_ids" METADATA = "metadata" METADATA_LIST = "metadata_list" METADATA_SUFFIX = "metadata_suffix" DOC_SUMMARY = "doc_summary" CHUNK_CONTEXT = "chunk_context" BOOST = "boost" AGGREGATED_CHUNK_BOOST_FACTOR = "aggregated_chunk_boost_factor" DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch PRIMARY_OWNERS = "primary_owners" SECONDARY_OWNERS = "secondary_owners" RECENCY_BIAS = "recency_bias" HIDDEN = "hidden" # for legacy reasons, called `name` in Vespa despite it really being an ID IMAGE_FILE_NAME = "image_file_name" # Specific to Vespa, needed for highlighting matching keywords / section CONTENT_SUMMARY = "content_summary" FULL_CHUNK_EMBEDDING_KEY = "full_chunk" YQL_BASE = ( f"select " f"documentid, " f"{DOCUMENT_ID}, " f"{CHUNK_ID}, " f"{BLURB}, " f"{CONTENT}, " f"{SOURCE_TYPE}, " f"{SOURCE_LINKS}, " f"{SEMANTIC_IDENTIFIER}, " f"{TITLE}, " f"{SECTION_CONTINUATION}, " f"{IMAGE_FILE_NAME}, " f"{BOOST}, " f"{AGGREGATED_CHUNK_BOOST_FACTOR}, " f"{HIDDEN}, " f"{DOC_UPDATED_AT}, " f"{PRIMARY_OWNERS}, " f"{SECONDARY_OWNERS}, " f"{LARGE_CHUNK_REFERENCE_IDS}, " f"{METADATA}, " f"{METADATA_SUFFIX}, " f"{DOC_SUMMARY}, " f"{CHUNK_CONTEXT}, " f"{CONTENT_SUMMARY} " f"from {{index_name}} where " ) ================================================ FILE: backend/onyx/error_handling/__init__.py ================================================ ================================================ FILE: backend/onyx/error_handling/error_codes.py ================================================ """ Standardized error codes for the Onyx backend. Usage: from onyx.error_handling.error_codes import OnyxErrorCode from onyx.error_handling.exceptions import OnyxError raise OnyxError(OnyxErrorCode.UNAUTHENTICATED, "Token expired") """ from enum import Enum class OnyxErrorCode(Enum): """ Each member is a tuple of (error_code_string, http_status_code). The error_code_string is a stable, machine-readable identifier that API consumers can match on. The http_status_code is the default HTTP status to return. """ # ------------------------------------------------------------------ # Authentication (401) # ------------------------------------------------------------------ UNAUTHENTICATED = ("UNAUTHENTICATED", 401) INVALID_TOKEN = ("INVALID_TOKEN", 401) TOKEN_EXPIRED = ("TOKEN_EXPIRED", 401) CSRF_FAILURE = ("CSRF_FAILURE", 403) # ------------------------------------------------------------------ # Authorization (403) # ------------------------------------------------------------------ UNAUTHORIZED = ("UNAUTHORIZED", 403) INSUFFICIENT_PERMISSIONS = ("INSUFFICIENT_PERMISSIONS", 403) ADMIN_ONLY = ("ADMIN_ONLY", 403) EE_REQUIRED = ("EE_REQUIRED", 403) SINGLE_TENANT_ONLY = ("SINGLE_TENANT_ONLY", 403) ENV_VAR_GATED = ("ENV_VAR_GATED", 403) # ------------------------------------------------------------------ # Validation / Bad Request (400) # ------------------------------------------------------------------ VALIDATION_ERROR = ("VALIDATION_ERROR", 400) INVALID_INPUT = ("INVALID_INPUT", 400) MISSING_REQUIRED_FIELD = ("MISSING_REQUIRED_FIELD", 400) QUERY_REJECTED = ("QUERY_REJECTED", 400) # ------------------------------------------------------------------ # Not Found (404) # ------------------------------------------------------------------ NOT_FOUND = ("NOT_FOUND", 404) CONNECTOR_NOT_FOUND = ("CONNECTOR_NOT_FOUND", 404) CREDENTIAL_NOT_FOUND = ("CREDENTIAL_NOT_FOUND", 404) PERSONA_NOT_FOUND = ("PERSONA_NOT_FOUND", 404) DOCUMENT_NOT_FOUND = ("DOCUMENT_NOT_FOUND", 404) SESSION_NOT_FOUND = ("SESSION_NOT_FOUND", 404) USER_NOT_FOUND = ("USER_NOT_FOUND", 404) # ------------------------------------------------------------------ # Conflict (409) # ------------------------------------------------------------------ CONFLICT = ("CONFLICT", 409) DUPLICATE_RESOURCE = ("DUPLICATE_RESOURCE", 409) # ------------------------------------------------------------------ # Rate Limiting / Quotas (429 / 402) # ------------------------------------------------------------------ RATE_LIMITED = ("RATE_LIMITED", 429) SEAT_LIMIT_EXCEEDED = ("SEAT_LIMIT_EXCEEDED", 402) # ------------------------------------------------------------------ # Payload (413) # ------------------------------------------------------------------ PAYLOAD_TOO_LARGE = ("PAYLOAD_TOO_LARGE", 413) # ------------------------------------------------------------------ # Connector / Credential Errors (400-range) # ------------------------------------------------------------------ CONNECTOR_VALIDATION_FAILED = ("CONNECTOR_VALIDATION_FAILED", 400) CREDENTIAL_INVALID = ("CREDENTIAL_INVALID", 400) CREDENTIAL_EXPIRED = ("CREDENTIAL_EXPIRED", 401) # ------------------------------------------------------------------ # Server Errors (5xx) # ------------------------------------------------------------------ INTERNAL_ERROR = ("INTERNAL_ERROR", 500) NOT_IMPLEMENTED = ("NOT_IMPLEMENTED", 501) SERVICE_UNAVAILABLE = ("SERVICE_UNAVAILABLE", 503) BAD_GATEWAY = ("BAD_GATEWAY", 502) LLM_PROVIDER_ERROR = ("LLM_PROVIDER_ERROR", 502) HOOK_EXECUTION_FAILED = ("HOOK_EXECUTION_FAILED", 502) GATEWAY_TIMEOUT = ("GATEWAY_TIMEOUT", 504) def __init__(self, code: str, status_code: int) -> None: self.code = code self.status_code = status_code def detail(self, message: str | None = None) -> dict[str, str]: """Build a structured error detail dict. Returns a dict like: {"error_code": "UNAUTHENTICATED", "detail": "Token expired"} If no message is supplied, the error code itself is used as the detail. """ return { "error_code": self.code, "detail": message or self.code, } ================================================ FILE: backend/onyx/error_handling/exceptions.py ================================================ """OnyxError — the single exception type for all Onyx business errors. Raise ``OnyxError`` instead of ``HTTPException`` in business code. A global FastAPI exception handler (registered via ``register_onyx_exception_handlers``) converts it into a JSON response with the standard ``{"error_code": "...", "detail": "..."}`` shape. Usage:: from onyx.error_handling.error_codes import OnyxErrorCode from onyx.error_handling.exceptions import OnyxError raise OnyxError(OnyxErrorCode.NOT_FOUND, "Session not found") For upstream errors with a dynamic HTTP status (e.g. billing service), use ``status_code_override``:: raise OnyxError( OnyxErrorCode.BAD_GATEWAY, detail, status_code_override=upstream_status, ) """ from fastapi import FastAPI from fastapi import Request from fastapi.responses import JSONResponse from onyx.error_handling.error_codes import OnyxErrorCode from onyx.utils.logger import setup_logger logger = setup_logger() class OnyxError(Exception): """Structured error that maps to a specific ``OnyxErrorCode``. Attributes: error_code: The ``OnyxErrorCode`` enum member. detail: Human-readable detail (defaults to the error code string). status_code: HTTP status — either overridden or from the error code. """ def __init__( self, error_code: OnyxErrorCode, detail: str | None = None, *, status_code_override: int | None = None, ) -> None: resolved_detail = detail or error_code.code super().__init__(resolved_detail) self.error_code = error_code self.detail = resolved_detail self._status_code_override = status_code_override @property def status_code(self) -> int: return self._status_code_override or self.error_code.status_code def log_onyx_error(exc: OnyxError) -> None: detail = exc.detail status_code = exc.status_code if status_code >= 500: logger.error(f"OnyxError {exc.error_code.code}: {detail}") elif status_code >= 400: logger.warning(f"OnyxError {exc.error_code.code}: {detail}") def onyx_error_to_json_response(exc: OnyxError) -> JSONResponse: return JSONResponse( status_code=exc.status_code, content=exc.error_code.detail(exc.detail), ) def register_onyx_exception_handlers(app: FastAPI) -> None: """Register a global handler that converts ``OnyxError`` to JSON responses. Must be called *after* the app is created but *before* it starts serving. The handler logs at WARNING for 4xx and ERROR for 5xx. """ @app.exception_handler(OnyxError) async def _handle_onyx_error( request: Request, # noqa: ARG001 exc: OnyxError, ) -> JSONResponse: log_onyx_error(exc) return onyx_error_to_json_response(exc) ================================================ FILE: backend/onyx/evals/README.md ================================================ # Onyx Evaluations This directory contains the evaluation framework for testing and measuring the performance of Onyx's chat and retrieval systems. ## Overview The evaluation system uses [Braintrust](https://www.braintrust.dev/) to run automated evaluations against test datasets. It measures the quality of responses generated by Onyx's chat system and can be used to track performance improvements over time. ## Prerequisites **Important**: The model server must be running in order for evals to work properly. Make sure your model server is up and running before executing any evaluations. ## Running Evaluations Kick off a remote job ```bash onyx/backend$ python -m dotenv -f .vscode/.env run -- python onyx/evals/eval_cli.py --remote --api-key --search-permissions-email --remote --remote-dataset-name Simple ``` You can also run the CLI directly from the command line: ```bash onyx$ python -m dotenv -f .vscode/.env run -- python backend/onyx/evals/eval_cli.py --local-dataset-path backend/onyx/evals/data/eval.json --search-permissions-email richard@onyx.app ``` Save the env var ONYX_EVAL_API_KEY in your .env file so you don't have to specify it every time for triggering remote runs. You'll need to create an API key in the admin panel to run evals. ### Production Environment ### Local Development For local development, use the `eval_cli.py` script. We recommend starting it from the VS Code launch configuration for the best debugging experience. #### Using VS Code Launch Configuration 1. Open VS Code in the project root 2. Go to the "Run and Debug" panel (Ctrl/Cmd + Shift + D) 3. Select "Eval CLI" from the dropdown 4. Click the play button or press F5 This will run the evaluation with the following default settings: - Uses the local data file at `evals/data/data.json` - Enables verbose output - Sets up proper environment variables and Python path #### CLI Options - `--local-data-path`: Path to local JSON file containing test data (defaults to `evals/data/data.json`) - `--remote-dataset-name`: Name of remote Braintrust dataset - `--braintrust-project`: Braintrust project name (overrides `BRAINTRUST_PROJECT` env var) - `--verbose`: Enable verbose output - `--no-send-logs`: Skip sending logs to Braintrust (useful for local testing) - `--local-only`: Run evals locally without Braintrust, output results to CLI only ## Test Data The evaluation system uses test data stored in `evals/data/data.json`. This file contains a list of test cases, each with: - `input`: The question or prompt to test Example test case: ```json { "input": { "message": "What is the capital of France?" } } ``` ### Per-Test Configuration Configure tool forcing, assertions, and model settings per-test by adding optional fields to each test case. #### Tool Configuration - `force_tools`: List of tool type names to force for this specific test - `expected_tools`: List of tool type names expected to be called - `require_all_tools`: If true, all expected tools must be called (default: false) #### Model Configuration - `model`: Model version to use (e.g., "gpt-4o", "claude-3-5-sonnet") - `model_provider`: Model provider (e.g., "openai", "anthropic") - `temperature`: Temperature for the model (default: 0.0) Example with tool and model configuration: ```json [ { "input": { "message": "Find information about Python programming" }, "expected_tools": ["SearchTool"], "force_tools": ["SearchTool"], "model": "gpt-4o" }, { "input": { "message": "Search the web for recent news about AI" }, "expected_tools": ["WebSearchTool"], "model": "claude-3-5-sonnet", "model_provider": "anthropic" }, { "input": { "message": "Calculate 2 + 2" }, "expected_tools": ["PythonTool"], "temperature": 0.5 } ] ``` ### Multi-Turn Evaluations For testing realistic multi-turn conversations where each turn may require different tools, use the `messages` array format instead of a single `message`: ```json { "input": { "messages": [ { "message": "What's the latest news about OpenAI today?", "expected_tools": ["WebSearchTool", "OpenURLTool"] }, { "message": "Now search our internal docs for our OpenAI integration guide", "expected_tools": ["SearchTool"] }, { "message": "Thanks, that's helpful!", "expected_tools": [] } ] } } ``` Each message in the `messages` array can have its own configuration: - `message`: The user message text (required) - `expected_tools`: List of tool types expected to be called for this turn - `require_all_tools`: If true, all expected tools must be called (default: false) - `force_tools`: List of tool types to force for this turn - `model`: Model version override for this turn - `model_provider`: Model provider override for this turn - `temperature`: Temperature override for this turn Multi-turn evals run within a single chat session, so the model has full context of previous turns when responding. ### Available Tool Types The following built-in tool types can be used: - `SearchTool`: Internal document search - `WebSearchTool`: Internet/web search - `ImageGenerationTool`: Image generation - `PythonTool`: Python code execution - `OpenURLTool`: Open and read URLs ### Braintrust Dashboard After running evaluations, you can view results in the Braintrust dashboard. The evaluation will report: - `tool_assertion`: Score of 1.0 if tool assertions passed (or no assertions configured), 0.0 if failed - Metadata including `tools_called`, `tools_called_count`, and assertion details ================================================ FILE: backend/onyx/evals/eval.py ================================================ import time from collections.abc import Callable from collections.abc import Generator from contextlib import contextmanager from typing import Any from sqlalchemy import Engine from sqlalchemy import event from sqlalchemy.orm import Session from sqlalchemy.orm import sessionmaker from sqlalchemy.orm.session import SessionTransaction from onyx.chat.chat_state import ChatStateContainer from onyx.chat.models import ChatFullResponse from onyx.chat.process_message import gather_stream_full from onyx.chat.process_message import handle_stream_message_objects from onyx.configs.constants import DEFAULT_PERSONA_ID from onyx.db.chat import create_chat_session from onyx.db.engine.sql_engine import get_sqlalchemy_engine from onyx.db.users import get_user_by_email from onyx.evals.models import ChatFullEvalResult from onyx.evals.models import EvalationAck from onyx.evals.models import EvalConfigurationOptions from onyx.evals.models import EvalMessage from onyx.evals.models import EvalProvider from onyx.evals.models import EvalTimings from onyx.evals.models import EvalToolResult from onyx.evals.models import MultiTurnEvalResult from onyx.evals.models import ToolAssertion from onyx.evals.provider import get_provider from onyx.llm.override_models import LLMOverride from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE from onyx.server.query_and_chat.models import ChatSessionCreationRequest from onyx.server.query_and_chat.models import SendMessageRequest from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() @contextmanager def isolated_ephemeral_session_factory( engine: Engine, ) -> Generator[Callable[[], Session], None, None]: """ Create a session factory that creates sessions that run in a transaction that gets rolled back. This is useful for running evals without any lasting db side effects. """ tenant_id = get_current_tenant_id() schema_translate_map = {None: tenant_id} conn = engine.connect().execution_options(schema_translate_map=schema_translate_map) outer_tx = conn.begin() Maker = sessionmaker(bind=conn, expire_on_commit=False, future=True) def make_session() -> Session: s = Maker() s.begin_nested() @event.listens_for(s, "after_transaction_end") def _restart_savepoint( session: Session, transaction: SessionTransaction ) -> None: if transaction.nested and not ( transaction._parent is not None and transaction._parent.nested ): session.begin_nested() return s try: yield make_session finally: outer_tx.rollback() conn.close() def _chat_full_response_to_eval_result( full: ChatFullResponse, stream_start_time: float, ) -> ChatFullEvalResult: """Map ChatFullResponse from gather_stream_full to eval result components.""" tools_called = [tc.tool_name for tc in full.tool_calls] tool_call_details: list[dict[str, Any]] = [ {"tool_name": tc.tool_name, "tool_arguments": tc.tool_arguments} for tc in full.tool_calls ] stream_end_time = time.time() total_ms = (stream_end_time - stream_start_time) * 1000 timings = EvalTimings( total_ms=total_ms, llm_first_token_ms=None, tool_execution_ms={}, stream_processing_ms=total_ms, ) return ChatFullEvalResult( answer=full.answer, tools_called=tools_called, tool_call_details=tool_call_details, citations=full.citation_info, timings=timings, ) def evaluate_tool_assertions( tools_called: list[str], assertions: ToolAssertion | None, ) -> tuple[bool | None, str | None]: """ Evaluate tool assertions against the tools that were called. Args: tools_called: List of tool names that were called during evaluation assertions: Tool assertions to check, or None if no assertions Returns: Tuple of (passed, details) where: - passed: True if assertions passed, False if failed, None if no assertions - details: Human-readable explanation of the result """ if assertions is None: return None, None expected_tools = set(assertions.expected_tools) called_tools = set(tools_called) if assertions.require_all: # All expected tools must be called missing_tools = expected_tools - called_tools if missing_tools: return False, ( f"Missing expected tools: {sorted(missing_tools)}. Called tools: {sorted(called_tools)}" ) return True, ( f"All expected tools called: {sorted(expected_tools)}. Called tools: {sorted(called_tools)}" ) else: # At least one expected tool must be called matched_tools = expected_tools & called_tools if not matched_tools: return False, ( f"None of expected tools called. Expected one of: {sorted(expected_tools)}. Called tools: {sorted(called_tools)}" ) return True, ( f"Expected tool(s) called: {sorted(matched_tools)}. Called tools: {sorted(called_tools)}" ) def _get_answer_with_tools( eval_input: dict[str, Any], configuration: EvalConfigurationOptions, ) -> EvalToolResult: """ Get answer from the chat system with full tool call tracking. Args: eval_input: Dictionary containing: - 'message': The user message to send - 'force_tools' (optional): List of tool types to force for this input - 'expected_tools' (optional): List of tool types expected to be called - 'require_all_tools' (optional): If true, all expected tools must be called - 'model' (optional): Model version to use (e.g., "gpt-4o", "claude-3-5-sonnet") - 'model_provider' (optional): Model provider (e.g., "openai", "anthropic") - 'temperature' (optional): Temperature for the model configuration: Evaluation configuration options Returns: EvalToolResult containing the answer and tool call information """ engine = get_sqlalchemy_engine() with isolated_ephemeral_session_factory(engine) as SessionLocal: with SessionLocal() as db_session: full_configuration = configuration.get_configuration(db_session) # Handle per-input tool forcing (from data file) forced_tool_ids: list[int] = [] input_force_tools = eval_input.get("force_tools", []) if input_force_tools: from onyx.db.tools import get_builtin_tool from onyx.tools.built_in_tools import BUILT_IN_TOOL_MAP for tool_type in input_force_tools: if tool_type in BUILT_IN_TOOL_MAP: tool_id = get_builtin_tool( db_session, BUILT_IN_TOOL_MAP[tool_type] ).id if tool_id not in forced_tool_ids: forced_tool_ids.append(tool_id) # Build tool assertions from per-input config tool_assertions: ToolAssertion | None = None input_expected_tools = eval_input.get("expected_tools", []) if input_expected_tools: tool_assertions = ToolAssertion( expected_tools=input_expected_tools, require_all=eval_input.get("require_all_tools", False), ) # Handle per-input model configuration llm_override = full_configuration.llm input_model = eval_input.get("model") input_model_provider = eval_input.get("model_provider") input_temperature = eval_input.get("temperature") if input_model or input_model_provider or input_temperature is not None: # Create a new LLMOverride with per-input values, falling back to config llm_override = LLMOverride( model_provider=input_model_provider or llm_override.model_provider, model_version=input_model or llm_override.model_version, temperature=( input_temperature if input_temperature is not None else llm_override.temperature ), ) user = get_user_by_email(configuration.search_permissions_email, db_session) if not user: raise ValueError( f"User not found for email: {configuration.search_permissions_email}" ) forced_tool_id = forced_tool_ids[0] if forced_tool_ids else None request = SendMessageRequest( message=eval_input["message"], llm_override=llm_override, allowed_tool_ids=full_configuration.allowed_tool_ids, forced_tool_id=forced_tool_id, chat_session_info=ChatSessionCreationRequest( persona_id=DEFAULT_PERSONA_ID, description="Eval session", ), ) stream_start_time = time.time() state_container = ChatStateContainer() packets = handle_stream_message_objects( new_msg_req=request, user=user, db_session=db_session, external_state_container=state_container, ) full = gather_stream_full(packets, state_container) result = _chat_full_response_to_eval_result(full, stream_start_time) # Evaluate tool assertions assertion_passed, assertion_details = evaluate_tool_assertions( result.tools_called, tool_assertions ) logger.info( f"Eval completed. Tools called: {result.tools_called}.\n" f"Assertion passed: {assertion_passed}. Details: {assertion_details}\n" ) return EvalToolResult( answer=result.answer, tools_called=result.tools_called, tool_call_details=result.tool_call_details, citations=result.citations, assertion_passed=assertion_passed, assertion_details=assertion_details, timings=result.timings, ) def _get_multi_turn_answer_with_tools( eval_input: dict[str, Any], configuration: EvalConfigurationOptions, ) -> MultiTurnEvalResult: """ Get answers from a multi-turn conversation with tool call tracking for each turn. Args: eval_input: Dictionary containing: - 'messages': List of message dicts, each with: - 'message': The user message text - 'expected_tools' (optional): List of expected tool types - 'require_all_tools' (optional): If true, all expected tools must be called - 'model' (optional): Model version override for this turn - 'model_provider' (optional): Provider override for this turn - 'temperature' (optional): Temperature override for this turn - 'force_tools' (optional): List of tool types to force configuration: Evaluation configuration options Returns: MultiTurnEvalResult containing per-turn results and aggregate metrics """ messages_data = eval_input.get("messages", []) if not messages_data: raise ValueError("Multi-turn eval requires 'messages' array in input") # Parse messages into EvalMessage objects messages: list[EvalMessage] = [] for msg_data in messages_data: messages.append( EvalMessage( message=msg_data["message"], expected_tools=msg_data.get("expected_tools", []), require_all_tools=msg_data.get("require_all_tools", False), model=msg_data.get("model"), model_provider=msg_data.get("model_provider"), temperature=msg_data.get("temperature"), force_tools=msg_data.get("force_tools", []), ) ) turn_results: list[EvalToolResult] = [] engine = get_sqlalchemy_engine() with isolated_ephemeral_session_factory(engine) as SessionLocal: with SessionLocal() as db_session: full_configuration = configuration.get_configuration(db_session) user = get_user_by_email(configuration.search_permissions_email, db_session) if not user: raise ValueError( f"User not found for email: {configuration.search_permissions_email}" ) # Cache user_id to avoid SQLAlchemy expiration issues user_id = user.id # Create a single chat session for all turns chat_session = create_chat_session( db_session=db_session, description="Multi-turn eval session", user_id=user_id, persona_id=DEFAULT_PERSONA_ID, onyxbot_flow=True, ) chat_session_id = chat_session.id # Process each turn sequentially for turn_idx, msg in enumerate(messages): logger.info( f"Processing turn {turn_idx + 1}/{len(messages)}: {msg.message[:50]}..." ) # Handle per-turn tool forcing forced_tool_ids: list[int] = [] if msg.force_tools: from onyx.db.tools import get_builtin_tool from onyx.tools.built_in_tools import BUILT_IN_TOOL_MAP for tool_type in msg.force_tools: if tool_type in BUILT_IN_TOOL_MAP: tool_id = get_builtin_tool( db_session, BUILT_IN_TOOL_MAP[tool_type] ).id if tool_id not in forced_tool_ids: forced_tool_ids.append(tool_id) # Build tool assertions for this turn tool_assertions: ToolAssertion | None = None if msg.expected_tools: tool_assertions = ToolAssertion( expected_tools=msg.expected_tools, require_all=msg.require_all_tools, ) # Handle per-turn model configuration llm_override = full_configuration.llm if msg.model or msg.model_provider or msg.temperature is not None: llm_override = LLMOverride( model_provider=msg.model_provider or llm_override.model_provider, model_version=msg.model or llm_override.model_version, temperature=( msg.temperature if msg.temperature is not None else llm_override.temperature ), ) # Create request for this turn using SendMessageRequest (same API as handle_stream_message_objects) # Use AUTO_PLACE_AFTER_LATEST_MESSAGE to chain messages forced_tool_id = forced_tool_ids[0] if forced_tool_ids else None request = SendMessageRequest( chat_session_id=chat_session_id, parent_message_id=AUTO_PLACE_AFTER_LATEST_MESSAGE, message=msg.message, llm_override=llm_override, allowed_tool_ids=full_configuration.allowed_tool_ids, forced_tool_id=forced_tool_id, ) # Stream and gather results for this turn via handle_stream_message_objects + gather_stream_full stream_start_time = time.time() state_container = ChatStateContainer() packets = handle_stream_message_objects( new_msg_req=request, user=user, db_session=db_session, external_state_container=state_container, ) full = gather_stream_full(packets, state_container) result = _chat_full_response_to_eval_result(full, stream_start_time) # Evaluate tool assertions for this turn assertion_passed, assertion_details = evaluate_tool_assertions( result.tools_called, tool_assertions ) logger.info( f"Turn {turn_idx + 1} completed. Tools called: {result.tools_called}.\n" f"Assertion passed: {assertion_passed}. Details: {assertion_details}\n" ) turn_results.append( EvalToolResult( answer=result.answer, tools_called=result.tools_called, tool_call_details=result.tool_call_details, citations=result.citations, assertion_passed=assertion_passed, assertion_details=assertion_details, timings=result.timings, ) ) # Calculate aggregate metrics pass_count = sum(1 for r in turn_results if r.assertion_passed is True) fail_count = sum(1 for r in turn_results if r.assertion_passed is False) # Consider "all passed" only if there are no failures # (turns with no assertions don't count as failures) all_passed = fail_count == 0 return MultiTurnEvalResult( turn_results=turn_results, all_passed=all_passed, pass_count=pass_count, fail_count=fail_count, total_turns=len(turn_results), ) def run_eval( configuration: EvalConfigurationOptions, data: list[dict[str, Any]] | None = None, remote_dataset_name: str | None = None, provider: EvalProvider = get_provider(), ) -> EvalationAck: if data is not None and remote_dataset_name is not None: raise ValueError("Cannot specify both data and remote_dataset_name") if data is None and remote_dataset_name is None: raise ValueError("Must specify either data or remote_dataset_name") return provider.eval( task=lambda eval_input: _get_answer_with_tools(eval_input, configuration), configuration=configuration, data=data, remote_dataset_name=remote_dataset_name, multi_turn_task=lambda eval_input: _get_multi_turn_answer_with_tools( eval_input, configuration ), ) ================================================ FILE: backend/onyx/evals/eval_cli.py ================================================ #!/usr/bin/env python3 """ CLI for running evaluations with local configurations. """ import argparse import json import logging import os from typing import Any import braintrust import requests from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE from onyx.configs.constants import POSTGRES_WEB_APP_NAME from onyx.db.engine.sql_engine import SqlEngine from onyx.evals.eval import run_eval from onyx.evals.models import EvalationAck from onyx.evals.models import EvalConfigurationOptions from onyx.evals.provider import get_provider from onyx.tracing.setup import setup_tracing def setup_session_factory() -> None: SqlEngine.set_app_name(POSTGRES_WEB_APP_NAME) SqlEngine.init_engine( pool_size=POSTGRES_API_SERVER_POOL_SIZE, max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW, ) def load_data_local( local_data_path: str, ) -> list[dict[str, Any]]: if not os.path.isfile(local_data_path): raise ValueError(f"Local data file does not exist: {local_data_path}") with open(local_data_path, "r") as f: return json.load(f) def configure_logging_for_evals(verbose: bool) -> None: """Set logging level to WARNING to reduce noise during evals.""" if verbose: return # Set environment variable for any future logger creation os.environ["LOG_LEVEL"] = "WARNING" # Force WARNING level for root logger and its handlers root = logging.getLogger() root.setLevel(logging.WARNING) for handler in root.handlers: handler.setLevel(logging.WARNING) # Force WARNING level for all existing loggers and their handlers for name in list(logging.Logger.manager.loggerDict.keys()): logger = logging.getLogger(name) logger.setLevel(logging.WARNING) for handler in logger.handlers: handler.setLevel(logging.WARNING) # Set a basic config to ensure new loggers also use WARNING logging.basicConfig(level=logging.WARNING, force=True) def run_local( local_data_path: str | None, remote_dataset_name: str | None, search_permissions_email: str | None = None, no_send_logs: bool = False, local_only: bool = False, verbose: bool = False, ) -> EvalationAck: """ Run evaluation with local configurations. Tool forcing and assertions are configured per-test in the data file using: - force_tools: List of tool type names to force - expected_tools: List of tool type names expected to be called - require_all_tools: If true, all expected tools must be called Args: local_data_path: Path to local JSON file remote_dataset_name: Name of remote Braintrust dataset search_permissions_email: Optional email address to impersonate for the evaluation no_send_logs: Whether to skip sending logs to Braintrust local_only: If True, use LocalEvalProvider (CLI output only, no Braintrust) Returns: EvalationAck: The evaluation result """ setup_session_factory() configure_logging_for_evals( verbose=verbose, ) # Only setup tracing if not running in local-only mode if not local_only: setup_tracing() if search_permissions_email is None: raise ValueError("search_permissions_email is required for local evaluation") configuration = EvalConfigurationOptions( search_permissions_email=search_permissions_email, dataset_name=remote_dataset_name or "local", no_send_logs=no_send_logs, ) # Get the appropriate provider provider = get_provider(local_only=local_only) if remote_dataset_name: score = run_eval( configuration=configuration, remote_dataset_name=remote_dataset_name, provider=provider, ) else: if local_data_path is None: raise ValueError( "local_data_path or remote_dataset_name is required for local evaluation" ) data = load_data_local(local_data_path) score = run_eval(configuration=configuration, data=data, provider=provider) return score def run_remote( base_url: str, api_key: str, remote_dataset_name: str, search_permissions_email: str, payload: dict[str, Any] | None = None, ) -> dict[str, Any]: """ Trigger an eval pipeline execution on a remote server. Tool forcing and assertions are configured per-test in the dataset. Args: base_url: Base URL of the remote server (e.g., "https://test.onyx.app") api_key: API key for authentication remote_dataset_name: Name of remote Braintrust dataset search_permissions_email: Email address to use for the evaluation. payload: Optional payload to send with the request Returns: Response from the remote server Raises: requests.RequestException: If the request fails """ if payload is None: payload = {} payload["search_permissions_email"] = search_permissions_email payload["dataset_name"] = remote_dataset_name url = f"{base_url}/api/evals/eval_run" headers = { "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } response = requests.post(url, headers=headers, json=payload) response.raise_for_status() return response.json() def main() -> None: """Main CLI entry point.""" parser = argparse.ArgumentParser( description="Run evaluations with local configurations" ) parser.add_argument( "--local-data-path", type=str, help="Path to local JSON file containing test data", ) parser.add_argument( "--remote-dataset-name", type=str, help="Name of remote Braintrust dataset", ) parser.add_argument( "--braintrust-project", type=str, help="Braintrust project name", default="Onyx", ) parser.add_argument("--verbose", action="store_true", help="Enable verbose output") # Remote eval arguments parser.add_argument( "--base-url", type=str, default="https://test.onyx.app", help="Base URL of the remote server (default: https://test.onyx.app)", ) parser.add_argument( "--api-key", type=str, help="API key for authentication with the remote server", ) parser.add_argument( "--remote", action="store_true", help="Run evaluation on remote server instead of locally", ) parser.add_argument( "--search-permissions-email", type=str, help="Email address to impersonate for the evaluation", ) parser.add_argument( "--no-send-logs", action="store_true", help="Do not send logs to the remote server", default=False, ) parser.add_argument( "--local-only", action="store_true", help="Run evals locally without Braintrust, output results to CLI only", default=False, ) args = parser.parse_args() if args.local_data_path: print(f"Loading data from local file: {args.local_data_path}") elif args.remote_dataset_name: if args.local_only: raise ValueError( "--local-only cannot be used with --remote-dataset-name. Use --local-data-path with a local JSON file instead." ) print(f"Loading data from remote dataset: {args.remote_dataset_name}") dataset = braintrust.init_dataset( project=args.braintrust_project, name=args.remote_dataset_name ) dataset_size = len(list(dataset.fetch())) print(f"Dataset size: {dataset_size}") if args.remote: if not args.api_key: print("Using API Key from ONYX_EVAL_API_KEY") api_key: str = ( args.api_key if args.api_key else os.environ.get("ONYX_EVAL_API_KEY", "") ) print(f"Running evaluation on remote server: {args.base_url}") if args.search_permissions_email: print(f"Using search permissions email: {args.search_permissions_email}") try: result = run_remote( args.base_url, api_key, args.remote_dataset_name, search_permissions_email=args.search_permissions_email, ) print(f"Remote evaluation triggered successfully: {result}") except requests.RequestException as e: print(f"Error triggering remote evaluation: {e}") return else: if args.local_only: print("Running in local-only mode (no Braintrust)") else: print(f"Using Braintrust project: {args.braintrust_project}") if args.search_permissions_email: print(f"Using search permissions email: {args.search_permissions_email}") run_local( local_data_path=args.local_data_path, remote_dataset_name=args.remote_dataset_name, search_permissions_email=args.search_permissions_email, no_send_logs=args.no_send_logs, local_only=args.local_only, verbose=args.verbose, ) if __name__ == "__main__": main() ================================================ FILE: backend/onyx/evals/models.py ================================================ from abc import ABC from abc import abstractmethod from collections.abc import Callable from typing import Any from pydantic import BaseModel from pydantic import Field from sqlalchemy.orm import Session from onyx.db.tools import get_builtin_tool from onyx.llm.override_models import LLMOverride from onyx.server.query_and_chat.streaming_models import CitationInfo from onyx.tools.built_in_tools import BUILT_IN_TOOL_MAP class ToolAssertion(BaseModel): """Assertion about expected tool usage during evaluation.""" expected_tools: list[str] # Tool type names that should be called require_all: bool = False # If True, ALL expected tools must be called class EvalTimings(BaseModel): """Timing information for eval execution.""" total_ms: float # Total time for the eval llm_first_token_ms: float | None = None # Time to first token from LLM tool_execution_ms: dict[str, float] = Field( default_factory=dict ) # Per-tool timings stream_processing_ms: float | None = None # Time to process the stream class ChatFullEvalResult(BaseModel): """Raw eval components from ChatFullResponse (before tool assertions).""" answer: str tools_called: list[str] tool_call_details: list[dict[str, Any]] citations: list[CitationInfo] timings: EvalTimings class EvalToolResult(BaseModel): """Result of a single eval with tool call information.""" answer: str tools_called: list[str] # Names of tools that were called tool_call_details: list[dict[str, Any]] # Full tool call info citations: list[CitationInfo] # Citations used in the answer assertion_passed: bool | None = None # None if no assertion configured assertion_details: str | None = None # Explanation of pass/fail timings: EvalTimings | None = None # Timing information for the eval class EvalMessage(BaseModel): """Single message in a multi-turn evaluation conversation.""" message: str # The message text to send expected_tools: list[str] = Field( default_factory=list ) # Expected tools for this turn require_all_tools: bool = False # If True, ALL expected tools must be called # Per-message model configuration overrides model: str | None = None model_provider: str | None = None temperature: float | None = None force_tools: list[str] = Field(default_factory=list) # Tools to force for this turn class MultiTurnEvalResult(BaseModel): """Result of a multi-turn evaluation containing per-message results.""" turn_results: list[EvalToolResult] # Results for each turn/message all_passed: bool # True if all turn assertions passed pass_count: int # Number of turns that passed fail_count: int # Number of turns that failed total_turns: int # Total number of turns class EvalConfiguration(BaseModel): llm: LLMOverride = Field(default_factory=LLMOverride) search_permissions_email: str allowed_tool_ids: list[int] class EvalConfigurationOptions(BaseModel): builtin_tool_types: list[str] = list(BUILT_IN_TOOL_MAP.keys()) llm: LLMOverride = LLMOverride( model_provider=None, model_version="gpt-4o", temperature=0.0, ) search_permissions_email: str dataset_name: str no_send_logs: bool = False # Optional override for Braintrust project (defaults to BRAINTRUST_PROJECT env var) braintrust_project: str | None = None # Optional experiment name for the eval run (shows in Braintrust UI) experiment_name: str | None = None def get_configuration(self, db_session: Session) -> EvalConfiguration: return EvalConfiguration( llm=self.llm, search_permissions_email=self.search_permissions_email, allowed_tool_ids=[ get_builtin_tool(db_session, BUILT_IN_TOOL_MAP[tool]).id for tool in self.builtin_tool_types ], ) class EvalationAck(BaseModel): success: bool class EvalProvider(ABC): @abstractmethod def eval( self, task: Callable[[dict[str, Any]], EvalToolResult], configuration: EvalConfigurationOptions, data: list[dict[str, Any]] | None = None, remote_dataset_name: str | None = None, multi_turn_task: "Callable[[dict[str, Any]], MultiTurnEvalResult] | None" = None, ) -> EvalationAck: pass ================================================ FILE: backend/onyx/evals/one_off/create_braintrust_dataset.py ================================================ #!/usr/bin/env python3 """ Script to create a Braintrust dataset from the DR Master Question & Metric Sheet CSV. This script: 1. Parses the CSV file 2. Filters records where "Should we use it" is TRUE and "web-only" is in categories 3. Creates a Braintrust dataset with Question as input and research_type metadata Usage: python create_braintrust_dataset.py --dataset-name "MyDataset" python create_braintrust_dataset.py --dataset-name "MyDataset" --csv-path "/path/to/csv" """ import argparse import csv import os import sys from typing import Any from typing import Dict from typing import List from onyx.configs.app_configs import BRAINTRUST_API_KEY try: from braintrust import init_dataset except ImportError: print( "Error: braintrust package not found. Please install it with: pip install braintrust" ) sys.exit(1) def column_letter_to_index(column_letter: str) -> int: """Convert Google Sheets column letter (A, B, C, etc.) to 0-based index.""" result = 0 for char in column_letter.upper(): result = result * 26 + (ord(char) - ord("A") + 1) return result - 1 def parse_csv_file(csv_path: str) -> List[Dict[str, Any]]: """Parse the CSV file and extract relevant records.""" records = [] with open(csv_path, "r", encoding="utf-8") as file: # Skip the first few header rows and read the actual data lines = file.readlines() # Find the actual data start (skip header rows) data_start = 0 for i, line in enumerate(lines): if "Should we use it?" in line: data_start = i + 1 break # Parse the CSV data starting from the data_start line csv_reader = csv.reader(lines[data_start:]) # Define Google Sheets column references for easy modification SHOULD_USE_COL = "C" # "Should we use it?" QUESTION_COL = "H" # "Question" EXPECTED_DEPTH_COL = "J" # "Expected Depth" CATEGORIES_COL = "M" # "Categories" OPENAI_DEEP_COL = "AA" # "OpenAI Deep Answer" OPENAI_THINKING_COL = "O" # "OpenAI Thinking Answer" for row_num, row in enumerate(csv_reader, start=data_start + 1): if len(row) < 15: # Ensure we have enough columns continue # Extract relevant fields using Google Sheets column references should_use = ( row[column_letter_to_index(SHOULD_USE_COL)].strip().upper() if len(row) > column_letter_to_index(SHOULD_USE_COL) else "" ) question = ( row[column_letter_to_index(QUESTION_COL)].strip() if len(row) > column_letter_to_index(QUESTION_COL) else "" ) expected_depth = ( row[column_letter_to_index(EXPECTED_DEPTH_COL)].strip() if len(row) > column_letter_to_index(EXPECTED_DEPTH_COL) else "" ) categories = ( row[column_letter_to_index(CATEGORIES_COL)].strip() if len(row) > column_letter_to_index(CATEGORIES_COL) else "" ) openai_deep_answer = ( row[column_letter_to_index(OPENAI_DEEP_COL)].strip() if len(row) > column_letter_to_index(OPENAI_DEEP_COL) else "" ) openai_thinking_answer = ( row[column_letter_to_index(OPENAI_THINKING_COL)].strip() if len(row) > column_letter_to_index(OPENAI_THINKING_COL) else "" ) # Filter records: should_use = TRUE and categories contains "web-only" if ( should_use == "TRUE" and "web-only" in categories and question ): # Ensure question is not empty if expected_depth == "Deep": records.extend( [ { "question": question + ". All info is contained in the quesiton. DO NOT ask any clarifying questions.", "research_type": "DEEP", "categories": categories, "expected_depth": expected_depth, "expected_answer": openai_deep_answer, "row_number": row_num, } ] ) else: records.extend( [ { "question": question, "research_type": "THOUGHTFUL", "categories": categories, "expected_depth": expected_depth, "expected_answer": openai_thinking_answer, "row_number": row_num, } ] ) return records def create_braintrust_dataset(records: List[Dict[str, Any]], dataset_name: str) -> None: """Create a Braintrust dataset with the filtered records.""" # Check if BRAINTRUST_API_KEY is set if BRAINTRUST_API_KEY == "": print("WARNING: BRAINTRUST_API_KEY environment variable is not set.") print( "The script will show what would be inserted but won't actually create the dataset." ) print( "To actually create the dataset, set your BRAINTRUST_API_KEY environment variable." ) print() # Show what would be inserted print( f"Would create Braintrust dataset '{dataset_name}' with {len(records)} records:" ) for i, record in enumerate(records, 1): print(f"Record {i}/{len(records)}:") print(f" Question: {record['question'][:100]}...") print(f" Research Type: {record['research_type']}") print(f" Expected Answer: {record['expected_answer'][:100]}...") print() return # Initialize the dataset dataset = init_dataset("Onyx", dataset_name, api_key=BRAINTRUST_API_KEY) print(f"Creating Braintrust dataset with {len(records)} records...") # Insert records into the dataset for i, record in enumerate(records, 1): record_id = dataset.insert( {"message": record["question"], "research_type": record["research_type"]}, expected=record["expected_answer"], ) print(f"Inserted record {i}/{len(records)}: ID {record_id}") print(f" Question: {record['question'][:100]}...") print(f" Research Type: {record['research_type']}") print(f" Expected Answer: {record['expected_answer'][:100]}...") print() # Flush to ensure all records are sent dataset.flush() print(f"Successfully created dataset with {len(records)} records!") def main() -> None: """Main function to run the script.""" parser = argparse.ArgumentParser( description="Create a Braintrust dataset from the DR Master Question & Metric Sheet CSV" ) parser.add_argument( "--dataset-name", required=True, help="Name of the Braintrust dataset to create" ) parser.add_argument( "--csv-path", default="/Users/richardguan/onyx/backend/onyx/evals/data/DR Master Question & Metric Sheet - Sheet1.csv", help="Path to the CSV file (default: %(default)s)", ) args = parser.parse_args() csv_path = args.csv_path dataset_name = args.dataset_name if not os.path.exists(csv_path): print(f"Error: CSV file not found at {csv_path}") sys.exit(1) print("Parsing CSV file...") records = parse_csv_file(csv_path) print(f"Found {len(records)} records matching criteria:") print("- Should we use it = TRUE") print("- Categories contains 'web-only'") print("- Question is not empty") print() if not records: print("No records found matching the criteria!") sys.exit(1) # Show summary of research types deep_count = sum(1 for r in records if r["research_type"] == "DEEP") thoughtful_count = sum(1 for r in records if r["research_type"] == "THOUGHTFUL") print("Research type breakdown:") print(f" DEEP: {deep_count}") print(f" THOUGHTFUL: {thoughtful_count}") print() # Create the Braintrust dataset create_braintrust_dataset(records, dataset_name) if __name__ == "__main__": main() ================================================ FILE: backend/onyx/evals/provider.py ================================================ from onyx.evals.models import EvalProvider from onyx.evals.providers.braintrust import BraintrustEvalProvider from onyx.evals.providers.local import LocalEvalProvider def get_provider(local_only: bool = False) -> EvalProvider: """ Get the appropriate eval provider. Args: local_only: If True, use LocalEvalProvider (CLI output only, no Braintrust). If False, use BraintrustEvalProvider. Returns: The appropriate EvalProvider instance. """ if local_only: return LocalEvalProvider() return BraintrustEvalProvider() ================================================ FILE: backend/onyx/evals/providers/braintrust.py ================================================ from collections.abc import Callable from typing import Any from typing import Union from braintrust import Eval from braintrust import EvalCase from braintrust import init_dataset from braintrust import Score from onyx.configs.app_configs import BRAINTRUST_MAX_CONCURRENCY from onyx.configs.app_configs import BRAINTRUST_PROJECT from onyx.evals.models import EvalationAck from onyx.evals.models import EvalConfigurationOptions from onyx.evals.models import EvalProvider from onyx.evals.models import EvalToolResult from onyx.evals.models import MultiTurnEvalResult from onyx.utils.logger import setup_logger logger = setup_logger() # Union type for both single and multi-turn results EvalResult = Union[EvalToolResult, MultiTurnEvalResult] def tool_assertion_scorer( input: dict[str, Any], output: EvalResult, expected: EvalResult | None ) -> Score: """ Scorer that checks if tool assertions passed. Handles both single-turn (EvalToolResult) and multi-turn (MultiTurnEvalResult) outputs. Args: input: The input data for the evaluation case. output: The actual output from the task. expected: The expected output (unused for this scorer). Returns: Score with value 1.0 if passed or no assertions, 0.0 if failed. """ # input and expected are unused but required by Braintrust scorer signature _ = input, expected # Handle multi-turn results if isinstance(output, MultiTurnEvalResult): # Calculate score based on pass rate if output.total_turns == 0: score = 1.0 else: # Score is the ratio of passed assertions assertions_evaluated = output.pass_count + output.fail_count if assertions_evaluated == 0: score = 1.0 # No assertions configured else: score = output.pass_count / assertions_evaluated return Score( name="tool_assertion", score=score, metadata={ "is_multi_turn": True, "total_turns": output.total_turns, "pass_count": output.pass_count, "fail_count": output.fail_count, "all_passed": output.all_passed, "turn_details": [ { "tools_called": r.tools_called, "assertion_passed": r.assertion_passed, "assertion_details": r.assertion_details, } for r in output.turn_results ], }, ) # Handle single-turn results (EvalToolResult) if output.assertion_passed is None: # No assertions configured - return passing score return Score( name="tool_assertion", score=1.0, metadata={ "is_multi_turn": False, "tools_called": output.tools_called, "tools_called_count": len(output.tools_called), "assertion_configured": False, }, ) return Score( name="tool_assertion", score=1.0 if output.assertion_passed else 0.0, metadata={ "is_multi_turn": False, "tools_called": output.tools_called, "tools_called_count": len(output.tools_called), "assertion_passed": output.assertion_passed, "assertion_details": output.assertion_details, "tool_call_details": output.tool_call_details, }, ) class BraintrustEvalProvider(EvalProvider): def eval( self, task: Callable[[dict[str, Any]], EvalToolResult], configuration: EvalConfigurationOptions, data: list[dict[str, Any]] | None = None, remote_dataset_name: str | None = None, multi_turn_task: Callable[[dict[str, Any]], MultiTurnEvalResult] | None = None, ) -> EvalationAck: if data is not None and remote_dataset_name is not None: raise ValueError("Cannot specify both data and remote_dataset_name") if data is None and remote_dataset_name is None: raise ValueError("Must specify either data or remote_dataset_name") # Create a wrapper task that dispatches to the appropriate handler def dispatch_task(eval_input: dict[str, Any]) -> EvalResult: if "messages" in eval_input and multi_turn_task is not None: return multi_turn_task(eval_input) return task(eval_input) project_name = configuration.braintrust_project or BRAINTRUST_PROJECT experiment_name = configuration.experiment_name eval_data: Any = None if remote_dataset_name is not None: eval_data = init_dataset(project=project_name, name=remote_dataset_name) else: if data: eval_data = [ EvalCase( input={ **item.get("input", {}), # Pass through per-test tool configuration (for single-turn) "force_tools": item.get("force_tools", []), "expected_tools": item.get("expected_tools", []), "require_all_tools": item.get("require_all_tools", False), # Pass through per-test model configuration "model": item.get("model"), "model_provider": item.get("model_provider"), "temperature": item.get("temperature"), }, expected=item.get("expected"), ) for item in data ] metadata = configuration.model_dump() Eval( # type: ignore[misc] name=project_name, experiment_name=experiment_name, data=eval_data, task=dispatch_task, scores=[tool_assertion_scorer], metadata=metadata, max_concurrency=BRAINTRUST_MAX_CONCURRENCY, no_send_logs=configuration.no_send_logs, ) return EvalationAck(success=True) ================================================ FILE: backend/onyx/evals/providers/local.py ================================================ """ Local eval provider that runs evaluations and outputs results to the CLI. No external dependencies like Braintrust required. """ from collections.abc import Callable from typing import Any from onyx.evals.models import EvalationAck from onyx.evals.models import EvalConfigurationOptions from onyx.evals.models import EvalProvider from onyx.evals.models import EvalToolResult from onyx.evals.models import MultiTurnEvalResult from onyx.utils.logger import setup_logger logger = setup_logger() # ANSI color codes GREEN = "\033[92m" RED = "\033[91m" YELLOW = "\033[93m" BLUE = "\033[94m" BOLD = "\033[1m" RESET = "\033[0m" DIM = "\033[2m" def _display_single_turn_result( result: EvalToolResult, passed_count: list[int], failed_count: list[int], no_assertion_count: list[int], ) -> None: """Display results for a single turn and update counters.""" # Display timing trace if result.timings: print(f" {BOLD}Trace:{RESET}") print(f" Total: {result.timings.total_ms:.0f}ms") if result.timings.llm_first_token_ms is not None: print(f" First token: {result.timings.llm_first_token_ms:.0f}ms") if result.timings.tool_execution_ms: for tool_name, duration_ms in result.timings.tool_execution_ms.items(): print(f" {tool_name}: {duration_ms:.0f}ms") # Display tools called tools_str = ", ".join(result.tools_called) if result.tools_called else "(none)" print(f" Tools called: {BLUE}{tools_str}{RESET}") # Display assertion result if result.assertion_passed is None: print(f" Assertion: {YELLOW}N/A{RESET} - No assertion configured") no_assertion_count[0] += 1 elif result.assertion_passed: print(f" Assertion: {GREEN}PASS{RESET} - {result.assertion_details}") passed_count[0] += 1 else: print(f" Assertion: {RED}FAIL{RESET} - {result.assertion_details}") failed_count[0] += 1 # Display truncated answer answer = result.answer truncated_answer = answer[:200] + "..." if len(answer) > 200 else answer truncated_answer = truncated_answer.replace("\n", " ") print(f" Answer: {truncated_answer}") class LocalEvalProvider(EvalProvider): """ Eval provider that runs evaluations locally and prints results to the CLI. Does not require Braintrust or any external service. """ def eval( self, task: Callable[[dict[str, Any]], EvalToolResult], configuration: EvalConfigurationOptions, # noqa: ARG002 data: list[dict[str, Any]] | None = None, remote_dataset_name: str | None = None, multi_turn_task: Callable[[dict[str, Any]], MultiTurnEvalResult] | None = None, ) -> EvalationAck: if remote_dataset_name is not None: raise ValueError( "LocalEvalProvider does not support remote datasets. Use --local-data-path with a local JSON file." ) if data is None: raise ValueError("data is required for LocalEvalProvider") total = len(data) # Use lists to allow mutation in helper function passed = [0] failed = [0] no_assertion = [0] print(f"\n{BOLD}Running {total} evaluation(s)...{RESET}\n") print("=" * 60) for i, item in enumerate(data, 1): input_data = item.get("input", {}) # Check if this is a multi-turn eval (has 'messages' array) if "messages" in input_data: self._run_multi_turn_eval( i, total, item, multi_turn_task, passed, failed, no_assertion ) else: self._run_single_turn_eval( i, total, item, task, passed, failed, no_assertion ) # Summary print("\n" + "=" * 60) total_with_assertions = passed[0] + failed[0] if total_with_assertions > 0: pass_rate = (passed[0] / total_with_assertions) * 100 print( f"{BOLD}Summary:{RESET} {passed[0]}/{total_with_assertions} passed ({pass_rate:.1f}%)" ) else: print(f"{BOLD}Summary:{RESET} No assertions configured") print(f" {GREEN}Passed:{RESET} {passed[0]}") print(f" {RED}Failed:{RESET} {failed[0]}") if no_assertion[0] > 0: print(f" {YELLOW}No assertion:{RESET} {no_assertion[0]}") print("=" * 60 + "\n") # Return success if no failures return EvalationAck(success=(failed[0] == 0)) def _run_single_turn_eval( self, i: int, total: int, item: dict[str, Any], task: Callable[[dict[str, Any]], EvalToolResult], passed: list[int], failed: list[int], no_assertion: list[int], ) -> None: """Run a single-turn evaluation.""" # Build input with tool and model config eval_input = { **item.get("input", {}), # Tool configuration "force_tools": item.get("force_tools", []), "expected_tools": item.get("expected_tools", []), "require_all_tools": item.get("require_all_tools", False), # Model configuration "model": item.get("model"), "model_provider": item.get("model_provider"), "temperature": item.get("temperature"), } message = eval_input.get("message", "(no message)") truncated_message = message[:50] + "..." if len(message) > 50 else message # Show model if specified model_info = "" if item.get("model"): model_info = f" [{item.get('model')}]" print(f'\n{BOLD}[{i}/{total}]{RESET} "{truncated_message}"{model_info}') try: result = task(eval_input) _display_single_turn_result(result, passed, failed, no_assertion) except Exception as e: print(f" {RED}ERROR:{RESET} {e}") failed[0] += 1 logger.exception(f"Error running eval for input: {message}") def _run_multi_turn_eval( self, i: int, total: int, item: dict[str, Any], multi_turn_task: Callable[[dict[str, Any]], MultiTurnEvalResult] | None, passed: list[int], failed: list[int], no_assertion: list[int], ) -> None: """Run a multi-turn evaluation.""" if multi_turn_task is None: print( f"\n{BOLD}[{i}/{total}]{RESET} {RED}ERROR:{RESET} Multi-turn task not configured" ) failed[0] += 1 return input_data = item.get("input", {}) messages = input_data.get("messages", []) num_turns = len(messages) # Show first message as preview first_msg = ( messages[0].get("message", "(no message)") if messages else "(no messages)" ) truncated_first = first_msg[:40] + "..." if len(first_msg) > 40 else first_msg print(f"\n{BOLD}[{i}/{total}] Multi-turn ({num_turns} turns){RESET}") print(f' First: "{truncated_first}"') try: # Pass the full input with messages eval_input = {**input_data} result = multi_turn_task(eval_input) # Display each turn's result for turn_idx, turn_result in enumerate(result.turn_results): turn_msg = messages[turn_idx].get("message", "") truncated_turn = ( turn_msg[:40] + "..." if len(turn_msg) > 40 else turn_msg ) print(f'\n {DIM}Turn {turn_idx + 1}:{RESET} "{truncated_turn}"') _display_single_turn_result(turn_result, passed, failed, no_assertion) # Show multi-turn summary status = ( f"{GREEN}ALL PASSED{RESET}" if result.all_passed else f"{RED}SOME FAILED{RESET}" ) print( f"\n {BOLD}Multi-turn result:{RESET} {status} ({result.pass_count}/{result.total_turns} turns passed)" ) except Exception as e: print(f" {RED}ERROR:{RESET} {e}") failed[0] += 1 logger.exception(f"Error running multi-turn eval: {first_msg}") ================================================ FILE: backend/onyx/feature_flags/__init__.py ================================================ ================================================ FILE: backend/onyx/feature_flags/factory.py ================================================ from onyx.configs.app_configs import DEV_MODE from onyx.feature_flags.interface import FeatureFlagProvider from onyx.feature_flags.interface import NoOpFeatureFlagProvider from onyx.utils.variable_functionality import ( fetch_versioned_implementation_with_fallback, ) from shared_configs.configs import MULTI_TENANT def get_default_feature_flag_provider() -> FeatureFlagProvider: """ Get the default feature flag provider implementation. Returns the PostHog-based provider in Enterprise Edition when available, otherwise returns a no-op provider that always returns False. This function is designed for dependency injection - callers should use this factory rather than directly instantiating providers. Returns: FeatureFlagProvider: The configured feature flag provider instance """ if MULTI_TENANT or DEV_MODE: return fetch_versioned_implementation_with_fallback( module="onyx.feature_flags.factory", attribute="get_posthog_feature_flag_provider", fallback=lambda: NoOpFeatureFlagProvider(), )() return NoOpFeatureFlagProvider() ================================================ FILE: backend/onyx/feature_flags/feature_flags_keys.py ================================================ """ Feature flag keys used throughout the application. Centralizes feature flag key definitions to avoid magic strings. """ ================================================ FILE: backend/onyx/feature_flags/flags.py ================================================ ================================================ FILE: backend/onyx/feature_flags/interface.py ================================================ import abc from typing import Any from uuid import UUID from onyx.db.models import User from shared_configs.configs import ENVIRONMENT class FeatureFlagProvider(abc.ABC): """ Abstract base class for feature flag providers. Implementations should provide vendor-specific logic for checking whether a feature flag is enabled for a given user. """ @abc.abstractmethod def feature_enabled( self, flag_key: str, user_id: UUID, user_properties: dict[str, Any] | None = None, ) -> bool: """ Check if a feature flag is enabled for a user. Args: flag_key: The identifier for the feature flag to check user_id: The unique identifier for the user user_properties: Optional dictionary of user properties/attributes that may influence flag evaluation Returns: True if the feature is enabled for the user, False otherwise """ raise NotImplementedError def feature_enabled_for_user_tenant( self, flag_key: str, user: User, tenant_id: str ) -> bool: """ Check if a feature flag is enabled for a user. """ return self.feature_enabled( flag_key, # For anonymous/unauthenticated users, use a fixed UUID as fallback user.id if user else UUID("caa1e0cd-6ee6-4550-b1ec-8affaef4bf83"), user_properties={ "tenant_id": tenant_id, "email": user.email if user else "anonymous@onyx.app", }, ) class NoOpFeatureFlagProvider(FeatureFlagProvider): """ No-operation feature flag provider that always returns False. Used as a fallback when no real feature flag provider is available (e.g., in MIT version without PostHog). """ def feature_enabled( self, flag_key: str, # noqa: ARG002 user_id: UUID, # noqa: ARG002 user_properties: dict[str, Any] | None = None, # noqa: ARG002 ) -> bool: environment = ENVIRONMENT if environment == "local": return True return False ================================================ FILE: backend/onyx/federated_connectors/__init__.py ================================================ ================================================ FILE: backend/onyx/federated_connectors/factory.py ================================================ """Factory for creating federated connector instances.""" import importlib from typing import Any from typing import Type from onyx.configs.constants import FederatedConnectorSource from onyx.federated_connectors.interfaces import FederatedConnector from onyx.federated_connectors.registry import FEDERATED_CONNECTOR_CLASS_MAP from onyx.utils.logger import setup_logger logger = setup_logger() class FederatedConnectorMissingException(Exception): pass # Cache for already imported federated connector classes _federated_connector_cache: dict[FederatedConnectorSource, Type[FederatedConnector]] = ( {} ) def _load_federated_connector_class( source: FederatedConnectorSource, ) -> Type[FederatedConnector]: """Dynamically load and cache a federated connector class.""" if source in _federated_connector_cache: return _federated_connector_cache[source] if source not in FEDERATED_CONNECTOR_CLASS_MAP: raise FederatedConnectorMissingException( f"Federated connector not found for source={source}" ) mapping = FEDERATED_CONNECTOR_CLASS_MAP[source] try: module = importlib.import_module(mapping.module_path) connector_class = getattr(module, mapping.class_name) _federated_connector_cache[source] = connector_class return connector_class except (ImportError, AttributeError) as e: raise FederatedConnectorMissingException( f"Failed to import {mapping.class_name} from {mapping.module_path}: {e}" ) def get_federated_connector( source: FederatedConnectorSource, credentials: dict[str, Any], ) -> FederatedConnector: """Get an instance of the appropriate federated connector.""" connector_cls = get_federated_connector_cls(source) return connector_cls(credentials) def get_federated_connector_cls( source: FederatedConnectorSource, ) -> Type[FederatedConnector]: """Get the class of the appropriate federated connector.""" return _load_federated_connector_class(source) ================================================ FILE: backend/onyx/federated_connectors/federated_retrieval.py ================================================ from collections import defaultdict from collections.abc import Callable from typing import Any from uuid import UUID from pydantic import BaseModel from pydantic import ConfigDict from sqlalchemy.orm import Session from onyx.configs.constants import DocumentSource from onyx.configs.constants import FederatedConnectorSource from onyx.context.search.models import ChunkIndexRequest from onyx.context.search.models import InferenceChunk from onyx.db.federated import ( get_federated_connector_document_set_mappings_by_document_set_names, ) from onyx.db.federated import list_federated_connector_oauth_tokens from onyx.db.models import FederatedConnector__DocumentSet from onyx.db.slack_bot import fetch_slack_bots from onyx.federated_connectors.factory import get_federated_connector from onyx.federated_connectors.interfaces import FederatedConnector from onyx.onyxbot.slack.models import SlackContext from onyx.utils.logger import setup_logger logger = setup_logger() class FederatedRetrievalInfo(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) retrieval_function: Callable[[ChunkIndexRequest], list[InferenceChunk]] source: FederatedConnectorSource def get_federated_retrieval_functions( db_session: Session, user_id: UUID | None, source_types: list[DocumentSource] | None, document_set_names: list[str] | None, slack_context: SlackContext | None = None, ) -> list[FederatedRetrievalInfo]: # Check for Slack bot context first (regardless of user_id) if slack_context: logger.debug("Slack context detected, checking for Slack bot setup...") # Slack federated search requires a Slack federated connector to be linked # via document sets. If no document sets are provided, skip Slack federated search. if not document_set_names: logger.debug( "Skipping Slack federated search: no document sets provided, " "Slack federated connector must be linked via document sets" ) return [] # Check if any Slack federated connector is associated with the document sets # and extract its config (entities) for channel filtering slack_federated_connector_config: dict[str, Any] | None = None slack_federated_mappings = ( get_federated_connector_document_set_mappings_by_document_set_names( db_session, document_set_names ) ) for mapping in slack_federated_mappings: if ( mapping.federated_connector is not None and mapping.federated_connector.source == FederatedConnectorSource.FEDERATED_SLACK ): slack_federated_connector_config = ( mapping.federated_connector.config or {} ) logger.debug( f"Found Slack federated connector config: {slack_federated_connector_config}" ) break if slack_federated_connector_config is None: logger.debug( f"Skipping Slack federated search: document sets {document_set_names} " "are not associated with any Slack federated connector" ) # Return empty list - no Slack federated search for this context return [] try: slack_bots = fetch_slack_bots(db_session) logger.debug(f"Found {len(slack_bots)} Slack bots") # First try to find a bot with user token tenant_slack_bot = next( (bot for bot in slack_bots if bot.enabled and bot.user_token), None ) if tenant_slack_bot: logger.debug(f"Selected bot with user_token: {tenant_slack_bot.name}") else: # Fall back to any enabled bot without user token tenant_slack_bot = next( (bot for bot in slack_bots if bot.enabled), None ) if tenant_slack_bot: logger.debug( f"Selected bot without user_token: {tenant_slack_bot.name} (limited functionality)" ) else: logger.warning("No enabled Slack bots found") if tenant_slack_bot: federated_retrieval_infos_slack = [] # Use user_token if available, otherwise fall back to bot_token # Unwrap SensitiveValue for backend API calls access_token = ( tenant_slack_bot.user_token.get_value(apply_mask=False) if tenant_slack_bot.user_token else ( tenant_slack_bot.bot_token.get_value(apply_mask=False) if tenant_slack_bot.bot_token else "" ) ) if not tenant_slack_bot.user_token: logger.warning( f"Using bot_token for Slack search (limited functionality): {tenant_slack_bot.name}" ) # For bot context, we don't need real OAuth credentials credentials = { "client_id": "bot-context", # Placeholder for bot context "client_secret": "bot-context", # Placeholder for bot context } # Create Slack federated connector connector = get_federated_connector( FederatedConnectorSource.FEDERATED_SLACK, credentials, ) # Capture variables by value to avoid lambda closure issues # Unwrap SensitiveValue for backend API calls bot_token = ( tenant_slack_bot.bot_token.get_value(apply_mask=False) if tenant_slack_bot.bot_token else "" ) # Use connector config for channel filtering (guaranteed to exist at this point) connector_entities = slack_federated_connector_config logger.debug( f"Using Slack federated connector entities for bot context: {connector_entities}" ) def create_slack_retrieval_function( conn: FederatedConnector, token: str, ctx: SlackContext, bot_tok: str, entities: dict[str, Any], ) -> Callable[[ChunkIndexRequest], list[InferenceChunk]]: def retrieval_fn(query: ChunkIndexRequest) -> list[InferenceChunk]: return conn.search( query, entities, # Use connector-level entities for channel filtering access_token=token, limit=None, # Let connector use its own max_messages_per_query config slack_event_context=ctx, bot_token=bot_tok, ) return retrieval_fn federated_retrieval_infos_slack.append( FederatedRetrievalInfo( retrieval_function=create_slack_retrieval_function( connector, access_token, slack_context, bot_token, connector_entities, ), source=FederatedConnectorSource.FEDERATED_SLACK, ) ) logger.debug( f"Added Slack federated search for bot, returning {len(federated_retrieval_infos_slack)} retrieval functions" ) return federated_retrieval_infos_slack except Exception as e: logger.warning(f"Could not setup Slack bot federated search: {e}") # Fall through to regular federated connector logic if user_id is None: # No user ID provided and no Slack context, return empty logger.warning( "No user ID provided and no Slack context, returning empty retrieval functions" ) return [] federated_connector__document_set_pairs = ( ( get_federated_connector_document_set_mappings_by_document_set_names( db_session, document_set_names ) ) if document_set_names else [] ) federated_connector_id_to_document_sets: dict[ int, list[FederatedConnector__DocumentSet] ] = defaultdict(list) for pair in federated_connector__document_set_pairs: federated_connector_id_to_document_sets[pair.federated_connector_id].append( pair ) # At this point, user_id is guaranteed to be not None since we're in the else branch assert user_id is not None # If no source types are specified, don't use any federated connectors if source_types is None: logger.debug("No source types specified, skipping all federated connectors") return [] federated_retrieval_infos: list[FederatedRetrievalInfo] = [] federated_oauth_tokens = list_federated_connector_oauth_tokens(db_session, user_id) for oauth_token in federated_oauth_tokens: # Slack is handled separately inside SearchTool if ( oauth_token.federated_connector.source == FederatedConnectorSource.FEDERATED_SLACK ): logger.debug( "Skipping Slack federated connector in user OAuth path - handled by SearchTool" ) continue if ( oauth_token.federated_connector.source.to_non_federated_source() not in source_types ): continue document_set_associations = federated_connector_id_to_document_sets[ oauth_token.federated_connector_id ] # if document set names are specified by the user, skip federated connectors that are # not associated with any of the document sets if document_set_names and not document_set_associations: continue # Only use connector-level config (no junction table entities) entities = oauth_token.federated_connector.config or {} connector = get_federated_connector( oauth_token.federated_connector.source, oauth_token.federated_connector.credentials.get_value(apply_mask=False), ) # Capture variables by value to avoid lambda closure issues access_token = oauth_token.token.get_value(apply_mask=False) def create_retrieval_function( conn: FederatedConnector, ent: dict[str, Any], token: str, ) -> Callable[[ChunkIndexRequest], list[InferenceChunk]]: return lambda query: conn.search( query, ent, access_token=token, limit=None, # Let connector use its own max_messages_per_query config ) federated_retrieval_infos.append( FederatedRetrievalInfo( retrieval_function=create_retrieval_function( connector, entities, access_token ), source=oauth_token.federated_connector.source, ) ) return federated_retrieval_infos ================================================ FILE: backend/onyx/federated_connectors/interfaces.py ================================================ from abc import ABC from abc import abstractmethod from typing import Any from typing import Dict from onyx.context.search.models import ChunkIndexRequest from onyx.context.search.models import InferenceChunk from onyx.federated_connectors.models import CredentialField from onyx.federated_connectors.models import EntityField from onyx.federated_connectors.models import OAuthResult from onyx.onyxbot.slack.models import SlackContext class FederatedConnector(ABC): """Base interface that all federated connectors must implement.""" @abstractmethod def __init__(self, credentials: dict[str, Any]): """ Initialize the connector with credentials + validate their structure. Args: credentials: Dictionary of credentials to initialize the connector with """ self.credentials = credentials @abstractmethod def validate_entities(self, entities: Dict[str, Any]) -> bool: """ Validate that the provided entities match the expected structure. Args: entities: Dictionary of entities to validate Returns: True if entities are valid, False otherwise Note: This method is used for backward compatibility with document-set level entities. For connector-level config validation, use validate_config() instead. """ def validate_config(self, config: Dict[str, Any]) -> bool: """ Validate that the provided config matches the expected structure. This is an alias for validate_entities() to provide clearer semantics when validating connector-level configuration. Args: config: Dictionary of configuration to validate Returns: True if config is valid, False otherwise """ return self.validate_entities(config) @classmethod @abstractmethod def configuration_schema(cls) -> Dict[str, EntityField]: """ Return the specification of what configuration fields are available for this connector. Returns: Dictionary where keys are configuration field names and values are EntityField objects describing the expected structure and constraints. """ @classmethod @abstractmethod def credentials_schema(cls) -> Dict[str, CredentialField]: """ Return the specification of what credentials are required for this connector. Returns: Dictionary where keys are credential field names and values are CredentialField objects describing the expected structure, validation rules, and security properties. """ @abstractmethod def authorize(self, redirect_uri: str) -> str: """ Generate the OAuth authorization URL. Returns: The URL where users should be redirected to authorize the application """ @abstractmethod def callback(self, callback_data: Dict[str, Any], redirect_uri: str) -> OAuthResult: """ Handle the OAuth callback and exchange the authorization code for tokens. Args: callback_data: The data received from the OAuth callback (query params, etc.) redirect_uri: The OAuth redirect URI used in the authorization request Returns: Standardized OAuthResult containing tokens and metadata """ @abstractmethod def search( self, query: ChunkIndexRequest, entities: dict[str, Any], access_token: str, limit: int | None = None, # Slack-specific parameters slack_event_context: SlackContext | None = None, bot_token: str | None = None, ) -> list[InferenceChunk]: """ Perform a federated search using the provided query and entities. Args: query: The search query entities: Connector-level config (entity filtering configuration) access_token: The OAuth access token limit: Maximum number of results to return slack_event_context: Slack-specific context (only used by Slack bot) bot_token: Slack bot token (only used by Slack bot) Returns: Search results in a standardized format """ ================================================ FILE: backend/onyx/federated_connectors/models.py ================================================ from datetime import datetime from typing import Any from typing import Dict from typing import Optional from pydantic import BaseModel from pydantic import Field class FieldSpec(BaseModel): """Model for describing a field specification.""" type: str = Field( ..., description="The type of the field (e.g., 'str', 'bool', 'list[str]')" ) description: str = Field( ..., description="Description of what this field represents" ) required: bool = Field(default=False, description="Whether this field is required") default: Optional[Any] = Field( default=None, description="Default value if not provided" ) example: Optional[Any] = Field( default=None, description="Example value for documentation" ) secret: bool = Field( default=False, description="Whether this field contains sensitive data" ) class EntityField(FieldSpec): """Model for describing an entity field in the entities specification.""" class CredentialField(FieldSpec): """Model for describing a credential field in the credentials specification.""" class OAuthResult(BaseModel): """Standardized OAuth result that all federated connectors should return from callback.""" access_token: Optional[str] = Field( default=None, description="The bot access token for bot operations" ) user_token: Optional[str] = Field( default=None, description="The user access token for user-scoped operations like federated search", ) token_type: Optional[str] = Field( default=None, description="Token type (usually 'bearer')" ) scope: Optional[str] = Field(default=None, description="Granted scopes") expires_at: Optional[datetime] = Field( default=None, description="When the token expires" ) refresh_token: Optional[str] = Field( default=None, description="Refresh token if applicable" ) # Additional fields that might be useful team: Optional[Dict[str, Any]] = Field( default=None, description="Team/workspace information" ) user: Optional[Dict[str, Any]] = Field(default=None, description="User information") raw_response: Optional[Dict[str, Any]] = Field( default=None, description="Raw response for debugging" ) # Pydantic V2 automatically serializes datetime to ISO format, so no custom encoder needed ================================================ FILE: backend/onyx/federated_connectors/oauth_utils.py ================================================ """Generic OAuth utilities for federated connectors API layer.""" import base64 import json import uuid from typing import Any from onyx.cache.factory import get_cache_backend from onyx.configs.app_configs import WEB_DOMAIN from onyx.utils.logger import setup_logger logger = setup_logger() OAUTH_STATE_PREFIX = "federated_oauth" OAUTH_STATE_TTL = 300 # 5 minutes class OAuthSession: """Represents an OAuth session stored in the cache backend.""" def __init__( self, federated_connector_id: int, user_id: str, redirect_uri: str | None = None, additional_data: dict[str, Any] | None = None, ): self.federated_connector_id = federated_connector_id self.user_id = user_id self.redirect_uri = redirect_uri self.additional_data = additional_data or {} def to_dict(self) -> dict[str, Any]: return { "federated_connector_id": self.federated_connector_id, "user_id": self.user_id, "redirect_uri": self.redirect_uri, "additional_data": self.additional_data, } @classmethod def from_dict(cls, data: dict[str, Any]) -> "OAuthSession": return cls( federated_connector_id=data["federated_connector_id"], user_id=data["user_id"], redirect_uri=data.get("redirect_uri"), additional_data=data.get("additional_data", {}), ) def generate_oauth_state( federated_connector_id: int, user_id: str, redirect_uri: str | None = None, additional_data: dict[str, Any] | None = None, ttl: int = OAUTH_STATE_TTL, ) -> str: """ Generate a secure state parameter and store session data in the cache backend. Args: federated_connector_id: ID of the federated connector user_id: ID of the user initiating OAuth redirect_uri: Optional redirect URI after OAuth completion additional_data: Any additional data to store with the session ttl: Time-to-live in seconds for the cache key Returns: Base64-encoded state parameter """ # Generate a random UUID for the state state_uuid = uuid.uuid4() state_b64 = base64.urlsafe_b64encode(state_uuid.bytes).decode("utf-8").rstrip("=") session = OAuthSession( federated_connector_id=federated_connector_id, user_id=user_id, redirect_uri=redirect_uri, additional_data=additional_data, ) cache = get_cache_backend() cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}" cache.set(cache_key, json.dumps(session.to_dict()), ex=ttl) logger.info( f"Generated OAuth state for federated_connector_id={federated_connector_id}, user_id={user_id}, state={state_b64}" ) return state_b64 def verify_oauth_state(state: str) -> OAuthSession: """ Verify OAuth state parameter and retrieve session data. Args: state: Base64-encoded state parameter from OAuth callback Returns: OAuthSession if state is valid, None otherwise """ # Add padding if needed for base64 decoding padded_state = state + "=" * (-len(state) % 4) # Decode base64 to get UUID bytes state_bytes = base64.urlsafe_b64decode(padded_state) state_uuid = uuid.UUID(bytes=state_bytes) cache = get_cache_backend() cache_key = f"{OAUTH_STATE_PREFIX}:{state_uuid}" session_data = cache.get(cache_key) if not session_data: raise ValueError(f"OAuth state not found: {state}") cache.delete(cache_key) session_dict = json.loads(session_data) return OAuthSession.from_dict(session_dict) def get_oauth_callback_uri() -> str: """ Generate the OAuth callback URI for a federated connector. Returns: The callback URI """ # Use the frontend callback page as the OAuth redirect URI # The frontend will then make an API call to process the callback return f"{WEB_DOMAIN}/federated/oauth/callback" def add_state_to_oauth_url(base_oauth_url: str, state: str) -> str: """ Add state parameter to an OAuth URL. Args: base_oauth_url: The base OAuth URL from the connector state: The state parameter to add Returns: The OAuth URL with state parameter added """ # Check if URL already has query parameters separator = "&" if "?" in base_oauth_url else "?" return f"{base_oauth_url}{separator}state={state}" ================================================ FILE: backend/onyx/federated_connectors/registry.py ================================================ """Registry mapping for federated connector classes.""" from pydantic import BaseModel from onyx.configs.constants import FederatedConnectorSource class FederatedConnectorMapping(BaseModel): module_path: str class_name: str # Mapping of FederatedConnectorSource to connector details for lazy loading FEDERATED_CONNECTOR_CLASS_MAP = { FederatedConnectorSource.FEDERATED_SLACK: FederatedConnectorMapping( module_path="onyx.federated_connectors.slack.federated_connector", class_name="SlackFederatedConnector", ), } ================================================ FILE: backend/onyx/federated_connectors/slack/__init__.py ================================================ # Slack federated connector module ================================================ FILE: backend/onyx/federated_connectors/slack/federated_connector.py ================================================ from datetime import datetime from datetime import timedelta from datetime import timezone from typing import Any from urllib.parse import urlencode import requests from pydantic import ValidationError from slack_sdk import WebClient from typing_extensions import override from onyx.context.search.federated.slack_search import slack_retrieval from onyx.context.search.models import ChunkIndexRequest from onyx.context.search.models import InferenceChunk from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.federated_connectors.interfaces import FederatedConnector from onyx.federated_connectors.models import CredentialField from onyx.federated_connectors.models import EntityField from onyx.federated_connectors.models import OAuthResult from onyx.federated_connectors.slack.models import SlackCredentials from onyx.federated_connectors.slack.models import SlackEntities from onyx.onyxbot.slack.models import SlackContext from onyx.utils.logger import setup_logger logger = setup_logger() SCOPES = [ "channels:read", "groups:read", "im:read", "mpim:read", "search:read", "channels:history", "groups:history", "im:history", "mpim:history", "users:read", "users.profile:read", ] class SlackFederatedConnector(FederatedConnector): def __init__(self, credentials: dict[str, Any]): self.slack_credentials = SlackCredentials(**credentials) @override def validate_entities(self, entities: dict[str, Any]) -> bool: """Check the entities and verify that they match the expected structure/all values are valid. For Slack federated search, we expect: - channels: list[str] (list of channel names or IDs) - include_dm: bool (whether to include direct messages) """ try: # Use Pydantic model for validation SlackEntities(**entities) return True except ValidationError as e: logger.warning(f"Validation error for Slack entities: {e}") return False except Exception as e: logger.error(f"Error validating Slack entities: {e}") return False @classmethod def entities_schema(cls) -> dict[str, EntityField]: """Return the specifications of what entity configuration fields are available for Slack. This is the canonical schema definition for Slack entities. """ return { "exclude_channels": EntityField( type="list[str]", description="Exclude the following channels from search. Glob patterns are supported.", required=False, example=["secure-channel", "private-*", "customer*"], ), "search_all_channels": EntityField( type="bool", description="Search all accessible channels. If not set, must specify channels below.", required=False, default=False, example=False, ), "channels": EntityField( type="list[str]", description="Search the following channels", required=False, example=["general", "eng*", "product-*"], ), "include_dm": EntityField( type="bool", description="Include user direct messages in search results", required=False, default=False, example=False, ), "include_group_dm": EntityField( type="bool", description="Include group direct messages (multi-person DMs) in search results", required=False, default=False, example=False, ), "include_private_channels": EntityField( type="bool", description="Include private channels in search results (user must have access)", required=False, default=False, example=False, ), "default_search_days": EntityField( type="int", description="Maximum number of days to search back. Increasing this value degrades answer quality.", required=False, default=30, example=30, ), "max_messages_per_query": EntityField( type="int", description=( "Maximum number of messages to retrieve per search query. " "Higher values provide more context but may be slower." ), required=False, default=25, example=25, ), } @classmethod def configuration_schema(cls) -> dict[str, EntityField]: """Wrapper for backwards compatibility - delegates to entities_schema().""" return cls.entities_schema() @classmethod @override def credentials_schema(cls) -> dict[str, CredentialField]: """Return the specification of what credentials are required for Slack connector.""" return { "client_id": CredentialField( type="str", description="Slack app client ID from your Slack app configuration", required=True, example="1234567890.1234567890123", secret=False, ), "client_secret": CredentialField( type="str", description="Slack app client secret from your Slack app configuration", required=True, example="1a2b3c4d5e6f7g8h9i0j1k2l3m4n5o6p", secret=True, ), } @override def authorize(self, redirect_uri: str) -> str: """Get back the OAuth URL for Slack authorization. Returns the URL where users should be redirected to authorize the application. Note: State parameter will be added by the API layer. """ # Build OAuth URL with proper parameters (no state - handled by API layer) params = { "client_id": self.slack_credentials.client_id, "user_scope": " ".join(SCOPES), "redirect_uri": redirect_uri, } # Build query string oauth_url = f"https://slack.com/oauth/v2/authorize?{urlencode(params)}" logger.info("Generated Slack OAuth authorization URL") return oauth_url @override def callback(self, callback_data: dict[str, Any], redirect_uri: str) -> OAuthResult: """Handle the response from the OAuth flow and return it in a standard format. Args: callback_data: The data received from the OAuth callback (state already validated by API layer) Returns: Standardized OAuthResult """ # Extract authorization code from callback auth_code = callback_data.get("code") error = callback_data.get("error") if error: raise RuntimeError(f"OAuth error received: {error}") if not auth_code: raise ValueError("No authorization code received") # Exchange authorization code for access token token_response = self._exchange_code_for_token(auth_code, redirect_uri) if not token_response.get("ok"): raise RuntimeError( f"Failed to exchange authorization code for token: {token_response.get('error')}" ) # Build team info team_info = None if "team" in token_response: team_info = { "id": token_response["team"]["id"], "name": token_response["team"]["name"], } # Build user info and extract OAuth tokens if "authed_user" not in token_response: raise RuntimeError("Missing authed_user in OAuth response from Slack") authed_user = token_response["authed_user"] user_info = { "id": authed_user["id"], "scope": authed_user.get("scope"), "token_type": authed_user.get("token_type"), } # Extract OAuth tokens - bot token from root, user token from authed_user user_token = authed_user.get("access_token") # User token refresh_token = authed_user.get("refresh_token") token_type = authed_user.get("token_type", "bearer") scope = authed_user.get("scope") # Calculate expires_at from expires_in if present expires_at = None if "expires_in" in authed_user: expires_at = datetime.now(timezone.utc) + timedelta( seconds=authed_user["expires_in"] ) return OAuthResult( access_token=user_token, # Bot token for bot operations token_type=token_type, scope=scope, expires_at=expires_at, refresh_token=refresh_token, team=team_info, user=user_info, raw_response=token_response, ) def _exchange_code_for_token(self, code: str, redirect_uri: str) -> dict[str, Any]: """Exchange authorization code for access token. Args: code: Authorization code from OAuth callback Returns: Token response from Slack API """ response = requests.post( "https://slack.com/api/oauth.v2.access", data={ "client_id": self.slack_credentials.client_id, "client_secret": self.slack_credentials.client_secret, "code": code, "redirect_uri": redirect_uri, }, ) response.raise_for_status() return response.json() @override def search( self, query: ChunkIndexRequest, entities: dict[str, Any], access_token: str, limit: int | None = None, slack_event_context: SlackContext | None = None, bot_token: str | None = None, ) -> list[InferenceChunk]: """Perform a federated search on Slack. Args: query: The search query entities: Connector-level config (entity filtering configuration) access_token: The OAuth access token limit: Maximum number of results to return slack_event_context: Optional Slack context for slack bot bot_token: Optional bot token for slack bot Returns: Search results in SlackSearchResponse format """ logger.debug(f"Slack federated search called with entities: {entities}") # Get team_id from Slack API for caching and filtering team_id = None try: slack_client = WebClient(token=access_token) auth_response = slack_client.auth_test() auth_response.validate() # Cast response.data to dict for type checking auth_data: dict[str, Any] = auth_response.data # type: ignore team_id = auth_data.get("team_id") logger.debug(f"Slack team_id: {team_id}") except Exception as e: logger.warning(f"Could not fetch team_id from Slack API: {e}") with get_session_with_current_tenant() as db_session: return slack_retrieval( query, access_token, db_session, entities=entities, limit=limit, slack_event_context=slack_event_context, bot_token=bot_token, team_id=team_id, ) ================================================ FILE: backend/onyx/federated_connectors/slack/models.py ================================================ from typing import Optional from pydantic import BaseModel from pydantic import Field from pydantic import field_validator from pydantic import model_validator class SlackEntities(BaseModel): """Pydantic model for Slack federated search entities.""" # Channel filtering search_all_channels: bool = Field( default=True, description="Search all accessible channels. If not set, must specify channels below.", ) channels: Optional[list[str]] = Field( default=None, description="List of Slack channel names to search across.", ) exclude_channels: Optional[list[str]] = Field( default=None, description="List of channel names or patterns to exclude e.g. 'private-*, customer-*, secure-channel'.", ) # Direct message filtering include_dm: bool = Field( default=True, description="Include user direct messages in search results", ) include_group_dm: bool = Field( default=True, description="Include group direct messages (multi-person DMs) in search results", ) # Private channel filtering include_private_channels: bool = Field( default=True, description="Include private channels in search results (user must have access)", ) # Date range filtering default_search_days: int = Field( default=30, description="Maximum number of days to search back. Increasing this value degrades answer quality.", ) # Message count per slack request max_messages_per_query: int = Field( default=10, description=( "Maximum number of messages to retrieve per search query. " "Higher values increase API calls and may trigger rate limits." ), ) @field_validator("default_search_days") @classmethod def validate_default_search_days(cls, v: int) -> int: """Validate default_search_days is positive and reasonable""" if v < 1: raise ValueError("default_search_days must be at least 1") if v > 365: raise ValueError("default_search_days cannot exceed 365 days") return v @field_validator("max_messages_per_query") @classmethod def validate_max_messages_per_query(cls, v: int) -> int: """Validate max_messages_per_query is positive and reasonable""" if v < 1: raise ValueError("max_messages_per_query must be at least 1") if v > 100: raise ValueError("max_messages_per_query cannot exceed 100") return v @field_validator("channels") @classmethod def validate_channels(cls, v: Optional[list[str]]) -> Optional[list[str]]: """Validate each channel is a non-empty string""" if v is not None: if not isinstance(v, list): raise ValueError("channels must be a list") for channel in v: if not isinstance(channel, str) or not channel.strip(): raise ValueError("Each channel must be a non-empty string") return v @field_validator("exclude_channels") @classmethod def validate_exclude_patterns(cls, v: Optional[list[str]]) -> Optional[list[str]]: """Validate each exclude pattern is a non-empty string""" if v is None: return v for pattern in v: if not isinstance(pattern, str) or not pattern.strip(): raise ValueError("Each exclude pattern must be a non-empty string") return v @model_validator(mode="after") def validate_channel_config(self) -> "SlackEntities": """Validate search_all_channels configuration""" # If search_all_channels is False, channels list must be provided if not self.search_all_channels: if self.channels is None or len(self.channels) == 0: raise ValueError( "Must specify at least one channel when search_all_channels is False" ) return self class SlackCredentials(BaseModel): """Slack federated connector credentials.""" client_id: str = Field(..., description="Slack app client ID") client_secret: str = Field(..., description="Slack app client secret") @field_validator("client_id") @classmethod def validate_client_id(cls, v: str) -> str: if not v or not v.strip(): raise ValueError("Client ID cannot be empty") return v.strip() @field_validator("client_secret") @classmethod def validate_client_secret(cls, v: str) -> str: if not v or not v.strip(): raise ValueError("Client secret cannot be empty") return v.strip() class SlackTeamInfo(BaseModel): """Information about a Slack team/workspace.""" id: str = Field(..., description="Team ID") name: str = Field(..., description="Team name") domain: Optional[str] = Field(default=None, description="Team domain") class SlackUserInfo(BaseModel): """Information about a Slack user.""" id: str = Field(..., description="User ID") team_id: Optional[str] = Field(default=None, description="Team ID") name: Optional[str] = Field(default=None, description="User name") email: Optional[str] = Field(default=None, description="User email") class SlackSearchResult(BaseModel): """Individual search result from Slack.""" channel: str = Field(..., description="Channel where the message was found") timestamp: str = Field(..., description="Message timestamp") user: Optional[str] = Field(default=None, description="User who sent the message") text: str = Field(..., description="Message text") permalink: Optional[str] = Field( default=None, description="Permalink to the message" ) score: Optional[float] = Field(default=None, description="Search relevance score") # Additional context thread_ts: Optional[str] = Field( default=None, description="Thread timestamp if in a thread" ) reply_count: Optional[int] = Field( default=None, description="Number of replies if it's a thread" ) class SlackSearchResponse(BaseModel): """Response from Slack federated search.""" query: str = Field(..., description="The search query") total_count: int = Field(..., description="Total number of results") results: list[SlackSearchResult] = Field(..., description="Search results") next_cursor: Optional[str] = Field( default=None, description="Cursor for pagination" ) # Metadata channels_searched: Optional[list[str]] = Field( default=None, description="Channels that were searched" ) search_time_ms: Optional[int] = Field( default=None, description="Time taken to search in milliseconds" ) ================================================ FILE: backend/onyx/file_processing/__init__.py ================================================ ================================================ FILE: backend/onyx/file_processing/enums.py ================================================ from enum import Enum class HtmlBasedConnectorTransformLinksStrategy(str, Enum): # remove links entirely STRIP = "strip" # turn HTML links into markdown links MARKDOWN = "markdown" ================================================ FILE: backend/onyx/file_processing/extract_file_text.py ================================================ import csv import gc import io import json import os import re import zipfile from collections.abc import Callable from collections.abc import Iterator from collections.abc import Sequence from email.parser import Parser as EmailParser from io import BytesIO from pathlib import Path from typing import Any from typing import IO from typing import NamedTuple from typing import Optional from typing import TYPE_CHECKING from zipfile import BadZipFile import chardet import openpyxl from openpyxl.worksheet.worksheet import Worksheet from PIL import Image from onyx.configs.constants import ONYX_METADATA_FILENAME from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled from onyx.file_processing.file_types import OnyxFileExtensions from onyx.file_processing.file_types import OnyxMimeTypes from onyx.file_processing.file_types import PRESENTATION_MIME_TYPE from onyx.file_processing.file_types import WORD_PROCESSING_MIME_TYPE from onyx.file_processing.html_utils import parse_html_page_basic from onyx.file_processing.unstructured import get_unstructured_api_key from onyx.file_processing.unstructured import unstructured_to_text from onyx.utils.logger import setup_logger if TYPE_CHECKING: from markitdown import MarkItDown logger = setup_logger() TEXT_SECTION_SEPARATOR = "\n\n" _MARKITDOWN_CONVERTER: Optional["MarkItDown"] = None KNOWN_OPENPYXL_BUGS = [ "Value must be either numerical or a string containing a wildcard", "File contains no valid workbook part", "Unable to read workbook: could not read stylesheet from None", "Colors must be aRGB hex values", ] def get_markitdown_converter() -> "MarkItDown": global _MARKITDOWN_CONVERTER from markitdown import MarkItDown if _MARKITDOWN_CONVERTER is None: _MARKITDOWN_CONVERTER = MarkItDown(enable_plugins=False) return _MARKITDOWN_CONVERTER def get_file_ext(file_path_or_name: str | Path) -> str: _, extension = os.path.splitext(file_path_or_name) return extension.lower() def is_text_file(file: IO[bytes]) -> bool: """ checks if the first 1024 bytes only contain printable or whitespace characters if it does, then we say it's a plaintext file """ raw_data = file.read(1024) file.seek(0) text_chars = bytearray({7, 8, 9, 10, 12, 13, 27} | set(range(0x20, 0x100)) - {0x7F}) return all(c in text_chars for c in raw_data) def detect_encoding(file: IO[bytes]) -> str: raw_data = file.read(50000) file.seek(0) encoding = chardet.detect(raw_data)["encoding"] or "utf-8" return encoding def is_macos_resource_fork_file(file_name: str) -> bool: return os.path.basename(file_name).startswith("._") and file_name.startswith( "__MACOSX" ) def to_bytesio(stream: IO[bytes]) -> BytesIO: if isinstance(stream, BytesIO): return stream data = stream.read() # consumes the stream! return BytesIO(data) def load_files_from_zip( zip_file_io: IO, ignore_macos_resource_fork_files: bool = True, ignore_dirs: bool = True, ) -> Iterator[tuple[zipfile.ZipInfo, IO[Any]]]: """ Iterates through files in a zip archive, yielding (ZipInfo, file handle) pairs. """ with zipfile.ZipFile(zip_file_io, "r") as zip_file: for file_info in zip_file.infolist(): if ignore_dirs and file_info.is_dir(): continue if ( ignore_macos_resource_fork_files and is_macos_resource_fork_file(file_info.filename) ) or file_info.filename == ONYX_METADATA_FILENAME: continue with zip_file.open(file_info.filename, "r") as subfile: # Try to match by exact filename first yield file_info, subfile def _extract_onyx_metadata(line: str) -> dict | None: """ Example: first line has: or #ONYX_METADATA={"title":"..."} """ html_comment_pattern = r"" hashtag_pattern = r"#ONYX_METADATA=\{(.*?)\}" html_comment_match = re.search(html_comment_pattern, line) hashtag_match = re.search(hashtag_pattern, line) if html_comment_match: json_str = html_comment_match.group(1) elif hashtag_match: json_str = hashtag_match.group(1) else: return None try: return json.loads("{" + json_str + "}") except json.JSONDecodeError: return None def read_text_file( file: IO, encoding: str = "utf-8", errors: str = "replace", ignore_onyx_metadata: bool = True, ) -> tuple[str, dict]: """ For plain text files. Optionally extracts Onyx metadata from the first line. """ metadata = {} file_content_raw = "" for ind, line in enumerate(file): # decode try: line = line.decode(encoding) if isinstance(line, bytes) else line except UnicodeDecodeError: line = ( line.decode(encoding, errors=errors) if isinstance(line, bytes) else line ) # optionally parse metadata in the first line if ind == 0 and not ignore_onyx_metadata: potential_meta = _extract_onyx_metadata(line) if potential_meta is not None: metadata = potential_meta continue file_content_raw += line return file_content_raw, metadata def pdf_to_text(file: IO[Any], pdf_pass: str | None = None) -> str: """ Extract text from a PDF. For embedded images, a more complex approach is needed. This is a minimal approach returning text only. """ text, _, _ = read_pdf_file(file, pdf_pass) return text def read_pdf_file( file: IO[Any], pdf_pass: str | None = None, extract_images: bool = False, image_callback: Callable[[bytes, str], None] | None = None, ) -> tuple[str, dict[str, Any], Sequence[tuple[bytes, str]]]: """ Returns the text, basic PDF metadata, and optionally extracted images. """ from pypdf import PdfReader from pypdf.errors import PdfStreamError metadata: dict[str, Any] = {} extracted_images: list[tuple[bytes, str]] = [] try: pdf_reader = PdfReader(file) if pdf_reader.is_encrypted and pdf_pass is not None: decrypt_success = False try: decrypt_success = pdf_reader.decrypt(pdf_pass) != 0 except Exception: logger.error("Unable to decrypt pdf") if not decrypt_success: return "", metadata, [] elif pdf_reader.is_encrypted: logger.warning("No Password for an encrypted PDF, returning empty text.") return "", metadata, [] # Basic PDF metadata if pdf_reader.metadata is not None: for key, value in pdf_reader.metadata.items(): clean_key = key.lstrip("/") if isinstance(value, str) and value.strip(): metadata[clean_key] = value elif isinstance(value, list) and all( isinstance(item, str) for item in value ): metadata[clean_key] = ", ".join(value) text = TEXT_SECTION_SEPARATOR.join( page.extract_text() for page in pdf_reader.pages ) if extract_images: for page_num, page in enumerate(pdf_reader.pages): for image_file_object in page.images: image = Image.open(io.BytesIO(image_file_object.data)) img_byte_arr = io.BytesIO() image.save(img_byte_arr, format=image.format) img_bytes = img_byte_arr.getvalue() image_format = image.format.lower() if image.format else "png" image_name = f"page_{page_num + 1}_image_{image_file_object.name}.{image_format}" if image_callback is not None: # Stream image out immediately image_callback(img_bytes, image_name) else: extracted_images.append((img_bytes, image_name)) return text, metadata, extracted_images except PdfStreamError: logger.exception("Invalid PDF file") except Exception: logger.exception("Failed to read PDF") return "", metadata, [] def extract_docx_images(docx_bytes: IO[Any]) -> Iterator[tuple[bytes, str]]: """ Given the bytes of a docx file, extract all the images. Returns a list of tuples (image_bytes, image_name). """ try: with zipfile.ZipFile(docx_bytes) as z: for name in z.namelist(): if name.startswith("word/media/"): yield (z.read(name), name.split("/")[-1]) except Exception: logger.exception("Failed to extract all docx images") def read_docx_file( file: IO[Any], file_name: str = "", extract_images: bool = False, image_callback: Callable[[bytes, str], None] | None = None, ) -> tuple[str, Sequence[tuple[bytes, str]]]: """ Extract text from a docx. Return (text_content, list_of_images). The caller can choose to provide a callback to handle images with the intent of avoiding materializing the list of images in memory. The images list returned is empty in this case. """ md = get_markitdown_converter() from markitdown import ( StreamInfo, FileConversionException, UnsupportedFormatException, ) try: doc = md.convert( to_bytesio(file), stream_info=StreamInfo(mimetype=WORD_PROCESSING_MIME_TYPE) ) except ( BadZipFile, ValueError, FileConversionException, UnsupportedFormatException, ) as e: logger.warning( f"Failed to extract docx {file_name or 'docx file'}: {e}. Attempting to read as text file." ) # May be an invalid docx, but still a valid text file file.seek(0) encoding = detect_encoding(file) text_content_raw, _ = read_text_file( file, encoding=encoding, ignore_onyx_metadata=False ) return text_content_raw or "", [] file.seek(0) if extract_images: if image_callback is None: return doc.markdown, list(extract_docx_images(to_bytesio(file))) # If a callback is provided, iterate and stream images without accumulating try: for img_file_bytes, img_file_name in extract_docx_images(to_bytesio(file)): image_callback(img_file_bytes, img_file_name) except Exception: logger.exception("Failed to stream docx images") return doc.markdown, [] def pptx_to_text(file: IO[Any], file_name: str = "") -> str: md = get_markitdown_converter() from markitdown import ( StreamInfo, FileConversionException, UnsupportedFormatException, ) stream_info = StreamInfo( mimetype=PRESENTATION_MIME_TYPE, filename=file_name or None, extension=".pptx" ) try: presentation = md.convert(to_bytesio(file), stream_info=stream_info) except ( BadZipFile, ValueError, FileConversionException, UnsupportedFormatException, ) as e: error_str = f"Failed to extract text from {file_name or 'pptx file'}: {e}" logger.warning(error_str) return "" return presentation.markdown def _worksheet_to_matrix( worksheet: Worksheet, ) -> list[list[str]]: """ Converts a singular worksheet to a matrix of values """ rows: list[list[str]] = [] for worksheet_row in worksheet.iter_rows(min_row=1, values_only=True): row = ["" if cell is None else str(cell) for cell in worksheet_row] rows.append(row) return rows def _clean_worksheet_matrix(matrix: list[list[str]]) -> list[list[str]]: """ Cleans a worksheet matrix by removing rows if there are N consecutive empty rows and removing cols if there are M consecutive empty columns """ MAX_EMPTY_ROWS = 2 # Runs longer than this are capped to max_empty; shorter runs are preserved as-is MAX_EMPTY_COLS = 2 # Row cleanup matrix = _remove_empty_runs(matrix, max_empty=MAX_EMPTY_ROWS) if not matrix: return matrix # Column cleanup — determine which columns to keep without transposing. num_cols = len(matrix[0]) keep_cols = _columns_to_keep(matrix, num_cols, max_empty=MAX_EMPTY_COLS) if len(keep_cols) < num_cols: matrix = [[row[c] for c in keep_cols] for row in matrix] return matrix def _columns_to_keep( matrix: list[list[str]], num_cols: int, max_empty: int ) -> list[int]: """Return the indices of columns to keep after removing empty-column runs. Uses the same logic as ``_remove_empty_runs`` but operates on column indices so no transpose is needed. """ kept: list[int] = [] empty_buffer: list[int] = [] for col_idx in range(num_cols): col_is_empty = all(not row[col_idx] for row in matrix) if col_is_empty: empty_buffer.append(col_idx) else: kept.extend(empty_buffer[:max_empty]) kept.append(col_idx) empty_buffer = [] return kept def _remove_empty_runs( rows: list[list[str]], max_empty: int, ) -> list[list[str]]: """Removes entire runs of empty rows when the run length exceeds max_empty. Leading empty runs are capped to max_empty, just like interior runs. Trailing empty rows are always dropped since there is no subsequent non-empty row to flush them. """ result: list[list[str]] = [] empty_buffer: list[list[str]] = [] for row in rows: # Check if empty if not any(row): if len(empty_buffer) < max_empty: empty_buffer.append(row) else: # Add upto max empty rows onto the result - that's what we allow result.extend(empty_buffer[:max_empty]) # Add the new non-empty row result.append(row) empty_buffer = [] return result def xlsx_to_text(file: IO[Any], file_name: str = "") -> str: # TODO: switch back to this approach in a few months when markitdown # fixes their handling of excel files # md = get_markitdown_converter() # stream_info = StreamInfo( # mimetype=SPREADSHEET_MIME_TYPE, filename=file_name or None, extension=".xlsx" # ) # try: # workbook = md.convert(to_bytesio(file), stream_info=stream_info) # except ( # BadZipFile, # ValueError, # FileConversionException, # UnsupportedFormatException, # ) as e: # error_str = f"Failed to extract text from {file_name or 'xlsx file'}: {e}" # if file_name.startswith("~"): # logger.debug(error_str + " (this is expected for files with ~)") # else: # logger.warning(error_str) # return "" # return workbook.markdown try: workbook = openpyxl.load_workbook(file, read_only=True) except BadZipFile as e: error_str = f"Failed to extract text from {file_name or 'xlsx file'}: {e}" if file_name.startswith("~"): logger.debug(error_str + " (this is expected for files with ~)") else: logger.warning(error_str) return "" except Exception as e: if any(s in str(e) for s in KNOWN_OPENPYXL_BUGS): logger.error( f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}" ) return "" raise text_content = [] for sheet in workbook.worksheets: sheet_matrix = _clean_worksheet_matrix(_worksheet_to_matrix(sheet)) buf = io.StringIO() writer = csv.writer(buf, lineterminator="\n") writer.writerows(sheet_matrix) text_content.append(buf.getvalue().rstrip("\n")) return TEXT_SECTION_SEPARATOR.join(text_content) def eml_to_text(file: IO[Any]) -> str: encoding = detect_encoding(file) text_file = io.TextIOWrapper(file, encoding=encoding) parser = EmailParser() try: message = parser.parse(text_file) finally: try: # Keep underlying upload handle open for downstream consumers. raw_file = text_file.detach() except Exception as detach_error: logger.warning( f"Failed to detach TextIOWrapper for EML upload, using original file: {detach_error}" ) raw_file = file try: raw_file.seek(0) except Exception: pass text_content = [] for part in message.walk(): if part.get_content_type().startswith("text/plain"): payload = part.get_payload() if isinstance(payload, str): text_content.append(payload) elif isinstance(payload, list): text_content.extend(item for item in payload if isinstance(item, str)) else: logger.warning(f"Unexpected payload type: {type(payload)}") return TEXT_SECTION_SEPARATOR.join(text_content) def epub_to_text(file: IO[Any]) -> str: with zipfile.ZipFile(file) as epub: text_content = [] for item in epub.infolist(): if item.filename.endswith(".xhtml") or item.filename.endswith(".html"): with epub.open(item) as html_file: text_content.append(parse_html_page_basic(html_file)) return TEXT_SECTION_SEPARATOR.join(text_content) def file_io_to_text(file: IO[Any]) -> str: encoding = detect_encoding(file) file_content, _ = read_text_file(file, encoding=encoding) return file_content def extract_file_text( file: IO[Any], file_name: str, break_on_unprocessable: bool = True, extension: str | None = None, ) -> str: """ Legacy function that returns *only text*, ignoring embedded images. For backward-compatibility in code that only wants text. NOTE: Ignoring seems to be defined as returning an empty string for files it can't handle (such as images). """ extension_to_function: dict[str, Callable[[IO[Any]], str]] = { ".pdf": pdf_to_text, ".docx": lambda f: read_docx_file(f, file_name)[0], # no images ".pptx": lambda f: pptx_to_text(f, file_name), ".xlsx": lambda f: xlsx_to_text(f, file_name), ".eml": eml_to_text, ".epub": epub_to_text, ".html": parse_html_page_basic, } try: if get_unstructured_api_key(): try: return unstructured_to_text(file, file_name) except Exception as unstructured_error: logger.error( f"Failed to process with Unstructured: {str(unstructured_error)}. Falling back to normal processing." ) if extension is None: extension = get_file_ext(file_name) if extension in OnyxFileExtensions.TEXT_AND_DOCUMENT_EXTENSIONS: func = extension_to_function.get(extension, file_io_to_text) file.seek(0) return func(file) # If unknown extension, maybe it's a text file file.seek(0) if is_text_file(file): return file_io_to_text(file) raise ValueError("Unknown file extension or not recognized as text data") except Exception as e: if break_on_unprocessable: raise RuntimeError( f"Failed to process file {file_name or 'Unknown'}: {str(e)}" ) from e logger.warning(f"Failed to process file {file_name or 'Unknown'}: {str(e)}") return "" class ExtractionResult(NamedTuple): """Structured result from text and image extraction from various file types.""" text_content: str embedded_images: Sequence[tuple[bytes, str]] metadata: dict[str, Any] def extract_result_from_text_file(file: IO[Any]) -> ExtractionResult: encoding = detect_encoding(file) text_content_raw, file_metadata = read_text_file( file, encoding=encoding, ignore_onyx_metadata=False ) return ExtractionResult( text_content=text_content_raw, embedded_images=[], metadata=file_metadata, ) def extract_text_and_images( file: IO[Any], file_name: str, pdf_pass: str | None = None, content_type: str | None = None, image_callback: Callable[[bytes, str], None] | None = None, ) -> ExtractionResult: """ Primary new function for the updated connector. Returns structured extraction result with text content, embedded images, and metadata. Args: file: File-like object to extract content from. file_name: Name of the file (used to determine extension/type). pdf_pass: Optional password for encrypted PDFs. content_type: Optional MIME type override for the file. image_callback: Optional callback for streaming image extraction. When provided, embedded images are passed to this callback one at a time as (bytes, filename) instead of being accumulated in the returned ExtractionResult.embedded_images list. This is a memory optimization for large documents with many images - the caller can process/store each image immediately rather than holding all images in memory. When using a callback, ExtractionResult.embedded_images will be an empty list. Returns: ExtractionResult containing text_content, embedded_images (empty if callback used), and metadata extracted from the file. """ res = _extract_text_and_images( file, file_name, pdf_pass, content_type, image_callback ) # Clean up any temporary objects and force garbage collection unreachable = gc.collect() logger.info(f"Unreachable objects: {unreachable}") return res def _extract_text_and_images( file: IO[Any], file_name: str, pdf_pass: str | None = None, content_type: str | None = None, image_callback: Callable[[bytes, str], None] | None = None, ) -> ExtractionResult: file.seek(0) if get_unstructured_api_key(): try: text_content = unstructured_to_text(file, file_name) return ExtractionResult( text_content=text_content, embedded_images=[], metadata={} ) except Exception as e: logger.error( f"Failed to process with Unstructured: {str(e)}. Falling back to normal processing." ) file.seek(0) # Reset file pointer just in case # When we upload a document via a connector or MyDocuments, we extract and store the content of files # with content types in UploadMimeTypes.DOCUMENT_MIME_TYPES as plain text files. # As a result, the file name extension may differ from the original content type. # We process files with a plain text content type first to handle this scenario. if content_type in OnyxMimeTypes.TEXT_MIME_TYPES: return extract_result_from_text_file(file) # Default processing try: extension = get_file_ext(file_name) # docx example for embedded images if extension == ".docx": text_content, images = read_docx_file( file, file_name, extract_images=True, image_callback=image_callback ) return ExtractionResult( text_content=text_content, embedded_images=images, metadata={} ) # PDF example: we do not show complicated PDF image extraction here # so we simply extract text for now and skip images. if extension == ".pdf": text_content, pdf_metadata, images = read_pdf_file( file, pdf_pass, extract_images=get_image_extraction_and_analysis_enabled(), image_callback=image_callback, ) return ExtractionResult( text_content=text_content, embedded_images=images, metadata=pdf_metadata ) # For PPTX, XLSX, EML, etc., we do not show embedded image logic here. # You can do something similar to docx if needed. if extension == ".pptx": return ExtractionResult( text_content=pptx_to_text(file, file_name=file_name), embedded_images=[], metadata={}, ) if extension == ".xlsx": return ExtractionResult( text_content=xlsx_to_text(file, file_name=file_name), embedded_images=[], metadata={}, ) if extension == ".eml": return ExtractionResult( text_content=eml_to_text(file), embedded_images=[], metadata={} ) if extension == ".epub": return ExtractionResult( text_content=epub_to_text(file), embedded_images=[], metadata={} ) if extension == ".html": return ExtractionResult( text_content=parse_html_page_basic(file), embedded_images=[], metadata={}, ) # If we reach here and it's a recognized text extension if extension in OnyxFileExtensions.PLAIN_TEXT_EXTENSIONS: return extract_result_from_text_file(file) # If it's an image file or something else, we do not parse embedded images from them # just return empty text return ExtractionResult(text_content="", embedded_images=[], metadata={}) except Exception as e: logger.exception(f"Failed to extract text/images from {file_name}: {e}") return ExtractionResult(text_content="", embedded_images=[], metadata={}) def docx_to_txt_filename(file_path: str) -> str: return file_path.rsplit(".", 1)[0] + ".txt" ================================================ FILE: backend/onyx/file_processing/file_types.py ================================================ PRESENTATION_MIME_TYPE = ( "application/vnd.openxmlformats-officedocument.presentationml.presentation" ) SPREADSHEET_MIME_TYPE = ( "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet" ) WORD_PROCESSING_MIME_TYPE = ( "application/vnd.openxmlformats-officedocument.wordprocessingml.document" ) PDF_MIME_TYPE = "application/pdf" PLAIN_TEXT_MIME_TYPE = "text/plain" class OnyxMimeTypes: IMAGE_MIME_TYPES = {"image/jpg", "image/jpeg", "image/png", "image/webp"} CSV_MIME_TYPES = {"text/csv"} TABULAR_MIME_TYPES = CSV_MIME_TYPES | {SPREADSHEET_MIME_TYPE} TEXT_MIME_TYPES = { PLAIN_TEXT_MIME_TYPE, "text/markdown", "text/x-markdown", "text/x-log", "text/x-config", "text/tab-separated-values", "application/json", "application/xml", "text/xml", "application/x-yaml", "application/yaml", "text/yaml", "text/x-yaml", } DOCUMENT_MIME_TYPES = { PDF_MIME_TYPE, WORD_PROCESSING_MIME_TYPE, PRESENTATION_MIME_TYPE, "message/rfc822", "application/epub+zip", } ALLOWED_MIME_TYPES = IMAGE_MIME_TYPES.union( TEXT_MIME_TYPES, DOCUMENT_MIME_TYPES, TABULAR_MIME_TYPES ) EXCLUDED_IMAGE_TYPES = { "image/bmp", "image/tiff", "image/gif", "image/svg+xml", "image/avif", } class OnyxFileExtensions: TABULAR_EXTENSIONS = { ".csv", ".tsv", ".xlsx", } PLAIN_TEXT_EXTENSIONS = { ".txt", ".md", ".mdx", ".conf", ".log", ".json", ".csv", ".tsv", ".xml", ".yml", ".yaml", ".sql", } DOCUMENT_EXTENSIONS = { ".pdf", ".docx", ".pptx", ".xlsx", ".eml", ".epub", ".html", } IMAGE_EXTENSIONS = { ".png", ".jpg", ".jpeg", ".webp", } TEXT_AND_DOCUMENT_EXTENSIONS = PLAIN_TEXT_EXTENSIONS.union(DOCUMENT_EXTENSIONS) ALL_ALLOWED_EXTENSIONS = TEXT_AND_DOCUMENT_EXTENSIONS.union(IMAGE_EXTENSIONS) ================================================ FILE: backend/onyx/file_processing/html_utils.py ================================================ import re from copy import copy from dataclasses import dataclass from io import BytesIO from typing import IO import bs4 from onyx.configs.app_configs import HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY from onyx.configs.app_configs import PARSE_WITH_TRAFILATURA from onyx.configs.app_configs import WEB_CONNECTOR_IGNORED_CLASSES from onyx.configs.app_configs import WEB_CONNECTOR_IGNORED_ELEMENTS from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy from onyx.utils.logger import setup_logger logger = setup_logger() MINTLIFY_UNWANTED = ["sticky", "hidden"] @dataclass class ParsedHTML: title: str | None cleaned_text: str def strip_excessive_newlines_and_spaces(document: str) -> str: # collapse repeated spaces into one document = re.sub(r" +", " ", document) # remove trailing spaces document = re.sub(r" +[\n\r]", "\n", document) # remove repeated newlines document = re.sub(r"[\n\r]+", "\n", document) return document.strip() def strip_newlines(document: str) -> str: # HTML might contain newlines which are just whitespaces to a browser return re.sub(r"[\n\r]+", " ", document) def format_element_text(element_text: str, link_href: str | None) -> str: element_text_no_newlines = strip_newlines(element_text) if ( not link_href or HTML_BASED_CONNECTOR_TRANSFORM_LINKS_STRATEGY == HtmlBasedConnectorTransformLinksStrategy.STRIP ): return element_text_no_newlines return f"[{element_text_no_newlines}]({link_href})" def parse_html_with_trafilatura(html_content: str) -> str: """Parse HTML content using trafilatura.""" import trafilatura # type: ignore from trafilatura.settings import use_config # type: ignore config = use_config() config.set("DEFAULT", "include_links", "True") config.set("DEFAULT", "include_tables", "True") config.set("DEFAULT", "include_images", "True") config.set("DEFAULT", "include_formatting", "True") extracted_text = trafilatura.extract(html_content, config=config) return strip_excessive_newlines_and_spaces(extracted_text) if extracted_text else "" def format_document_soup( document: bs4.BeautifulSoup, table_cell_separator: str = "\t" ) -> str: """Format html to a flat text document. The following goals: - Newlines from within the HTML are removed (as browser would ignore them as well). - Repeated newlines/spaces are removed (as browsers would ignore them). - Newlines only before and after headlines and paragraphs or when explicit (br or pre tag) - Table columns/rows are separated by newline - List elements are separated by newline and start with a hyphen """ text = "" list_element_start = False verbatim_output = 0 in_table = False last_added_newline = False link_href: str | None = None for e in document.descendants: verbatim_output -= 1 if isinstance(e, bs4.element.NavigableString): if isinstance(e, (bs4.element.Comment, bs4.element.Doctype)): continue element_text = e.text if in_table: # Tables are represented in natural language with rows separated by newlines # Can't have newlines then in the table elements element_text = element_text.replace("\n", " ").strip() # Some tags are translated to spaces but in the logic underneath this section, we # translate them to newlines as a browser should render them such as with br # This logic here avoids a space after newline when it shouldn't be there. if last_added_newline and element_text.startswith(" "): element_text = element_text[1:] last_added_newline = False if element_text: content_to_add = ( element_text if verbatim_output > 0 else format_element_text(element_text, link_href) ) # Don't join separate elements without any spacing if (text and not text[-1].isspace()) and ( content_to_add and not content_to_add[0].isspace() ): text += " " text += content_to_add list_element_start = False elif isinstance(e, bs4.element.Tag): # table is standard HTML element if e.name == "table": in_table = True # tr is for rows elif e.name == "tr" and in_table: text += "\n" # td for data cell, th for header elif e.name in ["td", "th"] and in_table: text += table_cell_separator elif e.name == "/table": in_table = False elif in_table: # don't handle other cases while in table pass elif e.name == "a": href_value = e.get("href", None) # mostly for typing, having multiple hrefs is not valid HTML link_href = ( href_value[0] if isinstance(href_value, list) else href_value ) elif e.name == "/a": link_href = None elif e.name in ["p", "div"]: if not list_element_start: text += "\n" elif e.name in ["h1", "h2", "h3", "h4"]: text += "\n" list_element_start = False last_added_newline = True elif e.name == "br": text += "\n" list_element_start = False last_added_newline = True elif e.name == "li": text += "\n- " list_element_start = True elif e.name == "pre": if verbatim_output <= 0: verbatim_output = len(list(e.childGenerator())) return strip_excessive_newlines_and_spaces(text) def parse_html_page_basic(text: str | BytesIO | IO[bytes]) -> str: soup = bs4.BeautifulSoup(text, "lxml") return format_document_soup(soup) def web_html_cleanup( page_content: str | bs4.BeautifulSoup, mintlify_cleanup_enabled: bool = True, additional_element_types_to_discard: list[str] | None = None, ) -> ParsedHTML: if isinstance(page_content, str): soup = bs4.BeautifulSoup(page_content, "lxml") else: soup = page_content title_tag = soup.find("title") title = None if title_tag and title_tag.text: title = title_tag.text title_tag.extract() # Heuristics based cleaning of elements based on css classes unwanted_classes = copy(WEB_CONNECTOR_IGNORED_CLASSES) if mintlify_cleanup_enabled: unwanted_classes.extend(MINTLIFY_UNWANTED) for undesired_element in unwanted_classes: [ tag.extract() for tag in soup.find_all( class_=lambda x: x and undesired_element in x.split() ) ] for undesired_tag in WEB_CONNECTOR_IGNORED_ELEMENTS: [tag.extract() for tag in soup.find_all(undesired_tag)] if additional_element_types_to_discard: for undesired_tag in additional_element_types_to_discard: [tag.extract() for tag in soup.find_all(undesired_tag)] soup_string = str(soup) page_text = "" if PARSE_WITH_TRAFILATURA: try: page_text = parse_html_with_trafilatura(soup_string) if not page_text: raise ValueError("Empty content returned by trafilatura.") except Exception as e: logger.info(f"Trafilatura parsing failed: {e}. Falling back on bs4.") page_text = format_document_soup(soup) else: page_text = format_document_soup(soup) # 200B is ZeroWidthSpace which we don't care for cleaned_text = page_text.replace("\u200b", "") return ParsedHTML(title=title, cleaned_text=cleaned_text) ================================================ FILE: backend/onyx/file_processing/image_summarization.py ================================================ import base64 from io import BytesIO from PIL import Image from onyx.configs.app_configs import IMAGE_SUMMARIZATION_SYSTEM_PROMPT from onyx.configs.app_configs import IMAGE_SUMMARIZATION_USER_PROMPT from onyx.llm.interfaces import LLM from onyx.llm.models import ChatCompletionMessage from onyx.llm.models import ContentPart from onyx.llm.models import ImageContentPart from onyx.llm.models import ImageUrlDetail from onyx.llm.models import SystemMessage from onyx.llm.models import TextContentPart from onyx.llm.models import UserMessage from onyx.llm.utils import llm_response_to_string from onyx.tracing.llm_utils import llm_generation_span from onyx.tracing.llm_utils import record_llm_response from onyx.utils.b64 import get_image_type_from_bytes from onyx.utils.logger import setup_logger logger = setup_logger() class UnsupportedImageFormatError(ValueError): """Raised when an image uses a MIME type unsupported by the summarization flow.""" def prepare_image_bytes(image_data: bytes) -> str: """Prepare image bytes for summarization. Resizes image if it's larger than 20MB. Encodes image as a base64 string.""" image_data = _resize_image_if_needed(image_data) # encode image (base64) encoded_image = _encode_image_for_llm_prompt(image_data) return encoded_image def summarize_image_pipeline( llm: LLM, image_data: bytes, query: str | None = None, system_prompt: str | None = None, ) -> str: """Pipeline to generate a summary of an image. Resizes images if it is bigger than 20MB. Encodes image as a base64 string. And finally uses the Default LLM to generate a textual summary of the image.""" # resize image if it's bigger than 20MB encoded_image = prepare_image_bytes(image_data) summary = _summarize_image( encoded_image, llm, query, system_prompt, ) return summary def summarize_image_with_error_handling( llm: LLM | None, image_data: bytes, context_name: str, system_prompt: str = IMAGE_SUMMARIZATION_SYSTEM_PROMPT, user_prompt_template: str = IMAGE_SUMMARIZATION_USER_PROMPT, ) -> str | None: """Wrapper function that handles error cases and configuration consistently. Args: llm: The LLM with vision capabilities to use for summarization image_data: The raw image bytes context_name: Name or title of the image for context system_prompt: System prompt to use for the LLM user_prompt_template: User prompt to use (without title) Returns: The image summary text, or None if summarization failed or is disabled """ if llm is None: return None # Prepend the image filename to the user prompt user_prompt = ( f"The image has the file name '{context_name}'.\n{user_prompt_template}" ) try: return summarize_image_pipeline(llm, image_data, user_prompt, system_prompt) except UnsupportedImageFormatError: magic_hex = image_data[:8].hex() if image_data else "empty" logger.info( "Skipping image summarization due to unsupported MIME type " "for %s (magic_bytes=%s, size=%d bytes)", context_name, magic_hex, len(image_data), ) return None def _summarize_image( encoded_image: str, llm: LLM, query: str | None = None, system_prompt: str | None = None, ) -> str: """Use default LLM (if it is multimodal) to generate a summary of an image.""" messages: list[ChatCompletionMessage] = [] if system_prompt: messages.append(SystemMessage(content=system_prompt)) content: list[ContentPart] = [] if query: content.append(TextContentPart(text=query)) content.append(ImageContentPart(image_url=ImageUrlDetail(url=encoded_image))) messages.append( UserMessage( content=content, ), ) try: # Call LLM with Braintrust tracing with llm_generation_span( llm=llm, flow="image_summarization", input_messages=[{"type": "image_summarization_request"}], ) as span_generation: # Note: We don't include the actual image in the span input to avoid bloating traces response = llm.invoke(messages) record_llm_response(span_generation, response) summary = llm_response_to_string(response) return summary except Exception as e: # Extract structured details from LiteLLM exceptions when available, # rather than dumping the full messages payload (which contains base64 # image data and produces enormous, unreadable error logs). str_e = str(e) if len(str_e) > 512: str_e = str_e[:512] + "... (truncated)" parts = [f"Summarization failed: {type(e).__name__}: {str_e}"] status_code = getattr(e, "status_code", None) llm_provider = getattr(e, "llm_provider", None) model = getattr(e, "model", None) if status_code is not None: parts.append(f"status_code={status_code}") if llm_provider is not None: parts.append(f"llm_provider={llm_provider}") if model is not None: parts.append(f"model={model}") raise ValueError(" | ".join(parts)) from e def _encode_image_for_llm_prompt(image_data: bytes) -> str: """Prepare a data URL with the correct MIME type for the LLM message.""" try: mime_type = get_image_type_from_bytes(image_data) except ValueError as exc: raise UnsupportedImageFormatError( "Unsupported image format for summarization" ) from exc base64_encoded_data = base64.b64encode(image_data).decode("utf-8") return f"data:{mime_type};base64,{base64_encoded_data}" def _resize_image_if_needed(image_data: bytes, max_size_mb: int = 20) -> bytes: """Resize image if it's larger than the specified max size in MB.""" max_size_bytes = max_size_mb * 1024 * 1024 if len(image_data) > max_size_bytes: with Image.open(BytesIO(image_data)) as img: # Reduce dimensions for better size reduction img.thumbnail((1024, 1024), Image.Resampling.LANCZOS) output = BytesIO() # Save with lower quality for compression img.save(output, format="JPEG", quality=85) resized_data = output.getvalue() return resized_data return image_data ================================================ FILE: backend/onyx/file_processing/image_utils.py ================================================ from io import BytesIO from typing import Tuple from onyx.configs.constants import FileOrigin from onyx.connectors.models import ImageSection from onyx.file_store.file_store import get_default_file_store from onyx.utils.logger import setup_logger logger = setup_logger() def store_image_and_create_section( image_data: bytes, file_id: str, display_name: str, link: str | None = None, media_type: str = "application/octet-stream", file_origin: FileOrigin = FileOrigin.OTHER, ) -> Tuple[ImageSection, str | None]: """ Stores an image in FileStore and creates an ImageSection object without summarization. Args: image_data: Raw image bytes file_id: Base identifier for the file display_name: Human-readable name for the image media_type: MIME type of the image file_origin: Origin of the file (e.g., CONFLUENCE, GOOGLE_DRIVE, etc.) Returns: Tuple containing: - ImageSection object with image reference - The file_id in FileStore or None if storage failed """ # Storage logic try: file_store = get_default_file_store() file_id = file_store.save_file( content=BytesIO(image_data), display_name=display_name, file_origin=file_origin, file_type=media_type, file_id=file_id, ) except Exception as e: logger.error(f"Failed to store image: {e}") raise e # Create an ImageSection with empty text (will be filled by LLM later in the pipeline) return ( ImageSection(image_file_id=file_id, link=link), file_id, ) ================================================ FILE: backend/onyx/file_processing/password_validation.py ================================================ from collections.abc import Callable from collections.abc import Generator from contextlib import contextmanager from typing import Any from typing import IO from onyx.file_processing.extract_file_text import get_file_ext from onyx.utils.logger import setup_logger logger = setup_logger() PASSWORD_PROTECTED_FILES = [ ".pdf", ".docx", ".pptx", ".xlsx", ] @contextmanager def preserve_position(file: IO[Any]) -> Generator[IO[Any], None, None]: """Preserves the file's cursor position""" pos = file.tell() try: file.seek(0) yield file finally: file.seek(pos) def is_pdf_protected(file: IO[Any]) -> bool: from pypdf import PdfReader with preserve_position(file): reader = PdfReader(file) return bool(reader.is_encrypted) def is_docx_protected(file: IO[Any]) -> bool: return is_office_file_protected(file) def is_pptx_protected(file: IO[Any]) -> bool: return is_office_file_protected(file) def is_xlsx_protected(file: IO[Any]) -> bool: return is_office_file_protected(file) def is_office_file_protected(file: IO[Any]) -> bool: import msoffcrypto # type: ignore[import-untyped] with preserve_position(file): office = msoffcrypto.OfficeFile(file) return office.is_encrypted() def is_file_password_protected( file: IO[Any], file_name: str, extension: str | None = None, ) -> bool: extension_to_function: dict[str, Callable[[IO[Any]], bool]] = { ".pdf": is_pdf_protected, ".docx": is_docx_protected, ".pptx": is_pptx_protected, ".xlsx": is_xlsx_protected, } if not extension: extension = get_file_ext(file_name) if extension not in PASSWORD_PROTECTED_FILES: return False if extension not in extension_to_function: logger.warning( f"Extension={extension} can be password protected, but no function found" ) return False func = extension_to_function[extension] return func(file) ================================================ FILE: backend/onyx/file_processing/unstructured.py ================================================ from typing import Any from typing import cast from typing import IO from typing import TYPE_CHECKING from onyx.configs.constants import KV_UNSTRUCTURED_API_KEY from onyx.key_value_store.factory import get_kv_store from onyx.key_value_store.interface import KvKeyNotFoundError from onyx.utils.logger import setup_logger if TYPE_CHECKING: from unstructured_client.models import operations logger = setup_logger() def get_unstructured_api_key() -> str | None: kv_store = get_kv_store() try: return cast(str, kv_store.load(KV_UNSTRUCTURED_API_KEY)) except KvKeyNotFoundError: return None def update_unstructured_api_key(api_key: str) -> None: kv_store = get_kv_store() kv_store.store(KV_UNSTRUCTURED_API_KEY, api_key) def delete_unstructured_api_key() -> None: kv_store = get_kv_store() kv_store.delete(KV_UNSTRUCTURED_API_KEY) def _sdk_partition_request( file: IO[Any], file_name: str, **kwargs: Any ) -> "operations.PartitionRequest": from unstructured_client.models import operations from unstructured_client.models import shared file.seek(0, 0) try: request = operations.PartitionRequest( partition_parameters=shared.PartitionParameters( files=shared.Files(content=file.read(), file_name=file_name), **kwargs, ), ) return request except Exception as e: logger.error(f"Error creating partition request for file {file_name}: {str(e)}") raise def unstructured_to_text(file: IO[Any], file_name: str) -> str: from unstructured.staging.base import dict_to_elements from unstructured_client import UnstructuredClient logger.debug(f"Starting to read file: {file_name}") req = _sdk_partition_request(file, file_name, strategy="fast") unstructured_client = UnstructuredClient(api_key_auth=get_unstructured_api_key()) response = unstructured_client.general.partition(request=req) if response.status_code != 200: err = f"Received unexpected status code {response.status_code} from Unstructured API." logger.error(err) raise ValueError(err) elements = dict_to_elements(response.elements or []) return "\n\n".join(str(el) for el in elements) ================================================ FILE: backend/onyx/file_store/README.md ================================================ # Onyx File Store The Onyx file store provides a unified interface for storing files and large binary objects in S3-compatible storage systems. It supports AWS S3, MinIO, Azure Blob Storage, Digital Ocean Spaces, and other S3-compatible services. ## Architecture The file store uses a single database table (`file_record`) to store file metadata while the actual file content is stored in external S3-compatible storage. This approach provides scalability, cost-effectiveness, and decouples file storage from the database. ### Database Schema The `file_record` table contains the following columns: - `file_id` (primary key): Unique identifier for the file - `display_name`: Human-readable name for the file - `file_origin`: Origin/source of the file (enum) - `file_type`: MIME type of the file - `file_metadata`: Additional metadata as JSON - `bucket_name`: External storage bucket/container name - `object_key`: External storage object key/path - `created_at`: Timestamp when the file was created - `updated_at`: Timestamp when the file was last updated ## Storage Backend ### S3-Compatible Storage Stores files in external S3-compatible storage systems while keeping metadata in the database. **Pros:** - Scalable storage - Cost-effective for large files - CDN integration possible - Decoupled from database - Wide ecosystem support **Cons:** - Additional infrastructure required - Network dependency - Eventual consistency considerations ## Configuration All configuration is handled via environment variables. The system requires S3-compatible storage to be configured. ### AWS S3 ```bash S3_FILE_STORE_BUCKET_NAME=your-bucket-name # Defaults to 'onyx-file-store-bucket' S3_FILE_STORE_PREFIX=onyx-files # Optional, defaults to 'onyx-files' # AWS credentials (use one of these methods): # 1. Environment variables S3_AWS_ACCESS_KEY_ID=your-access-key S3_AWS_SECRET_ACCESS_KEY=your-secret-key AWS_REGION_NAME=us-east-2 # Optional, defaults to 'us-east-2' # 2. IAM roles (recommended for EC2/ECS deployments) # No additional configuration needed if using IAM roles ``` ### MinIO ```bash S3_FILE_STORE_BUCKET_NAME=your-bucket-name S3_ENDPOINT_URL=http://localhost:9000 # MinIO endpoint S3_AWS_ACCESS_KEY_ID=minioadmin S3_AWS_SECRET_ACCESS_KEY=minioadmin AWS_REGION_NAME=us-east-1 # Any region name S3_VERIFY_SSL=false # Optional, defaults to false ``` ### Digital Ocean Spaces ```bash S3_FILE_STORE_BUCKET_NAME=your-space-name S3_ENDPOINT_URL=https://nyc3.digitaloceanspaces.com S3_AWS_ACCESS_KEY_ID=your-spaces-key S3_AWS_SECRET_ACCESS_KEY=your-spaces-secret AWS_REGION_NAME=nyc3 ``` ### Other S3-Compatible Services The file store works with any S3-compatible service. Simply configure: - `S3_FILE_STORE_BUCKET_NAME`: Your bucket/container name - `S3_ENDPOINT_URL`: The service endpoint URL - `S3_AWS_ACCESS_KEY_ID` and `S3_AWS_SECRET_ACCESS_KEY`: Your credentials - `AWS_REGION_NAME`: The region (any valid region name) ## Implementation The system uses the `S3BackedFileStore` class that implements the abstract `FileStore` interface. The database uses generic column names (`bucket_name`, `object_key`) to maintain compatibility with different S3-compatible services. ### File Store Interface The `FileStore` abstract base class defines the following methods: - `initialize()`: Initialize the storage backend (create bucket if needed) - `has_file(file_id, file_origin, file_type)`: Check if a file exists - `save_file(content, display_name, file_origin, file_type, file_metadata, file_id)`: Save a file - `read_file(file_id, mode, use_tempfile)`: Read file content - `read_file_record(file_id)`: Get file metadata from database - `delete_file(file_id)`: Delete a file and its metadata - `get_file_with_mime_type(file_id)`: Get file with parsed MIME type ## Usage Example ```python from onyx.file_store.file_store import get_default_file_store from onyx.configs.constants import FileOrigin # Get the configured file store file_store = get_default_file_store(db_session) # Initialize the storage backend (creates bucket if needed) file_store.initialize() # Save a file with open("example.pdf", "rb") as f: file_id = file_store.save_file( content=f, display_name="Important Document.pdf", file_origin=FileOrigin.OTHER, file_type="application/pdf", file_metadata={"department": "engineering", "version": "1.0"} ) # Check if a file exists exists = file_store.has_file( file_id=file_id, file_origin=FileOrigin.OTHER, file_type="application/pdf" ) # Read a file file_content = file_store.read_file(file_id) # Read file with temporary file (for large files) file_content = file_store.read_file(file_id, use_tempfile=True) # Get file metadata file_record = file_store.read_file_record(file_id) # Get file with MIME type detection file_with_mime = file_store.get_file_with_mime_type(file_id) # Delete a file file_store.delete_file(file_id) ``` ## Initialization When deploying the application, ensure that: 1. The S3-compatible storage service is accessible 2. Credentials are properly configured 3. The bucket specified in `S3_FILE_STORE_BUCKET_NAME` exists or the service account has permissions to create it 4. Call `file_store.initialize()` during application startup to ensure the bucket exists The file store will automatically create the bucket if it doesn't exist and the credentials have sufficient permissions. ================================================ FILE: backend/onyx/file_store/constants.py ================================================ MAX_IN_MEMORY_SIZE = 30 * 1024 * 1024 # 30MB STANDARD_CHUNK_SIZE = 10 * 1024 * 1024 # 10MB chunks ================================================ FILE: backend/onyx/file_store/document_batch_storage.py ================================================ import json from abc import ABC from abc import abstractmethod from enum import Enum from io import StringIO from typing import List from typing import Optional from typing import TypeAlias from pydantic import BaseModel from onyx.configs.constants import FileOrigin from onyx.connectors.models import DocExtractionContext from onyx.connectors.models import DocIndexingContext from onyx.connectors.models import Document from onyx.file_store.file_store import FileStore from onyx.file_store.file_store import get_default_file_store from onyx.utils.logger import setup_logger logger = setup_logger() class DocumentBatchStorageStateType(str, Enum): EXTRACTION = "extraction" INDEXING = "indexing" DocumentStorageState: TypeAlias = DocExtractionContext | DocIndexingContext STATE_TYPE_TO_MODEL: dict[str, type[DocumentStorageState]] = { DocumentBatchStorageStateType.EXTRACTION.value: DocExtractionContext, DocumentBatchStorageStateType.INDEXING.value: DocIndexingContext, } class BatchStoragePathInfo(BaseModel): cc_pair_id: int index_attempt_id: int batch_num: int class DocumentBatchStorage(ABC): """Abstract base class for document batch storage implementations.""" def __init__(self, cc_pair_id: int, index_attempt_id: int): self.cc_pair_id = cc_pair_id self.index_attempt_id = index_attempt_id self.base_path = f"{self._per_cc_pair_base_path()}/{index_attempt_id}" @abstractmethod def store_batch(self, batch_num: int, documents: List[Document]) -> None: """Store a batch of documents.""" @abstractmethod def get_batch(self, batch_num: int) -> Optional[List[Document]]: """Retrieve a batch of documents.""" @abstractmethod def delete_batch_by_name(self, batch_file_name: str) -> None: """Delete a specific batch.""" @abstractmethod def delete_batch_by_num(self, batch_num: int) -> None: """Delete a specific batch.""" @abstractmethod def cleanup_all_batches(self) -> None: """Clean up all batches and state for this index attempt.""" @abstractmethod def get_all_batches_for_cc_pair(self) -> list[str]: """Get all IDs of batches stored in the file store.""" @abstractmethod def update_old_batches_to_new_index_attempt(self, batch_names: list[str]) -> None: """Update all batches to the new index attempt.""" """ This is used when we need to re-issue docprocessing tasks for a new index attempt. We need to update the batch file names to the new index attempt ID. """ @abstractmethod def extract_path_info(self, path: str) -> BatchStoragePathInfo | None: """Extract path info from a path.""" def _serialize_documents(self, documents: list[Document]) -> str: """Serialize documents to JSON string.""" # Use mode='json' to properly serialize datetime and other complex types return json.dumps([doc.model_dump(mode="json") for doc in documents], indent=2) def _deserialize_documents(self, data: str) -> list[Document]: """Deserialize documents from JSON string.""" doc_dicts = json.loads(data) return [ Document.model_validate(self._normalize_doc_dict(doc_dict)) for doc_dict in doc_dicts ] def _normalize_doc_dict(self, doc_dict: dict) -> dict: """Normalize document dict to handle legacy data with non-string metadata values. Before the _convert_to_metadata_value fix, Salesforce connector stored raw types (bool, float, None) in metadata. This converts them to strings for backward compatibility. """ if "metadata" not in doc_dict: return doc_dict metadata = doc_dict["metadata"] if not isinstance(metadata, dict): return doc_dict normalized_metadata: dict[str, str | list[str]] = {} converted_keys: list[str] = [] for key, value in metadata.items(): if isinstance(value, list): normalized_metadata[key] = [str(item) for item in value] elif isinstance(value, str): normalized_metadata[key] = value else: # Convert bool, int, float, None to string converted_keys.append(f"{key}={type(value).__name__}") normalized_metadata[key] = str(value) if converted_keys: doc_id = doc_dict.get("id", "unknown") logger.warning( f"Normalized legacy metadata for document {doc_id}: {converted_keys}" ) doc_dict["metadata"] = normalized_metadata return doc_dict def _per_cc_pair_base_path(self) -> str: """Get the base path for the cc pair.""" return f"iab/{self.cc_pair_id}" class FileStoreDocumentBatchStorage(DocumentBatchStorage): """FileStore-based implementation of document batch storage.""" def __init__(self, cc_pair_id: int, index_attempt_id: int, file_store: FileStore): super().__init__(cc_pair_id, index_attempt_id) self.file_store = file_store def _get_batch_file_name(self, batch_num: int) -> str: """Generate file name for a document batch.""" return f"{self.base_path}/{batch_num}.json" def store_batch(self, batch_num: int, documents: list[Document]) -> None: """Store a batch of documents using FileStore.""" file_name = self._get_batch_file_name(batch_num) try: data = self._serialize_documents(documents) content = StringIO(data) self.file_store.save_file( file_id=file_name, content=content, display_name=f"Document Batch {batch_num}", file_origin=FileOrigin.OTHER, file_type="application/json", file_metadata={ "batch_num": batch_num, "document_count": str(len(documents)), }, ) logger.debug( f"Stored batch {batch_num} with {len(documents)} documents to FileStore as {file_name}" ) except Exception as e: logger.error(f"Failed to store batch {batch_num}: {e}") raise def get_batch(self, batch_num: int) -> list[Document] | None: """Retrieve a batch of documents from FileStore.""" file_name = self._get_batch_file_name(batch_num) try: # Check if file exists if not self.file_store.has_file( file_id=file_name, file_origin=FileOrigin.OTHER, file_type="application/json", ): logger.warning( f"Batch {batch_num} not found in FileStore with name {file_name}" ) return None content_io = self.file_store.read_file(file_name) data = content_io.read().decode("utf-8") documents = self._deserialize_documents(data) logger.debug( f"Retrieved batch {batch_num} with {len(documents)} documents from FileStore" ) return documents except Exception as e: logger.error(f"Failed to retrieve batch {batch_num}: {e}") raise def delete_batch_by_name(self, batch_file_name: str) -> None: """Delete a specific batch from FileStore.""" self.file_store.delete_file(batch_file_name) logger.debug(f"Deleted batch {batch_file_name} from FileStore") def delete_batch_by_num(self, batch_num: int) -> None: """Delete a specific batch from FileStore.""" batch_file_name = self._get_batch_file_name(batch_num) self.delete_batch_by_name(batch_file_name) logger.debug(f"Deleted batch num {batch_num} {batch_file_name} from FileStore") def cleanup_all_batches(self) -> None: """Clean up all batches for this index attempt.""" for batch_file_name in self.get_all_batches_for_cc_pair(): self.delete_batch_by_name(batch_file_name) def get_all_batches_for_cc_pair(self) -> list[str]: """Get all IDs of batches stored in the file store for the cc pair this batch store was initialized with. This includes any batches left over from a previous indexing attempt that need to be processed. """ return [ file.file_id for file in self.file_store.list_files_by_prefix( self._per_cc_pair_base_path() ) ] def update_old_batches_to_new_index_attempt(self, batch_names: list[str]) -> None: """Update all batches to the new index attempt.""" for batch_file_name in batch_names: path_info = self.extract_path_info(batch_file_name) if path_info is None: logger.warning( f"Could not extract path info from batch file: {batch_file_name}" ) continue new_batch_file_name = self._get_batch_file_name(path_info.batch_num) self.file_store.change_file_id(batch_file_name, new_batch_file_name) def extract_path_info(self, path: str) -> BatchStoragePathInfo | None: """Extract path info from a path.""" path_spl = path.split("/") # TODO: remove this in a few months, just for backwards compatibility if len(path_spl) == 3: path_spl = ["iab"] + path_spl try: _, cc_pair_id, index_attempt_id, batch_num = path_spl return BatchStoragePathInfo( cc_pair_id=int(cc_pair_id), index_attempt_id=int(index_attempt_id), batch_num=int(batch_num.split(".")[0]), # remove .json ) except Exception as e: logger.error(f"Failed to extract path info from {path}: {e}") return None def get_document_batch_storage( cc_pair_id: int, index_attempt_id: int ) -> DocumentBatchStorage: """Factory function to get the configured document batch storage implementation.""" # The get_default_file_store will now correctly use S3BackedFileStore # or other configured stores based on environment variables file_store = get_default_file_store() return FileStoreDocumentBatchStorage(cc_pair_id, index_attempt_id, file_store) ================================================ FILE: backend/onyx/file_store/file_store.py ================================================ import hashlib import tempfile import uuid from abc import ABC from abc import abstractmethod from io import BytesIO from typing import Any from typing import cast from typing import IO from typing import NotRequired from typing import TypedDict import boto3 import puremagic from botocore.config import Config from botocore.exceptions import ClientError from mypy_boto3_s3 import S3Client from sqlalchemy.orm import Session from onyx.configs.app_configs import AWS_REGION_NAME from onyx.configs.app_configs import S3_AWS_ACCESS_KEY_ID from onyx.configs.app_configs import S3_AWS_SECRET_ACCESS_KEY from onyx.configs.app_configs import S3_ENDPOINT_URL from onyx.configs.app_configs import S3_FILE_STORE_BUCKET_NAME from onyx.configs.app_configs import S3_FILE_STORE_PREFIX from onyx.configs.app_configs import S3_GENERATE_LOCAL_CHECKSUM from onyx.configs.app_configs import S3_VERIFY_SSL from onyx.configs.constants import FileOrigin from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none from onyx.db.file_record import delete_filerecord_by_file_id from onyx.db.file_record import get_filerecord_by_file_id from onyx.db.file_record import get_filerecord_by_file_id_optional from onyx.db.file_record import get_filerecord_by_prefix from onyx.db.file_record import upsert_filerecord from onyx.db.models import FileRecord from onyx.db.models import FileRecord as FileStoreModel from onyx.file_store.s3_key_utils import generate_s3_key from onyx.utils.file import FileWithMimeType from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() class S3PutKwargs(TypedDict): ChecksumSHA256: NotRequired[str] class FileStore(ABC): """ An abstraction for storing files and large binary objects. """ @abstractmethod def initialize(self) -> None: """ Should generally be called once before any other methods are called. """ raise NotImplementedError @abstractmethod def has_file( self, file_id: str, file_origin: FileOrigin, file_type: str, ) -> bool: """ Check if a file exists in the blob store Parameters: - file_id: Unique ID of the file to check for - file_origin: Origin of the file - file_type: Type of the file """ raise NotImplementedError @abstractmethod def save_file( self, content: IO, display_name: str | None, file_origin: FileOrigin, file_type: str, file_metadata: dict[str, Any] | None = None, file_id: str | None = None, ) -> str: """ Save a file to the blob store Parameters: - content: Contents of the file - display_name: Display name of the file to save - file_origin: Origin of the file - file_type: Type of the file - file_metadata: Additional metadata for the file - file_id: Unique ID of the file to save. If not provided, a random UUID will be generated. It is generally NOT recommended to provide this. Returns: The unique ID of the file that was saved. """ raise NotImplementedError @abstractmethod def read_file( self, file_id: str, mode: str | None = None, use_tempfile: bool = False ) -> IO[bytes]: """ Read the content of a given file by the ID Parameters: - file_id: Unique ID of file to read - mode: Mode to open the file (e.g. 'b' for binary) - use_tempfile: Whether to use a temporary file to store the contents in order to avoid loading the entire file into memory Returns: Contents of the file and metadata dict """ @abstractmethod def read_file_record(self, file_id: str) -> FileStoreModel: """ Read the file record by the ID """ @abstractmethod def get_file_size( self, file_id: str, db_session: Session | None = None ) -> int | None: """ Get the size of a file in bytes. Optionally provide a db_session for database access. """ @abstractmethod def delete_file(self, file_id: str) -> None: """ Delete a file by its ID. Parameters: - file_name: Name of file to delete """ @abstractmethod def get_file_with_mime_type(self, file_id: str) -> FileWithMimeType | None: """ Get the file + parse out the mime type. """ @abstractmethod def change_file_id(self, old_file_id: str, new_file_id: str) -> None: """ Change the file ID of an existing file. Parameters: - old_file_id: Current file ID - new_file_id: New file ID to assign """ raise NotImplementedError @abstractmethod def list_files_by_prefix(self, prefix: str) -> list[FileRecord]: """ List all file IDs that start with the given prefix. """ class S3BackedFileStore(FileStore): """Isn't necessarily S3, but is any S3-compatible storage (e.g. MinIO)""" def __init__( self, bucket_name: str, aws_access_key_id: str | None = None, aws_secret_access_key: str | None = None, aws_region_name: str | None = None, s3_endpoint_url: str | None = None, s3_prefix: str | None = None, s3_verify_ssl: bool = True, ) -> None: self._s3_client: S3Client | None = None self._bucket_name = bucket_name self._aws_access_key_id = aws_access_key_id self._aws_secret_access_key = aws_secret_access_key self._aws_region_name = aws_region_name or "us-east-2" self._s3_endpoint_url = s3_endpoint_url self._s3_prefix = s3_prefix or "onyx-files" self._s3_verify_ssl = s3_verify_ssl def _get_s3_client(self) -> S3Client: """Initialize S3 client if not already done""" if self._s3_client is None: try: client_kwargs: dict[str, Any] = { "service_name": "s3", "region_name": self._aws_region_name, } # Add endpoint URL if specified (for MinIO, etc.) if self._s3_endpoint_url: client_kwargs["endpoint_url"] = self._s3_endpoint_url client_kwargs["config"] = Config( signature_version="s3v4", s3={"addressing_style": "path"}, # Required for MinIO ) # Disable SSL verification if requested (for local development) if not self._s3_verify_ssl: import urllib3 urllib3.disable_warnings( urllib3.exceptions.InsecureRequestWarning ) client_kwargs["verify"] = False if self._aws_access_key_id and self._aws_secret_access_key: # Use explicit credentials client_kwargs.update( { "aws_access_key_id": self._aws_access_key_id, "aws_secret_access_key": self._aws_secret_access_key, } ) self._s3_client = boto3.client(**client_kwargs) else: # Use IAM role or default credentials (not typically used with MinIO) self._s3_client = boto3.client(**client_kwargs) except Exception as e: logger.error(f"Failed to initialize S3 client: {e}") raise RuntimeError(f"Failed to initialize S3 client: {e}") return self._s3_client def _get_bucket_name(self) -> str: """Get S3 bucket name from configuration""" if not self._bucket_name: raise RuntimeError("S3 bucket name is required for S3 file store") return self._bucket_name def _get_s3_key(self, file_name: str) -> str: """Generate S3 key from file name with tenant ID prefix""" tenant_id = get_current_tenant_id() s3_key = generate_s3_key( file_name=file_name, prefix=self._s3_prefix, tenant_id=tenant_id, max_key_length=1024, ) # Log if truncation occurred (when the key is exactly at the limit) if len(s3_key) == 1024: logger.info(f"File name was too long and was truncated: {file_name}") return s3_key def initialize(self) -> None: """Initialize the S3 file store by ensuring the bucket exists""" s3_client = self._get_s3_client() bucket_name = self._get_bucket_name() # Check if bucket exists try: s3_client.head_bucket(Bucket=bucket_name) logger.info(f"S3 bucket '{bucket_name}' already exists") except ClientError as e: error_code = e.response["Error"]["Code"] if error_code == "404": # Bucket doesn't exist, create it logger.info(f"Creating S3 bucket '{bucket_name}'") # For AWS S3, we need to handle region-specific bucket creation region = ( s3_client._client_config.region_name if hasattr(s3_client, "_client_config") else None ) if region and region != "us-east-1": # For regions other than us-east-1, we need to specify LocationConstraint s3_client.create_bucket( Bucket=bucket_name, CreateBucketConfiguration={"LocationConstraint": region}, ) else: # For us-east-1 or MinIO/other S3-compatible services s3_client.create_bucket(Bucket=bucket_name) logger.info(f"Successfully created S3 bucket '{bucket_name}'") elif error_code == "403": # Bucket exists but we don't have permission to access it logger.warning( f"S3 bucket '{bucket_name}' exists but access is forbidden" ) raise RuntimeError( f"Access denied to S3 bucket '{bucket_name}'. Check credentials and permissions." ) else: # Some other error occurred logger.error(f"Failed to check S3 bucket '{bucket_name}': {e}") raise RuntimeError(f"Failed to check S3 bucket '{bucket_name}': {e}") def has_file( self, file_id: str, file_origin: FileOrigin, file_type: str, db_session: Session | None = None, ) -> bool: with get_session_with_current_tenant_if_none(db_session) as db_session: file_record = get_filerecord_by_file_id_optional( file_id=file_id, db_session=db_session ) return ( file_record is not None and file_record.file_origin == file_origin and file_record.file_type == file_type ) def save_file( self, content: IO, display_name: str | None, file_origin: FileOrigin, file_type: str, file_metadata: dict[str, Any] | None = None, file_id: str | None = None, db_session: Session | None = None, ) -> str: if file_id is None: file_id = str(uuid.uuid4()) s3_client = self._get_s3_client() bucket_name = self._get_bucket_name() s3_key = self._get_s3_key(file_id) hash256 = "" sha256_hash = hashlib.sha256() kwargs: S3PutKwargs = {} # FIX: Optimize checksum generation to avoid creating extra copies in memory # Read content from IO object if hasattr(content, "read"): file_content = content.read() if S3_GENERATE_LOCAL_CHECKSUM: # FIX: Don't convert to string first (creates unnecessary copy) # Work directly with bytes if isinstance(file_content, bytes): sha256_hash.update(file_content) else: sha256_hash.update(str(file_content).encode()) hash256 = sha256_hash.hexdigest() kwargs["ChecksumSHA256"] = hash256 if hasattr(content, "seek"): content.seek(0) # Reset position for potential re-reads else: file_content = content # Upload to S3 s3_client.put_object( Bucket=bucket_name, Key=s3_key, Body=file_content, ContentType=file_type, **kwargs, ) with get_session_with_current_tenant_if_none(db_session) as db_session: # Save metadata to database upsert_filerecord( file_id=file_id, display_name=display_name or file_id, file_origin=file_origin, file_type=file_type, bucket_name=bucket_name, object_key=s3_key, db_session=db_session, file_metadata=file_metadata, ) db_session.commit() return file_id def read_file( self, file_id: str, mode: str | None = None, # noqa: ARG002 use_tempfile: bool = False, db_session: Session | None = None, ) -> IO[bytes]: with get_session_with_current_tenant_if_none(db_session) as db_session: file_record = get_filerecord_by_file_id( file_id=file_id, db_session=db_session ) s3_client = self._get_s3_client() try: response = s3_client.get_object( Bucket=file_record.bucket_name, Key=file_record.object_key ) except ClientError: logger.error(f"Failed to read file {file_id} from S3") raise # FIX: Stream file content instead of loading entire file into memory # This prevents OOM issues with large files (500MB+ PDFs, etc.) if use_tempfile: # Stream directly to temp file to avoid holding entire file in memory temp_file = tempfile.NamedTemporaryFile(mode="w+b", delete=True) # Stream in 8MB chunks to reduce memory footprint for chunk in response["Body"].iter_chunks(chunk_size=8 * 1024 * 1024): temp_file.write(chunk) temp_file.seek(0) return temp_file else: # For BytesIO, we still need to read into memory (legacy behavior) # but at least we're not creating duplicate copies file_content = response["Body"].read() return BytesIO(file_content) def read_file_record( self, file_id: str, db_session: Session | None = None ) -> FileStoreModel: with get_session_with_current_tenant_if_none(db_session) as db_session: file_record = get_filerecord_by_file_id( file_id=file_id, db_session=db_session ) return file_record def get_file_size( self, file_id: str, db_session: Session | None = None ) -> int | None: """ Get the size of a file in bytes by querying S3 metadata. """ try: with get_session_with_current_tenant_if_none(db_session) as db_session: file_record = get_filerecord_by_file_id( file_id=file_id, db_session=db_session ) s3_client = self._get_s3_client() response = s3_client.head_object( Bucket=file_record.bucket_name, Key=file_record.object_key ) return response.get("ContentLength") except Exception as e: logger.warning(f"Error getting file size for {file_id}: {e}") return None def delete_file(self, file_id: str, db_session: Session | None = None) -> None: with get_session_with_current_tenant_if_none(db_session) as db_session: try: file_record = get_filerecord_by_file_id( file_id=file_id, db_session=db_session ) if not file_record.bucket_name: logger.error( f"File record {file_id} with key {file_record.object_key} " "has no bucket name, cannot delete from filestore" ) delete_filerecord_by_file_id(file_id=file_id, db_session=db_session) db_session.commit() return # Delete from external storage s3_client = self._get_s3_client() try: s3_client.delete_object( Bucket=file_record.bucket_name, Key=file_record.object_key ) except ClientError as e: # If the object doesn't exist in file store, treat it as success # since the end goal (object not existing) is achieved if e.response.get("Error", {}).get("Code") == "NoSuchKey": logger.warning( f"delete_file: File {file_id} not found in file store (key: {file_record.object_key}), " "cleaning up database record." ) else: raise # Delete metadata from database delete_filerecord_by_file_id(file_id=file_id, db_session=db_session) db_session.commit() except Exception: db_session.rollback() raise def change_file_id( self, old_file_id: str, new_file_id: str, db_session: Session | None = None ) -> None: with get_session_with_current_tenant_if_none(db_session) as db_session: try: # Get the existing file record old_file_record = get_filerecord_by_file_id( file_id=old_file_id, db_session=db_session ) # Generate new S3 key for the new file ID new_s3_key = self._get_s3_key(new_file_id) # Copy S3 object to new key s3_client = self._get_s3_client() bucket_name = self._get_bucket_name() copy_source = ( f"{old_file_record.bucket_name}/{old_file_record.object_key}" ) s3_client.copy_object( CopySource=copy_source, Bucket=bucket_name, Key=new_s3_key, MetadataDirective="COPY", ) # Create new file record with new file_id # Cast file_metadata to the expected type file_metadata = cast( dict[Any, Any] | None, old_file_record.file_metadata ) upsert_filerecord( file_id=new_file_id, display_name=old_file_record.display_name, file_origin=old_file_record.file_origin, file_type=old_file_record.file_type, bucket_name=bucket_name, object_key=new_s3_key, db_session=db_session, file_metadata=file_metadata, ) # Delete old S3 object s3_client.delete_object( Bucket=old_file_record.bucket_name, Key=old_file_record.object_key ) # Delete old file record delete_filerecord_by_file_id(file_id=old_file_id, db_session=db_session) db_session.commit() except Exception as e: db_session.rollback() logger.exception( f"Failed to change file ID from {old_file_id} to {new_file_id}: {e}" ) raise def get_file_with_mime_type(self, file_id: str) -> FileWithMimeType | None: mime_type: str = "application/octet-stream" try: file_io = self.read_file(file_id, mode="b") file_content = file_io.read() matches = puremagic.magic_string(file_content) if matches: mime_type = cast(str, matches[0].mime_type) return FileWithMimeType(data=file_content, mime_type=mime_type) except Exception: return None def list_files_by_prefix(self, prefix: str) -> list[FileRecord]: """ List all file IDs that start with the given prefix. """ with get_session_with_current_tenant() as db_session: file_records = get_filerecord_by_prefix( prefix=prefix, db_session=db_session ) return file_records def get_s3_file_store() -> S3BackedFileStore: """ Returns the S3 file store implementation. """ # Get bucket name - this is required bucket_name = S3_FILE_STORE_BUCKET_NAME if not bucket_name: raise RuntimeError( "S3_FILE_STORE_BUCKET_NAME configuration is required for S3 file store" ) return S3BackedFileStore( bucket_name=bucket_name, aws_access_key_id=S3_AWS_ACCESS_KEY_ID, aws_secret_access_key=S3_AWS_SECRET_ACCESS_KEY, aws_region_name=AWS_REGION_NAME, s3_endpoint_url=S3_ENDPOINT_URL, s3_prefix=S3_FILE_STORE_PREFIX, s3_verify_ssl=S3_VERIFY_SSL, ) def get_default_file_store() -> FileStore: """ Returns the configured file store implementation based on FILE_STORE_BACKEND. When FILE_STORE_BACKEND=postgres (default): - Files are stored in PostgreSQL using Large Objects. - No external storage service (S3/MinIO) is required. When FILE_STORE_BACKEND=s3: - Supports AWS S3, MinIO, and other S3-compatible storage. - Configuration via environment variables: - S3_FILE_STORE_BUCKET_NAME, S3_ENDPOINT_URL, S3_AWS_ACCESS_KEY_ID, etc. """ from onyx.configs.app_configs import FILE_STORE_BACKEND from onyx.configs.constants import FileStoreType if FileStoreType(FILE_STORE_BACKEND) == FileStoreType.POSTGRES: from onyx.file_store.postgres_file_store import PostgresBackedFileStore return PostgresBackedFileStore() return get_s3_file_store() ================================================ FILE: backend/onyx/file_store/models.py ================================================ import base64 from enum import Enum from typing import NotRequired from typing_extensions import TypedDict # noreorder from pydantic import BaseModel class ChatFileType(str, Enum): # Image types only contain the binary data IMAGE = "image" # Doc types are saved as both the binary, and the parsed text DOC = "document" # Plain text only contain the text PLAIN_TEXT = "plain_text" # Tabular data files (CSV, XLSX) TABULAR = "tabular" def is_text_file(self) -> bool: return self in ( ChatFileType.PLAIN_TEXT, ChatFileType.DOC, ChatFileType.TABULAR, ) def use_metadata_only(self) -> bool: """File types where we can ignore the file content and only use the metadata.""" return self in (ChatFileType.TABULAR,) class FileDescriptor(TypedDict): """NOTE: is a `TypedDict` so it can be used as a type hint for a JSONB column in Postgres""" id: str type: ChatFileType name: NotRequired[str | None] user_file_id: NotRequired[str | None] class InMemoryChatFile(BaseModel): file_id: str content: bytes file_type: ChatFileType filename: str | None = None def to_base64(self) -> str: if self.file_type == ChatFileType.IMAGE: return base64.b64encode(self.content).decode() else: raise RuntimeError( "Should not be trying to convert a non-image file to base64" ) def to_file_descriptor(self) -> FileDescriptor: return { "id": str(self.file_id), "type": self.file_type, "name": self.filename, "user_file_id": str(self.file_id) if self.file_id else None, } ================================================ FILE: backend/onyx/file_store/postgres_file_store.py ================================================ """PostgreSQL-backed file store using Large Objects. Stores file content directly in PostgreSQL via the Large Object facility, eliminating the need for an external S3/MinIO service. """ import tempfile import uuid from io import BytesIO from typing import Any from typing import cast from typing import IO import puremagic from psycopg2.extensions import connection as Psycopg2Connection from sqlalchemy.orm import Session from onyx.configs.constants import FileOrigin from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.engine.sql_engine import get_session_with_current_tenant_if_none from onyx.db.file_content import delete_file_content_by_file_id from onyx.db.file_content import get_file_content_by_file_id from onyx.db.file_content import get_file_content_by_file_id_optional from onyx.db.file_content import transfer_file_content_file_id from onyx.db.file_content import upsert_file_content from onyx.db.file_record import delete_filerecord_by_file_id from onyx.db.file_record import get_filerecord_by_file_id from onyx.db.file_record import get_filerecord_by_file_id_optional from onyx.db.file_record import get_filerecord_by_prefix from onyx.db.file_record import upsert_filerecord from onyx.db.models import FileRecord from onyx.db.models import FileRecord as FileStoreModel from onyx.file_store.file_store import FileStore from onyx.utils.file import FileWithMimeType from onyx.utils.logger import setup_logger logger = setup_logger() POSTGRES_BUCKET_SENTINEL = "postgres" STREAM_CHUNK_SIZE = 8 * 1024 * 1024 # 8 MB def _get_raw_connection(db_session: Session) -> Psycopg2Connection: """Extract the raw psycopg2 connection from a SQLAlchemy session.""" raw_conn = db_session.connection().connection.dbapi_connection if raw_conn is None: raise ValueError("Failed to get raw connection from session") return cast(Psycopg2Connection, raw_conn) def _create_large_object(raw_conn: Psycopg2Connection, data: bytes) -> int: """Create a new Large Object, write data, and return the OID.""" lobj = raw_conn.lobject(0, "wb") lobj.write(data) oid: int = lobj.oid lobj.close() return oid def _read_large_object(raw_conn: Psycopg2Connection, oid: int) -> bytes: """Read all bytes from a Large Object.""" lobj = raw_conn.lobject(oid, "rb") data: bytes = lobj.read() lobj.close() return data def _read_large_object_to_tempfile(raw_conn: Psycopg2Connection, oid: int) -> IO[bytes]: """Stream a Large Object into a temporary file to avoid OOM on large files.""" lobj = raw_conn.lobject(oid, "rb") temp = tempfile.NamedTemporaryFile(mode="w+b", delete=True) while True: chunk = lobj.read(STREAM_CHUNK_SIZE) if not chunk: break temp.write(chunk) lobj.close() temp.seek(0) return temp def _delete_large_object(raw_conn: Any, oid: int) -> None: """Unlink (delete) a Large Object by OID.""" lobj = raw_conn.lobject(oid, "n") lobj.unlink() class PostgresBackedFileStore(FileStore): """File store backed entirely by PostgreSQL. Metadata lives in `file_record`, content lives in PostgreSQL Large Objects with OID references tracked in `file_content`. """ def initialize(self) -> None: # Nothing to do — tables are created by Alembic migrations. pass def has_file( self, file_id: str, file_origin: FileOrigin, file_type: str, db_session: Session | None = None, ) -> bool: with get_session_with_current_tenant_if_none(db_session) as session: record = get_filerecord_by_file_id_optional( file_id=file_id, db_session=session ) return ( record is not None and record.file_origin == file_origin and record.file_type == file_type ) def save_file( self, content: IO, display_name: str | None, file_origin: FileOrigin, file_type: str, file_metadata: dict[str, Any] | None = None, file_id: str | None = None, db_session: Session | None = None, ) -> str: if file_id is None: file_id = str(uuid.uuid4()) file_bytes = self._read_content_bytes(content) created_lo = False with get_session_with_current_tenant_if_none(db_session) as session: raw_conn, oid = None, None try: raw_conn = _get_raw_connection(session) # Look up existing content so we can unlink the old # Large Object after a successful overwrite. existing = get_file_content_by_file_id_optional( file_id=file_id, db_session=session ) old_oid = existing.lobj_oid if existing else None oid = _create_large_object(raw_conn, file_bytes) created_lo = True upsert_filerecord( file_id=file_id, display_name=display_name or file_id, file_origin=file_origin, file_type=file_type, bucket_name=POSTGRES_BUCKET_SENTINEL, object_key=str(oid), db_session=session, file_metadata=file_metadata, ) upsert_file_content( file_id=file_id, lobj_oid=oid, file_size=len(file_bytes), db_session=session, ) # Unlink the previous Large Object to avoid orphans if old_oid is not None and old_oid != oid: try: _delete_large_object(raw_conn, old_oid) except Exception: logger.warning( f"Failed to unlink old large object {old_oid} for file {file_id}" ) session.commit() except Exception as e: session.rollback() try: if created_lo and raw_conn is not None and oid is not None: _delete_large_object(raw_conn, oid) except Exception: logger.exception( f"Failed to delete large object {oid} for file {file_id}" ) raise e return file_id def read_file( self, file_id: str, mode: str | None = None, # noqa: ARG002 use_tempfile: bool = False, db_session: Session | None = None, ) -> IO[bytes]: with get_session_with_current_tenant_if_none(db_session) as session: file_content = get_file_content_by_file_id( file_id=file_id, db_session=session ) raw_conn = _get_raw_connection(session) if use_tempfile: return _read_large_object_to_tempfile(raw_conn, file_content.lobj_oid) data = _read_large_object(raw_conn, file_content.lobj_oid) return BytesIO(data) def read_file_record( self, file_id: str, db_session: Session | None = None ) -> FileStoreModel: with get_session_with_current_tenant_if_none(db_session) as session: return get_filerecord_by_file_id(file_id=file_id, db_session=session) def get_file_size( self, file_id: str, db_session: Session | None = None ) -> int | None: try: with get_session_with_current_tenant_if_none(db_session) as session: record = get_file_content_by_file_id( file_id=file_id, db_session=session ) return record.file_size except Exception as e: logger.warning(f"Error getting file size for {file_id}: {e}") return None def delete_file(self, file_id: str, db_session: Session | None = None) -> None: with get_session_with_current_tenant_if_none(db_session) as session: try: file_content = get_file_content_by_file_id( file_id=file_id, db_session=session ) raw_conn = _get_raw_connection(session) try: _delete_large_object(raw_conn, file_content.lobj_oid) except Exception: logger.warning( f"Large object {file_content.lobj_oid} for file {file_id} not found, cleaning up records only." ) delete_file_content_by_file_id(file_id=file_id, db_session=session) delete_filerecord_by_file_id(file_id=file_id, db_session=session) session.commit() except Exception: session.rollback() raise def get_file_with_mime_type(self, file_id: str) -> FileWithMimeType | None: mime_type = "application/octet-stream" try: file_io = self.read_file(file_id, mode="b") except Exception: return None file_content = file_io.read() try: matches = puremagic.magic_string(file_content) if matches: mime_type = cast(str, matches[0].mime_type) except puremagic.PureError: pass return FileWithMimeType(data=file_content, mime_type=mime_type) def change_file_id( self, old_file_id: str, new_file_id: str, db_session: Session | None = None ) -> None: with get_session_with_current_tenant_if_none(db_session) as session: try: old_record = get_filerecord_by_file_id( file_id=old_file_id, db_session=session ) file_metadata = cast(dict[Any, Any] | None, old_record.file_metadata) # 1. Create the new file_record so the FK target exists upsert_filerecord( file_id=new_file_id, display_name=old_record.display_name, file_origin=old_record.file_origin, file_type=old_record.file_type, bucket_name=POSTGRES_BUCKET_SENTINEL, object_key=old_record.object_key, db_session=session, file_metadata=file_metadata, ) # 2. Move file_content in-place — the LO OID is never # shared between two rows. transfer_file_content_file_id( old_file_id=old_file_id, new_file_id=new_file_id, db_session=session, ) # 3. Remove the now-orphaned old file_record delete_filerecord_by_file_id(file_id=old_file_id, db_session=session) session.commit() except Exception as e: session.rollback() logger.exception( f"Failed to change file ID from {old_file_id} to {new_file_id}: {e}" ) raise def list_files_by_prefix(self, prefix: str) -> list[FileRecord]: with get_session_with_current_tenant() as session: return get_filerecord_by_prefix(prefix=prefix, db_session=session) @staticmethod def _read_content_bytes(content: IO) -> bytes: """Normalize an IO object into raw bytes.""" if hasattr(content, "read"): raw = content.read() else: raw = content if isinstance(raw, str): return raw.encode("utf-8") return raw ================================================ FILE: backend/onyx/file_store/s3_key_utils.py ================================================ """ S3 key sanitization utilities for ensuring AWS S3 compatibility. This module provides utilities for sanitizing file names to be compatible with AWS S3 object key naming guidelines while ensuring uniqueness when significant sanitization occurs. Reference: https://docs.aws.amazon.com/AmazonS3/latest/userguide/object-keys.html """ import hashlib import re import urllib.parse from re import Match # Constants for S3 key generation HASH_LENGTH = 64 # SHA256 hex digest length HASH_SEPARATOR_LENGTH = 1 # Length of underscore separator HASH_WITH_SEPARATOR_LENGTH = HASH_LENGTH + HASH_SEPARATOR_LENGTH def _encode_special_char(match: Match[str]) -> str: """Helper function to URL encode special characters.""" return urllib.parse.quote(match.group(0), safe="") def sanitize_s3_key_name(file_name: str) -> str: """ Sanitize file name to be S3-compatible according to AWS guidelines. This method: 1. Replaces problematic characters with safe alternatives 2. URL-encodes characters that might require special handling 3. Ensures the result is safe for S3 object keys 4. Adds uniqueness when significant sanitization occurs Args: file_name: The original file name to sanitize Returns: A sanitized file name that is S3-compatible Reference: https://docs.aws.amazon.com/AmazonS3/latest/userguide/object-keys.html """ if not file_name: return "unnamed_file" original_name = file_name # Characters to avoid completely (replace with underscore) # These are characters that AWS recommends avoiding avoid_chars = r'[\\{}^%`\[\]"<>#|~/]' # Replace avoided characters with underscore sanitized = re.sub(avoid_chars, "_", file_name) # Characters that might require special handling but are allowed # We'll URL encode these to be safe special_chars = r"[&$@=;:+,?\s]" sanitized = re.sub(special_chars, _encode_special_char, sanitized) # Handle non-ASCII characters by URL encoding them # This ensures Unicode characters are properly handled needs_unicode_encoding = False try: # Try to encode as ASCII to check if it contains non-ASCII chars sanitized.encode("ascii") except UnicodeEncodeError: needs_unicode_encoding = True # Contains non-ASCII characters, URL encode the entire string # but preserve safe ASCII characters sanitized = urllib.parse.quote( sanitized, safe="ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_.()!*", ) # Ensure we don't have consecutive periods at the start (relative path issue) sanitized = re.sub(r"^\.+", "", sanitized) # Remove any trailing periods to avoid download issues sanitized = sanitized.rstrip(".") # Remove multiple separators sanitized = re.sub(r"[-_]{2,}", "-", sanitized) # If sanitization resulted in empty string, use a default if not sanitized: sanitized = "sanitized_file" # Check if significant sanitization occurred and add uniqueness if needed significant_changes = ( # Check if we replaced many characters len(re.findall(avoid_chars, original_name)) > 3 or # Check if we had to URL encode Unicode characters needs_unicode_encoding or # Check if the sanitized name is very different in length (expansion due to encoding) len(sanitized) > len(original_name) * 2 or # Check if the original had many special characters len(re.findall(special_chars, original_name)) > 5 ) if significant_changes: # Add a short hash to ensure uniqueness while keeping some readability name_hash = hashlib.sha256(original_name.encode("utf-8")).hexdigest()[:8] # Try to preserve file extension if it exists and is reasonable if "." in sanitized and len(sanitized.split(".")[-1]) <= 10: name_parts = sanitized.rsplit(".", 1) sanitized = f"{name_parts[0]}_{name_hash}.{name_parts[1]}" else: sanitized = f"{sanitized}_{name_hash}" return sanitized def generate_s3_key( file_name: str, prefix: str, tenant_id: str, max_key_length: int = 1024 ) -> str: """ Generate a complete S3 key from file name with prefix and tenant ID. Args: file_name: The original file name prefix: S3 key prefix (e.g., 'onyx-files') tenant_id: Tenant identifier max_key_length: Maximum allowed S3 key length (default: 1024) Returns: A complete S3 key that fits within the length limit """ # Strip slashes from prefix and tenant_id to avoid double slashes prefix_clean = prefix.strip("/") tenant_clean = tenant_id.strip("/") # Sanitize the file name first sanitized_file_name = sanitize_s3_key_name(file_name) # Handle long file names that could exceed S3's key limit # S3 key format: {prefix}/{tenant_id}/{file_name} prefix_and_tenant_parts = [prefix_clean, tenant_clean] prefix_and_tenant = "/".join(prefix_and_tenant_parts) + "/" max_file_name_length = max_key_length - len(prefix_and_tenant) if len(sanitized_file_name) < max_file_name_length: return "/".join(prefix_and_tenant_parts + [sanitized_file_name]) # For very long file names, use hash-based approach to ensure uniqueness # Use the original file name for the hash to maintain consistency file_hash = hashlib.sha256(file_name.encode("utf-8")).hexdigest() # Calculate how much space we have for the readable part # Reserve space for hash (64 chars) + underscore separator (1 char) readable_part_max_length = max(0, max_file_name_length - HASH_WITH_SEPARATOR_LENGTH) if readable_part_max_length > 0: # Use first part of sanitized name + hash to maintain some readability readable_part = sanitized_file_name[:readable_part_max_length] truncated_name = f"{readable_part}_{file_hash}" else: # If no space for readable part, just use hash truncated_name = file_hash return "/".join(prefix_and_tenant_parts + [truncated_name]) ================================================ FILE: backend/onyx/file_store/utils.py ================================================ import base64 from collections.abc import Callable from io import BytesIO from typing import cast from uuid import UUID import requests from sqlalchemy.orm import Session from onyx.configs.app_configs import WEB_DOMAIN from onyx.configs.constants import FileOrigin from onyx.db.models import UserFile from onyx.file_store.file_store import get_default_file_store from onyx.file_store.models import ChatFileType from onyx.file_store.models import FileDescriptor from onyx.file_store.models import InMemoryChatFile from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type from onyx.utils.b64 import get_image_type from onyx.utils.logger import setup_logger from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel from onyx.utils.timing import log_function_time logger = setup_logger() def plaintext_file_name_for_id(file_id: str) -> str: """Generate a consistent file name for storing plaintext content of a file.""" return f"plaintext_{file_id}" def store_plaintext(file_id: str, plaintext_content: str) -> bool: """ Store plaintext content for a file in the file store. Args: file_id: The ID of the file (user_file or artifact_file) plaintext_content: The plaintext content to store Returns: bool: True if storage was successful, False otherwise """ if not plaintext_content: return False plaintext_file_name = plaintext_file_name_for_id(file_id) try: file_store = get_default_file_store() file_content = BytesIO(plaintext_content.encode("utf-8")) file_store.save_file( content=file_content, display_name=f"Plaintext for {file_id}", file_origin=FileOrigin.PLAINTEXT_CACHE, file_type="text/plain", file_id=plaintext_file_name, ) return True except Exception as e: logger.warning(f"Failed to store plaintext for {file_id}: {e}") return False # --- Convenience wrappers for callers that use user-file UUIDs --- def user_file_id_to_plaintext_file_name(user_file_id: UUID) -> str: """Generate a consistent file name for storing plaintext content of a user file.""" return plaintext_file_name_for_id(str(user_file_id)) def store_user_file_plaintext(user_file_id: UUID, plaintext_content: str) -> bool: """Store plaintext content for a user file (delegates to :func:`store_plaintext`).""" return store_plaintext(str(user_file_id), plaintext_content) def load_chat_file_by_id(file_id: str) -> InMemoryChatFile: """Load a file directly from the file store using its file_record ID. This is the fallback path for chat-attached files that don't have a corresponding row in the ``user_file`` table.""" file_store = get_default_file_store() file_record = file_store.read_file_record(file_id) chat_file_type = mime_type_to_chat_file_type(file_record.file_type) file_io = file_store.read_file(file_id, mode="b") return InMemoryChatFile( file_id=file_id, content=file_io.read(), file_type=chat_file_type, filename=file_record.display_name, ) def load_user_file(file_id: UUID, db_session: Session) -> InMemoryChatFile: status = "not_loaded" user_file = db_session.query(UserFile).filter(UserFile.id == file_id).first() if not user_file: raise ValueError(f"User file with id {file_id} not found") # Get the file record to determine the appropriate chat file type file_store = get_default_file_store() file_record = file_store.read_file_record(user_file.file_id) # Determine appropriate chat file type based on the original file's MIME type chat_file_type = mime_type_to_chat_file_type(file_record.file_type) # Try to load plaintext version first plaintext_file_name = user_file_id_to_plaintext_file_name(file_id) # check for plain text normalized version first, then use original file otherwise try: file_io = file_store.read_file(plaintext_file_name, mode="b") # Metadata-only file types preserve their original type so # downstream injection paths can route them correctly. if chat_file_type.use_metadata_only(): plaintext_chat_file_type = chat_file_type elif file_io is not None: # if we have plaintext for image (which happens when image # extraction is enabled), we use PLAIN_TEXT type plaintext_chat_file_type = ChatFileType.PLAIN_TEXT else: plaintext_chat_file_type = ( ChatFileType.PLAIN_TEXT if chat_file_type != ChatFileType.IMAGE else chat_file_type ) chat_file = InMemoryChatFile( file_id=str(user_file.file_id), content=file_io.read(), file_type=plaintext_chat_file_type, filename=user_file.name, ) status = "plaintext" return chat_file except Exception as e: logger.warning(f"Failed to load plaintext for user file {user_file.id}: {e}") # Fall back to original file if plaintext not available file_io = file_store.read_file(user_file.file_id, mode="b") chat_file = InMemoryChatFile( file_id=str(user_file.file_id), content=file_io.read(), file_type=chat_file_type, filename=user_file.name, ) status = "original" return chat_file finally: logger.debug( f"load_user_file finished: file_id={user_file.file_id} chat_file_type={chat_file_type} status={status}" ) def load_in_memory_chat_files( user_file_ids: list[UUID], db_session: Session, ) -> list[InMemoryChatFile]: """ Loads the actual content of user files specified by individual IDs and those within specified project IDs into memory. Args: user_file_ids: A list of specific UserFile IDs to load. db_session: The SQLAlchemy database session. Returns: A list of InMemoryChatFile objects, each containing the file content (as bytes), file ID, file type, and filename. Prioritizes loading plaintext versions if available. """ # Use parallel execution to load files concurrently return cast( list[InMemoryChatFile], run_functions_tuples_in_parallel( # 1. Load files specified by individual IDs [(load_user_file, (file_id, db_session)) for file_id in user_file_ids] ), ) def get_user_files( user_file_ids: list[UUID], db_session: Session, ) -> list[UserFile]: """ Fetches UserFile database records based on provided file and project IDs. Args: user_file_ids: A list of specific UserFile IDs to fetch. db_session: The SQLAlchemy database session. Returns: A list containing UserFile SQLAlchemy model objects corresponding to the specified file IDs and all files within the specified project IDs. It does NOT return the actual file content. """ user_files: list[UserFile] = [] # 1. Fetch UserFile records for specific file IDs for user_file_id in user_file_ids: # Query the database for a UserFile with the matching ID user_file = ( db_session.query(UserFile).filter(UserFile.id == user_file_id).first() ) # If found, add it to the list if user_file is not None: user_files.append(user_file) # 3. Return the combined list of UserFile database objects return user_files def validate_user_files_ownership( user_file_ids: list[UUID], user_id: UUID | None, db_session: Session, ) -> list[UserFile]: """ Fetches all UserFile database records for a given user. """ user_files = get_user_files(user_file_ids, db_session) current_user_files = [] for user_file in user_files: # Note: if user_id is None, then all files should be None as well # (since auth must be disabled in this case) if user_file.user_id != user_id: raise ValueError( f"User {user_id} does not have access to file {user_file.id}" ) current_user_files.append(user_file) return current_user_files def save_file_from_url(url: str) -> str: response = requests.get(url) response.raise_for_status() file_io = BytesIO(response.content) file_store = get_default_file_store() file_id = file_store.save_file( content=file_io, display_name="GeneratedImage", file_origin=FileOrigin.CHAT_IMAGE_GEN, file_type="image/png;base64", ) return file_id def save_file_from_base64(base64_string: str) -> str: file_store = get_default_file_store() file_id = file_store.save_file( content=BytesIO(base64.b64decode(base64_string)), display_name="GeneratedImage", file_origin=FileOrigin.CHAT_IMAGE_GEN, file_type=get_image_type(base64_string), ) return file_id def save_file( url: str | None = None, base64_data: str | None = None, ) -> str: """Save a file from either a URL or base64 encoded string. Args: url: URL to download file from base64_data: Base64 encoded file data Returns: The unique ID of the saved file Raises: ValueError: If neither url nor base64_data is provided, or if both are provided """ if url is not None and base64_data is not None: raise ValueError("Cannot specify both url and base64_data") if url is not None: return save_file_from_url(url) elif base64_data is not None: return save_file_from_base64(base64_data) else: raise ValueError("Must specify either url or base64_data") def save_files(urls: list[str], base64_files: list[str]) -> list[str]: # NOTE: be explicit about typing so that if we change things, we get notified funcs: list[ tuple[ Callable[[str | None, str | None], str], tuple[str | None, str | None], ] ] = [(save_file, (url, None)) for url in urls] + [ (save_file, (None, base64_file)) for base64_file in base64_files ] return run_functions_tuples_in_parallel(funcs) @log_function_time(print_only=True) def verify_user_files( user_files: list[FileDescriptor], user_id: UUID | None, db_session: Session, project_id: int | None = None, ) -> None: """ Verify that all provided file descriptors belong to the specified user. For project files (those without user_file_id), verifies access through project ownership. Args: user_files: List of file descriptors to verify user_id: The user ID to check ownership against db_session: The SQLAlchemy database session project_id: Optional project ID to verify project file access against Raises: ValueError: If any file does not belong to the user or is not found """ from onyx.db.models import Project__UserFile from onyx.db.projects import check_project_ownership # Extract user_file_ids and project file_ids from the file descriptors user_file_ids = [] project_file_ids = [] for file_descriptor in user_files: # Check if this file descriptor has a user_file_id if file_descriptor.get("user_file_id"): try: user_file_ids.append(UUID(file_descriptor["user_file_id"])) except (ValueError, TypeError): logger.warning( f"Invalid user_file_id in file descriptor: {file_descriptor['user_file_id']}" ) continue else: # This is a project file - use the 'id' field which is the file_id if file_descriptor.get("id"): project_file_ids.append(file_descriptor["id"]) # Verify user files (existing logic) if user_file_ids: validate_user_files_ownership(user_file_ids, user_id, db_session) # Verify project files if project_file_ids: if project_id is None: raise ValueError( "Project files provided but no project_id specified for verification" ) # Verify user owns the project if not check_project_ownership(project_id, user_id, db_session): raise ValueError( f"User {user_id} does not have access to project {project_id}" ) # Verify all project files belong to the specified project user_files_in_project = ( db_session.query(UserFile) .join(Project__UserFile) .filter( Project__UserFile.project_id == project_id, UserFile.file_id.in_(project_file_ids), ) .all() ) # Check if all files were found in the project found_file_ids = {uf.file_id for uf in user_files_in_project} missing_files = set(project_file_ids) - found_file_ids if missing_files: raise ValueError( f"Files {missing_files} are not associated with project {project_id}" ) def build_frontend_file_url(file_id: str) -> str: return f"/api/chat/file/{file_id}" def build_full_frontend_file_url(file_id: str) -> str: return f"{WEB_DOMAIN}/api/chat/file/{file_id}" ================================================ FILE: backend/onyx/hooks/__init__.py ================================================ ================================================ FILE: backend/onyx/hooks/api_dependencies.py ================================================ from onyx.error_handling.error_codes import OnyxErrorCode from onyx.error_handling.exceptions import OnyxError from shared_configs.configs import MULTI_TENANT def require_hook_enabled() -> None: """FastAPI dependency that gates all hook management endpoints. Hooks are only available in single-tenant / self-hosted EE deployments. Use as: Depends(require_hook_enabled) """ if MULTI_TENANT: raise OnyxError( OnyxErrorCode.SINGLE_TENANT_ONLY, "Hooks are not available in multi-tenant deployments", ) ================================================ FILE: backend/onyx/hooks/executor.py ================================================ """CE hook executor. HookSkipped and HookSoftFailed are real classes kept here because process_message.py (CE code) uses isinstance checks against them. execute_hook is the public entry point. It dispatches to _execute_hook_impl via fetch_versioned_implementation so that: - CE: onyx.hooks.executor._execute_hook_impl → no-op, returns HookSkipped() - EE: ee.onyx.hooks.executor._execute_hook_impl → real HTTP call """ from typing import Any from typing import TypeVar from pydantic import BaseModel from sqlalchemy.orm import Session from onyx.db.enums import HookPoint from onyx.utils.variable_functionality import fetch_versioned_implementation class HookSkipped: """No active hook configured for this hook point.""" class HookSoftFailed: """Hook was called but failed with SOFT fail strategy — continuing.""" T = TypeVar("T", bound=BaseModel) def _execute_hook_impl( *, db_session: Session, # noqa: ARG001 hook_point: HookPoint, # noqa: ARG001 payload: dict[str, Any], # noqa: ARG001 response_type: type[T], # noqa: ARG001 ) -> T | HookSkipped | HookSoftFailed: """CE no-op — hooks are not available without EE.""" return HookSkipped() def execute_hook( *, db_session: Session, hook_point: HookPoint, payload: dict[str, Any], response_type: type[T], ) -> T | HookSkipped | HookSoftFailed: """Execute the hook for the given hook point. Dispatches to the versioned implementation so EE gets the real executor and CE gets the no-op stub, without any changes at the call site. """ impl = fetch_versioned_implementation("onyx.hooks.executor", "_execute_hook_impl") return impl( db_session=db_session, hook_point=hook_point, payload=payload, response_type=response_type, ) ================================================ FILE: backend/onyx/hooks/models.py ================================================ from datetime import datetime from enum import Enum from typing import Annotated from typing import Any from pydantic import BaseModel from pydantic import Field from pydantic import field_validator from pydantic import model_validator from pydantic import SecretStr from onyx.db.enums import HookFailStrategy from onyx.db.enums import HookPoint NonEmptySecretStr = Annotated[SecretStr, Field(min_length=1)] # --------------------------------------------------------------------------- # Request models # --------------------------------------------------------------------------- class HookCreateRequest(BaseModel): name: str = Field(min_length=1) hook_point: HookPoint endpoint_url: str = Field(min_length=1) api_key: NonEmptySecretStr | None = None fail_strategy: HookFailStrategy | None = None # if None, uses HookPointSpec default timeout_seconds: float | None = Field( default=None, gt=0 ) # if None, uses HookPointSpec default @field_validator("name", "endpoint_url") @classmethod def no_whitespace_only(cls, v: str) -> str: if not v.strip(): raise ValueError("cannot be whitespace-only.") return v class HookUpdateRequest(BaseModel): name: str | None = None endpoint_url: str | None = None api_key: NonEmptySecretStr | None = None fail_strategy: HookFailStrategy | None = None timeout_seconds: float | None = Field(default=None, gt=0) @model_validator(mode="after") def require_at_least_one_field(self) -> "HookUpdateRequest": if not self.model_fields_set: raise ValueError("At least one field must be provided for an update.") if "name" in self.model_fields_set and not (self.name or "").strip(): raise ValueError("name cannot be cleared.") if ( "endpoint_url" in self.model_fields_set and not (self.endpoint_url or "").strip() ): raise ValueError("endpoint_url cannot be cleared.") if "fail_strategy" in self.model_fields_set and self.fail_strategy is None: raise ValueError( "fail_strategy cannot be null; omit the field to leave it unchanged." ) if "timeout_seconds" in self.model_fields_set and self.timeout_seconds is None: raise ValueError( "timeout_seconds cannot be null; omit the field to leave it unchanged." ) return self # --------------------------------------------------------------------------- # Response models # --------------------------------------------------------------------------- class HookPointMetaResponse(BaseModel): hook_point: HookPoint display_name: str description: str docs_url: str | None input_schema: dict[str, Any] output_schema: dict[str, Any] default_timeout_seconds: float default_fail_strategy: HookFailStrategy fail_hard_description: str class HookResponse(BaseModel): id: int name: str hook_point: HookPoint # Nullable to match the DB column — endpoint_url is required on creation but # future hook point types may not use an external endpoint (e.g. built-in handlers). endpoint_url: str | None # Partially-masked API key (e.g. "abcd••••••••wxyz"), or None if no key is set. api_key_masked: str | None fail_strategy: HookFailStrategy timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write is_active: bool is_reachable: bool | None creator_email: str | None created_at: datetime updated_at: datetime class HookValidateStatus(str, Enum): passed = "passed" # server responded (any status except 401/403) auth_failed = "auth_failed" # server responded with 401 or 403 timeout = ( "timeout" # TCP connected, but read/write timed out (server exists but slow) ) cannot_connect = "cannot_connect" # could not connect to the server class HookValidateResponse(BaseModel): status: HookValidateStatus error_message: str | None = None class HookExecutionRecord(BaseModel): error_message: str | None = None status_code: int | None = None duration_ms: int | None = None created_at: datetime ================================================ FILE: backend/onyx/hooks/points/__init__.py ================================================ ================================================ FILE: backend/onyx/hooks/points/base.py ================================================ from typing import Any from typing import ClassVar from pydantic import BaseModel from onyx.db.enums import HookFailStrategy from onyx.db.enums import HookPoint _REQUIRED_ATTRS = ( "hook_point", "display_name", "description", "default_timeout_seconds", "fail_hard_description", "default_fail_strategy", "payload_model", "response_model", ) class HookPointSpec: """Static metadata and contract for a pipeline hook point. Each concrete subclass represents exactly one hook point and is instantiated once at startup, registered in onyx.hooks.registry._REGISTRY. Prefer get_hook_point_spec() or get_all_specs() from the registry over direct instantiation. Each hook point is a concrete subclass of this class. Onyx engineers own these definitions — customers never touch this code. Subclasses must define all attributes as class-level constants. payload_model and response_model must be Pydantic BaseModel subclasses; input_schema and output_schema are derived from them automatically. """ hook_point: HookPoint display_name: str description: str default_timeout_seconds: float fail_hard_description: str default_fail_strategy: HookFailStrategy docs_url: str | None = None payload_model: ClassVar[type[BaseModel]] response_model: ClassVar[type[BaseModel]] # Computed once at class definition time from payload_model / response_model. input_schema: ClassVar[dict[str, Any]] output_schema: ClassVar[dict[str, Any]] def __init_subclass__(cls, **kwargs: object) -> None: """Enforce that every subclass declares all required class attributes. Called automatically by Python whenever a class inherits from HookPointSpec. Raises TypeError at import time if any required attribute is missing or if payload_model / response_model are not Pydantic BaseModel subclasses. input_schema and output_schema are derived automatically from the models. """ super().__init_subclass__(**kwargs) missing = [attr for attr in _REQUIRED_ATTRS if not hasattr(cls, attr)] if missing: raise TypeError(f"{cls.__name__} must define class attributes: {missing}") for attr in ("payload_model", "response_model"): val = getattr(cls, attr, None) if val is None or not ( isinstance(val, type) and issubclass(val, BaseModel) ): raise TypeError( f"{cls.__name__}.{attr} must be a Pydantic BaseModel subclass, got {val!r}" ) cls.input_schema = cls.payload_model.model_json_schema() cls.output_schema = cls.response_model.model_json_schema() ================================================ FILE: backend/onyx/hooks/points/document_ingestion.py ================================================ from pydantic import BaseModel from pydantic import Field from onyx.db.enums import HookFailStrategy from onyx.db.enums import HookPoint from onyx.hooks.points.base import HookPointSpec class DocumentIngestionSection(BaseModel): """Represents a single section of a document — either text or image, not both. Text section: set `text`, leave `image_file_id` null. Image section: set `image_file_id`, leave `text` null. """ text: str | None = Field( default=None, description="Text content of this section. Set for text sections, null for image sections.", ) link: str | None = Field( default=None, description="Optional URL associated with this section. Preserve the original link from the payload if you want it retained.", ) image_file_id: str | None = Field( default=None, description=( "Opaque identifier for an image stored in the file store. " "The image content is not included — this field signals that the section is an image. " "Hooks can use its presence to reorder or drop image sections, but cannot read or modify the image itself." ), ) class DocumentIngestionOwner(BaseModel): display_name: str | None = Field( default=None, description="Human-readable name of the owner.", ) email: str | None = Field( default=None, description="Email address of the owner.", ) class DocumentIngestionPayload(BaseModel): document_id: str = Field( description="Unique identifier for the document. Read-only — changes are ignored." ) title: str | None = Field(description="Title of the document.") semantic_identifier: str = Field( description="Human-readable identifier used for display (e.g. file name, page title)." ) source: str = Field( description=( "Connector source type (e.g. confluence, slack, google_drive). " "Read-only — changes are ignored. " "Full list of values: https://github.com/onyx-dot-app/onyx/blob/main/backend/onyx/configs/constants.py#L195" ) ) sections: list[DocumentIngestionSection] = Field( description="Sections of the document. Includes both text sections (text set, image_file_id null) and image sections (image_file_id set, text null)." ) metadata: dict[str, list[str]] = Field( description="Key-value metadata attached to the document. Values are always a list of strings." ) doc_updated_at: str | None = Field( description="ISO 8601 UTC timestamp of the last update at the source, or null if unknown. Example: '2024-03-15T10:30:00+00:00'." ) primary_owners: list[DocumentIngestionOwner] | None = Field( description="Primary owners of the document, or null if not available." ) secondary_owners: list[DocumentIngestionOwner] | None = Field( description="Secondary owners of the document, or null if not available." ) class DocumentIngestionResponse(BaseModel): # Intentionally permissive — customer endpoints may return extra fields. sections: list[DocumentIngestionSection] | None = Field( description="The sections to index, in the desired order. Reorder, drop, or modify sections freely. Null or empty list drops the document." ) rejection_reason: str | None = Field( default=None, description="Logged when sections is null or empty. Falls back to a generic message if omitted.", ) class DocumentIngestionSpec(HookPointSpec): """Hook point that runs on every document before it enters the indexing pipeline. Call site: immediately after Onyx's internal validation and before the indexing pipeline begins — no partial writes have occurred yet. If a Document Ingestion hook is configured, it takes precedence — Document Ingestion Light will not run. Configure only one per deployment. Supported use cases: - Document filtering: drop documents based on content or metadata - Content rewriting: redact PII or normalize text before indexing """ hook_point = HookPoint.DOCUMENT_INGESTION display_name = "Document Ingestion" description = ( "Runs on every document before it enters the indexing pipeline. " "Allows filtering, rewriting, or dropping documents." ) default_timeout_seconds = 30.0 fail_hard_description = "The document will not be indexed." default_fail_strategy = HookFailStrategy.HARD docs_url = "https://docs.onyx.app/admins/advanced_configs/hook_extensions#document-ingestion" payload_model = DocumentIngestionPayload response_model = DocumentIngestionResponse ================================================ FILE: backend/onyx/hooks/points/query_processing.py ================================================ from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field from onyx.db.enums import HookFailStrategy from onyx.db.enums import HookPoint from onyx.hooks.points.base import HookPointSpec class QueryProcessingPayload(BaseModel): model_config = ConfigDict(extra="forbid") query: str = Field(description="The raw query string exactly as the user typed it.") user_email: str | None = Field( description="Email of the user submitting the query, or null if unauthenticated." ) chat_session_id: str = Field( description="UUID of the chat session, formatted as a hyphenated lowercase string (e.g. '550e8400-e29b-41d4-a716-446655440000'). Always present — the session is guaranteed to exist by the time this hook fires." ) class QueryProcessingResponse(BaseModel): # Intentionally permissive — customer endpoints may return extra fields. query: str | None = Field( default=None, description=( "The query to use in the pipeline. " "Null, empty string, whitespace-only, or absent = reject the query." ), ) rejection_message: str | None = Field( default=None, description="Message shown to the user when the query is rejected. Falls back to a generic message if not provided.", ) class QueryProcessingSpec(HookPointSpec): """Hook point that runs on every user query before it enters the pipeline. Call site: inside handle_stream_message_objects() in backend/onyx/chat/process_message.py, immediately after message_text is assigned from the request and before create_new_chat_message() saves it. This is the earliest possible point in the query pipeline: - Raw query — unmodified, exactly as the user typed it - No side effects yet — message has not been saved to DB - User identity is available for user-specific logic Supported use cases: - Query rejection: block queries based on content or user context - Query rewriting: normalize, expand, or modify the query - PII removal: scrub sensitive data before the LLM sees it - Access control: reject queries from certain users or groups - Query auditing: log or track queries based on business rules """ hook_point = HookPoint.QUERY_PROCESSING display_name = "Query Processing" description = ( "Runs on every user query before it enters the pipeline. " "Allows rewriting, filtering, or rejecting queries." ) default_timeout_seconds = 5.0 # user is actively waiting — keep tight fail_hard_description = ( "The query will be blocked and the user will see an error message." ) default_fail_strategy = HookFailStrategy.HARD docs_url = ( "https://docs.onyx.app/admins/advanced_configs/hook_extensions#query-processing" ) payload_model = QueryProcessingPayload response_model = QueryProcessingResponse ================================================ FILE: backend/onyx/hooks/registry.py ================================================ from onyx.db.enums import HookPoint from onyx.hooks.points.base import HookPointSpec from onyx.hooks.points.document_ingestion import DocumentIngestionSpec from onyx.hooks.points.query_processing import QueryProcessingSpec # Internal: use `monkeypatch.setattr(registry_module, "_REGISTRY", {...})` to override in tests. _REGISTRY: dict[HookPoint, HookPointSpec] = { HookPoint.DOCUMENT_INGESTION: DocumentIngestionSpec(), HookPoint.QUERY_PROCESSING: QueryProcessingSpec(), } def validate_registry() -> None: """Assert that every HookPoint enum value has a registered spec. Call once at application startup (e.g. from the FastAPI lifespan hook). Raises RuntimeError if any hook point is missing a spec. """ missing = set(HookPoint) - set(_REGISTRY) if missing: raise RuntimeError( f"Hook point(s) have no registered spec: {missing}. " "Add an entry to onyx.hooks.registry._REGISTRY." ) def get_hook_point_spec(hook_point: HookPoint) -> HookPointSpec: """Returns the spec for a given hook point. Raises ValueError if the hook point has no registered spec — this is a programmer error; every HookPoint enum value must have a corresponding spec in _REGISTRY. """ try: return _REGISTRY[hook_point] except KeyError: raise ValueError( f"No spec registered for hook point {hook_point!r}. " "Add an entry to onyx.hooks.registry._REGISTRY." ) def get_all_specs() -> list[HookPointSpec]: """Returns the specs for all registered hook points.""" return list(_REGISTRY.values()) ================================================ FILE: backend/onyx/httpx/httpx_pool.py ================================================ import threading from typing import Any import httpx def make_default_kwargs() -> dict[str, Any]: return { "http2": True, "limits": httpx.Limits(), } class HttpxPool: """Class to manage a global httpx Client instance""" _clients: dict[str, httpx.Client] = {} _lock: threading.Lock = threading.Lock() # Default parameters for creation def __init__(self) -> None: pass @classmethod def _init_client(cls, **kwargs: Any) -> httpx.Client: """Private helper method to create and return an httpx.Client.""" merged_kwargs = {**(make_default_kwargs()), **kwargs} return httpx.Client(**merged_kwargs) @classmethod def init_client(cls, name: str, **kwargs: Any) -> None: """Allow the caller to init the client with extra params.""" with cls._lock: if name not in cls._clients: cls._clients[name] = cls._init_client(**kwargs) @classmethod def close_client(cls, name: str) -> None: """Allow the caller to close the client.""" with cls._lock: client = cls._clients.pop(name, None) if client: client.close() @classmethod def close_all(cls) -> None: """Close all registered clients.""" with cls._lock: for client in cls._clients.values(): client.close() cls._clients.clear() @classmethod def get(cls, name: str) -> httpx.Client: """Gets the httpx.Client. Will init to default settings if not init'd.""" with cls._lock: if name not in cls._clients: cls._clients[name] = cls._init_client() return cls._clients[name] ================================================ FILE: backend/onyx/image_gen/__init__.py ================================================ ================================================ FILE: backend/onyx/image_gen/exceptions.py ================================================ class ImageProviderError(Exception): pass class ImageProviderCredentialsError(ImageProviderError): pass ================================================ FILE: backend/onyx/image_gen/factory.py ================================================ from enum import Enum from onyx.image_gen.interfaces import ImageGenerationProvider from onyx.image_gen.interfaces import ImageGenerationProviderCredentials from onyx.image_gen.providers.azure_img_gen import AzureImageGenerationProvider from onyx.image_gen.providers.openai_img_gen import OpenAIImageGenerationProvider from onyx.image_gen.providers.vertex_img_gen import VertexImageGenerationProvider class ImageGenerationProviderName(str, Enum): AZURE = "azure" OPENAI = "openai" VERTEX_AI = "vertex_ai" PROVIDERS: dict[ImageGenerationProviderName, type[ImageGenerationProvider]] = { ImageGenerationProviderName.AZURE: AzureImageGenerationProvider, ImageGenerationProviderName.OPENAI: OpenAIImageGenerationProvider, ImageGenerationProviderName.VERTEX_AI: VertexImageGenerationProvider, } def get_image_generation_provider( provider: str, credentials: ImageGenerationProviderCredentials, ) -> ImageGenerationProvider: provider_cls = _get_provider_cls(provider) return provider_cls.build_from_credentials(credentials) def validate_credentials( provider: str, credentials: ImageGenerationProviderCredentials, ) -> bool: provider_cls = _get_provider_cls(provider) return provider_cls.validate_credentials(credentials) def _get_provider_cls(provider: str) -> type[ImageGenerationProvider]: try: provider_enum = ImageGenerationProviderName(provider) except ValueError: raise ValueError(f"Invalid image generation provider: {provider}") return PROVIDERS[provider_enum] ================================================ FILE: backend/onyx/image_gen/interfaces.py ================================================ from __future__ import annotations import abc from typing import Any from typing import TYPE_CHECKING from pydantic import BaseModel from onyx.image_gen.exceptions import ImageProviderCredentialsError if TYPE_CHECKING: from litellm.types.utils import ImageResponse as ImageGenerationResponse class ImageGenerationProviderCredentials(BaseModel): api_key: str | None = None api_base: str | None = None api_version: str | None = None deployment_name: str | None = None custom_config: dict[str, str] | None = None class ReferenceImage(BaseModel): data: bytes mime_type: str class ImageGenerationProvider(abc.ABC): @property def supports_reference_images(self) -> bool: return False @property def max_reference_images(self) -> int: return 0 @classmethod @abc.abstractmethod def validate_credentials( cls, credentials: ImageGenerationProviderCredentials, ) -> bool: """Returns true if sufficient credentials are given to build this provider.""" raise NotImplementedError("validate_credentials not implemented") @classmethod def build_from_credentials( cls, credentials: ImageGenerationProviderCredentials, ) -> ImageGenerationProvider: if not cls.validate_credentials(credentials): raise ImageProviderCredentialsError( f"Invalid image generation credentials: {credentials}" ) return cls._build_from_credentials(credentials) @classmethod @abc.abstractmethod def _build_from_credentials( cls, credentials: ImageGenerationProviderCredentials, ) -> ImageGenerationProvider: """ Given credentials, builds an instance of the provider. Should NOT be called directly - use build_from_credentials instead. AssertionError if credentials are invalid. """ raise NotImplementedError("build_from_credentials not implemented") @abc.abstractmethod def generate_image( self, prompt: str, model: str, size: str, n: int, quality: str | None = None, reference_images: list[ReferenceImage] | None = None, **kwargs: Any, ) -> ImageGenerationResponse: """Generates an image based on a prompt.""" raise NotImplementedError("generate_image not implemented") ================================================ FILE: backend/onyx/image_gen/providers/azure_img_gen.py ================================================ from __future__ import annotations from typing import Any from typing import TYPE_CHECKING from onyx.image_gen.interfaces import ImageGenerationProvider from onyx.image_gen.interfaces import ImageGenerationProviderCredentials from onyx.image_gen.interfaces import ReferenceImage if TYPE_CHECKING: from onyx.image_gen.interfaces import ImageGenerationResponse class AzureImageGenerationProvider(ImageGenerationProvider): _GPT_IMAGE_MODEL_PREFIX = "gpt-image-" _DALL_E_2_MODEL_NAME = "dall-e-2" def __init__( self, api_key: str, api_base: str, api_version: str, deployment_name: str | None = None, ): self._api_key = api_key self._api_base = api_base self._api_version = api_version self._deployment_name = deployment_name @classmethod def validate_credentials( cls, credentials: ImageGenerationProviderCredentials, ) -> bool: return all( [ credentials.api_key, credentials.api_base, credentials.api_version, ] ) @classmethod def _build_from_credentials( cls, credentials: ImageGenerationProviderCredentials, ) -> AzureImageGenerationProvider: assert credentials.api_key assert credentials.api_base assert credentials.api_version return cls( api_key=credentials.api_key, api_base=credentials.api_base, api_version=credentials.api_version, deployment_name=credentials.deployment_name, ) @property def supports_reference_images(self) -> bool: return True @property def max_reference_images(self) -> int: # Azure GPT image models support up to 16 input images for edits. return 16 def _normalize_model_name(self, model: str) -> str: return model.rsplit("/", 1)[-1] def _model_supports_image_edits(self, model: str) -> bool: normalized_model = self._normalize_model_name(model) return ( normalized_model.startswith(self._GPT_IMAGE_MODEL_PREFIX) or normalized_model == self._DALL_E_2_MODEL_NAME ) def generate_image( self, prompt: str, model: str, size: str, n: int, quality: str | None = None, reference_images: list[ReferenceImage] | None = None, **kwargs: Any, ) -> ImageGenerationResponse: deployment = self._deployment_name or model model_name = f"azure/{deployment}" if reference_images: if not self._model_supports_image_edits(model): raise ValueError( f"Model '{model}' does not support image edits with reference images." ) normalized_model = self._normalize_model_name(model) if ( normalized_model == self._DALL_E_2_MODEL_NAME and len(reference_images) > 1 ): raise ValueError( "Model 'dall-e-2' only supports a single reference image for edits." ) from litellm import image_edit return image_edit( image=[image.data for image in reference_images], prompt=prompt, model=model_name, api_key=self._api_key, api_base=self._api_base, api_version=self._api_version, size=size, n=n, quality=quality, **kwargs, ) from litellm import image_generation return image_generation( prompt=prompt, model=model_name, api_key=self._api_key, api_base=self._api_base, api_version=self._api_version, size=size, n=n, quality=quality, **kwargs, ) ================================================ FILE: backend/onyx/image_gen/providers/openai_img_gen.py ================================================ from __future__ import annotations from typing import Any from typing import TYPE_CHECKING from onyx.image_gen.interfaces import ImageGenerationProvider from onyx.image_gen.interfaces import ImageGenerationProviderCredentials from onyx.image_gen.interfaces import ReferenceImage if TYPE_CHECKING: from onyx.image_gen.interfaces import ImageGenerationResponse class OpenAIImageGenerationProvider(ImageGenerationProvider): _GPT_IMAGE_MODEL_PREFIX = "gpt-image-" _DALL_E_2_MODEL_NAME = "dall-e-2" def __init__( self, api_key: str, api_base: str | None = None, ): self._api_key = api_key self._api_base = api_base @classmethod def validate_credentials( cls, credentials: ImageGenerationProviderCredentials, ) -> bool: return bool(credentials.api_key) @classmethod def _build_from_credentials( cls, credentials: ImageGenerationProviderCredentials, ) -> OpenAIImageGenerationProvider: assert credentials.api_key return cls( api_key=credentials.api_key, api_base=credentials.api_base, ) @property def supports_reference_images(self) -> bool: return True @property def max_reference_images(self) -> int: # GPT image models support up to 16 input images for edits. return 16 def _normalize_model_name(self, model: str) -> str: return model.rsplit("/", 1)[-1] def _model_supports_image_edits(self, model: str) -> bool: normalized_model = self._normalize_model_name(model) return ( normalized_model.startswith(self._GPT_IMAGE_MODEL_PREFIX) or normalized_model == self._DALL_E_2_MODEL_NAME ) def generate_image( self, prompt: str, model: str, size: str, n: int, quality: str | None = None, reference_images: list[ReferenceImage] | None = None, **kwargs: Any, ) -> ImageGenerationResponse: if reference_images: if not self._model_supports_image_edits(model): raise ValueError( f"Model '{model}' does not support image edits with reference images." ) normalized_model = self._normalize_model_name(model) if ( normalized_model == self._DALL_E_2_MODEL_NAME and len(reference_images) > 1 ): raise ValueError( "Model 'dall-e-2' only supports a single reference image for edits." ) from litellm import image_edit return image_edit( image=[image.data for image in reference_images], prompt=prompt, model=model, api_key=self._api_key, api_base=self._api_base, size=size, n=n, quality=quality, **kwargs, ) from litellm import image_generation return image_generation( prompt=prompt, model=model, api_key=self._api_key, api_base=self._api_base, size=size, n=n, quality=quality, **kwargs, ) ================================================ FILE: backend/onyx/image_gen/providers/vertex_img_gen.py ================================================ from __future__ import annotations import base64 import json from datetime import datetime from typing import Any from typing import TYPE_CHECKING from pydantic import BaseModel from onyx.image_gen.exceptions import ImageProviderCredentialsError from onyx.image_gen.interfaces import ImageGenerationProvider from onyx.image_gen.interfaces import ImageGenerationProviderCredentials from onyx.image_gen.interfaces import ReferenceImage if TYPE_CHECKING: from onyx.image_gen.interfaces import ImageGenerationResponse class VertexCredentials(BaseModel): vertex_credentials: str vertex_location: str project_id: str class VertexImageGenerationProvider(ImageGenerationProvider): def __init__( self, vertex_credentials: VertexCredentials, ): self._vertex_credentials = vertex_credentials.vertex_credentials self._vertex_location = vertex_credentials.vertex_location self._vertex_project = vertex_credentials.project_id @classmethod def validate_credentials( cls, credentials: ImageGenerationProviderCredentials, ) -> bool: try: _parse_to_vertex_credentials(credentials) return True except ImageProviderCredentialsError: return False @classmethod def _build_from_credentials( cls, credentials: ImageGenerationProviderCredentials, ) -> VertexImageGenerationProvider: vertex_credentials = _parse_to_vertex_credentials(credentials) return cls( vertex_credentials=vertex_credentials, ) @property def supports_reference_images(self) -> bool: return True @property def max_reference_images(self) -> int: # Gemini image editing supports up to 14 input images. return 14 def generate_image( self, prompt: str, model: str, size: str, n: int, quality: str | None = None, reference_images: list[ReferenceImage] | None = None, **kwargs: Any, ) -> ImageGenerationResponse: if reference_images: return self._generate_image_with_reference_images( prompt=prompt, model=model, size=size, n=n, reference_images=reference_images, ) from litellm import image_generation return image_generation( prompt=prompt, model=model, size=size, n=n, quality=quality, vertex_location=self._vertex_location, vertex_credentials=self._vertex_credentials, vertex_project=self._vertex_project, **kwargs, ) def _generate_image_with_reference_images( self, prompt: str, model: str, size: str, n: int, reference_images: list[ReferenceImage], ) -> ImageGenerationResponse: from google import genai from google.genai import types as genai_types from google.oauth2 import service_account from litellm.types.utils import ImageObject from litellm.types.utils import ImageResponse service_account_info = json.loads(self._vertex_credentials) credentials = service_account.Credentials.from_service_account_info( service_account_info, scopes=["https://www.googleapis.com/auth/cloud-platform"], ) client = genai.Client( vertexai=True, project=self._vertex_project, location=self._vertex_location, credentials=credentials, ) parts: list[genai_types.Part] = [ genai_types.Part.from_bytes(data=image.data, mime_type=image.mime_type) for image in reference_images ] parts.append(genai_types.Part.from_text(text=prompt)) config = genai_types.GenerateContentConfig( response_modalities=["TEXT", "IMAGE"], candidate_count=max(1, n), image_config=genai_types.ImageConfig( aspect_ratio=_map_size_to_aspect_ratio(size) ), ) model_name = model.replace("vertex_ai/", "") response = client.models.generate_content( model=model_name, contents=genai_types.Content( role="user", parts=parts, ), config=config, ) generated_data: list[ImageObject] = [] for candidate in response.candidates or []: candidate_content = candidate.content if not candidate_content: continue for part in candidate_content.parts or []: inline_data = part.inline_data if not inline_data or inline_data.data is None: continue if isinstance(inline_data.data, bytes): b64_json = base64.b64encode(inline_data.data).decode("utf-8") elif isinstance(inline_data.data, str): b64_json = inline_data.data else: continue generated_data.append( ImageObject( b64_json=b64_json, revised_prompt=prompt, ) ) if not generated_data: raise RuntimeError("No image data returned from Vertex AI.") return ImageResponse( created=int(datetime.now().timestamp()), data=generated_data, ) def _map_size_to_aspect_ratio(size: str) -> str: return { "1024x1024": "1:1", "1792x1024": "16:9", "1024x1792": "9:16", "1536x1024": "3:2", "1024x1536": "2:3", }.get(size, "1:1") def _parse_to_vertex_credentials( credentials: ImageGenerationProviderCredentials, ) -> VertexCredentials: custom_config = credentials.custom_config if not custom_config: raise ImageProviderCredentialsError("Custom config is required") vertex_credentials = custom_config.get("vertex_credentials") vertex_location = custom_config.get("vertex_location") if not vertex_credentials: raise ImageProviderCredentialsError("Vertex credentials are required") if not vertex_location: raise ImageProviderCredentialsError("Vertex location is required") vertex_json = json.loads(vertex_credentials) vertex_project = vertex_json.get("project_id") if not vertex_project: raise ImageProviderCredentialsError("Project ID is required") return VertexCredentials( vertex_credentials=vertex_credentials, vertex_location=vertex_location, project_id=vertex_project, ) ================================================ FILE: backend/onyx/indexing/__init__.py ================================================ ================================================ FILE: backend/onyx/indexing/adapters/document_indexing_adapter.py ================================================ import contextlib from collections.abc import Generator from sqlalchemy.engine.util import TransactionalContext from sqlalchemy.orm import Session from onyx.access.access import get_access_for_documents from onyx.access.models import DocumentAccess from onyx.configs.constants import DEFAULT_BOOST from onyx.connectors.models import Document from onyx.connectors.models import IndexAttemptMetadata from onyx.db.chunk import update_chunk_boost_components__no_commit from onyx.db.document import fetch_chunk_counts_for_documents from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit from onyx.db.document import prepare_to_modify_documents from onyx.db.document import update_docs_chunk_count__no_commit from onyx.db.document import update_docs_last_modified__no_commit from onyx.db.document import update_docs_updated_at__no_commit from onyx.db.document_set import fetch_document_sets_for_documents from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext from onyx.indexing.indexing_pipeline import index_doc_batch_prepare from onyx.indexing.models import ChunkEnrichmentContext from onyx.indexing.models import DocAwareChunk from onyx.indexing.models import DocMetadataAwareIndexChunk from onyx.indexing.models import IndexChunk from onyx.indexing.models import UpdatableChunkData from onyx.redis.redis_hierarchy import get_ancestors_from_raw_id from onyx.redis.redis_pool import get_redis_client from onyx.utils.logger import setup_logger logger = setup_logger() class DocumentIndexingBatchAdapter: """Default adapter: handles DB prep, locking, metadata enrichment, and finalize. Keeps orchestration logic in the pipeline and side-effects in the adapter. """ def __init__( self, db_session: Session, connector_id: int, credential_id: int, tenant_id: str, index_attempt_metadata: IndexAttemptMetadata, ): self.db_session = db_session self.connector_id = connector_id self.credential_id = credential_id self.tenant_id = tenant_id self.index_attempt_metadata = index_attempt_metadata def prepare( self, documents: list[Document], ignore_time_skip: bool ) -> DocumentBatchPrepareContext | None: """Upsert docs, map CC pairs, return context or mark as indexed if no-op.""" context = index_doc_batch_prepare( documents=documents, index_attempt_metadata=self.index_attempt_metadata, db_session=self.db_session, ignore_time_skip=ignore_time_skip, ) if not context: # even though we didn't actually index anything, we should still # mark them as "completed" for the CC Pair in order to make the # counts match mark_document_as_indexed_for_cc_pair__no_commit( connector_id=self.index_attempt_metadata.connector_id, credential_id=self.index_attempt_metadata.credential_id, document_ids=[doc.id for doc in documents], db_session=self.db_session, ) self.db_session.commit() return context @contextlib.contextmanager def lock_context( self, documents: list[Document] ) -> Generator[TransactionalContext, None, None]: """Acquire transaction/row locks on docs for the critical section.""" with prepare_to_modify_documents( db_session=self.db_session, document_ids=[doc.id for doc in documents] ) as transaction: yield transaction def prepare_enrichment( self, context: DocumentBatchPrepareContext, tenant_id: str, chunks: list[DocAwareChunk], ) -> "DocumentChunkEnricher": """Do all DB lookups once and return a per-chunk enricher.""" updatable_ids = [doc.id for doc in context.updatable_docs] doc_id_to_new_chunk_cnt: dict[str, int] = { doc_id: 0 for doc_id in updatable_ids } for chunk in chunks: if chunk.source_document.id in doc_id_to_new_chunk_cnt: doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1 no_access = DocumentAccess.build( user_emails=[], user_groups=[], external_user_emails=[], external_user_group_ids=[], is_public=False, ) return DocumentChunkEnricher( doc_id_to_access_info=get_access_for_documents( document_ids=updatable_ids, db_session=self.db_session ), doc_id_to_document_set={ document_id: document_sets for document_id, document_sets in fetch_document_sets_for_documents( document_ids=updatable_ids, db_session=self.db_session ) }, doc_id_to_ancestor_ids=self._get_ancestor_ids_for_documents( context.updatable_docs, tenant_id ), id_to_boost_map=context.id_to_boost_map, doc_id_to_previous_chunk_cnt={ document_id: chunk_count for document_id, chunk_count in fetch_chunk_counts_for_documents( document_ids=updatable_ids, db_session=self.db_session, ) }, doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt), no_access=no_access, tenant_id=tenant_id, ) def _get_ancestor_ids_for_documents( self, documents: list[Document], tenant_id: str, ) -> dict[str, list[int]]: """ Get ancestor hierarchy node IDs for a batch of documents. Uses Redis cache for fast lookups - no DB calls are made unless there's a cache miss. Documents provide parent_hierarchy_raw_node_id directly from the connector. Returns a mapping from document_id to list of ancestor node IDs. """ if not documents: return {} redis_client = get_redis_client(tenant_id=tenant_id) result: dict[str, list[int]] = {} for doc in documents: # Use parent_hierarchy_raw_node_id directly from the document # If None, get_ancestors_from_raw_id will return just the SOURCE node ancestors = get_ancestors_from_raw_id( redis_client=redis_client, source=doc.source, parent_hierarchy_raw_node_id=doc.parent_hierarchy_raw_node_id, db_session=self.db_session, ) result[doc.id] = ancestors return result def post_index( self, context: DocumentBatchPrepareContext, updatable_chunk_data: list[UpdatableChunkData], filtered_documents: list[Document], enrichment: ChunkEnrichmentContext, ) -> None: """Finalize DB updates, store plaintext, and mark docs as indexed.""" updatable_ids = [doc.id for doc in context.updatable_docs] last_modified_ids = [] ids_to_new_updated_at = {} for doc in context.updatable_docs: last_modified_ids.append(doc.id) # doc_updated_at is the source's idea (on the other end of the connector) # of when the doc was last modified if doc.doc_updated_at is None: continue ids_to_new_updated_at[doc.id] = doc.doc_updated_at update_docs_updated_at__no_commit( ids_to_new_updated_at=ids_to_new_updated_at, db_session=self.db_session ) update_docs_last_modified__no_commit( document_ids=last_modified_ids, db_session=self.db_session ) update_docs_chunk_count__no_commit( document_ids=updatable_ids, doc_id_to_chunk_count=enrichment.doc_id_to_new_chunk_cnt, db_session=self.db_session, ) # these documents can now be counted as part of the CC Pairs # document count, so we need to mark them as indexed # NOTE: even documents we skipped since they were already up # to date should be counted here in order to maintain parity # between CC Pair and index attempt counts mark_document_as_indexed_for_cc_pair__no_commit( connector_id=self.index_attempt_metadata.connector_id, credential_id=self.index_attempt_metadata.credential_id, document_ids=[doc.id for doc in filtered_documents], db_session=self.db_session, ) # save the chunk boost components to postgres update_chunk_boost_components__no_commit( chunk_data=updatable_chunk_data, db_session=self.db_session ) self.db_session.commit() class DocumentChunkEnricher: """Pre-computed metadata for per-chunk enrichment of connector documents.""" def __init__( self, doc_id_to_access_info: dict[str, DocumentAccess], doc_id_to_document_set: dict[str, list[str]], doc_id_to_ancestor_ids: dict[str, list[int]], id_to_boost_map: dict[str, int], doc_id_to_previous_chunk_cnt: dict[str, int], doc_id_to_new_chunk_cnt: dict[str, int], no_access: DocumentAccess, tenant_id: str, ) -> None: self._doc_id_to_access_info = doc_id_to_access_info self._doc_id_to_document_set = doc_id_to_document_set self._doc_id_to_ancestor_ids = doc_id_to_ancestor_ids self._id_to_boost_map = id_to_boost_map self._no_access = no_access self._tenant_id = tenant_id self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt def enrich_chunk( self, chunk: IndexChunk, score: float ) -> DocMetadataAwareIndexChunk: return DocMetadataAwareIndexChunk.from_index_chunk( index_chunk=chunk, access=self._doc_id_to_access_info.get( chunk.source_document.id, self._no_access ), document_sets=set( self._doc_id_to_document_set.get(chunk.source_document.id, []) ), user_project=[], personas=[], boost=( self._id_to_boost_map[chunk.source_document.id] if chunk.source_document.id in self._id_to_boost_map else DEFAULT_BOOST ), tenant_id=self._tenant_id, aggregated_chunk_boost_factor=score, ancestor_hierarchy_node_ids=self._doc_id_to_ancestor_ids[ chunk.source_document.id ], ) ================================================ FILE: backend/onyx/indexing/adapters/user_file_indexing_adapter.py ================================================ from __future__ import annotations import contextlib import datetime import time from collections import defaultdict from collections.abc import Generator from uuid import UUID from sqlalchemy import select from sqlalchemy.exc import OperationalError from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from sqlalchemy.orm.session import TransactionalContext from onyx.access.access import get_access_for_user_files from onyx.access.models import DocumentAccess from onyx.configs.constants import DEFAULT_BOOST from onyx.configs.constants import NotificationType from onyx.connectors.models import Document from onyx.db.enums import UserFileStatus from onyx.db.models import Persona from onyx.db.models import UserFile from onyx.db.notification import create_notification from onyx.db.user_file import fetch_chunk_counts_for_user_files from onyx.db.user_file import fetch_persona_ids_for_user_files from onyx.db.user_file import fetch_user_project_ids_for_user_files from onyx.file_store.utils import store_user_file_plaintext from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext from onyx.indexing.models import ChunkEnrichmentContext from onyx.indexing.models import DocAwareChunk from onyx.indexing.models import DocMetadataAwareIndexChunk from onyx.indexing.models import IndexChunk from onyx.indexing.models import UpdatableChunkData from onyx.llm.factory import get_default_llm from onyx.natural_language_processing.utils import count_tokens from onyx.natural_language_processing.utils import get_tokenizer from onyx.utils.logger import setup_logger logger = setup_logger() _NUM_LOCK_ATTEMPTS = 3 retry_delay = 0.5 def _acquire_user_file_locks(db_session: Session, user_file_ids: list[str]) -> bool: """Acquire locks for the specified user files.""" # Convert to UUIDs for the DB comparison user_file_uuid_list = [UUID(user_file_id) for user_file_id in user_file_ids] stmt = ( select(UserFile.id) .where(UserFile.id.in_(user_file_uuid_list)) .with_for_update(nowait=True) ) # will raise exception if any of the documents are already locked documents = db_session.scalars(stmt).all() # make sure we found every document if len(documents) != len(set(user_file_ids)): logger.warning("Didn't find row for all specified user file IDs. Aborting.") return False return True class UserFileIndexingAdapter: def __init__(self, tenant_id: str, db_session: Session): self.tenant_id = tenant_id self.db_session = db_session def prepare( self, documents: list[Document], ignore_time_skip: bool, # noqa: ARG002 ) -> DocumentBatchPrepareContext: return DocumentBatchPrepareContext( updatable_docs=documents, id_to_boost_map={}, # TODO(subash): add boost map ) @contextlib.contextmanager def lock_context( self, documents: list[Document] ) -> Generator[TransactionalContext, None, None]: self.db_session.commit() # ensure that we're not in a transaction lock_acquired = False for i in range(_NUM_LOCK_ATTEMPTS): try: with self.db_session.begin() as transaction: lock_acquired = _acquire_user_file_locks( db_session=self.db_session, user_file_ids=[doc.id for doc in documents], ) if lock_acquired: yield transaction break except OperationalError as e: logger.warning( f"Failed to acquire locks for user files on attempt {i}, retrying. Error: {e}" ) time.sleep(retry_delay) if not lock_acquired: raise RuntimeError( f"Failed to acquire locks after {_NUM_LOCK_ATTEMPTS} attempts for user files: {[doc.id for doc in documents]}" ) def prepare_enrichment( self, context: DocumentBatchPrepareContext, tenant_id: str, chunks: list[DocAwareChunk], ) -> UserFileChunkEnricher: """Do all DB lookups and pre-compute file metadata from chunks.""" updatable_ids = [doc.id for doc in context.updatable_docs] doc_id_to_new_chunk_cnt: dict[str, int] = defaultdict(int) content_by_file: dict[str, list[str]] = defaultdict(list) for chunk in chunks: doc_id_to_new_chunk_cnt[chunk.source_document.id] += 1 content_by_file[chunk.source_document.id].append(chunk.content) no_access = DocumentAccess.build( user_emails=[], user_groups=[], external_user_emails=[], external_user_group_ids=[], is_public=False, ) user_file_id_to_project_ids = fetch_user_project_ids_for_user_files( user_file_ids=updatable_ids, db_session=self.db_session, ) user_file_id_to_persona_ids = fetch_persona_ids_for_user_files( user_file_ids=updatable_ids, db_session=self.db_session, ) user_file_id_to_access: dict[str, DocumentAccess] = get_access_for_user_files( user_file_ids=updatable_ids, db_session=self.db_session, ) user_file_id_to_previous_chunk_cnt: dict[str, int] = { user_file_id: chunk_count for user_file_id, chunk_count in fetch_chunk_counts_for_user_files( user_file_ids=updatable_ids, db_session=self.db_session, ) } # Initialize tokenizer used for token count calculation try: llm = get_default_llm() llm_tokenizer = get_tokenizer( model_name=llm.config.model_name, provider_type=llm.config.model_provider, ) except Exception as e: logger.error(f"Error getting tokenizer: {e}") llm_tokenizer = None user_file_id_to_raw_text: dict[str, str] = {} user_file_id_to_token_count: dict[str, int | None] = {} for user_file_id in updatable_ids: contents = content_by_file.get(user_file_id) if contents: combined_content = " ".join(contents) user_file_id_to_raw_text[str(user_file_id)] = combined_content token_count: int = ( count_tokens(combined_content, llm_tokenizer) if llm_tokenizer else 0 ) user_file_id_to_token_count[str(user_file_id)] = token_count else: user_file_id_to_raw_text[str(user_file_id)] = "" user_file_id_to_token_count[str(user_file_id)] = None return UserFileChunkEnricher( user_file_id_to_access=user_file_id_to_access, user_file_id_to_project_ids=user_file_id_to_project_ids, user_file_id_to_persona_ids=user_file_id_to_persona_ids, doc_id_to_previous_chunk_cnt=user_file_id_to_previous_chunk_cnt, doc_id_to_new_chunk_cnt=dict(doc_id_to_new_chunk_cnt), user_file_id_to_raw_text=user_file_id_to_raw_text, user_file_id_to_token_count=user_file_id_to_token_count, no_access=no_access, tenant_id=tenant_id, ) def _notify_assistant_owners_if_files_ready( self, user_files: list[UserFile] ) -> None: """ Check if all files for associated assistants are processed and notify owners. Only sends notification when all files for an assistant are COMPLETED. """ for user_file in user_files: if user_file.status == UserFileStatus.COMPLETED: for assistant in user_file.assistants: # Skip assistants without owners if assistant.user_id is None: continue # Check if all OTHER files for this assistant are completed # (we already know current file is completed from the outer check) all_files_completed = all( f.status == UserFileStatus.COMPLETED for f in assistant.user_files if f.id != user_file.id ) if all_files_completed: create_notification( user_id=assistant.user_id, notif_type=NotificationType.ASSISTANT_FILES_READY, db_session=self.db_session, title="Your files are ready!", description=f"All files for agent {assistant.name} have been processed and are now available.", additional_data={ "persona_id": assistant.id, "link": f"/assistants/{assistant.id}", }, autocommit=False, ) def post_index( self, context: DocumentBatchPrepareContext, updatable_chunk_data: list[UpdatableChunkData], # noqa: ARG002 filtered_documents: list[Document], # noqa: ARG002 enrichment: ChunkEnrichmentContext, ) -> None: assert isinstance(enrichment, UserFileChunkEnricher) user_file_ids = [doc.id for doc in context.updatable_docs] user_files = ( self.db_session.query(UserFile) .options(selectinload(UserFile.assistants).selectinload(Persona.user_files)) .filter(UserFile.id.in_(user_file_ids)) .all() ) for user_file in user_files: # don't update the status if the user file is being deleted if user_file.status != UserFileStatus.DELETING: user_file.status = UserFileStatus.COMPLETED user_file.last_project_sync_at = datetime.datetime.now( datetime.timezone.utc ) user_file.chunk_count = enrichment.doc_id_to_new_chunk_cnt.get( str(user_file.id), 0 ) user_file.token_count = enrichment.user_file_id_to_token_count[ str(user_file.id) ] # Notify assistant owners if all their files are now processed self._notify_assistant_owners_if_files_ready(user_files) self.db_session.commit() # Store the plaintext in the file store for faster retrieval # NOTE: this creates its own session to avoid committing the overall # transaction. for user_file_id, raw_text in enrichment.user_file_id_to_raw_text.items(): store_user_file_plaintext( user_file_id=UUID(user_file_id), plaintext_content=raw_text, ) class UserFileChunkEnricher: """Pre-computed metadata for per-chunk enrichment of user-uploaded files.""" def __init__( self, user_file_id_to_access: dict[str, DocumentAccess], user_file_id_to_project_ids: dict[str, list[int]], user_file_id_to_persona_ids: dict[str, list[int]], doc_id_to_previous_chunk_cnt: dict[str, int], doc_id_to_new_chunk_cnt: dict[str, int], user_file_id_to_raw_text: dict[str, str], user_file_id_to_token_count: dict[str, int | None], no_access: DocumentAccess, tenant_id: str, ) -> None: self._user_file_id_to_access = user_file_id_to_access self._user_file_id_to_project_ids = user_file_id_to_project_ids self._user_file_id_to_persona_ids = user_file_id_to_persona_ids self._no_access = no_access self._tenant_id = tenant_id self.doc_id_to_previous_chunk_cnt = doc_id_to_previous_chunk_cnt self.doc_id_to_new_chunk_cnt = doc_id_to_new_chunk_cnt self.user_file_id_to_raw_text = user_file_id_to_raw_text self.user_file_id_to_token_count = user_file_id_to_token_count def enrich_chunk( self, chunk: IndexChunk, score: float ) -> DocMetadataAwareIndexChunk: return DocMetadataAwareIndexChunk.from_index_chunk( index_chunk=chunk, access=self._user_file_id_to_access.get( chunk.source_document.id, self._no_access ), document_sets=set(), user_project=self._user_file_id_to_project_ids.get( chunk.source_document.id, [] ), personas=self._user_file_id_to_persona_ids.get( chunk.source_document.id, [] ), boost=DEFAULT_BOOST, tenant_id=self._tenant_id, aggregated_chunk_boost_factor=score, ) ================================================ FILE: backend/onyx/indexing/chunk_batch_store.py ================================================ import pickle import shutil import tempfile from collections.abc import Iterator from pathlib import Path from onyx.indexing.models import IndexChunk class ChunkBatchStore: """Manages serialization of embedded chunks to a temporary directory. Owns the temp directory lifetime and provides save/load/stream/scrub operations. Use as a context manager to ensure cleanup:: with ChunkBatchStore() as store: store.save(chunks, batch_idx=0) for chunk in store.stream(): ... """ _EXT = ".pkl" def __init__(self) -> None: self._tmpdir: Path | None = None # -- context manager ----------------------------------------------------- def __enter__(self) -> "ChunkBatchStore": self._tmpdir = Path(tempfile.mkdtemp(prefix="onyx_embeddings_")) return self def __exit__(self, *_exc: object) -> None: if self._tmpdir is not None: shutil.rmtree(self._tmpdir, ignore_errors=True) self._tmpdir = None @property def _dir(self) -> Path: assert self._tmpdir is not None, "ChunkBatchStore used outside context manager" return self._tmpdir # -- storage primitives -------------------------------------------------- def save(self, chunks: list[IndexChunk], batch_idx: int) -> None: """Serialize a batch of embedded chunks to disk.""" with open(self._dir / f"batch_{batch_idx}{self._EXT}", "wb") as f: pickle.dump(chunks, f) def _load(self, batch_file: Path) -> list[IndexChunk]: """Deserialize a batch of embedded chunks from a file.""" with open(batch_file, "rb") as f: return pickle.load(f) def _batch_files(self) -> list[Path]: """Return batch files sorted by numeric index.""" return sorted( self._dir.glob(f"batch_*{self._EXT}"), key=lambda p: int(p.stem.removeprefix("batch_")), ) # -- higher-level operations --------------------------------------------- def stream(self) -> Iterator[IndexChunk]: """Yield all chunks across all batch files. Each call returns a fresh generator, so the data can be iterated multiple times (e.g. once per document index). """ for batch_file in self._batch_files(): yield from self._load(batch_file) def scrub_failed_docs(self, failed_doc_ids: set[str]) -> None: """Remove chunks belonging to *failed_doc_ids* from all batch files. When a document fails embedding in batch N, earlier batches may already contain successfully embedded chunks for that document. This ensures the output is all-or-nothing per document. """ for batch_file in self._batch_files(): batch_chunks = self._load(batch_file) cleaned = [ c for c in batch_chunks if c.source_document.id not in failed_doc_ids ] if len(cleaned) != len(batch_chunks): with open(batch_file, "wb") as f: pickle.dump(cleaned, f) ================================================ FILE: backend/onyx/indexing/chunker.py ================================================ from typing import cast from chonkie import SentenceChunker from onyx.configs.app_configs import AVERAGE_SUMMARY_EMBEDDINGS from onyx.configs.app_configs import BLURB_SIZE from onyx.configs.app_configs import LARGE_CHUNK_RATIO from onyx.configs.app_configs import MINI_CHUNK_SIZE from onyx.configs.app_configs import SKIP_METADATA_IN_CHUNK from onyx.configs.app_configs import USE_CHUNK_SUMMARY from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY from onyx.configs.constants import DocumentSource from onyx.configs.constants import RETURN_SEPARATOR from onyx.configs.constants import SECTION_SEPARATOR from onyx.connectors.cross_connector_utils.miscellaneous_utils import ( get_metadata_keys_to_ignore, ) from onyx.connectors.models import IndexingDocument from onyx.connectors.models import Section from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.indexing.models import DocAwareChunk from onyx.llm.utils import MAX_CONTEXT_TOKENS from onyx.natural_language_processing.utils import BaseTokenizer from onyx.utils.logger import setup_logger from onyx.utils.text_processing import clean_text from onyx.utils.text_processing import shared_precompare_cleanup from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE from shared_configs.configs import STRICT_CHUNK_TOKEN_LIMIT # Not supporting overlaps, we need a clean combination of chunks and it is unclear if overlaps # actually help quality at all CHUNK_OVERLAP = 0 # Fairly arbitrary numbers but the general concept is we don't want the title/metadata to # overwhelm the actual contents of the chunk MAX_METADATA_PERCENTAGE = 0.25 CHUNK_MIN_CONTENT = 256 logger = setup_logger() def _get_metadata_suffix_for_document_index( metadata: dict[str, str | list[str]], include_separator: bool = False ) -> tuple[str, str]: """ Returns the metadata as a natural language string representation with all of the keys and values for the vector embedding and a string of all of the values for the keyword search. """ if not metadata: return "", "" metadata_str = "Metadata:\n" metadata_values = [] for key, value in metadata.items(): if key in get_metadata_keys_to_ignore(): continue value_str = ", ".join(value) if isinstance(value, list) else value if isinstance(value, list): metadata_values.extend(value) else: metadata_values.append(value) metadata_str += f"\t{key} - {value_str}\n" metadata_semantic = metadata_str.strip() metadata_keyword = " ".join(metadata_values) if include_separator: return RETURN_SEPARATOR + metadata_semantic, RETURN_SEPARATOR + metadata_keyword return metadata_semantic, metadata_keyword def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwareChunk: """ Combines multiple DocAwareChunks into one large chunk (for "multipass" mode), appending the content and adjusting source_links accordingly. """ merged_chunk = DocAwareChunk( source_document=chunks[0].source_document, chunk_id=chunks[0].chunk_id, blurb=chunks[0].blurb, content=chunks[0].content, source_links=chunks[0].source_links or {}, image_file_id=None, section_continuation=(chunks[0].chunk_id > 0), title_prefix=chunks[0].title_prefix, metadata_suffix_semantic=chunks[0].metadata_suffix_semantic, metadata_suffix_keyword=chunks[0].metadata_suffix_keyword, large_chunk_reference_ids=[chunk.chunk_id for chunk in chunks], mini_chunk_texts=None, large_chunk_id=large_chunk_id, chunk_context="", doc_summary="", contextual_rag_reserved_tokens=0, ) offset = 0 for i in range(1, len(chunks)): merged_chunk.content += SECTION_SEPARATOR + chunks[i].content offset += len(SECTION_SEPARATOR) + len(chunks[i - 1].content) for link_offset, link_text in (chunks[i].source_links or {}).items(): if merged_chunk.source_links is None: merged_chunk.source_links = {} merged_chunk.source_links[link_offset + offset] = link_text return merged_chunk def generate_large_chunks(chunks: list[DocAwareChunk]) -> list[DocAwareChunk]: """ Generates larger "grouped" chunks by combining sets of smaller chunks. """ large_chunks = [] for idx, i in enumerate(range(0, len(chunks), LARGE_CHUNK_RATIO)): chunk_group = chunks[i : i + LARGE_CHUNK_RATIO] if len(chunk_group) > 1: large_chunk = _combine_chunks(chunk_group, idx) large_chunks.append(large_chunk) return large_chunks class Chunker: """ Chunks documents into smaller chunks for indexing. """ def __init__( self, tokenizer: BaseTokenizer, enable_multipass: bool = False, enable_large_chunks: bool = False, enable_contextual_rag: bool = False, blurb_size: int = BLURB_SIZE, include_metadata: bool = not SKIP_METADATA_IN_CHUNK, chunk_token_limit: int = DOC_EMBEDDING_CONTEXT_SIZE, chunk_overlap: int = CHUNK_OVERLAP, mini_chunk_size: int = MINI_CHUNK_SIZE, callback: IndexingHeartbeatInterface | None = None, ) -> None: self.include_metadata = include_metadata self.chunk_token_limit = chunk_token_limit self.enable_multipass = enable_multipass self.enable_large_chunks = enable_large_chunks self.enable_contextual_rag = enable_contextual_rag if enable_contextual_rag: assert ( USE_CHUNK_SUMMARY or USE_DOCUMENT_SUMMARY ), "Contextual RAG requires at least one of chunk summary and document summary enabled" self.default_contextual_rag_reserved_tokens = MAX_CONTEXT_TOKENS * ( int(USE_CHUNK_SUMMARY) + int(USE_DOCUMENT_SUMMARY) ) self.tokenizer = tokenizer self.callback = callback self.max_context = 0 self.prompt_tokens = 0 # Create a token counter function that returns the count instead of the tokens def token_counter(text: str) -> int: return len(tokenizer.encode(text)) self.blurb_splitter = SentenceChunker( tokenizer_or_token_counter=token_counter, chunk_size=blurb_size, chunk_overlap=0, return_type="texts", ) self.chunk_splitter = SentenceChunker( tokenizer_or_token_counter=token_counter, chunk_size=chunk_token_limit, chunk_overlap=chunk_overlap, return_type="texts", ) self.mini_chunk_splitter = ( SentenceChunker( tokenizer_or_token_counter=token_counter, chunk_size=mini_chunk_size, chunk_overlap=0, return_type="texts", ) if enable_multipass else None ) def _split_oversized_chunk(self, text: str, content_token_limit: int) -> list[str]: """ Splits the text into smaller chunks based on token count to ensure no chunk exceeds the content_token_limit. """ tokens = self.tokenizer.tokenize(text) chunks = [] start = 0 total_tokens = len(tokens) while start < total_tokens: end = min(start + content_token_limit, total_tokens) token_chunk = tokens[start:end] chunk_text = " ".join(token_chunk) chunks.append(chunk_text) start = end return chunks def _extract_blurb(self, text: str) -> str: """ Extract a short blurb from the text (first chunk of size `blurb_size`). """ # chunker is in `text` mode texts = cast(list[str], self.blurb_splitter.chunk(text)) if not texts: return "" return texts[0] def _get_mini_chunk_texts(self, chunk_text: str) -> list[str] | None: """ For "multipass" mode: additional sub-chunks (mini-chunks) for use in certain embeddings. """ if self.mini_chunk_splitter and chunk_text.strip(): # chunker is in `text` mode return cast(list[str], self.mini_chunk_splitter.chunk(chunk_text)) return None # ADDED: extra param image_url to store in the chunk def _create_chunk( self, document: IndexingDocument, chunks_list: list[DocAwareChunk], text: str, links: dict[int, str], is_continuation: bool = False, title_prefix: str = "", metadata_suffix_semantic: str = "", metadata_suffix_keyword: str = "", image_file_id: str | None = None, ) -> None: """ Helper to create a new DocAwareChunk, append it to chunks_list. """ new_chunk = DocAwareChunk( source_document=document, chunk_id=len(chunks_list), blurb=self._extract_blurb(text), content=text, source_links=links or {0: ""}, image_file_id=image_file_id, section_continuation=is_continuation, title_prefix=title_prefix, metadata_suffix_semantic=metadata_suffix_semantic, metadata_suffix_keyword=metadata_suffix_keyword, mini_chunk_texts=self._get_mini_chunk_texts(text), large_chunk_id=None, doc_summary="", chunk_context="", contextual_rag_reserved_tokens=0, # set per-document in _handle_single_document ) chunks_list.append(new_chunk) def _chunk_document_with_sections( self, document: IndexingDocument, sections: list[Section], title_prefix: str, metadata_suffix_semantic: str, metadata_suffix_keyword: str, content_token_limit: int, ) -> list[DocAwareChunk]: """ Loops through sections of the document, converting them into one or more chunks. Works with processed sections that are base Section objects. """ chunks: list[DocAwareChunk] = [] link_offsets: dict[int, str] = {} chunk_text = "" for section_idx, section in enumerate(sections): # Get section text and other attributes section_text = clean_text(str(section.text or "")) section_link_text = section.link or "" image_url = section.image_file_id # If there is no useful content, skip if not section_text and (not document.title or section_idx > 0): logger.warning( f"Skipping empty or irrelevant section in doc {document.semantic_identifier}, link={section_link_text}" ) continue # CASE 1: If this section has an image, force a separate chunk if image_url: # First, if we have any partially built text chunk, finalize it if chunk_text.strip(): self._create_chunk( document, chunks, chunk_text, link_offsets, is_continuation=False, title_prefix=title_prefix, metadata_suffix_semantic=metadata_suffix_semantic, metadata_suffix_keyword=metadata_suffix_keyword, ) chunk_text = "" link_offsets = {} # Create a chunk specifically for this image section # (Using the text summary that was generated during processing) self._create_chunk( document, chunks, section_text, links={0: section_link_text} if section_link_text else {}, image_file_id=image_url, title_prefix=title_prefix, metadata_suffix_semantic=metadata_suffix_semantic, metadata_suffix_keyword=metadata_suffix_keyword, ) # Continue to next section continue # CASE 2: Normal text section section_token_count = len(self.tokenizer.encode(section_text)) # If the section is large on its own, split it separately if section_token_count > content_token_limit: if chunk_text.strip(): self._create_chunk( document, chunks, chunk_text, link_offsets, False, title_prefix, metadata_suffix_semantic, metadata_suffix_keyword, ) chunk_text = "" link_offsets = {} # chunker is in `text` mode split_texts = cast(list[str], self.chunk_splitter.chunk(section_text)) for i, split_text in enumerate(split_texts): # If even the split_text is bigger than strict limit, further split if ( STRICT_CHUNK_TOKEN_LIMIT and len(self.tokenizer.encode(split_text)) > content_token_limit ): smaller_chunks = self._split_oversized_chunk( split_text, content_token_limit ) for j, small_chunk in enumerate(smaller_chunks): self._create_chunk( document, chunks, small_chunk, {0: section_link_text}, is_continuation=(j != 0), title_prefix=title_prefix, metadata_suffix_semantic=metadata_suffix_semantic, metadata_suffix_keyword=metadata_suffix_keyword, ) else: self._create_chunk( document, chunks, split_text, {0: section_link_text}, is_continuation=(i != 0), title_prefix=title_prefix, metadata_suffix_semantic=metadata_suffix_semantic, metadata_suffix_keyword=metadata_suffix_keyword, ) continue # If we can still fit this section into the current chunk, do so current_token_count = len(self.tokenizer.encode(chunk_text)) current_offset = len(shared_precompare_cleanup(chunk_text)) next_section_tokens = ( len(self.tokenizer.encode(SECTION_SEPARATOR)) + section_token_count ) if next_section_tokens + current_token_count <= content_token_limit: if chunk_text: chunk_text += SECTION_SEPARATOR chunk_text += section_text link_offsets[current_offset] = section_link_text else: # finalize the existing chunk self._create_chunk( document, chunks, chunk_text, link_offsets, False, title_prefix, metadata_suffix_semantic, metadata_suffix_keyword, ) # start a new chunk link_offsets = {0: section_link_text} chunk_text = section_text # finalize any leftover text chunk if chunk_text.strip() or not chunks: self._create_chunk( document, chunks, chunk_text, link_offsets or {0: ""}, # safe default False, title_prefix, metadata_suffix_semantic, metadata_suffix_keyword, ) return chunks def _handle_single_document( self, document: IndexingDocument ) -> list[DocAwareChunk]: # Specifically for reproducing an issue with gmail if document.source == DocumentSource.GMAIL: logger.debug(f"Chunking {document.semantic_identifier}") # Title prep title = self._extract_blurb(document.get_title_for_document_index() or "") title_prefix = title + RETURN_SEPARATOR if title else "" title_tokens = len(self.tokenizer.encode(title_prefix)) # Metadata prep metadata_suffix_semantic = "" metadata_suffix_keyword = "" metadata_tokens = 0 if self.include_metadata: ( metadata_suffix_semantic, metadata_suffix_keyword, ) = _get_metadata_suffix_for_document_index( document.metadata, include_separator=True ) metadata_tokens = len(self.tokenizer.encode(metadata_suffix_semantic)) # If metadata is too large, skip it in the semantic content if metadata_tokens >= self.chunk_token_limit * MAX_METADATA_PERCENTAGE: metadata_suffix_semantic = "" metadata_tokens = 0 single_chunk_fits = True doc_token_count = 0 if self.enable_contextual_rag: doc_content = document.get_text_content() tokenized_doc = self.tokenizer.tokenize(doc_content) doc_token_count = len(tokenized_doc) # check if doc + title + metadata fits in a single chunk. If so, no need for contextual RAG single_chunk_fits = ( doc_token_count + title_tokens + metadata_tokens <= self.chunk_token_limit ) # expand the size of the context used for contextual rag based on whether chunk context and doc summary are used context_size = 0 if ( self.enable_contextual_rag and not single_chunk_fits and not AVERAGE_SUMMARY_EMBEDDINGS ): context_size += self.default_contextual_rag_reserved_tokens # Adjust content token limit to accommodate title + metadata content_token_limit = ( self.chunk_token_limit - title_tokens - metadata_tokens - context_size ) # first check: if there is not enough actual chunk content when including contextual rag, # then don't do contextual rag if content_token_limit <= CHUNK_MIN_CONTENT: context_size = 0 # Don't do contextual RAG # revert to previous content token limit content_token_limit = ( self.chunk_token_limit - title_tokens - metadata_tokens ) # If there is not enough context remaining then just index the chunk with no prefix/suffix if content_token_limit <= CHUNK_MIN_CONTENT: # Not enough space left, so revert to full chunk without the prefix content_token_limit = self.chunk_token_limit title_prefix = "" metadata_suffix_semantic = "" # Use processed_sections if available (IndexingDocument), otherwise use original sections sections_to_chunk = document.processed_sections normal_chunks = self._chunk_document_with_sections( document, sections_to_chunk, title_prefix, metadata_suffix_semantic, metadata_suffix_keyword, content_token_limit, ) # Optional "multipass" large chunk creation if self.enable_multipass and self.enable_large_chunks: large_chunks = generate_large_chunks(normal_chunks) normal_chunks.extend(large_chunks) for chunk in normal_chunks: chunk.contextual_rag_reserved_tokens = context_size return normal_chunks def chunk(self, documents: list[IndexingDocument]) -> list[DocAwareChunk]: """ Takes in a list of documents and chunks them into smaller chunks for indexing while persisting the document metadata. Works with both standard Document objects and IndexingDocument objects with processed_sections. """ final_chunks: list[DocAwareChunk] = [] for document in documents: if self.callback and self.callback.should_stop(): raise RuntimeError("Chunker.chunk: Stop signal detected") chunks = self._handle_single_document(document) final_chunks.extend(chunks) if self.callback: self.callback.progress("Chunker.chunk", len(chunks)) return final_chunks ================================================ FILE: backend/onyx/indexing/content_classification.py ================================================ ================================================ FILE: backend/onyx/indexing/embedder.py ================================================ import time from abc import ABC from abc import abstractmethod from collections import defaultdict from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import ConnectorStopSignal from onyx.connectors.models import DocumentFailure from onyx.db.models import SearchSettings from onyx.document_index.chunk_content_enrichment import ( generate_enriched_content_for_chunk_embedding, ) from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.indexing.models import ChunkEmbedding from onyx.indexing.models import DocAwareChunk from onyx.indexing.models import IndexChunk from onyx.natural_language_processing.search_nlp_models import EmbeddingModel from onyx.utils.logger import setup_logger from onyx.utils.pydantic_util import shallow_model_dump from onyx.utils.timing import log_function_time from shared_configs.configs import INDEXING_MODEL_SERVER_HOST from shared_configs.configs import INDEXING_MODEL_SERVER_PORT from shared_configs.enums import EmbeddingProvider from shared_configs.enums import EmbedTextType from shared_configs.model_server_models import Embedding logger = setup_logger() class IndexingEmbedder(ABC): """Converts chunks into chunks with embeddings. Note that one chunk may have multiple embeddings associated with it.""" def __init__( self, model_name: str, normalize: bool, query_prefix: str | None, passage_prefix: str | None, provider_type: EmbeddingProvider | None, api_key: str | None, api_url: str | None, api_version: str | None, deployment_name: str | None, reduced_dimension: int | None, callback: IndexingHeartbeatInterface | None, ): self.model_name = model_name self.normalize = normalize self.query_prefix = query_prefix self.passage_prefix = passage_prefix self.provider_type = provider_type self.api_key = api_key self.api_url = api_url self.api_version = api_version self.deployment_name = deployment_name self.embedding_model = EmbeddingModel( model_name=model_name, query_prefix=query_prefix, passage_prefix=passage_prefix, normalize=normalize, api_key=api_key, provider_type=provider_type, api_url=api_url, api_version=api_version, deployment_name=deployment_name, reduced_dimension=reduced_dimension, # The below are globally set, this flow always uses the indexing one server_host=INDEXING_MODEL_SERVER_HOST, server_port=INDEXING_MODEL_SERVER_PORT, retrim_content=True, callback=callback, ) @abstractmethod def embed_chunks( self, chunks: list[DocAwareChunk], tenant_id: str | None = None, request_id: str | None = None, ) -> list[IndexChunk]: raise NotImplementedError class DefaultIndexingEmbedder(IndexingEmbedder): def __init__( self, model_name: str, normalize: bool, query_prefix: str | None, passage_prefix: str | None, provider_type: EmbeddingProvider | None = None, api_key: str | None = None, api_url: str | None = None, api_version: str | None = None, deployment_name: str | None = None, reduced_dimension: int | None = None, callback: IndexingHeartbeatInterface | None = None, ): super().__init__( model_name, normalize, query_prefix, passage_prefix, provider_type, api_key, api_url, api_version, deployment_name, reduced_dimension, callback, ) @log_function_time() def embed_chunks( self, chunks: list[DocAwareChunk], tenant_id: str | None = None, request_id: str | None = None, ) -> list[IndexChunk]: """Adds embeddings to the chunks, the title and metadata suffixes are added to the chunk as well if they exist. If there is no space for it, it would have been thrown out at the chunking step. """ # All chunks at this point must have some non-empty content flat_chunk_texts: list[str] = [] large_chunks_present = False for chunk in chunks: if chunk.large_chunk_reference_ids: large_chunks_present = True chunk_text = ( generate_enriched_content_for_chunk_embedding(chunk) ) or chunk.source_document.get_title_for_document_index() if not chunk_text: # This should never happen, the document would have been dropped # before getting to this point raise ValueError(f"Chunk has no content: {chunk.to_short_descriptor()}") flat_chunk_texts.append(chunk_text) if chunk.mini_chunk_texts: if chunk.large_chunk_reference_ids: # A large chunk does not contain mini chunks, if it matches the large chunk # with a high score, then mini chunks would not be used anyway # otherwise it should match the normal chunk raise RuntimeError("Large chunk contains mini chunks") flat_chunk_texts.extend(chunk.mini_chunk_texts) embeddings = self.embedding_model.encode( texts=flat_chunk_texts, text_type=EmbedTextType.PASSAGE, large_chunks_present=large_chunks_present, tenant_id=tenant_id, request_id=request_id, ) chunk_titles = { chunk.source_document.get_title_for_document_index() for chunk in chunks } # Drop any None or empty strings # If there is no title or the title is empty, the title embedding field will be null # which is ok, it just won't contribute at all to the scoring. chunk_titles_list = [title for title in chunk_titles if title] # Cache the Title embeddings to only have to do it once title_embed_dict: dict[str, Embedding] = {} if chunk_titles_list: title_embeddings = self.embedding_model.encode( chunk_titles_list, text_type=EmbedTextType.PASSAGE, tenant_id=tenant_id, request_id=request_id, ) title_embed_dict.update( { title: vector for title, vector in zip(chunk_titles_list, title_embeddings) } ) # Mapping embeddings to chunks embedded_chunks: list[IndexChunk] = [] embedding_ind_start = 0 for chunk in chunks: num_embeddings = 1 + ( len(chunk.mini_chunk_texts) if chunk.mini_chunk_texts else 0 ) chunk_embeddings = embeddings[ embedding_ind_start : embedding_ind_start + num_embeddings ] title = chunk.source_document.get_title_for_document_index() title_embedding = None if title: if title in title_embed_dict: # Using cached value to avoid recalculating for every chunk title_embedding = title_embed_dict[title] else: logger.error( "Title had to be embedded separately, this should not happen!" ) title_embedding = self.embedding_model.encode( [title], text_type=EmbedTextType.PASSAGE, tenant_id=tenant_id, request_id=request_id, )[0] title_embed_dict[title] = title_embedding new_embedded_chunk = IndexChunk.model_construct( **shallow_model_dump(chunk), embeddings=ChunkEmbedding( full_embedding=chunk_embeddings[0], mini_chunk_embeddings=chunk_embeddings[1:], ), title_embedding=title_embedding, ) embedded_chunks.append(new_embedded_chunk) embedding_ind_start += num_embeddings return embedded_chunks @classmethod def from_db_search_settings( cls, search_settings: SearchSettings, callback: IndexingHeartbeatInterface | None = None, ) -> "DefaultIndexingEmbedder": return cls( model_name=search_settings.model_name, normalize=search_settings.normalize, query_prefix=search_settings.query_prefix, passage_prefix=search_settings.passage_prefix, provider_type=search_settings.provider_type, api_key=search_settings.api_key, api_url=search_settings.api_url, api_version=search_settings.api_version, deployment_name=search_settings.deployment_name, reduced_dimension=search_settings.reduced_dimension, callback=callback, ) def embed_chunks_with_failure_handling( chunks: list[DocAwareChunk], embedder: IndexingEmbedder, tenant_id: str | None = None, request_id: str | None = None, ) -> tuple[list[IndexChunk], list[ConnectorFailure]]: """Tries to embed all chunks in one large batch. If that batch fails for any reason, goes document by document to isolate the failure(s). """ # TODO(rkuo): this doesn't disambiguate calls to the model server on retries. # Improve this if needed. # First try to embed all chunks in one batch try: return ( embedder.embed_chunks( chunks=chunks, tenant_id=tenant_id, request_id=request_id ), [], ) except ConnectorStopSignal as e: logger.warning( "Connector stop signal detected in embed_chunks_with_failure_handling" ) raise e except Exception: logger.exception("Failed to embed chunk batch. Trying individual docs.") # wait a couple seconds to let any rate limits or temporary issues resolve time.sleep(2) # Try embedding each document's chunks individually chunks_by_doc: dict[str, list[DocAwareChunk]] = defaultdict(list) for chunk in chunks: chunks_by_doc[chunk.source_document.id].append(chunk) embedded_chunks: list[IndexChunk] = [] failures: list[ConnectorFailure] = [] for doc_id, chunks_for_doc in chunks_by_doc.items(): try: doc_embedded_chunks = embedder.embed_chunks( chunks=chunks_for_doc, tenant_id=tenant_id, request_id=request_id ) embedded_chunks.extend(doc_embedded_chunks) except Exception as e: logger.exception(f"Failed to embed chunks for document '{doc_id}'") failures.append( ConnectorFailure( failed_document=DocumentFailure( document_id=doc_id, document_link=( chunks_for_doc[0].get_link() if chunks_for_doc else None ), ), failure_message=str(e), exception=e, ) ) return embedded_chunks, failures ================================================ FILE: backend/onyx/indexing/indexing_heartbeat.py ================================================ from abc import ABC from abc import abstractmethod class IndexingHeartbeatInterface(ABC): """Defines a callback interface to be passed to to run_indexing_entrypoint.""" @abstractmethod def should_stop(self) -> bool: """Signal to stop the looping function in flight.""" @abstractmethod def progress(self, tag: str, amount: int) -> None: """Send progress updates to the caller. Amount can be a positive number to indicate progress or <= 0 just to act as a keep-alive. """ ================================================ FILE: backend/onyx/indexing/indexing_pipeline.py ================================================ from collections import defaultdict from collections.abc import Callable from collections.abc import Generator from collections.abc import Iterator from contextlib import contextmanager from typing import Protocol from pydantic import BaseModel from pydantic import ConfigDict from sqlalchemy.orm import Session from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_NAME from onyx.configs.app_configs import DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER from onyx.configs.app_configs import ENABLE_CONTEXTUAL_RAG from onyx.configs.app_configs import MAX_CHUNKS_PER_DOC_BATCH from onyx.configs.app_configs import MAX_DOCUMENT_CHARS from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION from onyx.configs.app_configs import USE_CHUNK_SUMMARY from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled from onyx.connectors.cross_connector_utils.miscellaneous_utils import ( get_experts_stores_representations, ) from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import ConnectorStopSignal from onyx.connectors.models import Document from onyx.connectors.models import DocumentFailure from onyx.connectors.models import ImageSection from onyx.connectors.models import IndexAttemptMetadata from onyx.connectors.models import IndexingDocument from onyx.connectors.models import Section from onyx.connectors.models import TextSection from onyx.db.document import get_documents_by_ids from onyx.db.document import upsert_document_by_connector_credential_pair from onyx.db.document import upsert_documents from onyx.db.enums import HookPoint from onyx.db.hierarchy import link_hierarchy_nodes_to_documents from onyx.db.models import Document as DBDocument from onyx.db.models import IndexModelStatus from onyx.db.search_settings import get_active_search_settings from onyx.db.tag import upsert_document_tags from onyx.document_index.document_index_utils import ( get_multipass_config, ) from onyx.document_index.interfaces import DocumentIndex from onyx.document_index.interfaces import DocumentInsertionRecord from onyx.document_index.interfaces import DocumentMetadata from onyx.document_index.interfaces import IndexBatchParams from onyx.file_processing.image_summarization import summarize_image_with_error_handling from onyx.file_store.file_store import get_default_file_store from onyx.hooks.executor import execute_hook from onyx.hooks.executor import HookSkipped from onyx.hooks.executor import HookSoftFailed from onyx.hooks.points.document_ingestion import DocumentIngestionOwner from onyx.hooks.points.document_ingestion import DocumentIngestionPayload from onyx.hooks.points.document_ingestion import DocumentIngestionResponse from onyx.hooks.points.document_ingestion import DocumentIngestionSection from onyx.indexing.chunk_batch_store import ChunkBatchStore from onyx.indexing.chunker import Chunker from onyx.indexing.embedder import embed_chunks_with_failure_handling from onyx.indexing.embedder import IndexingEmbedder from onyx.indexing.models import DocAwareChunk from onyx.indexing.models import DocMetadataAwareIndexChunk from onyx.indexing.models import IndexingBatchAdapter from onyx.indexing.models import UpdatableChunkData from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff from onyx.llm.factory import get_default_llm_with_vision from onyx.llm.factory import get_llm_for_contextual_rag from onyx.llm.interfaces import LLM from onyx.llm.models import UserMessage from onyx.llm.multi_llm import LLMRateLimitError from onyx.llm.utils import llm_response_to_string from onyx.llm.utils import MAX_CONTEXT_TOKENS from onyx.natural_language_processing.utils import BaseTokenizer from onyx.natural_language_processing.utils import get_tokenizer from onyx.natural_language_processing.utils import tokenizer_trim_middle from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT1 from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_PROMPT2 from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_PROMPT from onyx.utils.batching import batch_generator from onyx.utils.logger import setup_logger from onyx.utils.postgres_sanitization import sanitize_documents_for_postgres from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel from onyx.utils.timing import log_function_time logger = setup_logger() class DocumentBatchPrepareContext(BaseModel): updatable_docs: list[Document] id_to_boost_map: dict[str, int] indexable_docs: list[IndexingDocument] = [] model_config = ConfigDict(arbitrary_types_allowed=True) class IndexingPipelineResult(BaseModel): # number of documents that are completely new (e.g. did # not exist as a part of this OR any other connector) new_docs: int # NOTE: need total_docs, since the pipeline can skip some docs # (e.g. not even insert them into Postgres) total_docs: int # number of chunks that were inserted into Vespa total_chunks: int failures: list[ConnectorFailure] @classmethod def empty(cls, total_docs: int) -> "IndexingPipelineResult": return cls( new_docs=0, total_docs=total_docs, total_chunks=0, failures=[], ) class ChunkEmbeddingResult(BaseModel): successful_chunk_ids: list[tuple[int, str]] # (chunk_id, document_id) connector_failures: list[ConnectorFailure] class IndexingPipelineProtocol(Protocol): def __call__( self, document_batch: list[Document], index_attempt_metadata: IndexAttemptMetadata, ) -> IndexingPipelineResult: ... def _upsert_documents_in_db( documents: list[Document], index_attempt_metadata: IndexAttemptMetadata, db_session: Session, ) -> None: # Metadata here refers to basic document info, not metadata about the actual content document_metadata_list: list[DocumentMetadata] = [] for doc in documents: first_link = next( (section.link for section in doc.sections if section.link), "" ) db_doc_metadata = DocumentMetadata( connector_id=index_attempt_metadata.connector_id, credential_id=index_attempt_metadata.credential_id, document_id=doc.id, semantic_identifier=doc.semantic_identifier, first_link=first_link, primary_owners=get_experts_stores_representations(doc.primary_owners), secondary_owners=get_experts_stores_representations(doc.secondary_owners), from_ingestion_api=doc.from_ingestion_api, external_access=doc.external_access, doc_metadata=doc.doc_metadata, # parent_hierarchy_node_id is resolved in docfetching using Redis cache parent_hierarchy_node_id=doc.parent_hierarchy_node_id, ) document_metadata_list.append(db_doc_metadata) upsert_documents(db_session, document_metadata_list) # Insert document content metadata for doc in documents: upsert_document_tags( document_id=doc.id, source=doc.source, metadata=doc.metadata, db_session=db_session, ) def _get_failed_doc_ids(failures: list[ConnectorFailure]) -> set[str]: """Extract document IDs from a list of connector failures.""" return {f.failed_document.document_id for f in failures if f.failed_document} def _embed_chunks_to_store( chunks: list[DocAwareChunk], embedder: IndexingEmbedder, tenant_id: str, request_id: str | None, store: ChunkBatchStore, ) -> ChunkEmbeddingResult: """Embed chunks in batches, spilling each batch to *store*. If a document fails embedding in any batch, its chunks are excluded from all batches (including earlier ones already written) so that the output is all-or-nothing per document. """ successful_chunk_ids: list[tuple[int, str]] = [] all_embedding_failures: list[ConnectorFailure] = [] # Track failed doc IDs across all batches so that a failure in batch N # causes chunks for that doc to be skipped in batch N+1 and stripped # from earlier batches. all_failed_doc_ids: set[str] = set() for batch_idx, chunk_batch in enumerate( batch_generator(chunks, MAX_CHUNKS_PER_DOC_BATCH) ): # Skip chunks belonging to documents that failed in earlier batches. chunk_batch = [ c for c in chunk_batch if c.source_document.id not in all_failed_doc_ids ] if not chunk_batch: continue logger.debug(f"Embedding batch {batch_idx}: {len(chunk_batch)} chunks") chunks_with_embeddings, embedding_failures = embed_chunks_with_failure_handling( chunks=chunk_batch, embedder=embedder, tenant_id=tenant_id, request_id=request_id, ) all_embedding_failures.extend(embedding_failures) all_failed_doc_ids.update(_get_failed_doc_ids(embedding_failures)) # Only keep successfully embedded chunks for non-failed docs. chunks_with_embeddings = [ c for c in chunks_with_embeddings if c.source_document.id not in all_failed_doc_ids ] successful_chunk_ids.extend( (c.chunk_id, c.source_document.id) for c in chunks_with_embeddings ) store.save(chunks_with_embeddings, batch_idx) del chunks_with_embeddings # Scrub earlier batches for docs that failed in later batches. if all_failed_doc_ids: store.scrub_failed_docs(all_failed_doc_ids) successful_chunk_ids = [ (chunk_id, doc_id) for chunk_id, doc_id in successful_chunk_ids if doc_id not in all_failed_doc_ids ] return ChunkEmbeddingResult( successful_chunk_ids=successful_chunk_ids, connector_failures=all_embedding_failures, ) @contextmanager def embed_and_stream( chunks: list[DocAwareChunk], embedder: IndexingEmbedder, tenant_id: str, request_id: str | None, ) -> Generator[tuple[ChunkEmbeddingResult, ChunkBatchStore], None, None]: """Embed chunks to disk and yield a ``(result, store)`` pair. The store owns the temp directory — files are cleaned up when the context manager exits. Usage:: with embed_and_stream(chunks, embedder, tenant_id, req_id) as (result, store): for chunk in store.stream(): ... """ with ChunkBatchStore() as store: result = _embed_chunks_to_store( chunks=chunks, embedder=embedder, tenant_id=tenant_id, request_id=request_id, store=store, ) yield result, store def get_doc_ids_to_update( documents: list[Document], db_docs: list[DBDocument] ) -> list[Document]: """Figures out which documents actually need to be updated. If a document is already present and the `updated_at` hasn't changed, we shouldn't need to do anything with it. NB: Still need to associate the document in the DB if multiple connectors are indexing the same doc.""" id_update_time_map = { doc.id: doc.doc_updated_at for doc in db_docs if doc.doc_updated_at } updatable_docs: list[Document] = [] for doc in documents: if ( doc.id in id_update_time_map and doc.doc_updated_at and doc.doc_updated_at <= id_update_time_map[doc.id] ): continue updatable_docs.append(doc) return updatable_docs def index_doc_batch_with_handler( *, chunker: Chunker, embedder: IndexingEmbedder, document_indices: list[DocumentIndex], document_batch: list[Document], request_id: str | None, tenant_id: str, db_session: Session, adapter: IndexingBatchAdapter, ignore_time_skip: bool = False, enable_contextual_rag: bool = False, llm: LLM | None = None, ) -> IndexingPipelineResult: try: index_pipeline_result = index_doc_batch( chunker=chunker, embedder=embedder, document_indices=document_indices, document_batch=document_batch, request_id=request_id, tenant_id=tenant_id, db_session=db_session, adapter=adapter, ignore_time_skip=ignore_time_skip, enable_contextual_rag=enable_contextual_rag, llm=llm, ) except ConnectorStopSignal as e: logger.warning("Connector stop signal detected in index_doc_batch_with_handler") raise e except Exception as e: # don't log the batch directly, it's too much text document_ids = [doc.id for doc in document_batch] logger.exception(f"Failed to index document batch: {document_ids}") index_pipeline_result = IndexingPipelineResult( new_docs=0, total_docs=len(document_batch), total_chunks=0, failures=[ ConnectorFailure( failed_document=DocumentFailure( document_id=document.id, document_link=( document.sections[0].link if document.sections else None ), ), failure_message=str(e), exception=e, ) for document in document_batch ], ) return index_pipeline_result def index_doc_batch_prepare( documents: list[Document], index_attempt_metadata: IndexAttemptMetadata, db_session: Session, ignore_time_skip: bool = False, ) -> DocumentBatchPrepareContext | None: """Sets up the documents in the relational DB (source of truth) for permissions, metadata, etc. This preceeds indexing it into the actual document index.""" documents = sanitize_documents_for_postgres(documents) # Create a trimmed list of docs that don't have a newer updated at # Shortcuts the time-consuming flow on connector index retries document_ids: list[str] = [document.id for document in documents] db_docs: list[DBDocument] = get_documents_by_ids( db_session=db_session, document_ids=document_ids, ) updatable_docs = ( get_doc_ids_to_update(documents=documents, db_docs=db_docs) if not ignore_time_skip else documents ) if len(updatable_docs) != len(documents): updatable_doc_ids = [doc.id for doc in updatable_docs] skipped_doc_ids = [ doc.id for doc in documents if doc.id not in updatable_doc_ids ] logger.info( f"Skipping {len(skipped_doc_ids)} documents because they are up to date. Skipped doc IDs: {skipped_doc_ids}" ) # for all updatable docs, upsert into the DB # Does not include doc_updated_at which is also used to indicate a successful update if updatable_docs: _upsert_documents_in_db( documents=updatable_docs, index_attempt_metadata=index_attempt_metadata, db_session=db_session, ) logger.info( f"Upserted {len(updatable_docs)} changed docs out of {len(documents)} total docs into the DB" ) # for all docs, upsert the document to cc pair relationship upsert_document_by_connector_credential_pair( db_session, index_attempt_metadata.connector_id, index_attempt_metadata.credential_id, document_ids, ) # Link hierarchy nodes to documents for sources where pages can be both # hierarchy nodes AND documents (e.g., Notion, Confluence). # This must happen after documents are upserted due to FK constraint. if documents: link_hierarchy_nodes_to_documents( db_session=db_session, document_ids=document_ids, source=documents[0].source, commit=False, # We'll commit with the rest of the transaction ) # No docs to process because the batch is empty or every doc was already indexed if not updatable_docs: return None id_to_boost_map = {doc.id: doc.boost for doc in db_docs} return DocumentBatchPrepareContext( updatable_docs=updatable_docs, id_to_boost_map=id_to_boost_map ) def filter_documents(document_batch: list[Document]) -> list[Document]: documents: list[Document] = [] total_chars_in_batch = 0 skipped_too_long = [] for document in document_batch: empty_contents = not any( isinstance(section, TextSection) and section.text is not None and section.text.strip() for section in document.sections ) if ( (not document.title or not document.title.strip()) and not document.semantic_identifier.strip() and empty_contents ): # Skip documents that have neither title nor content # If the document doesn't have either, then there is no useful information in it # This is again verified later in the pipeline after chunking but at that point there should # already be no documents that are empty. logger.warning( f"Skipping document with ID {document.id} as it has neither title nor content." ) continue if document.title is not None and not document.title.strip() and empty_contents: # The title is explicitly empty ("" and not None) and the document is empty # so when building the chunk text representation, it will be empty and unuseable logger.warning( f"Skipping document with ID {document.id} as the chunks will be empty." ) continue section_chars = sum( ( len(section.text) if isinstance(section, TextSection) and section.text is not None else 0 ) for section in document.sections ) doc_total_chars = ( len(document.title or document.semantic_identifier) + section_chars ) if MAX_DOCUMENT_CHARS and doc_total_chars > MAX_DOCUMENT_CHARS: # Skip documents that are too long, later on there are more memory intensive steps done on the text # and the container will run out of memory and crash. Several other checks are included upstream but # those are at the connector level so a catchall is still needed. # Assumption here is that files that are that long, are generated files and not the type users # generally care for. logger.warning( f"Skipping document with ID {document.id} as it is too long " f"({doc_total_chars:,} chars, max={MAX_DOCUMENT_CHARS:,})" ) skipped_too_long.append((document.id, doc_total_chars)) continue total_chars_in_batch += doc_total_chars documents.append(document) # Log batch statistics for OOM debugging if documents: avg_chars = total_chars_in_batch / len(documents) # Get the source from the first document (all in batch should be same source) source = documents[0].source.value if documents[0].source else "unknown" logger.debug( f"Document batch filter [{source}]: {len(documents)} docs kept, {len(skipped_too_long)} skipped (too long). " f"Total chars: {total_chars_in_batch:,}, Avg: {avg_chars:,.0f} chars/doc" ) if skipped_too_long: logger.warning( f"Skipped oversized documents [{source}]: {skipped_too_long[:5]}" ) # Log first 5 return documents def process_image_sections(documents: list[Document]) -> list[IndexingDocument]: """ Process all sections in documents by: 1. Converting both TextSection and ImageSection objects to base Section objects 2. Processing ImageSections to generate text summaries using a vision-capable LLM 3. Returning IndexingDocument objects with both original and processed sections Args: documents: List of documents with TextSection | ImageSection objects Returns: List of IndexingDocument objects with processed_sections as list[Section] """ # Check if image extraction and analysis is enabled before trying to get a vision LLM if not get_image_extraction_and_analysis_enabled(): llm = None else: # Only get the vision LLM if image processing is enabled llm = get_default_llm_with_vision() if not llm: if get_image_extraction_and_analysis_enabled(): logger.warning( "Image analysis is enabled but no vision-capable LLM is " "available — images will not be summarized. Configure a " "vision model in the admin LLM settings." ) # Even without LLM, we still convert to IndexingDocument with base Sections return [ IndexingDocument( **document.model_dump(), processed_sections=[ Section( text=section.text if isinstance(section, TextSection) else "", link=section.link, image_file_id=( section.image_file_id if isinstance(section, ImageSection) else None ), ) for section in document.sections ], ) for document in documents ] indexed_documents: list[IndexingDocument] = [] for document in documents: processed_sections: list[Section] = [] for section in document.sections: # For ImageSection, process and create base Section with both text and image_file_id if isinstance(section, ImageSection): # Default section with image path preserved - ensure text is always a string processed_section = Section( link=section.link, image_file_id=section.image_file_id, text="", # Initialize with empty string ) # Try to get image summary try: file_store = get_default_file_store() file_record = file_store.read_file_record( file_id=section.image_file_id ) if not file_record: logger.warning( f"Image file {section.image_file_id} not found in FileStore" ) processed_section.text = "[Image could not be processed]" else: # Get the image data image_data_io = file_store.read_file( file_id=section.image_file_id ) image_data = image_data_io.read() summary = summarize_image_with_error_handling( llm=llm, image_data=image_data, context_name=file_record.display_name or "Image", ) if summary: processed_section.text = summary else: processed_section.text = "[Image could not be summarized]" except Exception as e: logger.error(f"Error processing image section: {e}") processed_section.text = "[Error processing image]" processed_sections.append(processed_section) # For TextSection, create a base Section with text and link elif isinstance(section, TextSection): processed_section = Section( text=section.text or "", # Ensure text is always a string, not None link=section.link, image_file_id=None, ) processed_sections.append(processed_section) # Create IndexingDocument with original sections and processed_sections indexed_document = IndexingDocument( **document.model_dump(), processed_sections=processed_sections ) indexed_documents.append(indexed_document) return indexed_documents def add_document_summaries( chunks_by_doc: list[DocAwareChunk], llm: LLM, tokenizer: BaseTokenizer, trunc_doc_tokens: int, ) -> list[int] | None: """ Adds a document summary to a list of chunks from the same document. Returns the number of tokens in the document. """ doc_tokens = [] # this is value is the same for each chunk in the document; 0 indicates # There is not enough space for contextual RAG (the chunk content # and possibly metadata took up too much space) if chunks_by_doc[0].contextual_rag_reserved_tokens == 0: return None doc_tokens = tokenizer.encode(chunks_by_doc[0].source_document.get_text_content()) doc_content = tokenizer_trim_middle(doc_tokens, trunc_doc_tokens, tokenizer) # Apply prompt caching: cache the static prompt, document content is the suffix # Note: For document summarization, there's no cacheable prefix since the document changes # So we just pass the full prompt without caching summary_prompt = DOCUMENT_SUMMARY_PROMPT.format(document=doc_content) prompt_msg = UserMessage(content=summary_prompt) response = llm.invoke(prompt_msg, max_tokens=MAX_CONTEXT_TOKENS) doc_summary = llm_response_to_string(response) for chunk in chunks_by_doc: chunk.doc_summary = doc_summary return doc_tokens def add_chunk_summaries( chunks_by_doc: list[DocAwareChunk], llm: LLM, tokenizer: BaseTokenizer, trunc_doc_chunk_tokens: int, doc_tokens: list[int] | None, ) -> None: """ Adds chunk summaries to the chunks grouped by document id. Chunk summaries look at the chunk as well as the entire document (or a summary, if the document is too long) and describe how the chunk relates to the document. """ # all chunks within a document have the same contextual_rag_reserved_tokens if chunks_by_doc[0].contextual_rag_reserved_tokens == 0: return # use values computed in above doc summary section if available doc_tokens = doc_tokens or tokenizer.encode( chunks_by_doc[0].source_document.get_text_content() ) doc_content = tokenizer_trim_middle(doc_tokens, trunc_doc_chunk_tokens, tokenizer) # only compute doc summary if needed doc_info = ( doc_content if len(doc_tokens) <= MAX_TOKENS_FOR_FULL_INCLUSION else chunks_by_doc[0].doc_summary ) if not doc_info: # This happens if the document is too long AND document summaries are turned off # In this case we compute a doc summary using the LLM fallback_prompt = UserMessage( content=DOCUMENT_SUMMARY_PROMPT.format(document=doc_content) ) response = llm.invoke(fallback_prompt, max_tokens=MAX_CONTEXT_TOKENS) doc_info = llm_response_to_string(response) from onyx.llm.prompt_cache.processor import process_with_prompt_cache context_prompt1 = CONTEXTUAL_RAG_PROMPT1.format(document=doc_info) def assign_context(chunk: DocAwareChunk) -> None: context_prompt2 = CONTEXTUAL_RAG_PROMPT2.format(chunk=chunk.content) try: # Apply prompt caching: cache the document context (prompt1), chunk content is the suffix # For string inputs with continuation=True, the result will be a concatenated string processed_prompt, _ = process_with_prompt_cache( llm_config=llm.config, cacheable_prefix=UserMessage(content=context_prompt1), suffix=UserMessage(content=context_prompt2), continuation=True, # Append chunk to the document context ) response = llm.invoke(processed_prompt, max_tokens=MAX_CONTEXT_TOKENS) chunk.chunk_context = llm_response_to_string(response) except LLMRateLimitError as e: # Erroring during chunker is undesirable, so we log the error and continue # TODO: for v2, add robust retry logic logger.exception(f"Rate limit adding chunk summary: {e}", exc_info=e) chunk.chunk_context = "" except Exception as e: logger.exception(f"Error adding chunk summary: {e}", exc_info=e) chunk.chunk_context = "" run_functions_tuples_in_parallel( [(assign_context, (chunk,)) for chunk in chunks_by_doc] ) def add_contextual_summaries( chunks: list[DocAwareChunk], llm: LLM, tokenizer: BaseTokenizer, chunk_token_limit: int, ) -> list[DocAwareChunk]: """ Adds Document summary and chunk-within-document context to the chunks based on which environment variables are set. """ doc2chunks = defaultdict(list) for chunk in chunks: doc2chunks[chunk.source_document.id].append(chunk) # The number of tokens allowed for the document when computing a document summary trunc_doc_summary_tokens = llm.config.max_input_tokens - len( tokenizer.encode(DOCUMENT_SUMMARY_PROMPT) ) prompt_tokens = len( tokenizer.encode(CONTEXTUAL_RAG_PROMPT1 + CONTEXTUAL_RAG_PROMPT2) ) # The number of tokens allowed for the document when computing a # "chunk in context of document" summary trunc_doc_chunk_tokens = ( llm.config.max_input_tokens - prompt_tokens - chunk_token_limit ) for chunks_by_doc in doc2chunks.values(): doc_tokens = None if USE_DOCUMENT_SUMMARY: doc_tokens = add_document_summaries( chunks_by_doc, llm, tokenizer, trunc_doc_summary_tokens ) if USE_CHUNK_SUMMARY: add_chunk_summaries( chunks_by_doc, llm, tokenizer, trunc_doc_chunk_tokens, doc_tokens ) return chunks def _verify_indexing_completeness( insertion_records: list[DocumentInsertionRecord], write_failures: list[ConnectorFailure], embedding_failed_doc_ids: set[str], updatable_ids: list[str], document_index_name: str, ) -> None: """Verify that every updatable document was either indexed or reported as failed.""" all_returned_doc_ids = ( {r.document_id for r in insertion_records} | {f.failed_document.document_id for f in write_failures if f.failed_document} | embedding_failed_doc_ids ) if all_returned_doc_ids != set(updatable_ids): raise RuntimeError( f"Some documents were not successfully indexed. " f"Updatable IDs: {updatable_ids}, " f"Returned IDs: {all_returned_doc_ids}. " f"This should never happen. " f"This occured for document index {document_index_name}" ) def _apply_document_ingestion_hook( documents: list[Document], db_session: Session, ) -> list[Document]: """Apply the Document Ingestion hook to each document in the batch. - HookSkipped / HookSoftFailed → document passes through unchanged. - Response with sections=None → document is dropped (logged). - Response with sections → document sections are replaced with the hook's output. """ def _build_payload(doc: Document) -> DocumentIngestionPayload: return DocumentIngestionPayload( document_id=doc.id or "", title=doc.title, semantic_identifier=doc.semantic_identifier, source=doc.source.value if doc.source is not None else "", sections=[ DocumentIngestionSection( text=s.text if isinstance(s, TextSection) else None, link=s.link, image_file_id=( s.image_file_id if isinstance(s, ImageSection) else None ), ) for s in doc.sections ], metadata={ k: v if isinstance(v, list) else [v] for k, v in doc.metadata.items() }, doc_updated_at=( doc.doc_updated_at.isoformat() if doc.doc_updated_at else None ), primary_owners=( [ DocumentIngestionOwner( display_name=o.get_semantic_name() or None, email=o.email, ) for o in doc.primary_owners ] if doc.primary_owners else None ), secondary_owners=( [ DocumentIngestionOwner( display_name=o.get_semantic_name() or None, email=o.email, ) for o in doc.secondary_owners ] if doc.secondary_owners else None ), ) def _apply_result( doc: Document, hook_result: DocumentIngestionResponse | HookSkipped | HookSoftFailed, ) -> Document | None: """Return the modified doc, original doc (skip/soft-fail), or None (drop).""" if isinstance(hook_result, (HookSkipped, HookSoftFailed)): return doc if not hook_result.sections: reason = hook_result.rejection_reason or "Document rejected by hook" logger.info( f"Document ingestion hook dropped document doc_id={doc.id!r}: {reason}" ) return None new_sections: list[TextSection | ImageSection] = [] for s in hook_result.sections: if s.image_file_id is not None: new_sections.append( ImageSection(image_file_id=s.image_file_id, link=s.link) ) elif s.text is not None: new_sections.append(TextSection(text=s.text, link=s.link)) else: logger.warning( f"Document ingestion hook returned a section with neither text nor " f"image_file_id for doc_id={doc.id!r} — skipping section." ) if not new_sections: logger.info( f"Document ingestion hook produced no valid sections for doc_id={doc.id!r} — dropping document." ) return None return doc.model_copy(update={"sections": new_sections}) if not documents: return documents # Run the hook for the first document. If it returns HookSkipped the hook # is not configured — skip the remaining N-1 DB lookups. first_doc = documents[0] first_payload = _build_payload(first_doc).model_dump() first_hook_result = execute_hook( db_session=db_session, hook_point=HookPoint.DOCUMENT_INGESTION, payload=first_payload, response_type=DocumentIngestionResponse, ) if isinstance(first_hook_result, HookSkipped): return documents result: list[Document] = [] first_applied = _apply_result(first_doc, first_hook_result) if first_applied is not None: result.append(first_applied) for doc in documents[1:]: payload = _build_payload(doc).model_dump() hook_result = execute_hook( db_session=db_session, hook_point=HookPoint.DOCUMENT_INGESTION, payload=payload, response_type=DocumentIngestionResponse, ) applied = _apply_result(doc, hook_result) if applied is not None: result.append(applied) return result @log_function_time(debug_only=True) def index_doc_batch( *, document_batch: list[Document], chunker: Chunker, embedder: IndexingEmbedder, document_indices: list[DocumentIndex], request_id: str | None, tenant_id: str, db_session: Session, adapter: IndexingBatchAdapter, enable_contextual_rag: bool = False, llm: LLM | None = None, ignore_time_skip: bool = False, filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents, ) -> IndexingPipelineResult: """End-to-end indexing for a pre-batched set of documents.""" """Takes different pieces of the indexing pipeline and applies it to a batch of documents Note that the documents should already be batched at this point so that it does not inflate the memory requirements Returns a tuple where the first element is the number of new docs and the second element is the number of chunks.""" # Log connector info for debugging OOM issues connector_id = getattr(adapter, "connector_id", None) credential_id = getattr(adapter, "credential_id", None) logger.debug( f"Starting index_doc_batch: connector_id={connector_id}, " f"credential_id={credential_id}, tenant_id={tenant_id}, " f"num_docs={len(document_batch)}" ) filtered_documents = filter_fnc(document_batch) filtered_documents = _apply_document_ingestion_hook(filtered_documents, db_session) context = adapter.prepare(filtered_documents, ignore_time_skip) if not context: return IndexingPipelineResult.empty(len(filtered_documents)) # Convert documents to IndexingDocument objects with processed section # logger.debug("Processing image sections") context.indexable_docs = process_image_sections(context.updatable_docs) doc_descriptors = [ { "doc_id": doc.id, "doc_length": doc.get_total_char_length(), } for doc in context.indexable_docs ] logger.debug(f"Starting indexing process for documents: {doc_descriptors}") logger.debug("Starting chunking") # NOTE: no special handling for failures here, since the chunker is not # a common source of failure for the indexing pipeline chunks: list[DocAwareChunk] = chunker.chunk(context.indexable_docs) llm_tokenizer: BaseTokenizer | None = None # contextual RAG if enable_contextual_rag: assert llm is not None, "must provide an LLM for contextual RAG" llm_tokenizer = get_tokenizer( model_name=llm.config.model_name, provider_type=llm.config.model_provider, ) # Because the chunker's tokens are different from the LLM's tokens, # We add a fudge factor to ensure we truncate prompts to the LLM's token limit chunks = add_contextual_summaries( chunks=chunks, llm=llm, tokenizer=llm_tokenizer, chunk_token_limit=chunker.chunk_token_limit * 2, ) logger.debug("Starting embedding") with embed_and_stream(chunks, embedder, tenant_id, request_id) as ( embedding_result, chunk_store, ): updatable_ids = [doc.id for doc in context.updatable_docs] updatable_chunk_data = [ UpdatableChunkData( chunk_id=chunk_id, document_id=document_id, boost_score=1.0, ) for chunk_id, document_id in embedding_result.successful_chunk_ids ] embedding_failed_doc_ids = _get_failed_doc_ids( embedding_result.connector_failures ) # Filter to only successfully embedded chunks so # doc_id_to_new_chunk_cnt reflects what's actually written to Vespa. embedded_chunks = [ c for c in chunks if c.source_document.id not in embedding_failed_doc_ids ] # Acquires a lock on the documents so that no other process can modify # them. Not needed until here, since this is when the actual race # condition with vector db can occur. with adapter.lock_context(context.updatable_docs): enricher = adapter.prepare_enrichment( context=context, tenant_id=tenant_id, chunks=embedded_chunks, ) index_batch_params = IndexBatchParams( doc_id_to_previous_chunk_cnt=enricher.doc_id_to_previous_chunk_cnt, doc_id_to_new_chunk_cnt=enricher.doc_id_to_new_chunk_cnt, tenant_id=tenant_id, large_chunks_enabled=chunker.enable_large_chunks, ) primary_doc_idx_insertion_records: list[DocumentInsertionRecord] | None = ( None ) primary_doc_idx_vector_db_write_failures: list[ConnectorFailure] | None = ( None ) for document_index in document_indices: def _enriched_stream() -> Iterator[DocMetadataAwareIndexChunk]: for chunk in chunk_store.stream(): yield enricher.enrich_chunk(chunk, 1.0) insertion_records, write_failures = ( write_chunks_to_vector_db_with_backoff( document_index=document_index, make_chunks=_enriched_stream, index_batch_params=index_batch_params, ) ) _verify_indexing_completeness( insertion_records=insertion_records, write_failures=write_failures, embedding_failed_doc_ids=embedding_failed_doc_ids, updatable_ids=updatable_ids, document_index_name=document_index.__class__.__name__, ) # We treat the first document index we got as the primary one used # for reporting the state of indexing. if primary_doc_idx_insertion_records is None: primary_doc_idx_insertion_records = insertion_records if primary_doc_idx_vector_db_write_failures is None: primary_doc_idx_vector_db_write_failures = write_failures adapter.post_index( context=context, updatable_chunk_data=updatable_chunk_data, filtered_documents=filtered_documents, enrichment=enricher, ) assert primary_doc_idx_insertion_records is not None assert primary_doc_idx_vector_db_write_failures is not None return IndexingPipelineResult( new_docs=sum( 1 for r in primary_doc_idx_insertion_records if not r.already_existed ), total_docs=len(filtered_documents), total_chunks=len(embedding_result.successful_chunk_ids), failures=primary_doc_idx_vector_db_write_failures + embedding_result.connector_failures, ) def run_indexing_pipeline( *, document_batch: list[Document], request_id: str | None, embedder: IndexingEmbedder, document_indices: list[DocumentIndex], db_session: Session, tenant_id: str, adapter: IndexingBatchAdapter, chunker: Chunker | None = None, ignore_time_skip: bool = False, ) -> IndexingPipelineResult: """Builds a pipeline which takes in a list (batch) of docs and indexes them.""" all_search_settings = get_active_search_settings(db_session) if ( all_search_settings.secondary and all_search_settings.secondary.status == IndexModelStatus.FUTURE ): search_settings = all_search_settings.secondary else: search_settings = all_search_settings.primary multipass_config = get_multipass_config(search_settings) enable_contextual_rag = ( search_settings.enable_contextual_rag or ENABLE_CONTEXTUAL_RAG ) llm = None if enable_contextual_rag: llm = get_llm_for_contextual_rag( search_settings.contextual_rag_llm_name or DEFAULT_CONTEXTUAL_RAG_LLM_NAME, search_settings.contextual_rag_llm_provider or DEFAULT_CONTEXTUAL_RAG_LLM_PROVIDER, ) chunker = chunker or Chunker( tokenizer=embedder.embedding_model.tokenizer, enable_multipass=multipass_config.multipass_indexing, enable_large_chunks=multipass_config.enable_large_chunks, enable_contextual_rag=enable_contextual_rag, # after every doc, update status in case there are a bunch of really long docs ) return index_doc_batch_with_handler( chunker=chunker, embedder=embedder, document_indices=document_indices, document_batch=document_batch, request_id=request_id, tenant_id=tenant_id, db_session=db_session, adapter=adapter, enable_contextual_rag=enable_contextual_rag, llm=llm, ignore_time_skip=ignore_time_skip, ) ================================================ FILE: backend/onyx/indexing/models.py ================================================ import contextlib from collections.abc import Generator from typing import Optional from typing import Protocol from typing import TYPE_CHECKING from pydantic import BaseModel from pydantic import Field from onyx.access.models import DocumentAccess from onyx.connectors.models import Document from onyx.db.enums import EmbeddingPrecision from onyx.db.enums import SwitchoverType from onyx.utils.logger import setup_logger from onyx.utils.pydantic_util import shallow_model_dump from shared_configs.enums import EmbeddingProvider from shared_configs.model_server_models import Embedding if TYPE_CHECKING: from onyx.indexing.indexing_pipeline import DocumentBatchPrepareContext from sqlalchemy.engine.util import TransactionalContext if TYPE_CHECKING: from onyx.db.models import SearchSettings logger = setup_logger() class ChunkEmbedding(BaseModel): full_embedding: Embedding mini_chunk_embeddings: list[Embedding] class BaseChunk(BaseModel): chunk_id: int # The first sentence(s) of the first Section of the chunk blurb: str content: str # Holds the link and the offsets into the raw Chunk text source_links: dict[int, str] | None image_file_id: str | None # True if this Chunk's start is not at the start of a Section # TODO(andrei): This is deprecated as of the OpenSearch migration. Remove. # Do not use. section_continuation: bool class DocAwareChunk(BaseChunk): # During indexing flow, we have access to a complete "Document" # During inference we only have access to the document id and do not reconstruct the Document source_document: Document # This could be an empty string if the title is too long and taking up too much of the chunk # This does not mean necessarily that the document does not have a title title_prefix: str # During indexing we also (optionally) build a metadata string from the metadata dict # This is also indexed so that we can strip it out after indexing, this way it supports # multiple iterations of metadata representation for backwards compatibility metadata_suffix_semantic: str metadata_suffix_keyword: str # This is the number of tokens reserved for contextual RAG # in the chunk. doc_summary and chunk_context conbined should # contain at most this many tokens. contextual_rag_reserved_tokens: int # This is the summary for the document generated for contextual RAG doc_summary: str # This is the context for this chunk generated for contextual RAG chunk_context: str mini_chunk_texts: list[str] | None large_chunk_id: int | None large_chunk_reference_ids: list[int] = Field(default_factory=list) def to_short_descriptor(self) -> str: """Used when logging the identity of a chunk""" return f"{self.source_document.to_short_descriptor()} Chunk ID: {self.chunk_id}" def get_link(self) -> str | None: return ( self.source_document.sections[0].link if self.source_document.sections else None ) class IndexChunk(DocAwareChunk): embeddings: ChunkEmbedding title_embedding: Embedding | None # TODO(rkuo): currently, this extra metadata sent during indexing is just for speed, # but full consistency happens on background sync class DocMetadataAwareIndexChunk(IndexChunk): """An `IndexChunk` that contains all necessary metadata to be indexed. This includes the following: access: holds all information about which users should have access to the source document for this chunk. document_sets: all document sets the source document for this chunk is a part of. This is used for filtering / personas. boost: influences the ranking of this chunk at query time. Positive -> ranked higher, negative -> ranked lower. Not included in aggregated boost calculation for legacy reasons. aggregated_chunk_boost_factor: represents the aggregated chunk-level boost (currently: information content) """ tenant_id: str access: "DocumentAccess" document_sets: set[str] user_project: list[int] personas: list[int] boost: int aggregated_chunk_boost_factor: float # Full ancestor path from root hierarchy node to document's parent. # Stored as an integer array in OpenSearch for hierarchy-based filtering. # Empty list means no hierarchy info (document excluded from hierarchy searches). ancestor_hierarchy_node_ids: list[int] @classmethod def from_index_chunk( cls, index_chunk: IndexChunk, access: "DocumentAccess", document_sets: set[str], user_project: list[int], personas: list[int], boost: int, aggregated_chunk_boost_factor: float, tenant_id: str, ancestor_hierarchy_node_ids: list[int] | None = None, ) -> "DocMetadataAwareIndexChunk": return cls.model_construct( **shallow_model_dump(index_chunk), access=access, document_sets=document_sets, user_project=user_project, personas=personas, boost=boost, aggregated_chunk_boost_factor=aggregated_chunk_boost_factor, tenant_id=tenant_id, ancestor_hierarchy_node_ids=ancestor_hierarchy_node_ids or [], ) class EmbeddingModelDetail(BaseModel): id: int | None = None model_name: str normalize: bool query_prefix: str | None passage_prefix: str | None api_url: str | None = None provider_type: EmbeddingProvider | None = None api_key: str | None = None # This disables the "model_" protected namespace for pydantic model_config = {"protected_namespaces": ()} @classmethod def from_db_model( cls, search_settings: "SearchSettings", ) -> "EmbeddingModelDetail": api_key = None if ( search_settings.cloud_provider is not None and search_settings.cloud_provider.api_key is not None ): api_key = search_settings.cloud_provider.api_key.get_value(apply_mask=True) return cls( id=search_settings.id, model_name=search_settings.model_name, normalize=search_settings.normalize, query_prefix=search_settings.query_prefix, passage_prefix=search_settings.passage_prefix, provider_type=search_settings.provider_type, api_key=api_key, api_url=search_settings.api_url, ) # Additional info needed for indexing time class IndexingSetting(EmbeddingModelDetail): model_dim: int index_name: str | None multipass_indexing: bool embedding_precision: EmbeddingPrecision reduced_dimension: int | None = None switchover_type: SwitchoverType = SwitchoverType.REINDEX enable_contextual_rag: bool contextual_rag_llm_name: str | None = None contextual_rag_llm_provider: str | None = None # This disables the "model_" protected namespace for pydantic model_config = {"protected_namespaces": ()} @property def final_embedding_dim(self) -> int: if self.reduced_dimension: return self.reduced_dimension return self.model_dim @classmethod def from_db_model(cls, search_settings: "SearchSettings") -> "IndexingSetting": return cls( model_name=search_settings.model_name, model_dim=search_settings.model_dim, normalize=search_settings.normalize, query_prefix=search_settings.query_prefix, passage_prefix=search_settings.passage_prefix, provider_type=search_settings.provider_type, index_name=search_settings.index_name, multipass_indexing=search_settings.multipass_indexing, embedding_precision=search_settings.embedding_precision, reduced_dimension=search_settings.reduced_dimension, switchover_type=search_settings.switchover_type, enable_contextual_rag=search_settings.enable_contextual_rag, ) class MultipassConfig(BaseModel): multipass_indexing: bool enable_large_chunks: bool class UpdatableChunkData(BaseModel): chunk_id: int document_id: str boost_score: float class ChunkEnrichmentContext(Protocol): """Returned by prepare_enrichment. Holds pre-computed metadata lookups and provides per-chunk enrichment.""" doc_id_to_previous_chunk_cnt: dict[str, int] doc_id_to_new_chunk_cnt: dict[str, int] def enrich_chunk( self, chunk: IndexChunk, score: float ) -> DocMetadataAwareIndexChunk: ... class IndexingBatchAdapter(Protocol): def prepare( self, documents: list[Document], ignore_time_skip: bool ) -> Optional["DocumentBatchPrepareContext"]: ... @contextlib.contextmanager def lock_context( self, documents: list[Document] ) -> Generator[TransactionalContext, None, None]: """Provide a transaction/row-lock context for critical updates.""" def prepare_enrichment( self, context: "DocumentBatchPrepareContext", tenant_id: str, chunks: list[DocAwareChunk], ) -> ChunkEnrichmentContext: """Prepare per-chunk enrichment data (access, document sets, boost, etc.). Precondition: ``chunks`` have already been through the embedding step (i.e. they are ``IndexChunk`` instances with populated embeddings, passed here as the base ``DocAwareChunk`` type). """ ... def post_index( self, context: "DocumentBatchPrepareContext", updatable_chunk_data: list[UpdatableChunkData], filtered_documents: list[Document], enrichment: ChunkEnrichmentContext, ) -> None: ... ================================================ FILE: backend/onyx/indexing/vector_db_insertion.py ================================================ import time from collections.abc import Callable from collections.abc import Iterable from http import HTTPStatus from itertools import chain from itertools import groupby import httpx from onyx.connectors.models import ConnectorFailure from onyx.connectors.models import DocumentFailure from onyx.document_index.interfaces import DocumentIndex from onyx.document_index.interfaces import DocumentInsertionRecord from onyx.document_index.interfaces import IndexBatchParams from onyx.indexing.models import DocMetadataAwareIndexChunk from onyx.utils.logger import setup_logger logger = setup_logger() def _log_insufficient_storage_error(e: Exception) -> None: if isinstance(e, httpx.HTTPStatusError): if e.response.status_code == HTTPStatus.INSUFFICIENT_STORAGE: logger.error( "NOTE: HTTP Status 507 Insufficient Storage indicates " "you need to allocate more memory or disk space to the " "Vespa/index container." ) def write_chunks_to_vector_db_with_backoff( document_index: DocumentIndex, make_chunks: Callable[[], Iterable[DocMetadataAwareIndexChunk]], index_batch_params: IndexBatchParams, ) -> tuple[list[DocumentInsertionRecord], list[ConnectorFailure]]: """Tries to insert all chunks in one large batch. If that batch fails for any reason, goes document by document to isolate the failure(s). IMPORTANT: must pass in whole documents at a time not individual chunks, since the vector DB interface assumes that all chunks for a single document are present. The chunks must also be in contiguous batches """ # first try to write the chunks to the vector db try: return ( list( document_index.index( chunks=make_chunks(), index_batch_params=index_batch_params, ) ), [], ) except Exception as e: logger.exception( "Failed to write chunk batch to vector db. Trying individual docs." ) # give some specific logging on this common failure case. _log_insufficient_storage_error(e) # wait a couple seconds just to give the vector db a chance to recover time.sleep(2) insertion_records: list[DocumentInsertionRecord] = [] failures: list[ConnectorFailure] = [] def key(chunk: DocMetadataAwareIndexChunk) -> str: return chunk.source_document.id seen_doc_ids: set[str] = set() for doc_id, chunks_for_doc in groupby(make_chunks(), key=key): if doc_id in seen_doc_ids: raise RuntimeError( f"Doc chunks are not arriving in order. Current doc_id={doc_id}, seen_doc_ids={list(seen_doc_ids)}" ) seen_doc_ids.add(doc_id) first_chunk = next(chunks_for_doc) chunks_for_doc = chain([first_chunk], chunks_for_doc) try: insertion_records.extend( document_index.index( chunks=chunks_for_doc, index_batch_params=index_batch_params, ) ) except Exception as e: logger.exception( f"Failed to write document chunks for '{doc_id}' to vector db" ) # give some specific logging on this common failure case. _log_insufficient_storage_error(e) failures.append( ConnectorFailure( failed_document=DocumentFailure( document_id=doc_id, document_link=first_chunk.get_link(), ), failure_message=str(e), exception=e, ) ) return insertion_records, failures ================================================ FILE: backend/onyx/key_value_store/__init__.py ================================================ ================================================ FILE: backend/onyx/key_value_store/factory.py ================================================ from onyx.key_value_store.interface import KeyValueStore from onyx.key_value_store.store import PgRedisKVStore from shared_configs.configs import DEFAULT_REDIS_PREFIX from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR def get_kv_store() -> KeyValueStore: # In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in # It's read from the global thread level variable return PgRedisKVStore() def get_shared_kv_store() -> KeyValueStore: token = CURRENT_TENANT_ID_CONTEXTVAR.set(DEFAULT_REDIS_PREFIX) try: return get_kv_store() finally: CURRENT_TENANT_ID_CONTEXTVAR.reset(token) ================================================ FILE: backend/onyx/key_value_store/interface.py ================================================ import abc from typing import cast from onyx.utils.special_types import JSON_ro class KvKeyNotFoundError(Exception): pass def unwrap_str(val: JSON_ro) -> str: """Unwrap a string stored as {"value": str} in the encrypted KV store. Also handles legacy plain-string values cached in Redis.""" if isinstance(val, dict): try: return cast(str, val["value"]) except KeyError: raise ValueError( f"Expected dict with 'value' key, got keys: {list(val.keys())}" ) return cast(str, val) class KeyValueStore: # In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in # It's read from the global thread level variable @abc.abstractmethod def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None: raise NotImplementedError @abc.abstractmethod def load(self, key: str, refresh_cache: bool = False) -> JSON_ro: raise NotImplementedError @abc.abstractmethod def delete(self, key: str) -> None: raise NotImplementedError ================================================ FILE: backend/onyx/key_value_store/store.py ================================================ import json from typing import cast from onyx.cache.interface import CacheBackend from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import KVStore from onyx.key_value_store.interface import KeyValueStore from onyx.key_value_store.interface import KvKeyNotFoundError from onyx.utils.logger import setup_logger from onyx.utils.special_types import JSON_ro logger = setup_logger() REDIS_KEY_PREFIX = "onyx_kv_store:" KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day class PgRedisKVStore(KeyValueStore): def __init__(self, cache: CacheBackend | None = None) -> None: self._cache = cache def _get_cache(self) -> CacheBackend: if self._cache is None: from onyx.cache.factory import get_cache_backend self._cache = get_cache_backend() return self._cache def store(self, key: str, val: JSON_ro, encrypt: bool = False) -> None: # Not encrypted in Cache backend (typically Redis), but encrypted in Postgres try: self._get_cache().set( REDIS_KEY_PREFIX + key, json.dumps(val), ex=KV_REDIS_KEY_EXPIRATION ) except Exception as e: # Fallback gracefully to Postgres if Cache backend fails logger.error( f"Failed to set value in Cache backend for key '{key}': {str(e)}" ) encrypted_val = val if encrypt else None plain_val = val if not encrypt else None with get_session_with_current_tenant() as db_session: obj = db_session.query(KVStore).filter_by(key=key).first() if obj: obj.value = plain_val obj.encrypted_value = encrypted_val # type: ignore[assignment] else: obj = KVStore(key=key, value=plain_val, encrypted_value=encrypted_val) db_session.query(KVStore).filter_by(key=key).delete() # just in case db_session.add(obj) db_session.commit() def load(self, key: str, refresh_cache: bool = False) -> JSON_ro: if not refresh_cache: try: cached = self._get_cache().get(REDIS_KEY_PREFIX + key) if cached is not None: return json.loads(cached.decode("utf-8")) except Exception as e: logger.error( f"Failed to get value from cache for key '{key}': {str(e)}" ) with get_session_with_current_tenant() as db_session: obj = db_session.query(KVStore).filter_by(key=key).first() if not obj: raise KvKeyNotFoundError if obj.value is not None: value = obj.value elif obj.encrypted_value is not None: # Unwrap SensitiveValue - this is internal backend use value = obj.encrypted_value.get_value(apply_mask=False) else: value = None try: self._get_cache().set( REDIS_KEY_PREFIX + key, json.dumps(value), ex=KV_REDIS_KEY_EXPIRATION, ) except Exception as e: logger.error(f"Failed to set value in cache for key '{key}': {str(e)}") return cast(JSON_ro, value) def delete(self, key: str) -> None: try: self._get_cache().delete(REDIS_KEY_PREFIX + key) except Exception as e: logger.error(f"Failed to delete value from cache for key '{key}': {str(e)}") with get_session_with_current_tenant() as db_session: result = db_session.query(KVStore).filter_by(key=key).delete() if result == 0: raise KvKeyNotFoundError db_session.commit() ================================================ FILE: backend/onyx/kg/clustering/clustering.py ================================================ import time from collections.abc import Generator from typing import cast from rapidfuzz.fuzz import ratio from redis.lock import Lock as RedisLock from sqlalchemy import func from sqlalchemy import text from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT from onyx.configs.kg_configs import KG_CLUSTERING_RETRIEVE_THRESHOLD from onyx.configs.kg_configs import KG_CLUSTERING_THRESHOLD from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.entities import KGEntity from onyx.db.entities import KGEntityExtractionStaging from onyx.db.entities import merge_entities from onyx.db.entities import transfer_entity from onyx.db.kg_config import get_kg_config_settings from onyx.db.kg_config import validate_kg_settings from onyx.db.models import Document from onyx.db.models import KGEntityType from onyx.db.models import KGRelationshipExtractionStaging from onyx.db.models import KGRelationshipTypeExtractionStaging from onyx.db.relationships import transfer_relationship from onyx.db.relationships import transfer_relationship_type from onyx.db.relationships import upsert_relationship from onyx.db.relationships import upsert_relationship_type from onyx.document_index.vespa.kg_interactions import ( get_kg_vespa_info_update_requests_for_document, ) from onyx.document_index.vespa.kg_interactions import update_kg_chunks_vespa_info from onyx.kg.models import KGGroundingType from onyx.kg.utils.formatting_utils import make_relationship_id from onyx.kg.utils.lock_utils import extend_lock from onyx.utils.logger import setup_logger from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA logger = setup_logger() def _get_batch_untransferred_grounded_entities( batch_size: int, ) -> Generator[list[KGEntityExtractionStaging], None, None]: while True: with get_session_with_current_tenant() as db_session: batch = ( db_session.query(KGEntityExtractionStaging) .join( KGEntityType, KGEntityExtractionStaging.entity_type_id_name == KGEntityType.id_name, ) .filter( KGEntityType.grounding == KGGroundingType.GROUNDED, KGEntityExtractionStaging.transferred_id_name.is_(None), ) .limit(batch_size) .all() ) if not batch: break yield batch def _get_batch_untransferred_relationship_types( batch_size: int, ) -> Generator[list[KGRelationshipTypeExtractionStaging], None, None]: while True: with get_session_with_current_tenant() as db_session: batch = ( db_session.query(KGRelationshipTypeExtractionStaging) .filter(KGRelationshipTypeExtractionStaging.transferred.is_(False)) .limit(batch_size) .all() ) if not batch: break yield batch def _get_batch_untransferred_relationships( batch_size: int, ) -> Generator[list[KGRelationshipExtractionStaging], None, None]: while True: with get_session_with_current_tenant() as db_session: batch = ( db_session.query(KGRelationshipExtractionStaging) .filter(KGRelationshipExtractionStaging.transferred.is_(False)) .limit(batch_size) .all() ) if not batch: break yield batch def _get_batch_entities_with_parent( batch_size: int, ) -> Generator[list[KGEntityExtractionStaging], None, None]: offset = 0 while True: with get_session_with_current_tenant() as db_session: batch = ( db_session.query(KGEntityExtractionStaging) .filter(KGEntityExtractionStaging.parent_key.isnot(None)) .order_by(KGEntityExtractionStaging.id_name) .offset(offset) .limit(batch_size) .all() ) if not batch: break # we can't filter out ""s earlier as it will mess up the pagination yield [entity for entity in batch if entity.parent_key != ""] offset += batch_size def _get_batch_kg_processed_documents( batch_size: int, ) -> Generator[list[Document], None, None]: offset = 0 while True: with get_session_with_current_tenant() as db_session: batch = ( db_session.query(Document) .join( KGEntityExtractionStaging, Document.id == KGEntityExtractionStaging.document_id, ) .filter( KGEntityExtractionStaging.transferred_id_name.is_not(None), ) .order_by(Document.id) .offset(offset) .limit(batch_size) .all() ) if not batch: break yield batch offset += batch_size def _cluster_one_grounded_entity( entity: KGEntityExtractionStaging, ) -> tuple[KGEntity, bool]: """ Cluster a single grounded entity. """ with get_session_with_current_tenant() as db_session: # get entity name and filtering conditions if entity.document_id is not None: entity_name = cast( str, db_session.query(Document.semantic_id) .filter(Document.id == entity.document_id) .scalar(), ).lower() filtering = [KGEntity.document_id.is_(None)] else: entity_name = entity.name.lower() filtering = [] # skip those with numbers so we don't cluster version1 and version2, etc. similar_entities: list[KGEntity] = [] if not any(char.isdigit() for char in entity_name): # find similar entities, uses GIN index, very efficient db_session.execute( text( "SET pg_trgm.similarity_threshold = " + str(KG_CLUSTERING_RETRIEVE_THRESHOLD) ) ) similar_entities = ( db_session.query(KGEntity) .filter( # find entities of the same type with a similar name *filtering, KGEntity.entity_type_id_name == entity.entity_type_id_name, getattr(func, POSTGRES_DEFAULT_SCHEMA).similarity_op( KGEntity.name, entity_name ), ) .all() ) # find best match best_score = -1.0 best_entity = None for similar in similar_entities: # skip those with numbers so we don't cluster version1 and version2, etc. if any(char.isdigit() for char in similar.name): continue score = ratio(similar.name, entity_name) if score >= KG_CLUSTERING_THRESHOLD * 100 and score > best_score: best_score = score best_entity = similar # if there is a match, update the entity, otherwise create a new one with get_session_with_current_tenant() as db_session: if best_entity: logger.debug(f"Merged {entity.name} with {best_entity.name}") update_vespa = ( best_entity.document_id is None and entity.document_id is not None ) transferred_entity = merge_entities( db_session=db_session, parent=best_entity, child=entity ) else: update_vespa = entity.document_id is not None transferred_entity = transfer_entity(db_session=db_session, entity=entity) db_session.commit() return transferred_entity, update_vespa def _create_one_parent_child_relationship(entity: KGEntityExtractionStaging) -> None: """ Creates a relationship between the entity and its parent, if it exists. Then, updates the entity's parent to the next ancestor. """ with get_session_with_current_tenant() as db_session: # find the next ancestor parent = ( db_session.query(KGEntity) .filter(KGEntity.entity_key == entity.parent_key) .first() ) if parent is not None: # create parent child relationship and relationship type upsert_relationship_type( db_session=db_session, source_entity_type=parent.entity_type_id_name, relationship_type="has_subcomponent", target_entity_type=entity.entity_type_id_name, ) relationship_id_name = make_relationship_id( parent.id_name, "has_subcomponent", cast(str, entity.transferred_id_name), ) upsert_relationship( db_session=db_session, relationship_id_name=relationship_id_name, source_document_id=entity.document_id, ) next_ancestor = parent.parent_key or "" else: next_ancestor = "" # set the staging entity's parent to the next ancestor # if there is no parent or next ancestor, set to "" to differentiate from None # None will mess up the pagination in _get_batch_entities_with_parent db_session.query(KGEntityExtractionStaging).filter( KGEntityExtractionStaging.id_name == entity.id_name ).update({"parent_key": next_ancestor}) db_session.commit() def _transfer_one_relationship( relationship: KGRelationshipExtractionStaging, ) -> None: with get_session_with_current_tenant() as db_session: # get the translations staging_entity_id_names = { relationship.source_node, relationship.target_node, } entity_translations: dict[str, str] = { entity.id_name: entity.transferred_id_name for entity in db_session.query(KGEntityExtractionStaging) .filter(KGEntityExtractionStaging.id_name.in_(staging_entity_id_names)) .all() if entity.transferred_id_name is not None } if len(entity_translations) != len(staging_entity_id_names): logger.error( f"Missing entity translations for {staging_entity_id_names - entity_translations.keys()}" ) return # transfer the relationship transfer_relationship( db_session=db_session, relationship=relationship, entity_translations=entity_translations, ) db_session.commit() def kg_clustering( tenant_id: str, index_name: str, lock: RedisLock, processing_chunk_batch_size: int = 16, ) -> None: """ Here we will cluster the extractions based on their cluster frameworks. Initially, this will only focus on grounded entities with pre-determined relationships, so 'clustering' is actually not yet required. However, we may need to reconcile entities coming from different sources. The primary purpose of this function is to populate the actual KG tables from the temp_extraction tables. This will change with deep extraction, where grounded-sourceless entities can be extracted and then need to be clustered. """ logger.info(f"Starting kg clustering for tenant {tenant_id}") kg_config_settings = get_kg_config_settings() validate_kg_settings(kg_config_settings) last_lock_time = time.monotonic() # Cluster and transfer grounded entities sequentially start_time = time.monotonic() i_batch = 0 for i_batch, untransferred_grounded_entities in enumerate( _get_batch_untransferred_grounded_entities( batch_size=processing_chunk_batch_size ) ): for entity in untransferred_grounded_entities: _cluster_one_grounded_entity(entity) last_lock_time = extend_lock( lock, CELERY_GENERIC_BEAT_LOCK_TIMEOUT, last_lock_time ) # logger.debug(f"Transferred entities batch {i}") # NOTE: we assume every entity is transferred, as we currently only have grounded entities time_delta = time.monotonic() - start_time logger.info( f"Finished transferring {i_batch + 1} entity batches in {time_delta:.2f}s" ) # Create parent-child relationships in parallel for _ in range(kg_config_settings.KG_MAX_PARENT_RECURSION_DEPTH): for root_entities in _get_batch_entities_with_parent( batch_size=processing_chunk_batch_size ): run_functions_tuples_in_parallel( [ (_create_one_parent_child_relationship, (root_entity,)) for root_entity in root_entities ] ) last_lock_time = extend_lock( lock, CELERY_GENERIC_BEAT_LOCK_TIMEOUT, last_lock_time ) logger.info("Finished creating all parent-child relationships") # Transfer the relationship types (no need to do in parallel as there's only a few) start_time = time.monotonic() i_batch = 0 for i_batch, relationship_types in enumerate( _get_batch_untransferred_relationship_types( batch_size=processing_chunk_batch_size ) ): with get_session_with_current_tenant() as db_session: for relationship_type in relationship_types: transfer_relationship_type(db_session, relationship_type) db_session.commit() last_lock_time = extend_lock( lock, CELERY_GENERIC_BEAT_LOCK_TIMEOUT, last_lock_time ) # logger.debug(f"Transferred relationship types batch {i}") time_delta = time.monotonic() - start_time logger.info( f"Finished transferring {i_batch + 1} relationship type batches in {time_delta:.2f}s" ) # Transfer the relationships in parallel start_time = time.monotonic() i_batch = 0 for i_batch, relationships in enumerate( _get_batch_untransferred_relationships(batch_size=processing_chunk_batch_size) ): run_functions_tuples_in_parallel( [ (_transfer_one_relationship, (relationship,)) for relationship in relationships ] ) last_lock_time = extend_lock( lock, CELERY_GENERIC_BEAT_LOCK_TIMEOUT, last_lock_time ) # logger.debug(f"Transferred relationships batch {i}") time_delta = time.monotonic() - start_time logger.info( f"Finished transferring {i_batch + 1} relationship batches in {time_delta:.2f}s" ) # Update vespa for each document start_time = time.monotonic() i_batch = 0 for i_batch, documents in enumerate( _get_batch_kg_processed_documents(batch_size=processing_chunk_batch_size) ): batch_update_requests = run_functions_tuples_in_parallel( [ (get_kg_vespa_info_update_requests_for_document, (document.id,)) for document in documents ] ) for update_requests, document in zip(batch_update_requests, documents): try: update_kg_chunks_vespa_info(update_requests, index_name, tenant_id) except Exception as e: logger.error(f"Error updating vespa for document {document.id}: {e}") last_lock_time = extend_lock( lock, CELERY_GENERIC_BEAT_LOCK_TIMEOUT, last_lock_time ) # logger.debug(f"Updated vespa for documents batch {i}") time_delta = time.monotonic() - start_time logger.info( f"Finished updating {i_batch + 1} document batches in {time_delta:.2f}s" ) # Delete the transferred objects from the staging tables try: with get_session_with_current_tenant() as db_session: db_session.query(KGRelationshipExtractionStaging).filter( KGRelationshipExtractionStaging.transferred.is_(True) ).delete(synchronize_session=False) db_session.commit() except Exception as e: logger.error(f"Error deleting relationships: {e}") try: with get_session_with_current_tenant() as db_session: db_session.query(KGRelationshipTypeExtractionStaging).filter( KGRelationshipTypeExtractionStaging.transferred.is_(True) ).delete(synchronize_session=False) db_session.commit() except Exception as e: logger.error(f"Error deleting relationship types: {e}") try: with get_session_with_current_tenant() as db_session: db_session.query(KGEntityExtractionStaging).filter( KGEntityExtractionStaging.transferred_id_name.is_not(None) ).delete(synchronize_session=False) db_session.commit() except Exception as e: logger.error(f"Error deleting entities: {e}") logger.info("Finished deleting all transferred staging entries") ================================================ FILE: backend/onyx/kg/clustering/normalizations.py ================================================ import re from collections import defaultdict from typing import cast import numpy as np from rapidfuzz.distance.DamerauLevenshtein import normalized_similarity from sqlalchemy import desc from sqlalchemy import Float from sqlalchemy import func from sqlalchemy import MetaData from sqlalchemy import select from sqlalchemy import String from sqlalchemy import Table from sqlalchemy.dialects.postgresql import ARRAY from onyx.configs.kg_configs import KG_NORMALIZATION_RERANK_LEVENSHTEIN_WEIGHT from onyx.configs.kg_configs import KG_NORMALIZATION_RERANK_NGRAM_WEIGHTS from onyx.configs.kg_configs import KG_NORMALIZATION_RERANK_THRESHOLD from onyx.configs.kg_configs import KG_NORMALIZATION_RETRIEVE_ENTITIES_LIMIT from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import KGEntity from onyx.db.relationships import get_relationships_for_entity_type_pairs from onyx.kg.models import NormalizedEntities from onyx.kg.models import NormalizedRelationships from onyx.kg.utils.embeddings import encode_string_batch from onyx.kg.utils.formatting_utils import format_entity_id_for_models from onyx.kg.utils.formatting_utils import get_attributes from onyx.kg.utils.formatting_utils import get_entity_type from onyx.kg.utils.formatting_utils import make_entity_w_attributes from onyx.kg.utils.formatting_utils import make_relationship_id from onyx.kg.utils.formatting_utils import split_entity_id from onyx.kg.utils.formatting_utils import split_relationship_id from onyx.utils.logger import setup_logger from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA logger = setup_logger() alphanum_regex = re.compile(r"[^a-z0-9]+") rem_email_regex = re.compile(r"(?<=\S)@([a-z0-9-]+)\.([a-z]{2,6})$") def _ngrams(sequence: str, n: int) -> list[tuple[str, ...]]: """Generate n-grams from a sequence.""" return [tuple(sequence[i : i + n]) for i in range(len(sequence) - n + 1)] def _clean_name(entity_name: str) -> str: """ Clean an entity string by removing non-alphanumeric characters and email addresses. If the name after cleaning is empty, return the original name in lowercase. """ cleaned_entity = entity_name.casefold() return ( alphanum_regex.sub("", rem_email_regex.sub("", cleaned_entity)) or cleaned_entity ) def _normalize_one_entity( entity: str, attributes: dict[str, str], allowed_docs_temp_view_name: str | None = None, ) -> str | None: """ Matches a single entity to the best matching entity of the same type. """ entity_type, entity_name = split_entity_id(entity) if entity_name == "*": return entity cleaned_entity = _clean_name(entity_name) # narrow filter to subtype if requested type_filters = [KGEntity.entity_type_id_name == entity_type] if "subtype" in attributes: type_filters.append( KGEntity.attributes.op("@>")({"subtype": attributes["subtype"]}) ) # step 1: find entities containing the entity_name or something similar with get_session_with_current_tenant() as db_session: # get allowed documents metadata = MetaData() if allowed_docs_temp_view_name is None: raise ValueError("allowed_docs_temp_view_name is not available") effective_schema_allowed_docs_temp_view_name = ( allowed_docs_temp_view_name.split(".")[-1] ) allowed_docs_temp_view = Table( effective_schema_allowed_docs_temp_view_name, metadata, autoload_with=db_session.get_bind(), ) # generate trigrams of the queried entity Q query_trigrams = db_session.query( getattr(func, POSTGRES_DEFAULT_SCHEMA) .show_trgm(cleaned_entity) .cast(ARRAY(String(3))) .label("trigrams") ).cte("query") candidates = cast( list[tuple[str, str, float]], db_session.query( KGEntity.id_name, KGEntity.name, ( # for each entity E, compute score = | Q ∩ E | / min(|Q|, |E|) func.cardinality( func.array( select(func.unnest(KGEntity.name_trigrams)) .correlate(KGEntity) .intersect( select( func.unnest(query_trigrams.c.trigrams) ).correlate(query_trigrams) ) .scalar_subquery() ) ).cast(Float) / func.least( func.cardinality(query_trigrams.c.trigrams), func.cardinality(KGEntity.name_trigrams), ) ).label("score"), ) .select_from(KGEntity, query_trigrams) .outerjoin( allowed_docs_temp_view, KGEntity.document_id == allowed_docs_temp_view.c.allowed_doc_id, ) .filter( *type_filters, KGEntity.name_trigrams.overlap(query_trigrams.c.trigrams), # Add filter for allowed docs - either document_id is NULL or it's in allowed_docs ( KGEntity.document_id.is_(None) | allowed_docs_temp_view.c.allowed_doc_id.isnot(None) ), ) .order_by(desc("score")) .limit(KG_NORMALIZATION_RETRIEVE_ENTITIES_LIMIT) .all(), ) if not candidates: return None # step 2: do a weighted ngram analysis and damerau levenshtein distance to rerank n1, n2, n3 = ( set(_ngrams(cleaned_entity, 1)), set(_ngrams(cleaned_entity, 2)), set(_ngrams(cleaned_entity, 3)), ) for i, (candidate_id_name, candidate_name, _) in enumerate(candidates): cleaned_candidate = _clean_name(candidate_name) h_n1, h_n2, h_n3 = ( set(_ngrams(cleaned_candidate, 1)), set(_ngrams(cleaned_candidate, 2)), set(_ngrams(cleaned_candidate, 3)), ) # compute ngram overlap, renormalize scores if the names are too short for larger ngrams grams_used = min(2, len(cleaned_entity) - 1, len(cleaned_candidate) - 1) W_n1, W_n2, W_n3 = KG_NORMALIZATION_RERANK_NGRAM_WEIGHTS ngram_score = ( # compute | Q ∩ E | / min(|Q|, |E|) for unigrams and bigrams (trigrams already computed) W_n1 * len(n1 & h_n1) / max(1, min(len(n1), len(h_n1))) + W_n2 * len(n2 & h_n2) / max(1, min(len(n2), len(h_n2))) + W_n3 * len(n3 & h_n3) / max(1, min(len(n3), len(h_n3))) ) / (W_n1, W_n1 + W_n2, 1.0)[grams_used] # compute damerau levenshtein distance to fuzzy match against typos W_leven = KG_NORMALIZATION_RERANK_LEVENSHTEIN_WEIGHT leven_score = normalized_similarity(cleaned_entity, cleaned_candidate) # combine scores score = (1.0 - W_leven) * ngram_score + W_leven * leven_score candidates[i] = (candidate_id_name, candidate_name, score) candidates = list( sorted( filter(lambda x: x[2] > KG_NORMALIZATION_RERANK_THRESHOLD, candidates), key=lambda x: x[2], reverse=True, ) ) if not candidates: return None return candidates[0][0] def _get_existing_normalized_relationships( raw_relationships: list[str], ) -> dict[str, dict[str, list[str]]]: """ Get existing normalized relationships from the database. """ relationship_type_map: dict[str, dict[str, list[str]]] = defaultdict( lambda: defaultdict(list) ) relationship_pairs = list( { ( get_entity_type(split_relationship_id(relationship)[0]), get_entity_type(split_relationship_id(relationship)[2]), ) for relationship in raw_relationships } ) with get_session_with_current_tenant() as db_session: relationships = get_relationships_for_entity_type_pairs( db_session, relationship_pairs ) for relationship in relationships: relationship_type_map[relationship.source_entity_type_id_name][ relationship.target_entity_type_id_name ].append(relationship.id_name) return relationship_type_map def normalize_entities( raw_entities: list[str], raw_entities_w_attributes: list[str], allowed_docs_temp_view_name: str | None = None, ) -> NormalizedEntities: """ Match each entity against a list of normalized entities using fuzzy matching. Returns the best matching normalized entity for each input entity. Args: raw_entities: list of entity strings to normalize, w/o attributes raw_entities_w_attributes: list of entity strings to normalize, w/ attributes Returns: list of normalized entity strings """ normalized_entities: list[str] = [] normalized_entities_w_attributes: list[str] = [] normalized_map: dict[str, str] = {} entity_attributes = [ get_attributes(attr_entity) for attr_entity in raw_entities_w_attributes ] mapping: list[str | None] = run_functions_tuples_in_parallel( [ (_normalize_one_entity, (entity, attributes, allowed_docs_temp_view_name)) for entity, attributes in zip(raw_entities, entity_attributes) ] ) for entity, attributes, normalized_entity in zip( raw_entities, entity_attributes, mapping ): if normalized_entity is not None: normalized_entities.append(normalized_entity) normalized_entities_w_attributes.append( make_entity_w_attributes(normalized_entity, attributes) ) normalized_map[entity] = format_entity_id_for_models(normalized_entity) else: logger.warning(f"No normalized entity found for {entity}") normalized_map[entity] = format_entity_id_for_models(entity) return NormalizedEntities( entities=normalized_entities, entities_w_attributes=normalized_entities_w_attributes, entity_normalization_map=normalized_map, ) def normalize_relationships( raw_relationships: list[str], entity_normalization_map: dict[str, str] ) -> NormalizedRelationships: """ Normalize relationships using entity mappings and relationship string matching. Args: relationships: list of relationships in format "source__relation__target" entity_normalization_map: Mapping of raw entities to normalized ones (or None) Returns: NormalizedRelationships containing normalized relationships and mapping """ # Placeholder for normalized relationship structure nor_relationships = _get_existing_normalized_relationships(raw_relationships) normalized_rels: list[str] = [] normalization_map: dict[str, str] = {} for raw_rel in raw_relationships: # 1. Split and normalize entities try: source, rel_string, target = split_relationship_id(raw_rel) except ValueError: raise ValueError(f"Invalid relationship format: {raw_rel}") # Check if entities are in normalization map and not None norm_source = entity_normalization_map.get(source) norm_target = entity_normalization_map.get(target) if norm_source is None or norm_target is None: logger.warning(f"No normalized entities found for {raw_rel}") continue # 2. Find candidate normalized relationships candidate_rels = [] norm_source_type = get_entity_type(format_entity_id_for_models(norm_source)) norm_target_type = get_entity_type(format_entity_id_for_models(norm_target)) if ( norm_source_type in nor_relationships and norm_target_type in nor_relationships[norm_source_type] ): candidate_rels = [ split_relationship_id(rel)[1] for rel in nor_relationships[norm_source_type][norm_target_type] ] if not candidate_rels: logger.warning(f"No candidate relationships found for {raw_rel}") continue # 3. Encode and find best match strings_to_encode = [rel_string] + candidate_rels vectors = encode_string_batch(strings_to_encode) # Get raw relation vector and candidate vectors raw_vector = vectors[0] candidate_vectors = vectors[1:] # Calculate dot products dot_products = np.dot(candidate_vectors, raw_vector) best_match_idx = np.argmax(dot_products) # Create normalized relationship norm_rel = make_relationship_id( norm_source, candidate_rels[best_match_idx], norm_target ) normalized_rels.append(norm_rel) normalization_map[raw_rel] = norm_rel return NormalizedRelationships( relationships=normalized_rels, relationship_normalization_map=normalization_map ) ================================================ FILE: backend/onyx/kg/extractions/extraction_processing.py ================================================ import time from typing import Any from redis.lock import Lock as RedisLock from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT from onyx.db.connector import get_kg_enabled_connectors from onyx.db.document import get_document_updated_at from onyx.db.document import get_skipped_kg_documents from onyx.db.document import get_unprocessed_kg_document_batch_for_connector from onyx.db.document import update_document_kg_info from onyx.db.document import update_document_kg_stage from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.entities import delete_from_kg_entities__no_commit from onyx.db.entities import upsert_staging_entity from onyx.db.entity_type import get_entity_types from onyx.db.kg_config import get_kg_config_settings from onyx.db.kg_config import validate_kg_settings from onyx.db.models import Document from onyx.db.models import KGStage from onyx.db.relationships import delete_from_kg_relationships__no_commit from onyx.db.relationships import upsert_staging_relationship from onyx.db.relationships import upsert_staging_relationship_type from onyx.kg.models import KGClassificationInstructions from onyx.kg.models import KGDocumentDeepExtractionResults from onyx.kg.models import KGEnhancedDocumentMetadata from onyx.kg.models import KGEntityTypeInstructions from onyx.kg.models import KGExtractionInstructions from onyx.kg.models import KGImpliedExtractionResults from onyx.kg.utils.extraction_utils import EntityTypeMetadataTracker from onyx.kg.utils.extraction_utils import ( get_batch_documents_metadata, ) from onyx.kg.utils.extraction_utils import kg_deep_extraction from onyx.kg.utils.extraction_utils import ( kg_implied_extraction, ) from onyx.kg.utils.formatting_utils import extract_relationship_type_id from onyx.kg.utils.formatting_utils import get_entity_type from onyx.kg.utils.formatting_utils import split_entity_id from onyx.kg.utils.formatting_utils import split_relationship_id from onyx.kg.utils.lock_utils import extend_lock from onyx.utils.logger import setup_logger from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel logger = setup_logger() def _get_classification_extraction_instructions() -> ( dict[str | None, dict[str, KGEntityTypeInstructions]] ): """ Prepare the classification instructions for the given source. """ classification_instructions_dict: dict[ str | None, dict[str, KGEntityTypeInstructions] ] = {} with get_session_with_current_tenant() as db_session: entity_types = get_entity_types(db_session, active=True) for entity_type in entity_types: grounded_source_name = entity_type.grounded_source_name if grounded_source_name not in classification_instructions_dict: classification_instructions_dict[grounded_source_name] = {} if grounded_source_name is None: continue attributes = entity_type.parsed_attributes classification_attributes = { option: info for option, info in attributes.classification_attributes.items() if info.extraction } classification_options = ", ".join(classification_attributes.keys()) classification_enabled = ( len(classification_options) > 0 and len(classification_attributes) > 0 ) classification_instructions_dict[grounded_source_name][entity_type.id_name] = ( KGEntityTypeInstructions( metadata_attribute_conversion=attributes.metadata_attribute_conversion, classification_instructions=KGClassificationInstructions( classification_enabled=classification_enabled, classification_options=classification_options, classification_class_definitions=classification_attributes, ), extraction_instructions=KGExtractionInstructions( deep_extraction=entity_type.deep_extraction, active=entity_type.active, ), entity_filter_attributes=attributes.entity_filter_attributes, ) ) return classification_instructions_dict def _get_batch_documents_enhanced_metadata( unprocessed_document_batch: list[Document], source_type_classification_extraction_instructions: dict[ str, KGEntityTypeInstructions ], connector_source: str, ) -> dict[str, KGEnhancedDocumentMetadata]: """ Get the entity types for the given unprocessed documents. """ kg_document_meta_data_dict: dict[str, KGEnhancedDocumentMetadata] = { document.id: KGEnhancedDocumentMetadata( entity_type=None, metadata_attribute_conversion=None, document_metadata=None, deep_extraction=False, classification_enabled=False, classification_instructions=None, skip=True, ) for document in unprocessed_document_batch } batch_entity = None if len(source_type_classification_extraction_instructions) == 1: # if source only has one entity type, the document must be of that type batch_entity = list(source_type_classification_extraction_instructions.keys())[ 0 ] # the documents can be of multiple entity types. We need to identify the entity type for each document batch_metadata = get_batch_documents_metadata( [ unprocessed_document.id for unprocessed_document in unprocessed_document_batch ], connector_source, ) for metadata in batch_metadata: document_id = metadata.document_id doc_entity = None if not isinstance(document_id, str): continue chunk_metadata = metadata.source_metadata if batch_entity: doc_entity = batch_entity else: # TODO: make this a helper function if not chunk_metadata: continue for ( potential_entity_type ) in source_type_classification_extraction_instructions.keys(): potential_entity_type_attribute_filters = ( source_type_classification_extraction_instructions[ potential_entity_type ].entity_filter_attributes or {} ) if not potential_entity_type_attribute_filters: continue if all( chunk_metadata.get(attribute) == potential_entity_type_attribute_filters.get(attribute) for attribute in potential_entity_type_attribute_filters ): doc_entity = potential_entity_type break if doc_entity is None: continue entity_instructions = source_type_classification_extraction_instructions[ doc_entity ] kg_document_meta_data_dict[document_id] = KGEnhancedDocumentMetadata( entity_type=doc_entity, metadata_attribute_conversion=( source_type_classification_extraction_instructions[ doc_entity ].metadata_attribute_conversion ), document_metadata=chunk_metadata, deep_extraction=entity_instructions.extraction_instructions.deep_extraction, classification_enabled=entity_instructions.classification_instructions.classification_enabled, classification_instructions=entity_instructions.classification_instructions, skip=False, ) return kg_document_meta_data_dict def kg_extraction( tenant_id: str, index_name: str, lock: RedisLock, processing_chunk_batch_size: int = 8, ) -> None: """ This extraction will try to extract from all chunks that have not been kg-processed yet. Approach: - Get all connectors that are enabled for KG extraction - For each enabled connector: - Get unprocessed documents (using a generator) - For each batch of unprocessed documents: - Classify each document to select proper ones - Get and extract from chunks - Update chunks in Vespa - Update temporary KG extraction tables - Update document table to set kg_extracted = True """ logger.info(f"Starting kg extraction for tenant {tenant_id}") kg_config_settings = get_kg_config_settings() validate_kg_settings(kg_config_settings) # get connector ids that are enabled for KG extraction with get_session_with_current_tenant() as db_session: kg_enabled_connectors = get_kg_enabled_connectors(db_session) document_classification_extraction_instructions = ( _get_classification_extraction_instructions() ) # get entity type info with get_session_with_current_tenant() as db_session: all_entity_types = get_entity_types(db_session) active_entity_types = { entity_type.id_name for entity_type in get_entity_types(db_session, active=True) } # entity_type: (metadata: conversion property) entity_metadata_conversion_instructions = { entity_type.id_name: entity_type.parsed_attributes.metadata_attribute_conversion for entity_type in all_entity_types } # Track which metadata attributes are possible for each entity type metadata_tracker = EntityTypeMetadataTracker() metadata_tracker.import_typeinfo() last_lock_time = time.monotonic() # Iterate over connectors that are enabled for KG extraction for kg_enabled_connector in kg_enabled_connectors: connector_id = kg_enabled_connector.id connector_coverage_days = kg_enabled_connector.kg_coverage_days connector_source = kg_enabled_connector.source document_batch_counter = 0 # iterate over un-kg-processed documents in connector while True: # get a batch of unprocessed documents with get_session_with_current_tenant() as db_session: unprocessed_document_batch = ( get_unprocessed_kg_document_batch_for_connector( db_session, connector_id, kg_coverage_start=kg_config_settings.KG_COVERAGE_START_DATE, kg_max_coverage_days=connector_coverage_days or kg_config_settings.KG_MAX_COVERAGE_DAYS, batch_size=processing_chunk_batch_size, ) ) if len(unprocessed_document_batch) == 0: logger.info( f"No unprocessed documents found for connector {connector_id}. Processed {document_batch_counter} batches." ) break document_batch_counter += 1 last_lock_time = extend_lock( lock, CELERY_GENERIC_BEAT_LOCK_TIMEOUT, last_lock_time ) logger.info(f"Processing document batch {document_batch_counter}") # Get the document attributes and entity types batch_metadata = _get_batch_documents_enhanced_metadata( unprocessed_document_batch, document_classification_extraction_instructions.get( connector_source, {} ), connector_source, ) # mark docs in unprocessed_document_batch as EXTRACTING for unprocessed_document in unprocessed_document_batch: if batch_metadata[unprocessed_document.id].entity_type is None: # info for after the connector has been processed kg_stage = KGStage.SKIPPED logger.debug( f"Document {unprocessed_document.id} is not of any entity type" ) elif batch_metadata[unprocessed_document.id].skip: # info for after the connector has been processed. But no message as there may be many # purposefully skipped documents kg_stage = KGStage.SKIPPED else: kg_stage = KGStage.EXTRACTING with get_session_with_current_tenant() as db_session: update_document_kg_stage( db_session, unprocessed_document.id, kg_stage, ) if kg_stage == KGStage.EXTRACTING: delete_from_kg_relationships__no_commit( db_session, [unprocessed_document.id] ) delete_from_kg_entities__no_commit( db_session, [unprocessed_document.id] ) db_session.commit() # Iterate over batches of unprocessed documents # For each document: # - extract implied entities and relationships # - if deep extraction is enabled, extract entities and relationships with LLM # - if deep extraction and classification are enabled, classify document # - update postgres with # - extracted entities (with classification) and relationships # - kg_stage of the processed document documents_to_process = [x.id for x in unprocessed_document_batch] batch_implied_extraction: dict[str, KGImpliedExtractionResults] = {} batch_deep_extraction_args: list[ tuple[str, KGEnhancedDocumentMetadata, KGImpliedExtractionResults] ] = [] for unprocessed_document in unprocessed_document_batch: if ( unprocessed_document.id not in documents_to_process or batch_metadata[unprocessed_document.id].entity_type is None or batch_metadata[unprocessed_document.id].skip ): with get_session_with_current_tenant() as db_session: update_document_kg_stage( db_session, unprocessed_document.id, KGStage.SKIPPED, ) db_session.commit() continue # 1. perform (implicit) KG 'extractions' on the documents that should be processed # This is really about assigning document meta-data to KG entities/relationships or KG entity attributes # General approach: # - vendor emails to Employee-type entities + relationship to current primary grounded entity # - external account emails to Account-type entities + relationship to current primary grounded entity # - non-email owners to KG current entity's attributes, no relationships # We also collect email addresses of vendors and external accounts to inform chunk processing batch_implied_extraction[unprocessed_document.id] = ( kg_implied_extraction( unprocessed_document, batch_metadata[unprocessed_document.id], active_entity_types, kg_config_settings, ) ) # 2. prepare inputs for deep extraction and classification if batch_metadata[unprocessed_document.id].deep_extraction: batch_deep_extraction_args.append( ( unprocessed_document.id, batch_metadata[unprocessed_document.id], batch_implied_extraction[unprocessed_document.id], ) ) # 2. perform deep extraction and classification in parallel batch_deep_extraction_func_calls = [ ( kg_deep_extraction, ( *arg, tenant_id, index_name, kg_config_settings, ), ) for arg in batch_deep_extraction_args ] batch_deep_extractions: dict[str, KGDocumentDeepExtractionResults] = { document_id: result for document_id, result in zip( documents_to_process, run_functions_tuples_in_parallel(batch_deep_extraction_func_calls), ) } # Collect entities and relationships to upsert batch_entities: list[tuple[str | None, str]] = [] batch_relationships: list[tuple[str, str]] = [] entity_classification: dict[str, str] = {} for document_id, implied_metadata in batch_implied_extraction.items(): batch_entities += [ (None, entity) for entity in implied_metadata.implied_entities ] batch_entities.append((document_id, implied_metadata.document_entity)) batch_relationships += [ (document_id, relationship) for relationship in implied_metadata.implied_relationships ] for document_id, deep_extraction_result in batch_deep_extractions.items(): batch_entities += [ (None, entity) for entity in deep_extraction_result.deep_extracted_entities ] for relationship in deep_extraction_result.deep_extracted_relationships: source_entity, _, target_entity = split_relationship_id( relationship ) if ( source_entity in active_entity_types and target_entity in active_entity_types ): batch_relationships += [(document_id, relationship)] classification_result = deep_extraction_result.classification_result if not classification_result: continue entity_classification[classification_result.document_entity] = ( classification_result.classification_class ) # Populate the KG database with the extracted entities, relationships, and terms for potential_document_id, entity in batch_entities: # verify the entity is valid parts = split_entity_id(entity) if len(parts) != 2: logger.error( f"Invalid entity {entity} in aggregated_kg_extractions.entities" ) continue entity_type, entity_name = parts entity_type = entity_type.upper() entity_name = entity_name.capitalize() if entity_type not in active_entity_types: continue try: with get_session_with_current_tenant() as db_session: entity_attributes: dict[str, Any] = {} if potential_document_id: entity_attributes = ( batch_metadata[potential_document_id].document_metadata or {} ) # only keep selected attributes (and translate the attribute names) metadata_attributes = entity_metadata_conversion_instructions[ entity_type ] keep_attributes = { metadata_attributes[attr_name].name: attr_val for attr_name, attr_val in entity_attributes.items() if ( attr_name in metadata_attributes and metadata_attributes[attr_name].keep ) } # add the classification result to the attributes if entity in entity_classification: keep_attributes["classification"] = entity_classification[ entity ] event_time = None if potential_document_id: event_time = get_document_updated_at( potential_document_id, db_session ) upserted_entity = upsert_staging_entity( db_session=db_session, name=entity_name, entity_type=entity_type, document_id=potential_document_id, occurrences=1, attributes=keep_attributes, event_time=event_time, ) metadata_tracker.track_metadata( entity_type, upserted_entity.attributes ) db_session.commit() except Exception as e: logger.error(f"Error adding entity {entity}. Error message: {e}") for document_id, relationship in batch_relationships: relationship_split = split_relationship_id(relationship) if len(relationship_split) != 3: logger.error( f"Invalid relationship {relationship} in aggregated_kg_extractions.relationships" ) continue source_entity, relationship_type, target_entity = relationship_split source_entity_type = get_entity_type(source_entity) target_entity_type = get_entity_type(target_entity) if ( source_entity_type not in active_entity_types or target_entity_type not in active_entity_types ): continue relationship_type_id_name = extract_relationship_type_id(relationship) with get_session_with_current_tenant() as db_session: try: upsert_staging_relationship_type( db_session=db_session, source_entity_type=source_entity_type.upper(), relationship_type=relationship_type, target_entity_type=target_entity_type.upper(), definition=False, extraction_count=1, ) db_session.commit() except Exception as e: logger.error( f"Error adding relationship type {relationship_type_id_name} to the database: {e}" ) with get_session_with_current_tenant() as db_session: try: upsert_staging_relationship( db_session=db_session, relationship_id_name=relationship, source_document_id=document_id, occurrences=1, ) db_session.commit() except Exception as e: logger.error( f"Error adding relationship {relationship} to the database: {e}" ) # Populate the Documents table with the kg information for the documents for processed_document in documents_to_process: with get_session_with_current_tenant() as db_session: update_document_kg_info( db_session, processed_document, KGStage.EXTRACTED, ) db_session.commit() # Update the the Skipped Docs back to Not Started with get_session_with_current_tenant() as db_session: skipped_documents = get_skipped_kg_documents(db_session) for document_id in skipped_documents: update_document_kg_stage( db_session, document_id, KGStage.NOT_STARTED, ) db_session.commit() metadata_tracker.export_typeinfo() ================================================ FILE: backend/onyx/kg/models.py ================================================ from datetime import datetime from enum import Enum from typing import Any from pydantic import BaseModel from onyx.configs.constants import DocumentSource from onyx.configs.kg_configs import KG_DEFAULT_MAX_PARENT_RECURSION_DEPTH # Note: make sure to write a migration if adding a non-nullable field or removing a field class KGConfigSettings(BaseModel): KG_EXPOSED: bool = False KG_ENABLED: bool = False KG_VENDOR: str | None = None KG_VENDOR_DOMAINS: list[str] = [] KG_IGNORE_EMAIL_DOMAINS: list[str] = [] KG_COVERAGE_START: str = datetime(1970, 1, 1).strftime("%Y-%m-%d") KG_MAX_COVERAGE_DAYS: int = 10000 KG_MAX_PARENT_RECURSION_DEPTH: int = KG_DEFAULT_MAX_PARENT_RECURSION_DEPTH KG_BETA_PERSONA_ID: int | None = None @property def KG_COVERAGE_START_DATE(self) -> datetime: return datetime.strptime(self.KG_COVERAGE_START, "%Y-%m-%d") class KGGroundingType(str, Enum): UNGROUNDED = "ungrounded" GROUNDED = "grounded" class KGAttributeTrackType(str, Enum): VALUE = "value" LIST = "list" class KGAttributeTrackInfo(BaseModel): type: KGAttributeTrackType values: set[str] | None class KGAttributeEntityOption(str, Enum): FROM_EMAIL = "from_email" # use email to determine type (ACCOUNT or EMPLOYEE) class KGAttributeImplicationProperty(BaseModel): # type of implied entity to create # if str, will create an implied entity of that type # if KGAttributeEntityOption, will determine the type based on the option implied_entity_type: str | KGAttributeEntityOption # name of the implied relationship to create (from implied entity to this entity) implied_relationship_name: str class KGAttributeProperty(BaseModel): # name of attribute to map metadata to name: str # whether to keep this attribute in the entity keep: bool # properties for creating implied entities and relations from this metadata implication_property: KGAttributeImplicationProperty | None = None class KGEntityTypeClassificationInfo(BaseModel): extraction: bool description: str class KGEntityTypeAttributes(BaseModel): # information on how to use the metadata to extract attributes, implied entities, and relations metadata_attribute_conversion: dict[str, KGAttributeProperty] = {} # a metadata key: value pair to match for to differentiate entities from the same source entity_filter_attributes: dict[str, Any] = {} # mapping of classification names to their corresponding classification info classification_attributes: dict[str, KGEntityTypeClassificationInfo] = {} # mapping of attribute names to their allowed values, populated during extraction attribute_values: dict[str, KGAttributeTrackInfo | None] = {} class KGEntityTypeDefinition(BaseModel): description: str grounding: KGGroundingType grounded_source_name: DocumentSource | None active: bool = False attributes: KGEntityTypeAttributes = KGEntityTypeAttributes() entity_values: list[str] = [] class KGChunkFormat(BaseModel): connector_id: int | None = None document_id: str chunk_id: int title: str content: str primary_owners: list[str] secondary_owners: list[str] source_type: str metadata: dict[str, str | list[str]] | None = None class KGPerson(BaseModel): name: str company: str employee: bool class NormalizedEntities(BaseModel): entities: list[str] entities_w_attributes: list[str] entity_normalization_map: dict[str, str] class NormalizedRelationships(BaseModel): relationships: list[str] relationship_normalization_map: dict[str, str] class KGMetadataContent(BaseModel): document_id: str source_type: str source_metadata: dict[str, Any] | None = None class KGClassificationInstructions(BaseModel): classification_enabled: bool classification_options: str classification_class_definitions: dict[str, KGEntityTypeClassificationInfo] class KGExtractionInstructions(BaseModel): deep_extraction: bool active: bool class KGEntityTypeInstructions(BaseModel): metadata_attribute_conversion: dict[str, KGAttributeProperty] classification_instructions: KGClassificationInstructions extraction_instructions: KGExtractionInstructions entity_filter_attributes: dict[str, Any] | None = None class KGEnhancedDocumentMetadata(BaseModel): entity_type: str | None metadata_attribute_conversion: dict[str, KGAttributeProperty] | None document_metadata: dict[str, Any] | None deep_extraction: bool classification_enabled: bool classification_instructions: KGClassificationInstructions | None skip: bool class KGConnectorData(BaseModel): id: int source: str kg_coverage_days: int | None class KGStage(str, Enum): EXTRACTED = "extracted" NORMALIZED = "normalized" FAILED = "failed" SKIPPED = "skipped" NOT_STARTED = "not_started" EXTRACTING = "extracting" DO_NOT_EXTRACT = "do_not_extract" class KGClassificationResult(BaseModel): document_entity: str classification_class: str class KGImpliedExtractionResults(BaseModel): document_entity: str implied_entities: set[str] implied_relationships: set[str] company_participant_emails: set[str] account_participant_emails: set[str] class KGDocumentDeepExtractionResults(BaseModel): classification_result: KGClassificationResult | None deep_extracted_entities: set[str] deep_extracted_relationships: set[str] class KGException(Exception): pass ================================================ FILE: backend/onyx/kg/resets/reset_index.py ================================================ from sqlalchemy.orm import Session from onyx.db.document import reset_all_document_kg_stages from onyx.db.models import Connector from onyx.db.models import KGEntity from onyx.db.models import KGEntityExtractionStaging from onyx.db.models import KGEntityType from onyx.db.models import KGRelationship from onyx.db.models import KGRelationshipExtractionStaging from onyx.db.models import KGRelationshipType from onyx.db.models import KGRelationshipTypeExtractionStaging def reset_full_kg_index__commit(db_session: Session) -> None: """ Resets the knowledge graph index. """ db_session.query(KGRelationship).delete() db_session.query(KGRelationshipType).delete() db_session.query(KGEntity).delete() db_session.query(KGRelationshipExtractionStaging).delete() db_session.query(KGEntityExtractionStaging).delete() db_session.query(KGRelationshipTypeExtractionStaging).delete() # Update all connectors to disable KG processing db_session.query(Connector).update({"kg_processing_enabled": False}) # Only reset grounded entity types db_session.query(KGEntityType).filter( KGEntityType.grounded_source_name.isnot(None) ).update({"active": False}) reset_all_document_kg_stages(db_session) db_session.commit() ================================================ FILE: backend/onyx/kg/resets/reset_source.py ================================================ from redis.lock import Lock as RedisLock from sqlalchemy import or_ from onyx.configs.constants import DocumentSource from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import Connector from onyx.db.models import Document from onyx.db.models import DocumentByConnectorCredentialPair from onyx.db.models import KGEntity from onyx.db.models import KGEntityExtractionStaging from onyx.db.models import KGEntityType from onyx.db.models import KGRelationship from onyx.db.models import KGRelationshipExtractionStaging from onyx.db.models import KGRelationshipType from onyx.db.models import KGRelationshipTypeExtractionStaging from onyx.db.models import KGStage from onyx.kg.resets.reset_index import reset_full_kg_index__commit from onyx.kg.resets.reset_vespa import reset_vespa_kg_index def reset_source_kg_index( source_name: str | None, tenant_id: str, index_name: str, lock: RedisLock ) -> None: """ Resets the knowledge graph index and vespa for a source. """ # reset vespa for the source reset_vespa_kg_index(tenant_id, index_name, lock, source_name) with get_session_with_current_tenant() as db_session: if source_name is None: reset_full_kg_index__commit(db_session) return # get all the entity types for the given source entity_types = [ et.id_name for et in db_session.query(KGEntityType) .filter(KGEntityType.grounded_source_name == source_name) .all() ] if not entity_types: raise ValueError(f"There are no entity types for the source {source_name}") # delete the entity type from the knowledge graph for entity_type in entity_types: db_session.query(KGRelationship).filter( or_( KGRelationship.source_node_type == entity_type, KGRelationship.target_node_type == entity_type, ) ).delete() db_session.query(KGRelationshipType).filter( or_( KGRelationshipType.source_entity_type_id_name == entity_type, KGRelationshipType.target_entity_type_id_name == entity_type, ) ).delete() db_session.query(KGEntity).filter( KGEntity.entity_type_id_name == entity_type ).delete() db_session.query(KGRelationshipExtractionStaging).filter( or_( KGRelationshipExtractionStaging.source_node_type == entity_type, KGRelationshipExtractionStaging.target_node_type == entity_type, ) ).delete() db_session.query(KGEntityExtractionStaging).filter( KGEntityExtractionStaging.entity_type_id_name == entity_type ).delete() db_session.query(KGRelationshipTypeExtractionStaging).filter( or_( KGRelationshipTypeExtractionStaging.source_entity_type_id_name == entity_type, KGRelationshipTypeExtractionStaging.target_entity_type_id_name == entity_type, ) ).delete() db_session.commit() with get_session_with_current_tenant() as db_session: # get all the documents for the given source kg_connectors = [ connector.id for connector in db_session.query(Connector) .filter(Connector.source == DocumentSource(source_name)) .all() ] document_ids = [ cc_pair.id for cc_pair in db_session.query(DocumentByConnectorCredentialPair) .filter(DocumentByConnectorCredentialPair.connector_id.in_(kg_connectors)) .all() ] # reset the kg stage for the documents db_session.query(Document).filter(Document.id.in_(document_ids)).update( {"kg_stage": KGStage.NOT_STARTED} ) db_session.commit() ================================================ FILE: backend/onyx/kg/resets/reset_vespa.py ================================================ import time from typing import Any from redis.lock import Lock as RedisLock from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import DocumentSource from onyx.db.document import get_num_chunks_for_document from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import Connector from onyx.db.models import DocumentByConnectorCredentialPair from onyx.db.models import KGEntityType from onyx.document_index.document_index_utils import get_uuid_from_chunk_info from onyx.document_index.vespa.index import KGVespaChunkUpdateRequest from onyx.document_index.vespa.index import VespaIndex from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT from onyx.kg.utils.lock_utils import extend_lock from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT logger = setup_logger() def _reset_vespa_for_doc(document_id: str, tenant_id: str, index_name: str) -> None: vespa_index = VespaIndex( index_name=index_name, secondary_index_name=None, large_chunks_enabled=False, secondary_large_chunks_enabled=False, multitenant=MULTI_TENANT, httpx_client=None, ) reset_update_dict: dict[str, Any] = { "fields": { "kg_entities": {"assign": []}, "kg_relationships": {"assign": []}, "kg_terms": {"assign": []}, } } with get_session_with_current_tenant() as db_session: num_chunks = get_num_chunks_for_document(db_session, document_id) vespa_requests: list[KGVespaChunkUpdateRequest] = [] for chunk_num in range(num_chunks): doc_chunk_id = get_uuid_from_chunk_info( document_id=document_id, chunk_id=chunk_num, tenant_id=tenant_id, large_chunk_id=None, ) vespa_requests.append( KGVespaChunkUpdateRequest( document_id=document_id, chunk_id=chunk_num, url=f"{DOCUMENT_ID_ENDPOINT.format(index_name=vespa_index.index_name)}/{doc_chunk_id}", update_request=reset_update_dict, ) ) with vespa_index.httpx_client_context as httpx_client: vespa_index._apply_kg_chunk_updates_batched(vespa_requests, httpx_client) def reset_vespa_kg_index( tenant_id: str, index_name: str, lock: RedisLock, source_name: str | None = None ) -> None: """ Reset the kg info in vespa for all documents of a given source name, or all documents from kg grounded sources if source_name is None. """ logger.info( f"Resetting kg vespa index {index_name} for tenant {tenant_id}, source: {source_name if source_name else 'all'}" ) last_lock_time = time.monotonic() # Get all documents that need a vespa reset with get_session_with_current_tenant() as db_session: if source_name: # get all connectors of the given source name kg_connectors = [ connector.id for connector in db_session.query(Connector) .filter(Connector.source == DocumentSource(source_name)) .all() ] else: # get all connectors that have kg enabled kg_sources = [ DocumentSource(et.grounded_source_name) for et in db_session.query(KGEntityType) .filter( KGEntityType.grounded_source_name.is_not(None), KGEntityType.active.is_(True), ) .distinct() .all() ] kg_connectors = [ connector.id for connector in db_session.query(Connector) .filter(Connector.source.in_(kg_sources)) .all() ] # Get all the documents for the given connectors document_ids = [ cc_pair.id for cc_pair in db_session.query(DocumentByConnectorCredentialPair) .filter(DocumentByConnectorCredentialPair.connector_id.in_(kg_connectors)) .all() ] # Reset the kg fields for document_id in document_ids: _reset_vespa_for_doc(document_id, tenant_id, index_name) last_lock_time = extend_lock( lock, CELERY_GENERIC_BEAT_LOCK_TIMEOUT, last_lock_time ) logger.info( f"Finished resetting kg vespa index {index_name} for tenant {tenant_id}, source: {source_name if source_name else 'all'}" ) ================================================ FILE: backend/onyx/kg/setup/kg_default_entity_definitions.py ================================================ from typing import cast from sqlalchemy.orm import Session from onyx.configs.constants import DocumentSource from onyx.db.entity_type import KGEntityType from onyx.db.kg_config import get_kg_config_settings from onyx.db.kg_config import validate_kg_settings from onyx.kg.models import KGAttributeEntityOption from onyx.kg.models import KGAttributeImplicationProperty from onyx.kg.models import KGAttributeProperty from onyx.kg.models import KGEntityTypeAttributes from onyx.kg.models import KGEntityTypeClassificationInfo from onyx.kg.models import KGEntityTypeDefinition from onyx.kg.models import KGGroundingType def get_default_entity_types(vendor_name: str) -> dict[str, KGEntityTypeDefinition]: return { "LINEAR": KGEntityTypeDefinition( description="A formal Linear ticket about a product issue or improvement request.", attributes=KGEntityTypeAttributes( metadata_attribute_conversion={ "team": KGAttributeProperty(name="team", keep=True), "state": KGAttributeProperty(name="state", keep=True), "priority": KGAttributeProperty(name="priority", keep=True), "estimate": KGAttributeProperty(name="estimate", keep=True), "created_at": KGAttributeProperty(name="created_at", keep=True), "started_at": KGAttributeProperty(name="started_at", keep=True), "completed_at": KGAttributeProperty(name="completed_at", keep=True), "due_date": KGAttributeProperty(name="due_date", keep=True), "creator": KGAttributeProperty( name="creator", keep=False, implication_property=KGAttributeImplicationProperty( implied_entity_type=KGAttributeEntityOption.FROM_EMAIL, implied_relationship_name="is_creator_of", ), ), "assignee": KGAttributeProperty( name="assignee", keep=False, implication_property=KGAttributeImplicationProperty( implied_entity_type=KGAttributeEntityOption.FROM_EMAIL, implied_relationship_name="is_assignee_of", ), ), }, ), grounding=KGGroundingType.GROUNDED, grounded_source_name=DocumentSource.LINEAR, ), "JIRA": KGEntityTypeDefinition( description=( "A formal Jira ticket about a product issue or improvement request." ), attributes=KGEntityTypeAttributes( metadata_attribute_conversion={ "issuetype": KGAttributeProperty(name="subtype", keep=True), "status": KGAttributeProperty(name="status", keep=True), "priority": KGAttributeProperty(name="priority", keep=True), "project_name": KGAttributeProperty(name="project", keep=True), "created": KGAttributeProperty(name="created_at", keep=True), "updated": KGAttributeProperty(name="updated_at", keep=True), "resolution_date": KGAttributeProperty( name="completed_at", keep=True ), "duedate": KGAttributeProperty(name="due_date", keep=True), "reporter_email": KGAttributeProperty( name="creator", keep=False, implication_property=KGAttributeImplicationProperty( implied_entity_type=KGAttributeEntityOption.FROM_EMAIL, implied_relationship_name="is_creator_of", ), ), "assignee_email": KGAttributeProperty( name="assignee", keep=False, implication_property=KGAttributeImplicationProperty( implied_entity_type=KGAttributeEntityOption.FROM_EMAIL, implied_relationship_name="is_assignee_of", ), ), # not using implication property as that only captures 1 depth "key": KGAttributeProperty(name="key", keep=True), "parent": KGAttributeProperty(name="parent", keep=True), }, ), grounding=KGGroundingType.GROUNDED, grounded_source_name=DocumentSource.JIRA, ), "GITHUB_PR": KGEntityTypeDefinition( description="A formal engineering request to merge proposed changes into the codebase.", attributes=KGEntityTypeAttributes( metadata_attribute_conversion={ "repo": KGAttributeProperty(name="repository", keep=True), "state": KGAttributeProperty(name="state", keep=True), "num_commits": KGAttributeProperty(name="num_commits", keep=True), "num_files_changed": KGAttributeProperty( name="num_files_changed", keep=True ), "labels": KGAttributeProperty(name="labels", keep=True), "merged": KGAttributeProperty(name="merged", keep=True), "merged_at": KGAttributeProperty(name="merged_at", keep=True), "closed_at": KGAttributeProperty(name="closed_at", keep=True), "created_at": KGAttributeProperty(name="created_at", keep=True), "updated_at": KGAttributeProperty(name="updated_at", keep=True), "user": KGAttributeProperty( name="creator", keep=False, implication_property=KGAttributeImplicationProperty( implied_entity_type=KGAttributeEntityOption.FROM_EMAIL, implied_relationship_name="is_creator_of", ), ), "assignees": KGAttributeProperty( name="assignees", keep=False, implication_property=KGAttributeImplicationProperty( implied_entity_type=KGAttributeEntityOption.FROM_EMAIL, implied_relationship_name="is_assignee_of", ), ), }, entity_filter_attributes={"object_type": "PullRequest"}, ), grounding=KGGroundingType.GROUNDED, grounded_source_name=DocumentSource.GITHUB, ), "GITHUB_ISSUE": KGEntityTypeDefinition( description="A formal engineering ticket about an issue, idea, inquiry, or task.", attributes=KGEntityTypeAttributes( metadata_attribute_conversion={ "repo": KGAttributeProperty(name="repository", keep=True), "state": KGAttributeProperty(name="state", keep=True), "labels": KGAttributeProperty(name="labels", keep=True), "closed_at": KGAttributeProperty(name="closed_at", keep=True), "created_at": KGAttributeProperty(name="created_at", keep=True), "updated_at": KGAttributeProperty(name="updated_at", keep=True), "user": KGAttributeProperty( name="creator", keep=False, implication_property=KGAttributeImplicationProperty( implied_entity_type=KGAttributeEntityOption.FROM_EMAIL, implied_relationship_name="is_creator_of", ), ), "assignees": KGAttributeProperty( name="assignees", keep=False, implication_property=KGAttributeImplicationProperty( implied_entity_type=KGAttributeEntityOption.FROM_EMAIL, implied_relationship_name="is_assignee_of", ), ), }, entity_filter_attributes={"object_type": "Issue"}, ), grounding=KGGroundingType.GROUNDED, grounded_source_name=DocumentSource.GITHUB, ), "FIREFLIES": KGEntityTypeDefinition( description=( f"A phone call transcript between us ({vendor_name}) and another account or individuals, or an internal meeting." ), attributes=KGEntityTypeAttributes( classification_attributes={ "customer": KGEntityTypeClassificationInfo( extraction=True, description="a call with representatives of one or more customers prospects", ), "internal": KGEntityTypeClassificationInfo( extraction=True, description="a call between employees of the vendor's company (a vendor-internal call)", ), "interview": KGEntityTypeClassificationInfo( extraction=True, description=( "a call with an individual who is interviewed or is discussing potential employment with the vendor" ), ), "other": KGEntityTypeClassificationInfo( extraction=True, description=( "a call with representatives of companies having a different reason for the call " "(investment, partnering, etc.)" ), ), }, ), grounding=KGGroundingType.GROUNDED, grounded_source_name=DocumentSource.FIREFLIES, ), "ACCOUNT": KGEntityTypeDefinition( description=( "A company that was, is, or potentially could be a customer of the vendor " f"('us, {vendor_name}'). Note that {vendor_name} can never be an ACCOUNT." ), attributes=KGEntityTypeAttributes( entity_filter_attributes={"object_type": "Account"}, ), grounding=KGGroundingType.GROUNDED, grounded_source_name=DocumentSource.SALESFORCE, ), "OPPORTUNITY": KGEntityTypeDefinition( description="A sales opportunity.", attributes=KGEntityTypeAttributes( metadata_attribute_conversion={ "name": KGAttributeProperty(name="name", keep=True), "stage_name": KGAttributeProperty(name="stage", keep=True), "type": KGAttributeProperty(name="type", keep=True), "amount": KGAttributeProperty(name="amount", keep=True), "fiscal_year": KGAttributeProperty(name="fiscal_year", keep=True), "fiscal_quarter": KGAttributeProperty( name="fiscal_quarter", keep=True ), "is_closed": KGAttributeProperty(name="is_closed", keep=True), "close_date": KGAttributeProperty(name="close_date", keep=True), "probability": KGAttributeProperty( name="close_probability", keep=True ), "created_date": KGAttributeProperty(name="created_at", keep=True), "last_modified_date": KGAttributeProperty( name="updated_at", keep=True ), "account": KGAttributeProperty( name="account", keep=False, implication_property=KGAttributeImplicationProperty( implied_entity_type="ACCOUNT", implied_relationship_name="is_account_of", ), ), }, entity_filter_attributes={"object_type": "Opportunity"}, ), grounding=KGGroundingType.GROUNDED, grounded_source_name=DocumentSource.SALESFORCE, ), "VENDOR": KGEntityTypeDefinition( description=f"The Vendor {vendor_name}, 'us'", grounding=KGGroundingType.GROUNDED, active=True, grounded_source_name=None, ), "EMPLOYEE": KGEntityTypeDefinition( description=( f"A person who speaks on behalf of 'our' company (the VENDOR {vendor_name}), " "NOT of another account. Therefore, employees of other companies " "are NOT included here. If in doubt, do NOT extract." ), grounding=KGGroundingType.GROUNDED, active=False, grounded_source_name=None, ), } def populate_missing_default_entity_types__commit(db_session: Session) -> None: """ Populates the database with the missing default entity types. """ kg_config_settings = get_kg_config_settings() validate_kg_settings(kg_config_settings) vendor_name = cast(str, kg_config_settings.KG_VENDOR) existing_entity_types = {et.id_name for et in db_session.query(KGEntityType).all()} default_entity_types = get_default_entity_types(vendor_name=vendor_name) for entity_type_id_name, entity_type_definition in default_entity_types.items(): if entity_type_id_name in existing_entity_types: continue grounded_source_name = ( entity_type_definition.grounded_source_name.value if entity_type_definition.grounded_source_name else None ) kg_entity_type = KGEntityType( id_name=entity_type_id_name, description=entity_type_definition.description, attributes=entity_type_definition.attributes.model_dump(), grounding=entity_type_definition.grounding, grounded_source_name=grounded_source_name, active=entity_type_definition.active, ) db_session.add(kg_entity_type) db_session.commit() ================================================ FILE: backend/onyx/kg/utils/embeddings.py ================================================ from typing import List import numpy as np from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.search_settings import get_current_search_settings from onyx.natural_language_processing.search_nlp_models import EmbeddingModel from onyx.natural_language_processing.search_nlp_models import EmbedTextType from shared_configs.configs import MODEL_SERVER_HOST from shared_configs.configs import MODEL_SERVER_PORT def encode_string_batch(strings: List[str]) -> np.ndarray: with get_session_with_current_tenant() as db_session: current_search_settings = get_current_search_settings(db_session) model = EmbeddingModel.from_db_model( search_settings=current_search_settings, server_host=MODEL_SERVER_HOST, server_port=MODEL_SERVER_PORT, ) # Get embeddings while session is still open embedding = model.encode(strings, text_type=EmbedTextType.QUERY) return np.array(embedding) ================================================ FILE: backend/onyx/kg/utils/extraction_utils.py ================================================ import json from onyx.configs.constants import DocumentSource from onyx.configs.constants import OnyxCallTypes from onyx.configs.kg_configs import KG_METADATA_TRACKING_THRESHOLD from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.entities import get_kg_entity_by_document from onyx.db.entity_type import get_entity_types from onyx.db.kg_config import KGConfigSettings from onyx.db.models import Document from onyx.db.models import KGEntityType from onyx.db.models import KGRelationshipType from onyx.db.tag import get_structured_tags_for_document from onyx.kg.models import KGAttributeEntityOption from onyx.kg.models import KGAttributeTrackInfo from onyx.kg.models import KGAttributeTrackType from onyx.kg.models import KGChunkFormat from onyx.kg.models import KGClassificationInstructions from onyx.kg.models import KGClassificationResult from onyx.kg.models import KGDocumentDeepExtractionResults from onyx.kg.models import KGEnhancedDocumentMetadata from onyx.kg.models import KGImpliedExtractionResults from onyx.kg.models import KGMetadataContent from onyx.kg.utils.formatting_utils import extract_email from onyx.kg.utils.formatting_utils import get_entity_type from onyx.kg.utils.formatting_utils import kg_email_processing from onyx.kg.utils.formatting_utils import make_entity_id from onyx.kg.utils.formatting_utils import make_relationship_id from onyx.kg.utils.formatting_utils import make_relationship_type_id from onyx.kg.vespa.vespa_interactions import get_document_vespa_contents from onyx.llm.factory import get_default_llm from onyx.llm.models import UserMessage from onyx.llm.utils import llm_response_to_string from onyx.prompts.kg_prompts import CALL_CHUNK_PREPROCESSING_PROMPT from onyx.prompts.kg_prompts import CALL_DOCUMENT_CLASSIFICATION_PROMPT from onyx.prompts.kg_prompts import GENERAL_CHUNK_PREPROCESSING_PROMPT from onyx.prompts.kg_prompts import MASTER_EXTRACTION_PROMPT from onyx.tracing.llm_utils import llm_generation_span from onyx.tracing.llm_utils import record_llm_response from onyx.utils.logger import setup_logger logger = setup_logger() def get_entity_types_str(active: bool | None = None) -> str: """ Format the entity types into a string for the LLM. """ with get_session_with_current_tenant() as db_session: entity_types = get_entity_types(db_session, active) entity_types_list: list[str] = [] for entity_type in entity_types: if entity_type.description: entity_description = "\n - Description: " + entity_type.description else: entity_description = "" if entity_type.entity_values: allowed_values = "\n - Allowed Values: " + ", ".join( entity_type.entity_values ) else: allowed_values = "" attributes = entity_type.parsed_attributes entity_type_attribute_list: list[str] = [] for attribute, values in attributes.attribute_values.items(): entity_type_attribute_list.append( f"{attribute}: {trackinfo_to_str(values)}" ) if attributes.classification_attributes: entity_type_attribute_list.append( # TODO: restructure classification attribute to be a dict of attribute name to classification info # e.g., {scope: {internal: prompt, external: prompt}, sentiment: {positive: prompt, negative: prompt}} "classification: one of: " + ", ".join(attributes.classification_attributes.keys()) ) if entity_type_attribute_list: entity_attributes = "\n - Attributes:\n - " + "\n - ".join( entity_type_attribute_list ) else: entity_attributes = "" entity_types_list.append( entity_type.id_name + entity_description + allowed_values + entity_attributes ) return "\n".join(entity_types_list) def get_relationship_types_str(active: bool | None = None) -> str: """ Format the relationship types into a string for the LLM. """ with get_session_with_current_tenant() as db_session: active_filters = [] if active is not None: active_filters.append(KGRelationshipType.active == active) relationship_types = ( db_session.query(KGRelationshipType).filter(*active_filters).all() ) relationship_types_list = [] for rel_type in relationship_types: # Format as "source_type__relationship_type__target_type" formatted_type = make_relationship_type_id( rel_type.source_entity_type_id_name, rel_type.type, rel_type.target_entity_type_id_name, ) relationship_types_list.append(formatted_type) return "\n".join(relationship_types_list) def kg_process_owners( owner_emails: list[str], document_entity_id: str, relationship_type: str, kg_config_settings: KGConfigSettings, active_entity_types: set[str], ) -> tuple[set[str], set[str], set[str], set[str]]: owner_entities: set[str] = set() owner_relationships: set[str] = set() company_participant_emails: set[str] = set() account_participant_emails: set[str] = set() for owner_email in owner_emails: if extract_email(owner_email) is None: continue process_results = kg_process_person( owner_email, document_entity_id, relationship_type, kg_config_settings, active_entity_types, ) if process_results is None: continue ( owner_entity, owner_relationship, company_participant_email, account_participant_email, ) = process_results owner_entities.add(owner_entity) owner_relationships.add(owner_relationship) if company_participant_email: company_participant_emails.add(company_participant_email) if account_participant_email: account_participant_emails.add(account_participant_email) return ( owner_entities, owner_relationships, company_participant_emails, account_participant_emails, ) def kg_implied_extraction( document: Document, doc_metadata: KGEnhancedDocumentMetadata, active_entity_types: set[str], kg_config_settings: KGConfigSettings, ) -> KGImpliedExtractionResults: """ Generate entities, relationships, and attributes for a document. """ # Get document entity and metadata stuff from the KGEnhancedDocumentMetadata document_entity_type = doc_metadata.entity_type document_metadata = doc_metadata.document_metadata or {} metadata_attribute_conversion = doc_metadata.metadata_attribute_conversion if document_entity_type is None or metadata_attribute_conversion is None: raise ValueError("Entity type and metadata attributes are required") implied_entities: set[str] = set() implied_relationships: set[str] = set() # Quantity needed for call processing - participants from vendor company_participant_emails: set[str] = set() # Quantity needed for call processing - external participants account_participant_emails: set[str] = set() # Chunk treatment variables document_is_from_call = document_entity_type.lower() in ( call_type.value.lower() for call_type in OnyxCallTypes ) # Get core entity document_id = document.id primary_owners = document.primary_owners secondary_owners = document.secondary_owners with get_session_with_current_tenant() as db_session: document_entity = get_kg_entity_by_document(db_session, document_id) if document_entity: document_entity_id = document_entity.id_name else: document_entity_id = make_entity_id(document_entity_type, document_id) # Get implied entities and relationships from primary/secondary owners if document_is_from_call: ( implied_entities, implied_relationships, company_participant_emails, account_participant_emails, ) = kg_process_owners( owner_emails=(primary_owners or []) + (secondary_owners or []), document_entity_id=document_entity_id, relationship_type="participates_in", kg_config_settings=kg_config_settings, active_entity_types=active_entity_types, ) else: ( implied_entities, implied_relationships, company_participant_emails, account_participant_emails, ) = kg_process_owners( owner_emails=primary_owners or [], document_entity_id=document_entity_id, relationship_type="leads", kg_config_settings=kg_config_settings, active_entity_types=active_entity_types, ) ( participant_entities, participant_relationships, company_emails, account_emails, ) = kg_process_owners( owner_emails=secondary_owners or [], document_entity_id=document_entity_id, relationship_type="participates_in", kg_config_settings=kg_config_settings, active_entity_types=active_entity_types, ) implied_entities.update(participant_entities) implied_relationships.update(participant_relationships) company_participant_emails.update(company_emails) account_participant_emails.update(account_emails) # Get implied entities and relationships from document metadata for metadata, value in document_metadata.items(): # get implication property for this metadata if metadata not in metadata_attribute_conversion: continue if ( implication_property := metadata_attribute_conversion[ metadata ].implication_property ) is None: continue if not isinstance(value, str) and not isinstance(value, list): continue values: list[str] = [value] if isinstance(value, str) else value # create implied entities and relationships for item in values: if ( implication_property.implied_entity_type == KGAttributeEntityOption.FROM_EMAIL ): # determine entity type from email email = extract_email(item) if email is None: continue process_results = kg_process_person( email=email, document_entity_id=document_entity_id, relationship_type=implication_property.implied_relationship_name, kg_config_settings=kg_config_settings, active_entity_types=active_entity_types, ) if process_results is None: continue (implied_entity, implied_relationship, _, _) = process_results implied_entities.add(implied_entity) implied_relationships.add(implied_relationship) else: # use the given entity type entity_type = implication_property.implied_entity_type if entity_type not in active_entity_types: continue implied_entity = make_entity_id(entity_type, item) implied_entities.add(implied_entity) implied_relationships.add( make_relationship_id( implied_entity, implication_property.implied_relationship_name, document_entity_id, ) ) return KGImpliedExtractionResults( document_entity=document_entity_id, implied_entities=implied_entities, implied_relationships=implied_relationships, company_participant_emails=company_participant_emails, account_participant_emails=account_participant_emails, ) def kg_deep_extraction( document_id: str, metadata: KGEnhancedDocumentMetadata, implied_extraction: KGImpliedExtractionResults, tenant_id: str, index_name: str, kg_config_settings: KGConfigSettings, ) -> KGDocumentDeepExtractionResults: """ Perform deep extraction and classification on the document. """ result = KGDocumentDeepExtractionResults( classification_result=None, deep_extracted_entities=set(), deep_extracted_relationships=set(), ) entity_types_str = get_entity_types_str(active=True) relationship_types_str = get_relationship_types_str(active=True) for i, chunk_batch in enumerate( get_document_vespa_contents(document_id, index_name, tenant_id) ): # use first batch for classification if i == 0 and metadata.classification_enabled: if not metadata.classification_instructions: raise ValueError( "Classification is enabled but no instructions are provided" ) result.classification_result = kg_classify_document( document_entity=implied_extraction.document_entity, chunk_batch=chunk_batch, implied_extraction=implied_extraction, classification_instructions=metadata.classification_instructions, kg_config_settings=kg_config_settings, ) # deep extract from this chunk batch chunk_batch_results = kg_deep_extract_chunks( document_entity=implied_extraction.document_entity, chunk_batch=chunk_batch, implied_extraction=implied_extraction, kg_config_settings=kg_config_settings, entity_types_str=entity_types_str, relationship_types_str=relationship_types_str, ) if chunk_batch_results is not None: result.deep_extracted_entities.update( chunk_batch_results.deep_extracted_entities ) result.deep_extracted_relationships.update( chunk_batch_results.deep_extracted_relationships ) return result def kg_classify_document( document_entity: str, chunk_batch: list[KGChunkFormat], implied_extraction: KGImpliedExtractionResults, classification_instructions: KGClassificationInstructions, kg_config_settings: KGConfigSettings, ) -> KGClassificationResult | None: # currently, classification is only done for calls # TODO: add support (or use same prompt and format) for non-call documents entity_type = get_entity_type(document_entity) if entity_type not in (call_type.value for call_type in OnyxCallTypes): return None # prepare prompt implied_extraction.document_entity company_participants = implied_extraction.company_participant_emails account_participants = implied_extraction.account_participant_emails content = ( f"Title: {chunk_batch[0].title}:\nVendor Participants:\n" + "".join(f" - {participant}\n" for participant in company_participants) + "Other Participants:\n" + "".join(f" - {participant}\n" for participant in account_participants) + "Call Content:\n" + "\n".join(chunk.content for chunk in chunk_batch) ) category_list = { cls: definition.description for cls, definition in classification_instructions.classification_class_definitions.items() } prompt = CALL_DOCUMENT_CLASSIFICATION_PROMPT.format( beginning_of_call_content=content, category_list=category_list, category_options=classification_instructions.classification_options, vendor=kg_config_settings.KG_VENDOR, ) # classify with LLM with Braintrust tracing llm = get_default_llm() try: prompt_msg = UserMessage(content=prompt) with llm_generation_span( llm=llm, flow="kg_document_classification", input_messages=[prompt_msg] ) as span_generation: response = llm.invoke(prompt_msg) record_llm_response(span_generation, response) raw_classification_result = llm_response_to_string(response) classification_result = ( raw_classification_result.replace("```json", "").replace("```", "").strip() ) # no json parsing here because of reasoning output classification_class = classification_result.split("CATEGORY:")[1].strip() if ( classification_class in classification_instructions.classification_class_definitions ): return KGClassificationResult( document_entity=document_entity, classification_class=classification_class, ) except Exception as e: logger.error(f"Failed to classify document {document_entity}. Error: {str(e)}") return None def kg_deep_extract_chunks( document_entity: str, chunk_batch: list[KGChunkFormat], implied_extraction: KGImpliedExtractionResults, kg_config_settings: KGConfigSettings, entity_types_str: str, relationship_types_str: str, ) -> KGDocumentDeepExtractionResults | None: # currently, calls are treated differently # TODO: either treat some other documents differently too, or ideally all the same way entity_type = get_entity_type(document_entity) is_call = entity_type in (call_type.value for call_type in OnyxCallTypes) content = "\n".join(chunk.content for chunk in chunk_batch) # prepare prompt if is_call: company_participants_str = "".join( f" - {participant}\n" for participant in implied_extraction.company_participant_emails ) account_participants_str = "".join( f" - {participant}\n" for participant in implied_extraction.account_participant_emails ) llm_context = CALL_CHUNK_PREPROCESSING_PROMPT.format( participant_string=company_participants_str, account_participant_string=account_participants_str, vendor=kg_config_settings.KG_VENDOR, content=content, ) else: llm_context = GENERAL_CHUNK_PREPROCESSING_PROMPT.format( vendor=kg_config_settings.KG_VENDOR, content=content, ) prompt = MASTER_EXTRACTION_PROMPT.format( entity_types=entity_types_str, relationship_types=relationship_types_str, ).replace("---content---", llm_context) # extract with LLM with Braintrust tracing llm = get_default_llm() try: prompt_msg = UserMessage(content=prompt) with llm_generation_span( llm=llm, flow="kg_deep_extraction", input_messages=[prompt_msg] ) as span_generation: response = llm.invoke(prompt_msg) record_llm_response(span_generation, response) raw_extraction_result = llm_response_to_string(response) cleaned_response = ( raw_extraction_result.replace("{{", "{") .replace("}}", "}") .replace("```json\n", "") .replace("\n```", "") .replace("\n", "") ) first_bracket = cleaned_response.find("{") last_bracket = cleaned_response.rfind("}") cleaned_response = cleaned_response[first_bracket : last_bracket + 1] parsed_result = json.loads(cleaned_response) return KGDocumentDeepExtractionResults( classification_result=None, deep_extracted_entities=set(parsed_result.get("entities", [])), deep_extracted_relationships={ rel.replace(" ", "_") for rel in parsed_result.get("relationships", []) }, ) except Exception as e: failed_chunks = [chunk.chunk_id for chunk in chunk_batch] logger.error( f"Failed to process chunks {failed_chunks} from document {document_entity}. Error: {str(e)}" ) return None def kg_process_person( email: str, document_entity_id: str, relationship_type: str, kg_config_settings: KGConfigSettings, active_entity_types: set[str], ) -> tuple[str, str, str, str] | None: """ Create an employee or account entity from an email address, and a relationship to the entity from the document that the email is from. Returns: tuple containing (person_entity, person_relationship, company_participant_email, and account_participant_email), or None if the created entity is not of an active entity type or is from an ignored email domain. """ kg_person = kg_email_processing(email, kg_config_settings) if any( domain.lower() in kg_person.company.lower() for domain in kg_config_settings.KG_IGNORE_EMAIL_DOMAINS ): return None person_entity = None if kg_person.employee and "EMPLOYEE" in active_entity_types: person_entity = make_entity_id("EMPLOYEE", kg_person.name) elif not kg_person.employee and "ACCOUNT" in active_entity_types: person_entity = make_entity_id("ACCOUNT", kg_person.company) if person_entity: is_account = person_entity.startswith("ACCOUNT") participant_email = f"{kg_person.name} -- ({kg_person.company})" return ( person_entity, make_relationship_id(person_entity, relationship_type, document_entity_id), participant_email if not is_account else "", participant_email if is_account else "", ) return None def get_batch_documents_metadata( document_ids: list[str], connector_source: str ) -> list[KGMetadataContent]: """ Gets the metadata for a batch of documents. """ batch_metadata: list[KGMetadataContent] = [] source_type = DocumentSource(connector_source).value with get_session_with_current_tenant() as db_session: for document_id in document_ids: # get document metadata metadata = get_structured_tags_for_document(document_id, db_session) batch_metadata.append( KGMetadataContent( document_id=document_id, source_type=source_type, source_metadata=metadata, ) ) return batch_metadata def trackinfo_to_str(trackinfo: KGAttributeTrackInfo | None) -> str: """Convert trackinfo to an LLM friendly string""" if trackinfo is None: return "" if trackinfo.type == KGAttributeTrackType.LIST: if trackinfo.values is None: return "a list of any suitable values" return "a list with possible values: " + ", ".join(trackinfo.values) elif trackinfo.type == KGAttributeTrackType.VALUE: if trackinfo.values is None: return "any suitable value" return "one of: " + ", ".join(trackinfo.values) def trackinfo_to_dict(trackinfo: KGAttributeTrackInfo | None) -> dict | None: if trackinfo is None: return None return { "type": trackinfo.type, "values": (list(trackinfo.values) if trackinfo.values else None), } class EntityTypeMetadataTracker: def __init__(self) -> None: """ Tracks the possible values the metadata attributes can take for each entity type. """ # entity type -> attribute -> trackinfo self.entity_attr_info: dict[str, dict[str, KGAttributeTrackInfo | None]] = {} self.entity_allowed_attrs: dict[str, set[str]] = {} def import_typeinfo(self) -> None: """ Loads the metadata tracking information from the database. """ with get_session_with_current_tenant() as db_session: entity_types = db_session.query(KGEntityType).all() for entity_type in entity_types: self.entity_attr_info[entity_type.id_name] = ( entity_type.parsed_attributes.attribute_values ) self.entity_allowed_attrs[entity_type.id_name] = { attr.name for attr in entity_type.parsed_attributes.metadata_attribute_conversion.values() } def export_typeinfo(self) -> None: """ Exports the metadata tracking information to the database. """ with get_session_with_current_tenant() as db_session: for entity_type_id_name, attribute_values in self.entity_attr_info.items(): db_session.query(KGEntityType).filter( KGEntityType.id_name == entity_type_id_name ).update( { KGEntityType.attributes: KGEntityType.attributes.op("||")( { "attribute_values": { attr: trackinfo_to_dict(info) for attr, info in attribute_values.items() } } ) }, synchronize_session=False, ) db_session.commit() def track_metadata( self, entity_type: str, attributes: dict[str, str | list[str]] ) -> None: """ Tracks which values are possible for the given attributes. If the attribute value is a list, we track the values in the list rather than the list itself. If we see to many different values, we stop tracking the attribute. """ for attribute, value in attributes.items(): # ignore types/metadata we are not tracking if entity_type not in self.entity_attr_info: continue if attribute not in self.entity_allowed_attrs[entity_type]: continue # determine if the attribute is a list or a value trackinfo = self.entity_attr_info[entity_type].get(attribute, None) if trackinfo is None: trackinfo = KGAttributeTrackInfo( type=( KGAttributeTrackType.VALUE if isinstance(value, str) else KGAttributeTrackType.LIST ), values=set(), ) self.entity_attr_info[entity_type][attribute] = trackinfo # None means marked as don't track if trackinfo.values is None: continue # track the value if isinstance(value, str): trackinfo.values.add(value) else: trackinfo.type = KGAttributeTrackType.LIST trackinfo.values.update(value) # if we see to many different values, we stop tracking if len(trackinfo.values) > KG_METADATA_TRACKING_THRESHOLD: trackinfo.values = None ================================================ FILE: backend/onyx/kg/utils/formatting_utils.py ================================================ import re from onyx.db.kg_config import KGConfigSettings from onyx.kg.models import KGPerson def format_entity_id(entity_id_name: str) -> str: return make_entity_id(*split_entity_id(entity_id_name)) def make_entity_id(entity_type: str, entity_name: str) -> str: return f"{entity_type.upper()}::{entity_name.lower()}" def split_entity_id(entity_id_name: str) -> list[str]: return entity_id_name.split("::") def get_entity_type(entity_id_name: str) -> str: return entity_id_name.split("::", 1)[0].upper() def format_entity_id_for_models(entity_id_name: str) -> str: entity_split = entity_id_name.split("::") if len(entity_split) == 2: entity_type, entity_name = entity_split separator = "::" elif len(entity_split) > 2: raise ValueError(f"Entity {entity_id_name} is not in the correct format") else: entity_name = entity_id_name separator = entity_type = "" formatted_entity_type = entity_type.strip().upper() formatted_entity_name = entity_name.strip().replace('"', "").replace("'", "") return f"{formatted_entity_type}{separator}{formatted_entity_name}" def get_attributes(entity_w_attributes: str) -> dict[str, str]: """ Extract attributes from an entity string. E.g., "TYPE::Entity--[attr1: value1, attr2: value2]" -> {"attr1": "value1", "attr2": "value2"} """ attr_split = entity_w_attributes.split("--") if len(attr_split) != 2: raise ValueError(f"Invalid entity with attributes: {entity_w_attributes}") match = re.search(r"\[(.*)\]", attr_split[1]) if not match: return {} attr_list_str = match.group(1) return { attr_split[0].strip(): attr_split[1].strip() for attr in attr_list_str.split(",") if len(attr_split := attr.split(":", 1)) == 2 } def make_entity_w_attributes(entity: str, attributes: dict[str, str]) -> str: return f"{entity}--[{', '.join(f'{k}: {v}' for k, v in attributes.items())}]" def format_relationship_id(relationship_id_name: str) -> str: return make_relationship_id(*split_relationship_id(relationship_id_name)) def make_relationship_id( source_node: str, relationship_type: str, target_node: str ) -> str: return f"{format_entity_id(source_node)}__{relationship_type.lower()}__{format_entity_id(target_node)}" def split_relationship_id(relationship_id_name: str) -> list[str]: return relationship_id_name.split("__") def format_relationship_type_id(relationship_type_id_name: str) -> str: return make_relationship_type_id( *split_relationship_type_id(relationship_type_id_name) ) def make_relationship_type_id( source_node_type: str, relationship_type: str, target_node_type: str ) -> str: return f"{source_node_type.upper()}__{relationship_type.lower()}__{target_node_type.upper()}" def split_relationship_type_id(relationship_type_id_name: str) -> list[str]: return relationship_type_id_name.split("__") def extract_relationship_type_id(relationship_id_name: str) -> str: source_node, relationship_type, target_node = split_relationship_id( relationship_id_name ) return make_relationship_type_id( get_entity_type(source_node), relationship_type, get_entity_type(target_node) ) def extract_email(email: str) -> str | None: """ Extract an email from an arbitrary string (if any). Only the first email is returned. """ match = re.search(r"([A-Za-z0-9._+-]+@[A-Za-z0-9-]+(?:\.[A-Za-z0-9-]+)+)", email) return match.group(0) if match else None def kg_email_processing(email: str, kg_config_settings: KGConfigSettings) -> KGPerson: """ Process the email. """ name, company_domain = email.split("@") assert isinstance(company_domain, str) assert isinstance(kg_config_settings.KG_VENDOR_DOMAINS, list) assert isinstance(kg_config_settings.KG_VENDOR, str) employee = any( domain in company_domain for domain in kg_config_settings.KG_VENDOR_DOMAINS ) if employee: company = kg_config_settings.KG_VENDOR else: # TODO: maybe store a list of domains for each account and use that to match # right now, gmail and other random domains are being converted into accounts company = company_domain.title() return KGPerson(name=name, company=company, employee=employee) ================================================ FILE: backend/onyx/kg/utils/lock_utils.py ================================================ import time from redis.lock import Lock as RedisLock def extend_lock(lock: RedisLock, timeout: int, last_lock_time: float) -> float: current_time = time.monotonic() if current_time - last_lock_time >= (timeout / 4): lock.reacquire() last_lock_time = current_time return last_lock_time ================================================ FILE: backend/onyx/kg/vespa/vespa_interactions.py ================================================ import json from collections.abc import Generator from onyx.document_index.vespa.chunk_retrieval import get_chunks_via_visit_api from onyx.document_index.vespa.chunk_retrieval import VespaChunkRequest from onyx.document_index.vespa.index import IndexFilters from onyx.kg.models import KGChunkFormat from onyx.utils.logger import setup_logger logger = setup_logger() def get_document_vespa_contents( document_id: str, index_name: str, tenant_id: str, batch_size: int = 8, ) -> Generator[list[KGChunkFormat], None, None]: """ Retrieves chunks from Vespa for the given document IDs and converts them to KGChunks. Args: document_id (str): ID of the document to fetch chunks for index_name (str): Name of the Vespa index tenant_id (str): ID of the tenant batch_size (int): Number of chunks to fetch per batch Yields: list[KGChunk]: Batches of chunks ready for KG processing """ current_batch: list[KGChunkFormat] = [] # get all chunks for the document # TODO: revisit the visit function chunks = get_chunks_via_visit_api( chunk_request=VespaChunkRequest(document_id=document_id), index_name=index_name, filters=IndexFilters(access_control_list=None, tenant_id=tenant_id), field_names=[ "document_id", "chunk_id", "title", "content", "metadata", "primary_owners", "secondary_owners", "source_type", ], get_large_chunks=False, ) # Convert Vespa chunks to KGChunks # kg_chunks: list[KGChunkFormat] = [] for i, chunk in enumerate(chunks): fields = chunk["fields"] if isinstance(fields.get("metadata", {}), str): fields["metadata"] = json.loads(fields["metadata"]) current_batch.append( KGChunkFormat( connector_id=None, # We may need to adjust this document_id=fields.get("document_id"), chunk_id=fields.get("chunk_id"), primary_owners=fields.get("primary_owners", []), secondary_owners=fields.get("secondary_owners", []), source_type=fields.get("source_type", ""), title=fields.get("title", ""), content=fields.get("content", ""), metadata=fields.get("metadata", {}), ) ) if len(current_batch) >= batch_size: yield current_batch current_batch = [] # Yield any remaining chunks if current_batch: yield current_batch ================================================ FILE: backend/onyx/llm/__init__.py ================================================ ================================================ FILE: backend/onyx/llm/constants.py ================================================ """ LLM Constants Centralized constants for LLM providers, vendors, and display names. """ from enum import Enum # Provider names class LlmProviderNames(str, Enum): """ Canonical string identifiers for LLM providers. """ OPENAI = "openai" ANTHROPIC = "anthropic" GOOGLE = "google" BEDROCK = "bedrock" BEDROCK_CONVERSE = "bedrock_converse" VERTEX_AI = "vertex_ai" OPENROUTER = "openrouter" AZURE = "azure" OLLAMA_CHAT = "ollama_chat" LM_STUDIO = "lm_studio" MISTRAL = "mistral" LITELLM_PROXY = "litellm_proxy" BIFROST = "bifrost" def __str__(self) -> str: """Needed so things like: f"{LlmProviderNames.OPENAI}/" gives back "openai/" instead of "LlmProviderNames.OPENAI/" """ return self.value WELL_KNOWN_PROVIDER_NAMES = [ LlmProviderNames.OPENAI, LlmProviderNames.ANTHROPIC, LlmProviderNames.VERTEX_AI, LlmProviderNames.BEDROCK, LlmProviderNames.OPENROUTER, LlmProviderNames.AZURE, LlmProviderNames.OLLAMA_CHAT, LlmProviderNames.LM_STUDIO, LlmProviderNames.LITELLM_PROXY, LlmProviderNames.BIFROST, ] # Proper capitalization for known providers and vendors PROVIDER_DISPLAY_NAMES: dict[str, str] = { LlmProviderNames.OPENAI: "OpenAI", LlmProviderNames.ANTHROPIC: "Anthropic", LlmProviderNames.GOOGLE: "Google", LlmProviderNames.BEDROCK: "Bedrock", LlmProviderNames.BEDROCK_CONVERSE: "Bedrock", LlmProviderNames.VERTEX_AI: "Vertex AI", LlmProviderNames.OPENROUTER: "OpenRouter", LlmProviderNames.AZURE: "Azure", "ollama": "Ollama", LlmProviderNames.OLLAMA_CHAT: "Ollama", LlmProviderNames.LM_STUDIO: "LM Studio", LlmProviderNames.LITELLM_PROXY: "LiteLLM Proxy", LlmProviderNames.BIFROST: "Bifrost", "groq": "Groq", "anyscale": "Anyscale", "deepseek": "DeepSeek", "xai": "xAI", LlmProviderNames.MISTRAL: "Mistral", "mistralai": "Mistral", # Alias used by some providers "cohere": "Cohere", "perplexity": "Perplexity", "amazon": "Amazon", "meta": "Meta", "meta-llama": "Meta", # Alias used by some providers "ai21": "AI21", "nvidia": "NVIDIA", "databricks": "Databricks", "alibaba": "Alibaba", "qwen": "Qwen", "microsoft": "Microsoft", "gemini": "Gemini", "stability": "Stability", "writer": "Writer", } # Map vendors to their brand names (used for provider_display_name generation) VENDOR_BRAND_NAMES: dict[str, str] = { "anthropic": "Claude", "openai": "GPT", "google": "Gemini", "amazon": "Nova", "meta": "Llama", "mistral": "Mistral", "cohere": "Command", "deepseek": "DeepSeek", "xai": "Grok", "perplexity": "Sonar", "ai21": "Jamba", "nvidia": "Nemotron", "qwen": "Qwen", "alibaba": "Qwen", "writer": "Palmyra", } # Aggregator providers that host models from multiple vendors AGGREGATOR_PROVIDERS: set[str] = { LlmProviderNames.BEDROCK, LlmProviderNames.BEDROCK_CONVERSE, LlmProviderNames.OPENROUTER, LlmProviderNames.OLLAMA_CHAT, LlmProviderNames.LM_STUDIO, LlmProviderNames.VERTEX_AI, LlmProviderNames.AZURE, LlmProviderNames.LITELLM_PROXY, LlmProviderNames.BIFROST, } # Model family name mappings for display name generation # Used by Bedrock display name generator BEDROCK_MODEL_NAME_MAPPINGS: dict[str, str] = { "claude": "Claude", "llama": "Llama", "mistral": "Mistral", "mixtral": "Mixtral", "titan": "Titan", "nova": "Nova", "jamba": "Jamba", "command": "Command", "deepseek": "DeepSeek", } # Used by Ollama display name generator OLLAMA_MODEL_NAME_MAPPINGS: dict[str, str] = { "llama": "Llama", "qwen": "Qwen", "mistral": "Mistral", "deepseek": "DeepSeek", "gemma": "Gemma", "phi": "Phi", "codellama": "Code Llama", "starcoder": "StarCoder", "wizardcoder": "WizardCoder", "vicuna": "Vicuna", "orca": "Orca", "dolphin": "Dolphin", "nous": "Nous", "neural": "Neural", "mixtral": "Mixtral", "falcon": "Falcon", "yi": "Yi", "command": "Command", "zephyr": "Zephyr", "openchat": "OpenChat", "solar": "Solar", } # Bedrock model token limits (AWS doesn't expose this via API) # Note: Many Bedrock model IDs include context length suffix (e.g., ":200k") # which is parsed first. This mapping is for models without suffixes. # Sources: # - LiteLLM model_prices_and_context_window.json # - AWS Bedrock documentation and announcement blogs BEDROCK_MODEL_TOKEN_LIMITS: dict[str, int] = { # Anthropic Claude models (new naming: claude-{tier}-{version}) "claude-opus-4": 200000, "claude-sonnet-4": 200000, "claude-haiku-4": 200000, # Anthropic Claude models (old naming: claude-{version}) "claude-4": 200000, "claude-3-7": 200000, "claude-3-5": 200000, "claude-3": 200000, "claude-v2": 100000, "claude-instant": 100000, # Amazon Nova models (from LiteLLM) "nova-premier": 1000000, "nova-pro": 300000, "nova-lite": 300000, "nova-2-lite": 1000000, # Nova 2 Lite has 1M context "nova-2-sonic": 128000, "nova-micro": 128000, # Amazon Titan models (from LiteLLM: all text models are 42K) "titan-text-premier": 42000, "titan-text-express": 42000, "titan-text-lite": 42000, "titan-tg1": 8000, # Meta Llama models (Llama 3 base = 8K, Llama 3.1+ = 128K) "llama4": 128000, "llama3-3": 128000, "llama3-2": 128000, "llama3-1": 128000, "llama3-8b": 8000, "llama3-70b": 8000, # Mistral models (Large 2+ = 128K, original Large/Small = 32K) "mistral-large-3": 128000, "mistral-large-2407": 128000, # Mistral Large 2 "mistral-large-2402": 32000, # Original Mistral Large "mistral-large": 128000, # Default to newer version "mistral-small": 32000, "mistral-7b": 32000, "mixtral-8x7b": 32000, "pixtral": 128000, "ministral": 128000, "magistral": 128000, "voxtral": 32000, # Cohere models "command-r-plus": 128000, "command-r": 128000, # DeepSeek models "deepseek": 64000, # Google Gemma models "gemma-3": 128000, "gemma-2": 8000, "gemma": 8000, # Qwen models "qwen3": 128000, "qwen2": 128000, # NVIDIA models "nemotron": 128000, # Writer Palmyra models "palmyra": 128000, # Moonshot Kimi "kimi": 128000, # Minimax "minimax": 128000, # OpenAI (via Bedrock) "gpt-oss": 128000, # AI21 models (from LiteLLM: Jamba 1.5 = 256K, Jamba Instruct = 70K) "jamba-1-5": 256000, "jamba-instruct": 70000, "jamba": 256000, # Default to newer version } # Models that should keep their hyphenated format in display names # These are model families where the hyphen is part of the brand name HYPHENATED_MODEL_NAMES: set[str] = { "gpt-oss", } # General model prefix to vendor mapping (used as fallback when enrichment data is missing) # This covers common model families across all providers MODEL_PREFIX_TO_VENDOR: dict[str, str] = { # Google "gemini": "google", "gemma": "google", "palm": "google", # Anthropic "claude": "anthropic", # OpenAI "gpt": "openai", "o1": "openai", "o3": "openai", "o4": "openai", "chatgpt": "openai", # Meta "llama": "meta", "codellama": "meta", # Mistral "mistral": "mistral", "mixtral": "mistral", "codestral": "mistral", "ministral": "mistral", "pixtral": "mistral", "magistral": "mistral", # Cohere "command": "cohere", "aya": "cohere", # Amazon "nova": "amazon", "titan": "amazon", # AI21 "jamba": "ai21", # DeepSeek "deepseek": "deepseek", # Alibaba/Qwen "qwen": "alibaba", "qwq": "alibaba", # Microsoft "phi": "microsoft", # NVIDIA "nemotron": "nvidia", # xAI "grok": "xai", } # Ollama model prefix to vendor mapping (for grouping models by vendor) OLLAMA_MODEL_TO_VENDOR: dict[str, str] = { "llama": "Meta", "codellama": "Meta", "qwen": "Alibaba", "qwq": "Alibaba", "mistral": "Mistral", "ministral": "Mistral", "mixtral": "Mistral", "deepseek": "DeepSeek", "gemma": "Google", "phi": "Microsoft", "command": "Cohere", "aya": "Cohere", "falcon": "TII", "yi": "01.AI", "starcoder": "BigCode", "wizardcoder": "WizardLM", "vicuna": "LMSYS", "openchat": "OpenChat", "solar": "Upstage", "orca": "Microsoft", "dolphin": "Cognitive Computations", "nous": "Nous Research", "neural": "Intel", "zephyr": "HuggingFace", "granite": "IBM", "nemotron": "NVIDIA", "smollm": "HuggingFace", } ================================================ FILE: backend/onyx/llm/cost.py ================================================ """LLM cost calculation utilities.""" from onyx.utils.logger import setup_logger logger = setup_logger() def calculate_llm_cost_cents( model_name: str, prompt_tokens: int, completion_tokens: int, ) -> float: """ Calculate the cost in cents for an LLM API call. Uses litellm's cost_per_token function to get current pricing. Returns 0 if the model is not found or on any error. """ try: import litellm # cost_per_token returns (prompt_cost, completion_cost) in USD prompt_cost_usd, completion_cost_usd = litellm.cost_per_token( model=model_name, prompt_tokens=prompt_tokens, completion_tokens=completion_tokens, ) # Convert to cents (multiply by 100) total_cost_cents = (prompt_cost_usd + completion_cost_usd) * 100 return total_cost_cents except Exception as e: # Log but don't fail - unknown models or errors shouldn't block usage logger.debug( f"Could not calculate cost for model {model_name}: {e}. Assuming cost is 0." ) return 0.0 ================================================ FILE: backend/onyx/llm/factory.py ================================================ from collections.abc import Callable from typing import Any from onyx.auth.schemas import UserRole from onyx.configs.model_configs import GEN_AI_TEMPERATURE from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.enums import LLMModelFlowType from onyx.db.llm import can_user_access_llm_provider from onyx.db.llm import fetch_default_llm_model from onyx.db.llm import fetch_default_vision_model from onyx.db.llm import fetch_existing_llm_provider from onyx.db.llm import fetch_existing_models from onyx.db.llm import fetch_llm_provider_view from onyx.db.llm import fetch_user_group_ids from onyx.db.models import Persona from onyx.db.models import User from onyx.llm.constants import LlmProviderNames from onyx.llm.interfaces import LLM from onyx.llm.multi_llm import LitellmLLM from onyx.llm.override_models import LLMOverride from onyx.llm.utils import get_max_input_tokens_from_llm_provider from onyx.llm.utils import model_supports_image_input from onyx.llm.well_known_providers.constants import ( PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING, ) from onyx.natural_language_processing.utils import get_tokenizer from onyx.server.manage.llm.models import LLMProviderView from onyx.utils.headers import build_llm_extra_headers from onyx.utils.logger import setup_logger logger = setup_logger() def _build_provider_extra_headers( provider: str, custom_config: dict[str, str] | None ) -> dict[str, str]: if provider in PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING and custom_config: raw = custom_config.get(PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING[provider]) api_key = raw.strip() if raw else None if not api_key: return {} return { "Authorization": ( api_key if api_key.lower().startswith("bearer ") else f"Bearer {api_key}" ) } # Passing these will put Onyx on the OpenRouter leaderboard elif provider == LlmProviderNames.OPENROUTER: return { "HTTP-Referer": "https://onyx.app", "X-Title": "Onyx", } return {} def _get_model_configured_max_input_tokens( llm_provider: LLMProviderView, model_name: str, ) -> int | None: for model_configuration in llm_provider.model_configurations: if model_configuration.name == model_name: return model_configuration.max_input_tokens return None def _build_model_kwargs( provider: str, configured_max_input_tokens: int | None, ) -> dict[str, Any]: model_kwargs: dict[str, Any] = {} if ( provider == LlmProviderNames.OLLAMA_CHAT and configured_max_input_tokens and configured_max_input_tokens > 0 ): model_kwargs["num_ctx"] = configured_max_input_tokens return model_kwargs def get_llm_for_persona( persona: Persona | None, user: User, llm_override: LLMOverride | None = None, additional_headers: dict[str, str] | None = None, ) -> LLM: if persona is None: logger.warning("No persona provided, using default LLM") return get_default_llm() provider_name_override = llm_override.model_provider if llm_override else None model_version_override = llm_override.model_version if llm_override else None temperature_override = llm_override.temperature if llm_override else None provider_name = provider_name_override or persona.llm_model_provider_override if not provider_name: return get_default_llm( temperature=temperature_override or GEN_AI_TEMPERATURE, additional_headers=additional_headers, ) with get_session_with_current_tenant() as db_session: provider_model = fetch_existing_llm_provider(provider_name, db_session) if not provider_model: raise ValueError("No LLM provider found") # Fetch user group IDs for access control check user_group_ids = fetch_user_group_ids(db_session, user) if not can_user_access_llm_provider( provider_model, user_group_ids, persona, user.role == UserRole.ADMIN ): logger.warning( "User %s with persona %s cannot access provider %s. Falling back to default provider.", user.id, persona.id, provider_model.name, ) return get_default_llm( temperature=temperature_override or GEN_AI_TEMPERATURE, additional_headers=additional_headers, ) llm_provider = LLMProviderView.from_model(provider_model) model = model_version_override or persona.llm_model_version_override if not model: raise ValueError("No model name found") return llm_from_provider( model_name=model, llm_provider=llm_provider, temperature=temperature_override, additional_headers=additional_headers, ) def get_default_llm_with_vision( timeout: int | None = None, temperature: float | None = None, additional_headers: dict[str, str] | None = None, ) -> LLM | None: """Get an LLM that supports image input, with the following priority: 1. Use the designated default vision provider if it exists and supports image input 2. Fall back to the first LLM provider that supports image input Returns None if no providers exist or if no provider supports images. """ def create_vision_llm(provider: LLMProviderView, model: str) -> LLM: """Helper to create an LLM if the provider supports image input.""" return llm_from_provider( model_name=model, llm_provider=provider, timeout=timeout, temperature=temperature, additional_headers=additional_headers, ) provider_map = {} with get_session_with_current_tenant() as db_session: # Try the default vision provider first default_model = fetch_default_vision_model(db_session) if default_model: if model_supports_image_input( default_model.name, default_model.llm_provider.provider ): logger.info( "Using default vision model: %s (provider=%s)", default_model.name, default_model.llm_provider.provider, ) return create_vision_llm( LLMProviderView.from_model(default_model.llm_provider), default_model.name, ) else: logger.warning( "Default vision model %s (provider=%s) does not support " "image input — falling back to searching all providers", default_model.name, default_model.llm_provider.provider, ) # Fall back to searching all providers models = fetch_existing_models( db_session=db_session, flow_types=[LLMModelFlowType.VISION, LLMModelFlowType.CHAT], ) if not models: logger.warning( "No LLM models with VISION or CHAT flow type found — " "image summarization will be disabled" ) return None for model in models: if model.llm_provider_id not in provider_map: provider_map[model.llm_provider_id] = LLMProviderView.from_model( model.llm_provider ) # Search for viable vision model followed by chat models # Sort models from VISION to CHAT priority sorted_models = sorted( models, key=lambda x: ( LLMModelFlowType.VISION in x.llm_model_flow_types, LLMModelFlowType.CHAT in x.llm_model_flow_types, ), reverse=True, ) for model in sorted_models: if model_supports_image_input(model.name, model.llm_provider.provider): logger.info( "Using fallback vision model: %s (provider=%s)", model.name, model.llm_provider.provider, ) return create_vision_llm( provider_map[model.llm_provider_id], model.name, ) checked_models = [ f"{m.name} (provider={m.llm_provider.provider})" for m in sorted_models ] logger.warning( "No vision-capable model found among %d candidates: %s — " "image summarization will be disabled", len(sorted_models), ", ".join(checked_models), ) return None def llm_from_provider( model_name: str, llm_provider: LLMProviderView, timeout: int | None = None, temperature: float | None = None, additional_headers: dict[str, str] | None = None, ) -> LLM: configured_max_input_tokens = _get_model_configured_max_input_tokens( llm_provider=llm_provider, model_name=model_name ) model_kwargs = _build_model_kwargs( provider=llm_provider.provider, configured_max_input_tokens=configured_max_input_tokens, ) max_input_tokens = ( configured_max_input_tokens if configured_max_input_tokens else get_max_input_tokens_from_llm_provider( llm_provider=llm_provider, model_name=model_name ) ) return get_llm( provider=llm_provider.provider, model=model_name, deployment_name=llm_provider.deployment_name, api_key=llm_provider.api_key, api_base=llm_provider.api_base, api_version=llm_provider.api_version, custom_config=llm_provider.custom_config, timeout=timeout, temperature=temperature, additional_headers=additional_headers, max_input_tokens=max_input_tokens, model_kwargs=model_kwargs, ) def get_llm_for_contextual_rag(model_name: str, model_provider: str) -> LLM: with get_session_with_current_tenant() as db_session: llm_provider = fetch_llm_provider_view(db_session, model_provider) if not llm_provider: raise ValueError("No LLM provider with name {} found".format(model_provider)) return llm_from_provider( model_name=model_name, llm_provider=llm_provider, ) def get_default_llm( timeout: int | None = None, temperature: float | None = None, additional_headers: dict[str, str] | None = None, ) -> LLM: with get_session_with_current_tenant() as db_session: model = fetch_default_llm_model(db_session) if not model: raise ValueError("No default LLM model found") return llm_from_provider( model_name=model.name, llm_provider=LLMProviderView.from_model(model.llm_provider), timeout=timeout, temperature=temperature, additional_headers=additional_headers, ) def get_llm( provider: str, model: str, max_input_tokens: int, deployment_name: str | None, api_key: str | None = None, api_base: str | None = None, api_version: str | None = None, custom_config: dict[str, str] | None = None, temperature: float | None = None, timeout: int | None = None, additional_headers: dict[str, str] | None = None, model_kwargs: dict[str, Any] | None = None, ) -> LLM: if temperature is None: temperature = GEN_AI_TEMPERATURE extra_headers = build_llm_extra_headers(additional_headers) # NOTE: this is needed since Ollama API key is optional # User may access Ollama cloud via locally hosted instance (logged in) # or just via the cloud API (not logged in, using API key) provider_extra_headers = _build_provider_extra_headers(provider, custom_config) if provider_extra_headers: extra_headers.update(provider_extra_headers) return LitellmLLM( model_provider=provider, model_name=model, deployment_name=deployment_name, api_key=api_key, api_base=api_base, api_version=api_version, timeout=timeout, temperature=temperature, custom_config=custom_config, extra_headers=extra_headers, model_kwargs=model_kwargs or {}, max_input_tokens=max_input_tokens, ) def get_llm_tokenizer_encode_func(llm: LLM) -> Callable[[str], list[int]]: """Get the tokenizer encode function for an LLM. Args: llm: The LLM instance to get the tokenizer for Returns: A callable that encodes a string into a list of token IDs """ llm_provider = llm.config.model_provider llm_model_name = llm.config.model_name llm_tokenizer = get_tokenizer( model_name=llm_model_name, provider_type=llm_provider, ) return llm_tokenizer.encode def get_llm_token_counter(llm: LLM) -> Callable[[str], int]: tokenizer_encode_func = get_llm_tokenizer_encode_func(llm) return lambda text: len(tokenizer_encode_func(text)) ================================================ FILE: backend/onyx/llm/interfaces.py ================================================ import abc from collections.abc import Iterator from braintrust import traced from pydantic import BaseModel from onyx.llm.model_response import ModelResponse from onyx.llm.model_response import ModelResponseStream from onyx.llm.models import LanguageModelInput from onyx.llm.models import ReasoningEffort from onyx.llm.models import ToolChoiceOptions from onyx.utils.logger import setup_logger logger = setup_logger() class LLMUserIdentity(BaseModel): user_id: str | None = None session_id: str | None = None class LLMConfig(BaseModel): model_provider: str model_name: str temperature: float api_key: str | None = None api_base: str | None = None api_version: str | None = None deployment_name: str | None = None custom_config: dict[str, str] | None = None max_input_tokens: int # This disables the "model_" protected namespace for pydantic model_config = {"protected_namespaces": ()} class LLM(abc.ABC): @property @abc.abstractmethod def config(self) -> LLMConfig: raise NotImplementedError @traced(name="invoke llm", type="llm") def invoke( self, prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, structured_response_format: dict | None = None, timeout_override: int | None = None, max_tokens: int | None = None, reasoning_effort: ReasoningEffort = ReasoningEffort.AUTO, user_identity: LLMUserIdentity | None = None, ) -> "ModelResponse": raise NotImplementedError def stream( self, prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, structured_response_format: dict | None = None, timeout_override: int | None = None, max_tokens: int | None = None, reasoning_effort: ReasoningEffort = ReasoningEffort.AUTO, user_identity: LLMUserIdentity | None = None, ) -> Iterator[ModelResponseStream]: raise NotImplementedError ================================================ FILE: backend/onyx/llm/litellm_singleton/__init__.py ================================================ """ Singleton module for litellm configuration. This ensures litellm is configured exactly once when first imported. All other modules should import litellm from here instead of directly. """ import litellm from .config import initialize_litellm from .monkey_patches import apply_monkey_patches initialize_litellm() apply_monkey_patches() # Export the configured litellm module and model __all__ = ["litellm"] ================================================ FILE: backend/onyx/llm/litellm_singleton/config.py ================================================ import json from pathlib import Path import litellm from onyx.utils.logger import setup_logger logger = setup_logger() def configure_litellm_settings() -> None: # If a user configures a different model and it doesn't support all the same # parameters like frequency and presence, just ignore them litellm.drop_params = True litellm.telemetry = False litellm.modify_params = True litellm.add_function_to_prompt = False litellm.suppress_debug_info = True # TODO: We might not need to register ollama_chat in addition to ollama but let's just do it for good measure for now. def register_ollama_models() -> None: litellm.register_model( model_cost={ # GPT-OSS models "ollama_chat/gpt-oss:120b-cloud": {"supports_function_calling": True}, "ollama_chat/gpt-oss:120b": {"supports_function_calling": True}, "ollama_chat/gpt-oss:20b-cloud": {"supports_function_calling": True}, "ollama_chat/gpt-oss:20b": {"supports_function_calling": True}, "ollama/gpt-oss:120b-cloud": {"supports_function_calling": True}, "ollama/gpt-oss:120b": {"supports_function_calling": True}, "ollama/gpt-oss:20b-cloud": {"supports_function_calling": True}, "ollama/gpt-oss:20b": {"supports_function_calling": True}, # DeepSeek models "ollama_chat/deepseek-r1:latest": {"supports_function_calling": True}, "ollama_chat/deepseek-r1:1.5b": {"supports_function_calling": True}, "ollama_chat/deepseek-r1:7b": {"supports_function_calling": True}, "ollama_chat/deepseek-r1:8b": {"supports_function_calling": True}, "ollama_chat/deepseek-r1:14b": {"supports_function_calling": True}, "ollama_chat/deepseek-r1:32b": {"supports_function_calling": True}, "ollama_chat/deepseek-r1:70b": {"supports_function_calling": True}, "ollama_chat/deepseek-r1:671b": {"supports_function_calling": True}, "ollama_chat/deepseek-v3.1:latest": {"supports_function_calling": True}, "ollama_chat/deepseek-v3.1:671b": {"supports_function_calling": True}, "ollama_chat/deepseek-v3.1:671b-cloud": {"supports_function_calling": True}, "ollama/deepseek-r1:latest": {"supports_function_calling": True}, "ollama/deepseek-r1:1.5b": {"supports_function_calling": True}, "ollama/deepseek-r1:7b": {"supports_function_calling": True}, "ollama/deepseek-r1:8b": {"supports_function_calling": True}, "ollama/deepseek-r1:14b": {"supports_function_calling": True}, "ollama/deepseek-r1:32b": {"supports_function_calling": True}, "ollama/deepseek-r1:70b": {"supports_function_calling": True}, "ollama/deepseek-r1:671b": {"supports_function_calling": True}, "ollama/deepseek-v3.1:latest": {"supports_function_calling": True}, "ollama/deepseek-v3.1:671b": {"supports_function_calling": True}, "ollama/deepseek-v3.1:671b-cloud": {"supports_function_calling": True}, # Gemma3 models "ollama_chat/gemma3:latest": {"supports_function_calling": True}, "ollama_chat/gemma3:270m": {"supports_function_calling": True}, "ollama_chat/gemma3:1b": {"supports_function_calling": True}, "ollama_chat/gemma3:4b": {"supports_function_calling": True}, "ollama_chat/gemma3:12b": {"supports_function_calling": True}, "ollama_chat/gemma3:27b": {"supports_function_calling": True}, "ollama/gemma3:latest": {"supports_function_calling": True}, "ollama/gemma3:270m": {"supports_function_calling": True}, "ollama/gemma3:1b": {"supports_function_calling": True}, "ollama/gemma3:4b": {"supports_function_calling": True}, "ollama/gemma3:12b": {"supports_function_calling": True}, "ollama/gemma3:27b": {"supports_function_calling": True}, # Qwen models "ollama_chat/qwen3-coder:latest": {"supports_function_calling": True}, "ollama_chat/qwen3-coder:30b": {"supports_function_calling": True}, "ollama_chat/qwen3-coder:480b": {"supports_function_calling": True}, "ollama_chat/qwen3-coder:480b-cloud": {"supports_function_calling": True}, "ollama_chat/qwen3-vl:latest": {"supports_function_calling": True}, "ollama_chat/qwen3-vl:2b": {"supports_function_calling": True}, "ollama_chat/qwen3-vl:4b": {"supports_function_calling": True}, "ollama_chat/qwen3-vl:8b": {"supports_function_calling": True}, "ollama_chat/qwen3-vl:30b": {"supports_function_calling": True}, "ollama_chat/qwen3-vl:32b": {"supports_function_calling": True}, "ollama_chat/qwen3-vl:235b": {"supports_function_calling": True}, "ollama_chat/qwen3-vl:235b-cloud": {"supports_function_calling": True}, "ollama_chat/qwen3-vl:235b-instruct-cloud": { "supports_function_calling": True }, "ollama/qwen3-coder:latest": {"supports_function_calling": True}, "ollama/qwen3-coder:30b": {"supports_function_calling": True}, "ollama/qwen3-coder:480b": {"supports_function_calling": True}, "ollama/qwen3-coder:480b-cloud": {"supports_function_calling": True}, "ollama/qwen3-vl:latest": {"supports_function_calling": True}, "ollama/qwen3-vl:2b": {"supports_function_calling": True}, "ollama/qwen3-vl:4b": {"supports_function_calling": True}, "ollama/qwen3-vl:8b": {"supports_function_calling": True}, "ollama/qwen3-vl:30b": {"supports_function_calling": True}, "ollama/qwen3-vl:32b": {"supports_function_calling": True}, "ollama/qwen3-vl:235b": {"supports_function_calling": True}, "ollama/qwen3-vl:235b-cloud": {"supports_function_calling": True}, "ollama/qwen3-vl:235b-instruct-cloud": {"supports_function_calling": True}, # Kimi "ollama_chat/kimi-k2:1t": {"supports_function_calling": True}, "ollama_chat/kimi-k2:1t-cloud": {"supports_function_calling": True}, "ollama/kimi-k2:1t": {"supports_function_calling": True}, "ollama/kimi-k2:1t-cloud": {"supports_function_calling": True}, # GLM "ollama_chat/glm-4.6:cloud": {"supports_function_calling": True}, "ollama_chat/glm-4.6": {"supports_function_calling": True}, "ollama/glm-4.6": {"supports_function_calling": True}, "ollama/glm-4.6-cloud": {"supports_function_calling": True}, } ) def load_model_metadata_enrichments() -> None: """ Load model metadata enrichments from JSON file and merge into litellm.model_cost. This adds model_vendor, display_name, and model_version fields to litellm's model_cost dict. These fields are used by the UI to display models grouped by vendor with human-friendly names. Once LiteLLM accepts our upstream PR to add these fields natively, this function and the JSON file can be removed. """ enrichments_path = Path(__file__).parent.parent / "model_metadata_enrichments.json" if not enrichments_path.exists(): logger.warning(f"Model metadata enrichments file not found: {enrichments_path}") return try: with open(enrichments_path) as f: enrichments = json.load(f) # Merge enrichments into litellm.model_cost for model_key, metadata in enrichments.items(): if model_key in litellm.model_cost: # Update existing entry with our metadata litellm.model_cost[model_key].update(metadata) else: # Model not in litellm.model_cost - add it with just our metadata litellm.model_cost[model_key] = metadata logger.info(f"Loaded model metadata enrichments for {len(enrichments)} models") # Clear the model name parser cache since enrichments are now loaded # This ensures any parsing done before enrichments were loaded gets refreshed try: from onyx.llm.model_name_parser import parse_litellm_model_name parse_litellm_model_name.cache_clear() except ImportError: pass # Parser not yet imported, no cache to clear except Exception as e: logger.error(f"Failed to load model metadata enrichments: {e}") def initialize_litellm() -> None: configure_litellm_settings() register_ollama_models() load_model_metadata_enrichments() ================================================ FILE: backend/onyx/llm/litellm_singleton/monkey_patches.py ================================================ """ LiteLLM Monkey Patches This module addresses the following issues in LiteLLM: Status checked against LiteLLM v1.81.6-nightly (2026-02-02): 1. Ollama Streaming Reasoning Content (_patch_ollama_chunk_parser): - LiteLLM's chunk_parser doesn't properly handle reasoning content in streaming responses from Ollama - Processes native "thinking" field from Ollama responses - Also handles ... tags in content for models that use that format - Tracks reasoning state to properly separate thinking from regular content STATUS: STILL NEEDED - LiteLLM has a bug where it only yields thinking content on the first two chunks, then stops (lines 504-510). Our patch correctly yields ALL thinking chunks. The upstream logic sets finished_reasoning_content=True on the second chunk instead of when regular content starts. 2. OpenAI Responses API Parallel Tool Calls (_patch_openai_responses_parallel_tool_calls): - LiteLLM's translate_responses_chunk_to_openai_stream hardcodes index=0 for all tool calls - This breaks parallel tool calls where multiple functions are called simultaneously - The OpenAI Responses API provides output_index in streaming events to track which tool call each event belongs to STATUS: STILL NEEDED - LiteLLM hardcodes index=0 in translate_responses_chunk_to_openai_stream for response.output_item.added (line 962), response.function_call_arguments.delta (line 989), and response.output_item.done (line 1033). Our patch uses output_index from the event to properly track parallel tool calls. 3. OpenAI Responses API Non-Streaming (_patch_openai_responses_transform_response): - LiteLLM's transform_response doesn't properly concatenate multiple reasoning summary parts in non-streaming responses - Multiple ReasoningSummaryItem objects should be joined with newlines STATUS: STILL NEEDED - LiteLLM's _convert_response_output_to_choices (lines 366-370) only keeps the LAST summary item text, discarding earlier parts. Our patch concatenates all summary texts with double newlines. 4. Azure Responses API Fake Streaming (_patch_azure_responses_should_fake_stream): - LiteLLM uses "fake streaming" (MockResponsesAPIStreamingIterator) for models not in its database, which buffers the entire response before yielding - This causes poor time-to-first-token for Azure custom model deployments - Azure's Responses API supports native streaming, so we force real streaming STATUS: STILL NEEDED - AzureOpenAIResponsesAPIConfig does NOT override should_fake_stream, so it inherits from OpenAIResponsesAPIConfig which returns True for models not in litellm.utils.supports_native_streaming(). Custom Azure deployments will still use fake streaming without this patch. # Note: 5 and 6 are to supress a warning and may fix usage info but is not strictly required for the app to run 5. Responses API Usage Format Mismatch (_patch_responses_api_usage_format): - LiteLLM uses model_construct as a fallback in multiple places when ResponsesAPIResponse validation fails - This bypasses the usage validator, allowing chat completion format usage (completion_tokens, prompt_tokens) to be stored instead of Responses API format (input_tokens, output_tokens) - When model_dump() is later called, Pydantic emits a serialization warning STATUS: STILL NEEDED - Multiple files use model_construct which bypasses validation: openai/responses/transformation.py, chatgpt/responses/transformation.py, manus/responses/transformation.py, volcengine/responses/transformation.py, and handler.py. Our patch wraps ResponsesAPIResponse.model_construct itself to transform usage in all code paths. 6. Logging Usage Transformation Warning (_patch_logging_assembled_streaming_response): - LiteLLM's _get_assembled_streaming_response in litellm_logging.py transforms ResponseAPIUsage to chat completion format and sets it as a dict on the ResponsesAPIResponse.usage field - This replaces the proper ResponseAPIUsage object with a dict, causing Pydantic to emit a serialization warning when model_dump() is called later STATUS: STILL NEEDED - litellm_core_utils/litellm_logging.py lines 3185-3199 set usage as a dict with chat completion format instead of keeping it as ResponseAPIUsage. Our patch creates a deep copy before modification. 7. Responses API metadata=None TypeError (_patch_responses_metadata_none): - LiteLLM's @client decorator wrapper in utils.py uses kwargs.get("metadata", {}) to check for router calls, but when metadata is explicitly None (key exists with value None), the default {} is not used - This causes "argument of type 'NoneType' is not iterable" TypeError which swallows the real exception (e.g. AuthenticationError for wrong API key) - Surfaces as: APIConnectionError: OpenAIException - argument of type 'NoneType' is not iterable STATUS: STILL NEEDED - litellm/utils.py wrapper function (line 1721) does not guard against metadata being explicitly None. Triggered when Responses API bridge passes **litellm_params containing metadata=None. """ import time import uuid from typing import Any from typing import cast from typing import List from typing import Optional from litellm.completion_extras.litellm_responses_transformation.transformation import ( LiteLLMResponsesTransformationHandler, ) from litellm.completion_extras.litellm_responses_transformation.transformation import ( OpenAiResponsesToChatCompletionStreamIterator, ) from litellm.llms.ollama.chat.transformation import OllamaChatCompletionResponseIterator from litellm.llms.ollama.common_utils import OllamaError from litellm.types.utils import ChatCompletionUsageBlock from litellm.types.utils import ModelResponseStream def _patch_ollama_chunk_parser() -> None: """ Patches OllamaChatCompletionResponseIterator.chunk_parser to properly handle reasoning content and content in streaming responses. """ if ( getattr(OllamaChatCompletionResponseIterator.chunk_parser, "__name__", "") == "_patched_chunk_parser" ): return def _patched_chunk_parser(self: Any, chunk: dict) -> ModelResponseStream: try: """ Expected chunk format: { "model": "llama3.1", "created_at": "2025-05-24T02:12:05.859654Z", "message": { "role": "assistant", "content": "", "tool_calls": [{ "function": { "name": "get_latest_album_ratings", "arguments": { "artist_name": "Taylor Swift" } } }] }, "done_reason": "stop", "done": true, ... } Need to: - convert 'message' to 'delta' - return finish_reason when done is true - return usage when done is true """ from litellm.types.utils import Delta from litellm.types.utils import StreamingChoices # process tool calls - if complete function arg - add id to tool call tool_calls = chunk["message"].get("tool_calls") if tool_calls is not None: for tool_call in tool_calls: function_args = tool_call.get("function").get("arguments") if function_args is not None and len(function_args) > 0: is_function_call_complete = self._is_function_call_complete( function_args ) if is_function_call_complete: tool_call["id"] = str(uuid.uuid4()) # PROCESS REASONING CONTENT reasoning_content: Optional[str] = None content: Optional[str] = None thinking_content = chunk["message"].get("thinking") if thinking_content: # Truthy check: skips None and empty string "" reasoning_content = thinking_content if self.started_reasoning_content is False: self.started_reasoning_content = True if chunk["message"].get("content") is not None: message_content = chunk["message"].get("content") # Track whether we are inside ... tagged content. in_think_tag_block = bool(getattr(self, "_in_think_tag_block", False)) if "" in message_content: message_content = message_content.replace("", "") self.started_reasoning_content = True self.finished_reasoning_content = False in_think_tag_block = True if "" in message_content and self.started_reasoning_content: message_content = message_content.replace("", "") self.finished_reasoning_content = True in_think_tag_block = False # For native Ollama "thinking" streams, content without active # think tags indicates a transition into regular assistant output. if ( self.started_reasoning_content and not self.finished_reasoning_content and not in_think_tag_block and not thinking_content ): self.finished_reasoning_content = True self._in_think_tag_block = in_think_tag_block # When Ollama returns both "thinking" and "content" in the same # chunk, preserve both instead of classifying content as reasoning. if thinking_content and not in_think_tag_block: content = message_content elif ( self.started_reasoning_content and not self.finished_reasoning_content ): reasoning_content = message_content else: content = message_content delta = Delta( content=content, reasoning_content=reasoning_content, tool_calls=tool_calls, ) if chunk["done"] is True: finish_reason = chunk.get("done_reason", "stop") choices = [ StreamingChoices( delta=delta, finish_reason=finish_reason, ) ] else: choices = [ StreamingChoices( delta=delta, ) ] usage = ChatCompletionUsageBlock( prompt_tokens=chunk.get("prompt_eval_count", 0), completion_tokens=chunk.get("eval_count", 0), total_tokens=chunk.get("prompt_eval_count", 0) + chunk.get("eval_count", 0), ) return ModelResponseStream( id=str(uuid.uuid4()), object="chat.completion.chunk", created=int(time.time()), # ollama created_at is in UTC usage=usage, model=chunk["model"], choices=choices, ) except KeyError as e: raise OllamaError( message=f"KeyError: {e}, Got unexpected response from Ollama: {chunk}", status_code=400, headers={"Content-Type": "application/json"}, ) except Exception as e: raise e OllamaChatCompletionResponseIterator.chunk_parser = _patched_chunk_parser # type: ignore[method-assign] def _patch_openai_responses_parallel_tool_calls() -> None: """ Patches OpenAiResponsesToChatCompletionStreamIterator to properly handle: 1. Parallel tool calls by using output_index from streaming events 2. Reasoning summary sections by inserting newlines between different summary indices LiteLLM's implementation hardcodes index=0 for all tool calls, breaking parallel tool calls. The OpenAI Responses API provides output_index in each event to track which tool call the event belongs to. STATUS: STILL NEEDED - LiteLLM hardcodes index=0 in translate_responses_chunk_to_openai_stream for response.output_item.added (line 962), response.function_call_arguments.delta (line 989), and response.output_item.done (line 1033). Our patch uses output_index from the event to properly track parallel tool calls. """ if ( getattr( OpenAiResponsesToChatCompletionStreamIterator.chunk_parser, "__name__", "", ) == "_patched_responses_chunk_parser" ): return def _patched_responses_chunk_parser( self: Any, chunk: dict ) -> "ModelResponseStream": from pydantic import BaseModel from litellm.types.llms.openai import ( ChatCompletionToolCallFunctionChunk, ResponsesAPIStreamEvents, ) from litellm.types.utils import ( ChatCompletionToolCallChunk, Delta, ModelResponseStream, StreamingChoices, ) parsed_chunk = chunk if not parsed_chunk: raise ValueError("Chat provider: Empty parsed_chunk") if isinstance(parsed_chunk, BaseModel): parsed_chunk = parsed_chunk.model_dump() if not isinstance(parsed_chunk, dict): raise ValueError(f"Chat provider: Invalid chunk type {type(parsed_chunk)}") event_type = parsed_chunk.get("type") if isinstance(event_type, ResponsesAPIStreamEvents): event_type = event_type.value # Get the output_index for proper parallel tool call tracking output_index = parsed_chunk.get("output_index", 0) if event_type == "response.output_item.added": output_item = parsed_chunk.get("item", {}) if output_item.get("type") == "function_call": provider_specific_fields = output_item.get("provider_specific_fields") if provider_specific_fields and not isinstance( provider_specific_fields, dict ): provider_specific_fields = ( dict(provider_specific_fields) if hasattr(provider_specific_fields, "__dict__") else {} ) function_chunk = ChatCompletionToolCallFunctionChunk( name=output_item.get("name", None), arguments=parsed_chunk.get("arguments", ""), ) if provider_specific_fields: function_chunk["provider_specific_fields"] = ( provider_specific_fields ) tool_call_chunk = ChatCompletionToolCallChunk( id=output_item.get("call_id"), index=output_index, # Use output_index for parallel tool calls type="function", function=function_chunk, ) if provider_specific_fields: tool_call_chunk.provider_specific_fields = provider_specific_fields # type: ignore return ModelResponseStream( choices=[ StreamingChoices( index=0, delta=Delta(tool_calls=[tool_call_chunk]), finish_reason=None, ) ] ) elif event_type == "response.function_call_arguments.delta": content_part: Optional[str] = parsed_chunk.get("delta", None) if content_part: return ModelResponseStream( choices=[ StreamingChoices( index=0, delta=Delta( tool_calls=[ ChatCompletionToolCallChunk( id=None, index=output_index, # Use output_index for parallel tool calls type="function", function=ChatCompletionToolCallFunctionChunk( name=None, arguments=content_part ), ) ] ), finish_reason=None, ) ] ) else: raise ValueError( f"Chat provider: Invalid function argument delta {parsed_chunk}" ) elif event_type == "response.output_item.done": output_item = parsed_chunk.get("item", {}) if output_item.get("type") == "function_call": provider_specific_fields = output_item.get("provider_specific_fields") if provider_specific_fields and not isinstance( provider_specific_fields, dict ): provider_specific_fields = ( dict(provider_specific_fields) if hasattr(provider_specific_fields, "__dict__") else {} ) function_chunk = ChatCompletionToolCallFunctionChunk( name=output_item.get("name", None), arguments="", # responses API sends everything again, we don't need it ) if provider_specific_fields: function_chunk["provider_specific_fields"] = ( provider_specific_fields ) tool_call_chunk = ChatCompletionToolCallChunk( id=output_item.get("call_id"), index=output_index, # Use output_index for parallel tool calls type="function", function=function_chunk, ) if provider_specific_fields: tool_call_chunk.provider_specific_fields = provider_specific_fields # type: ignore return ModelResponseStream( choices=[ StreamingChoices( index=0, delta=Delta(tool_calls=[tool_call_chunk]), finish_reason="tool_calls", ) ] ) elif event_type == "response.reasoning_summary_text.delta": # Handle reasoning summary with newlines between sections content_part = parsed_chunk.get("delta", None) if content_part: summary_index = parsed_chunk.get("summary_index", 0) # Track the last summary index to insert newlines between parts last_summary_index = getattr( self, "_last_reasoning_summary_index", None ) if ( last_summary_index is not None and summary_index != last_summary_index ): # New summary part started, prepend newlines to separate them content_part = "\n\n" + content_part self._last_reasoning_summary_index = summary_index return ModelResponseStream( choices=[ StreamingChoices( index=cast(int, summary_index), delta=Delta(reasoning_content=content_part), ) ] ) # For all other event types, use the original static method return OpenAiResponsesToChatCompletionStreamIterator.translate_responses_chunk_to_openai_stream( parsed_chunk ) _patched_responses_chunk_parser.__name__ = "_patched_responses_chunk_parser" OpenAiResponsesToChatCompletionStreamIterator.chunk_parser = _patched_responses_chunk_parser # type: ignore[method-assign] def _patch_openai_responses_transform_response() -> None: """ Patches LiteLLMResponsesTransformationHandler.transform_response to properly concatenate multiple reasoning summary parts with newlines in non-streaming responses. """ # Store the original method original_transform_response = ( LiteLLMResponsesTransformationHandler.transform_response ) if ( getattr( original_transform_response, "__name__", "", ) == "_patched_transform_response" ): return def _patched_transform_response( self: Any, model: str, raw_response: Any, model_response: Any, logging_obj: Any, request_data: dict, messages: List[Any], optional_params: dict, litellm_params: dict, encoding: Any, api_key: Optional[str] = None, json_mode: Optional[bool] = None, ) -> Any: """ Patched transform_response that properly concatenates reasoning summary parts with newlines. """ from openai.types.responses.response import Response as ResponsesAPIResponse from openai.types.responses.response_reasoning_item import ResponseReasoningItem # Check if raw_response has reasoning items that need concatenation if isinstance(raw_response, ResponsesAPIResponse) and raw_response.output: for item in raw_response.output: if isinstance(item, ResponseReasoningItem) and item.summary: # Concatenate summary texts with double newlines summary_texts = [] for summary_item in item.summary: text = getattr(summary_item, "text", "") if text: summary_texts.append(text) if len(summary_texts) > 1: # Modify the first summary item to contain all concatenated text combined_text = "\n\n".join(summary_texts) if hasattr(item.summary[0], "text"): # Create a modified copy of the response with concatenated text # Since OpenAI types are typically frozen, we need to work around this # by modifying the object after the fact or using the result pass # The fix is applied in the result processing below # Call the original method result = original_transform_response( self, model, raw_response, model_response, logging_obj, request_data, messages, optional_params, litellm_params, encoding, api_key, json_mode, ) # Post-process: If there are multiple summary items, fix the reasoning_content if isinstance(raw_response, ResponsesAPIResponse) and raw_response.output: for item in raw_response.output: if isinstance(item, ResponseReasoningItem) and item.summary: if len(item.summary) > 1: # Concatenate all summary texts with double newlines summary_texts = [] for summary_item in item.summary: text = getattr(summary_item, "text", "") if text: summary_texts.append(text) if summary_texts: combined_text = "\n\n".join(summary_texts) # Update the reasoning_content in the result choices if hasattr(result, "choices"): for choice in result.choices: if hasattr(choice, "message") and hasattr( choice.message, "reasoning_content" ): choice.message.reasoning_content = combined_text break # Only process the first reasoning item return result _patched_transform_response.__name__ = "_patched_transform_response" LiteLLMResponsesTransformationHandler.transform_response = _patched_transform_response # type: ignore[method-assign] def _patch_azure_responses_should_fake_stream() -> None: """ Patches AzureOpenAIResponsesAPIConfig.should_fake_stream to always return False. By default, LiteLLM uses "fake streaming" (MockResponsesAPIStreamingIterator) for models not in its database. This causes Azure custom model deployments to buffer the entire response before yielding, resulting in poor time-to-first-token. Azure's Responses API supports native streaming, so we override this to always use real streaming (SyncResponsesAPIStreamingIterator). """ from litellm.llms.azure.responses.transformation import ( AzureOpenAIResponsesAPIConfig, ) if ( getattr(AzureOpenAIResponsesAPIConfig.should_fake_stream, "__name__", "") == "_patched_should_fake_stream" ): return def _patched_should_fake_stream( self: Any, # noqa: ARG001 model: Optional[str], # noqa: ARG001 stream: Optional[bool], # noqa: ARG001 custom_llm_provider: Optional[str] = None, # noqa: ARG001 ) -> bool: # Azure Responses API supports native streaming - never fake it return False _patched_should_fake_stream.__name__ = "_patched_should_fake_stream" AzureOpenAIResponsesAPIConfig.should_fake_stream = _patched_should_fake_stream # type: ignore[method-assign] def _patch_responses_api_usage_format() -> None: """ Patches ResponsesAPIResponse.model_construct to properly transform usage data from chat completion format to Responses API format. LiteLLM uses model_construct as a fallback in multiple places when ResponsesAPIResponse validation fails. This bypasses the usage validator, allowing usage data in chat completion format (completion_tokens, prompt_tokens) to be stored instead of Responses API format (input_tokens, output_tokens), causing Pydantic serialization warnings. This patch wraps model_construct to transform usage before construction, ensuring the correct type regardless of which code path calls model_construct. Affected locations in LiteLLM: - litellm/llms/openai/responses/transformation.py (lines 183, 563) - litellm/llms/chatgpt/responses/transformation.py (line 153) - litellm/llms/manus/responses/transformation.py (lines 243, 334) - litellm/llms/volcengine/responses/transformation.py (line 280) - litellm/completion_extras/litellm_responses_transformation/handler.py (line 51) """ from litellm.types.llms.openai import ResponseAPIUsage, ResponsesAPIResponse original_model_construct = ResponsesAPIResponse.model_construct if getattr(original_model_construct, "_is_patched", False): return @classmethod # type: ignore[misc] def _patched_model_construct( cls: Any, _fields_set: Optional[set[str]] = None, **values: Any, ) -> "ResponsesAPIResponse": """ Patched model_construct that ensures usage is a ResponseAPIUsage object. """ # Transform usage if present and not already the correct type if "usage" in values and values["usage"] is not None: usage = values["usage"] if not isinstance(usage, ResponseAPIUsage): if isinstance(usage, dict): values = dict(values) # Don't mutate original # Check if it's in chat completion format if "prompt_tokens" in usage or "completion_tokens" in usage: # Transform from chat completion format values["usage"] = ResponseAPIUsage( input_tokens=usage.get("prompt_tokens", 0), output_tokens=usage.get("completion_tokens", 0), total_tokens=usage.get("total_tokens", 0), ) elif "input_tokens" in usage or "output_tokens" in usage: # Already in Responses API format, just convert to proper type values["usage"] = ResponseAPIUsage( input_tokens=usage.get("input_tokens", 0), output_tokens=usage.get("output_tokens", 0), total_tokens=usage.get("total_tokens", 0), ) # Call original model_construct (need to call it as unbound method) return original_model_construct.__func__(cls, _fields_set, **values) # type: ignore[attr-defined] _patched_model_construct._is_patched = True # type: ignore[attr-defined] ResponsesAPIResponse.model_construct = _patched_model_construct # type: ignore[method-assign, assignment] def _patch_logging_assembled_streaming_response() -> None: """ Patches LiteLLMLoggingObj._get_assembled_streaming_response to create a deep copy of the ResponsesAPIResponse before modifying its usage field. The original code transforms usage to chat completion format and sets it as a dict directly on the ResponsesAPIResponse.usage field. This mutates the original object, causing Pydantic serialization warnings when model_dump() is called later because the usage field contains a dict instead of the expected ResponseAPIUsage type. This patch creates a copy of the response before modification, preserving the original object with its proper ResponseAPIUsage type. """ from litellm import LiteLLMLoggingObj from litellm.responses.utils import ResponseAPILoggingUtils from litellm.types.llms.openai import ( ResponseAPIUsage, ResponseCompletedEvent, ResponsesAPIResponse, ) from litellm.types.utils import ModelResponse, TextCompletionResponse original_method = LiteLLMLoggingObj._get_assembled_streaming_response if getattr(original_method, "_is_patched", False): return def _patched_get_assembled_streaming_response( self: Any, # noqa: ARG001 result: Any, start_time: Any, # noqa: ARG001 end_time: Any, # noqa: ARG001 is_async: bool, # noqa: ARG001 streaming_chunks: List[Any], # noqa: ARG001 ) -> Any: """ Patched version that creates a copy before modifying usage. The original LiteLLM code transforms usage to chat completion format and sets it directly as a dict, which causes Pydantic serialization warnings. This patch uses model_construct to rebuild the response with the transformed usage, ensuring proper typing. """ if isinstance(result, ModelResponse): return result elif isinstance(result, TextCompletionResponse): return result elif isinstance(result, ResponseCompletedEvent): # Get the original response data original_response = result.response response_data = original_response.model_dump() # Transform usage if present if isinstance(original_response.usage, ResponseAPIUsage): transformed_usage = ( ResponseAPILoggingUtils._transform_response_api_usage_to_chat_usage( original_response.usage ) ) # Put the transformed usage (in chat completion format) into response_data # Our patched model_construct will convert it back to ResponseAPIUsage response_data["usage"] = ( transformed_usage.model_dump() if hasattr(transformed_usage, "model_dump") else dict(transformed_usage) ) # Rebuild using model_construct - our patch ensures usage is properly typed response_copy = ResponsesAPIResponse.model_construct(**response_data) # Copy hidden params if hasattr(original_response, "_hidden_params"): response_copy._hidden_params = dict(original_response._hidden_params) return response_copy else: return None _patched_get_assembled_streaming_response._is_patched = True # type: ignore[attr-defined] LiteLLMLoggingObj._get_assembled_streaming_response = _patched_get_assembled_streaming_response # type: ignore[method-assign] def _patch_responses_metadata_none() -> None: """ Patches litellm.responses to normalize metadata=None to metadata={} in kwargs. LiteLLM's @client decorator wrapper in utils.py (line 1721) does: _is_litellm_router_call = "model_group" in kwargs.get("metadata", {}) When metadata is explicitly None in kwargs, kwargs.get("metadata", {}) returns None (the key exists, so the default is not used), causing: TypeError: argument of type 'NoneType' is not iterable This swallows the real exception (e.g. AuthenticationError) and surfaces as: APIConnectionError: OpenAIException - argument of type 'NoneType' is not iterable This happens when the Responses API bridge calls litellm.responses() with **litellm_params which may contain metadata=None. STATUS: STILL NEEDED - litellm/utils.py wrapper function uses kwargs.get("metadata", {}) which does not guard against metadata being explicitly None. Same pattern exists on line 1407 for async path. """ import litellm as _litellm from functools import wraps original_responses = _litellm.responses if getattr(original_responses, "_metadata_patched", False): return @wraps(original_responses) def _patched_responses(*args: Any, **kwargs: Any) -> Any: if kwargs.get("metadata") is None: kwargs["metadata"] = {} return original_responses(*args, **kwargs) _patched_responses._metadata_patched = True # type: ignore[attr-defined] _litellm.responses = _patched_responses def apply_monkey_patches() -> None: """ Apply all necessary monkey patches to LiteLLM for compatibility. This includes: - Patching OllamaChatCompletionResponseIterator.chunk_parser for streaming content - Patching translate_responses_chunk_to_openai_stream for parallel tool calls - Patching LiteLLMResponsesTransformationHandler.transform_response for non-streaming responses - Patching AzureOpenAIResponsesAPIConfig.should_fake_stream to enable native streaming - Patching ResponsesAPIResponse.model_construct to fix usage format in all code paths - Patching LiteLLMLoggingObj._get_assembled_streaming_response to avoid mutating original response - Patching litellm.responses to fix metadata=None causing TypeError in error handling """ _patch_ollama_chunk_parser() _patch_openai_responses_parallel_tool_calls() _patch_openai_responses_transform_response() _patch_azure_responses_should_fake_stream() _patch_responses_api_usage_format() _patch_logging_assembled_streaming_response() _patch_responses_metadata_none() ================================================ FILE: backend/onyx/llm/model_metadata_enrichments.json ================================================ { "ai21.j2-mid-v1": { "display_name": "J2 Mid", "model_vendor": "ai21", "model_version": "v1" }, "ai21.j2-ultra-v1": { "display_name": "J2 Ultra", "model_vendor": "ai21", "model_version": "v1" }, "ai21.jamba-1-5-large-v1:0": { "display_name": "Jamba 1.5 Large", "model_vendor": "ai21", "model_version": "v1:0" }, "ai21.jamba-1-5-mini-v1:0": { "display_name": "Jamba 1.5 Mini", "model_vendor": "ai21", "model_version": "v1:0" }, "ai21.jamba-instruct-v1:0": { "display_name": "Jamba Instruct", "model_vendor": "ai21", "model_version": "v1:0" }, "amazon.nova-lite-v1:0": { "display_name": "Nova Lite", "model_vendor": "amazon", "model_version": "v1:0" }, "amazon.nova-micro-v1:0": { "display_name": "Nova Micro", "model_vendor": "amazon", "model_version": "v1:0" }, "amazon.nova-pro-v1:0": { "display_name": "Nova Pro", "model_vendor": "amazon", "model_version": "v1:0" }, "amazon.titan-text-express-v1": { "display_name": "Titan Text Express", "model_vendor": "amazon", "model_version": "v1" }, "amazon.titan-text-lite-v1": { "display_name": "Titan Text Lite", "model_vendor": "amazon", "model_version": "v1" }, "amazon.titan-text-premier-v1:0": { "display_name": "Titan Text Premier", "model_vendor": "amazon", "model_version": "v1:0" }, "anthropic.claude-3-5-sonnet-20240620-v1:0": { "display_name": "Claude Sonnet 3.5", "model_vendor": "anthropic", "model_version": "20240620-v1:0" }, "anthropic.claude-3-5-sonnet-20241022-v2:0": { "display_name": "Claude Sonnet 3.5", "model_vendor": "anthropic", "model_version": "20241022-v2:0" }, "anthropic.claude-3-sonnet-20240229-v1:0": { "display_name": "Claude Sonnet 3", "model_vendor": "anthropic", "model_version": "20240229-v1:0" }, "anthropic.claude-haiku-4-5-20251001-v1:0": { "display_name": "Claude Haiku 4.5", "model_vendor": "anthropic", "model_version": "20251001-v1:0" }, "anthropic.claude-haiku-4-5@20251001": { "display_name": "Claude Haiku 4.5", "model_vendor": "anthropic", "model_version": "20251001" }, "anthropic.claude-instant-v1": { "display_name": "Claude Instant", "model_vendor": "anthropic", "model_version": "v1" }, "anthropic.claude-opus-4-1-20250805-v1:0": { "display_name": "Claude Opus 4.1", "model_vendor": "anthropic", "model_version": "20250805-v1:0" }, "anthropic.claude-opus-4-20250514-v1:0": { "display_name": "Claude Opus 4", "model_vendor": "anthropic", "model_version": "20250514-v1:0" }, "anthropic.claude-opus-4-5-20251101-v1:0": { "display_name": "Claude Opus 4.5", "model_vendor": "anthropic", "model_version": "20251101-v1:0" }, "anthropic.claude-sonnet-4-20250514-v1:0": { "display_name": "Claude Sonnet 4", "model_vendor": "anthropic", "model_version": "20250514-v1:0" }, "anthropic.claude-sonnet-4-5-20250929-v1:0": { "display_name": "Claude Sonnet 4.5", "model_vendor": "anthropic", "model_version": "20250929-v1:0" }, "anthropic.claude-v1": { "display_name": "Claude", "model_vendor": "anthropic", "model_version": "v1" }, "anthropic.claude-v2:1": { "display_name": "Claude 2", "model_vendor": "anthropic", "model_version": "v2:1" }, "apac.amazon.nova-lite-v1:0": { "display_name": "Nova Lite", "model_vendor": "amazon", "model_version": "v1:0" }, "apac.amazon.nova-micro-v1:0": { "display_name": "Nova Micro", "model_vendor": "amazon", "model_version": "v1:0" }, "apac.amazon.nova-pro-v1:0": { "display_name": "Nova Pro", "model_vendor": "amazon", "model_version": "v1:0" }, "apac.anthropic.claude-3-5-sonnet-20240620-v1:0": { "display_name": "Claude Sonnet 3.5", "model_vendor": "anthropic", "model_version": "20240620-v1:0" }, "apac.anthropic.claude-3-5-sonnet-20241022-v2:0": { "display_name": "Claude Sonnet 3.5", "model_vendor": "anthropic", "model_version": "20241022-v2:0" }, "apac.anthropic.claude-3-sonnet-20240229-v1:0": { "display_name": "Claude Sonnet 3", "model_vendor": "anthropic", "model_version": "20240229-v1:0" }, "apac.anthropic.claude-haiku-4-5-20251001-v1:0": { "display_name": "Claude Haiku 4.5", "model_vendor": "anthropic", "model_version": "20251001-v1:0" }, "apac.anthropic.claude-sonnet-4-20250514-v1:0": { "display_name": "Claude Sonnet 4", "model_vendor": "anthropic", "model_version": "20250514-v1:0" }, "au.anthropic.claude-haiku-4-5-20251001-v1:0": { "display_name": "Claude Haiku 4.5", "model_vendor": "anthropic", "model_version": "20251001-v1:0" }, "au.anthropic.claude-sonnet-4-5-20250929-v1:0": { "display_name": "Claude Sonnet 4.5", "model_vendor": "anthropic", "model_version": "20250929-v1:0" }, "azure/claude-haiku-4-5": { "display_name": "Claude Haiku", "model_vendor": "anthropic" }, "azure/claude-opus-4-1": { "display_name": "Claude Opus", "model_vendor": "anthropic" }, "azure/claude-sonnet-4-5": { "display_name": "Claude Sonnet", "model_vendor": "anthropic" }, "azure/codex-mini": { "display_name": "Codex Mini", "model_vendor": "openai" }, "azure/command-r-plus": { "display_name": "Command R Plus", "model_vendor": "cohere", "model_version": "latest" }, "azure/computer-use-preview": { "display_name": "Computer Use Preview", "model_vendor": "anthropic", "model_version": "preview" }, "azure/container": { "display_name": "Container", "model_vendor": "azure", "model_version": "latest" }, "azure/eu/gpt-4o-2024-08-06": { "display_name": "GPT-4o", "model_vendor": "openai", "model_version": "2024-08-06" }, "azure/eu/gpt-4o-2024-11-20": { "display_name": "GPT-4o", "model_vendor": "openai", "model_version": "2024-11-20" }, "azure/eu/gpt-4o-mini-2024-07-18": { "display_name": "GPT-4o Mini", "model_vendor": "openai", "model_version": "2024-07-18" }, "azure/eu/gpt-4o-mini-realtime-preview-2024-12-17": { "display_name": "GPT-4o Mini Realtime Preview", "model_vendor": "openai", "model_version": "2024-12-17" }, "azure/eu/gpt-4o-realtime-preview-2024-10-01": { "display_name": "GPT-4o Realtime Preview", "model_vendor": "openai", "model_version": "2024-10-01" }, "azure/eu/gpt-4o-realtime-preview-2024-12-17": { "display_name": "GPT-4o Realtime Preview", "model_vendor": "openai", "model_version": "2024-12-17" }, "azure/eu/gpt-5-2025-08-07": { "display_name": "GPT-5", "model_vendor": "openai", "model_version": "2025-08-07" }, "azure/eu/gpt-5-mini-2025-08-07": { "display_name": "GPT-5 Mini", "model_vendor": "openai", "model_version": "2025-08-07" }, "azure/eu/gpt-5-nano-2025-08-07": { "display_name": "GPT 5 Nano", "model_vendor": "openai", "model_version": "2025-08-07" }, "azure/eu/gpt-5.1": { "display_name": "GPT-5.1", "model_vendor": "openai" }, "azure/eu/gpt-5.1-chat": { "display_name": "GPT-5.1 Chat", "model_vendor": "openai" }, "azure/eu/gpt-5.1-codex": { "display_name": "GPT-5.1 Codex", "model_vendor": "openai" }, "azure/eu/gpt-5.1-codex-mini": { "display_name": "GPT-5.1 Codex Mini", "model_vendor": "openai" }, "azure/eu/o1-2024-12-17": { "display_name": "o1", "model_vendor": "openai", "model_version": "2024-12-17" }, "azure/eu/o1-mini-2024-09-12": { "display_name": "o1 Mini", "model_vendor": "openai", "model_version": "2024-09-12" }, "azure/eu/o1-preview-2024-09-12": { "display_name": "o1 Preview", "model_vendor": "openai", "model_version": "2024-09-12" }, "azure/eu/o3-mini-2025-01-31": { "display_name": "o3 Mini", "model_vendor": "openai", "model_version": "2025-01-31" }, "azure/global-standard/gpt-4o-2024-08-06": { "display_name": "GPT-4o", "model_vendor": "openai", "model_version": "2024-08-06" }, "azure/global-standard/gpt-4o-2024-11-20": { "display_name": "GPT-4o", "model_vendor": "openai", "model_version": "2024-11-20" }, "azure/global-standard/gpt-4o-mini": { "display_name": "GPT-4o Mini", "model_vendor": "openai" }, "azure/global/gpt-4o-2024-08-06": { "display_name": "GPT-4o", "model_vendor": "openai", "model_version": "2024-08-06" }, "azure/global/gpt-4o-2024-11-20": { "display_name": "GPT-4o", "model_vendor": "openai", "model_version": "2024-11-20" }, "azure/global/gpt-5.1": { "display_name": "GPT-5.1", "model_vendor": "openai" }, "azure/global/gpt-5.1-chat": { "display_name": "GPT-5.1 Chat", "model_vendor": "openai" }, "azure/global/gpt-5.1-codex": { "display_name": "GPT-5.1 Codex", "model_vendor": "openai" }, "azure/global/gpt-5.1-codex-mini": { "display_name": "GPT-5.1 Codex Mini", "model_vendor": "openai" }, "azure/gpt-3.5-turbo": { "display_name": "GPT-3.5 Turbo", "model_vendor": "openai" }, "azure/gpt-3.5-turbo-0125": { "display_name": "GPT 3.5 Turbo 0125", "model_vendor": "openai", "model_version": "0125" }, "azure/gpt-3.5-turbo-instruct-0914": { "display_name": "GPT-3.5 Turbo", "model_vendor": "openai", "model_version": "0914" }, "azure/gpt-35-turbo": { "display_name": "GPT-3.5 Turbo", "model_vendor": "openai" }, "azure/gpt-35-turbo-0125": { "display_name": "GPT-3.5 Turbo", "model_vendor": "openai", "model_version": "0125" }, "azure/gpt-35-turbo-0301": { "display_name": "GPT-3.5 Turbo", "model_vendor": "openai", "model_version": "0301" }, "azure/gpt-35-turbo-0613": { "display_name": "GPT-3.5 Turbo", "model_vendor": "openai", "model_version": "0613" }, "azure/gpt-35-turbo-1106": { "display_name": "GPT-3.5 Turbo", "model_vendor": "openai", "model_version": "1106" }, "azure/gpt-35-turbo-16k": { "display_name": "GPT-3.5 Turbo 16K", "model_vendor": "openai" }, "azure/gpt-35-turbo-16k-0613": { "display_name": "GPT-3.5 Turbo 16K", "model_vendor": "openai", "model_version": "0613" }, "azure/gpt-35-turbo-instruct": { "display_name": "GPT-3.5 Turbo Instruct", "model_vendor": "openai" }, "azure/gpt-35-turbo-instruct-0914": { "display_name": "GPT-3.5 Turbo", "model_vendor": "openai", "model_version": "0914" }, "azure/gpt-4": { "display_name": "GPT-4", "model_vendor": "openai" }, "azure/gpt-4-0125-preview": { "display_name": "GPT 4 0125 Preview", "model_vendor": "openai", "model_version": "0125" }, "azure/gpt-4-0613": { "display_name": "GPT 4 0613", "model_vendor": "openai", "model_version": "0613" }, "azure/gpt-4-1106-preview": { "display_name": "GPT 4 1106 Preview", "model_vendor": "openai", "model_version": "1106" }, "azure/gpt-4-32k": { "display_name": "GPT-4 32K", "model_vendor": "openai" }, "azure/gpt-4-32k-0613": { "display_name": "GPT 4 32k 0613", "model_vendor": "openai", "model_version": "0613" }, "azure/gpt-4-turbo": { "display_name": "GPT-4 Turbo", "model_vendor": "openai" }, "azure/gpt-4-turbo-2024-04-09": { "display_name": "GPT-4 Turbo", "model_vendor": "openai", "model_version": "2024-04-09" }, "azure/gpt-4-turbo-vision-preview": { "display_name": "GPT-4 Turbo Vision Preview", "model_vendor": "openai" }, "azure/gpt-4.1": { "display_name": "GPT-4.1", "model_vendor": "openai" }, "azure/gpt-4.1-2025-04-14": { "display_name": "GPT-4.1", "model_vendor": "openai", "model_version": "2025-04-14" }, "azure/gpt-4.1-mini": { "display_name": "GPT-4.1 Mini", "model_vendor": "openai" }, "azure/gpt-4.1-mini-2025-04-14": { "display_name": "GPT-4.1 Mini", "model_vendor": "openai", "model_version": "2025-04-14" }, "azure/gpt-4.1-nano": { "display_name": "GPT-4.1 Nano", "model_vendor": "openai" }, "azure/gpt-4.1-nano-2025-04-14": { "display_name": "GPT-4.1 Nano", "model_vendor": "openai", "model_version": "2025-04-14" }, "azure/gpt-4.5-preview": { "display_name": "GPT-4.5 Preview", "model_vendor": "openai" }, "azure/gpt-4o": { "display_name": "GPT-4o", "model_vendor": "openai" }, "azure/gpt-4o-2024-05-13": { "display_name": "GPT-4o", "model_vendor": "openai", "model_version": "2024-05-13" }, "azure/gpt-4o-2024-08-06": { "display_name": "GPT-4o", "model_vendor": "openai", "model_version": "2024-08-06" }, "azure/gpt-4o-2024-11-20": { "display_name": "GPT-4o", "model_vendor": "openai", "model_version": "2024-11-20" }, "azure/gpt-4o-audio-preview-2024-12-17": { "display_name": "GPT-4o", "model_vendor": "openai", "model_version": "2024-12-17" }, "azure/gpt-4o-mini": { "display_name": "GPT-4o Mini", "model_vendor": "openai" }, "azure/gpt-4o-mini-2024-07-18": { "display_name": "GPT-4o Mini", "model_vendor": "openai", "model_version": "2024-07-18" }, "azure/gpt-4o-mini-audio-preview-2024-12-17": { "display_name": "GPT-4o Mini", "model_vendor": "openai", "model_version": "2024-12-17" }, "azure/gpt-4o-mini-realtime-preview-2024-12-17": { "display_name": "GPT-4o Mini Realtime Preview", "model_vendor": "openai", "model_version": "2024-12-17" }, "azure/gpt-4o-mini-transcribe": { "display_name": "GPT-4o Mini Transcribe", "model_vendor": "openai" }, "azure/gpt-4o-mini-tts": { "display_name": "GPT-4o Mini TTS", "model_vendor": "openai" }, "azure/gpt-4o-realtime-preview-2024-10-01": { "display_name": "GPT-4o Realtime Preview", "model_vendor": "openai", "model_version": "2024-10-01" }, "azure/gpt-4o-realtime-preview-2024-12-17": { "display_name": "GPT-4o Realtime Preview", "model_vendor": "openai", "model_version": "2024-12-17" }, "azure/gpt-4o-transcribe": { "display_name": "GPT-4o Transcribe", "model_vendor": "openai" }, "azure/gpt-4o-transcribe-diarize": { "display_name": "GPT-4o Transcribe Diarize", "model_vendor": "openai" }, "azure/gpt-5": { "display_name": "GPT-5", "model_vendor": "openai" }, "azure/gpt-5-2025-08-07": { "display_name": "GPT-5", "model_vendor": "openai", "model_version": "2025-08-07" }, "azure/gpt-5-chat": { "display_name": "GPT-5 Chat", "model_vendor": "openai" }, "azure/gpt-5-chat-latest": { "display_name": "GPT 5 Chat", "model_vendor": "openai", "model_version": "latest" }, "azure/gpt-5-codex": { "display_name": "GPT-5 Codex", "model_vendor": "openai" }, "azure/gpt-5-mini": { "display_name": "GPT-5 Mini", "model_vendor": "openai" }, "azure/gpt-5-mini-2025-08-07": { "display_name": "GPT-5 Mini", "model_vendor": "openai", "model_version": "2025-08-07" }, "azure/gpt-5-nano": { "display_name": "GPT-5 Nano", "model_vendor": "openai" }, "azure/gpt-5-nano-2025-08-07": { "display_name": "GPT 5 Nano", "model_vendor": "openai", "model_version": "2025-08-07" }, "azure/gpt-5-pro": { "display_name": "GPT-5 Pro", "model_vendor": "openai" }, "azure/gpt-5.1": { "display_name": "GPT-5.1", "model_vendor": "openai" }, "azure/gpt-5.1-2025-11-13": { "display_name": "GPT 5.1", "model_vendor": "openai", "model_version": "2025-11-13" }, "azure/gpt-5.1-chat": { "display_name": "GPT-5.1 Chat", "model_vendor": "openai" }, "azure/gpt-5.1-chat-2025-11-13": { "display_name": "GPT 5.1 Chat", "model_vendor": "openai", "model_version": "2025-11-13" }, "azure/gpt-5.1-codex": { "display_name": "GPT-5.1 Codex", "model_vendor": "openai" }, "azure/gpt-5.1-codex-2025-11-13": { "display_name": "GPT-5.1 Codex", "model_vendor": "openai", "model_version": "2025-11-13" }, "azure/gpt-5.1-codex-mini": { "display_name": "GPT-5.1 Codex Mini", "model_vendor": "openai" }, "azure/gpt-5.1-codex-mini-2025-11-13": { "display_name": "GPT-5.1 Codex Mini", "model_vendor": "openai", "model_version": "2025-11-13" }, "azure/gpt-audio-2025-08-28": { "display_name": "GPT Audio", "model_vendor": "openai", "model_version": "2025-08-28" }, "azure/gpt-audio-mini-2025-10-06": { "display_name": "GPT Audio Mini", "model_vendor": "openai", "model_version": "2025-10-06" }, "azure/gpt-realtime-2025-08-28": { "display_name": "GPT Realtime", "model_vendor": "openai", "model_version": "2025-08-28" }, "azure/gpt-realtime-mini-2025-10-06": { "display_name": "GPT Realtime Mini", "model_vendor": "openai", "model_version": "2025-10-06" }, "azure/mistral-large-2402": { "display_name": "Mistral Large 24.02", "model_vendor": "mistral", "model_version": "2402" }, "azure/mistral-large-latest": { "display_name": "Mistral Large", "model_vendor": "mistral", "model_version": "latest" }, "azure/o1": { "display_name": "o1", "model_vendor": "openai", "model_version": "latest" }, "azure/o1-2024-12-17": { "display_name": "o1", "model_vendor": "openai", "model_version": "2024-12-17" }, "azure/o1-mini": { "display_name": "o1 Mini", "model_vendor": "openai", "model_version": "latest" }, "azure/o1-mini-2024-09-12": { "display_name": "o1 Mini", "model_vendor": "openai", "model_version": "2024-09-12" }, "azure/o1-preview": { "display_name": "o1 Preview", "model_vendor": "openai", "model_version": "latest" }, "azure/o1-preview-2024-09-12": { "display_name": "o1 Preview", "model_vendor": "openai", "model_version": "2024-09-12" }, "azure/o3": { "display_name": "o3", "model_vendor": "openai", "model_version": "latest" }, "azure/o3-2025-04-16": { "display_name": "o3", "model_vendor": "openai", "model_version": "2025-04-16" }, "azure/o3-deep-research": { "display_name": "O3", "model_vendor": "openai", "model_version": "latest" }, "azure/o3-mini": { "display_name": "o3 Mini", "model_vendor": "openai", "model_version": "latest" }, "azure/o3-mini-2025-01-31": { "display_name": "o3 Mini", "model_vendor": "openai", "model_version": "2025-01-31" }, "azure/o3-pro": { "display_name": "O3", "model_vendor": "openai", "model_version": "latest" }, "azure/o3-pro-2025-06-10": { "display_name": "O3", "model_vendor": "openai", "model_version": "2025-06-10" }, "azure/o4-mini": { "display_name": "o4 Mini", "model_vendor": "openai", "model_version": "latest" }, "azure/o4-mini-2025-04-16": { "display_name": "o4 Mini", "model_vendor": "openai", "model_version": "2025-04-16" }, "azure/us/gpt-4.1-2025-04-14": { "display_name": "GPT-4.1", "model_vendor": "openai", "model_version": "2025-04-14" }, "azure/us/gpt-4.1-mini-2025-04-14": { "display_name": "GPT-4.1 Mini", "model_vendor": "openai", "model_version": "2025-04-14" }, "azure/us/gpt-4.1-nano-2025-04-14": { "display_name": "GPT-4.1 Nano", "model_vendor": "openai", "model_version": "2025-04-14" }, "azure/us/gpt-4o-2024-08-06": { "display_name": "GPT-4o", "model_vendor": "openai", "model_version": "2024-08-06" }, "azure/us/gpt-4o-2024-11-20": { "display_name": "GPT-4o", "model_vendor": "openai", "model_version": "2024-11-20" }, "azure/us/gpt-4o-mini-2024-07-18": { "display_name": "GPT-4o Mini", "model_vendor": "openai", "model_version": "2024-07-18" }, "azure/us/gpt-4o-mini-realtime-preview-2024-12-17": { "display_name": "GPT-4o Mini Realtime Preview", "model_vendor": "openai", "model_version": "2024-12-17" }, "azure/us/gpt-4o-realtime-preview-2024-10-01": { "display_name": "GPT-4o Realtime Preview", "model_vendor": "openai", "model_version": "2024-10-01" }, "azure/us/gpt-4o-realtime-preview-2024-12-17": { "display_name": "GPT-4o Realtime Preview", "model_vendor": "openai", "model_version": "2024-12-17" }, "azure/us/gpt-5-2025-08-07": { "display_name": "GPT-5", "model_vendor": "openai", "model_version": "2025-08-07" }, "azure/us/gpt-5-mini-2025-08-07": { "display_name": "GPT-5 Mini", "model_vendor": "openai", "model_version": "2025-08-07" }, "azure/us/gpt-5-nano-2025-08-07": { "display_name": "GPT 5 Nano", "model_vendor": "openai", "model_version": "2025-08-07" }, "azure/us/gpt-5.1": { "display_name": "GPT-5.1", "model_vendor": "openai" }, "azure/us/gpt-5.1-chat": { "display_name": "GPT-5.1 Chat", "model_vendor": "openai" }, "azure/us/gpt-5.1-codex": { "display_name": "GPT-5.1 Codex", "model_vendor": "openai" }, "azure/us/gpt-5.1-codex-mini": { "display_name": "GPT-5.1 Codex Mini", "model_vendor": "openai" }, "azure/us/o1-2024-12-17": { "display_name": "o1", "model_vendor": "openai", "model_version": "2024-12-17" }, "azure/us/o1-mini-2024-09-12": { "display_name": "o1 Mini", "model_vendor": "openai", "model_version": "2024-09-12" }, "azure/us/o1-preview-2024-09-12": { "display_name": "o1 Preview", "model_vendor": "openai", "model_version": "2024-09-12" }, "azure/us/o3-2025-04-16": { "display_name": "o3", "model_vendor": "openai", "model_version": "2025-04-16" }, "azure/us/o3-mini-2025-01-31": { "display_name": "o3 Mini", "model_vendor": "openai", "model_version": "2025-01-31" }, "azure/us/o4-mini-2025-04-16": { "display_name": "o4 Mini", "model_vendor": "openai", "model_version": "2025-04-16" }, "azure_ai/Llama-3.2-11B-Vision-Instruct": { "display_name": "Llama 3.2 11B Vision Instruct", "model_vendor": "meta" }, "azure_ai/Llama-3.2-90B-Vision-Instruct": { "display_name": "Llama 3.2 90B Vision Instruct", "model_vendor": "meta" }, "azure_ai/Llama-3.3-70B-Instruct": { "display_name": "Llama 3.3 70B Instruct", "model_vendor": "meta" }, "azure_ai/Llama-4-Maverick-17B-128E-Instruct-FP8": { "display_name": "Llama 4 Maverick 17B 128E Instruct FP8", "model_vendor": "meta" }, "azure_ai/Llama-4-Scout-17B-16E-Instruct": { "display_name": "Llama 4 Scout 17B 16E Instruct", "model_vendor": "meta" }, "azure_ai/MAI-DS-R1": { "display_name": "MAI-DS-R1", "model_vendor": "microsoft", "model_version": "latest" }, "azure_ai/Meta-Llama-3-70B-Instruct": { "display_name": "Llama 3 70B Instruct", "model_vendor": "meta" }, "azure_ai/Meta-Llama-3.1-405B-Instruct": { "display_name": "Llama 3.1 405B Instruct", "model_vendor": "meta" }, "azure_ai/Meta-Llama-3.1-70B-Instruct": { "display_name": "Llama 3.1 70B Instruct", "model_vendor": "meta" }, "azure_ai/Meta-Llama-3.1-8B-Instruct": { "display_name": "Llama 3.1 8B Instruct", "model_vendor": "meta" }, "azure_ai/Phi-3-medium-128k-instruct": { "display_name": "Phi 3 Medium 128k Instruct", "model_vendor": "microsoft", "model_version": "latest" }, "azure_ai/Phi-3-medium-4k-instruct": { "display_name": "Phi 3 Medium 4k Instruct", "model_vendor": "microsoft", "model_version": "latest" }, "azure_ai/Phi-3-mini-128k-instruct": { "display_name": "Phi 3 Mini 128k Instruct", "model_vendor": "microsoft", "model_version": "latest" }, "azure_ai/Phi-3-mini-4k-instruct": { "display_name": "Phi 3 Mini 4k Instruct", "model_vendor": "microsoft", "model_version": "latest" }, "azure_ai/Phi-3-small-128k-instruct": { "display_name": "Phi 3 Small 128k Instruct", "model_vendor": "microsoft", "model_version": "latest" }, "azure_ai/Phi-3-small-8k-instruct": { "display_name": "Phi 3 Small 8k Instruct", "model_vendor": "microsoft", "model_version": "latest" }, "azure_ai/Phi-3.5-MoE-instruct": { "display_name": "Phi 3.5 MOE Instruct", "model_vendor": "microsoft", "model_version": "latest" }, "azure_ai/Phi-3.5-mini-instruct": { "display_name": "Phi 3.5 Mini Instruct", "model_vendor": "microsoft", "model_version": "latest" }, "azure_ai/Phi-3.5-vision-instruct": { "display_name": "Phi 3.5 Vision Instruct", "model_vendor": "microsoft", "model_version": "latest" }, "azure_ai/Phi-4": { "display_name": "Phi 4", "model_vendor": "microsoft", "model_version": "latest" }, "azure_ai/Phi-4-mini-instruct": { "display_name": "Phi 4 Mini Instruct", "model_vendor": "microsoft", "model_version": "latest" }, "azure_ai/Phi-4-mini-reasoning": { "display_name": "Phi 4 Mini Reasoning", "model_vendor": "microsoft", "model_version": "latest" }, "azure_ai/Phi-4-multimodal-instruct": { "display_name": "Phi 4 Multimodal Instruct", "model_vendor": "microsoft", "model_version": "latest" }, "azure_ai/Phi-4-reasoning": { "display_name": "Phi 4 Reasoning", "model_vendor": "microsoft", "model_version": "latest" }, "azure_ai/deepseek-r1": { "display_name": "DeepSeek R1", "model_vendor": "deepseek", "model_version": "latest" }, "azure_ai/deepseek-v3": { "display_name": "DeepSeek V3", "model_vendor": "deepseek", "model_version": "v3" }, "azure_ai/deepseek-v3-0324": { "display_name": "DeepSeek v3 0324", "model_vendor": "deepseek", "model_version": "0324" }, "azure_ai/global/grok-3": { "display_name": "Grok 3", "model_vendor": "xai", "model_version": "latest" }, "azure_ai/global/grok-3-mini": { "display_name": "Grok 3 Mini", "model_vendor": "xai", "model_version": "latest" }, "azure_ai/grok-3": { "display_name": "Grok 3", "model_vendor": "xai", "model_version": "latest" }, "azure_ai/grok-3-mini": { "display_name": "Grok 3 Mini", "model_vendor": "xai", "model_version": "latest" }, "azure_ai/grok-4": { "display_name": "Grok 4", "model_vendor": "xai", "model_version": "latest" }, "azure_ai/grok-4-fast-non-reasoning": { "display_name": "Grok 4 Fast Non Reasoning", "model_vendor": "xai", "model_version": "latest" }, "azure_ai/grok-4-fast-reasoning": { "display_name": "Grok 4 Fast Reasoning", "model_vendor": "xai", "model_version": "latest" }, "azure_ai/grok-code-fast-1": { "display_name": "Grok Code Fast 1", "model_vendor": "xai", "model_version": "latest" }, "azure_ai/jais-30b-chat": { "display_name": "Jais 30B Chat", "model_vendor": "g42", "model_version": "latest" }, "azure_ai/jamba-instruct": { "display_name": "Jamba Instruct", "model_vendor": "ai21", "model_version": "latest" }, "azure_ai/ministral-3b": { "display_name": "Ministral 3B", "model_vendor": "mistral", "model_version": "latest" }, "azure_ai/mistral-large": { "display_name": "Mistral Large", "model_vendor": "mistral", "model_version": "latest" }, "azure_ai/mistral-large-2407": { "display_name": "Mistral Large 24.07", "model_vendor": "mistral", "model_version": "2407" }, "azure_ai/mistral-large-latest": { "display_name": "Mistral Large", "model_vendor": "mistral", "model_version": "latest" }, "azure_ai/mistral-medium-2505": { "display_name": "Mistral Medium 2505", "model_vendor": "mistral", "model_version": "2505" }, "azure_ai/mistral-nemo": { "display_name": "Mistral Nemo", "model_vendor": "mistral", "model_version": "latest" }, "azure_ai/mistral-small": { "display_name": "Mistral Small", "model_vendor": "mistral", "model_version": "latest" }, "azure_ai/mistral-small-2503": { "display_name": "Mistral Small 2503", "model_vendor": "mistral", "model_version": "2503" }, "bedrock/*/1-month-commitment/cohere.command-light-text-v14": { "display_name": "Command Light Text", "model_vendor": "cohere", "model_version": "v14" }, "bedrock/*/1-month-commitment/cohere.command-text-v14": { "display_name": "Command Text", "model_vendor": "cohere", "model_version": "v14" }, "bedrock/*/6-month-commitment/cohere.command-light-text-v14": { "display_name": "Command Light Text", "model_vendor": "cohere", "model_version": "v14" }, "bedrock/*/6-month-commitment/cohere.command-text-v14": { "display_name": "Command Text", "model_vendor": "cohere", "model_version": "v14" }, "bedrock/ap-northeast-1/1-month-commitment/anthropic.claude-instant-v1": { "display_name": "Claude Instant", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/ap-northeast-1/1-month-commitment/anthropic.claude-v1": { "display_name": "Claude", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/ap-northeast-1/1-month-commitment/anthropic.claude-v2:1": { "display_name": "Claude 2", "model_vendor": "anthropic", "model_version": "v2:1" }, "bedrock/ap-northeast-1/6-month-commitment/anthropic.claude-instant-v1": { "display_name": "Claude Instant", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/ap-northeast-1/6-month-commitment/anthropic.claude-v1": { "display_name": "Claude", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/ap-northeast-1/6-month-commitment/anthropic.claude-v2:1": { "display_name": "Claude 2", "model_vendor": "anthropic", "model_version": "v2:1" }, "bedrock/ap-northeast-1/anthropic.claude-instant-v1": { "display_name": "Claude Instant", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/ap-northeast-1/anthropic.claude-v1": { "display_name": "Claude", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/ap-northeast-1/anthropic.claude-v2:1": { "display_name": "Claude 2", "model_vendor": "anthropic", "model_version": "v2:1" }, "bedrock/ap-south-1/meta.llama3-70b-instruct-v1:0": { "display_name": "Llama 3 70B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "bedrock/ap-south-1/meta.llama3-8b-instruct-v1:0": { "display_name": "Llama 3 8B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "bedrock/ca-central-1/meta.llama3-70b-instruct-v1:0": { "display_name": "Llama 3 70B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "bedrock/ca-central-1/meta.llama3-8b-instruct-v1:0": { "display_name": "Llama 3 8B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "bedrock/eu-central-1/1-month-commitment/anthropic.claude-instant-v1": { "display_name": "Claude Instant", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/eu-central-1/1-month-commitment/anthropic.claude-v1": { "display_name": "Claude", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/eu-central-1/1-month-commitment/anthropic.claude-v2:1": { "display_name": "Claude 2", "model_vendor": "anthropic", "model_version": "v2:1" }, "bedrock/eu-central-1/6-month-commitment/anthropic.claude-instant-v1": { "display_name": "Claude Instant", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/eu-central-1/6-month-commitment/anthropic.claude-v1": { "display_name": "Claude", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/eu-central-1/6-month-commitment/anthropic.claude-v2:1": { "display_name": "Claude 2", "model_vendor": "anthropic", "model_version": "v2:1" }, "bedrock/eu-central-1/anthropic.claude-instant-v1": { "display_name": "Claude Instant", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/eu-central-1/anthropic.claude-v1": { "display_name": "Claude", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/eu-central-1/anthropic.claude-v2:1": { "display_name": "Claude 2", "model_vendor": "anthropic", "model_version": "v2:1" }, "bedrock/eu-west-1/meta.llama3-70b-instruct-v1:0": { "display_name": "Llama 3 70B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "bedrock/eu-west-1/meta.llama3-8b-instruct-v1:0": { "display_name": "Llama 3 8B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "bedrock/eu-west-2/meta.llama3-70b-instruct-v1:0": { "display_name": "Llama 3 70B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "bedrock/eu-west-2/meta.llama3-8b-instruct-v1:0": { "display_name": "Llama 3 8B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "bedrock/eu-west-3/mistral.mistral-7b-instruct-v0:2": { "display_name": "Mistral 7B Instruct", "model_vendor": "mistral", "model_version": "v0:2" }, "bedrock/eu-west-3/mistral.mistral-large-2402-v1:0": { "display_name": "Mistral Large 24.02", "model_vendor": "mistral", "model_version": "2402-v1:0" }, "bedrock/eu-west-3/mistral.mixtral-8x7b-instruct-v0:1": { "display_name": "Mixtral 8x7B Instruct", "model_vendor": "mistral", "model_version": "v0:1" }, "bedrock/invoke/anthropic.claude-3-5-sonnet-20240620-v1:0": { "display_name": "Claude Sonnet 3.5", "model_vendor": "anthropic", "model_version": "20240620-v1:0" }, "bedrock/sa-east-1/meta.llama3-70b-instruct-v1:0": { "display_name": "Llama 3 70B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "bedrock/sa-east-1/meta.llama3-8b-instruct-v1:0": { "display_name": "Llama 3 8B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "bedrock/us-east-1/1-month-commitment/anthropic.claude-instant-v1": { "display_name": "Claude Instant", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/us-east-1/1-month-commitment/anthropic.claude-v1": { "display_name": "Claude", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/us-east-1/1-month-commitment/anthropic.claude-v2:1": { "display_name": "Claude 2", "model_vendor": "anthropic", "model_version": "v2:1" }, "bedrock/us-east-1/6-month-commitment/anthropic.claude-instant-v1": { "display_name": "Claude Instant", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/us-east-1/6-month-commitment/anthropic.claude-v1": { "display_name": "Claude", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/us-east-1/6-month-commitment/anthropic.claude-v2:1": { "display_name": "Claude 2", "model_vendor": "anthropic", "model_version": "v2:1" }, "bedrock/us-east-1/anthropic.claude-instant-v1": { "display_name": "Claude Instant", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/us-east-1/anthropic.claude-v1": { "display_name": "Claude", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/us-east-1/anthropic.claude-v2:1": { "display_name": "Claude 2", "model_vendor": "anthropic", "model_version": "v2:1" }, "bedrock/us-east-1/meta.llama3-70b-instruct-v1:0": { "display_name": "Llama 3 70B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "bedrock/us-east-1/meta.llama3-8b-instruct-v1:0": { "display_name": "Llama 3 8B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "bedrock/us-east-1/mistral.mistral-7b-instruct-v0:2": { "display_name": "Mistral 7B Instruct", "model_vendor": "mistral", "model_version": "v0:2" }, "bedrock/us-east-1/mistral.mistral-large-2402-v1:0": { "display_name": "Mistral Large 24.02", "model_vendor": "mistral", "model_version": "2402-v1:0" }, "bedrock/us-east-1/mistral.mixtral-8x7b-instruct-v0:1": { "display_name": "Mixtral 8x7B Instruct", "model_vendor": "mistral", "model_version": "v0:1" }, "bedrock/us-gov-east-1/amazon.nova-pro-v1:0": { "display_name": "Nova Pro", "model_vendor": "amazon", "model_version": "v1:0" }, "bedrock/us-gov-east-1/amazon.titan-text-express-v1": { "display_name": "Titan Text Express", "model_vendor": "amazon", "model_version": "v1" }, "bedrock/us-gov-east-1/amazon.titan-text-lite-v1": { "display_name": "Titan Text Lite", "model_vendor": "amazon", "model_version": "v1" }, "bedrock/us-gov-east-1/amazon.titan-text-premier-v1:0": { "display_name": "Titan Text Premier", "model_vendor": "amazon", "model_version": "v1:0" }, "bedrock/us-gov-east-1/anthropic.claude-3-5-sonnet-20240620-v1:0": { "display_name": "Claude Sonnet 3.5", "model_vendor": "anthropic", "model_version": "20240620-v1:0" }, "bedrock/us-gov-east-1/claude-sonnet-4-5-20250929-v1:0": { "display_name": "Claude Sonnet 4.5", "model_vendor": "anthropic", "model_version": "20250929-v1:0" }, "bedrock/us-gov-east-1/meta.llama3-70b-instruct-v1:0": { "display_name": "Llama 3 70B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "bedrock/us-gov-east-1/meta.llama3-8b-instruct-v1:0": { "display_name": "Llama 3 8B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "bedrock/us-gov-west-1/amazon.nova-pro-v1:0": { "display_name": "Nova Pro", "model_vendor": "amazon", "model_version": "v1:0" }, "bedrock/us-gov-west-1/amazon.titan-text-express-v1": { "display_name": "Titan Text Express", "model_vendor": "amazon", "model_version": "v1" }, "bedrock/us-gov-west-1/amazon.titan-text-lite-v1": { "display_name": "Titan Text Lite", "model_vendor": "amazon", "model_version": "v1" }, "bedrock/us-gov-west-1/amazon.titan-text-premier-v1:0": { "display_name": "Titan Text Premier", "model_vendor": "amazon", "model_version": "v1:0" }, "bedrock/us-gov-west-1/anthropic.claude-3-5-sonnet-20240620-v1:0": { "display_name": "Claude Sonnet 3.5", "model_vendor": "anthropic", "model_version": "20240620-v1:0" }, "bedrock/us-gov-west-1/claude-sonnet-4-5-20250929-v1:0": { "display_name": "Claude Sonnet 4.5", "model_vendor": "anthropic", "model_version": "20250929-v1:0" }, "bedrock/us-gov-west-1/meta.llama3-70b-instruct-v1:0": { "display_name": "Llama 3 70B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "bedrock/us-gov-west-1/meta.llama3-8b-instruct-v1:0": { "display_name": "Llama 3 8B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "bedrock/us-west-1/meta.llama3-70b-instruct-v1:0": { "display_name": "Llama 3 70B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "bedrock/us-west-1/meta.llama3-8b-instruct-v1:0": { "display_name": "Llama 3 8B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "bedrock/us-west-2/1-month-commitment/anthropic.claude-instant-v1": { "display_name": "Claude Instant", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/us-west-2/1-month-commitment/anthropic.claude-v1": { "display_name": "Claude", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/us-west-2/1-month-commitment/anthropic.claude-v2:1": { "display_name": "Claude 2", "model_vendor": "anthropic", "model_version": "v2:1" }, "bedrock/us-west-2/6-month-commitment/anthropic.claude-instant-v1": { "display_name": "Claude Instant", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/us-west-2/6-month-commitment/anthropic.claude-v1": { "display_name": "Claude", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/us-west-2/6-month-commitment/anthropic.claude-v2:1": { "display_name": "Claude 2", "model_vendor": "anthropic", "model_version": "v2:1" }, "bedrock/us-west-2/anthropic.claude-instant-v1": { "display_name": "Claude Instant", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/us-west-2/anthropic.claude-v1": { "display_name": "Claude", "model_vendor": "anthropic", "model_version": "v1" }, "bedrock/us-west-2/anthropic.claude-v2:1": { "display_name": "Claude 2", "model_vendor": "anthropic", "model_version": "v2:1" }, "bedrock/us-west-2/mistral.mistral-7b-instruct-v0:2": { "display_name": "Mistral 7B Instruct", "model_vendor": "mistral", "model_version": "v0:2" }, "bedrock/us-west-2/mistral.mistral-large-2402-v1:0": { "display_name": "Mistral Large 24.02", "model_vendor": "mistral", "model_version": "2402-v1:0" }, "bedrock/us-west-2/mistral.mixtral-8x7b-instruct-v0:1": { "display_name": "Mixtral 8x7B Instruct", "model_vendor": "mistral", "model_version": "v0:1" }, "chat-bison": { "display_name": "Chat Bison", "model_vendor": "google", "model_version": "latest" }, "chat-bison-32k": { "display_name": "Chat Bison 32k", "model_vendor": "google", "model_version": "latest" }, "chat-bison-32k@002": { "display_name": "Chat Bison 32k", "model_vendor": "google", "model_version": "002" }, "chat-bison@001": { "display_name": "Chat Bison", "model_vendor": "google", "model_version": "001" }, "chat-bison@002": { "display_name": "Chat Bison", "model_vendor": "google", "model_version": "002" }, "chatgpt-4o-latest": { "display_name": "ChatGPT 4o", "model_vendor": "openai", "model_version": "latest" }, "claude-3-5-sonnet-20240620": { "display_name": "Claude Sonnet 3.5", "model_vendor": "anthropic", "model_version": "20240620" }, "claude-3-5-sonnet-20241022": { "display_name": "Claude Sonnet 3.5", "model_vendor": "anthropic", "model_version": "20241022" }, "claude-3-5-sonnet-latest": { "display_name": "Claude Sonnet 3.5", "model_vendor": "anthropic", "model_version": "latest" }, "claude-4-opus-20250514": { "display_name": "Claude Opus 4", "model_vendor": "anthropic", "model_version": "20250514" }, "claude-4-sonnet-20250514": { "display_name": "Claude Sonnet 4", "model_vendor": "anthropic", "model_version": "20250514" }, "claude-haiku-4-5": { "display_name": "Claude Haiku 4.5", "model_vendor": "anthropic" }, "claude-haiku-4-5-20251001": { "display_name": "Claude Haiku 4.5", "model_vendor": "anthropic", "model_version": "20251001" }, "claude-opus-4-1": { "display_name": "Claude Opus 4.1", "model_vendor": "anthropic" }, "claude-opus-4-1-20250805": { "display_name": "Claude Opus 4.1", "model_vendor": "anthropic", "model_version": "20250805" }, "claude-opus-4-1@20250805": { "display_name": "Claude Opus 4.1", "model_vendor": "anthropic", "model_version": "20250805" }, "claude-opus-4-20250514": { "display_name": "Claude Opus 4", "model_vendor": "anthropic", "model_version": "20250514" }, "claude-opus-4-5": { "display_name": "Claude Opus 4.5", "model_vendor": "anthropic" }, "claude-opus-4-6": { "display_name": "Claude Opus 4.6", "model_vendor": "anthropic" }, "claude-opus-4-5-20251101": { "display_name": "Claude Opus 4.5", "model_vendor": "anthropic", "model_version": "20251101" }, "claude-sonnet-4-20250514": { "display_name": "Claude Sonnet 4", "model_vendor": "anthropic", "model_version": "20250514" }, "claude-sonnet-4-5": { "display_name": "Claude Sonnet 4.5", "model_vendor": "anthropic" }, "claude-sonnet-4-6": { "display_name": "Claude Sonnet 4.6", "model_vendor": "anthropic" }, "claude-sonnet-4-5-20250929": { "display_name": "Claude Sonnet 4.5", "model_vendor": "anthropic", "model_version": "20250929" }, "claude-sonnet-4-5-20250929-v1:0": { "display_name": "Claude Sonnet 4.5", "model_vendor": "anthropic", "model_version": "20250929-v1:0" }, "codechat-bison": { "display_name": "Codechat Bison", "model_vendor": "google", "model_version": "latest" }, "codechat-bison-32k": { "display_name": "Codechat Bison 32k", "model_vendor": "google", "model_version": "latest" }, "codechat-bison-32k@002": { "display_name": "Codechat Bison 32k", "model_vendor": "google", "model_version": "002" }, "codechat-bison@001": { "display_name": "Codechat Bison", "model_vendor": "google", "model_version": "001" }, "codechat-bison@002": { "display_name": "Codechat Bison", "model_vendor": "google", "model_version": "002" }, "codechat-bison@latest": { "display_name": "Codechat Bison", "model_vendor": "google", "model_version": "latest" }, "codex-mini-latest": { "display_name": "Codex Mini", "model_vendor": "openai" }, "cohere.command-light-text-v14": { "display_name": "Command Light Text", "model_vendor": "cohere", "model_version": "v14" }, "cohere.command-r-plus-v1:0": { "display_name": "Command R Plus", "model_vendor": "cohere", "model_version": "v1:0" }, "cohere.command-r-v1:0": { "display_name": "Command R", "model_vendor": "cohere", "model_version": "v1:0" }, "cohere.command-text-v14": { "display_name": "Command Text", "model_vendor": "cohere", "model_version": "v14" }, "computer-use-preview": { "display_name": "Computer Use Preview", "model_vendor": "anthropic", "model_version": "preview" }, "deepseek.v3-v1:0": { "display_name": "DeepSeek V3", "model_vendor": "deepseek", "model_version": "v1:0" }, "deepseek/deepseek-chat": { "display_name": "DeepSeek Chat", "model_vendor": "deepseek", "model_version": "latest" }, "deepseek/deepseek-coder": { "display_name": "DeepSeek Coder", "model_vendor": "deepseek", "model_version": "latest" }, "deepseek/deepseek-r1": { "display_name": "DeepSeek R1", "model_vendor": "deepseek", "model_version": "latest" }, "deepseek/deepseek-reasoner": { "display_name": "DeepSeek Reasoner", "model_vendor": "deepseek", "model_version": "latest" }, "deepseek/deepseek-v3": { "display_name": "DeepSeek V3", "model_vendor": "deepseek", "model_version": "v3" }, "eu.amazon.nova-lite-v1:0": { "display_name": "Nova Lite", "model_vendor": "amazon", "model_version": "v1:0" }, "eu.amazon.nova-micro-v1:0": { "display_name": "Nova Micro", "model_vendor": "amazon", "model_version": "v1:0" }, "eu.amazon.nova-pro-v1:0": { "display_name": "Nova Pro", "model_vendor": "amazon", "model_version": "v1:0" }, "eu.anthropic.claude-3-5-sonnet-20240620-v1:0": { "display_name": "Claude Sonnet 3.5", "model_vendor": "anthropic", "model_version": "20240620-v1:0" }, "eu.anthropic.claude-3-5-sonnet-20241022-v2:0": { "display_name": "Claude Sonnet 3.5", "model_vendor": "anthropic", "model_version": "20241022-v2:0" }, "eu.anthropic.claude-3-sonnet-20240229-v1:0": { "display_name": "Claude Sonnet 3", "model_vendor": "anthropic", "model_version": "20240229-v1:0" }, "eu.anthropic.claude-haiku-4-5-20251001-v1:0": { "display_name": "Claude Haiku 4.5", "model_vendor": "anthropic", "model_version": "20251001-v1:0" }, "eu.anthropic.claude-opus-4-1-20250805-v1:0": { "display_name": "Claude Opus 4.1", "model_vendor": "anthropic", "model_version": "20250805-v1:0" }, "eu.anthropic.claude-opus-4-20250514-v1:0": { "display_name": "Claude Opus 4", "model_vendor": "anthropic", "model_version": "20250514-v1:0" }, "eu.anthropic.claude-sonnet-4-20250514-v1:0": { "display_name": "Claude Sonnet 4", "model_vendor": "anthropic", "model_version": "20250514-v1:0" }, "eu.anthropic.claude-sonnet-4-5-20250929-v1:0": { "display_name": "Claude Sonnet 4.5", "model_vendor": "anthropic", "model_version": "20250929-v1:0" }, "eu.meta.llama3-2-1b-instruct-v1:0": { "display_name": "Llama 3.2 1B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "eu.meta.llama3-2-3b-instruct-v1:0": { "display_name": "Llama 3.2 3B Instruct", "model_vendor": "meta", "model_version": "v1:0" }, "eu.mistral.pixtral-large-2502-v1:0": { "display_name": "Pixtral Large 25.02", "model_vendor": "mistral", "model_version": "2502-v1:0" }, "eu.twelvelabs.pegasus-1-2-v1:0": { "display_name": "Pegasus 1.2", "model_vendor": "twelvelabs", "model_version": "1.2-v1:0" }, "ft:gpt-3.5-turbo": { "display_name": "Ft:gpt 3.5 Turbo", "model_vendor": "openai" }, "ft:gpt-3.5-turbo-0125": { "display_name": "GPT-3.5 Turbo (Fine-tuned)", "model_vendor": "openai", "model_version": "0125" }, "ft:gpt-3.5-turbo-0613": { "display_name": "GPT-3.5 Turbo (Fine-tuned)", "model_vendor": "openai", "model_version": "0613" }, "ft:gpt-3.5-turbo-1106": { "display_name": "GPT-3.5 Turbo (Fine-tuned)", "model_vendor": "openai", "model_version": "1106" }, "ft:gpt-4-0613": { "display_name": "GPT-4 (Fine-tuned)", "model_vendor": "openai", "model_version": "0613" }, "ft:gpt-4o-2024-08-06": { "display_name": "GPT-4o (Fine-tuned)", "model_vendor": "openai", "model_version": "2024-08-06" }, "ft:gpt-4o-2024-11-20": { "display_name": "GPT-4o (Fine-tuned)", "model_vendor": "openai", "model_version": "2024-11-20" }, "ft:gpt-4o-mini-2024-07-18": { "display_name": "GPT-4o Mini (Fine-tuned)", "model_vendor": "openai", "model_version": "2024-07-18" }, "gemini-1.0-pro": { "display_name": "Gemini 1.0 Pro", "model_vendor": "google" }, "gemini-1.0-pro-001": { "display_name": "Gemini 1.0 Pro 001", "model_vendor": "google", "model_version": "001" }, "gemini-1.0-pro-002": { "display_name": "Gemini 1.0 Pro 002", "model_vendor": "google", "model_version": "002" }, "gemini-1.0-ultra": { "display_name": "Gemini 1.0 Ultra", "model_vendor": "google" }, "gemini-1.0-ultra-001": { "display_name": "Gemini 1.0 Ultra 001", "model_vendor": "google", "model_version": "001" }, "gemini-1.5-flash": { "display_name": "Gemini 1.5 Flash", "model_vendor": "google" }, "gemini-1.5-flash-001": { "display_name": "Gemini 1.5 Flash 001", "model_vendor": "google", "model_version": "001" }, "gemini-1.5-flash-002": { "display_name": "Gemini 1.5 Flash 002", "model_vendor": "google", "model_version": "002" }, "gemini-1.5-flash-exp-0827": { "display_name": "Gemini 1.5 Flash Exp 0827", "model_vendor": "google", "model_version": "0827" }, "gemini-1.5-flash-preview-0514": { "display_name": "Gemini 1.5 Flash Preview 0514", "model_vendor": "google", "model_version": "0514" }, "gemini-1.5-pro": { "display_name": "Gemini 1.5 Pro", "model_vendor": "google" }, "gemini-1.5-pro-001": { "display_name": "Gemini 1.5 Pro 001", "model_vendor": "google", "model_version": "001" }, "gemini-1.5-pro-002": { "display_name": "Gemini 1.5 Pro 002", "model_vendor": "google", "model_version": "002" }, "gemini-1.5-pro-preview-0215": { "display_name": "Gemini 1.5 Pro Preview 0215", "model_vendor": "google", "model_version": "0215" }, "gemini-1.5-pro-preview-0409": { "display_name": "Gemini 1.5 Pro Preview 0409", "model_vendor": "google", "model_version": "0409" }, "gemini-1.5-pro-preview-0514": { "display_name": "Gemini 1.5 Pro Preview 0514", "model_vendor": "google", "model_version": "0514" }, "gemini-2.0-flash": { "display_name": "Gemini 2.0 Flash", "model_vendor": "google" }, "gemini-2.0-flash-001": { "display_name": "Gemini 2.0 Flash 001", "model_vendor": "google", "model_version": "001" }, "gemini-2.0-flash-exp": { "display_name": "Gemini 2.0 Flash", "model_vendor": "google" }, "gemini-2.0-flash-lite": { "display_name": "Gemini 2.0 Flash Lite", "model_vendor": "google" }, "gemini-2.0-flash-lite-001": { "display_name": "Gemini 2.0 Flash Lite 001", "model_vendor": "google", "model_version": "001" }, "gemini-2.0-flash-live-preview-04-09": { "display_name": "Gemini 2.0 Flash Live Preview 04 09", "model_vendor": "google" }, "gemini-2.0-flash-thinking-exp": { "display_name": "Gemini 2.0 Flash Thinking", "model_vendor": "google" }, "gemini-2.0-flash-thinking-exp-01-21": { "display_name": "Gemini 2.0 Flash Thinking Exp 01 21", "model_vendor": "google" }, "gemini-2.0-pro-exp-02-05": { "display_name": "Gemini 2.0 Pro Exp 02 05", "model_vendor": "google" }, "gemini-2.5-flash": { "display_name": "Gemini 2.5 Flash", "model_vendor": "google" }, "gemini-2.5-flash-lite": { "display_name": "Gemini 2.5 Flash Lite", "model_vendor": "google" }, "gemini-2.5-flash-lite-preview-06-17": { "display_name": "Gemini 2.5 Flash Lite Preview 06 17", "model_vendor": "google" }, "gemini-2.5-flash-lite-preview-09-2025": { "display_name": "Gemini 2.5 Flash Lite Preview 09 2025", "model_vendor": "google", "model_version": "2025" }, "gemini-2.5-flash-preview-04-17": { "display_name": "Gemini 2.5 Flash Preview 04 17", "model_vendor": "google" }, "gemini-2.5-flash-preview-05-20": { "display_name": "Gemini 2.5 Flash Preview 05 20", "model_vendor": "google" }, "gemini-2.5-flash-preview-09-2025": { "display_name": "Gemini 2.5 Flash Preview 09 2025", "model_vendor": "google", "model_version": "2025" }, "gemini-2.5-pro": { "display_name": "Gemini 2.5 Pro", "model_vendor": "google" }, "gemini-2.5-pro-exp-03-25": { "display_name": "Gemini 2.5 Pro Exp 03 25", "model_vendor": "google" }, "gemini-2.5-pro-preview-03-25": { "display_name": "Gemini 2.5 Pro Preview 03 25", "model_vendor": "google" }, "gemini-2.5-pro-preview-05-06": { "display_name": "Gemini 2.5 Pro Preview 05 06", "model_vendor": "google" }, "gemini-2.5-pro-preview-06-05": { "display_name": "Gemini 2.5 Pro Preview 06 05", "model_vendor": "google" }, "gemini-3-pro-preview": { "display_name": "Gemini 3 Pro Preview", "model_vendor": "google", "model_version": "preview" }, "gemini-3-flash-preview": { "display_name": "Gemini 3 Flash Preview", "model_vendor": "google", "model_version": "preview" }, "gemini-flash-experimental": { "display_name": "Gemini Flash Experimental", "model_vendor": "google", "model_version": "experimental" }, "gemini-pro": { "display_name": "Gemini Pro", "model_vendor": "google" }, "gemini-pro-experimental": { "display_name": "Gemini Pro Experimental", "model_vendor": "google" }, "gemini/gemini-1.5-flash": { "display_name": "Gemini 1.5 Flash", "model_vendor": "google" }, "gemini/gemini-1.5-flash-001": { "display_name": "Gemini 1.5 Flash", "model_vendor": "google", "model_version": "001" }, "gemini/gemini-1.5-flash-002": { "display_name": "Gemini 1.5 Flash", "model_vendor": "google", "model_version": "002" }, "gemini/gemini-1.5-flash-8b": { "display_name": "Gemini 1.5 Flash 8B", "model_vendor": "google" }, "gemini/gemini-1.5-flash-8b-exp-0827": { "display_name": "Gemini 1.5 Flash 8B", "model_vendor": "google", "model_version": "0827" }, "gemini/gemini-1.5-flash-8b-exp-0924": { "display_name": "Gemini 1.5 Flash 8B", "model_vendor": "google", "model_version": "0924" }, "gemini/gemini-1.5-flash-exp-0827": { "display_name": "Gemini 1.5 Flash", "model_vendor": "google", "model_version": "0827" }, "gemini/gemini-1.5-flash-latest": { "display_name": "Gemini 1.5 Flash", "model_vendor": "google" }, "gemini/gemini-1.5-pro": { "display_name": "Gemini 1.5 Pro", "model_vendor": "google" }, "gemini/gemini-1.5-pro-001": { "display_name": "Gemini 1.5 Pro", "model_vendor": "google", "model_version": "001" }, "gemini/gemini-1.5-pro-002": { "display_name": "Gemini 1.5 Pro", "model_vendor": "google", "model_version": "002" }, "gemini/gemini-1.5-pro-exp-0801": { "display_name": "Gemini 1.5 Pro", "model_vendor": "google", "model_version": "0801" }, "gemini/gemini-1.5-pro-exp-0827": { "display_name": "Gemini 1.5 Pro", "model_vendor": "google", "model_version": "0827" }, "gemini/gemini-1.5-pro-latest": { "display_name": "Gemini 1.5 Pro", "model_vendor": "google" }, "gemini/gemini-2.0-flash": { "display_name": "Gemini 2.0 Flash", "model_vendor": "google" }, "gemini/gemini-2.0-flash-001": { "display_name": "Gemini 2.0 Flash", "model_vendor": "google", "model_version": "001" }, "gemini/gemini-2.0-flash-exp": { "display_name": "Gemini 2.0 Flash", "model_vendor": "google" }, "gemini/gemini-2.0-flash-lite": { "display_name": "Gemini 2.0 Flash", "model_vendor": "google" }, "gemini/gemini-2.0-flash-lite-preview-02-05": { "display_name": "Gemini 2.0 Flash", "model_vendor": "google" }, "gemini/gemini-2.0-flash-live-001": { "display_name": "Gemini 2.0 Flash", "model_vendor": "google", "model_version": "001" }, "gemini/gemini-2.0-flash-preview-image-generation": { "display_name": "Gemini 2.0 Flash", "model_vendor": "google" }, "gemini/gemini-2.0-flash-thinking-exp": { "display_name": "Gemini 2.0 Flash", "model_vendor": "google" }, "gemini/gemini-2.0-flash-thinking-exp-01-21": { "display_name": "Gemini 2.0 Flash", "model_vendor": "google" }, "gemini/gemini-2.0-pro-exp-02-05": { "display_name": "Gemini 2.0", "model_vendor": "google" }, "gemini/gemini-2.5-flash": { "display_name": "Gemini 2.5 Flash", "model_vendor": "google" }, "gemini/gemini-2.5-flash-image": { "display_name": "Gemini 2.5 Flash", "model_vendor": "google" }, "gemini/gemini-2.5-flash-image-preview": { "display_name": "Gemini 2.5 Flash", "model_vendor": "google" }, "gemini/gemini-2.5-flash-lite": { "display_name": "Gemini 2.5 Flash", "model_vendor": "google" }, "gemini/gemini-2.5-flash-lite-preview-06-17": { "display_name": "Gemini 2.5 Flash", "model_vendor": "google" }, "gemini/gemini-2.5-flash-lite-preview-09-2025": { "display_name": "Gemini 2.5 Flash", "model_vendor": "google", "model_version": "2025" }, "gemini/gemini-2.5-flash-preview-04-17": { "display_name": "Gemini 2.5 Flash", "model_vendor": "google" }, "gemini/gemini-2.5-flash-preview-05-20": { "display_name": "Gemini 2.5 Flash", "model_vendor": "google" }, "gemini/gemini-2.5-flash-preview-09-2025": { "display_name": "Gemini 2.5 Flash", "model_vendor": "google", "model_version": "2025" }, "gemini/gemini-2.5-flash-preview-tts": { "display_name": "Gemini 2.5 Flash", "model_vendor": "google" }, "gemini/gemini-2.5-pro": { "display_name": "Gemini 2.5 Pro", "model_vendor": "google" }, "gemini/gemini-2.5-pro-exp-03-25": { "display_name": "Gemini 2.5 Pro", "model_vendor": "google" }, "gemini/gemini-2.5-pro-preview-03-25": { "display_name": "Gemini 2.5 Pro", "model_vendor": "google" }, "gemini/gemini-2.5-pro-preview-05-06": { "display_name": "Gemini 2.5 Pro", "model_vendor": "google" }, "gemini/gemini-2.5-pro-preview-06-05": { "display_name": "Gemini 2.5 Pro", "model_vendor": "google" }, "gemini/gemini-2.5-pro-preview-tts": { "display_name": "Gemini 2.5 Pro", "model_vendor": "google" }, "gemini/gemini-3-pro-image-preview": { "display_name": "Gemini 1.0 Pro", "model_vendor": "google", "model_version": "preview" }, "gemini/gemini-3-pro-preview": { "display_name": "Gemini 1.0 Pro", "model_vendor": "google", "model_version": "preview" }, "gemini/gemini-embedding-001": { "display_name": "Gemini", "model_vendor": "google", "model_version": "001" }, "gemini/gemini-exp-1114": { "display_name": "Gemini", "model_vendor": "google", "model_version": "experimental" }, "gemini/gemini-exp-1206": { "display_name": "Gemini", "model_vendor": "google", "model_version": "experimental" }, "gemini/gemini-flash-latest": { "display_name": "Gemini", "model_vendor": "google", "model_version": "latest" }, "gemini/gemini-flash-lite-latest": { "display_name": "Gemini", "model_vendor": "google", "model_version": "latest" }, "gemini/gemini-gemma-2-27b-it": { "display_name": "Gemini", "model_vendor": "google", "model_version": "latest" }, "gemini/gemini-gemma-2-9b-it": { "display_name": "Gemini", "model_vendor": "google", "model_version": "latest" }, "gemini/gemini-live-2.5-flash-preview-native-audio-09-2025": { "display_name": "Gemini 2.5 Flash", "model_vendor": "google", "model_version": "preview" }, "gemini/gemini-pro": { "display_name": "Gemini 1.0 Pro", "model_vendor": "google" }, "gemini/gemini-pro-vision": { "display_name": "Gemini 1.0 Pro", "model_vendor": "google" }, "gemini/gemma-3-27b-it": { "display_name": "Gemini", "model_vendor": "google", "model_version": "latest" }, "gemini/imagen-3.0-fast-generate-001": { "display_name": "Gemini", "model_vendor": "google", "model_version": "001" }, "gemini/imagen-3.0-generate-001": { "display_name": "Gemini", "model_vendor": "google", "model_version": "001" }, "gemini/imagen-3.0-generate-002": { "display_name": "Gemini", "model_vendor": "google", "model_version": "002" }, "gemini/imagen-4.0-fast-generate-001": { "display_name": "Gemini", "model_vendor": "google", "model_version": "001" }, "gemini/imagen-4.0-generate-001": { "display_name": "Gemini", "model_vendor": "google", "model_version": "001" }, "gemini/imagen-4.0-ultra-generate-001": { "display_name": "Gemini", "model_vendor": "google", "model_version": "001" }, "gemini/learnlm-1.5-pro-experimental": { "display_name": "Gemini 1.5 Pro", "model_vendor": "google", "model_version": "experimental" }, "gemini/veo-2.0-generate-001": { "display_name": "Gemini 2.0", "model_vendor": "google", "model_version": "001" }, "gemini/veo-3.0-fast-generate-preview": { "display_name": "Gemini", "model_vendor": "google", "model_version": "preview" }, "gemini/veo-3.0-generate-preview": { "display_name": "Gemini", "model_vendor": "google", "model_version": "preview" }, "gemini/veo-3.1-fast-generate-preview": { "display_name": "Gemini", "model_vendor": "google", "model_version": "preview" }, "gemini/veo-3.1-generate-preview": { "display_name": "Gemini", "model_vendor": "google", "model_version": "preview" }, "global.anthropic.claude-haiku-4-5-20251001-v1:0": { "display_name": "Claude Haiku 4.5", "model_vendor": "anthropic", "model_version": "20251001" }, "global.anthropic.claude-sonnet-4-20250514-v1:0": { "display_name": "Claude Sonnet 4", "model_vendor": "anthropic", "model_version": "20250514" }, "global.anthropic.claude-sonnet-4-5-20250929-v1:0": { "display_name": "Claude Sonnet 4.5", "model_vendor": "anthropic", "model_version": "20250929" }, "gpt-3.5-turbo": { "display_name": "GPT-3.5 Turbo", "model_vendor": "openai" }, "gpt-3.5-turbo-0125": { "display_name": "GPT 3.5 Turbo 0125", "model_vendor": "openai", "model_version": "0125" }, "gpt-3.5-turbo-0301": { "display_name": "GPT 3.5 Turbo 0301", "model_vendor": "openai", "model_version": "0301" }, "gpt-3.5-turbo-0613": { "display_name": "GPT 3.5 Turbo 0613", "model_vendor": "openai", "model_version": "0613" }, "gpt-3.5-turbo-1106": { "display_name": "GPT 3.5 Turbo 1106", "model_vendor": "openai", "model_version": "1106" }, "gpt-3.5-turbo-16k": { "display_name": "GPT-3.5 Turbo 16K", "model_vendor": "openai" }, "gpt-3.5-turbo-16k-0613": { "display_name": "GPT 3.5 Turbo 16k 0613", "model_vendor": "openai", "model_version": "0613" }, "gpt-4": { "display_name": "GPT-4", "model_vendor": "openai" }, "gpt-4-0125-preview": { "display_name": "GPT-4 Preview", "model_vendor": "openai", "model_version": "0125" }, "gpt-4-0314": { "display_name": "GPT-4", "model_vendor": "openai", "model_version": "0314" }, "gpt-4-0613": { "display_name": "GPT-4", "model_vendor": "openai", "model_version": "0613" }, "gpt-4-1106-preview": { "display_name": "GPT-4 Preview", "model_vendor": "openai", "model_version": "1106" }, "gpt-4-1106-vision-preview": { "display_name": "GPT-4 Vision Preview", "model_vendor": "openai", "model_version": "1106" }, "gpt-4-32k": { "display_name": "GPT-4 32K", "model_vendor": "openai" }, "gpt-4-32k-0314": { "display_name": "GPT-4 32K", "model_vendor": "openai", "model_version": "0314" }, "gpt-4-32k-0613": { "display_name": "GPT-4 32K", "model_vendor": "openai", "model_version": "0613" }, "gpt-4-turbo": { "display_name": "GPT-4 Turbo", "model_vendor": "openai" }, "gpt-4-turbo-2024-04-09": { "display_name": "GPT-4 Turbo", "model_vendor": "openai", "model_version": "2024-04-09" }, "gpt-4-turbo-preview": { "display_name": "GPT-4 Turbo Preview", "model_vendor": "openai" }, "gpt-4-vision-preview": { "display_name": "GPT-4 Vision Preview", "model_vendor": "openai" }, "gpt-4.1": { "display_name": "GPT-4.1", "model_vendor": "openai" }, "gpt-4.1-2025-04-14": { "display_name": "GPT-4.1", "model_vendor": "openai", "model_version": "2025-04-14" }, "gpt-4.1-mini": { "display_name": "GPT-4.1 Mini", "model_vendor": "openai" }, "gpt-4.1-mini-2025-04-14": { "display_name": "GPT-4.1 Mini", "model_vendor": "openai", "model_version": "2025-04-14" }, "gpt-4.1-nano": { "display_name": "GPT-4.1 Nano", "model_vendor": "openai" }, "gpt-4.1-nano-2025-04-14": { "display_name": "GPT-4.1 Nano", "model_vendor": "openai", "model_version": "2025-04-14" }, "gpt-4.5-preview": { "display_name": "GPT-4.5 Preview", "model_vendor": "openai" }, "gpt-4.5-preview-2025-02-27": { "display_name": "GPT-4.5 Preview", "model_vendor": "openai", "model_version": "2025-02-27" }, "gpt-4o": { "display_name": "GPT-4o", "model_vendor": "openai" }, "gpt-4o-2024-05-13": { "display_name": "GPT-4o", "model_vendor": "openai", "model_version": "2024-05-13" }, "gpt-4o-2024-08-06": { "display_name": "GPT-4o", "model_vendor": "openai", "model_version": "2024-08-06" }, "gpt-4o-2024-11-20": { "display_name": "GPT-4o", "model_vendor": "openai", "model_version": "2024-11-20" }, "gpt-4o-audio-preview": { "display_name": "GPT-4o Audio Preview", "model_vendor": "openai" }, "gpt-4o-audio-preview-2024-10-01": { "display_name": "GPT-4o Audio Preview", "model_vendor": "openai", "model_version": "2024-10-01" }, "gpt-4o-audio-preview-2024-12-17": { "display_name": "GPT-4o Audio Preview", "model_vendor": "openai", "model_version": "2024-12-17" }, "gpt-4o-audio-preview-2025-06-03": { "display_name": "GPT-4o Audio Preview", "model_vendor": "openai", "model_version": "2025-06-03" }, "gpt-4o-mini": { "display_name": "GPT-4o Mini", "model_vendor": "openai" }, "gpt-4o-mini-2024-07-18": { "display_name": "GPT-4o Mini", "model_vendor": "openai", "model_version": "2024-07-18" }, "gpt-4o-mini-audio-preview": { "display_name": "GPT-4o Mini Audio Preview", "model_vendor": "openai" }, "gpt-4o-mini-audio-preview-2024-12-17": { "display_name": "GPT-4o Mini Audio Preview", "model_vendor": "openai", "model_version": "2024-12-17" }, "gpt-4o-mini-realtime-preview": { "display_name": "GPT-4o Mini Realtime Preview", "model_vendor": "openai" }, "gpt-4o-mini-realtime-preview-2024-12-17": { "display_name": "GPT-4o Mini Realtime Preview", "model_vendor": "openai", "model_version": "2024-12-17" }, "gpt-4o-mini-search-preview": { "display_name": "GPT 4o Mini Search Preview", "model_vendor": "openai" }, "gpt-4o-mini-search-preview-2025-03-11": { "display_name": "GPT 4o Mini Search Preview", "model_vendor": "openai", "model_version": "2025-03-11" }, "gpt-4o-realtime-preview": { "display_name": "GPT-4o Realtime Preview", "model_vendor": "openai" }, "gpt-4o-realtime-preview-2024-10-01": { "display_name": "GPT-4o Realtime Preview", "model_vendor": "openai", "model_version": "2024-10-01" }, "gpt-4o-realtime-preview-2024-12-17": { "display_name": "GPT-4o Realtime Preview", "model_vendor": "openai", "model_version": "2024-12-17" }, "gpt-4o-realtime-preview-2025-06-03": { "display_name": "GPT-4o Realtime Preview", "model_vendor": "openai", "model_version": "2025-06-03" }, "gpt-4o-search-preview": { "display_name": "GPT 4o Search Preview", "model_vendor": "openai" }, "gpt-4o-search-preview-2025-03-11": { "display_name": "GPT 4o Search Preview", "model_vendor": "openai", "model_version": "2025-03-11" }, "gpt-5": { "display_name": "GPT-5", "model_vendor": "openai" }, "gpt-5-2025-08-07": { "display_name": "GPT-5", "model_vendor": "openai", "model_version": "2025-08-07" }, "gpt-5-chat": { "display_name": "GPT 5 Chat", "model_vendor": "openai" }, "gpt-5-chat-latest": { "display_name": "GPT 5 Chat", "model_vendor": "openai" }, "gpt-5-codex": { "display_name": "GPT-5 Codex", "model_vendor": "openai" }, "gpt-5-mini": { "display_name": "GPT-5 Mini", "model_vendor": "openai" }, "gpt-5-mini-2025-08-07": { "display_name": "GPT-5 Mini", "model_vendor": "openai", "model_version": "2025-08-07" }, "gpt-5-nano": { "display_name": "GPT 5 Nano", "model_vendor": "openai" }, "gpt-5-nano-2025-08-07": { "display_name": "GPT 5 Nano", "model_vendor": "openai", "model_version": "2025-08-07" }, "gpt-5-pro": { "display_name": "GPT-5 Pro", "model_vendor": "openai" }, "gpt-5-pro-2025-10-06": { "display_name": "GPT-5 Pro", "model_vendor": "openai", "model_version": "2025-10-06" }, "gpt-5.4": { "display_name": "GPT-5.4", "model_vendor": "openai" }, "gpt-5.2-pro-2025-12-11": { "display_name": "GPT-5.2 Pro", "model_vendor": "openai", "model_version": "2025-12-11" }, "gpt-5.2-pro": { "display_name": "GPT-5.2 Pro", "model_vendor": "openai" }, "gpt-5.2-chat-latest": { "display_name": "GPT 5.2 Chat", "model_vendor": "openai" }, "gpt-5.2-2025-12-11": { "display_name": "GPT 5.2", "model_vendor": "openai", "model_version": "2025-12-11" }, "gpt-5.2": { "display_name": "GPT 5.2", "model_vendor": "openai" }, "gpt-5.1": { "display_name": "GPT 5.1", "model_vendor": "openai" }, "gpt-5.1-2025-11-13": { "display_name": "GPT 5.1", "model_vendor": "openai", "model_version": "2025-11-13" }, "gpt-5.1-chat-latest": { "display_name": "GPT 5.1 Chat", "model_vendor": "openai" }, "gpt-5.1-codex": { "display_name": "GPT-5.1 Codex", "model_vendor": "openai" }, "gpt-5.1-codex-mini": { "display_name": "GPT-5.1 Codex Mini", "model_vendor": "openai" }, "gpt-image-1-mini": { "display_name": "GPT Image 1 Mini", "model_vendor": "openai" }, "gpt-realtime": { "display_name": "GPT Realtime", "model_vendor": "openai", "model_version": "latest" }, "gpt-realtime-2025-08-28": { "display_name": "GPT Realtime", "model_vendor": "openai", "model_version": "2025-08-28" }, "gpt-realtime-mini": { "display_name": "GPT Realtime Mini", "model_vendor": "openai", "model_version": "latest" }, "jp.anthropic.claude-haiku-4-5-20251001-v1:0": { "display_name": "Claude Haiku 4.5", "model_vendor": "anthropic", "model_version": "20251001" }, "jp.anthropic.claude-sonnet-4-5-20250929-v1:0": { "display_name": "Claude Sonnet 4.5", "model_vendor": "anthropic", "model_version": "20250929" }, "medlm-large": { "display_name": "MedLM Large", "model_vendor": "google", "model_version": "latest" }, "medlm-medium": { "display_name": "MedLM Medium", "model_vendor": "google", "model_version": "latest" }, "meta.llama2-13b-chat-v1": { "display_name": "Llama 2 13B Chat", "model_vendor": "meta", "model_version": "v1" }, "meta.llama2-70b-chat-v1": { "display_name": "Llama 2 70B Chat", "model_vendor": "meta", "model_version": "v1" }, "meta.llama3-1-405b-instruct-v1:0": { "display_name": "Llama 3.1 405B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "meta.llama3-1-70b-instruct-v1:0": { "display_name": "Llama 3.1 70B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "meta.llama3-1-8b-instruct-v1:0": { "display_name": "Llama 3.1 8B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "meta.llama3-2-11b-instruct-v1:0": { "display_name": "Llama 3.2 11B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "meta.llama3-2-1b-instruct-v1:0": { "display_name": "Llama 3.2 1B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "meta.llama3-2-3b-instruct-v1:0": { "display_name": "Llama 3.2 3B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "meta.llama3-2-90b-instruct-v1:0": { "display_name": "Llama 3.2 90B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "meta.llama3-3-70b-instruct-v1:0": { "display_name": "Llama 3.3 70B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "meta.llama3-70b-instruct-v1:0": { "display_name": "Llama 3 70B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "meta.llama3-8b-instruct-v1:0": { "display_name": "Llama 3 8B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "meta.llama4-maverick-17b-instruct-v1:0": { "display_name": "Llama 4 Maverick 17B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "meta.llama4-scout-17b-instruct-v1:0": { "display_name": "Llama 4 Scout 17B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "mistral.mistral-7b-instruct-v0:2": { "display_name": "Mistral 7B Instruct", "model_vendor": "mistral", "model_version": "v0:2" }, "mistral.mistral-large-2402-v1:0": { "display_name": "Mistral Large 24.02", "model_vendor": "mistral", "model_version": "1:0" }, "mistral.mistral-large-2407-v1:0": { "display_name": "Mistral Large 24.07", "model_vendor": "mistral", "model_version": "1:0" }, "mistral.mistral-small-2402-v1:0": { "display_name": "Mistral Small 24.02", "model_vendor": "mistral", "model_version": "1:0" }, "mistral.mixtral-8x7b-instruct-v0:1": { "display_name": "Mixtral 8x7B Instruct", "model_vendor": "mistral", "model_version": "0:1" }, "mistral/codestral-2405": { "display_name": "Codestral", "model_vendor": "mistral", "model_version": "latest" }, "mistral/codestral-embed": { "display_name": "Codestral Embed", "model_vendor": "mistral", "model_version": "latest" }, "mistral/codestral-embed-2505": { "display_name": "Codestral Embed", "model_vendor": "mistral", "model_version": "latest" }, "mistral/codestral-latest": { "display_name": "Codestral", "model_vendor": "mistral", "model_version": "latest" }, "mistral/codestral-mamba-latest": { "display_name": "Codestral Mamba", "model_vendor": "mistral", "model_version": "latest" }, "mistral/devstral-medium-2507": { "display_name": "Devstral Medium", "model_vendor": "mistral", "model_version": "latest" }, "mistral/devstral-small-2505": { "display_name": "Devstral Small", "model_vendor": "mistral", "model_version": "latest" }, "mistral/devstral-small-2507": { "display_name": "Devstral Small", "model_vendor": "mistral", "model_version": "latest" }, "mistral/magistral-medium-2506": { "display_name": "Magistral Medium", "model_vendor": "mistral", "model_version": "latest" }, "mistral/magistral-medium-2509": { "display_name": "Magistral Medium", "model_vendor": "mistral", "model_version": "latest" }, "mistral/magistral-medium-latest": { "display_name": "Magistral Medium", "model_vendor": "mistral", "model_version": "latest" }, "mistral/magistral-small-2506": { "display_name": "Magistral Small", "model_vendor": "mistral", "model_version": "latest" }, "mistral/magistral-small-latest": { "display_name": "Magistral Small", "model_vendor": "mistral", "model_version": "latest" }, "mistral/mistral-embed": { "display_name": "Mistral Embed", "model_vendor": "mistral", "model_version": "latest" }, "mistral/mistral-large-2402": { "display_name": "Mistral Large", "model_vendor": "mistral", "model_version": "latest" }, "mistral/mistral-large-2407": { "display_name": "Mistral Large", "model_vendor": "mistral", "model_version": "latest" }, "mistral/mistral-large-2411": { "display_name": "Mistral Large", "model_vendor": "mistral", "model_version": "latest" }, "mistral/mistral-large-latest": { "display_name": "Mistral Large", "model_vendor": "mistral", "model_version": "latest" }, "mistral/mistral-medium": { "display_name": "Mistral Medium", "model_vendor": "mistral", "model_version": "latest" }, "mistral/mistral-medium-2312": { "display_name": "Mistral Medium", "model_vendor": "mistral", "model_version": "latest" }, "mistral/mistral-medium-2505": { "display_name": "Mistral Medium", "model_vendor": "mistral", "model_version": "latest" }, "mistral/mistral-medium-latest": { "display_name": "Mistral Medium", "model_vendor": "mistral", "model_version": "latest" }, "mistral/mistral-ocr-2505-completion": { "display_name": "Mistral", "model_vendor": "mistral", "model_version": "latest" }, "mistral/mistral-ocr-latest": { "display_name": "Mistral", "model_vendor": "mistral", "model_version": "latest" }, "mistral/mistral-small": { "display_name": "Mistral Small", "model_vendor": "mistral", "model_version": "latest" }, "mistral/mistral-small-latest": { "display_name": "Mistral Small", "model_vendor": "mistral", "model_version": "latest" }, "mistral/mistral-tiny": { "display_name": "Mistral Tiny", "model_vendor": "mistral", "model_version": "latest" }, "mistral/open-codestral-mamba": { "display_name": "Codestral Mamba", "model_vendor": "mistral", "model_version": "latest" }, "mistral/open-mistral-7b": { "display_name": "Open Mistral", "model_vendor": "mistral", "model_version": "latest" }, "mistral/open-mistral-nemo": { "display_name": "Open Mistral Nemo", "model_vendor": "mistral", "model_version": "latest" }, "mistral/open-mistral-nemo-2407": { "display_name": "Open Mistral Nemo", "model_vendor": "mistral", "model_version": "latest" }, "mistral/open-mixtral-8x22b": { "display_name": "Open Mixtral 8x22B", "model_vendor": "mistral", "model_version": "latest" }, "mistral/open-mixtral-8x7b": { "display_name": "Open Mixtral 8x7B", "model_vendor": "mistral", "model_version": "latest" }, "mistral/pixtral-12b-2409": { "display_name": "Pixtral", "model_vendor": "mistral", "model_version": "latest" }, "mistral/pixtral-large-2411": { "display_name": "Pixtral Large", "model_vendor": "mistral", "model_version": "latest" }, "mistral/pixtral-large-latest": { "display_name": "Pixtral Large", "model_vendor": "mistral", "model_version": "latest" }, "o1": { "display_name": "o1", "model_vendor": "openai", "model_version": "latest" }, "o1-2024-12-17": { "display_name": "o1", "model_vendor": "openai", "model_version": "2024-12-17" }, "o1-mini": { "display_name": "o1 Mini", "model_vendor": "openai", "model_version": "latest" }, "o1-mini-2024-09-12": { "display_name": "o1 Mini", "model_vendor": "openai", "model_version": "latest" }, "o1-preview": { "display_name": "o1 Preview", "model_vendor": "openai", "model_version": "latest" }, "o1-preview-2024-09-12": { "display_name": "o1 Preview", "model_vendor": "openai", "model_version": "latest" }, "o1-pro": { "display_name": "o1 Pro", "model_vendor": "openai" }, "o1-pro-2025-03-19": { "display_name": "o1 Pro", "model_vendor": "openai", "model_version": "2025-03-19" }, "o3": { "display_name": "o3", "model_vendor": "openai", "model_version": "latest" }, "o3-2025-04-16": { "display_name": "o3", "model_vendor": "openai", "model_version": "latest" }, "o3-deep-research": { "display_name": "o3 Deep Research", "model_vendor": "openai" }, "o3-deep-research-2025-06-26": { "display_name": "o3 Deep Research", "model_vendor": "openai", "model_version": "2025-06-26" }, "o3-mini": { "display_name": "o3 Mini", "model_vendor": "openai", "model_version": "latest" }, "o3-mini-2025-01-31": { "display_name": "o3 Mini", "model_vendor": "openai", "model_version": "latest" }, "o3-pro": { "display_name": "o3 Pro", "model_vendor": "openai" }, "o3-pro-2025-06-10": { "display_name": "o3 Pro", "model_vendor": "openai", "model_version": "2025-06-10" }, "o4-mini": { "display_name": "o4 Mini", "model_vendor": "openai", "model_version": "latest" }, "o4-mini-2025-04-16": { "display_name": "o4 Mini", "model_vendor": "openai", "model_version": "latest" }, "o4-mini-deep-research": { "display_name": "o4 Mini Deep Research", "model_vendor": "openai" }, "o4-mini-deep-research-2025-06-26": { "display_name": "o4 Mini Deep Research", "model_vendor": "openai", "model_version": "2025-06-26" }, "ollama/codegeex4": { "display_name": "CodeGeeX4", "model_vendor": "zhipu", "model_version": "latest" }, "ollama/codegemma": { "display_name": "Codegemma", "model_vendor": "google", "model_version": "latest" }, "ollama/codellama": { "display_name": "CodeLlama", "model_vendor": "meta", "model_version": "latest" }, "ollama/deepseek-coder-v2-base": { "display_name": "DeepSeek Coder v2 Base", "model_vendor": "deepseek", "model_version": "latest" }, "ollama/deepseek-coder-v2-instruct": { "display_name": "DeepSeek Coder v2 Instruct", "model_vendor": "deepseek", "model_version": "latest" }, "ollama/deepseek-coder-v2-lite-base": { "display_name": "DeepSeek Coder v2 Lite Base", "model_vendor": "deepseek", "model_version": "latest" }, "ollama/deepseek-coder-v2-lite-instruct": { "display_name": "DeepSeek Coder v2 Lite Instruct", "model_vendor": "deepseek", "model_version": "latest" }, "ollama/deepseek-v3.1:671b-cloud": { "display_name": "DeepSeek V3.1:671B Cloud", "model_vendor": "deepseek", "model_version": "latest" }, "ollama/gpt-oss:120b-cloud": { "display_name": "GPT Open-Source 120B", "model_vendor": "openai", "model_version": "latest" }, "ollama/gpt-oss:20b-cloud": { "display_name": "GPT Open-Source 20B", "model_vendor": "openai", "model_version": "latest" }, "ollama/internlm2_5-20b-chat": { "display_name": "InternLM 2.5 20B Chat", "model_vendor": "shanghai-ai-lab", "model_version": "latest" }, "ollama/llama2": { "display_name": "Llama 2", "model_vendor": "meta" }, "ollama/llama2-uncensored": { "display_name": "Llama 2 Uncensored", "model_vendor": "meta" }, "ollama/llama2:13b": { "display_name": "Llama 2:13B", "model_vendor": "meta" }, "ollama/llama2:70b": { "display_name": "Llama 2:70B", "model_vendor": "meta" }, "ollama/llama2:7b": { "display_name": "Llama 2:7B", "model_vendor": "meta" }, "ollama/llama3": { "display_name": "Llama 3", "model_vendor": "meta" }, "ollama/llama3.1": { "display_name": "Llama 3.1", "model_vendor": "meta" }, "ollama/llama3:70b": { "display_name": "Llama 3:70B", "model_vendor": "meta" }, "ollama/llama3:8b": { "display_name": "Llama 3:8B", "model_vendor": "meta" }, "ollama/mistral": { "display_name": "Mistral", "model_vendor": "mistral", "model_version": "latest" }, "ollama/mistral-7B-Instruct-v0.1": { "display_name": "Mistral 7B Instruct", "model_vendor": "mistral", "model_version": "v0.1" }, "ollama/mistral-7B-Instruct-v0.2": { "display_name": "Mistral 7B Instruct", "model_vendor": "mistral", "model_version": "v0.2" }, "ollama/mistral-large-instruct-2407": { "display_name": "Mistral Large Instruct 24.07", "model_vendor": "mistral", "model_version": "latest" }, "ollama/mixtral-8x22B-Instruct-v0.1": { "display_name": "Mixtral 8x22B Instruct V0.1", "model_vendor": "mistral", "model_version": "latest" }, "ollama/mixtral-8x7B-Instruct-v0.1": { "display_name": "Mixtral 8x7B Instruct V0.1", "model_vendor": "mistral", "model_version": "latest" }, "ollama/orca-mini": { "display_name": "Orca Mini", "model_vendor": "microsoft", "model_version": "latest" }, "ollama/qwen3-coder:480b-cloud": { "display_name": "Qwen3 Coder:480B Cloud", "model_vendor": "alibaba", "model_version": "latest" }, "ollama/vicuna": { "display_name": "Vicuna", "model_vendor": "lmsys", "model_version": "latest" }, "openai.gpt-oss-120b-1:0": { "display_name": "GPT Open-Source 120B", "model_vendor": "openai", "model_version": "v1:0" }, "openai.gpt-oss-20b-1:0": { "display_name": "GPT Open-Source 20B", "model_vendor": "openai", "model_version": "v1:0" }, "openai/container": { "display_name": "Container", "model_vendor": "openai", "model_version": "latest" }, "openrouter/agentica-org/deepcoder-14b-preview": { "display_name": "DeepCoder 14B Preview", "model_vendor": "agentica" }, "openrouter/ai21/jamba-1-5-large": { "display_name": "Jamba 1.5 Large", "model_vendor": "ai21" }, "openrouter/ai21/jamba-1-5-mini": { "display_name": "Jamba 1.5 Mini", "model_vendor": "ai21" }, "openrouter/ai21/jamba-large-1.7": { "display_name": "Jamba Large 1.7", "model_vendor": "ai21" }, "openrouter/aion-labs/aion-1.0": { "display_name": "AION 1.0", "model_vendor": "aion-labs" }, "openrouter/alibaba/qwen-2.5-72b-instruct": { "display_name": "Qwen 2.5 72B Instruct", "model_vendor": "alibaba" }, "openrouter/alibaba/qwen-2.5-coder-32b-instruct": { "display_name": "Qwen 2.5 Coder 32B", "model_vendor": "alibaba" }, "openrouter/alibaba/tongyi-deepresearch-30b-a3b": { "display_name": "Tongyi DeepResearch 30B", "model_vendor": "alibaba" }, "openrouter/alibaba/tongyi-deepresearch-30b-a3b:free": { "display_name": "Tongyi DeepResearch 30B (Free)", "model_vendor": "alibaba" }, "openrouter/anthropic/claude-2": { "display_name": "Claude 2", "model_vendor": "anthropic", "model_version": "latest" }, "openrouter/anthropic/claude-3-sonnet": { "display_name": "Claude Sonnet 3", "model_vendor": "anthropic" }, "openrouter/anthropic/claude-3.5-sonnet": { "display_name": "Claude Sonnet 3.5", "model_vendor": "anthropic", "model_version": "latest" }, "openrouter/anthropic/claude-3.5-sonnet:beta": { "display_name": "Claude Sonnet 3.5:beta", "model_vendor": "anthropic", "model_version": "latest" }, "openrouter/anthropic/claude-haiku-4.5": { "display_name": "Claude Haiku 4.5", "model_vendor": "anthropic", "model_version": "latest" }, "openrouter/anthropic/claude-instant-v1": { "display_name": "Claude Instant", "model_vendor": "anthropic", "model_version": "v1" }, "openrouter/anthropic/claude-opus-4": { "display_name": "Claude Opus 4", "model_vendor": "anthropic" }, "openrouter/anthropic/claude-opus-4.1": { "display_name": "Claude Opus 4.1", "model_vendor": "anthropic" }, "openrouter/anthropic/claude-opus-4.5": { "display_name": "Claude Opus 4.5", "model_vendor": "anthropic" }, "openrouter/anthropic/claude-sonnet-4": { "display_name": "Claude Sonnet 4", "model_vendor": "anthropic" }, "openrouter/anthropic/claude-sonnet-4.5": { "display_name": "Claude Sonnet 4.5", "model_vendor": "anthropic" }, "openrouter/baidu/ernie-4.5-300b-a47b": { "display_name": "ERNIE 4.5 300B", "model_vendor": "baidu" }, "openrouter/baidu/ernie-4.5-vl-28b-a3b": { "display_name": "ERNIE 4.5 VL 28B", "model_vendor": "baidu" }, "openrouter/bytedance/ui-tars-1.5-7b": { "display_name": "UI-TARS 1.5 7B", "model_vendor": "bytedance", "model_version": "latest" }, "openrouter/cognitivecomputations/dolphin-mixtral-8x7b": { "display_name": "Dolphin Mixtral 8x7B", "model_vendor": "mistral", "model_version": "latest" }, "openrouter/cohere/command-a": { "display_name": "Command A", "model_vendor": "cohere" }, "openrouter/cohere/command-r": { "display_name": "Command R", "model_vendor": "cohere" }, "openrouter/cohere/command-r-08-2024": { "display_name": "Command R", "model_vendor": "cohere", "model_version": "08-2024" }, "openrouter/cohere/command-r-plus": { "display_name": "Command R Plus", "model_vendor": "cohere" }, "openrouter/cohere/command-r-plus-08-2024": { "display_name": "Command R Plus", "model_vendor": "cohere", "model_version": "08-2024" }, "openrouter/databricks/dbrx-instruct": { "display_name": "DBRX Instruct", "model_vendor": "databricks", "model_version": "latest" }, "openrouter/deepcogito/cogito-v2-preview-deepseek-671b": { "display_name": "Cogito V2 Preview DeepSeek 671B", "model_vendor": "deepcogito" }, "openrouter/deepcogito/cogito-v2-preview-llama-109b-moe": { "display_name": "Cogito V2 Preview Llama 109B MoE", "model_vendor": "deepcogito" }, "openrouter/deepseek/deepseek-chat": { "display_name": "DeepSeek Chat", "model_vendor": "deepseek", "model_version": "latest" }, "openrouter/deepseek/deepseek-chat-v3-0324": { "display_name": "DeepSeek Chat v3 0324", "model_vendor": "deepseek", "model_version": "latest" }, "openrouter/deepseek/deepseek-chat-v3.1": { "display_name": "DeepSeek Chat V3.1", "model_vendor": "deepseek", "model_version": "latest" }, "openrouter/deepseek/deepseek-coder": { "display_name": "DeepSeek Coder", "model_vendor": "deepseek", "model_version": "latest" }, "openrouter/deepseek/deepseek-r1": { "display_name": "DeepSeek R1", "model_vendor": "deepseek", "model_version": "latest" }, "openrouter/deepseek/deepseek-r1-0528": { "display_name": "DeepSeek R1 0528", "model_vendor": "deepseek", "model_version": "latest" }, "openrouter/deepseek/deepseek-v3.2-exp": { "display_name": "DeepSeek V3.2", "model_vendor": "deepseek", "model_version": "experimental" }, "openrouter/fireworks/firellava-13b": { "display_name": "FireLLaVA 13B", "model_vendor": "fireworks", "model_version": "latest" }, "openrouter/google/gemini-2.0-flash-001": { "display_name": "Gemini 2.0 Flash", "model_vendor": "google", "model_version": "001" }, "openrouter/google/gemini-2.5-flash": { "display_name": "Gemini 2.5 Flash", "model_vendor": "google" }, "openrouter/google/gemini-2.5-pro": { "display_name": "Gemini 2.5 Pro", "model_vendor": "google" }, "openrouter/google/gemini-3-pro-preview": { "display_name": "Gemini 3 Pro Preview", "model_vendor": "google", "model_version": "preview" }, "openrouter/google/gemini-pro-1.5": { "display_name": "Gemini Pro 1.5", "model_vendor": "google" }, "openrouter/google/gemini-pro-vision": { "display_name": "Gemini Pro Vision", "model_vendor": "google" }, "openrouter/google/gemma-2-27b-it": { "display_name": "Gemma 2 27B", "model_vendor": "google" }, "openrouter/google/gemma-2-9b-it": { "display_name": "Gemma 2 9B", "model_vendor": "google" }, "openrouter/google/gemma-2-9b-it:free": { "display_name": "Gemma 2 9B (Free)", "model_vendor": "google" }, "openrouter/google/gemma-3n-e4b-it": { "display_name": "Gemma 3N E4B", "model_vendor": "google" }, "openrouter/google/gemma-3n-e4b-it:free": { "display_name": "Gemma 3N E4B (Free)", "model_vendor": "google" }, "openrouter/google/palm-2-chat-bison": { "display_name": "PaLM 2 Chat Bison", "model_vendor": "google", "model_version": "latest" }, "openrouter/google/palm-2-codechat-bison": { "display_name": "PaLM 2 Codechat Bison", "model_vendor": "google", "model_version": "latest" }, "openrouter/gryphe/mythomax-l2-13b": { "display_name": "MythoMax L2 13B", "model_vendor": "gryphe", "model_version": "latest" }, "openrouter/inclusionai/ring-1t": { "display_name": "Ring 1T", "model_vendor": "inclusionai" }, "openrouter/jondurbin/airoboros-l2-70b-2.1": { "display_name": "Airoboros L2 70B", "model_vendor": "jondurbin" }, "openrouter/mancer/weaver": { "display_name": "Weaver", "model_vendor": "mancer", "model_version": "latest" }, "openrouter/meta-llama/codellama-34b-instruct": { "display_name": "CodeLlama 34B Instruct", "model_vendor": "meta" }, "openrouter/meta-llama/llama-2-13b-chat": { "display_name": "Llama 2 13B Chat", "model_vendor": "meta" }, "openrouter/meta-llama/llama-2-70b-chat": { "display_name": "Llama 2 70B Chat", "model_vendor": "meta" }, "openrouter/meta-llama/llama-3-70b-instruct": { "display_name": "Llama 3 70B Instruct", "model_vendor": "meta" }, "openrouter/meta-llama/llama-3-70b-instruct:nitro": { "display_name": "Llama 3 70B Instruct:nitro", "model_vendor": "meta" }, "openrouter/meta-llama/llama-3-8b-instruct:extended": { "display_name": "Llama 3 8B Instruct:extended", "model_vendor": "meta" }, "openrouter/meta-llama/llama-3-8b-instruct:free": { "display_name": "Llama 3 8B Instruct:free", "model_vendor": "meta" }, "openrouter/microsoft/wizardlm-2-8x22b:nitro": { "display_name": "WizardLM 2 8x22B", "model_vendor": "microsoft", "model_version": "latest" }, "openrouter/minimax/minimax-m2": { "display_name": "MiniMax M2", "model_vendor": "minimax", "model_version": "latest" }, "openrouter/mistralai/mistral-7b-instruct": { "display_name": "Mistral 7B Instruct", "model_vendor": "mistral", "model_version": "latest" }, "openrouter/mistralai/mistral-7b-instruct:free": { "display_name": "Mistral 7B Instruct", "model_vendor": "mistral", "model_version": "latest" }, "openrouter/mistralai/mistral-large": { "display_name": "Mistral Large", "model_vendor": "mistral", "model_version": "latest" }, "openrouter/mistralai/mistral-small-3.1-24b-instruct": { "display_name": "Mistral Small 3.1 24B Instruct", "model_vendor": "mistral", "model_version": "latest" }, "openrouter/mistralai/mistral-small-3.2-24b-instruct": { "display_name": "Mistral Small 3.2 24B Instruct", "model_vendor": "mistral", "model_version": "latest" }, "openrouter/mistralai/mixtral-8x22b-instruct": { "display_name": "Mixtral 8x22B Instruct", "model_vendor": "mistral", "model_version": "latest" }, "openrouter/nousresearch/nous-hermes-llama2-13b": { "display_name": "Nous Hermes Llama 2 13B", "model_vendor": "meta" }, "openrouter/openai/gpt-3.5-turbo": { "display_name": "GPT-3.5 Turbo", "model_vendor": "openai" }, "openrouter/openai/gpt-3.5-turbo-16k": { "display_name": "GPT-3.5 Turbo 16K", "model_vendor": "openai" }, "openrouter/openai/gpt-4": { "display_name": "GPT-4", "model_vendor": "openai" }, "openrouter/openai/gpt-4-vision-preview": { "display_name": "GPT-4 Vision Preview", "model_vendor": "openai" }, "openrouter/openai/gpt-4.1": { "display_name": "GPT-4.1", "model_vendor": "openai" }, "openrouter/openai/gpt-4.1-2025-04-14": { "display_name": "GPT-4.1", "model_vendor": "openai", "model_version": "2025-04-14" }, "openrouter/openai/gpt-4.1-mini": { "display_name": "GPT-4.1 Mini", "model_vendor": "openai" }, "openrouter/openai/gpt-4.1-mini-2025-04-14": { "display_name": "GPT-4.1 Mini", "model_vendor": "openai", "model_version": "2025-04-14" }, "openrouter/openai/gpt-4.1-nano": { "display_name": "GPT-4.1 Nano", "model_vendor": "openai" }, "openrouter/openai/gpt-4.1-nano-2025-04-14": { "display_name": "GPT-4.1 Nano", "model_vendor": "openai", "model_version": "2025-04-14" }, "openrouter/openai/gpt-4o": { "display_name": "GPT-4o", "model_vendor": "openai" }, "openrouter/openai/gpt-4o-2024-05-13": { "display_name": "GPT-4o", "model_vendor": "openai", "model_version": "2024-05-13" }, "openrouter/openai/gpt-5": { "display_name": "GPT-5", "model_vendor": "openai" }, "openrouter/openai/gpt-5-chat": { "display_name": "GPT 5 Chat", "model_vendor": "openai" }, "openrouter/openai/gpt-5-codex": { "display_name": "GPT-5 Codex", "model_vendor": "openai" }, "openrouter/openai/gpt-5-mini": { "display_name": "GPT-5 Mini", "model_vendor": "openai" }, "openrouter/openai/gpt-5-nano": { "display_name": "GPT 5 Nano", "model_vendor": "openai" }, "openrouter/openai/gpt-oss-120b": { "display_name": "GPT Open-Source 120B", "model_vendor": "openai", "model_version": "latest" }, "openrouter/openai/gpt-oss-20b": { "display_name": "GPT Open-Source 20B", "model_vendor": "openai", "model_version": "latest" }, "openrouter/openai/o1": { "display_name": "o1", "model_vendor": "openai", "model_version": "latest" }, "openrouter/openai/o1-mini": { "display_name": "o1 Mini", "model_vendor": "openai", "model_version": "latest" }, "openrouter/openai/o1-mini-2024-09-12": { "display_name": "o1 Mini", "model_vendor": "openai", "model_version": "latest" }, "openrouter/openai/o1-preview": { "display_name": "o1 Preview", "model_vendor": "openai", "model_version": "latest" }, "openrouter/openai/o1-preview-2024-09-12": { "display_name": "o1 Preview", "model_vendor": "openai", "model_version": "latest" }, "openrouter/openai/o3-mini": { "display_name": "o3 Mini", "model_vendor": "openai", "model_version": "latest" }, "openrouter/openai/o3-mini-high": { "display_name": "O3 Mini High", "model_vendor": "openai", "model_version": "latest" }, "openrouter/pygmalionai/mythalion-13b": { "display_name": "Mythalion 13B", "model_vendor": "pygmalionai", "model_version": "latest" }, "openrouter/qwen/qwen-2.5-coder-32b-instruct": { "display_name": "Qwen 2.5 Coder 32B Instruct", "model_vendor": "alibaba", "model_version": "latest" }, "openrouter/qwen/qwen-vl-plus": { "display_name": "Qwen Vl Plus", "model_vendor": "alibaba", "model_version": "latest" }, "openrouter/qwen/qwen3-coder": { "display_name": "Qwen3 Coder", "model_vendor": "alibaba", "model_version": "latest" }, "openrouter/switchpoint/router": { "display_name": "SwitchPoint Router", "model_vendor": "switchpoint", "model_version": "latest" }, "openrouter/undi95/remm-slerp-l2-13b": { "display_name": "ReMM SLERP L2 13B", "model_vendor": "undi95", "model_version": "latest" }, "openrouter/x-ai/grok-4": { "display_name": "Grok 4", "model_vendor": "xai", "model_version": "latest" }, "openrouter/x-ai/grok-4-fast:free": { "display_name": "Grok 4 Fast:free", "model_vendor": "xai", "model_version": "latest" }, "openrouter/z-ai/glm-4.6": { "display_name": "GLM 4.6", "model_vendor": "zhipu", "model_version": "latest" }, "openrouter/z-ai/glm-4.6:exacto": { "display_name": "GLM 4.6 Exacto", "model_vendor": "zhipu", "model_version": "latest" }, "qwen.qwen3-235b-a22b-2507-v1:0": { "display_name": "Qwen.qwen3 235B A22b 2507", "model_vendor": "alibaba", "model_version": "1:0" }, "qwen.qwen3-32b-v1:0": { "display_name": "Qwen.qwen3 32B", "model_vendor": "alibaba", "model_version": "1:0" }, "qwen.qwen3-coder-30b-a3b-v1:0": { "display_name": "Qwen.qwen3 Coder 30B A3b", "model_vendor": "alibaba", "model_version": "1:0" }, "qwen.qwen3-coder-480b-a35b-v1:0": { "display_name": "Qwen.qwen3 Coder 480B A35b", "model_vendor": "alibaba", "model_version": "1:0" }, "twelvelabs.pegasus-1-2-v1:0": { "display_name": "Pegasus 1.2", "model_vendor": "twelvelabs", "model_version": "v1:0" }, "us.amazon.nova-lite-v1:0": { "display_name": "Nova Lite", "model_vendor": "amazon", "model_version": "1:0" }, "us.amazon.nova-micro-v1:0": { "display_name": "Nova Micro", "model_vendor": "amazon", "model_version": "1:0" }, "us.amazon.nova-premier-v1:0": { "display_name": "Nova Premier", "model_vendor": "amazon", "model_version": "v1:0" }, "us.amazon.nova-pro-v1:0": { "display_name": "Nova Pro", "model_vendor": "amazon", "model_version": "1:0" }, "us.anthropic.claude-3-5-sonnet-20240620-v1:0": { "display_name": "Claude Sonnet 3.5", "model_vendor": "anthropic", "model_version": "20240620" }, "us.anthropic.claude-3-5-sonnet-20241022-v2:0": { "display_name": "Claude Sonnet 3.5", "model_vendor": "anthropic", "model_version": "20241022" }, "us.anthropic.claude-3-sonnet-20240229-v1:0": { "display_name": "Claude Sonnet 3", "model_vendor": "anthropic", "model_version": "20240229" }, "us.anthropic.claude-haiku-4-5-20251001-v1:0": { "display_name": "Claude Haiku 4.5", "model_vendor": "anthropic", "model_version": "20251001" }, "us.anthropic.claude-opus-4-1-20250805-v1:0": { "display_name": "Claude Opus 4.1", "model_vendor": "anthropic", "model_version": "20250805" }, "us.anthropic.claude-opus-4-20250514-v1:0": { "display_name": "Claude Opus 4", "model_vendor": "anthropic", "model_version": "20250514" }, "us.anthropic.claude-opus-4-5-20251101-v1:0": { "display_name": "Claude Opus 4.5", "model_vendor": "anthropic", "model_version": "20251101" }, "us.anthropic.claude-sonnet-4-20250514-v1:0": { "display_name": "Claude Sonnet 4", "model_vendor": "anthropic", "model_version": "20250514" }, "us.anthropic.claude-sonnet-4-5-20250929-v1:0": { "display_name": "Claude Sonnet 4.5", "model_vendor": "anthropic", "model_version": "20250929" }, "us.deepseek.r1-v1:0": { "display_name": "DeepSeek R1", "model_vendor": "deepseek", "model_version": "v1:0" }, "us.meta.llama3-1-405b-instruct-v1:0": { "display_name": "Llama 3.1 405B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "us.meta.llama3-1-70b-instruct-v1:0": { "display_name": "Llama 3.1 70B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "us.meta.llama3-1-8b-instruct-v1:0": { "display_name": "Llama 3.1 8B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "us.meta.llama3-2-11b-instruct-v1:0": { "display_name": "Llama 3.2 11B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "us.meta.llama3-2-1b-instruct-v1:0": { "display_name": "Llama 3.2 1B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "us.meta.llama3-2-3b-instruct-v1:0": { "display_name": "Llama 3.2 3B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "us.meta.llama3-2-90b-instruct-v1:0": { "display_name": "Llama 3.2 90B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "us.meta.llama3-3-70b-instruct-v1:0": { "display_name": "Llama 3.3 70B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "us.meta.llama4-maverick-17b-instruct-v1:0": { "display_name": "Llama 4 Maverick 17B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "us.meta.llama4-scout-17b-instruct-v1:0": { "display_name": "Llama 4 Scout 17B Instruct", "model_vendor": "meta", "model_version": "1:0" }, "us.mistral.pixtral-large-2502-v1:0": { "display_name": "Pixtral Large 25.02", "model_vendor": "mistral", "model_version": "1:0" }, "us.twelvelabs.pegasus-1-2-v1:0": { "display_name": "Pegasus 1.2", "model_vendor": "twelvelabs", "model_version": "v1:0" }, "vertex_ai/claude-3-5-sonnet": { "display_name": "Claude Sonnet 3.5", "model_vendor": "anthropic" }, "vertex_ai/claude-3-5-sonnet@20240620": { "display_name": "Claude Sonnet 3.5", "model_vendor": "anthropic", "model_version": "20240620" }, "vertex_ai/claude-3-sonnet": { "display_name": "Claude Sonnet 3", "model_vendor": "anthropic" }, "vertex_ai/claude-3-sonnet@20240229": { "display_name": "Claude Sonnet 3", "model_vendor": "anthropic", "model_version": "20240229" }, "vertex_ai/claude-haiku-4-5": { "display_name": "Claude Haiku 4.5", "model_vendor": "anthropic" }, "vertex_ai/claude-haiku-4-5@20251001": { "display_name": "Claude Haiku 4.5", "model_vendor": "anthropic", "model_version": "20251001" }, "vertex_ai/claude-opus-4": { "display_name": "Claude Opus 4", "model_vendor": "anthropic" }, "vertex_ai/claude-opus-4-1": { "display_name": "Claude Opus 4.1", "model_vendor": "anthropic" }, "vertex_ai/claude-opus-4-1@20250805": { "display_name": "Claude Opus 4.1", "model_vendor": "anthropic", "model_version": "20250805" }, "vertex_ai/claude-opus-4-5": { "display_name": "Claude Opus 4.5", "model_vendor": "anthropic" }, "vertex_ai/claude-opus-4-5@20251101": { "display_name": "Claude Opus 4.5", "model_vendor": "anthropic", "model_version": "20251101" }, "vertex_ai/claude-opus-4@20250514": { "display_name": "Claude Opus 4", "model_vendor": "anthropic", "model_version": "20250514" }, "vertex_ai/claude-sonnet-4": { "display_name": "Claude Sonnet 4", "model_vendor": "anthropic" }, "vertex_ai/claude-sonnet-4-5": { "display_name": "Claude Sonnet 4.5", "model_vendor": "anthropic" }, "vertex_ai/claude-sonnet-4-5@20250929": { "display_name": "Claude Sonnet 4.5", "model_vendor": "anthropic", "model_version": "20250929" }, "vertex_ai/claude-sonnet-4@20250514": { "display_name": "Claude Sonnet 4", "model_vendor": "anthropic", "model_version": "20250514" }, "vertex_ai/codestral-2": { "display_name": "Codestral 2", "model_vendor": "mistral", "model_version": "latest" }, "vertex_ai/codestral-2501": { "display_name": "Codestral 25.01", "model_vendor": "mistral", "model_version": "latest" }, "vertex_ai/codestral-2@001": { "display_name": "Codestral 2@001", "model_vendor": "mistral", "model_version": "latest" }, "vertex_ai/codestral@2405": { "display_name": "Codestral@2405", "model_vendor": "mistral", "model_version": "latest" }, "vertex_ai/codestral@latest": { "display_name": "Codestral@latest", "model_vendor": "mistral", "model_version": "latest" }, "vertex_ai/deepseek-ai/deepseek-r1-0528-maas": { "display_name": "DeepSeek R1 0528 Maas", "model_vendor": "deepseek", "model_version": "latest" }, "vertex_ai/deepseek-ai/deepseek-v3.1-maas": { "display_name": "DeepSeek V3.1 Maas", "model_vendor": "deepseek", "model_version": "latest" }, "vertex_ai/gemini-2.5-flash": { "display_name": "Gemini 2.5 Flash", "model_vendor": "google" }, "vertex_ai/gemini-2.5-flash-lite": { "display_name": "Gemini 2.5 Flash Lite", "model_vendor": "google" }, "vertex_ai/gemini-2.5-pro": { "display_name": "Gemini 2.5 Pro", "model_vendor": "google" }, "vertex_ai/gemini-3-pro-preview": { "display_name": "Gemini 3 Pro Preview", "model_vendor": "google", "model_version": "preview" }, "vertex_ai/gemini-3-flash-preview": { "display_name": "Gemini 3 Flash Preview", "model_vendor": "google", "model_version": "preview" }, "vertex_ai/jamba-1.5": { "display_name": "Jamba 1.5", "model_vendor": "ai21", "model_version": "latest" }, "vertex_ai/jamba-1.5-large": { "display_name": "Jamba 1.5 Large", "model_vendor": "ai21", "model_version": "latest" }, "vertex_ai/jamba-1.5-large@001": { "display_name": "Jamba 1.5 Large@001", "model_vendor": "ai21", "model_version": "latest" }, "vertex_ai/jamba-1.5-mini": { "display_name": "Jamba 1.5 Mini", "model_vendor": "ai21", "model_version": "latest" }, "vertex_ai/jamba-1.5-mini@001": { "display_name": "Jamba 1.5 Mini@001", "model_vendor": "ai21", "model_version": "latest" }, "vertex_ai/meta/llama-3.1-405b-instruct-maas": { "display_name": "Llama 3.1 405B Instruct Maas", "model_vendor": "meta" }, "vertex_ai/meta/llama-3.1-70b-instruct-maas": { "display_name": "Llama 3.1 70B Instruct Maas", "model_vendor": "meta" }, "vertex_ai/meta/llama-3.1-8b-instruct-maas": { "display_name": "Llama 3.1 8B Instruct Maas", "model_vendor": "meta" }, "vertex_ai/meta/llama-3.2-90b-vision-instruct-maas": { "display_name": "Llama 3.2 90B Vision Instruct Maas", "model_vendor": "meta" }, "vertex_ai/meta/llama-4-maverick-17b-128e-instruct-maas": { "display_name": "Llama 4 Maverick 17B 128e Instruct Maas", "model_vendor": "meta" }, "vertex_ai/meta/llama-4-maverick-17b-16e-instruct-maas": { "display_name": "Llama 4 Maverick 17B 16e Instruct Maas", "model_vendor": "meta" }, "vertex_ai/meta/llama-4-scout-17b-128e-instruct-maas": { "display_name": "Llama 4 Scout 17B 128e Instruct Maas", "model_vendor": "meta" }, "vertex_ai/meta/llama-4-scout-17b-16e-instruct-maas": { "display_name": "Llama 4 Scout 17B 16e Instruct Maas", "model_vendor": "meta" }, "vertex_ai/meta/llama3-405b-instruct-maas": { "display_name": "Llama 3 405B Instruct Maas", "model_vendor": "meta" }, "vertex_ai/meta/llama3-70b-instruct-maas": { "display_name": "Llama 3 70B Instruct Maas", "model_vendor": "meta" }, "vertex_ai/meta/llama3-8b-instruct-maas": { "display_name": "Llama 3 8B Instruct Maas", "model_vendor": "meta" }, "vertex_ai/minimaxai/minimax-m2-maas": { "display_name": "MiniMax M2", "model_vendor": "minimax", "model_version": "latest" }, "vertex_ai/mistral-large-2411": { "display_name": "Mistral Large 24.11", "model_vendor": "mistral", "model_version": "latest" }, "vertex_ai/mistral-large@2407": { "display_name": "Mistral Large@24.07", "model_vendor": "mistral", "model_version": "latest" }, "vertex_ai/mistral-large@2411-001": { "display_name": "Mistral Large@24.11 001", "model_vendor": "mistral", "model_version": "001" }, "vertex_ai/mistral-large@latest": { "display_name": "Mistral Large@latest", "model_vendor": "mistral", "model_version": "latest" }, "vertex_ai/mistral-medium-3": { "display_name": "Mistral Medium 3", "model_vendor": "mistral", "model_version": "latest" }, "vertex_ai/mistral-medium-3@001": { "display_name": "Mistral Medium 3@001", "model_vendor": "mistral", "model_version": "latest" }, "vertex_ai/mistral-nemo@2407": { "display_name": "Mistral Nemo@24.07", "model_vendor": "mistral", "model_version": "latest" }, "vertex_ai/mistral-nemo@latest": { "display_name": "Mistral Nemo@latest", "model_vendor": "mistral", "model_version": "latest" }, "vertex_ai/mistral-small-2503": { "display_name": "Mistral Small 2503", "model_vendor": "mistral", "model_version": "latest" }, "vertex_ai/mistral-small-2503@001": { "display_name": "Mistral Small 2503@001", "model_vendor": "mistral", "model_version": "latest" }, "vertex_ai/mistralai/codestral-2": { "display_name": "Codestral 2", "model_vendor": "mistral", "model_version": "latest" }, "vertex_ai/mistralai/codestral-2@001": { "display_name": "Codestral 2@001", "model_vendor": "mistral", "model_version": "latest" }, "vertex_ai/mistralai/mistral-medium-3": { "display_name": "Mistral Medium 3", "model_vendor": "mistral", "model_version": "latest" }, "vertex_ai/mistralai/mistral-medium-3@001": { "display_name": "Mistral Medium 3@001", "model_vendor": "mistral", "model_version": "latest" }, "vertex_ai/moonshotai/kimi-k2-thinking-maas": { "display_name": "Kimi K2 Thinking", "model_vendor": "moonshot", "model_version": "latest" }, "vertex_ai/openai/gpt-oss-120b-maas": { "display_name": "GPT Open-Source 120B", "model_vendor": "openai", "model_version": "latest" }, "vertex_ai/openai/gpt-oss-20b-maas": { "display_name": "GPT Open-Source 20B", "model_vendor": "openai", "model_version": "latest" }, "vertex_ai/qwen/qwen3-235b-a22b-instruct-2507-maas": { "display_name": "Qwen3 235B A22b Instruct 2507 Maas", "model_vendor": "alibaba", "model_version": "latest" }, "vertex_ai/qwen/qwen3-coder-480b-a35b-instruct-maas": { "display_name": "Qwen3 Coder 480B A35b Instruct Maas", "model_vendor": "alibaba", "model_version": "latest" }, "vertex_ai/qwen/qwen3-next-80b-a3b-instruct-maas": { "display_name": "Qwen3 Next 80B A3b Instruct Maas", "model_vendor": "alibaba", "model_version": "latest" }, "vertex_ai/qwen/qwen3-next-80b-a3b-thinking-maas": { "display_name": "Qwen3 Next 80B A3b Thinking Maas", "model_vendor": "alibaba", "model_version": "latest" }, "xai/grok-2": { "display_name": "Grok 2", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-2-1212": { "display_name": "Grok 2", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-2-latest": { "display_name": "Grok 2", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-2-vision": { "display_name": "Grok 2 Vision", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-2-vision-1212": { "display_name": "Grok 2 Vision", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-2-vision-latest": { "display_name": "Grok 2 Vision", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-3": { "display_name": "Grok 3", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-3-beta": { "display_name": "Grok 3", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-3-fast-beta": { "display_name": "Grok 3", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-3-fast-latest": { "display_name": "Grok 3", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-3-latest": { "display_name": "Grok 3", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-3-mini": { "display_name": "Grok 3 Mini", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-3-mini-beta": { "display_name": "Grok 3 Mini", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-3-mini-fast": { "display_name": "Grok 3 Mini", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-3-mini-fast-beta": { "display_name": "Grok 3 Mini", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-3-mini-fast-latest": { "display_name": "Grok 3 Mini", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-3-mini-latest": { "display_name": "Grok 3 Mini", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-4": { "display_name": "Grok", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-4-0709": { "display_name": "Grok", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-4-1-fast": { "display_name": "Grok", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-4-1-fast-non-reasoning": { "display_name": "Grok", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-4-1-fast-non-reasoning-latest": { "display_name": "Grok", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-4-1-fast-reasoning": { "display_name": "Grok", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-4-1-fast-reasoning-latest": { "display_name": "Grok", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-4-fast-non-reasoning": { "display_name": "Grok", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-4-fast-reasoning": { "display_name": "Grok", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-4-latest": { "display_name": "Grok", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-beta": { "display_name": "Grok Beta", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-code-fast": { "display_name": "Grok", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-code-fast-1": { "display_name": "Grok", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-code-fast-1-0825": { "display_name": "Grok", "model_vendor": "xai", "model_version": "latest" }, "xai/grok-vision-beta": { "display_name": "Grok Vision", "model_vendor": "xai", "model_version": "latest" } } ================================================ FILE: backend/onyx/llm/model_name_parser.py ================================================ """ LiteLLM Model Name Parser Parses LiteLLM model strings and returns structured metadata for UI display. All metadata comes from litellm's model_cost dictionary. Until this upstream patch to LiteLLM is merged (https://github.com/BerriAI/litellm/pull/17330), we use the model_metadata_enrichments.json to add these fields at server startup. Enrichment fields: - display_name: Human-friendly name (e.g., "Claude 3.5 Sonnet") - model_vendor: The company that made the model (anthropic, openai, meta, etc.) - model_version: Version string (e.g., "20241022-v2:0", "v1:0") The parser only extracts provider and region from the model key - everything else comes from enrichment. """ import re from functools import lru_cache from pydantic import BaseModel from onyx.llm.constants import AGGREGATOR_PROVIDERS from onyx.llm.constants import HYPHENATED_MODEL_NAMES from onyx.llm.constants import LlmProviderNames from onyx.llm.constants import MODEL_PREFIX_TO_VENDOR from onyx.llm.constants import PROVIDER_DISPLAY_NAMES from onyx.llm.constants import VENDOR_BRAND_NAMES class ParsedModelName(BaseModel): """Structured representation of a parsed LiteLLM model name.""" raw_name: str # Original: "bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0" provider: str # "bedrock", "azure", "openai", etc. (the API route) vendor: str | None = None # From enrichment: "anthropic", "openai", "meta", etc. version: str | None = None # From enrichment: "20241022-v2:0", "v1:0", etc. region: str | None = None # Extracted: "us", "eu", or None display_name: str # From enrichment: "Claude 3.5 Sonnet" provider_display_name: str # Generated: "Claude (Bedrock - Anthropic)" def _get_model_info(model_key: str) -> dict: """Get model info from litellm.model_cost.""" from onyx.llm.litellm_singleton import litellm # Try exact key first info = litellm.model_cost.get(model_key) if info: return info # Try without provider prefix (e.g., "bedrock/anthropic.claude-..." -> "anthropic.claude-...") if "/" in model_key: return litellm.model_cost.get(model_key.split("/", 1)[-1], {}) return {} def _extract_provider(model_key: str) -> str: """Extract provider from model key prefix.""" from onyx.llm.litellm_singleton import litellm if "/" in model_key: return model_key.split("/")[0] # No prefix - try to get from litellm.model_cost info = litellm.model_cost.get(model_key, {}) litellm_provider = info.get("litellm_provider", "") if litellm_provider: # Normalize vertex_ai variants if litellm_provider.startswith(LlmProviderNames.VERTEX_AI): return LlmProviderNames.VERTEX_AI return litellm_provider return "unknown" def _extract_region(model_key: str) -> str | None: """Extract region from model key (e.g., us., eu., apac. prefix).""" base = model_key.split("/")[-1].lower() for prefix in ["us.", "eu.", "apac.", "global.", "us-gov."]: if base.startswith(prefix): return prefix.rstrip(".") return None def _format_name(name: str | None) -> str: """Format provider or vendor name with proper capitalization.""" if not name: return "Unknown" return PROVIDER_DISPLAY_NAMES.get(name.lower(), name.replace("_", " ").title()) def _infer_vendor_from_model_name(model_name: str) -> str | None: """ Infer vendor from model name patterns when enrichment data is missing. Uses MODEL_PREFIX_TO_VENDOR mapping to match model name prefixes. Returns lowercase vendor name for consistency with enrichment data. Examples: "gemini-3-flash-preview" → "google" "claude-3-5-sonnet" → "anthropic" "llama-3.1-70b" → "meta" """ try: # Get the base model name (remove provider prefix if present) base_name = model_name.split("/")[-1].lower() # Try to match against known prefixes (sorted by length to match longest first) for prefix in sorted(MODEL_PREFIX_TO_VENDOR.keys(), key=len, reverse=True): if base_name.startswith(prefix): return MODEL_PREFIX_TO_VENDOR[prefix] except Exception: pass return None def _generate_display_name_from_model(model_name: str) -> str: """ Generate a human-friendly display name from a model identifier. Used as fallback when the model is not in enrichment data. Cleans up the raw model name by removing provider prefixes and formatting version numbers nicely. Examples: "vertex_ai/gemini-3-flash-preview" → "Gemini 3 Flash Preview" "gemini-2.5-pro-exp-03-25" → "Gemini 2.5 Pro" "claude-3-5-sonnet-20241022" → "Claude 3.5 Sonnet" "gpt-oss:120b" → "GPT-OSS 120B" (hyphenated exception) """ try: # Remove provider prefix if present base_name = model_name.split("/")[-1] # Remove tag suffix (e.g., :14b, :latest) - handle separately size_suffix = "" if ":" in base_name: base_name, tag = base_name.rsplit(":", 1) # Keep size tags like "14b", "70b", "120b" if re.match(r"^\d+[bBmM]$", tag): size_suffix = f" {tag.upper()}" # Check if this is a hyphenated model that should keep its format base_name_lower = base_name.lower() for hyphenated in HYPHENATED_MODEL_NAMES: if base_name_lower.startswith(hyphenated): # Keep the hyphenated prefix, uppercase it return hyphenated.upper() + size_suffix # Remove common suffixes: date stamps, version numbers cleaned = base_name # Remove date stamps like -20241022, @20250219, -2024-08-06 cleaned = re.sub(r"[-@]\d{4}-?\d{2}-?\d{2}", "", cleaned) # Remove experimental/preview date suffixes like -exp-03-25 cleaned = re.sub(r"-exp-\d{2}-\d{2}", "", cleaned) # Remove version suffixes like -v1, -v2 cleaned = re.sub(r"-v\d+$", "", cleaned) # Convert separators to spaces cleaned = cleaned.replace("-", " ").replace("_", " ") # Clean up version numbers: "3 5" → "3.5", "2 5" → "2.5" # But only for single digits that look like version numbers cleaned = re.sub(r"(\d) (\d)(?!\d)", r"\1.\2", cleaned) # Title case each word, preserving version numbers words = cleaned.split() result_words = [] for word in words: if word.isdigit() or re.match(r"^\d+\.?\d*$", word): # Keep numbers as-is result_words.append(word) elif word.lower() in ("pro", "lite", "mini", "flash", "preview", "ultra"): # Common suffixes get title case result_words.append(word.title()) else: # Title case other words result_words.append(word.title()) return " ".join(result_words) + size_suffix except Exception: return model_name def _generate_provider_display_name(provider: str, vendor: str | None) -> str: """ Generate provider display name with model brand and vendor info. Examples: - Direct OpenAI: "GPT (OpenAI)" - Bedrock via Anthropic: "Claude (Bedrock - Anthropic)" - Vertex AI via Google: "Gemini (Vertex AI - Google)" """ provider_nice = _format_name(provider) vendor_nice = _format_name(vendor) if vendor else None brand = VENDOR_BRAND_NAMES.get(vendor.lower()) if vendor else None # For aggregator providers, show: Brand (Provider - Vendor) if provider.lower() in AGGREGATOR_PROVIDERS: if brand and vendor_nice: return f"{brand} ({provider_nice} - {vendor_nice})" elif vendor_nice: return f"{provider_nice} - {vendor_nice}" return provider_nice # For direct providers, show: Brand (Provider) if brand: return f"{brand} ({provider_nice})" return provider_nice @lru_cache(maxsize=1024) def parse_litellm_model_name(raw_name: str) -> ParsedModelName: """ Parse a LiteLLM model string into structured data. Metadata comes from enrichment when available, with fallback logic for models not in the enrichment data. Args: raw_name: The LiteLLM model string Returns: ParsedModelName with all components from enrichment or fallback """ model_info = _get_model_info(raw_name) # Extract from key (not in enrichment) provider = _extract_provider(raw_name) region = _extract_region(raw_name) # Get from enrichment, with fallbacks for unenriched models vendor = model_info.get("model_vendor") or _infer_vendor_from_model_name(raw_name) version = model_info.get("model_version") display_name = model_info.get("display_name") or _generate_display_name_from_model( raw_name ) # Generate provider display name provider_display_name = _generate_provider_display_name(provider, vendor) return ParsedModelName( raw_name=raw_name, provider=provider, vendor=vendor, version=version, region=region, display_name=display_name, provider_display_name=provider_display_name, ) ================================================ FILE: backend/onyx/llm/model_response.py ================================================ from __future__ import annotations from typing import Any from typing import List from typing import TYPE_CHECKING from pydantic import BaseModel from pydantic import Field class FunctionCall(BaseModel): arguments: str | None = None name: str | None = None class ChatCompletionMessageToolCall(BaseModel): id: str type: str = "function" function: FunctionCall class ChatCompletionDeltaToolCall(BaseModel): id: str | None = None index: int = 0 type: str = "function" function: FunctionCall | None = None class Delta(BaseModel): content: str | None = None reasoning_content: str | None = None tool_calls: List[ChatCompletionDeltaToolCall] = Field(default_factory=list) class StreamingChoice(BaseModel): finish_reason: str | None = None index: int = 0 delta: Delta = Field(default_factory=Delta) class Usage(BaseModel): completion_tokens: int prompt_tokens: int total_tokens: int cache_creation_input_tokens: int cache_read_input_tokens: int class ModelResponseStream(BaseModel): id: str created: str choice: StreamingChoice usage: Usage | None = None if TYPE_CHECKING: from litellm.types.utils import ModelResponseStream as LiteLLMModelResponseStream class Message(BaseModel): content: str | None = None role: str = "assistant" tool_calls: List[ChatCompletionMessageToolCall] | None = None reasoning_content: str | None = None class Choice(BaseModel): finish_reason: str | None = None index: int = 0 message: Message = Field(default_factory=Message) class ModelResponse(BaseModel): id: str created: str choice: Choice usage: Usage | None = None if TYPE_CHECKING: from litellm.types.utils import ( ModelResponse as LiteLLMModelResponse, ModelResponseStream as LiteLLMModelResponseStream, ) def _parse_function_call( function_payload: dict[str, Any] | None, ) -> FunctionCall | None: """Parse a function call payload into a FunctionCall object.""" if not function_payload or not isinstance(function_payload, dict): return None return FunctionCall( arguments=function_payload.get("arguments"), name=function_payload.get("name"), ) def _parse_delta_tool_calls( tool_calls: list[dict[str, Any]] | None, ) -> list[ChatCompletionDeltaToolCall]: """Parse tool calls for streaming responses (delta format).""" if not tool_calls: return [] parsed_tool_calls: list[ChatCompletionDeltaToolCall] = [] for tool_call in tool_calls: parsed_tool_calls.append( ChatCompletionDeltaToolCall( id=tool_call.get("id"), index=tool_call.get("index", 0), type=tool_call.get("type", "function"), function=_parse_function_call(tool_call.get("function")), ) ) return parsed_tool_calls def _parse_message_tool_calls( tool_calls: list[dict[str, Any]] | None, ) -> list[ChatCompletionMessageToolCall]: """Parse tool calls for non-streaming responses (message format).""" if not tool_calls: return [] parsed_tool_calls: list[ChatCompletionMessageToolCall] = [] for tool_call in tool_calls: function_call = _parse_function_call(tool_call.get("function")) if not function_call: continue parsed_tool_calls.append( ChatCompletionMessageToolCall( id=tool_call.get("id", ""), type=tool_call.get("type", "function"), function=function_call, ) ) return parsed_tool_calls def _validate_and_extract_base_fields( response_data: dict[str, Any], error_prefix: str ) -> tuple[str, str, dict[str, Any]]: """ Validate and extract common fields (id, created, first choice) from a LiteLLM response. Returns: Tuple of (id, created, choice_data) """ response_id = response_data.get("id") created = response_data.get("created") if response_id is None or created is None: raise ValueError(f"{error_prefix} must include 'id' and 'created'.") choices: list[dict[str, Any]] = response_data.get("choices") or [] if not choices: raise ValueError(f"{error_prefix} must include at least one choice.") return str(response_id), str(created), choices[0] or {} def _usage_from_usage_data(usage_data: dict[str, Any]) -> Usage: # NOTE: sometimes the usage data dictionary has these keys and the values are None # hence the "or 0" instead of just using default values return Usage( completion_tokens=usage_data.get("completion_tokens") or 0, prompt_tokens=usage_data.get("prompt_tokens") or 0, total_tokens=usage_data.get("total_tokens") or 0, cache_creation_input_tokens=usage_data.get("cache_creation_input_tokens") or 0, cache_read_input_tokens=usage_data.get( "cache_read_input_tokens", (usage_data.get("prompt_tokens_details") or {}).get("cached_tokens"), ) or 0, ) def from_litellm_model_response_stream( response: "LiteLLMModelResponseStream", ) -> ModelResponseStream: """ Convert a LiteLLM ModelResponseStream into the simplified Onyx representation. """ response_data = response.model_dump() response_id, created, choice_data = _validate_and_extract_base_fields( response_data, "LiteLLM response stream" ) delta_data: dict[str, Any] = choice_data.get("delta") or {} parsed_delta = Delta( content=delta_data.get("content"), reasoning_content=delta_data.get("reasoning_content"), tool_calls=_parse_delta_tool_calls(delta_data.get("tool_calls")), ) streaming_choice = StreamingChoice( finish_reason=choice_data.get("finish_reason"), index=choice_data.get("index", 0), delta=parsed_delta, ) usage_data = response_data.get("usage") return ModelResponseStream( id=response_id, created=created, choice=streaming_choice, usage=(_usage_from_usage_data(usage_data) if usage_data else None), ) def from_litellm_model_response( response: "LiteLLMModelResponse", ) -> ModelResponse: """ Convert a LiteLLM ModelResponse into the simplified Onyx representation. """ response_data = response.model_dump() response_id, created, choice_data = _validate_and_extract_base_fields( response_data, "LiteLLM response" ) message_data: dict[str, Any] = choice_data.get("message") or {} parsed_tool_calls = _parse_message_tool_calls(message_data.get("tool_calls")) message = Message( content=message_data.get("content"), role=message_data.get("role", "assistant"), tool_calls=parsed_tool_calls if parsed_tool_calls else None, reasoning_content=message_data.get("reasoning_content"), ) choice = Choice( finish_reason=choice_data.get("finish_reason"), index=choice_data.get("index", 0), message=message, ) usage_data = response_data.get("usage") return ModelResponse( id=response_id, created=created, choice=choice, usage=(_usage_from_usage_data(usage_data) if usage_data else None), ) ================================================ FILE: backend/onyx/llm/models.py ================================================ from enum import Enum from typing import Literal from pydantic import BaseModel class ToolChoiceOptions(str, Enum): REQUIRED = "required" AUTO = "auto" NONE = "none" class ReasoningEffort(str, Enum): """Reasoning effort levels for models that support extended thinking. Different providers map these values differently: - OpenAI: Uses "low", "medium", "high" directly for reasoning_effort. Recently added "none" for 5 series which is like "minimal" - Claude: Uses budget_tokens with different values for each level - Gemini: Uses "none", "low", "medium", "high" for thinking_budget (via litellm mapping) """ AUTO = "auto" OFF = "off" LOW = "low" MEDIUM = "medium" HIGH = "high" # OpenAI reasoning effort mapping # Note: OpenAI API does not support "auto" - valid values are: none, minimal, low, medium, high, xhigh OPENAI_REASONING_EFFORT: dict[ReasoningEffort, str] = { ReasoningEffort.AUTO: "medium", # Default to medium when auto is requested ReasoningEffort.OFF: "none", ReasoningEffort.LOW: "low", ReasoningEffort.MEDIUM: "medium", ReasoningEffort.HIGH: "high", } # Anthropic reasoning effort to budget tokens mapping # Loosely based on budgets from LiteLLM but this ensures it's not updated without our knowing from a version bump. ANTHROPIC_REASONING_EFFORT_BUDGET: dict[ReasoningEffort, int] = { ReasoningEffort.AUTO: 2048, ReasoningEffort.LOW: 1024, ReasoningEffort.MEDIUM: 2048, ReasoningEffort.HIGH: 4096, } # Content part structures for multimodal messages # The classes in this mirror the OpenAI Chat Completions message types and work well with routers like LiteLLM class TextContentPart(BaseModel): type: Literal["text"] = "text" text: str # Some providers (e.g. Anthropic/Gemini) support prompt caching controls on content blocks. cache_control: dict | None = None class ImageUrlDetail(BaseModel): url: str detail: Literal["auto", "low", "high"] | None = None class ImageContentPart(BaseModel): type: Literal["image_url"] = "image_url" image_url: ImageUrlDetail ContentPart = TextContentPart | ImageContentPart # Tool call structures class FunctionCall(BaseModel): name: str arguments: str class ToolCall(BaseModel): type: Literal["function"] = "function" id: str function: FunctionCall # Message types # Base class for all cacheable messages class CacheableMessage(BaseModel): # Some providers support prompt caching controls at the message level (passed through via LiteLLM). cache_control: dict | None = None class SystemMessage(CacheableMessage): role: Literal["system"] = "system" content: str class UserMessage(CacheableMessage): role: Literal["user"] = "user" content: str | list[ContentPart] class AssistantMessage(CacheableMessage): role: Literal["assistant"] = "assistant" content: str | None = None tool_calls: list[ToolCall] | None = None class ToolMessage(CacheableMessage): role: Literal["tool"] = "tool" content: str tool_call_id: str # Union type for all OpenAI Chat Completions messages ChatCompletionMessage = SystemMessage | UserMessage | AssistantMessage | ToolMessage # Allows for passing in a string directly. This is provided for convenience and is wrapped as a UserMessage. LanguageModelInput = list[ChatCompletionMessage] | ChatCompletionMessage ================================================ FILE: backend/onyx/llm/multi_llm.py ================================================ import os import threading from collections.abc import Iterator from contextlib import contextmanager from contextlib import nullcontext from typing import Any from typing import cast from typing import TYPE_CHECKING from typing import Union from onyx.configs.app_configs import MOCK_LLM_RESPONSE from onyx.configs.chat_configs import LLM_SOCKET_READ_TIMEOUT from onyx.configs.model_configs import GEN_AI_TEMPERATURE from onyx.configs.model_configs import LITELLM_EXTRA_BODY from onyx.llm.constants import LlmProviderNames from onyx.llm.cost import calculate_llm_cost_cents from onyx.llm.interfaces import LanguageModelInput from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLMConfig from onyx.llm.interfaces import LLMUserIdentity from onyx.llm.interfaces import ReasoningEffort from onyx.llm.interfaces import ToolChoiceOptions from onyx.llm.model_response import ModelResponse from onyx.llm.model_response import ModelResponseStream from onyx.llm.model_response import Usage from onyx.llm.models import ANTHROPIC_REASONING_EFFORT_BUDGET from onyx.llm.models import OPENAI_REASONING_EFFORT from onyx.llm.request_context import get_llm_mock_response from onyx.llm.utils import build_litellm_passthrough_kwargs from onyx.llm.utils import is_true_openai_model from onyx.llm.utils import model_is_reasoning_model from onyx.llm.well_known_providers.constants import AWS_ACCESS_KEY_ID_KWARG from onyx.llm.well_known_providers.constants import ( AWS_ACCESS_KEY_ID_KWARG_ENV_VAR_FORMAT, ) from onyx.llm.well_known_providers.constants import ( AWS_BEARER_TOKEN_BEDROCK_KWARG_ENV_VAR_FORMAT, ) from onyx.llm.well_known_providers.constants import AWS_REGION_NAME_KWARG from onyx.llm.well_known_providers.constants import AWS_REGION_NAME_KWARG_ENV_VAR_FORMAT from onyx.llm.well_known_providers.constants import AWS_SECRET_ACCESS_KEY_KWARG from onyx.llm.well_known_providers.constants import ( AWS_SECRET_ACCESS_KEY_KWARG_ENV_VAR_FORMAT, ) from onyx.llm.well_known_providers.constants import LM_STUDIO_API_KEY_CONFIG_KEY from onyx.llm.well_known_providers.constants import OLLAMA_API_KEY_CONFIG_KEY from onyx.llm.well_known_providers.constants import VERTEX_CREDENTIALS_FILE_KWARG from onyx.llm.well_known_providers.constants import ( VERTEX_CREDENTIALS_FILE_KWARG_ENV_VAR_FORMAT, ) from onyx.llm.well_known_providers.constants import VERTEX_LOCATION_KWARG from onyx.utils.encryption import mask_string from onyx.utils.logger import setup_logger logger = setup_logger() _env_lock = threading.Lock() if TYPE_CHECKING: from litellm import CustomStreamWrapper from litellm import HTTPHandler _LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt" LEGACY_MAX_TOKENS_KWARG = "max_tokens" STANDARD_MAX_TOKENS_KWARG = "max_completion_tokens" _VERTEX_ANTHROPIC_MODELS_REJECTING_OUTPUT_CONFIG = ( "claude-opus-4-5", "claude-opus-4-6", ) class LLMTimeoutError(Exception): """ Exception raised when an LLM call times out. """ class LLMRateLimitError(Exception): """ Exception raised when an LLM call is rate limited. """ def _prompt_to_dicts(prompt: LanguageModelInput) -> list[dict[str, Any]]: """Convert Pydantic message models to dictionaries for LiteLLM. LiteLLM expects messages to be dictionaries (with .get() method), not Pydantic models. This function serializes the messages. """ if isinstance(prompt, list): return [msg.model_dump(exclude_none=True) for msg in prompt] return [prompt.model_dump(exclude_none=True)] def _normalize_content(raw: Any) -> str: """Normalize a message content field to a plain string. Content can be a string, None, or a list of content-block dicts (e.g. [{"type": "text", "text": "..."}]). """ if raw is None: return "" if isinstance(raw, str): return raw if isinstance(raw, list): return "\n".join( block.get("text", "") if isinstance(block, dict) else str(block) for block in raw ) return str(raw) def _strip_tool_content_from_messages( messages: list[dict[str, Any]], ) -> list[dict[str, Any]]: """Convert tool-related messages to plain text. Bedrock's Converse API requires toolConfig when messages contain toolUse/toolResult content blocks. When no tools are provided for the current request, we must convert any tool-related history into plain text to avoid the "toolConfig field must be defined" error. This is the same approach used by _OllamaHistoryMessageFormatter. """ result: list[dict[str, Any]] = [] for msg in messages: role = msg.get("role") tool_calls = msg.get("tool_calls") if role == "assistant" and tool_calls: # Convert structured tool calls to text representation tool_call_lines = [] for tc in tool_calls: func = tc.get("function", {}) name = func.get("name", "unknown") args = func.get("arguments", "{}") tc_id = tc.get("id", "") tool_call_lines.append( f"[Tool Call] name={name} id={tc_id} args={args}" ) existing_content = _normalize_content(msg.get("content")) parts = ( [existing_content] + tool_call_lines if existing_content else tool_call_lines ) new_msg = { "role": "assistant", "content": "\n".join(parts), } result.append(new_msg) elif role == "tool": # Convert tool response to user message with text content tool_call_id = msg.get("tool_call_id", "") content = _normalize_content(msg.get("content")) tool_result_text = f"[Tool Result] id={tool_call_id}\n{content}" # Merge into previous user message if it is also a converted # tool result to avoid consecutive user messages (Bedrock requires # strict user/assistant alternation). if ( result and result[-1]["role"] == "user" and "[Tool Result]" in result[-1].get("content", "") ): result[-1]["content"] += "\n\n" + tool_result_text else: result.append({"role": "user", "content": tool_result_text}) else: result.append(msg) return result def _fix_tool_user_message_ordering( messages: list[dict[str, Any]], ) -> list[dict[str, Any]]: """Insert a synthetic assistant message between tool and user messages. Some models (e.g. Mistral on Azure) require strict message ordering where a user message cannot immediately follow a tool message. This function inserts a minimal assistant message to bridge the gap. """ if len(messages) < 2: return messages result: list[dict[str, Any]] = [messages[0]] for msg in messages[1:]: prev_role = result[-1].get("role") curr_role = msg.get("role") if prev_role == "tool" and curr_role == "user": result.append({"role": "assistant", "content": "Noted. Continuing."}) result.append(msg) return result def _messages_contain_tool_content(messages: list[dict[str, Any]]) -> bool: """Check if any messages contain tool-related content blocks.""" for msg in messages: if msg.get("role") == "tool": return True if msg.get("role") == "assistant" and msg.get("tool_calls"): return True return False def _prompt_contains_tool_call_history(prompt: LanguageModelInput) -> bool: """Check if the prompt contains any assistant messages with tool_calls. When Anthropic's extended thinking is enabled, the API requires every assistant message to start with a thinking block before any tool_use blocks. Since we don't preserve thinking_blocks (they carry cryptographic signatures that can't be reconstructed), we must skip the thinking param whenever history contains prior tool-calling turns. """ from onyx.llm.models import AssistantMessage msgs = prompt if isinstance(prompt, list) else [prompt] return any(isinstance(msg, AssistantMessage) and msg.tool_calls for msg in msgs) def _is_vertex_model_rejecting_output_config(model_name: str) -> bool: normalized_model_name = model_name.lower() return any( blocked_model in normalized_model_name for blocked_model in _VERTEX_ANTHROPIC_MODELS_REJECTING_OUTPUT_CONFIG ) class LitellmLLM(LLM): """Uses Litellm library to allow easy configuration to use a multitude of LLMs See https://python.langchain.com/docs/integrations/chat/litellm""" def __init__( self, api_key: str | None, model_provider: str, model_name: str, max_input_tokens: int, timeout: int | None = None, api_base: str | None = None, api_version: str | None = None, deployment_name: str | None = None, custom_llm_provider: str | None = None, temperature: float | None = None, custom_config: dict[str, str] | None = None, extra_headers: dict[str, str] | None = None, extra_body: dict | None = LITELLM_EXTRA_BODY, model_kwargs: dict[str, Any] | None = None, ): # Timeout in seconds for each socket read operation (i.e., max time between # receiving data chunks/tokens). This is NOT a total request timeout - a # request can run indefinitely as long as data keeps arriving within this # window. If the LLM pauses for longer than this timeout between chunks, # a ReadTimeout is raised. self._timeout = timeout if timeout is None: self._timeout = LLM_SOCKET_READ_TIMEOUT self._temperature = GEN_AI_TEMPERATURE if temperature is None else temperature self._model_provider = model_provider self._model_version = model_name self._api_key = api_key self._deployment_name = deployment_name self._api_base = api_base self._api_version = api_version self._custom_llm_provider = custom_llm_provider self._max_input_tokens = max_input_tokens self._custom_config = custom_config # Create a dictionary for model-specific arguments if it's None model_kwargs = model_kwargs or {} if custom_config: for k, v in custom_config.items(): if model_provider == LlmProviderNames.VERTEX_AI: if k == VERTEX_CREDENTIALS_FILE_KWARG: model_kwargs[k] = v elif k == VERTEX_CREDENTIALS_FILE_KWARG_ENV_VAR_FORMAT: model_kwargs[VERTEX_CREDENTIALS_FILE_KWARG] = v elif k == VERTEX_LOCATION_KWARG: model_kwargs[k] = v elif model_provider == LlmProviderNames.OLLAMA_CHAT: if k == OLLAMA_API_KEY_CONFIG_KEY: model_kwargs["api_key"] = v elif model_provider == LlmProviderNames.LM_STUDIO: if k == LM_STUDIO_API_KEY_CONFIG_KEY: model_kwargs["api_key"] = v elif model_provider == LlmProviderNames.BEDROCK: if k == AWS_REGION_NAME_KWARG: model_kwargs[k] = v elif k == AWS_REGION_NAME_KWARG_ENV_VAR_FORMAT: model_kwargs[AWS_REGION_NAME_KWARG] = v elif k == AWS_BEARER_TOKEN_BEDROCK_KWARG_ENV_VAR_FORMAT: model_kwargs["api_key"] = v elif k == AWS_ACCESS_KEY_ID_KWARG: model_kwargs[k] = v elif k == AWS_ACCESS_KEY_ID_KWARG_ENV_VAR_FORMAT: model_kwargs[AWS_ACCESS_KEY_ID_KWARG] = v elif k == AWS_SECRET_ACCESS_KEY_KWARG: model_kwargs[k] = v elif k == AWS_SECRET_ACCESS_KEY_KWARG_ENV_VAR_FORMAT: model_kwargs[AWS_SECRET_ACCESS_KEY_KWARG] = v # LM Studio: LiteLLM defaults to "fake-api-key" when no key is provided, # which LM Studio rejects. Ensure we always pass an explicit key (or empty # string) to prevent LiteLLM from injecting its fake default. if model_provider == LlmProviderNames.LM_STUDIO: model_kwargs.setdefault("api_key", "") # Users provide the server root (e.g. http://localhost:1234) but LiteLLM # needs /v1 for OpenAI-compatible calls. if self._api_base is not None: base = self._api_base.rstrip("/") self._api_base = base if base.endswith("/v1") else f"{base}/v1" model_kwargs["api_base"] = self._api_base # Default vertex_location to "global" if not provided for Vertex AI # Latest gemini models are only available through the global region if ( model_provider == LlmProviderNames.VERTEX_AI and VERTEX_LOCATION_KWARG not in model_kwargs ): model_kwargs[VERTEX_LOCATION_KWARG] = "global" # Bifrost: OpenAI-compatible proxy that expects model names in # provider/model format (e.g. "anthropic/claude-sonnet-4-6"). # We route through LiteLLM's openai provider with the Bifrost base URL, # and ensure /v1 is appended. if model_provider == LlmProviderNames.BIFROST: self._custom_llm_provider = "openai" if self._api_base is not None: base = self._api_base.rstrip("/") self._api_base = base if base.endswith("/v1") else f"{base}/v1" model_kwargs["api_base"] = self._api_base # This is needed for Ollama to do proper function calling if model_provider == LlmProviderNames.OLLAMA_CHAT and api_base is not None: model_kwargs["api_base"] = api_base if extra_headers: model_kwargs.update({"extra_headers": extra_headers}) if extra_body: model_kwargs.update({"extra_body": extra_body}) self._model_kwargs = model_kwargs def _safe_model_config(self) -> dict: dump = self.config.model_dump() dump["api_key"] = mask_string(dump.get("api_key") or "") custom_config = dump.get("custom_config") if isinstance(custom_config, dict): # Mask sensitive values in custom_config masked_config = {} for k, v in custom_config.items(): masked_config[k] = mask_string(v) if v else v dump["custom_config"] = masked_config return dump def _track_llm_cost(self, usage: Usage) -> None: """ Track LLM usage cost for Onyx-managed API keys. This is called after every LLM call completes (streaming or non-streaming). Cost is only tracked if: 1. Usage limits are enabled for this deployment 2. The API key is one of Onyx's managed default keys """ from onyx.server.usage_limits import is_usage_limits_enabled if not is_usage_limits_enabled(): return from onyx.server.usage_limits import is_onyx_managed_api_key if not is_onyx_managed_api_key(self._api_key): return # Import here to avoid circular imports from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.usage import increment_usage from onyx.db.usage import UsageType # Calculate cost in cents cost_cents = calculate_llm_cost_cents( model_name=self._model_version, prompt_tokens=usage.prompt_tokens, completion_tokens=usage.completion_tokens, ) if cost_cents <= 0: return try: with get_session_with_current_tenant() as db_session: increment_usage(db_session, UsageType.LLM_COST, cost_cents) db_session.commit() except Exception as e: # Log but don't fail the LLM call if tracking fails logger.warning(f"Failed to track LLM cost: {e}") def _completion( self, prompt: LanguageModelInput, tools: list[dict] | None, tool_choice: ToolChoiceOptions | None, stream: bool, parallel_tool_calls: bool, reasoning_effort: ReasoningEffort = ReasoningEffort.AUTO, structured_response_format: dict | None = None, timeout_override: int | None = None, max_tokens: int | None = None, user_identity: LLMUserIdentity | None = None, client: "HTTPHandler | None" = None, ) -> Union["ModelResponse", "CustomStreamWrapper"]: # Lazy loading to avoid memory bloat for non-inference flows from onyx.llm.litellm_singleton import litellm from litellm.exceptions import Timeout, RateLimitError ######################### # Flags that modify the final arguments ######################### is_claude_model = "claude" in self.config.model_name.lower() is_reasoning = model_is_reasoning_model( self.config.model_name, self.config.model_provider ) # All OpenAI models will use responses API for consistency # Responses API is needed to get reasoning packets from OpenAI models is_openai_model = is_true_openai_model( self.config.model_provider, self.config.model_name ) is_ollama = self._model_provider == LlmProviderNames.OLLAMA_CHAT is_mistral = self._model_provider == LlmProviderNames.MISTRAL is_vertex_ai = self._model_provider == LlmProviderNames.VERTEX_AI # Some Vertex Anthropic models reject output_config. # Keep this guard until LiteLLM/Vertex accept the field for these models. is_vertex_model_rejecting_output_config = ( is_vertex_ai and _is_vertex_model_rejecting_output_config(self.config.model_name) ) ######################### # Build arguments ######################### # Optional kwargs - should only be passed to LiteLLM under certain conditions optional_kwargs: dict[str, Any] = {} # Model name is_bifrost = self._model_provider == LlmProviderNames.BIFROST model_provider = ( f"{self.config.model_provider}/responses" if is_openai_model # Uses litellm's completions -> responses bridge else self.config.model_provider ) if is_bifrost: # Bifrost expects model names in provider/model format # (e.g. "anthropic/claude-sonnet-4-6") sent directly to its # OpenAI-compatible endpoint. We use custom_llm_provider="openai" # so LiteLLM doesn't try to route based on the provider prefix. model = self.config.deployment_name or self.config.model_name else: model = f"{model_provider}/{self.config.deployment_name or self.config.model_name}" # Tool choice if is_claude_model and tool_choice == ToolChoiceOptions.REQUIRED: # Claude models will not use reasoning if tool_choice is required # let it choose tools automatically so reasoning can still be used tool_choice = ToolChoiceOptions.AUTO # If no tools are provided, tool_choice should be None if not tools: tool_choice = None # Temperature temperature = 1 if is_reasoning else self._temperature if stream and not is_vertex_model_rejecting_output_config: optional_kwargs["stream_options"] = {"include_usage": True} # Note, there is a reasoning_effort parameter in LiteLLM but it is completely jank and does not work for any # of the major providers. Not setting it sets it to OFF. if ( is_reasoning # The default of this parameter not set is surprisingly not the equivalent of an Auto but is actually Off and reasoning_effort != ReasoningEffort.OFF and not is_vertex_model_rejecting_output_config ): if is_openai_model: # OpenAI API does not accept reasoning params for GPT 5 chat models # (neither reasoning nor reasoning_effort are accepted) # even though they are reasoning models (bug in OpenAI) if "-chat" not in model: optional_kwargs["reasoning"] = { "effort": OPENAI_REASONING_EFFORT[reasoning_effort], "summary": "auto", } elif is_claude_model: budget_tokens: int | None = ANTHROPIC_REASONING_EFFORT_BUDGET.get( reasoning_effort ) # Anthropic requires every assistant message with tool_use # blocks to start with a thinking block that carries a # cryptographic signature. We don't preserve those blocks # across turns, so skip thinking when the history already # contains tool-calling assistant messages. LiteLLM's # modify_params workaround doesn't cover all providers # (notably Bedrock). can_enable_thinking = ( budget_tokens is not None and not _prompt_contains_tool_call_history(prompt) ) if can_enable_thinking: assert budget_tokens is not None # mypy if max_tokens is not None: # Anthropic has a weird rule where max token has to be at least as much as budget tokens if set # and the minimum budget tokens is 1024 # Will note that overwriting a developer set max tokens is not ideal but is the best we can do for now # It is better to allow the LLM to output more reasoning tokens even if it results in a fairly small tool # call as compared to reducing the budget for reasoning. max_tokens = max(budget_tokens + 1, max_tokens) optional_kwargs["thinking"] = { "type": "enabled", "budget_tokens": budget_tokens, } # LiteLLM just does some mapping like this anyway but is incomplete for Anthropic optional_kwargs.pop("reasoning_effort", None) else: # Hope for the best from LiteLLM if reasoning_effort in [ ReasoningEffort.LOW, ReasoningEffort.MEDIUM, ReasoningEffort.HIGH, ]: optional_kwargs["reasoning_effort"] = reasoning_effort.value else: optional_kwargs["reasoning_effort"] = ReasoningEffort.MEDIUM.value if tools: # OpenAI will error if parallel_tool_calls is True and tools are not specified optional_kwargs["parallel_tool_calls"] = parallel_tool_calls if structured_response_format: optional_kwargs["response_format"] = structured_response_format if not (is_claude_model or is_ollama or is_mistral) or is_bifrost: # Litellm bug: tool_choice is dropped silently if not specified here for OpenAI # However, this param breaks Anthropic and Mistral models, # so it must be conditionally included unless the request is # routed through Bifrost's OpenAI-compatible endpoint. # Additionally, tool_choice is not supported by Ollama and causes warnings if included. # See also, https://github.com/ollama/ollama/issues/11171 optional_kwargs["allowed_openai_params"] = ["tool_choice"] # Passthrough kwargs passthrough_kwargs = build_litellm_passthrough_kwargs( model_kwargs=self._model_kwargs, user_identity=user_identity, ) try: # NOTE: must pass in None instead of empty strings otherwise litellm # can have some issues with bedrock. # NOTE: Sometimes _model_kwargs may have an "api_key" kwarg # depending on what the caller passes in for custom_config. If it # does we allow it to clobber _api_key. if "api_key" not in passthrough_kwargs: passthrough_kwargs["api_key"] = self._api_key or None # We only need to set environment variables if custom config is set env_ctx = ( temporary_env_and_lock(self._custom_config) if self._custom_config else nullcontext() ) with env_ctx: messages = _prompt_to_dicts(prompt) # Bedrock's Converse API requires toolConfig when messages # contain toolUse/toolResult content blocks. When no tools are # provided for this request but the history contains tool # content from previous turns, strip it to plain text. is_bedrock = self._model_provider in { LlmProviderNames.BEDROCK, LlmProviderNames.BEDROCK_CONVERSE, } if ( is_bedrock and not tools and _messages_contain_tool_content(messages) ): messages = _strip_tool_content_from_messages(messages) # Some models (e.g. Mistral) reject a user message # immediately after a tool message. Insert a synthetic # assistant bridge message to satisfy the ordering # constraint. Check both the provider and the deployment/ # model name to catch Mistral hosted on Azure. model_or_deployment = ( self._deployment_name or self._model_version or "" ).lower() is_mistral_model = is_mistral or "mistral" in model_or_deployment if is_mistral_model: messages = _fix_tool_user_message_ordering(messages) # Only pass tool_choice when tools are present — some providers (e.g. Fireworks) # reject requests where tool_choice is explicitly null. if tools and tool_choice is not None: optional_kwargs["tool_choice"] = tool_choice response = litellm.completion( mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE, model=model, base_url=self._api_base or None, api_version=self._api_version or None, custom_llm_provider=self._custom_llm_provider or None, messages=messages, tools=tools, stream=stream, temperature=temperature, timeout=timeout_override or self._timeout, max_tokens=max_tokens, client=client, **optional_kwargs, **passthrough_kwargs, ) return response except Exception as e: # for break pointing if isinstance(e, Timeout): raise LLMTimeoutError(e) elif isinstance(e, RateLimitError): raise LLMRateLimitError(e) raise e @property def config(self) -> LLMConfig: return LLMConfig( model_provider=self._model_provider, model_name=self._model_version, temperature=self._temperature, api_key=self._api_key, api_base=self._api_base, api_version=self._api_version, deployment_name=self._deployment_name, custom_config=self._custom_config, max_input_tokens=self._max_input_tokens, ) def invoke( self, prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, structured_response_format: dict | None = None, timeout_override: int | None = None, max_tokens: int | None = None, reasoning_effort: ReasoningEffort = ReasoningEffort.AUTO, user_identity: LLMUserIdentity | None = None, ) -> ModelResponse: from litellm import HTTPHandler from litellm import ModelResponse as LiteLLMModelResponse from onyx.llm.model_response import from_litellm_model_response # HTTPHandler Threading & Connection Pool Notes: # ============================================= # We create an isolated HTTPHandler ONLY for true OpenAI models (not OpenAI-compatible # providers like glm-4.7, DeepSeek, etc.). This distinction is critical: # # 1. WHY ONLY TRUE OPENAI MODELS: # - True OpenAI models use litellm's "responses API" path which expects HTTPHandler # - OpenAI-compatible providers (model_provider="openai" with non-OpenAI models) # use the standard completion path which expects OpenAI SDK client objects # - Passing HTTPHandler to OpenAI-compatible providers causes: # AttributeError: 'HTTPHandler' object has no attribute 'api_key' # (because _get_openai_client() calls openai_client.api_key on line ~929) # # 2. WHY ISOLATED HTTPHandler FOR OPENAI: # - Prevents "Bad file descriptor" errors when multiple threads stream concurrently # - Shared connection pools can have stale connections or abandoned streams that # corrupt the pool state for other threads # - Each request gets its own fresh httpx.Client via HTTPHandler # # 3. WHY OTHER PROVIDERS DON'T NEED THIS: # - Other providers (Anthropic, Bedrock, etc.) use litellm.module_level_client # which handles concurrency appropriately # - httpx.Client itself IS thread-safe for concurrent requests # - The issue is specific to OpenAI's responses API path and connection reuse # # 4. PITFALL - is_true_openai_model() CHECK: # - Must use is_true_openai_model() NOT just check model_provider == "openai" # - Many OpenAI-compatible providers set model_provider="openai" but are NOT true # OpenAI models (glm-4.7, DeepSeek, local proxies, etc.) # - is_true_openai_model() checks both provider AND model name patterns # # This note may not be entirely accurate as there is a lot of complexity in the LiteLLM codebase around this # and not every model path was traced thoroughly. It is also possible that in future versions of LiteLLM # they will realize that their OpenAI handling is not threadsafe. Hope they will just fix it. client = None if is_true_openai_model(self.config.model_provider, self.config.model_name): client = HTTPHandler(timeout=timeout_override or self._timeout) try: # When custom_config is set, env vars are temporarily injected # under a global lock. Using stream=True here means the lock is # only held during connection setup (not the full inference). # The chunks are then collected outside the lock and reassembled # into a single ModelResponse via stream_chunk_builder. from litellm import stream_chunk_builder from litellm import CustomStreamWrapper as LiteLLMCustomStreamWrapper stream_response = cast( LiteLLMCustomStreamWrapper, self._completion( prompt=prompt, tools=tools, tool_choice=tool_choice, stream=True, structured_response_format=structured_response_format, timeout_override=timeout_override, max_tokens=max_tokens, parallel_tool_calls=True, reasoning_effort=reasoning_effort, user_identity=user_identity, client=client, ), ) chunks = list(stream_response) response = cast( LiteLLMModelResponse, stream_chunk_builder(chunks), ) model_response = from_litellm_model_response(response) # Track LLM cost for Onyx-managed API keys if model_response.usage: self._track_llm_cost(model_response.usage) return model_response finally: if client is not None: client.close() def stream( self, prompt: LanguageModelInput, tools: list[dict] | None = None, tool_choice: ToolChoiceOptions | None = None, structured_response_format: dict | None = None, timeout_override: int | None = None, max_tokens: int | None = None, reasoning_effort: ReasoningEffort = ReasoningEffort.AUTO, user_identity: LLMUserIdentity | None = None, ) -> Iterator[ModelResponseStream]: from litellm import CustomStreamWrapper as LiteLLMCustomStreamWrapper from litellm import HTTPHandler from onyx.llm.model_response import from_litellm_model_response_stream # HTTPHandler Threading & Connection Pool Notes: # ============================================= # See invoke() method for full explanation. Key points for streaming: # # 1. SAME RESTRICTIONS APPLY: # - HTTPHandler ONLY for true OpenAI models (use is_true_openai_model()) # - OpenAI-compatible providers will fail with AttributeError on api_key # # 2. STREAMING-SPECIFIC CONCERNS: # - "Bad file descriptor" errors are MORE common during streaming because: # a) Streams hold connections open longer, increasing conflict window # b) Multiple concurrent streams (e.g., deep research) share the pool # c) Abandoned/interrupted streams can leave connections in bad state # # 3. ABANDONED STREAM PITFALL: # - If callers abandon this generator without fully consuming it (e.g., # early return, exception, or break), the finally block won't execute # until the generator is garbage collected # - This is acceptable because: # a) CPython's refcounting typically finalizes generators promptly # b) Each HTTPHandler has its own isolated connection pool # c) httpx has built-in connection timeouts as a fallback # - If abandoned streams become problematic, consider using contextlib # or explicit stream.close() at call sites # # 4. WHY NOT USE SHARED HTTPHandler: # - litellm's InMemoryCache (used for client caching) is NOT thread-safe # - Shared pools can have connections corrupted by other threads # - Per-request HTTPHandler eliminates cross-thread interference client = None if is_true_openai_model(self.config.model_provider, self.config.model_name): client = HTTPHandler(timeout=timeout_override or self._timeout) try: response = cast( LiteLLMCustomStreamWrapper, self._completion( prompt=prompt, tools=tools, tool_choice=tool_choice, stream=True, structured_response_format=structured_response_format, timeout_override=timeout_override, max_tokens=max_tokens, parallel_tool_calls=True, reasoning_effort=reasoning_effort, user_identity=user_identity, client=client, ), ) for chunk in response: model_response = from_litellm_model_response_stream(chunk) # Track LLM cost when usage info is available (typically in the last chunk) if model_response.usage: self._track_llm_cost(model_response.usage) yield model_response finally: if client is not None: client.close() @contextmanager def temporary_env_and_lock(env_variables: dict[str, str]) -> Iterator[None]: """ Temporarily sets the environment variables to the given values. Code path is locked while the environment variables are set. Then cleans up the environment and frees the lock. """ with _env_lock: logger.debug("Acquired lock in temporary_env_and_lock") # Store original values (None if key didn't exist) original_values: dict[str, str | None] = { key: os.environ.get(key) for key in env_variables } try: os.environ.update(env_variables) yield finally: for key, original_value in original_values.items(): if original_value is None: os.environ.pop(key, None) # Remove if it didn't exist before else: os.environ[key] = original_value # Restore original value logger.debug("Released lock in temporary_env_and_lock") ================================================ FILE: backend/onyx/llm/override_models.py ================================================ """Overrides sent over the wire / stored in the DB NOTE: these models are used in many places, so have to be kepy in a separate file to avoid circular imports. """ from pydantic import BaseModel class LLMOverride(BaseModel): """Per-request LLM settings that override persona defaults. All fields are optional — only the fields that differ from the persona's configured LLM need to be supplied. Used both over the wire (API requests) and for multi-model comparison, where one override is supplied per model. Attributes: model_provider: LLM provider slug (e.g. ``"openai"``, ``"anthropic"``). When ``None``, the persona's default provider is used. model_version: Specific model version string (e.g. ``"gpt-4o"``). When ``None``, the persona's default model is used. temperature: Sampling temperature in ``[0, 2]``. When ``None``, the persona's default temperature is used. display_name: Human-readable label shown in the UI for this model, e.g. ``"GPT-4 Turbo"``. Optional; falls back to ``model_version`` when not set. """ model_provider: str | None = None model_version: str | None = None temperature: float | None = None display_name: str | None = None # This disables the "model_" protected namespace for pydantic model_config = {"protected_namespaces": ()} class PromptOverride(BaseModel): system_prompt: str | None = None task_prompt: str | None = None ================================================ FILE: backend/onyx/llm/prompt_cache/README.md ================================================ # Prompt Caching Framework A comprehensive prompt-caching mechanism for enabling cost savings across multiple LLM providers by leveraging provider-side prompt token caching. ## Overview The prompt caching framework provides a unified interface for enabling prompt caching across different LLM providers. It supports both **implicit caching** (automatic provider-side caching) and **explicit caching** (with cache control parameters). ## Features - **Provider Support**: OpenAI (implicit), Anthropic (explicit), Vertex AI (explicit) - **Flexible Input**: Supports both `str` and `Sequence[ChatCompletionMessage]` inputs - **Continuation Handling**: Smart merging of cacheable prefix and suffix messages - **Best-Effort**: Gracefully degrades if caching fails - **Tenant-Aware**: Automatic tenant isolation for multi-tenant deployments - **Configurable**: Enable/disable via environment variable ## Quick Start ### Basic Usage ```python from onyx.llm.prompt_cache import process_with_prompt_cache from onyx.llm.models import SystemMessage, UserMessage # Assume you have an LLM instance with a config property # llm = get_your_llm_instance() # Define cacheable prefix (static context) using Pydantic message models cacheable_prefix = [ SystemMessage(role="system", content="You are a helpful assistant."), UserMessage(role="user", content="Context: ...") # Static context ] # Define suffix (dynamic user input) suffix = [UserMessage(role="user", content="What is the weather?")] # Process with caching - pass llm_config, not the llm instance processed_prompt, cache_metadata = process_with_prompt_cache( llm_config=llm.config, cacheable_prefix=cacheable_prefix, suffix=suffix, continuation=False, ) # Make LLM call with processed prompt response = llm.invoke(processed_prompt) ``` ### Using String Inputs ```python # Both prefix and suffix can be strings cacheable_prefix = "You are a helpful assistant. Context: ..." suffix = "What is the weather?" processed_prompt, cache_metadata = process_with_prompt_cache( llm_config=llm.config, cacheable_prefix=cacheable_prefix, suffix=suffix, continuation=False, ) response = llm.invoke(processed_prompt) ``` ### Continuation Flag When `continuation=True`, the suffix is appended to the last message of the cacheable prefix: ```python # Without continuation (default) # Result: [system_msg, prefix_user_msg, suffix_user_msg] # With continuation=True # Result: [system_msg, prefix_user_msg + suffix_user_msg] processed_prompt, _ = process_with_prompt_cache( llm_config=llm.config, cacheable_prefix=cacheable_prefix, suffix=suffix, continuation=True, # Merge suffix into last prefix message ) ``` **Note**: If `cacheable_prefix` is a string, it remains in its own content block even when `continuation=True`. ## Provider-Specific Behavior ### OpenAI - **Caching Type**: Implicit (automatic) - **Behavior**: No special parameters needed. Provider automatically caches prefixes >1024 tokens. - **Cache Lifetime**: Up to 1 hour - **Cost Savings**: 50% discount on cached tokens ### Anthropic - **Caching Type**: Explicit (requires `cache_control` parameter) - **Behavior**: Automatically adds `cache_control={"type": "ephemeral"}` to the **last message** of the cacheable prefix - **Cache Lifetime**: 5 minutes (default) - **Limitations**: Supports up to 4 cache breakpoints ### Vertex AI - **Caching Type**: Explicit (with `cache_control` parameter) - **Behavior**: Adds `cache_control={"type": "ephemeral"}` to **all content blocks** in cacheable messages. String content is converted to array format with the cache control attached. - **Cache Lifetime**: 5 minutes - **Future**: Full context caching with block number management (deferred to future PR) ## Configuration ### Environment Variables - `ENABLE_PROMPT_CACHING`: Enable/disable prompt caching (default: `true`) ```bash export ENABLE_PROMPT_CACHING=false # Disable caching ``` ## Architecture ### Core Components 1. **`processor.py`**: Main entry point (`process_with_prompt_cache`) 2. **`cache_manager.py`**: Cache metadata storage and retrieval 3. **`models.py`**: Pydantic models for cache metadata (`CacheMetadata`) 4. **`providers/`**: Provider-specific adapters 5. **`utils.py`**: Shared utility functions ### Provider Adapters Each provider has its own adapter in `providers/`: | File | Class | Description | |------|-------|-------------| | `base.py` | `PromptCacheProvider` | Abstract base class for all providers | | `openai.py` | `OpenAIPromptCacheProvider` | Implicit caching (no transformation) | | `anthropic.py` | `AnthropicPromptCacheProvider` | Explicit caching with `cache_control` on last message | | `vertex.py` | `VertexAIPromptCacheProvider` | Explicit caching with `cache_control` on all content blocks | | `noop.py` | `NoOpPromptCacheProvider` | Fallback for unsupported providers | Each adapter implements: - `supports_caching()`: Whether caching is supported - `prepare_messages_for_caching()`: Transform messages for caching - `extract_cache_metadata()`: Extract metadata from responses - `get_cache_ttl_seconds()`: Cache TTL ## Best Practices 1. **Cache Static Content**: Use cacheable prefix for system prompts, static context, and instructions that don't change between requests. 2. **Keep Dynamic Content in Suffix**: User queries, search results, and other dynamic content should be in the suffix. 3. **Monitor Cache Effectiveness**: Check logs for cache hits/misses and adjust your caching strategy accordingly. 4. **Provider Selection**: Different providers have different caching characteristics - choose based on your use case. ## Error Handling The framework is **best-effort** - if caching fails, it gracefully falls back to non-cached behavior: - Cache lookup failures: Logged and continue without caching - Provider adapter failures: Fall back to no-op adapter - Cache storage failures: Logged and continue (caching is best-effort) - Invalid cache metadata: Cleared and proceed without cache ## Future Enhancements - **Explicit Caching for Vertex AI**: Full block number tracking and management - **Cache Analytics**: Detailed metrics on cache effectiveness and cost savings - **Advanced Strategies**: More sophisticated cache key generation and invalidation - **Distributed Caching**: Shared caches across instances ## Examples See `backend/tests/external_dependency_unit/llm/test_prompt_caching.py` for detailed integration test examples. ================================================ FILE: backend/onyx/llm/prompt_cache/__init__.py ================================================ """Prompt caching framework for LLM providers. This module provides a framework for enabling prompt caching across different LLM providers. It supports both implicit caching (automatic provider-side caching) and explicit caching (with cache metadata management). """ from onyx.llm.prompt_cache.cache_manager import CacheManager from onyx.llm.prompt_cache.cache_manager import generate_cache_key_hash from onyx.llm.prompt_cache.models import CacheMetadata from onyx.llm.prompt_cache.processor import process_with_prompt_cache from onyx.llm.prompt_cache.providers.anthropic import AnthropicPromptCacheProvider from onyx.llm.prompt_cache.providers.base import PromptCacheProvider from onyx.llm.prompt_cache.providers.factory import get_provider_adapter from onyx.llm.prompt_cache.providers.noop import NoOpPromptCacheProvider from onyx.llm.prompt_cache.providers.openai import OpenAIPromptCacheProvider from onyx.llm.prompt_cache.providers.vertex import VertexAIPromptCacheProvider from onyx.llm.prompt_cache.utils import combine_messages_with_continuation from onyx.llm.prompt_cache.utils import prepare_messages_with_cacheable_transform __all__ = [ "AnthropicPromptCacheProvider", "CacheManager", "CacheMetadata", "combine_messages_with_continuation", "generate_cache_key_hash", "get_provider_adapter", "NoOpPromptCacheProvider", "OpenAIPromptCacheProvider", "prepare_messages_with_cacheable_transform", "process_with_prompt_cache", "PromptCacheProvider", "VertexAIPromptCacheProvider", ] ================================================ FILE: backend/onyx/llm/prompt_cache/cache_manager.py ================================================ """Cache manager for storing and retrieving prompt cache metadata.""" import hashlib import json from datetime import datetime from datetime import timezone from onyx.configs.model_configs import PROMPT_CACHE_REDIS_TTL_MULTIPLIER from onyx.key_value_store.store import PgRedisKVStore from onyx.llm.interfaces import LanguageModelInput from onyx.llm.prompt_cache.models import CacheMetadata from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() REDIS_KEY_PREFIX = "prompt_cache:" # Cache TTL multiplier - store caches slightly longer than provider TTL # This allows for some clock skew and ensures we don't lose cache metadata prematurely # Value is configurable via PROMPT_CACHE_REDIS_TTL_MULTIPLIER env var (default: 1.2) CACHE_TTL_MULTIPLIER = PROMPT_CACHE_REDIS_TTL_MULTIPLIER class CacheManager: """Manages storage and retrieval of prompt cache metadata.""" def __init__(self, kv_store: PgRedisKVStore | None = None) -> None: """Initialize the cache manager. Args: kv_store: Optional key-value store. If None, creates a new PgRedisKVStore. """ self._kv_store = kv_store or PgRedisKVStore() def _build_cache_key( self, provider: str, model_name: str, cache_key_hash: str, tenant_id: str | None = None, ) -> str: """Build a Redis/PostgreSQL key for cache metadata. Args: provider: LLM provider name (e.g., "openai", "anthropic") model_name: Model name cache_key_hash: Hash of the cacheable prefix content tenant_id: Tenant ID. If None, uses current tenant from context. Returns: Cache key string """ if tenant_id is None: tenant_id = get_current_tenant_id() return f"{REDIS_KEY_PREFIX}{tenant_id}:{provider}:{model_name}:{cache_key_hash}" def store_cache_metadata( self, metadata: CacheMetadata, ) -> None: """Store cache metadata. Args: metadata: Cache metadata to store ttl_seconds: Optional TTL in seconds. If None, uses provider default. """ try: cache_key = self._build_cache_key( metadata.provider, metadata.model_name, metadata.cache_key, metadata.tenant_id, ) # Update last_accessed timestamp metadata.last_accessed = datetime.now(timezone.utc) # Serialize metadata metadata_dict = metadata.model_dump(mode="json") # Store in key-value store # Note: PgRedisKVStore doesn't support TTL directly, but Redis will # handle expiration. For PostgreSQL persistence, we rely on cleanup # based on last_accessed timestamp. self._kv_store.store(cache_key, metadata_dict, encrypt=False) logger.debug( f"Stored cache metadata: provider={metadata.provider}, " f"model={metadata.model_name}, cache_key={metadata.cache_key[:16]}..., " f"tenant_id={metadata.tenant_id}" ) except Exception as e: # Best-effort: log and continue logger.warning(f"Failed to store cache metadata: {str(e)}") def retrieve_cache_metadata( self, provider: str, model_name: str, cache_key_hash: str, tenant_id: str | None = None, ) -> CacheMetadata | None: """Retrieve cache metadata. Args: provider: LLM provider name model_name: Model name cache_key_hash: Hash of the cacheable prefix content tenant_id: Tenant ID. If None, uses current tenant from context. Returns: CacheMetadata if found, None otherwise """ try: cache_key = self._build_cache_key( provider, model_name, cache_key_hash, tenant_id ) metadata_dict = self._kv_store.load(cache_key, refresh_cache=False) # Deserialize metadata metadata = CacheMetadata.model_validate(metadata_dict) # Update last_accessed timestamp metadata.last_accessed = datetime.now(timezone.utc) self.store_cache_metadata(metadata) logger.debug( f"Retrieved cache metadata: provider={provider}, " f"model={model_name}, cache_key={cache_key_hash[:16]}..., " f"tenant_id={tenant_id}" ) return metadata except Exception as e: # Best-effort: log and continue logger.debug(f"Cache metadata not found or error retrieving: {str(e)}") return None def delete_cache_metadata( self, provider: str, model_name: str, cache_key_hash: str, tenant_id: str | None = None, ) -> None: """Delete cache metadata. Args: provider: LLM provider name model_name: Model name cache_key_hash: Hash of the cacheable prefix content tenant_id: Tenant ID. If None, uses current tenant from context. """ try: cache_key = self._build_cache_key( provider, model_name, cache_key_hash, tenant_id ) self._kv_store.delete(cache_key) logger.debug( f"Deleted cache metadata for provider={provider}, model={model_name}, cache_key={cache_key_hash[:16]}..." ) except Exception as e: # Best-effort: log and continue logger.warning(f"Failed to delete cache metadata: {str(e)}") def _make_json_serializable(obj: object) -> object: """Recursively convert objects to JSON-serializable types. Handles Pydantic models, dicts, lists, and other common types. """ if hasattr(obj, "model_dump"): # Pydantic v2 model return obj.model_dump(mode="json") elif hasattr(obj, "dict"): # Pydantic v1 model or similar return _make_json_serializable(obj.dict()) elif isinstance(obj, dict): return {k: _make_json_serializable(v) for k, v in obj.items()} elif isinstance(obj, (list, tuple)): return [_make_json_serializable(item) for item in obj] elif isinstance(obj, (str, int, float, bool, type(None))): return obj else: # Fallback: convert to string representation return str(obj) def generate_cache_key_hash( cacheable_prefix: LanguageModelInput, provider: str, model_name: str, tenant_id: str, ) -> str: """Generate a deterministic cache key hash from cacheable prefix. Args: cacheable_prefix: Single message or list of messages to hash provider: LLM provider name model_name: Model name tenant_id: Tenant ID Returns: SHA256 hash as hex string """ # Normalize to list for consistent hashing; _make_json_serializable handles Pydantic models messages = ( cacheable_prefix if isinstance(cacheable_prefix, list) else [cacheable_prefix] ) messages_dict = [_make_json_serializable(msg) for msg in messages] # Serialize messages in a deterministic way # Include only content, roles, and order - exclude timestamps or dynamic fields serialized = json.dumps( { "messages": messages_dict, "provider": provider, "model": model_name, "tenant_id": tenant_id, }, sort_keys=True, separators=(",", ":"), ) return hashlib.sha256(serialized.encode("utf-8")).hexdigest() ================================================ FILE: backend/onyx/llm/prompt_cache/models.py ================================================ """Interfaces and data structures for prompt caching.""" from datetime import datetime from pydantic import BaseModel class CacheMetadata(BaseModel): """Metadata for cached prompt prefixes.""" cache_key: str provider: str model_name: str tenant_id: str created_at: datetime last_accessed: datetime # Provider-specific metadata # TODO: Add explicit caching support in future PR # vertex_block_numbers: dict[str, str] | None = None # message_hash -> block_number # anthropic_cache_id: str | None = None ================================================ FILE: backend/onyx/llm/prompt_cache/processor.py ================================================ """Main processor for prompt caching.""" from datetime import datetime from datetime import timezone from onyx.configs.model_configs import ENABLE_PROMPT_CACHING from onyx.llm.interfaces import LLMConfig from onyx.llm.models import LanguageModelInput from onyx.llm.prompt_cache.cache_manager import generate_cache_key_hash from onyx.llm.prompt_cache.models import CacheMetadata from onyx.llm.prompt_cache.providers.factory import get_provider_adapter from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() # TODO: test with a history containing images def process_with_prompt_cache( llm_config: LLMConfig, cacheable_prefix: LanguageModelInput | None, suffix: LanguageModelInput, continuation: bool = False, ) -> tuple[LanguageModelInput, CacheMetadata | None]: """Process prompt with caching support. This function takes a cacheable prefix and suffix, processes them according to the LLM provider's caching capabilities, and returns the combined messages ready for LLM API calls along with optional cache metadata. Args: llm: The LLM instance (used to determine provider and model) cacheable_prefix: Optional cacheable prefix. If None, no caching is attempted. suffix: The non-cacheable suffix to append continuation: If True, suffix should be appended to the last message of cacheable_prefix rather than being separate messages Returns: Tuple of (processed_prompt, cache_metadata_to_store) - processed_prompt: Combined and transformed messages ready for LLM API call - cache_metadata_to_store: Optional cache metadata for post-processing (currently None for implicit caching, will be populated in future PR for explicit caching) """ # Check if prompt caching is enabled if not ENABLE_PROMPT_CACHING: logger.debug("Prompt caching is disabled via configuration") # Fall back to no-op behavior from onyx.llm.prompt_cache.providers.noop import NoOpPromptCacheProvider noop_adapter = NoOpPromptCacheProvider() combined = noop_adapter.prepare_messages_for_caching( cacheable_prefix=cacheable_prefix, suffix=suffix, continuation=continuation, cache_metadata=None, ) return combined, None # If no cacheable prefix, return suffix unchanged if cacheable_prefix is None: logger.debug("No cacheable prefix provided, skipping caching") return suffix, None # Get provider adapter provider_adapter = get_provider_adapter(llm_config) # If provider doesn't support caching, combine and return unchanged if not provider_adapter.supports_caching(): logger.debug( f"Provider {llm_config.model_provider} does not support caching, combining messages without caching" ) # Use no-op adapter to combine messages from onyx.llm.prompt_cache.providers.noop import NoOpPromptCacheProvider noop_adapter = NoOpPromptCacheProvider() combined = noop_adapter.prepare_messages_for_caching( cacheable_prefix=cacheable_prefix, suffix=suffix, continuation=continuation, cache_metadata=None, ) return combined, None # Generate cache key for cacheable prefix tenant_id = get_current_tenant_id() cache_key_hash = generate_cache_key_hash( cacheable_prefix=cacheable_prefix, provider=llm_config.model_provider, model_name=llm_config.model_name, tenant_id=tenant_id, ) # For implicit caching: Skip cache lookup (providers handle caching automatically) # TODO (explicit caching - future PR): Look up cache metadata in CacheManager cache_metadata: CacheMetadata | None = None # Use provider adapter to prepare messages with caching try: processed_prompt = provider_adapter.prepare_messages_for_caching( cacheable_prefix=cacheable_prefix, suffix=suffix, continuation=continuation, cache_metadata=cache_metadata, ) logger.debug( f"Processed prompt with caching: provider={llm_config.model_provider}, " f"model={llm_config.model_name}, cache_key={cache_key_hash[:16]}..., " f"continuation={continuation}" ) # Create cache metadata for tracking (even for implicit caching) # This allows us to track cache usage and effectiveness cache_metadata = CacheMetadata( cache_key=cache_key_hash, provider=llm_config.model_provider, model_name=llm_config.model_name, tenant_id=tenant_id, created_at=datetime.now(timezone.utc), last_accessed=datetime.now(timezone.utc), ) return processed_prompt, cache_metadata except Exception as e: # Best-effort: log error and fall back to no-op behavior logger.warning( f"Error processing prompt with caching for provider={llm_config.model_provider}: {str(e)}. " "Falling back to non-cached behavior." ) # Fall back to no-op adapter from onyx.llm.prompt_cache.providers.noop import NoOpPromptCacheProvider noop_adapter = NoOpPromptCacheProvider() combined = noop_adapter.prepare_messages_for_caching( cacheable_prefix=cacheable_prefix, suffix=suffix, continuation=continuation, cache_metadata=None, ) return combined, None ================================================ FILE: backend/onyx/llm/prompt_cache/providers/__init__.py ================================================ """Provider adapters for prompt caching.""" from onyx.llm.prompt_cache.providers.anthropic import AnthropicPromptCacheProvider from onyx.llm.prompt_cache.providers.base import PromptCacheProvider from onyx.llm.prompt_cache.providers.factory import get_provider_adapter from onyx.llm.prompt_cache.providers.noop import NoOpPromptCacheProvider from onyx.llm.prompt_cache.providers.openai import OpenAIPromptCacheProvider from onyx.llm.prompt_cache.providers.vertex import VertexAIPromptCacheProvider __all__ = [ "AnthropicPromptCacheProvider", "get_provider_adapter", "NoOpPromptCacheProvider", "OpenAIPromptCacheProvider", "PromptCacheProvider", "VertexAIPromptCacheProvider", ] ================================================ FILE: backend/onyx/llm/prompt_cache/providers/anthropic.py ================================================ """Anthropic provider adapter for prompt caching.""" from collections.abc import Sequence from onyx.llm.interfaces import LanguageModelInput from onyx.llm.models import ChatCompletionMessage from onyx.llm.prompt_cache.models import CacheMetadata from onyx.llm.prompt_cache.providers.base import PromptCacheProvider from onyx.llm.prompt_cache.utils import prepare_messages_with_cacheable_transform from onyx.llm.prompt_cache.utils import revalidate_message_from_original def _add_anthropic_cache_control( messages: Sequence[ChatCompletionMessage], ) -> Sequence[ChatCompletionMessage]: """Add cache_control parameter to messages for Anthropic caching. Args: messages: Messages to transform Returns: Messages with cache_control added """ last_message_dict = dict(messages[-1]) last_message_dict["cache_control"] = {"type": "ephemeral"} last_message = revalidate_message_from_original( original=messages[-1], mutated=last_message_dict ) return list(messages[:-1]) + [last_message] class AnthropicPromptCacheProvider(PromptCacheProvider): """Anthropic adapter for prompt caching (explicit caching with cache_control). implicit caching = just need to ensure byte-equivalent prefixes, and the provider auto-detects and reuses them. explicit caching = the caller must do _something_ to enable provider-side caching. In this case, anthropic supports explicit caching via the cache_control parameter: https://platform.claude.com/docs/en/build-with-claude/prompt-caching """ def supports_caching(self) -> bool: """Anthropic supports explicit prompt caching.""" return True def prepare_messages_for_caching( self, cacheable_prefix: LanguageModelInput | None, suffix: LanguageModelInput, continuation: bool, cache_metadata: CacheMetadata | None, # noqa: ARG002 ) -> LanguageModelInput: """Prepare messages for Anthropic caching. Anthropic requires cache_control parameter on cacheable messages. We add cache_control={"type": "ephemeral"} to all cacheable prefix messages. Args: cacheable_prefix: Optional cacheable prefix suffix: Non-cacheable suffix continuation: Whether to append suffix to last prefix message cache_metadata: Cache metadata (for future explicit caching support) Returns: Combined messages with cache_control on cacheable messages """ return prepare_messages_with_cacheable_transform( cacheable_prefix=cacheable_prefix, suffix=suffix, continuation=continuation, transform_cacheable=_add_anthropic_cache_control, ) def extract_cache_metadata( self, response: dict, # noqa: ARG002 cache_key: str, # noqa: ARG002 ) -> CacheMetadata | None: """Extract cache metadata from Anthropic response. Anthropic may return cache identifiers in the response. For now, we don't extract detailed metadata (future explicit caching support). Args: response: Anthropic API response dictionary cache_key: Cache key used for this request Returns: CacheMetadata if extractable, None otherwise """ # TODO: Extract cache identifiers from response when implementing explicit caching return None def get_cache_ttl_seconds(self) -> int: """Get cache TTL for Anthropic (5 minutes default).""" return 300 ================================================ FILE: backend/onyx/llm/prompt_cache/providers/base.py ================================================ """Base interface for provider-specific prompt caching adapters.""" from abc import ABC from abc import abstractmethod from onyx.llm.interfaces import LanguageModelInput from onyx.llm.prompt_cache.models import CacheMetadata class PromptCacheProvider(ABC): """Abstract base class for provider-specific prompt caching logic.""" @abstractmethod def supports_caching(self) -> bool: """Whether this provider supports prompt caching. Returns: True if caching is supported, False otherwise """ raise NotImplementedError @abstractmethod def prepare_messages_for_caching( self, cacheable_prefix: LanguageModelInput | None, suffix: LanguageModelInput, continuation: bool, cache_metadata: CacheMetadata | None, ) -> LanguageModelInput: """Transform messages to enable caching. Args: cacheable_prefix: Optional cacheable prefix (can be str or Sequence[ChatCompletionMessage]) suffix: Non-cacheable suffix (can be str or Sequence[ChatCompletionMessage]) continuation: If True, suffix should be appended to the last message of cacheable_prefix rather than being separate messages. Note: When cacheable_prefix is a string, it should remain in its own content block even if continuation=True. cache_metadata: Optional cache metadata from previous requests Returns: Combined and transformed messages ready for LLM API call """ raise NotImplementedError @abstractmethod def extract_cache_metadata( self, response: dict, # Provider-specific response object cache_key: str, ) -> CacheMetadata | None: """Extract cache metadata from API response. Args: response: Provider-specific response dictionary cache_key: Cache key used for this request Returns: CacheMetadata if extractable, None otherwise """ raise NotImplementedError @abstractmethod def get_cache_ttl_seconds(self) -> int: """Get cache TTL in seconds for this provider. Returns: TTL in seconds """ raise NotImplementedError ================================================ FILE: backend/onyx/llm/prompt_cache/providers/factory.py ================================================ """Factory for creating provider-specific prompt cache adapters.""" from onyx.llm.constants import LlmProviderNames from onyx.llm.interfaces import LLMConfig from onyx.llm.prompt_cache.providers.anthropic import AnthropicPromptCacheProvider from onyx.llm.prompt_cache.providers.base import PromptCacheProvider from onyx.llm.prompt_cache.providers.noop import NoOpPromptCacheProvider from onyx.llm.prompt_cache.providers.openai import OpenAIPromptCacheProvider from onyx.llm.prompt_cache.providers.vertex import VertexAIPromptCacheProvider ANTHROPIC_BEDROCK_TAG = "anthropic." def get_provider_adapter(llm_config: LLMConfig) -> PromptCacheProvider: """Get the appropriate prompt cache provider adapter for a given provider. Args: provider: Provider name (e.g., "openai", "anthropic", "vertex_ai") Returns: PromptCacheProvider instance for the given provider """ if llm_config.model_provider == LlmProviderNames.OPENAI: return OpenAIPromptCacheProvider() elif llm_config.model_provider == LlmProviderNames.ANTHROPIC or ( llm_config.model_provider == LlmProviderNames.BEDROCK and ANTHROPIC_BEDROCK_TAG in llm_config.model_name ): return AnthropicPromptCacheProvider() elif llm_config.model_provider == LlmProviderNames.VERTEX_AI: return VertexAIPromptCacheProvider() else: # Default to no-op for providers without caching support return NoOpPromptCacheProvider() ================================================ FILE: backend/onyx/llm/prompt_cache/providers/noop.py ================================================ """No-op provider adapter for providers without caching support.""" from onyx.llm.models import LanguageModelInput from onyx.llm.prompt_cache.models import CacheMetadata from onyx.llm.prompt_cache.providers.base import PromptCacheProvider from onyx.llm.prompt_cache.utils import prepare_messages_with_cacheable_transform class NoOpPromptCacheProvider(PromptCacheProvider): """No-op adapter for providers that don't support prompt caching.""" def supports_caching(self) -> bool: """No-op providers don't support caching.""" return False def prepare_messages_for_caching( self, cacheable_prefix: LanguageModelInput | None, suffix: LanguageModelInput, continuation: bool, cache_metadata: CacheMetadata | None, # noqa: ARG002 ) -> LanguageModelInput: """Return messages unchanged (no caching support). Args: cacheable_prefix: Optional cacheable prefix (can be str or Sequence[ChatCompletionMessage]) suffix: Non-cacheable suffix (can be str or Sequence[ChatCompletionMessage]) continuation: Whether to append suffix to last prefix message. Note: When cacheable_prefix is a string, it remains in its own content block. cache_metadata: Cache metadata (ignored) Returns: Combined messages (prefix + suffix) """ # No transformation needed for no-op provider return prepare_messages_with_cacheable_transform( cacheable_prefix=cacheable_prefix, suffix=suffix, continuation=continuation, transform_cacheable=None, ) def extract_cache_metadata( self, response: dict, # noqa: ARG002 cache_key: str, # noqa: ARG002 ) -> CacheMetadata | None: """No cache metadata to extract.""" return None def get_cache_ttl_seconds(self) -> int: """Return default TTL (not used for no-op).""" return 0 ================================================ FILE: backend/onyx/llm/prompt_cache/providers/openai.py ================================================ """OpenAI provider adapter for prompt caching.""" from onyx.llm.interfaces import LanguageModelInput from onyx.llm.prompt_cache.models import CacheMetadata from onyx.llm.prompt_cache.providers.base import PromptCacheProvider from onyx.llm.prompt_cache.utils import prepare_messages_with_cacheable_transform class OpenAIPromptCacheProvider(PromptCacheProvider): """OpenAI adapter for prompt caching (implicit caching).""" def supports_caching(self) -> bool: """OpenAI supports automatic prompt caching.""" return True def prepare_messages_for_caching( self, cacheable_prefix: LanguageModelInput | None, suffix: LanguageModelInput, continuation: bool, cache_metadata: CacheMetadata | None, # noqa: ARG002 ) -> LanguageModelInput: """Prepare messages for OpenAI caching. OpenAI handles caching automatically, so we just normalize and combine the messages. The provider will automatically cache prefixes >1024 tokens. Args: cacheable_prefix: Optional cacheable prefix suffix: Non-cacheable suffix continuation: Whether to append suffix to last prefix message cache_metadata: Cache metadata (ignored for implicit caching) Returns: Combined messages ready for LLM API call """ # No transformation needed for OpenAI (implicit caching) return prepare_messages_with_cacheable_transform( cacheable_prefix=cacheable_prefix, suffix=suffix, continuation=continuation, transform_cacheable=None, ) def extract_cache_metadata( self, response: dict, # noqa: ARG002 cache_key: str, # noqa: ARG002 ) -> CacheMetadata | None: """Extract cache metadata from OpenAI response. OpenAI responses may include cached_tokens in the usage field. For implicit caching, we don't need to store much metadata. Args: response: OpenAI API response dictionary cache_key: Cache key used for this request Returns: CacheMetadata if extractable, None otherwise """ # For implicit caching, OpenAI handles everything automatically # We could extract cached_tokens from response.get("usage", {}).get("cached_tokens") # but for now, we don't need to store metadata for implicit caching return None def get_cache_ttl_seconds(self) -> int: """Get cache TTL for OpenAI (1 hour max).""" return 3600 ================================================ FILE: backend/onyx/llm/prompt_cache/providers/vertex.py ================================================ """Vertex AI provider adapter for prompt caching.""" from collections.abc import Sequence from onyx.llm.interfaces import LanguageModelInput from onyx.llm.models import ChatCompletionMessage from onyx.llm.prompt_cache.models import CacheMetadata from onyx.llm.prompt_cache.providers.base import PromptCacheProvider from onyx.llm.prompt_cache.utils import prepare_messages_with_cacheable_transform from onyx.llm.prompt_cache.utils import revalidate_message_from_original class VertexAIPromptCacheProvider(PromptCacheProvider): """Vertex AI adapter for prompt caching (implicit caching for this PR).""" def supports_caching(self) -> bool: """Vertex AI supports prompt caching (implicit and explicit).""" return True def prepare_messages_for_caching( self, cacheable_prefix: LanguageModelInput | None, suffix: LanguageModelInput, continuation: bool, cache_metadata: CacheMetadata | None, # noqa: ARG002 ) -> LanguageModelInput: """Prepare messages for Vertex AI caching. For implicit caching we attach cache_control={"type": "ephemeral"} to every cacheable prefix message so Vertex/Gemini can reuse them automatically. Explicit context caching (with cache blocks) will be added in a future PR. Args: cacheable_prefix: Optional cacheable prefix suffix: Non-cacheable suffix continuation: Whether to append suffix to last prefix message cache_metadata: Cache metadata (for future explicit caching support) Returns: Combined messages ready for LLM API call """ # For implicit caching, no transformation needed (Vertex handles caching automatically) # TODO (explicit caching - future PR): # - Check cache_metadata for vertex_block_numbers # - Create transform function that replaces messages with cache_block_id if available # - Or adds cache_control parameter if not using cached blocks return prepare_messages_with_cacheable_transform( cacheable_prefix=cacheable_prefix, suffix=suffix, continuation=continuation, transform_cacheable=None, # TODO: support explicit caching ) def extract_cache_metadata( self, response: dict, # noqa: ARG002 cache_key: str, # noqa: ARG002 ) -> CacheMetadata | None: """Extract cache metadata from Vertex AI response. For this PR (implicit caching): Extract basic cache usage info if available. TODO (explicit caching - future PR): Extract block numbers from response and store in metadata. Args: response: Vertex AI API response dictionary cache_key: Cache key used for this request Returns: CacheMetadata if extractable, None otherwise """ # For implicit caching, Vertex handles everything automatically # TODO (explicit caching - future PR): # - Extract cache block numbers from response # - Store in cache_metadata.vertex_block_numbers return None def get_cache_ttl_seconds(self) -> int: """Get cache TTL for Vertex AI (5 minutes).""" return 300 def _add_vertex_cache_control( messages: Sequence[ChatCompletionMessage], ) -> Sequence[ChatCompletionMessage]: """Add cache_control inside content blocks for Vertex AI/Gemini caching. Gemini requires cache_control to be on a content block within the content array, not at the message level. This function converts string content to the array format and adds cache_control to the last content block in each cacheable message. """ # NOTE: unfortunately we need a much more sophisticated mechnism to support # explict caching with vertex in the presence of tools and system messages # (since they're supposed to be stripped out when setting cache_control) # so we're deferring this to a future PR. updated: list[ChatCompletionMessage] = [] for message in messages: mutated = dict(message) content = mutated.get("content") if isinstance(content, str): # Convert string content to array format with cache_control mutated["content"] = [ { "type": "text", "text": content, "cache_control": {"type": "ephemeral"}, } ] elif isinstance(content, list) and content: # Content is already an array - add cache_control to last block new_content = [] for i, block in enumerate(content): if isinstance(block, dict): block_copy = dict(block) # Add cache_control to the last content block if i == len(content) - 1: block_copy["cache_control"] = {"type": "ephemeral"} new_content.append(block_copy) else: new_content.append(block) mutated["content"] = new_content updated.append(revalidate_message_from_original(message, mutated)) return updated ================================================ FILE: backend/onyx/llm/prompt_cache/utils.py ================================================ # pyright: reportMissingTypeStubs=false """Utility functions for prompt caching.""" import json from collections.abc import Callable from collections.abc import Sequence from typing import Any from onyx.llm.models import ChatCompletionMessage from onyx.llm.models import LanguageModelInput from onyx.utils.logger import setup_logger logger = setup_logger() def combine_messages_with_continuation( prefix_msgs: Sequence[ChatCompletionMessage], suffix_msgs: Sequence[ChatCompletionMessage], continuation: bool, ) -> list[ChatCompletionMessage]: """Combine prefix and suffix messages, handling continuation flag. Args: prefix_msgs: Normalized cacheable prefix messages suffix_msgs: Normalized suffix messages continuation: If True, append suffix content to the last message of prefix was_prefix_string: Deprecated, no longer used Returns: Combined messages """ if not continuation or not prefix_msgs: return list(prefix_msgs) + list(suffix_msgs) # Append suffix content to last message of prefix result = list(prefix_msgs) last_msg = dict(result[-1]) suffix_first = dict(suffix_msgs[0]) if suffix_msgs else {} # Combine content if "content" in last_msg and "content" in suffix_first: if isinstance(last_msg["content"], str) and isinstance( suffix_first["content"], str ): last_msg["content"] = last_msg["content"] + suffix_first["content"] else: # Handle list content (multimodal) prefix_content = ( last_msg["content"] if isinstance(last_msg["content"], list) else [{"type": "text", "text": last_msg["content"]}] ) suffix_content = ( suffix_first["content"] if isinstance(suffix_first["content"], list) else [{"type": "text", "text": suffix_first["content"]}] ) last_msg["content"] = prefix_content + suffix_content result[-1] = revalidate_message_from_original(original=result[-1], mutated=last_msg) result.extend(suffix_msgs[1:]) return result def revalidate_message_from_original( original: ChatCompletionMessage, mutated: dict[str, Any], ) -> ChatCompletionMessage: """Rebuild a mutated message using the original BaseModel type. Some providers need to add cache metadata to messages. Re-run validation against the original message's Pydantic class so union discrimination (by role) stays intact. """ cls = original.__class__ try: return cls.model_validate_json(json.dumps(mutated)) except Exception: return cls.model_validate(mutated) def prepare_messages_with_cacheable_transform( cacheable_prefix: LanguageModelInput | None, suffix: LanguageModelInput, continuation: bool, transform_cacheable: ( Callable[[Sequence[ChatCompletionMessage]], Sequence[ChatCompletionMessage]] | None ) = None, ) -> LanguageModelInput: """Prepare messages for caching with optional transformation of cacheable prefix. This is a shared utility that handles the common flow: 1. Normalize inputs 2. Optionally transform cacheable messages 3. Combine with continuation handling Args: cacheable_prefix: Optional cacheable prefix suffix: Non-cacheable suffix continuation: Whether to append suffix to last prefix message transform_cacheable: Optional function to transform cacheable messages (e.g., add cache_control parameter). If None, messages are used as-is. Returns: Combined messages ready for LLM API call """ if cacheable_prefix is None: return suffix prefix_msgs = ( cacheable_prefix if isinstance(cacheable_prefix, list) else [cacheable_prefix] ) suffix_msgs = suffix if isinstance(suffix, list) else [suffix] # Apply transformation to cacheable messages if provided if transform_cacheable is not None: prefix_msgs = list(transform_cacheable(prefix_msgs)) return combine_messages_with_continuation( prefix_msgs=prefix_msgs, suffix_msgs=suffix_msgs, continuation=continuation ) ================================================ FILE: backend/onyx/llm/request_context.py ================================================ import contextvars _LLM_MOCK_RESPONSE_CONTEXTVAR: contextvars.ContextVar[str | None] = ( contextvars.ContextVar("llm_mock_response", default=None) ) def get_llm_mock_response() -> str | None: return _LLM_MOCK_RESPONSE_CONTEXTVAR.get() def set_llm_mock_response(mock_response: str | None) -> contextvars.Token[str | None]: return _LLM_MOCK_RESPONSE_CONTEXTVAR.set(mock_response) def reset_llm_mock_response(token: contextvars.Token[str | None]) -> None: try: _LLM_MOCK_RESPONSE_CONTEXTVAR.reset(token) except ValueError: # Streaming requests can cross execution contexts. # Best effort clear to avoid crashing request teardown in integration mode. _LLM_MOCK_RESPONSE_CONTEXTVAR.set(None) ================================================ FILE: backend/onyx/llm/utils.py ================================================ import copy import re from collections.abc import Callable from functools import lru_cache from typing import Any from typing import cast from typing import TYPE_CHECKING from sqlalchemy import select from onyx.configs.app_configs import LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS from onyx.configs.app_configs import MAX_TOKENS_FOR_FULL_INCLUSION from onyx.configs.app_configs import SEND_USER_METADATA_TO_LLM_PROVIDER from onyx.configs.app_configs import USE_CHUNK_SUMMARY from onyx.configs.app_configs import USE_DOCUMENT_SUMMARY from onyx.configs.model_configs import GEN_AI_MAX_TOKENS from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS from onyx.configs.model_configs import GEN_AI_NUM_RESERVED_OUTPUT_TOKENS from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.enums import LLMModelFlowType from onyx.db.models import LLMProvider from onyx.db.models import ModelConfiguration from onyx.llm.constants import LlmProviderNames from onyx.llm.interfaces import LLM from onyx.llm.interfaces import LLMUserIdentity from onyx.llm.model_response import ModelResponse from onyx.llm.models import UserMessage from onyx.prompts.contextual_retrieval import CONTEXTUAL_RAG_TOKEN_ESTIMATE from onyx.prompts.contextual_retrieval import DOCUMENT_SUMMARY_TOKEN_ESTIMATE from onyx.utils.logger import setup_logger from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE if TYPE_CHECKING: from onyx.server.manage.llm.models import LLMProviderView logger = setup_logger() MAX_CONTEXT_TOKENS = 100 ONE_MILLION = 1_000_000 CHUNKS_PER_DOC_ESTIMATE = 5 MAX_LITELLM_USER_ID_LENGTH = 64 _TWELVE_LABS_PEGASUS_MODEL_NAMES = [ "us.twelvelabs.pegasus-1-2-v1:0", "us.twelvelabs.pegasus-1-2-v1", "twelvelabs/us.twelvelabs.pegasus-1-2-v1:0", "twelvelabs/us.twelvelabs.pegasus-1-2-v1", ] _TWELVE_LABS_PEGASUS_OUTPUT_TOKENS = max(512, GEN_AI_MODEL_FALLBACK_MAX_TOKENS // 4) CUSTOM_LITELLM_MODEL_OVERRIDES: dict[str, dict[str, Any]] = { model_name: { "max_input_tokens": GEN_AI_MODEL_FALLBACK_MAX_TOKENS, "max_output_tokens": _TWELVE_LABS_PEGASUS_OUTPUT_TOKENS, "max_tokens": GEN_AI_MODEL_FALLBACK_MAX_TOKENS, "supports_reasoning": False, "supports_vision": False, } for model_name in _TWELVE_LABS_PEGASUS_MODEL_NAMES } def truncate_litellm_user_id(user_id: str) -> str: """Truncate the LiteLLM `user` field maximum length.""" if len(user_id) <= MAX_LITELLM_USER_ID_LENGTH: return user_id logger.warning( "User's ID exceeds %d chars (len=%d); truncating for Litellm logging compatibility.", MAX_LITELLM_USER_ID_LENGTH, len(user_id), ) return user_id[:MAX_LITELLM_USER_ID_LENGTH] def build_litellm_passthrough_kwargs( model_kwargs: dict[str, Any], user_identity: LLMUserIdentity | None, ) -> dict[str, Any]: """Build kwargs passed through directly to LiteLLM. Returns `model_kwargs` unchanged unless we need to add user/session metadata, in which case a copy is returned to avoid cross-request mutation. """ if not (SEND_USER_METADATA_TO_LLM_PROVIDER and user_identity): return model_kwargs passthrough_kwargs = copy.deepcopy(model_kwargs) if user_identity.user_id: passthrough_kwargs["user"] = truncate_litellm_user_id(user_identity.user_id) if user_identity.session_id: existing_metadata = passthrough_kwargs.get("metadata") metadata: dict[str, Any] | None if existing_metadata is None: metadata = {} elif isinstance(existing_metadata, dict): metadata = copy.deepcopy(existing_metadata) else: metadata = None if metadata is not None: metadata["session_id"] = user_identity.session_id passthrough_kwargs["metadata"] = metadata return passthrough_kwargs def _unwrap_nested_exception(error: Exception) -> Exception: """ Traverse common exception wrappers to surface the underlying LiteLLM error. """ visited: set[int] = set() current = error for _ in range(100): visited.add(id(current)) candidate: Exception | None = None cause = getattr(current, "__cause__", None) if isinstance(cause, Exception): candidate = cause elif ( hasattr(current, "args") and len(getattr(current, "args")) == 1 and isinstance(current.args[0], Exception) ): candidate = current.args[0] if candidate is None or id(candidate) in visited: break current = candidate return current def litellm_exception_to_error_msg( e: Exception, llm: LLM, fallback_to_error_msg: bool = False, custom_error_msg_mappings: ( dict[str, str] | None ) = LITELLM_CUSTOM_ERROR_MESSAGE_MAPPINGS, ) -> tuple[str, str, bool]: """Convert a LiteLLM exception to a user-friendly error message with classification. Returns: tuple: (error_message, error_code, is_retryable) - error_message: User-friendly error description - error_code: Categorized error code for frontend display - is_retryable: Whether the user should try again """ from litellm.exceptions import BadRequestError from litellm.exceptions import AuthenticationError from litellm.exceptions import PermissionDeniedError from litellm.exceptions import NotFoundError from litellm.exceptions import UnprocessableEntityError from litellm.exceptions import RateLimitError from litellm.exceptions import ContextWindowExceededError from litellm.exceptions import APIConnectionError from litellm.exceptions import APIError from litellm.exceptions import Timeout from litellm.exceptions import ContentPolicyViolationError from litellm.exceptions import BudgetExceededError from litellm.exceptions import ServiceUnavailableError core_exception = _unwrap_nested_exception(e) error_msg = str(core_exception) error_code = "UNKNOWN_ERROR" is_retryable = True if custom_error_msg_mappings: for error_msg_pattern, custom_error_msg in custom_error_msg_mappings.items(): if error_msg_pattern in error_msg: return custom_error_msg, "CUSTOM_ERROR", True if isinstance(core_exception, BadRequestError): error_msg = "Bad request: The server couldn't process your request. Please check your input." error_code = "BAD_REQUEST" is_retryable = True elif isinstance(core_exception, AuthenticationError): error_msg = "Authentication failed: Please check your API key and credentials." error_code = "AUTH_ERROR" is_retryable = False elif isinstance(core_exception, PermissionDeniedError): error_msg = ( "Permission denied: You don't have the necessary permissions for this operation. " "Ensure you have access to this model." ) error_code = "PERMISSION_DENIED" is_retryable = False elif isinstance(core_exception, NotFoundError): error_msg = "Resource not found: The requested resource doesn't exist." error_code = "NOT_FOUND" is_retryable = False elif isinstance(core_exception, UnprocessableEntityError): error_msg = "Unprocessable entity: The server couldn't process your request due to semantic errors." error_code = "UNPROCESSABLE_ENTITY" is_retryable = True elif isinstance(core_exception, RateLimitError): provider_name = ( llm.config.model_provider if llm is not None and llm.config.model_provider else "The LLM provider" ) upstream_detail: str | None = None message_attr = getattr(core_exception, "message", None) if message_attr: upstream_detail = str(message_attr) elif hasattr(core_exception, "api_error"): api_error = core_exception.api_error if isinstance(api_error, dict): upstream_detail = ( api_error.get("message") or api_error.get("detail") or api_error.get("error") ) if not upstream_detail: upstream_detail = str(core_exception) upstream_detail = str(upstream_detail).strip() if ":" in upstream_detail and upstream_detail.lower().startswith( "ratelimiterror" ): upstream_detail = upstream_detail.split(":", 1)[1].strip() upstream_detail_lower = upstream_detail.lower() if ( "insufficient_quota" in upstream_detail_lower or "exceeded your current quota" in upstream_detail_lower ): error_msg = ( f"{provider_name} quota exceeded: {upstream_detail}" if upstream_detail else f"{provider_name} quota exceeded: Verify billing and quota for this API key." ) error_code = "BUDGET_EXCEEDED" is_retryable = False else: error_msg = ( f"{provider_name} rate limit: {upstream_detail}" if upstream_detail else f"{provider_name} rate limit exceeded: Please slow down your requests and try again later." ) error_code = "RATE_LIMIT" is_retryable = True elif isinstance(core_exception, ServiceUnavailableError): provider_name = ( llm.config.model_provider if llm is not None and llm.config.model_provider else "The LLM provider" ) # Check if this is specifically the Bedrock "Too many connections" error if "Too many connections" in error_msg or "BedrockException" in error_msg: error_msg = ( f"{provider_name} is experiencing high connection volume and cannot process your request right now. " "This typically happens when there are too many simultaneous requests to the AI model. " "Please wait a moment and try again. If this persists, contact your system administrator " "to review connection limits and retry configurations." ) else: # Generic 503 Service Unavailable error_msg = f"{provider_name} service error: {str(core_exception)}" error_code = "SERVICE_UNAVAILABLE" is_retryable = True elif isinstance(core_exception, ContextWindowExceededError): error_msg = ( "Context window exceeded: Your input is too long for the model to process." ) if llm is not None: try: max_context = get_max_input_tokens( model_name=llm.config.model_name, model_provider=llm.config.model_provider, ) error_msg += f" Your invoked model ({llm.config.model_name}) has a maximum context size of {max_context}." except Exception: logger.warning( "Unable to get maximum input token for LiteLLM exception handling" ) error_code = "CONTEXT_TOO_LONG" is_retryable = False elif isinstance(core_exception, ContentPolicyViolationError): error_msg = "Content policy violation: Your request violates the content policy. Please revise your input." error_code = "CONTENT_POLICY" is_retryable = False elif isinstance(core_exception, APIConnectionError): error_msg = "API connection error: Failed to connect to the API. Please check your internet connection." error_code = "CONNECTION_ERROR" is_retryable = True elif isinstance(core_exception, BudgetExceededError): error_msg = ( "Budget exceeded: You've exceeded your allocated budget for API usage." ) error_code = "BUDGET_EXCEEDED" is_retryable = False elif isinstance(core_exception, Timeout): error_msg = "Request timed out: The operation took too long to complete. Please try again." error_code = "CONNECTION_ERROR" is_retryable = True elif isinstance(core_exception, APIError): error_msg = f"API error: An error occurred while communicating with the API. Details: {str(core_exception)}" error_code = "API_ERROR" is_retryable = True elif not fallback_to_error_msg: error_msg = "An unexpected error occurred while processing your request. Please try again later." error_code = "UNKNOWN_ERROR" is_retryable = True return error_msg, error_code, is_retryable def llm_response_to_string(message: ModelResponse) -> str: if not isinstance(message.choice.message.content, str): raise RuntimeError("LLM message not in expected format.") return message.choice.message.content def check_number_of_tokens( text: str, encode_fn: Callable[[str], list] | None = None ) -> int: """Gets the number of tokens in the provided text, using the provided encoding function. If none is provided, default to the tiktoken encoder used by GPT-3.5 and GPT-4. """ import tiktoken if encode_fn is None: encode_fn = tiktoken.get_encoding("cl100k_base").encode return len(encode_fn(text)) def test_llm(llm: LLM) -> str | None: # try for up to 2 timeouts (e.g. 10 seconds in total) error_msg = None for _ in range(2): try: llm.invoke(UserMessage(content="Do not respond"), max_tokens=50) return None except Exception as e: error_msg = str(e) logger.warning(f"Failed to call LLM with the following error: {error_msg}") return error_msg @lru_cache(maxsize=1) # the copy.deepcopy is expensive, so we cache the result def get_model_map() -> dict: import litellm DIVIDER = "/" original_map = cast(dict[str, dict], litellm.model_cost) starting_map = copy.deepcopy(original_map) for key in original_map: if DIVIDER in key: truncated_key = key.split(DIVIDER)[-1] # make sure not to overwrite an original key if truncated_key in original_map: continue # if there are multiple possible matches, choose the most "detailed" # one as a heuristic. "detailed" = the description of the model # has the most filled out fields. existing_truncated_value = starting_map.get(truncated_key) potential_truncated_value = original_map[key] if not existing_truncated_value or len(potential_truncated_value) > len( existing_truncated_value ): starting_map[truncated_key] = potential_truncated_value for model_name, model_metadata in CUSTOM_LITELLM_MODEL_OVERRIDES.items(): if model_name in starting_map: continue starting_map[model_name] = copy.deepcopy(model_metadata) # NOTE: outside of the explicit CUSTOM_LITELLM_MODEL_OVERRIDES, # we avoid hard-coding additional models here. Ollama, for example, # allows the user to specify their desired max context window, and it's # unlikely to be standard across users even for the same model # (it heavily depends on their hardware). For those cases, we rely on # GEN_AI_MODEL_FALLBACK_MAX_TOKENS to cover this. # for model_name in [ # "llama3.2", # "llama3.2:1b", # "llama3.2:3b", # "llama3.2:11b", # "llama3.2:90b", # ]: # starting_map[f"ollama/{model_name}"] = { # "max_tokens": 128000, # "max_input_tokens": 128000, # "max_output_tokens": 128000, # } return starting_map def _strip_extra_provider_from_model_name(model_name: str) -> str: return model_name.split("/")[1] if "/" in model_name else model_name def _strip_colon_from_model_name(model_name: str) -> str: return ":".join(model_name.split(":")[:-1]) if ":" in model_name else model_name def find_model_obj(model_map: dict, provider: str, model_name: str) -> dict | None: stripped_model_name = _strip_extra_provider_from_model_name(model_name) model_names = [ model_name, _strip_extra_provider_from_model_name(model_name), # Remove leading extra provider. Usually for cases where user has a # customer model proxy which appends another prefix # remove :XXXX from the end, if present. Needed for ollama. _strip_colon_from_model_name(model_name), _strip_colon_from_model_name(stripped_model_name), ] # Filter out None values and deduplicate model names filtered_model_names = [name for name in model_names if name] # First try all model names with provider prefix for model_name in filtered_model_names: model_obj = model_map.get(f"{provider}/{model_name}") if model_obj: return model_obj # Then try all model names without provider prefix for model_name in filtered_model_names: model_obj = model_map.get(model_name) if model_obj: return model_obj return None def get_llm_contextual_cost( llm: LLM, ) -> float: """ Approximate the cost of using the given LLM for indexing with Contextual RAG. We use a precomputed estimate for the number of tokens in the contextualizing prompts, and we assume that every chunk is maximized in terms of content and context. We also assume that every document is maximized in terms of content, as currently if a document is longer than a certain length, its summary is used instead of the full content. We expect that the first assumption will overestimate more than the second one underestimates, so this should be a fairly conservative price estimate. Also, this does not account for the cost of documents that fit within a single chunk which do not get contextualized. """ import litellm # calculate input costs num_tokens = ONE_MILLION num_input_chunks = num_tokens // DOC_EMBEDDING_CONTEXT_SIZE # We assume that the documents are MAX_TOKENS_FOR_FULL_INCLUSION tokens long # on average. num_docs = num_tokens // MAX_TOKENS_FOR_FULL_INCLUSION num_input_tokens = 0 num_output_tokens = 0 if not USE_CHUNK_SUMMARY and not USE_DOCUMENT_SUMMARY: return 0 if USE_CHUNK_SUMMARY: # Each per-chunk prompt includes: # - The prompt tokens # - the document tokens # - the chunk tokens # for each chunk, we prompt the LLM with the contextual RAG prompt # and the full document content (or the doc summary, so this is an overestimate) num_input_tokens += num_input_chunks * ( CONTEXTUAL_RAG_TOKEN_ESTIMATE + MAX_TOKENS_FOR_FULL_INCLUSION ) # in aggregate, each chunk content is used as a prompt input once # so the full input size is covered num_input_tokens += num_tokens # A single MAX_CONTEXT_TOKENS worth of output is generated per chunk num_output_tokens += num_input_chunks * MAX_CONTEXT_TOKENS # going over each doc once means all the tokens, plus the prompt tokens for # the summary prompt. This CAN happen even when USE_DOCUMENT_SUMMARY is false, # since doc summaries are used for longer documents when USE_CHUNK_SUMMARY is true. # So, we include this unconditionally to overestimate. num_input_tokens += num_tokens + num_docs * DOCUMENT_SUMMARY_TOKEN_ESTIMATE num_output_tokens += num_docs * MAX_CONTEXT_TOKENS try: usd_per_prompt, usd_per_completion = litellm.cost_per_token( model=llm.config.model_name, prompt_tokens=num_input_tokens, completion_tokens=num_output_tokens, ) except Exception: logger.exception( "An unexpected error occurred while calculating cost for model " f"{llm.config.model_name} (potentially due to malformed name). " "Assuming cost is 0." ) return 0 # Costs are in USD dollars per million tokens return usd_per_prompt + usd_per_completion def llm_max_input_tokens( model_map: dict, model_name: str, model_provider: str, ) -> int: """Best effort attempt to get the max input tokens for the LLM.""" if GEN_AI_MAX_TOKENS: # This is an override, so always return this logger.info(f"Using override GEN_AI_MAX_TOKENS: {GEN_AI_MAX_TOKENS}") return GEN_AI_MAX_TOKENS model_obj = find_model_obj( model_map, model_provider, model_name, ) if not model_obj: logger.warning( f"Model '{model_name}' not found in LiteLLM. Falling back to {GEN_AI_MODEL_FALLBACK_MAX_TOKENS} tokens." ) return GEN_AI_MODEL_FALLBACK_MAX_TOKENS if "max_input_tokens" in model_obj: return model_obj["max_input_tokens"] if "max_tokens" in model_obj: return model_obj["max_tokens"] logger.warning( f"No max tokens found for '{model_name}'. Falling back to {GEN_AI_MODEL_FALLBACK_MAX_TOKENS} tokens." ) return GEN_AI_MODEL_FALLBACK_MAX_TOKENS def get_llm_max_output_tokens( model_map: dict, model_name: str, model_provider: str, ) -> int: """Best effort attempt to get the max output tokens for the LLM.""" default_output_tokens = int(GEN_AI_MODEL_FALLBACK_MAX_TOKENS) model_obj = model_map.get(f"{model_provider}/{model_name}") if not model_obj: model_obj = model_map.get(model_name) if not model_obj: logger.warning( f"Model '{model_name}' not found in LiteLLM. Falling back to {default_output_tokens} output tokens." ) return default_output_tokens if "max_output_tokens" in model_obj: return model_obj["max_output_tokens"] # Fallback to a fraction of max_tokens if max_output_tokens is not specified if "max_tokens" in model_obj: return int(model_obj["max_tokens"] * 0.1) logger.warning( f"No max output tokens found for '{model_name}'. Falling back to {default_output_tokens} output tokens." ) return default_output_tokens def get_max_input_tokens( model_name: str, model_provider: str, output_tokens: int = GEN_AI_NUM_RESERVED_OUTPUT_TOKENS, ) -> int: # NOTE: we previously used `litellm.get_max_tokens()`, but despite the name, this actually # returns the max OUTPUT tokens. Under the hood, this uses the `litellm.model_cost` dict, # and there is no other interface to get what we want. This should be okay though, since the # `model_cost` dict is a named public interface: # https://litellm.vercel.app/docs/completion/token_usage#7-model_cost # model_map is litellm.model_cost litellm_model_map = get_model_map() input_toks = ( llm_max_input_tokens( model_name=model_name, model_provider=model_provider, model_map=litellm_model_map, ) - output_tokens ) if input_toks <= 0: return GEN_AI_MODEL_FALLBACK_MAX_TOKENS return input_toks def get_max_input_tokens_from_llm_provider( llm_provider: "LLMProviderView", model_name: str, ) -> int: """Get max input tokens for a model, with fallback chain. Fallback order: 1. Use max_input_tokens from model_configuration (populated from source APIs like OpenRouter, Ollama, or our Bedrock mapping) 2. Look up in litellm.model_cost dictionary 3. Fall back to GEN_AI_MODEL_FALLBACK_MAX_TOKENS (32000) Most dynamic providers (OpenRouter, Ollama) provide context_length via their APIs. Bedrock doesn't expose this, so we parse from model ID suffix (:200k) or use BEDROCK_MODEL_TOKEN_LIMITS mapping. The 32000 fallback is only hit for unknown models not in any of these sources. """ max_input_tokens = None for model_configuration in llm_provider.model_configurations: if model_configuration.name == model_name: max_input_tokens = model_configuration.max_input_tokens return ( max_input_tokens if max_input_tokens else get_max_input_tokens( model_provider=llm_provider.name, model_name=model_name, ) ) def get_bedrock_token_limit(model_id: str) -> int: """Look up token limit for a Bedrock model. AWS Bedrock API doesn't expose token limits directly. This function attempts to determine the limit from multiple sources. Lookup order: 1. Parse from model ID suffix (e.g., ":200k" → 200000) 2. Check LiteLLM's model_cost dictionary 3. Fall back to our hardcoded BEDROCK_MODEL_TOKEN_LIMITS mapping 4. Default to 32000 if not found anywhere """ from onyx.llm.constants import BEDROCK_MODEL_TOKEN_LIMITS model_id_lower = model_id.lower() # 1. Try to parse context length from model ID suffix # Format: "model-name:version:NNNk" where NNN is the context length in thousands # Examples: ":200k", ":128k", ":1000k", ":8k", ":4k" context_match = re.search(r":(\d+)k\b", model_id_lower) if context_match: return int(context_match.group(1)) * 1000 # 2. Check LiteLLM's model_cost dictionary try: model_map = get_model_map() # Try with bedrock/ prefix first, then without for key in [f"bedrock/{model_id}", model_id]: if key in model_map: model_info = model_map[key] if "max_input_tokens" in model_info: return model_info["max_input_tokens"] if "max_tokens" in model_info: return model_info["max_tokens"] except Exception: pass # Fall through to mapping # 3. Try our hardcoded mapping (longest match first) for pattern, limit in sorted( BEDROCK_MODEL_TOKEN_LIMITS.items(), key=lambda x: -len(x[0]) ): if pattern in model_id_lower: return limit # 4. Default fallback return GEN_AI_MODEL_FALLBACK_MAX_TOKENS def model_supports_image_input(model_name: str, model_provider: str) -> bool: # First, try to read an explicit configuration from the model_configuration table try: with get_session_with_current_tenant() as db_session: model_config = db_session.scalar( select(ModelConfiguration) .join( LLMProvider, ModelConfiguration.llm_provider_id == LLMProvider.id, ) .where( ModelConfiguration.name == model_name, LLMProvider.provider == model_provider, ) ) if ( model_config and LLMModelFlowType.VISION in model_config.llm_model_flow_types ): return True except Exception as e: logger.warning( f"Failed to query database for {model_provider} model {model_name} image support: {e}" ) # Fallback to looking up the model in the litellm model_cost dict return litellm_thinks_model_supports_image_input(model_name, model_provider) def litellm_thinks_model_supports_image_input( model_name: str, model_provider: str ) -> bool: """Generally should call `model_supports_image_input` unless you already know that `model_supports_image_input` from the DB is not set OR you need to avoid the performance hit of querying the DB.""" try: model_obj = find_model_obj(get_model_map(), model_provider, model_name) if not model_obj: logger.warning( f"No litellm entry found for {model_provider}/{model_name}, this model may or may not support image input." ) return False # The or False here is because sometimes the dict contains the key but the value is None return model_obj.get("supports_vision", False) or False except Exception: logger.exception( f"Failed to get model object for {model_provider}/{model_name}" ) return False def model_is_reasoning_model(model_name: str, model_provider: str) -> bool: import litellm model_map = get_model_map() try: model_obj = find_model_obj( model_map, model_provider, model_name, ) if model_obj and "supports_reasoning" in model_obj: return model_obj["supports_reasoning"] # Fallback: try using litellm.supports_reasoning() for newer models try: # logger.debug("Falling back to `litellm.supports_reasoning`") full_model_name = ( f"{model_provider}/{model_name}" if model_provider not in model_name else model_name ) return litellm.supports_reasoning(model=full_model_name) except Exception: logger.exception( f"Failed to check if {model_provider}/{model_name} supports reasoning" ) return False except Exception: logger.exception( f"Failed to get model object for {model_provider}/{model_name}" ) return False def is_true_openai_model(model_provider: str, model_name: str) -> bool: """ Determines if a model is a true OpenAI model or just using OpenAI-compatible API. LiteLLM uses the "openai" provider for any OpenAI-compatible server (e.g. vLLM, LiteLLM proxy), but this function checks if the model is actually from OpenAI's model registry. This function is used primarily to determine if we should use the responses API. OpenAI models from OpenAI and Azure should use responses. """ if model_provider not in { LlmProviderNames.OPENAI, LlmProviderNames.LITELLM_PROXY, LlmProviderNames.AZURE, }: return False model_map = get_model_map() def _check_if_model_name_is_openai_provider(model_name: str) -> bool: if model_name not in model_map: return False return model_map[model_name].get("litellm_provider") == LlmProviderNames.OPENAI try: # Check if any model exists in litellm's registry with openai prefix # If it's registered as "openai/model-name", it's a real OpenAI model if f"{LlmProviderNames.OPENAI}/{model_name}" in model_map: return True if _check_if_model_name_is_openai_provider(model_name): return True if model_name.startswith(f"{LlmProviderNames.AZURE}/"): model_name_with_azure_removed = "/".join(model_name.split("/")[1:]) if _check_if_model_name_is_openai_provider(model_name_with_azure_removed): return True return False except Exception: logger.exception( f"Failed to determine if {model_provider}/{model_name} is a true OpenAI model" ) return False def model_needs_formatting_reenabled(model_name: str) -> bool: # See https://simonwillison.net/tags/markdown/ for context on why this is needed # for OpenAI reasoning models to have correct markdown generation # Models that need formatting re-enabled model_names = ["gpt-5.1", "gpt-5", "o3", "o1"] # Pattern matches if any of these model names appear with word boundaries # Word boundaries include: start/end of string, space, hyphen, or forward slash pattern = ( r"(?:^|[\s\-/])(" + "|".join(re.escape(name) for name in model_names) + r")(?:$|[\s\-/])" ) if re.search(pattern, model_name): return True return False ================================================ FILE: backend/onyx/llm/well_known_providers/auto_update_models.py ================================================ """Pydantic models for GitHub-hosted Auto LLM configuration.""" from datetime import datetime from typing import Any from pydantic import BaseModel from pydantic import field_validator from onyx.llm.well_known_providers.models import SimpleKnownModel class LLMProviderRecommendation(BaseModel): """Configuration for a single provider in the GitHub config. Schema matches the plan: - default_model: The default model config (can be string or object with name) - additional_visible_models: List of additional visible model configs """ default_model: SimpleKnownModel additional_visible_models: list[SimpleKnownModel] = [] @field_validator("default_model", mode="before") @classmethod def normalize_default_model(cls, v: Any) -> dict[str, Any]: """Allow default_model to be a string (model name) or object.""" if isinstance(v, str): return {"name": v} return v class LLMRecommendations(BaseModel): """Root configuration object fetched from GitHub.""" version: str updated_at: datetime providers: dict[str, LLMProviderRecommendation] def get_visible_models(self, provider_name: str) -> list[SimpleKnownModel]: """Get the set of models that should be visible by default for a provider.""" if provider_name in self.providers: provider_config = self.providers[provider_name] return [provider_config.default_model] + list( provider_config.additional_visible_models ) return [] def get_default_model(self, provider_name: str) -> SimpleKnownModel | None: """Get the default model for a provider.""" if provider_name in self.providers: provider_config = self.providers[provider_name] return provider_config.default_model return None ================================================ FILE: backend/onyx/llm/well_known_providers/auto_update_service.py ================================================ """Service for fetching and syncing LLM model configurations from GitHub. This service manages Auto mode LLM providers, where models and configuration are managed centrally via a GitHub-hosted JSON file. In Auto mode: - Model list is controlled by GitHub config - Model visibility is controlled by GitHub config - Default model is controlled by GitHub config - Admin only needs to provide API credentials """ from datetime import datetime import httpx from sqlalchemy.orm import Session from onyx.cache.factory import get_cache_backend from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL from onyx.db.llm import fetch_auto_mode_providers from onyx.db.llm import sync_auto_mode_models from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations from onyx.utils.logger import setup_logger logger = setup_logger() _CACHE_KEY_LAST_UPDATED_AT = "auto_llm_update:last_updated_at" _CACHE_TTL_SECONDS = 60 * 60 * 24 # 24 hours def _get_cached_last_updated_at() -> datetime | None: try: value = get_cache_backend().get(_CACHE_KEY_LAST_UPDATED_AT) if value is not None: return datetime.fromisoformat(value.decode("utf-8")) except Exception as e: logger.warning(f"Failed to get cached last_updated_at: {e}") return None def _set_cached_last_updated_at(updated_at: datetime) -> None: try: get_cache_backend().set( _CACHE_KEY_LAST_UPDATED_AT, updated_at.isoformat(), ex=_CACHE_TTL_SECONDS, ) except Exception as e: logger.warning(f"Failed to set cached last_updated_at: {e}") def fetch_llm_recommendations_from_github( timeout: float = 30.0, ) -> LLMRecommendations | None: """Fetch LLM configuration from GitHub. Returns: GitHubLLMConfig if successful, None on error. """ if not AUTO_LLM_CONFIG_URL: logger.debug("AUTO_LLM_CONFIG_URL not configured, skipping fetch") return None try: with httpx.Client(timeout=timeout) as client: response = client.get(AUTO_LLM_CONFIG_URL) response.raise_for_status() data = response.json() return LLMRecommendations.model_validate(data) except httpx.HTTPError as e: logger.error(f"Failed to fetch LLM config from GitHub: {e}") return None except Exception as e: logger.error(f"Error parsing LLM config: {e}") return None def sync_llm_models_from_github( db_session: Session, force: bool = False, ) -> dict[str, int]: """Sync models from GitHub config to database for all Auto mode providers. In Auto mode, EVERYTHING is controlled by GitHub config: - Model list - Model visibility (is_visible) - Default model - Fast default model Args: db_session: Database session config: GitHub LLM configuration force: If True, skip the updated_at check and force sync Returns: Dict of provider_name -> number of changes made. """ results: dict[str, int] = {} # Get all providers in Auto mode auto_providers = fetch_auto_mode_providers(db_session) if not auto_providers: logger.debug("No providers in Auto mode found") return {} # Fetch config from GitHub config = fetch_llm_recommendations_from_github() if not config: logger.warning("Failed to fetch GitHub config") return {} # Skip if we've already processed this version (unless forced) last_updated_at = _get_cached_last_updated_at() if not force and last_updated_at and config.updated_at <= last_updated_at: logger.debug("GitHub config unchanged, skipping sync") _set_cached_last_updated_at(config.updated_at) return {} for provider in auto_providers: provider_type = provider.provider # e.g., "openai", "anthropic" if provider_type not in config.providers: logger.debug( f"No config for provider type '{provider_type}' in GitHub config" ) continue # Sync models - this replaces the model list entirely for Auto mode changes = sync_auto_mode_models( db_session=db_session, provider=provider, llm_recommendations=config, ) if changes > 0: results[provider.name] = changes logger.info( f"Applied {changes} model changes to provider '{provider.name}'" ) _set_cached_last_updated_at(config.updated_at) return results def reset_cache() -> None: """Reset the cache timestamp. Useful for testing.""" try: get_cache_backend().delete(_CACHE_KEY_LAST_UPDATED_AT) except Exception as e: logger.warning(f"Failed to reset cache: {e}") ================================================ FILE: backend/onyx/llm/well_known_providers/constants.py ================================================ from onyx.llm.constants import LlmProviderNames OPENAI_PROVIDER_NAME = "openai" BEDROCK_PROVIDER_NAME = "bedrock" OLLAMA_PROVIDER_NAME = "ollama_chat" OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY" LM_STUDIO_PROVIDER_NAME = "lm_studio" LM_STUDIO_API_KEY_CONFIG_KEY = "LM_STUDIO_API_KEY" LITELLM_PROXY_PROVIDER_NAME = "litellm_proxy" BIFROST_PROVIDER_NAME = "bifrost" # Providers that use optional Bearer auth from custom_config PROVIDERS_WITH_SPECIAL_API_KEY_HANDLING: dict[str, str] = { LlmProviderNames.OLLAMA_CHAT: OLLAMA_API_KEY_CONFIG_KEY, LlmProviderNames.LM_STUDIO: LM_STUDIO_API_KEY_CONFIG_KEY, } # OpenRouter OPENROUTER_PROVIDER_NAME = "openrouter" ANTHROPIC_PROVIDER_NAME = "anthropic" AZURE_PROVIDER_NAME = "azure" VERTEXAI_PROVIDER_NAME = "vertex_ai" VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials" VERTEX_CREDENTIALS_FILE_KWARG_ENV_VAR_FORMAT = "CREDENTIALS_FILE" VERTEX_LOCATION_KWARG = "vertex_location" AWS_REGION_NAME_KWARG = "aws_region_name" AWS_REGION_NAME_KWARG_ENV_VAR_FORMAT = "AWS_REGION_NAME" AWS_BEARER_TOKEN_BEDROCK_KWARG_ENV_VAR_FORMAT = "AWS_BEARER_TOKEN_BEDROCK" AWS_ACCESS_KEY_ID_KWARG = "aws_access_key_id" AWS_ACCESS_KEY_ID_KWARG_ENV_VAR_FORMAT = "AWS_ACCESS_KEY_ID" AWS_SECRET_ACCESS_KEY_KWARG = "aws_secret_access_key" AWS_SECRET_ACCESS_KEY_KWARG_ENV_VAR_FORMAT = "AWS_SECRET_ACCESS_KEY" ================================================ FILE: backend/onyx/llm/well_known_providers/llm_provider_options.py ================================================ import json import pathlib import threading import time from onyx.llm.constants import LlmProviderNames from onyx.llm.constants import PROVIDER_DISPLAY_NAMES from onyx.llm.constants import WELL_KNOWN_PROVIDER_NAMES from onyx.llm.utils import get_max_input_tokens from onyx.llm.utils import model_supports_image_input from onyx.llm.well_known_providers.auto_update_models import LLMRecommendations from onyx.llm.well_known_providers.auto_update_service import ( fetch_llm_recommendations_from_github, ) from onyx.llm.well_known_providers.constants import ANTHROPIC_PROVIDER_NAME from onyx.llm.well_known_providers.constants import AZURE_PROVIDER_NAME from onyx.llm.well_known_providers.constants import BEDROCK_PROVIDER_NAME from onyx.llm.well_known_providers.constants import BIFROST_PROVIDER_NAME from onyx.llm.well_known_providers.constants import LITELLM_PROXY_PROVIDER_NAME from onyx.llm.well_known_providers.constants import LM_STUDIO_PROVIDER_NAME from onyx.llm.well_known_providers.constants import OLLAMA_PROVIDER_NAME from onyx.llm.well_known_providers.constants import OPENAI_PROVIDER_NAME from onyx.llm.well_known_providers.constants import OPENROUTER_PROVIDER_NAME from onyx.llm.well_known_providers.constants import VERTEXAI_PROVIDER_NAME from onyx.llm.well_known_providers.models import WellKnownLLMProviderDescriptor from onyx.server.manage.llm.models import ModelConfigurationView from onyx.utils.logger import setup_logger logger = setup_logger() _RECOMMENDATIONS_CACHE_TTL_SECONDS = 300 _recommendations_cache_lock = threading.Lock() _cached_recommendations: LLMRecommendations | None = None _cached_recommendations_time: float = 0.0 def _get_provider_to_models_map() -> dict[str, list[str]]: """Lazy-load provider model mappings to avoid importing litellm at module level. Dynamic providers (Bedrock, Ollama, OpenRouter) return empty lists here because their models are fetched directly from the source API, which is more up-to-date than LiteLLM's static lists. """ return { OPENAI_PROVIDER_NAME: get_openai_model_names(), BEDROCK_PROVIDER_NAME: [], # Dynamic - fetched from AWS API ANTHROPIC_PROVIDER_NAME: get_anthropic_model_names(), VERTEXAI_PROVIDER_NAME: get_vertexai_model_names(), OLLAMA_PROVIDER_NAME: [], # Dynamic - fetched from Ollama API LM_STUDIO_PROVIDER_NAME: [], # Dynamic - fetched from LM Studio API OPENROUTER_PROVIDER_NAME: [], # Dynamic - fetched from OpenRouter API LITELLM_PROXY_PROVIDER_NAME: [], # Dynamic - fetched from LiteLLM proxy API BIFROST_PROVIDER_NAME: [], # Dynamic - fetched from Bifrost API } def _load_bundled_recommendations() -> LLMRecommendations: json_path = pathlib.Path(__file__).parent / "recommended-models.json" with open(json_path, "r") as f: json_config = json.load(f) return LLMRecommendations.model_validate(json_config) def get_recommendations() -> LLMRecommendations: """Get the recommendations, with an in-memory cache to avoid hitting GitHub on every API request.""" global _cached_recommendations, _cached_recommendations_time now = time.monotonic() if ( _cached_recommendations is not None and (now - _cached_recommendations_time) < _RECOMMENDATIONS_CACHE_TTL_SECONDS ): return _cached_recommendations with _recommendations_cache_lock: # Double-check after acquiring lock if ( _cached_recommendations is not None and (time.monotonic() - _cached_recommendations_time) < _RECOMMENDATIONS_CACHE_TTL_SECONDS ): return _cached_recommendations recommendations_from_github = fetch_llm_recommendations_from_github() result = recommendations_from_github or _load_bundled_recommendations() _cached_recommendations = result _cached_recommendations_time = time.monotonic() return result def is_obsolete_model(model_name: str, provider: str) -> bool: """Check if a model is obsolete and should be filtered out. Filters models that are 2+ major versions behind or deprecated. This is the single source of truth for obsolete model detection. """ model_lower = model_name.lower() # OpenAI obsolete models if provider == LlmProviderNames.OPENAI: # GPT-3 models are obsolete if "gpt-3" in model_lower: return True # Legacy models deprecated = { "text-davinci-003", "text-davinci-002", "text-curie-001", "text-babbage-001", "text-ada-001", "davinci", "curie", "babbage", "ada", } if model_lower in deprecated: return True # Anthropic obsolete models if provider == LlmProviderNames.ANTHROPIC: if "claude-2" in model_lower or "claude-instant" in model_lower: return True # Vertex AI obsolete models if provider == LlmProviderNames.VERTEX_AI: if "gemini-1.0" in model_lower: return True if "palm" in model_lower or "bison" in model_lower: return True return False def get_openai_model_names() -> list[str]: """Get OpenAI model names dynamically from litellm.""" import re import litellm # TODO: remove these lists once we have a comprehensive model configuration page # The ideal flow should be: fetch all available models --> filter by type # --> allow user to modify filters and select models based on current context non_chat_model_terms = { "embed", "audio", "tts", "whisper", "dall-e", "image", "moderation", "sora", "container", } deprecated_model_terms = {"babbage", "davinci", "gpt-3.5", "gpt-4-"} excluded_terms = non_chat_model_terms | deprecated_model_terms # NOTE: We are explicitly excluding all "timestamped" models # because they are mostly just noise in the admin configuration panel # e.g. gpt-4o-2025-07-16, gpt-3.5-turbo-0613, etc. date_pattern = re.compile(r"-\d{4}") def is_valid_model(model: str) -> bool: model_lower = model.lower() return not any( ex in model_lower for ex in excluded_terms ) and not date_pattern.search(model) return sorted( ( model.removeprefix("openai/") for model in litellm.open_ai_chat_completion_models if is_valid_model(model) ), reverse=True, ) def get_anthropic_model_names() -> list[str]: """Get Anthropic model names dynamically from litellm.""" import litellm # Models to exclude from Anthropic's model list (deprecated or duplicates) _IGNORABLE_ANTHROPIC_MODELS = { "claude-2", "claude-instant-1", "anthropic/claude-3-5-sonnet-20241022", } return sorted( [ model for model in litellm.anthropic_models if model not in _IGNORABLE_ANTHROPIC_MODELS and not is_obsolete_model(model, LlmProviderNames.ANTHROPIC) ], reverse=True, ) def get_vertexai_model_names() -> list[str]: """Get Vertex AI model names dynamically from litellm model_cost.""" import litellm # Combine all vertex model sets vertex_models: set[str] = set() vertex_model_sets = [ "vertex_chat_models", "vertex_language_models", "vertex_anthropic_models", "vertex_llama3_models", "vertex_mistral_models", "vertex_ai_ai21_models", "vertex_deepseek_models", ] for attr in vertex_model_sets: if hasattr(litellm, attr): vertex_models.update(getattr(litellm, attr)) # Also extract from model_cost for any models not in the sets for key in litellm.model_cost.keys(): if key.startswith("vertex_ai/"): model_name = key.replace("vertex_ai/", "") vertex_models.add(model_name) return sorted( [ model for model in vertex_models if "embed" not in model.lower() and "image" not in model.lower() and "video" not in model.lower() and "code" not in model.lower() and "veo" not in model.lower() # video generation and "live" not in model.lower() # live/streaming models and "tts" not in model.lower() # text-to-speech and "native-audio" not in model.lower() # audio models and "/" not in model # filter out prefixed models like openai/gpt-oss and "search_api" not in model.lower() # not a model and "-maas" not in model.lower() # marketplace models and not is_obsolete_model(model, LlmProviderNames.VERTEX_AI) ], reverse=True, ) def model_configurations_for_provider( provider_name: str, llm_recommendations: LLMRecommendations ) -> list[ModelConfigurationView]: recommended_visible_models = llm_recommendations.get_visible_models(provider_name) recommended_visible_models_names = [m.name for m in recommended_visible_models] # Preserve provider-defined ordering while de-duplicating. model_names: list[str] = [] seen_model_names: set[str] = set() for model_name in ( fetch_models_for_provider(provider_name) + recommended_visible_models_names ): if model_name in seen_model_names: continue seen_model_names.add(model_name) model_names.append(model_name) # Vertex model list can be large and mixed-vendor; alphabetical ordering # makes model discovery easier in admin selection UIs. if provider_name == VERTEXAI_PROVIDER_NAME: model_names = sorted(model_names, key=str.lower) return [ ModelConfigurationView( name=model_name, is_visible=model_name in recommended_visible_models_names, max_input_tokens=get_max_input_tokens(model_name, provider_name), supports_image_input=model_supports_image_input(model_name, provider_name), ) for model_name in model_names ] def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]: llm_recommendations = get_recommendations() well_known_llms = [] for provider_name in WELL_KNOWN_PROVIDER_NAMES: model_configurations = model_configurations_for_provider( provider_name, llm_recommendations ) well_known_llms.append( WellKnownLLMProviderDescriptor( name=provider_name, known_models=model_configurations, recommended_default_model=llm_recommendations.get_default_model( provider_name ), ) ) return well_known_llms def fetch_models_for_provider(provider_name: str) -> list[str]: return _get_provider_to_models_map().get(provider_name, []) def fetch_model_names_for_provider_as_set(provider_name: str) -> set[str] | None: model_names = fetch_models_for_provider(provider_name) return set(model_names) if model_names else None def fetch_visible_model_names_for_provider_as_set( provider_name: str, ) -> set[str] | None: """Get visible model names for a provider. Note: Since we no longer maintain separate visible model lists, this returns all models (same as fetch_model_names_for_provider_as_set). Kept for backwards compatibility with alembic migrations. """ return fetch_model_names_for_provider_as_set(provider_name) def get_provider_display_name(provider_name: str) -> str: """Get human-friendly display name for an Onyx-supported provider. First checks Onyx-specific display names, then falls back to PROVIDER_DISPLAY_NAMES from constants. """ # Display names for Onyx-supported LLM providers (used in admin UI provider selection). # These override PROVIDER_DISPLAY_NAMES for Onyx-specific branding. _ONYX_PROVIDER_DISPLAY_NAMES: dict[str, str] = { OPENAI_PROVIDER_NAME: "ChatGPT (OpenAI)", OLLAMA_PROVIDER_NAME: "Ollama", LM_STUDIO_PROVIDER_NAME: "LM Studio", ANTHROPIC_PROVIDER_NAME: "Claude (Anthropic)", AZURE_PROVIDER_NAME: "Azure OpenAI", BEDROCK_PROVIDER_NAME: "Amazon Bedrock", VERTEXAI_PROVIDER_NAME: "Google Vertex AI", OPENROUTER_PROVIDER_NAME: "OpenRouter", LITELLM_PROXY_PROVIDER_NAME: "LiteLLM Proxy", } if provider_name in _ONYX_PROVIDER_DISPLAY_NAMES: return _ONYX_PROVIDER_DISPLAY_NAMES[provider_name] return PROVIDER_DISPLAY_NAMES.get( provider_name.lower(), provider_name.replace("_", " ").title() ) def fetch_default_model_for_provider(provider_name: str) -> str | None: """Fetch the default model for a provider. First checks the GitHub-hosted recommended-models.json config (via fetch_github_config), then falls back to hardcoded defaults if unavailable. """ llm_recommendations = get_recommendations() default_model = llm_recommendations.get_default_model(provider_name) return default_model.name if default_model else None ================================================ FILE: backend/onyx/llm/well_known_providers/models.py ================================================ from enum import Enum from pydantic import BaseModel from pydantic import Field from onyx.server.manage.llm.models import ModelConfigurationView class CustomConfigKeyType(str, Enum): # used for configuration values that require manual input # i.e., textual API keys (e.g., "abcd1234") TEXT_INPUT = "text_input" # used for configuration values that require a file to be selected/drag-and-dropped # i.e., file based credentials (e.g., "/path/to/credentials/file.json") FILE_INPUT = "file_input" # used for configuration values that require a selection from predefined options SELECT = "select" class SimpleKnownModel(BaseModel): name: str display_name: str | None = None class WellKnownLLMProviderDescriptor(BaseModel): name: str # NOTE: the recommended visible models are encoded in the known_models list known_models: list[ModelConfigurationView] = Field(default_factory=list) recommended_default_model: SimpleKnownModel | None = None ================================================ FILE: backend/onyx/llm/well_known_providers/recommended-models.json ================================================ { "version": "1.1", "updated_at": "2026-03-05T00:00:00Z", "providers": { "openai": { "default_model": { "name": "gpt-5.4" }, "additional_visible_models": [ { "name": "gpt-5.4" }, { "name": "gpt-5.2" } ] }, "anthropic": { "default_model": "claude-opus-4-6", "additional_visible_models": [ { "name": "claude-opus-4-6", "display_name": "Claude Opus 4.6" }, { "name": "claude-sonnet-4-6", "display_name": "Claude Sonnet 4.6" }, { "name": "claude-opus-4-5", "display_name": "Claude Opus 4.5" }, { "name": "claude-sonnet-4-5", "display_name": "Claude Sonnet 4.5" } ] }, "vertex_ai": { "default_model": "gemini-3-pro-preview", "additional_visible_models": [ { "name": "gemini-3-pro-preview", "display_name": "Gemini 3 Pro" }, { "name": "gemini-3-flash-preview", "display_name": "Gemini 3 Flash" } ] }, "openrouter": { "default_model": "z-ai/glm-4.7", "additional_visible_models": [ { "name": "z-ai/glm-4.7", "display_name": "GLM 4.7" }, { "name": "deepseek/deepseek-v3.2", "display_name": "DeepSeek V3.2" }, { "name": "qwen/qwen3-235b-a22b-2507", "display_name": "Qwen3 235B A22B Instruct 2507" }, { "name": "moonshotai/kimi-k2-0905", "display_name": "Kimi K2 0905" } ] } } } ================================================ FILE: backend/onyx/main.py ================================================ import logging import sys import traceback import warnings from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from typing import Any from typing import cast import sentry_sdk import uvicorn from fastapi import APIRouter from fastapi import FastAPI from fastapi import HTTPException from fastapi import Request from fastapi import status from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.routing import APIRoute from httpx_oauth.clients.google import GoogleOAuth2 from httpx_oauth.clients.openid import BASE_SCOPES from httpx_oauth.clients.openid import OpenID from sentry_sdk.integrations.fastapi import FastApiIntegration from sentry_sdk.integrations.starlette import StarletteIntegration from starlette.types import Lifespan from onyx import __version__ from onyx.auth.schemas import UserCreate from onyx.auth.schemas import UserRead from onyx.auth.schemas import UserUpdate from onyx.auth.users import auth_backend from onyx.auth.users import create_onyx_oauth_router from onyx.auth.users import fastapi_users from onyx.cache.interface import CacheBackendType from onyx.configs.app_configs import APP_API_PREFIX from onyx.configs.app_configs import APP_HOST from onyx.configs.app_configs import APP_PORT from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED from onyx.configs.app_configs import AUTH_TYPE from onyx.configs.app_configs import CACHE_BACKEND from onyx.configs.app_configs import DISABLE_VECTOR_DB from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY from onyx.configs.app_configs import OAUTH_CLIENT_ID from onyx.configs.app_configs import OAUTH_CLIENT_SECRET from onyx.configs.app_configs import OAUTH_ENABLED from onyx.configs.app_configs import OIDC_PKCE_ENABLED from onyx.configs.app_configs import OIDC_SCOPE_OVERRIDE from onyx.configs.app_configs import OPENID_CONFIG_URL from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE from onyx.configs.app_configs import POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW from onyx.configs.app_configs import POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE from onyx.configs.app_configs import SYSTEM_RECURSION_LIMIT from onyx.configs.app_configs import USER_AUTH_SECRET from onyx.configs.app_configs import WEB_DOMAIN from onyx.configs.constants import AuthType from onyx.configs.constants import POSTGRES_WEB_APP_NAME from onyx.db.engine.async_sql_engine import get_sqlalchemy_async_engine from onyx.db.engine.connection_warmup import warm_up_connections from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.engine.sql_engine import SqlEngine from onyx.error_handling.exceptions import register_onyx_exception_handlers from onyx.file_store.file_store import get_default_file_store from onyx.hooks.registry import validate_registry from onyx.server.api_key.api import router as api_key_router from onyx.server.auth_check import check_router_auth from onyx.server.documents.cc_pair import router as cc_pair_router from onyx.server.documents.connector import router as connector_router from onyx.server.documents.credential import router as credential_router from onyx.server.documents.document import router as document_router from onyx.server.documents.standard_oauth import router as standard_oauth_router from onyx.server.features.build.api.api import public_build_router from onyx.server.features.build.api.api import router as build_router from onyx.server.features.default_assistant.api import ( router as default_assistant_router, ) from onyx.server.features.document_set.api import router as document_set_router from onyx.server.features.hierarchy.api import router as hierarchy_router from onyx.server.features.input_prompt.api import ( admin_router as admin_input_prompt_router, ) from onyx.server.features.input_prompt.api import ( basic_router as input_prompt_router, ) from onyx.server.features.mcp.api import admin_router as mcp_admin_router from onyx.server.features.mcp.api import router as mcp_router from onyx.server.features.notifications.api import router as notification_router from onyx.server.features.oauth_config.api import ( admin_router as admin_oauth_config_router, ) from onyx.server.features.oauth_config.api import router as oauth_config_router from onyx.server.features.password.api import router as password_router from onyx.server.features.persona.api import admin_agents_router from onyx.server.features.persona.api import admin_router as admin_persona_router from onyx.server.features.persona.api import agents_router from onyx.server.features.persona.api import basic_router as persona_router from onyx.server.features.projects.api import router as projects_router from onyx.server.features.tool.api import admin_router as admin_tool_router from onyx.server.features.tool.api import router as tool_router from onyx.server.features.user_oauth_token.api import router as user_oauth_token_router from onyx.server.features.web_search.api import router as web_search_router from onyx.server.federated.api import router as federated_router from onyx.server.kg.api import admin_router as kg_admin_router from onyx.server.manage.administrative import router as admin_router from onyx.server.manage.code_interpreter.api import ( admin_router as code_interpreter_admin_router, ) from onyx.server.manage.discord_bot.api import router as discord_bot_router from onyx.server.manage.embedding.api import admin_router as embedding_admin_router from onyx.server.manage.embedding.api import basic_router as embedding_router from onyx.server.manage.get_state import router as state_router from onyx.server.manage.image_generation.api import ( admin_router as image_generation_admin_router, ) from onyx.server.manage.llm.api import admin_router as llm_admin_router from onyx.server.manage.llm.api import basic_router as llm_router from onyx.server.manage.opensearch_migration.api import ( admin_router as opensearch_migration_admin_router, ) from onyx.server.manage.search_settings import router as search_settings_router from onyx.server.manage.slack_bot import router as slack_bot_management_router from onyx.server.manage.users import router as user_router from onyx.server.manage.voice.api import admin_router as voice_admin_router from onyx.server.manage.voice.user_api import router as voice_router from onyx.server.manage.voice.websocket_api import router as voice_websocket_router from onyx.server.manage.web_search.api import ( admin_router as web_search_admin_router, ) from onyx.server.metrics.postgres_connection_pool import ( setup_postgres_connection_pool_metrics, ) from onyx.server.metrics.prometheus_setup import setup_prometheus_metrics from onyx.server.middleware.latency_logging import add_latency_logging_middleware from onyx.server.middleware.rate_limiting import close_auth_limiter from onyx.server.middleware.rate_limiting import get_auth_rate_limiters from onyx.server.middleware.rate_limiting import setup_auth_limiter from onyx.server.onyx_api.ingestion import router as onyx_api_router from onyx.server.pat.api import router as pat_router from onyx.server.query_and_chat.chat_backend import router as chat_router from onyx.server.query_and_chat.query_backend import ( admin_router as admin_query_router, ) from onyx.server.query_and_chat.query_backend import basic_router as query_router from onyx.server.saml import router as saml_router from onyx.server.settings.api import admin_router as settings_admin_router from onyx.server.settings.api import basic_router as settings_router from onyx.server.token_rate_limits.api import ( router as token_rate_limit_settings_router, ) from onyx.server.utils import BasicAuthenticationError from onyx.setup import setup_multitenant_onyx from onyx.setup import setup_onyx from onyx.tracing.setup import setup_tracing from onyx.utils.logger import setup_logger from onyx.utils.logger import setup_uvicorn_logger from onyx.utils.middleware import add_endpoint_context_middleware from onyx.utils.middleware import add_onyx_request_id_middleware from onyx.utils.telemetry import get_or_generate_uuid from onyx.utils.telemetry import optional_telemetry from onyx.utils.telemetry import RecordType from onyx.utils.variable_functionality import fetch_versioned_implementation from onyx.utils.variable_functionality import global_version from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable from shared_configs.configs import CORS_ALLOWED_ORIGIN from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.configs import SENTRY_DSN from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR warnings.filterwarnings( "ignore", category=ResourceWarning, message=r"Unclosed client session" ) warnings.filterwarnings( "ignore", category=ResourceWarning, message=r"Unclosed connector" ) logger = setup_logger() file_handlers = [ h for h in logger.logger.handlers if isinstance(h, logging.FileHandler) ] setup_uvicorn_logger(shared_file_handlers=file_handlers) def validation_exception_handler(request: Request, exc: Exception) -> JSONResponse: if not isinstance(exc, RequestValidationError): logger.error( f"Unexpected exception type in validation_exception_handler - {type(exc)}" ) raise exc exc_str = f"{exc}".replace("\n", " ").replace(" ", " ") logger.exception(f"{request}: {exc_str}") content = {"status_code": 422, "message": exc_str, "data": None} return JSONResponse(content=content, status_code=422) def value_error_handler(_: Request, exc: Exception) -> JSONResponse: if not isinstance(exc, ValueError): logger.error(f"Unexpected exception type in value_error_handler - {type(exc)}") raise exc try: raise (exc) except Exception: # log stacktrace logger.exception("ValueError") return JSONResponse( status_code=400, content={"message": str(exc)}, ) def use_route_function_names_as_operation_ids(app: FastAPI) -> None: """ OpenAPI generation defaults to naming the operation with the function + route + HTTP method, which usually looks very redundant. This function changes the operation IDs to be just the function name. Should be called only after all routes have been added. """ for route in app.routes: if isinstance(route, APIRoute): route.operation_id = route.name def include_router_with_global_prefix_prepended( application: FastAPI, router: APIRouter, **kwargs: Any ) -> None: """Adds the global prefix to all routes in the router.""" processed_global_prefix = f"/{APP_API_PREFIX.strip('/')}" if APP_API_PREFIX else "" passed_in_prefix = cast(str | None, kwargs.get("prefix")) if passed_in_prefix: final_prefix = f"{processed_global_prefix}/{passed_in_prefix.strip('/')}" else: final_prefix = f"{processed_global_prefix}" final_kwargs: dict[str, Any] = { **kwargs, "prefix": final_prefix, } application.include_router(router, **final_kwargs) def include_auth_router_with_prefix( application: FastAPI, router: APIRouter, prefix: str | None = None, tags: list[str] | None = None, ) -> None: """Wrapper function to include an 'auth' router with prefix + rate-limiting dependencies.""" final_tags = tags or ["auth"] include_router_with_global_prefix_prepended( application, router, prefix=prefix, tags=final_tags, dependencies=get_auth_rate_limiters(), ) def validate_cache_backend_settings() -> None: """Validate that CACHE_BACKEND=postgres is only used with DISABLE_VECTOR_DB. The Postgres cache backend eliminates the Redis dependency, but only works when Celery is not running (which requires DISABLE_VECTOR_DB=true). """ if CACHE_BACKEND == CacheBackendType.POSTGRES and not DISABLE_VECTOR_DB: raise RuntimeError( "CACHE_BACKEND=postgres requires DISABLE_VECTOR_DB=true. " "The Postgres cache backend is only supported in no-vector-DB " "deployments where Celery is replaced by the in-process task runner." ) def validate_no_vector_db_settings() -> None: """Validate that DISABLE_VECTOR_DB is not combined with incompatible settings. Raises RuntimeError if DISABLE_VECTOR_DB is set alongside MULTI_TENANT or ENABLE_CRAFT, since these modes require infrastructure that is removed in no-vector-DB deployments. """ if not DISABLE_VECTOR_DB: return if MULTI_TENANT: raise RuntimeError( "DISABLE_VECTOR_DB cannot be used with MULTI_TENANT. " "Multi-tenant deployments require the vector database for " "per-tenant document indexing and search. Run in single-tenant " "mode when disabling the vector database." ) from onyx.server.features.build.configs import ENABLE_CRAFT if ENABLE_CRAFT: raise RuntimeError( "DISABLE_VECTOR_DB cannot be used with ENABLE_CRAFT. " "Onyx Craft requires background workers for sandbox lifecycle " "management, which are removed in no-vector-DB deployments. " "Disable Craft (ENABLE_CRAFT=false) when disabling the vector database." ) @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001 validate_no_vector_db_settings() validate_cache_backend_settings() validate_registry() # Set recursion limit if SYSTEM_RECURSION_LIMIT is not None: sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT) logger.notice(f"System recursion limit set to {SYSTEM_RECURSION_LIMIT}") SqlEngine.set_app_name(POSTGRES_WEB_APP_NAME) SqlEngine.init_engine( pool_size=POSTGRES_API_SERVER_POOL_SIZE, max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW, ) SqlEngine.get_engine() SqlEngine.init_readonly_engine( pool_size=POSTGRES_API_SERVER_READ_ONLY_POOL_SIZE, max_overflow=POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW, ) # Register pool metrics now that engines are created. # HTTP instrumentation is set up earlier in get_application() since it # adds middleware (which Starlette forbids after the app has started). setup_postgres_connection_pool_metrics( engines={ "sync": SqlEngine.get_engine(), "async": get_sqlalchemy_async_engine(), "readonly": SqlEngine.get_readonly_engine(), }, ) verify_auth = fetch_versioned_implementation( "onyx.auth.users", "verify_auth_setting" ) # Will throw exception if an issue is found verify_auth() if OAUTH_CLIENT_ID and OAUTH_CLIENT_SECRET: logger.notice("Both OAuth Client ID and Secret are configured.") # Initialize tracing if credentials are provided setup_tracing() # fill up Postgres connection pools await warm_up_connections() if not MULTI_TENANT: # We cache this at the beginning so there is no delay in the first telemetry CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA) get_or_generate_uuid() # If we are multi-tenant, we need to only set up initial public tables with get_session_with_current_tenant() as db_session: setup_onyx(db_session, POSTGRES_DEFAULT_SCHEMA) # set up the file store (e.g. create bucket if needed). On multi-tenant, # this is done via IaC get_default_file_store().initialize() else: setup_multitenant_onyx() if not MULTI_TENANT: # don't emit a metric for every pod rollover/restart optional_telemetry( record_type=RecordType.VERSION, data={"version": __version__} ) if AUTH_RATE_LIMITING_ENABLED: await setup_auth_limiter() if DISABLE_VECTOR_DB: from onyx.background.periodic_poller import recover_stuck_user_files from onyx.background.periodic_poller import start_periodic_poller recover_stuck_user_files(POSTGRES_DEFAULT_SCHEMA) start_periodic_poller(POSTGRES_DEFAULT_SCHEMA) yield if DISABLE_VECTOR_DB: from onyx.background.periodic_poller import stop_periodic_poller stop_periodic_poller() SqlEngine.reset_engine() if AUTH_RATE_LIMITING_ENABLED: await close_auth_limiter() def log_http_error(request: Request, exc: Exception) -> JSONResponse: status_code = getattr(exc, "status_code", 500) if isinstance(exc, BasicAuthenticationError): # For BasicAuthenticationError, just log a brief message without stack trace # (almost always spammy) logger.debug(f"Authentication failed: {str(exc)}") elif status_code == 404 and request.url.path == "/metrics": # Log 404 errors for the /metrics endpoint with debug level logger.debug(f"404 error for /metrics endpoint: {str(exc)}") elif status_code >= 400: error_msg = f"{str(exc)}\n" error_msg += "".join(traceback.format_tb(exc.__traceback__)) logger.error(error_msg) detail = exc.detail if isinstance(exc, HTTPException) else str(exc) return JSONResponse( status_code=status_code, content={"detail": detail}, ) def get_application(lifespan_override: Lifespan | None = None) -> FastAPI: application = FastAPI( title="Onyx Backend", version=__version__, description="Onyx API for AI-powered chat with search, document indexing, agents, actions, and more", servers=[ {"url": f"{WEB_DOMAIN.rstrip('/')}/api", "description": "Onyx API Server"} ], lifespan=lifespan_override or lifespan, ) if SENTRY_DSN: sentry_sdk.init( dsn=SENTRY_DSN, integrations=[StarletteIntegration(), FastApiIntegration()], traces_sample_rate=0.1, release=__version__, ) logger.info("Sentry initialized") else: logger.debug("Sentry DSN not provided, skipping Sentry initialization") application.add_exception_handler(status.HTTP_400_BAD_REQUEST, log_http_error) application.add_exception_handler(status.HTTP_401_UNAUTHORIZED, log_http_error) application.add_exception_handler(status.HTTP_403_FORBIDDEN, log_http_error) application.add_exception_handler(status.HTTP_404_NOT_FOUND, log_http_error) application.add_exception_handler( status.HTTP_500_INTERNAL_SERVER_ERROR, log_http_error ) register_onyx_exception_handlers(application) include_router_with_global_prefix_prepended(application, password_router) include_router_with_global_prefix_prepended(application, chat_router) include_router_with_global_prefix_prepended(application, query_router) include_router_with_global_prefix_prepended(application, document_router) include_router_with_global_prefix_prepended(application, user_router) include_router_with_global_prefix_prepended(application, admin_query_router) include_router_with_global_prefix_prepended(application, admin_router) include_router_with_global_prefix_prepended(application, connector_router) include_router_with_global_prefix_prepended(application, credential_router) include_router_with_global_prefix_prepended(application, input_prompt_router) include_router_with_global_prefix_prepended(application, admin_input_prompt_router) include_router_with_global_prefix_prepended(application, cc_pair_router) include_router_with_global_prefix_prepended(application, projects_router) include_router_with_global_prefix_prepended(application, public_build_router) include_router_with_global_prefix_prepended(application, build_router) include_router_with_global_prefix_prepended(application, document_set_router) include_router_with_global_prefix_prepended(application, hierarchy_router) include_router_with_global_prefix_prepended(application, search_settings_router) include_router_with_global_prefix_prepended( application, slack_bot_management_router ) include_router_with_global_prefix_prepended(application, discord_bot_router) include_router_with_global_prefix_prepended(application, persona_router) include_router_with_global_prefix_prepended(application, admin_persona_router) include_router_with_global_prefix_prepended(application, agents_router) include_router_with_global_prefix_prepended(application, admin_agents_router) include_router_with_global_prefix_prepended(application, default_assistant_router) include_router_with_global_prefix_prepended(application, notification_router) include_router_with_global_prefix_prepended(application, tool_router) include_router_with_global_prefix_prepended(application, admin_tool_router) include_router_with_global_prefix_prepended(application, oauth_config_router) include_router_with_global_prefix_prepended(application, admin_oauth_config_router) include_router_with_global_prefix_prepended(application, user_oauth_token_router) include_router_with_global_prefix_prepended(application, state_router) include_router_with_global_prefix_prepended(application, onyx_api_router) include_router_with_global_prefix_prepended(application, settings_router) include_router_with_global_prefix_prepended(application, settings_admin_router) include_router_with_global_prefix_prepended(application, llm_admin_router) include_router_with_global_prefix_prepended(application, kg_admin_router) include_router_with_global_prefix_prepended(application, llm_router) include_router_with_global_prefix_prepended( application, code_interpreter_admin_router ) include_router_with_global_prefix_prepended( application, image_generation_admin_router ) include_router_with_global_prefix_prepended(application, embedding_admin_router) include_router_with_global_prefix_prepended(application, embedding_router) include_router_with_global_prefix_prepended(application, web_search_router) include_router_with_global_prefix_prepended(application, web_search_admin_router) include_router_with_global_prefix_prepended(application, voice_admin_router) include_router_with_global_prefix_prepended(application, voice_router) include_router_with_global_prefix_prepended(application, voice_websocket_router) include_router_with_global_prefix_prepended( application, opensearch_migration_admin_router ) include_router_with_global_prefix_prepended( application, token_rate_limit_settings_router ) include_router_with_global_prefix_prepended(application, api_key_router) include_router_with_global_prefix_prepended(application, standard_oauth_router) include_router_with_global_prefix_prepended(application, federated_router) include_router_with_global_prefix_prepended(application, mcp_router) include_router_with_global_prefix_prepended(application, mcp_admin_router) include_router_with_global_prefix_prepended(application, pat_router) if AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.CLOUD: include_auth_router_with_prefix( application, fastapi_users.get_auth_router(auth_backend), prefix="/auth", ) include_auth_router_with_prefix( application, fastapi_users.get_register_router(UserRead, UserCreate), prefix="/auth", ) include_auth_router_with_prefix( application, fastapi_users.get_reset_password_router(), prefix="/auth", ) include_auth_router_with_prefix( application, fastapi_users.get_verify_router(UserRead), prefix="/auth", ) include_auth_router_with_prefix( application, fastapi_users.get_users_router(UserRead, UserUpdate), prefix="/users", ) # Register Google OAuth when AUTH_TYPE is GOOGLE_OAUTH, or when # AUTH_TYPE is BASIC and OAuth credentials are configured if AUTH_TYPE == AuthType.GOOGLE_OAUTH or ( AUTH_TYPE == AuthType.BASIC and OAUTH_ENABLED ): oauth_client = GoogleOAuth2( OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, scopes=["openid", "email", "profile"], ) include_auth_router_with_prefix( application, create_onyx_oauth_router( oauth_client, auth_backend, USER_AUTH_SECRET, associate_by_email=True, is_verified_by_default=True, redirect_url=f"{WEB_DOMAIN}/auth/oauth/callback", ), prefix="/auth/oauth", ) # Need logout router for GOOGLE_OAUTH only (BASIC already has it from above) if AUTH_TYPE == AuthType.GOOGLE_OAUTH: include_auth_router_with_prefix( application, fastapi_users.get_logout_router(auth_backend), prefix="/auth", ) if AUTH_TYPE == AuthType.OIDC: # Ensure we request offline_access for refresh tokens try: oidc_scopes = list(OIDC_SCOPE_OVERRIDE or BASE_SCOPES) if "offline_access" not in oidc_scopes: oidc_scopes.append("offline_access") except Exception as e: logger.warning(f"Error configuring OIDC scopes: {e}") # Fall back to default scopes if there's an error oidc_scopes = BASE_SCOPES include_auth_router_with_prefix( application, create_onyx_oauth_router( OpenID( OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET, OPENID_CONFIG_URL, # Use the configured scopes base_scopes=oidc_scopes, ), auth_backend, USER_AUTH_SECRET, associate_by_email=True, is_verified_by_default=True, redirect_url=f"{WEB_DOMAIN}/auth/oidc/callback", enable_pkce=OIDC_PKCE_ENABLED, ), prefix="/auth/oidc", ) # need basic auth router for `logout` endpoint include_auth_router_with_prefix( application, fastapi_users.get_auth_router(auth_backend), prefix="/auth", ) elif AUTH_TYPE == AuthType.SAML: include_auth_router_with_prefix( application, saml_router, ) if ( AUTH_TYPE == AuthType.CLOUD or AUTH_TYPE == AuthType.BASIC or AUTH_TYPE == AuthType.GOOGLE_OAUTH or AUTH_TYPE == AuthType.OIDC ): # Add refresh token endpoint for OAuth as well include_auth_router_with_prefix( application, fastapi_users.get_refresh_router(auth_backend), prefix="/auth", ) application.add_exception_handler( RequestValidationError, validation_exception_handler ) application.add_exception_handler(ValueError, value_error_handler) application.add_middleware( CORSMiddleware, allow_origins=CORS_ALLOWED_ORIGIN, # Configurable via environment variable allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) if LOG_ENDPOINT_LATENCY: add_latency_logging_middleware(application, logger) add_onyx_request_id_middleware(application, "API", logger) # Set endpoint context for per-endpoint DB pool attribution metrics. # Must be registered after all routes are added. add_endpoint_context_middleware(application) # HTTP request metrics (latency histograms, in-progress gauge, slow request # counter). Must be called here — before the app starts — because the # instrumentator adds middleware via app.add_middleware(). setup_prometheus_metrics(application) # Ensure all routes have auth enabled or are explicitly marked as public check_router_auth(application) use_route_function_names_as_operation_ids(application) return application # NOTE: needs to be outside of the `if __name__ == "__main__"` block so that the # app is exportable set_is_ee_based_on_env_variable() app = fetch_versioned_implementation(module="onyx.main", attribute="get_application") if __name__ == "__main__": logger.notice( f"Starting Onyx Backend version {__version__} on http://{APP_HOST}:{str(APP_PORT)}/" ) if global_version.is_ee_version(): logger.notice("Running Enterprise Edition") uvicorn.run(app, host=APP_HOST, port=APP_PORT) ================================================ FILE: backend/onyx/mcp_server/README.md ================================================ # Onyx MCP Server ## Overview The Onyx MCP server allows LLMs to connect to your Onyx instance and access its knowledge base and search capabilities through the [Model Context Protocol (MCP)](https://modelcontextprotocol.io/). With the Onyx MCP Server, you can search your knowledgebase, give your LLMs web search, and upload and manage documents in Onyx. All access controls are managed within the main Onyx application. ### Authentication Provide an Onyx Personal Access Token or API Key in the `Authorization` header as a Bearer token. The MCP server quickly validates and passes through the token on every request. Depending on usage, the MCP Server may support OAuth and stdio in the future. ### Default Configuration - **Transport**: HTTP POST (MCP over HTTP) - **Port**: 8090 (shares domain with API server) - **Framework**: FastMCP with FastAPI wrapper - **Database**: None (all work delegates to the API server) ### Architecture The MCP server is built on [FastMCP](https://github.com/jlowin/fastmcp) and runs alongside the main Onyx API server: ``` ┌─────────────────┐ │ LLM Client │ │ (Claude, etc) │ └────────┬────────┘ │ MCP over HTTP │ (POST with bearer) ▼ ┌─────────────────┐ │ MCP Server │ │ Port 8090 │ │ ├─ Auth │ │ ├─ Tools │ │ └─ Resources │ └────────┬────────┘ │ Internal HTTP │ (authenticated) ▼ ┌─────────────────┐ │ API Server │ │ Port 8080 │ │ ├─ /me (auth) │ │ ├─ Search APIs │ │ └─ ACL checks │ └─────────────────┘ ``` ## Configuring MCP Clients ### Claude Desktop Add to your Claude Desktop configuration (`~/Library/Application Support/Claude/claude_desktop_config.json` on macOS): ```json { "mcpServers": { "onyx": { "url": "https://[YOUR_ONYX_DOMAIN]:8090/", "transport": "http", "headers": { "Authorization": "Bearer YOUR_ONYX_TOKEN_HERE" } } } } ``` ### Other MCP Clients Most MCP clients support HTTP transport with custom headers. Refer to your client's documentation for configuration details. ## Capabilities ### Tools The server provides three tools for searching and retrieving information: 1. `search_indexed_documents` Search the user's private knowledge base indexed in Onyx. Returns ranked documents with content snippets, scores, and metadata. 2. `search_web` Search the public internet for current events and general knowledge. Returns web search results with titles, URLs, and snippets. 3. `open_urls` Retrieve the complete text content from specific web URLs. Useful for fetching full page content after finding relevant URLs via `search_web`. ### Resources 1. `indexed_sources` Lists all document sources currently indexed in the tenant (e.g., `"confluence"`, `"github"`). Use these values to filter results when calling `search_indexed_documents`. ## Local Development ### Running the MCP Server The MCP Server automatically launches with the `Run All Onyx Services` task from the default launch.json. You can also independently launch the Server via the vscode debugger. ### Testing with MCP Inspector The [MCP Inspector](https://github.com/modelcontextprotocol/inspector) is a debugging tool for MCP servers: ```bash npx @modelcontextprotocol/inspector http://localhost:8090/ ``` **Setup in Inspector:** 1. Ignore the OAuth configuration menus 2. Open the **Authentication** tab 3. Select **Bearer Token** authentication 4. Paste your Onyx bearer token 5. Click **Connect** Once connected, you can: - Browse available tools - Test tool calls with different parameters - View request/response payloads - Debug authentication issues ### Health Check Verify the server is running: ```bash curl http://localhost:8090/health ``` Expected response: ```json { "status": "healthy", "service": "mcp_server" } ``` ### Environment Variables **MCP Server Configuration:** - `MCP_SERVER_ENABLED`: Enable MCP server (set to "true" to enable, default: disabled) - `MCP_SERVER_PORT`: Port for MCP server (default: 8090) - `MCP_SERVER_CORS_ORIGINS`: Comma-separated CORS origins (optional) **API Server Connection:** - `API_SERVER_PROTOCOL`: Protocol for API server connection (default: "http") - `API_SERVER_HOST`: Hostname for API server connection (default: "127.0.0.1") - `API_SERVER_URL_OVERRIDE_FOR_HTTP_REQUESTS`: Optional override URL. If set, takes precedence over the protocol/host variables. Used for self-hosting the MCP server with Onyx Cloud as the backend. ================================================ FILE: backend/onyx/mcp_server/api.py ================================================ """MCP server with FastAPI wrapper.""" from collections.abc import AsyncGenerator from contextlib import asynccontextmanager from fastapi import FastAPI from fastapi import Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from fastapi.responses import Response from fastmcp import FastMCP from starlette.datastructures import MutableHeaders from starlette.middleware.base import RequestResponseEndpoint from starlette.types import Receive from starlette.types import Scope from starlette.types import Send from onyx.configs.app_configs import MCP_SERVER_CORS_ORIGINS from onyx.mcp_server.auth import OnyxTokenVerifier from onyx.mcp_server.utils import shutdown_http_client from onyx.utils.logger import setup_logger logger = setup_logger() logger.info("Creating Onyx MCP Server...") mcp_server = FastMCP( name="Onyx MCP Server", version="1.0.0", auth=OnyxTokenVerifier(), ) # Import tools and resources AFTER mcp_server is created to avoid circular imports # Components register themselves via decorators on the shared mcp_server instance from onyx.mcp_server.tools import search # noqa: E402, F401 from onyx.mcp_server.resources import indexed_sources # noqa: E402, F401 logger.info("MCP server instance created") def create_mcp_fastapi_app() -> FastAPI: """Create FastAPI app wrapping MCP server with auth and shared client lifecycle.""" mcp_asgi_app = mcp_server.http_app(path="/") async def _ensure_streamable_accept_header( scope: Scope, receive: Receive, send: Send ) -> None: """Ensure Accept header includes types required by FastMCP streamable HTTP.""" if scope.get("type") == "http": headers = MutableHeaders(scope=scope) accept = headers.get("accept", "") accept_lower = accept.lower() if ( not accept or accept == "*/*" or "application/json" not in accept_lower or "text/event-stream" not in accept_lower ): headers["accept"] = "application/json, text/event-stream" await mcp_asgi_app(scope, receive, send) @asynccontextmanager async def combined_lifespan(app: FastAPI) -> AsyncGenerator[None, None]: """Initializes MCP session manager.""" logger.info("MCP server starting up") try: async with mcp_asgi_app.lifespan(app): yield finally: logger.info("MCP server shutting down") await shutdown_http_client() app = FastAPI( title="Onyx MCP Server", description="HTTP POST transport with bearer auth delegated to API /me", version="1.0.0", lifespan=combined_lifespan, ) # Public health check endpoint (bypasses MCP auth) @app.middleware("http") async def health_check( request: Request, call_next: RequestResponseEndpoint ) -> Response: if request.url.path.rstrip("/") == "/health": return JSONResponse({"status": "healthy", "service": "mcp_server"}) return await call_next(request) # Authentication is handled by FastMCP's OnyxTokenVerifier (see auth.py) if MCP_SERVER_CORS_ORIGINS: logger.info(f"CORS origins: {MCP_SERVER_CORS_ORIGINS}") app.add_middleware( CORSMiddleware, allow_origins=MCP_SERVER_CORS_ORIGINS, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) app.mount("/", _ensure_streamable_accept_header) return app mcp_app = create_mcp_fastapi_app() ================================================ FILE: backend/onyx/mcp_server/auth.py ================================================ """Authentication helpers for the Onyx MCP server.""" from typing import Optional from fastmcp.server.auth.auth import AccessToken from fastmcp.server.auth.auth import TokenVerifier from onyx.mcp_server.utils import get_http_client from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import build_api_server_url_for_http_requests logger = setup_logger() class OnyxTokenVerifier(TokenVerifier): """Validates bearer tokens by delegating to the API server.""" async def verify_token(self, token: str) -> Optional[AccessToken]: """Call API /me to verify the token, return minimal AccessToken on success.""" try: response = await get_http_client().get( f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/me", headers={"Authorization": f"Bearer {token}"}, ) except Exception as exc: logger.error( "MCP server failed to reach API /me for authentication: %s", exc, exc_info=True, ) return None if response.status_code != 200: logger.warning( "API server rejected MCP auth token with status %s", response.status_code, ) return None return AccessToken( token=token, client_id="mcp", scopes=["mcp:use"], expires_at=None, resource=None, claims={}, ) ================================================ FILE: backend/onyx/mcp_server/mcp.json.template ================================================ { "mcpServers": { "Onyx": { "url": "https://cloud.onyx.app/mcp", "headers": { "Authorization": "Bearer [YOUR PAT OR API KEY HERE]" } } } } ================================================ FILE: backend/onyx/mcp_server/resources/__init__.py ================================================ """Resource registrations for the Onyx MCP server.""" # Import resource modules so decorators execute when the package loads. from onyx.mcp_server.resources import indexed_sources # noqa: F401 ================================================ FILE: backend/onyx/mcp_server/resources/indexed_sources.py ================================================ """Resources that expose metadata for the Onyx MCP server.""" from __future__ import annotations from typing import Any from onyx.mcp_server.api import mcp_server from onyx.mcp_server.utils import get_indexed_sources from onyx.mcp_server.utils import require_access_token from onyx.utils.logger import setup_logger logger = setup_logger() @mcp_server.resource( "resource://indexed_sources", name="indexed_sources", description=( "Enumerate the user's document sources that are currently indexed in Onyx." "This can be used to discover filters for the `search_indexed_documents` tool." ), mime_type="application/json", ) async def indexed_sources_resource() -> dict[str, Any]: """Return the list of indexed source types for search filtering.""" access_token = require_access_token() sources = await get_indexed_sources(access_token) logger.info( "Onyx MCP Server: indexed_sources resource returning %s entries", len(sources), ) return { "indexed_sources": sorted(sources), } ================================================ FILE: backend/onyx/mcp_server/tools/__init__.py ================================================ """Tool registrations for the Onyx MCP server.""" # Import tool modules so decorators execute when the package is imported. from onyx.mcp_server.tools import search # noqa: F401 ================================================ FILE: backend/onyx/mcp_server/tools/search.py ================================================ """Search tools for MCP server - document and web search.""" from datetime import datetime from typing import Any from onyx.configs.constants import DocumentSource from onyx.mcp_server.api import mcp_server from onyx.mcp_server.utils import get_http_client from onyx.mcp_server.utils import get_indexed_sources from onyx.mcp_server.utils import require_access_token from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import build_api_server_url_for_http_requests from onyx.utils.variable_functionality import global_version logger = setup_logger() @mcp_server.tool() async def search_indexed_documents( query: str, source_types: list[str] | None = None, time_cutoff: str | None = None, limit: int = 10, ) -> dict[str, Any]: """ Search the user's knowledge base indexed in Onyx. Use this tool for information that is not public knowledge and specific to the user, their team, their work, or their organization/company. Note: In CE mode, this tool uses the chat endpoint internally which invokes an LLM on every call, consuming tokens and adding latency. Additionally, CE callers receive a truncated snippet (blurb) instead of a full document chunk, but this should still be sufficient for most use cases. CE mode functionality should be swapped when a dedicated CE search endpoint is implemented. In EE mode, the dedicated search endpoint is used instead. To find a list of available sources, use the `indexed_sources` resource. Returns chunks of text as search results with snippets, scores, and metadata. Example usage: ``` { "query": "What is the latest status of PROJ-1234 and what is the next development item?", "source_types": ["jira", "google_drive", "github"], "time_cutoff": "2025-11-24T00:00:00Z", "limit": 10, } ``` """ logger.info( f"Onyx MCP Server: document search: query='{query}', sources={source_types}, limit={limit}" ) # Parse time_cutoff string to datetime if provided time_cutoff_dt: datetime | None = None if time_cutoff: try: time_cutoff_dt = datetime.fromisoformat(time_cutoff.replace("Z", "+00:00")) except ValueError as e: logger.warning( f"Onyx MCP Server: Invalid time_cutoff format '{time_cutoff}': {e}. Continuing without time filter." ) # Continue with no time_cutoff instead of returning an error time_cutoff_dt = None # Initialize source_type_enums early to avoid UnboundLocalError source_type_enums: list[DocumentSource] | None = None # Get authenticated user from FastMCP's access token access_token = require_access_token() try: sources = await get_indexed_sources(access_token) except Exception as e: # Error fetching sources (network error, API failure, etc.) logger.error( "Onyx MCP Server: Error checking indexed sources: %s", e, exc_info=True, ) return { "documents": [], "total_results": 0, "query": query, "error": (f"Failed to check indexed sources: {str(e)}. "), } if not sources: logger.info("Onyx MCP Server: No indexed sources available for tenant") return { "documents": [], "total_results": 0, "query": query, "message": ( "No document sources are indexed yet. Add connectors or upload data " "through Onyx before calling onyx_search_documents." ), } # Convert source_types strings to DocumentSource enums if provided # Invalid values will be handled by the API server if source_types is not None: source_type_enums = [] for src in source_types: try: source_type_enums.append(DocumentSource(src.lower())) except ValueError: logger.warning( f"Onyx MCP Server: Invalid source type '{src}' - will be ignored by server" ) # Build filters dict only with non-None values filters: dict[str, Any] | None = None if source_type_enums or time_cutoff_dt: filters = {} if source_type_enums: filters["source_type"] = [src.value for src in source_type_enums] if time_cutoff_dt: filters["time_cutoff"] = time_cutoff_dt.isoformat() is_ee = global_version.is_ee_version() base_url = build_api_server_url_for_http_requests(respect_env_override_if_set=True) auth_headers = {"Authorization": f"Bearer {access_token.token}"} search_request: dict[str, Any] if is_ee: # EE: use the dedicated search endpoint (no LLM invocation) search_request = { "search_query": query, "filters": filters, "num_docs_fed_to_llm_selection": limit, "run_query_expansion": False, "include_content": True, "stream": False, } endpoint = f"{base_url}/search/send-search-message" error_key = "error" docs_key = "search_docs" content_field = "content" else: # CE: fall back to the chat endpoint (invokes LLM, consumes tokens) search_request = { "message": query, "stream": False, "chat_session_info": {}, } if filters: search_request["internal_search_filters"] = filters endpoint = f"{base_url}/chat/send-chat-message" error_key = "error_msg" docs_key = "top_documents" content_field = "blurb" try: response = await get_http_client().post( endpoint, json=search_request, headers=auth_headers, ) response.raise_for_status() result = response.json() # Check for error in response if result.get(error_key): return { "documents": [], "total_results": 0, "query": query, "error": result.get(error_key), } documents = [ { "semantic_identifier": doc.get("semantic_identifier"), "content": doc.get(content_field), "source_type": doc.get("source_type"), "link": doc.get("link"), "score": doc.get("score"), } for doc in result.get(docs_key, []) ] # NOTE: search depth is controlled by the backend persona defaults, not `limit`. # `limit` only caps the returned list; fewer results may be returned if the # backend retrieves fewer documents than requested. documents = documents[:limit] logger.info( f"Onyx MCP Server: Internal search returned {len(documents)} results" ) return { "documents": documents, "total_results": len(documents), "query": query, } except Exception as e: logger.error(f"Onyx MCP Server: Document search error: {e}", exc_info=True) return { "error": f"Document search failed: {str(e)}", "documents": [], "query": query, } @mcp_server.tool() async def search_web( query: str, limit: int = 5, ) -> dict[str, Any]: """ Search the public internet for general knowledge, current events, and publicly available information. Use this tool for information that is publicly available on the web, such as news, documentation, general facts, or when the user's private knowledge base doesn't contain relevant information. Returns web search results with titles, URLs, and snippets (NOT full content). Use `open_urls` to fetch full page content. Example usage: ``` { "query": "React 19 migration guide to use react compiler", "limit": 5 } ``` """ logger.info(f"Onyx MCP Server: Web search: query='{query}', limit={limit}") access_token = require_access_token() try: request_payload = {"queries": [query], "max_results": limit} response = await get_http_client().post( f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/web-search/search-lite", json=request_payload, headers={"Authorization": f"Bearer {access_token.token}"}, ) response.raise_for_status() response_payload = response.json() results = response_payload.get("results", []) return { "results": results, "query": query, } except Exception as e: logger.error(f"Onyx MCP Server: Web search error: {e}", exc_info=True) return { "error": f"Web search failed: {str(e)}", "results": [], "query": query, } @mcp_server.tool() async def open_urls( urls: list[str], ) -> dict[str, Any]: """ Retrieve the complete text content from specific web URLs. Use this tool when you need to access full content from known URLs, such as documentation pages or articles returned by the `search_web` tool. Useful for following up on web search results when snippets do not provide enough information. Returns the full text content of each URL along with metadata like title and content type. Example usage: ``` { "urls": ["https://react.dev/versions", "https://react.dev/learn/react-compiler","https://react.dev/learn/react-compiler/introduction"] } ``` """ logger.info(f"Onyx MCP Server: Open URL: fetching {len(urls)} URLs") access_token = require_access_token() try: response = await get_http_client().post( f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/web-search/open-urls", json={"urls": urls}, headers={"Authorization": f"Bearer {access_token.token}"}, ) response.raise_for_status() response_payload = response.json() results = response_payload.get("results", []) return { "results": results, } except Exception as e: logger.error(f"Onyx MCP Server: URL fetch error: {e}", exc_info=True) return { "error": f"URL fetch failed: {str(e)}", "results": [], } ================================================ FILE: backend/onyx/mcp_server/utils.py ================================================ """Utility helpers for the Onyx MCP server.""" from __future__ import annotations import httpx from fastmcp.server.auth.auth import AccessToken from fastmcp.server.dependencies import get_access_token from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import build_api_server_url_for_http_requests logger = setup_logger() # Shared HTTP client reused across requests _http_client: httpx.AsyncClient | None = None def require_access_token() -> AccessToken: """ Get and validate the access token from the current request. Raises: ValueError: If no access token is present in the request. Returns: AccessToken: The validated access token. """ access_token = get_access_token() if not access_token: raise ValueError( "MCP Server requires an Onyx access token to authenticate your request" ) return access_token def get_http_client() -> httpx.AsyncClient: """Return a shared async HTTP client.""" global _http_client if _http_client is None: _http_client = httpx.AsyncClient(timeout=60.0) return _http_client async def shutdown_http_client() -> None: """Close the shared HTTP client when the server shuts down.""" global _http_client if _http_client is not None: await _http_client.aclose() _http_client = None async def get_indexed_sources( access_token: AccessToken, ) -> list[str]: """ Fetch indexed document sources for the current user/tenant. Returns: List of indexed source strings. Empty list if no sources are indexed. """ headers = {"Authorization": f"Bearer {access_token.token}"} try: response = await get_http_client().get( f"{build_api_server_url_for_http_requests(respect_env_override_if_set=True)}/manage/indexed-sources", headers=headers, ) response.raise_for_status() payload = response.json() sources = payload.get("sources", []) if not isinstance(sources, list): raise ValueError("Unexpected response shape for indexed sources") return [str(source) for source in sources] except (httpx.HTTPStatusError, httpx.RequestError, ValueError): # Re-raise known exception types (httpx errors and validation errors) logger.error( "Onyx MCP Server: Failed to fetch indexed sources", exc_info=True, ) raise except Exception as exc: # Wrap unexpected exceptions logger.error( "Onyx MCP Server: Unexpected error fetching indexed sources", exc_info=True, ) raise RuntimeError(f"Failed to fetch indexed sources: {exc}") from exc ================================================ FILE: backend/onyx/mcp_server_main.py ================================================ """Entry point for MCP server - HTTP POST transport with API key auth.""" import uvicorn from onyx.configs.app_configs import MCP_SERVER_ENABLED from onyx.configs.app_configs import MCP_SERVER_HOST from onyx.configs.app_configs import MCP_SERVER_PORT from onyx.utils.logger import setup_logger logger = setup_logger() def main() -> None: """Run the MCP server.""" if not MCP_SERVER_ENABLED: logger.info("MCP server is disabled (MCP_SERVER_ENABLED=false)") return logger.info(f"Starting MCP server on {MCP_SERVER_HOST}:{MCP_SERVER_PORT}") from onyx.mcp_server.api import mcp_app uvicorn.run( mcp_app, host=MCP_SERVER_HOST, port=MCP_SERVER_PORT, log_config=None, ) if __name__ == "__main__": main() ================================================ FILE: backend/onyx/natural_language_processing/__init__.py ================================================ ================================================ FILE: backend/onyx/natural_language_processing/constants.py ================================================ """ Constants for natural language processing, including embedding and reranking models. This file contains constants moved from model_server to support the gradual migration of API-based calls to bypass the model server. """ from shared_configs.enums import EmbeddingProvider from shared_configs.enums import EmbedTextType # Default model names for different providers DEFAULT_OPENAI_MODEL = "text-embedding-3-small" DEFAULT_COHERE_MODEL = "embed-english-light-v3.0" DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct" DEFAULT_VERTEX_MODEL = "text-embedding-005" class EmbeddingModelTextType: """Mapping of Onyx text types to provider-specific text types.""" PROVIDER_TEXT_TYPE_MAP = { EmbeddingProvider.COHERE: { EmbedTextType.QUERY: "search_query", EmbedTextType.PASSAGE: "search_document", }, EmbeddingProvider.VOYAGE: { EmbedTextType.QUERY: "query", EmbedTextType.PASSAGE: "document", }, EmbeddingProvider.GOOGLE: { EmbedTextType.QUERY: "RETRIEVAL_QUERY", EmbedTextType.PASSAGE: "RETRIEVAL_DOCUMENT", }, } @staticmethod def get_type(provider: EmbeddingProvider, text_type: EmbedTextType) -> str: """Get provider-specific text type string.""" return EmbeddingModelTextType.PROVIDER_TEXT_TYPE_MAP[provider][text_type] ================================================ FILE: backend/onyx/natural_language_processing/english_stopwords.py ================================================ import re ENGLISH_STOPWORDS = [ "a", "about", "above", "after", "again", "against", "ain", "all", "am", "an", "and", "any", "are", "aren", "aren't", "as", "at", "be", "because", "been", "before", "being", "below", "between", "both", "but", "by", "can", "couldn", "couldn't", "d", "did", "didn", "didn't", "do", "does", "doesn", "doesn't", "doing", "don", "don't", "down", "during", "each", "few", "for", "from", "further", "had", "hadn", "hadn't", "has", "hasn", "hasn't", "have", "haven", "haven't", "having", "he", "he'd", "he'll", "he's", "her", "here", "hers", "herself", "him", "himself", "his", "how", "i", "i'd", "i'll", "i'm", "i've", "if", "in", "into", "is", "isn", "isn't", "it", "it'd", "it'll", "it's", "its", "itself", "just", "ll", "m", "ma", "me", "mightn", "mightn't", "more", "most", "mustn", "mustn't", "my", "myself", "needn", "needn't", "no", "nor", "not", "now", "o", "of", "off", "on", "once", "only", "or", "other", "our", "ours", "ourselves", "out", "over", "own", "re", "s", "same", "shan", "shan't", "she", "she'd", "she'll", "she's", "should", "should've", "shouldn", "shouldn't", "so", "some", "such", "t", "than", "that", "that'll", "the", "their", "theirs", "them", "themselves", "then", "there", "these", "they", "they'd", "they'll", "they're", "they've", "this", "those", "through", "to", "too", "under", "until", "up", "ve", "very", "was", "wasn", "wasn't", "we", "we'd", "we'll", "we're", "we've", "were", "weren", "weren't", "what", "when", "where", "which", "while", "who", "whom", "why", "will", "with", "won", "won't", "wouldn", "wouldn't", "y", "you", "you'd", "you'll", "you're", "you've", "your", "yours", "yourself", "yourselves", ] ENGLISH_STOPWORDS_SET = frozenset(ENGLISH_STOPWORDS) def strip_stopwords(text: str) -> list[str]: """Remove English stopwords from text. Matching is case-insensitive and ignores leading/trailing punctuation on each word. Internal punctuation (like apostrophes in contractions) is preserved for matching, so "you're" matches the stopword "you're" but "youre" would not. """ words = text.split() result = [] for word in words: # Strip leading/trailing punctuation to get the core word for comparison # This preserves internal punctuation like apostrophes core = re.sub(r"^[^\w']+|[^\w']+$", "", word) if core.lower() not in ENGLISH_STOPWORDS_SET: result.append(word) return result ================================================ FILE: backend/onyx/natural_language_processing/exceptions.py ================================================ class ModelServerRateLimitError(Exception): """ Exception raised for rate limiting errors from the model server. """ class CohereBillingLimitError(Exception): """ Raised when Cohere rejects requests because the billing cap is reached. """ ================================================ FILE: backend/onyx/natural_language_processing/search_nlp_models.py ================================================ import asyncio import json import os import threading import time from collections.abc import Callable from concurrent.futures import as_completed from concurrent.futures import ThreadPoolExecutor from functools import partial from functools import wraps from types import TracebackType from typing import Any from typing import cast import aioboto3 # type: ignore import httpx import requests import voyageai # type: ignore[import-untyped] from cohere import AsyncClient as CohereAsyncClient from cohere.core.api_error import ApiError from google.oauth2 import service_account from httpx import HTTPError from requests import JSONDecodeError from requests import RequestException from requests import Response from retry import retry from onyx.configs.app_configs import INDEXING_EMBEDDING_MODEL_NUM_THREADS from onyx.configs.app_configs import LARGE_CHUNK_RATIO from onyx.configs.model_configs import BATCH_SIZE_ENCODE_CHUNKS from onyx.configs.model_configs import ( BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES, ) from onyx.connectors.models import ConnectorStopSignal from onyx.db.models import SearchSettings from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface from onyx.natural_language_processing.constants import DEFAULT_COHERE_MODEL from onyx.natural_language_processing.constants import DEFAULT_OPENAI_MODEL from onyx.natural_language_processing.constants import DEFAULT_VERTEX_MODEL from onyx.natural_language_processing.constants import DEFAULT_VOYAGE_MODEL from onyx.natural_language_processing.constants import EmbeddingModelTextType from onyx.natural_language_processing.exceptions import CohereBillingLimitError from onyx.natural_language_processing.exceptions import ModelServerRateLimitError from onyx.natural_language_processing.utils import get_tokenizer from onyx.natural_language_processing.utils import tokenizer_trim_content from onyx.utils.logger import setup_logger from onyx.utils.search_nlp_models_utils import pass_aws_key from onyx.utils.text_processing import remove_invalid_unicode_chars from onyx.utils.timing import log_function_time from shared_configs.configs import API_BASED_EMBEDDING_TIMEOUT from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE from shared_configs.configs import INDEXING_ONLY from shared_configs.configs import MODEL_SERVER_HOST from shared_configs.configs import MODEL_SERVER_PORT from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT from shared_configs.configs import SKIP_WARM_UP from shared_configs.configs import VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE from shared_configs.enums import EmbeddingProvider from shared_configs.enums import EmbedTextType from shared_configs.enums import RerankerProvider from shared_configs.model_server_models import Embedding from shared_configs.model_server_models import EmbedRequest from shared_configs.model_server_models import EmbedResponse from shared_configs.model_server_models import IntentRequest from shared_configs.model_server_models import IntentResponse from shared_configs.model_server_models import RerankRequest from shared_configs.model_server_models import RerankResponse from shared_configs.utils import batch_list logger = setup_logger() # If we are not only indexing, dont want retry very long _RETRY_DELAY = 10 if INDEXING_ONLY else 0.1 _RETRY_TRIES = 10 if INDEXING_ONLY else 2 # OpenAI only allows 2048 embeddings to be computed at once _OPENAI_MAX_INPUT_LEN = 2048 # Cohere allows up to 96 embeddings in a single embedding calling _COHERE_MAX_INPUT_LEN = 96 # Authentication error string constants _AUTH_ERROR_401 = "401" _AUTH_ERROR_UNAUTHORIZED = "unauthorized" _AUTH_ERROR_INVALID_API_KEY = "invalid api key" _AUTH_ERROR_PERMISSION = "permission" # Thread-local storage for event loops # This prevents creating thousands of event loops during batch processing, # which was causing severe memory leaks with API-based embedding providers _thread_local = threading.local() def _get_or_create_event_loop() -> asyncio.AbstractEventLoop: """Get or create a thread-local event loop for API embedding calls. This prevents creating a new event loop for every batch during embedding, which was causing memory leaks. Instead, each thread reuses the same loop. Returns: asyncio.AbstractEventLoop: The thread-local event loop """ if ( not hasattr(_thread_local, "loop") or _thread_local.loop is None or _thread_local.loop.is_closed() ): _thread_local.loop = asyncio.new_event_loop() asyncio.set_event_loop(_thread_local.loop) return _thread_local.loop def cleanup_embedding_thread_locals() -> None: """Clean up thread-local event loops to prevent memory leaks. This should be called after each task completes to ensure that event loops and their associated resources are properly released. Thread-local storage persists across Celery tasks when using the thread pool, so explicit cleanup is necessary. NOTE: This must be called from the SAME thread that created the event loop. For ThreadPoolExecutor-based embedding, this cleanup happens automatically via the _cleanup_thread_local wrapper. """ if hasattr(_thread_local, "loop") and _thread_local.loop is not None: loop = _thread_local.loop if not loop.is_closed(): # Cancel all pending tasks in the event loop try: # Ensure loop is set as current event loop before accessing tasks asyncio.set_event_loop(loop) pending = asyncio.all_tasks(loop) if pending: logger.debug( f"Cleaning up event loop with {len(pending)} pending tasks in thread {threading.current_thread().name}" ) for task in pending: task.cancel() # Run the loop briefly to allow cancelled tasks to complete loop.run_until_complete( asyncio.gather(*pending, return_exceptions=True) ) except Exception as e: # If gathering tasks fails, just close the loop logger.debug(f"Error gathering tasks during cleanup: {e}") # Close the event loop loop.close() logger.debug( f"Closed event loop in thread {threading.current_thread().name}" ) # Clear the thread-local reference _thread_local.loop = None def _cleanup_thread_local(func: Callable) -> Callable: """Decorator to ensure thread-local cleanup after function execution. This wraps functions that run in ThreadPoolExecutor threads to ensure that thread-local event loops are cleaned up after each execution, preventing memory leaks from persistent thread-local storage. """ @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: try: return func(*args, **kwargs) finally: # Clean up thread-local event loop after this thread's work is done cleanup_embedding_thread_locals() return wrapper WARM_UP_STRINGS = [ "Onyx is amazing!", "Check out our easy deployment guide at", "https://docs.onyx.app/deployment/getting_started/quickstart", ] def clean_model_name(model_str: str) -> str: return model_str.replace("/", "_").replace("-", "_").replace(".", "_") def build_model_server_url( model_server_host: str, model_server_port: int, ) -> str: model_server_url = f"{model_server_host}:{model_server_port}" # use protocol if provided if "http" in model_server_url: return model_server_url # otherwise default to http return f"http://{model_server_url}" def is_authentication_error(error: Exception) -> bool: """Check if an exception is related to authentication issues. Args: error: The exception to check Returns: bool: True if the error appears to be authentication-related """ error_str = str(error).lower() return ( _AUTH_ERROR_401 in error_str or _AUTH_ERROR_UNAUTHORIZED in error_str or _AUTH_ERROR_INVALID_API_KEY in error_str or _AUTH_ERROR_PERMISSION in error_str ) def format_embedding_error( error: Exception, service_name: str, model: str | None, provider: EmbeddingProvider, sanitized_api_key: str | None = None, status_code: int | None = None, ) -> str: """ Format a standardized error string for embedding errors. """ detail = f"Status {status_code}" if status_code else f"{type(error)}" return ( f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: " f"Model: {model} " f"Provider: {provider} " f"API Key: {sanitized_api_key} " f"Exception: {error}" ) # Custom exception for authentication errors class AuthenticationError(Exception): """Raised when authentication fails with a provider.""" def __init__(self, provider: str, message: str = "API key is invalid or expired"): self.provider = provider self.message = message super().__init__(f"{provider} authentication failed: {message}") class CloudEmbedding: def __init__( self, api_key: str, provider: EmbeddingProvider, api_url: str | None = None, api_version: str | None = None, timeout: int = API_BASED_EMBEDDING_TIMEOUT, ) -> None: self.provider = provider self.api_key = api_key self.api_url = api_url self.api_version = api_version self.timeout = timeout self.http_client = httpx.AsyncClient(timeout=timeout) self._closed = False self.sanitized_api_key = api_key[:4] + "********" + api_key[-4:] async def _embed_openai( self, texts: list[str], model: str | None, reduced_dimension: int | None ) -> list[Embedding]: if not model: model = DEFAULT_OPENAI_MODEL import openai # Use the OpenAI specific timeout for this one client = openai.AsyncOpenAI( api_key=self.api_key, timeout=OPENAI_EMBEDDING_TIMEOUT ) final_embeddings: list[Embedding] = [] for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN): response = await client.embeddings.create( input=text_batch, model=model, dimensions=reduced_dimension or openai.omit, ) final_embeddings.extend( [embedding.embedding for embedding in response.data] ) return final_embeddings async def _embed_cohere( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: if not model: model = DEFAULT_COHERE_MODEL client = CohereAsyncClient(api_key=self.api_key) final_embeddings: list[Embedding] = [] for text_batch in batch_list(texts, _COHERE_MAX_INPUT_LEN): # Does not use the same tokenizer as the Onyx API server but it's approximately the same # empirically it's only off by a very few tokens so it's not a big deal response = await client.embed( texts=text_batch, model=model, input_type=embedding_type, truncate="END", ) final_embeddings.extend(cast(list[Embedding], response.embeddings)) return final_embeddings async def _embed_voyage( self, texts: list[str], model: str | None, embedding_type: str ) -> list[Embedding]: if not model: model = DEFAULT_VOYAGE_MODEL client = voyageai.AsyncClient( api_key=self.api_key, timeout=API_BASED_EMBEDDING_TIMEOUT ) response = await client.embed( texts=texts, model=model, input_type=embedding_type, truncation=True, ) return response.embeddings async def _embed_azure( self, texts: list[str], model: str | None ) -> list[Embedding]: from litellm import aembedding response = await aembedding( model=model, input=texts, timeout=API_BASED_EMBEDDING_TIMEOUT, api_key=self.api_key, api_base=self.api_url, api_version=self.api_version, ) embeddings = [embedding["embedding"] for embedding in response.data] return embeddings async def _embed_vertex( self, texts: list[str], model: str | None, embedding_type: str, reduced_dimension: int | None, ) -> list[Embedding]: from google import genai from google.genai import types as genai_types if not model: model = DEFAULT_VERTEX_MODEL service_account_info = json.loads(self.api_key) credentials = service_account.Credentials.from_service_account_info( service_account_info, scopes=["https://www.googleapis.com/auth/cloud-platform"], ) project_id = service_account_info["project_id"] location = ( service_account_info.get("location") or os.environ.get("GOOGLE_CLOUD_LOCATION") or "us-central1" ) client = genai.Client( vertexai=True, project=project_id, location=location, credentials=credentials, ) embed_config = genai_types.EmbedContentConfig( task_type=embedding_type, output_dimensionality=reduced_dimension, auto_truncate=True, ) async def _embed_batch(batch_texts: list[str]) -> list[Embedding]: content_requests: list[Any] = [ genai_types.Content(parts=[genai_types.Part(text=text)]) for text in batch_texts ] response = await client.aio.models.embed_content( model=model, contents=content_requests, config=embed_config, ) if not response.embeddings: raise RuntimeError("Received empty embeddings from Google GenAI.") embeddings: list[Embedding] = [] for idx, embedding in enumerate(response.embeddings): if embedding.values is None: raise RuntimeError( f"Missing embedding values for input at index {idx}." ) embeddings.append(embedding.values) return embeddings # Process VertexAI batches sequentially to avoid additional intra-task fanout. # The higher-level thread pool already provides concurrency; running these # requests in parallel here was causing excessive memory usage. batches = [ texts[i : i + VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE] for i in range(0, len(texts), VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE) ] all_embeddings: list[Embedding] = [] logger.debug( f"VertexAI embedding: processing {len(texts)} texts in {len(batches)} batches " f"(batch_size={VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE})" ) try: for batch_idx, batch in enumerate(batches): batch_embeddings = await _embed_batch(batch) all_embeddings.extend(batch_embeddings) # Log progress for large batches to track memory usage patterns if batch_idx % 10 == 0 and batch_idx > 0: logger.debug( f"VertexAI embedding progress: batch {batch_idx}/{len(batches)}, total_embeddings={len(all_embeddings)}" ) logger.debug( f"VertexAI embedding completed: {len(all_embeddings)} embeddings generated" ) return all_embeddings finally: # Ensure client is closed with a timeout to prevent hanging on stuck sessions try: await asyncio.wait_for(client.aio.aclose(), timeout=5.0) except asyncio.TimeoutError: logger.warning("Google GenAI client aclose() timed out after 5s") except Exception as e: logger.warning(f"Error closing Google GenAI client: {e}") async def _embed_litellm_proxy( self, texts: list[str], model_name: str | None ) -> list[Embedding]: if not model_name: raise ValueError("Model name is required for LiteLLM proxy embedding.") if not self.api_url: raise ValueError("API URL is required for LiteLLM proxy embedding.") headers = ( {} if not self.api_key else {"Authorization": f"Bearer {self.api_key}"} ) response = await self.http_client.post( self.api_url, json={ "model": model_name, "input": texts, }, headers=headers, ) response.raise_for_status() result = response.json() return [embedding["embedding"] for embedding in result["data"]] @retry(tries=_RETRY_TRIES, delay=_RETRY_DELAY) async def embed( self, *, texts: list[str], text_type: EmbedTextType, model_name: str | None = None, deployment_name: str | None = None, reduced_dimension: int | None = None, ) -> list[Embedding]: import openai try: if self.provider == EmbeddingProvider.OPENAI: return await self._embed_openai(texts, model_name, reduced_dimension) elif self.provider == EmbeddingProvider.AZURE: return await self._embed_azure(texts, f"azure/{deployment_name}") elif self.provider == EmbeddingProvider.LITELLM: return await self._embed_litellm_proxy(texts, model_name) embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type) if self.provider == EmbeddingProvider.COHERE: return await self._embed_cohere(texts, model_name, embedding_type) elif self.provider == EmbeddingProvider.VOYAGE: return await self._embed_voyage(texts, model_name, embedding_type) elif self.provider == EmbeddingProvider.GOOGLE: return await self._embed_vertex( texts, model_name, embedding_type, reduced_dimension ) else: raise ValueError(f"Unsupported provider: {self.provider}") except openai.AuthenticationError: raise AuthenticationError(provider="OpenAI") except httpx.HTTPStatusError as e: if e.response.status_code == 401: raise AuthenticationError(provider=str(self.provider)) error_string = format_embedding_error( e, str(self.provider), model_name or deployment_name, self.provider, sanitized_api_key=self.sanitized_api_key, status_code=e.response.status_code, ) logger.error(error_string) logger.debug(f"Exception texts: {texts}") raise RuntimeError(error_string) except Exception as e: if is_authentication_error(e): raise AuthenticationError(provider=str(self.provider)) error_string = format_embedding_error( e, str(self.provider), model_name or deployment_name, self.provider, sanitized_api_key=self.sanitized_api_key, ) logger.error(error_string) logger.debug(f"Exception texts: {texts}") raise RuntimeError(error_string) @staticmethod def create( api_key: str, provider: EmbeddingProvider, api_url: str | None = None, api_version: str | None = None, ) -> "CloudEmbedding": logger.debug(f"Creating Embedding instance for provider: {provider}") return CloudEmbedding(api_key, provider, api_url, api_version) async def aclose(self) -> None: """Explicitly close the client.""" if not self._closed: await self.http_client.aclose() self._closed = True async def __aenter__(self) -> "CloudEmbedding": return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: await self.aclose() def __del__(self) -> None: """Finalizer to warn about unclosed clients.""" if not self._closed: logger.warning( "CloudEmbedding was not properly closed. Use 'async with' or call aclose()" ) # API-based reranking functions (moved from model server) async def cohere_rerank_api( query: str, docs: list[str], model_name: str, api_key: str ) -> list[float]: cohere_client = CohereAsyncClient(api_key=api_key) try: response = await cohere_client.rerank( query=query, documents=docs, model=model_name ) except ApiError as err: if err.status_code == 402: logger.warning( "Cohere rerank request rejected due to billing cap. Falling back to retrieval ordering until billing resets." ) raise CohereBillingLimitError( "Cohere billing limit reached for reranking" ) from err raise results = response.results sorted_results = sorted(results, key=lambda item: item.index) return [result.relevance_score for result in sorted_results] async def cohere_rerank_aws( query: str, docs: list[str], model_name: str, region_name: str, aws_access_key_id: str, aws_secret_access_key: str, ) -> list[float]: session = aioboto3.Session( aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key ) async with session.client( "bedrock-runtime", region_name=region_name ) as bedrock_client: body = json.dumps( { "query": query, "documents": docs, "api_version": 2, } ) # Invoke the Bedrock model asynchronously response = await bedrock_client.invoke_model( modelId=model_name, accept="application/json", contentType="application/json", body=body, ) # Read the response asynchronously response_body = json.loads(await response["body"].read()) # Extract and sort the results results = response_body.get("results", []) sorted_results = sorted(results, key=lambda item: item["index"]) return [result["relevance_score"] for result in sorted_results] async def litellm_rerank( query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None ) -> list[float]: headers = {} if not api_key else {"Authorization": f"Bearer {api_key}"} async with httpx.AsyncClient() as client: response = await client.post( api_url, json={ "model": model_name, "query": query, "documents": docs, }, headers=headers, ) response.raise_for_status() result = response.json() return [ item["relevance_score"] for item in sorted(result["results"], key=lambda x: x["index"]) ] class EmbeddingModel: def __init__( self, server_host: str, # Changes depending on indexing or inference server_port: int, model_name: str | None, normalize: bool, query_prefix: str | None, passage_prefix: str | None, api_key: str | None, api_url: str | None, provider_type: EmbeddingProvider | None, retrim_content: bool = False, callback: IndexingHeartbeatInterface | None = None, api_version: str | None = None, deployment_name: str | None = None, reduced_dimension: int | None = None, ) -> None: self.api_key = api_key self.provider_type = provider_type self.query_prefix = query_prefix self.passage_prefix = passage_prefix self.normalize = normalize self.model_name = model_name self.retrim_content = retrim_content self.api_url = api_url self.api_version = api_version self.deployment_name = deployment_name self.reduced_dimension = reduced_dimension self.tokenizer = get_tokenizer( model_name=model_name, provider_type=provider_type ) self.callback = callback # Only build model server endpoint for local models if self.provider_type is None: model_server_url = build_model_server_url(server_host, server_port) self.embed_server_endpoint: str | None = ( f"{model_server_url}/encoder/bi-encoder-embed" ) else: # API providers don't need model server endpoint self.embed_server_endpoint = None async def _make_direct_api_call( self, embed_request: EmbedRequest, tenant_id: str | None = None, # noqa: ARG002 request_id: str | None = None, # noqa: ARG002 ) -> EmbedResponse: """Make direct API call to cloud provider, bypassing model server.""" if self.provider_type is None: raise ValueError("Provider type is required for direct API calls") if self.api_key is None: logger.error("API key not provided for cloud model") raise RuntimeError("API key not provided for cloud model") # Check for prefix usage with cloud models if embed_request.manual_query_prefix or embed_request.manual_passage_prefix: logger.warning("Prefix provided for cloud model, which is not supported") raise ValueError( "Prefix string is not valid for cloud models. Cloud models take an explicit text type instead." ) if not all(embed_request.texts): logger.error("Empty strings provided for embedding") raise ValueError("Empty strings are not allowed for embedding.") if not embed_request.texts: logger.error("No texts provided for embedding") raise ValueError("No texts provided for embedding.") start_time = time.monotonic() total_chars = sum(len(text) for text in embed_request.texts) logger.info( f"Embedding {len(embed_request.texts)} texts with {total_chars} total characters with provider: {self.provider_type}" ) async with CloudEmbedding( api_key=self.api_key, provider=self.provider_type, api_url=self.api_url, api_version=self.api_version, ) as cloud_model: embeddings = await cloud_model.embed( texts=embed_request.texts, model_name=embed_request.model_name, deployment_name=embed_request.deployment_name, text_type=embed_request.text_type, reduced_dimension=embed_request.reduced_dimension, ) if any(embedding is None for embedding in embeddings): error_message = "Embeddings contain None values\n" error_message += "Corresponding texts:\n" error_message += "\n".join(embed_request.texts) logger.error(error_message) raise ValueError(error_message) elapsed = time.monotonic() - start_time logger.info( f"event=embedding_provider " f"texts={len(embed_request.texts)} " f"chars={total_chars} " f"provider={self.provider_type} " f"elapsed={elapsed:.2f}" ) return EmbedResponse(embeddings=embeddings) def _make_model_server_request( self, embed_request: EmbedRequest, tenant_id: str | None = None, request_id: str | None = None, ) -> EmbedResponse: if self.embed_server_endpoint is None: raise ValueError("Model server endpoint is not configured for local models") # Store the endpoint in a local variable to help mypy understand it's not None endpoint = self.embed_server_endpoint def _make_request() -> Response: headers = {} if tenant_id: headers["X-Onyx-Tenant-ID"] = tenant_id if request_id: headers["X-Onyx-Request-ID"] = request_id response = requests.post( endpoint, headers=headers, json=embed_request.model_dump(), ) # signify that this is a rate limit error if response.status_code == 429: raise ModelServerRateLimitError(response.text) response.raise_for_status() return response final_make_request_func = _make_request # if the text type is a passage, add some default # retries + handling for rate limiting if embed_request.text_type == EmbedTextType.PASSAGE: final_make_request_func = retry( tries=3, delay=5, exceptions=(RequestException, ValueError, JSONDecodeError), )(final_make_request_func) # use 10 second delay as per Azure suggestion final_make_request_func = retry( tries=10, delay=10, exceptions=ModelServerRateLimitError )(final_make_request_func) response: Response | None = None try: response = final_make_request_func() return EmbedResponse(**response.json()) except requests.HTTPError as e: if not response: raise HTTPError("HTTP error occurred - response is None.") from e try: error_detail = response.json().get("detail", str(e)) except Exception: error_detail = response.text raise HTTPError(f"HTTP error occurred: {error_detail}") from e except requests.RequestException as e: raise HTTPError(f"Request failed: {str(e)}") from e def _batch_encode_texts( self, texts: list[str], text_type: EmbedTextType, batch_size: int, max_seq_length: int, num_threads: int = INDEXING_EMBEDDING_MODEL_NUM_THREADS, tenant_id: str | None = None, request_id: str | None = None, ) -> list[Embedding]: text_batches = batch_list(texts, batch_size) logger.debug(f"Encoding {len(texts)} texts in {len(text_batches)} batches") embeddings: list[Embedding] = [] @_cleanup_thread_local def process_batch( batch_idx: int, batch_len: int, text_batch: list[str], tenant_id: str | None = None, request_id: str | None = None, ) -> tuple[int, list[Embedding]]: if self.callback: if self.callback.should_stop(): raise ConnectorStopSignal( "_batch_encode_texts detected stop signal" ) embed_request = EmbedRequest( model_name=self.model_name, texts=text_batch, api_version=self.api_version, deployment_name=self.deployment_name, max_context_length=max_seq_length, normalize_embeddings=self.normalize, api_key=self.api_key, provider_type=self.provider_type, text_type=text_type, manual_query_prefix=self.query_prefix, manual_passage_prefix=self.passage_prefix, api_url=self.api_url, reduced_dimension=self.reduced_dimension, ) start_time = time.monotonic() # Route between direct API calls and model server calls if self.provider_type is not None: # For API providers, make direct API call # Use thread-local event loop to prevent memory leaks from creating # thousands of event loops during batch processing loop = _get_or_create_event_loop() response = loop.run_until_complete( self._make_direct_api_call( embed_request, tenant_id=tenant_id, request_id=request_id ) ) else: # For local models, use model server response = self._make_model_server_request( embed_request, tenant_id=tenant_id, request_id=request_id ) end_time = time.monotonic() processing_time = end_time - start_time logger.debug( f"EmbeddingModel.process_batch: Batch {batch_idx}/{batch_len} processing time: {processing_time:.2f} seconds" ) return batch_idx, response.embeddings # only multi thread if: # 1. num_threads is greater than 1 # 2. we are using an API-based embedding model (provider_type is not None) # 3. there are more than 1 batch (no point in threading if only 1) if num_threads >= 1 and self.provider_type and len(text_batches) > 1: with ThreadPoolExecutor(max_workers=num_threads) as executor: future_to_batch = { executor.submit( partial( process_batch, idx, len(text_batches), batch, tenant_id=tenant_id, request_id=request_id, ) ): idx for idx, batch in enumerate(text_batches, start=1) } # Collect results in order batch_results: list[tuple[int, list[Embedding]]] = [] for future in as_completed(future_to_batch): try: result = future.result() batch_results.append(result) except Exception as e: logger.exception("Embedding model failed to process batch") raise e # Sort by batch index and extend embeddings batch_results.sort(key=lambda x: x[0]) for _, batch_embeddings in batch_results: embeddings.extend(batch_embeddings) else: # Original sequential processing for idx, text_batch in enumerate(text_batches, start=1): _, batch_embeddings = process_batch( idx, len(text_batches), text_batch, tenant_id=tenant_id, request_id=request_id, ) embeddings.extend(batch_embeddings) return embeddings @log_function_time(print_only=True, debug_only=True) def encode( self, texts: list[str], text_type: EmbedTextType, large_chunks_present: bool = False, local_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS, api_embedding_batch_size: int = BATCH_SIZE_ENCODE_CHUNKS_FOR_API_EMBEDDING_SERVICES, max_seq_length: int = DOC_EMBEDDING_CONTEXT_SIZE, tenant_id: str | None = None, request_id: str | None = None, ) -> list[Embedding]: if not texts or not all(texts): raise ValueError(f"Empty or missing text for embedding: {texts}") if large_chunks_present: max_seq_length *= LARGE_CHUNK_RATIO if self.retrim_content: # This is applied during indexing as a catchall for overly long titles (or other uncapped fields) # Note that this uses just the default tokenizer which may also lead to very minor miscountings # However this slight miscounting is very unlikely to have any material impact. texts = [ tokenizer_trim_content( content=text, desired_length=max_seq_length, tokenizer=self.tokenizer, ) for text in texts ] # Remove invalid Unicode characters (e.g., unpaired surrogates from malformed documents) # that would cause UTF-8 encoding errors when sent to embedding providers texts = [remove_invalid_unicode_chars(text) or "<>" for text in texts] batch_size = ( api_embedding_batch_size if self.provider_type else local_embedding_batch_size ) return self._batch_encode_texts( texts=texts, text_type=text_type, batch_size=batch_size, max_seq_length=max_seq_length, tenant_id=tenant_id, request_id=request_id, ) @classmethod def from_db_model( cls, search_settings: SearchSettings, server_host: str, # Changes depending on indexing or inference server_port: int, retrim_content: bool = False, ) -> "EmbeddingModel": return cls( server_host=server_host, server_port=server_port, model_name=search_settings.model_name, normalize=search_settings.normalize, query_prefix=search_settings.query_prefix, passage_prefix=search_settings.passage_prefix, api_key=search_settings.api_key, provider_type=search_settings.provider_type, api_url=search_settings.api_url, retrim_content=retrim_content, api_version=search_settings.api_version, deployment_name=search_settings.deployment_name, reduced_dimension=search_settings.reduced_dimension, ) class RerankingModel: def __init__( self, model_name: str, provider_type: RerankerProvider | None, api_key: str | None, api_url: str | None, model_server_host: str = MODEL_SERVER_HOST, model_server_port: int = MODEL_SERVER_PORT, ) -> None: self.model_name = model_name self.provider_type = provider_type self.api_key = api_key self.api_url = api_url # Only build model server endpoint for local models if self.provider_type is None: model_server_url = build_model_server_url( model_server_host, model_server_port ) self.rerank_server_endpoint: str | None = ( model_server_url + "/encoder/cross-encoder-scores" ) else: # API providers don't need model server endpoint self.rerank_server_endpoint = None async def _make_direct_rerank_call( self, query: str, passages: list[str] ) -> list[float]: """Make direct API call to cloud provider, bypassing model server.""" if self.provider_type is None: raise ValueError("Provider type is required for direct API calls") if self.api_key is None: raise ValueError("API key is required for cloud provider") if self.provider_type == RerankerProvider.COHERE: return await cohere_rerank_api( query, passages, self.model_name, self.api_key ) elif self.provider_type == RerankerProvider.BEDROCK: aws_access_key_id, aws_secret_access_key, aws_region = pass_aws_key( self.api_key ) return await cohere_rerank_aws( query, passages, self.model_name, aws_region, aws_access_key_id, aws_secret_access_key, ) elif self.provider_type == RerankerProvider.LITELLM: if self.api_url is None: raise ValueError("API URL is required for LiteLLM reranking.") return await litellm_rerank( query, passages, self.api_url, self.model_name, self.api_key ) else: raise ValueError(f"Unsupported reranking provider: {self.provider_type}") def predict(self, query: str, passages: list[str]) -> list[float]: # Route between direct API calls and model server calls if self.provider_type is not None: # For API providers, make direct API call loop = asyncio.new_event_loop() try: asyncio.set_event_loop(loop) return loop.run_until_complete( self._make_direct_rerank_call(query, passages) ) finally: loop.close() else: # For local models, use model server if self.rerank_server_endpoint is None: raise ValueError( "Rerank server endpoint is not configured for local models" ) rerank_request = RerankRequest( query=query, documents=passages, model_name=self.model_name, provider_type=self.provider_type, api_key=self.api_key, api_url=self.api_url, ) response = requests.post( self.rerank_server_endpoint, json=rerank_request.model_dump() ) response.raise_for_status() return RerankResponse(**response.json()).scores class QueryAnalysisModel: def __init__( self, model_server_host: str = MODEL_SERVER_HOST, model_server_port: int = MODEL_SERVER_PORT, # Lean heavily towards not throwing out keywords keyword_percent_threshold: float = 0.1, # Lean towards semantic which is the default semantic_percent_threshold: float = 0.4, ) -> None: model_server_url = build_model_server_url(model_server_host, model_server_port) self.intent_server_endpoint = model_server_url + "/custom/query-analysis" self.keyword_percent_threshold = keyword_percent_threshold self.semantic_percent_threshold = semantic_percent_threshold def predict( self, query: str, ) -> tuple[bool, list[str]]: intent_request = IntentRequest( query=query, keyword_percent_threshold=self.keyword_percent_threshold, semantic_percent_threshold=self.semantic_percent_threshold, ) response = requests.post( self.intent_server_endpoint, json=intent_request.model_dump() ) response.raise_for_status() response_model = IntentResponse(**response.json()) return response_model.is_keyword, response_model.keywords def warm_up_retry( func: Callable[..., Any], tries: int = 20, delay: int = 5, *args: Any, # noqa: ARG001 **kwargs: Any, # noqa: ARG001 ) -> Callable[..., Any]: @wraps(func) def wrapper(*args: Any, **kwargs: Any) -> Any: exceptions = [] for attempt in range(tries): try: return func(*args, **kwargs) except Exception as e: exceptions.append(e) logger.info( f"Attempt {attempt + 1}/{tries} failed; retrying in {delay} seconds..." ) time.sleep(delay) raise Exception(f"All retries failed: {exceptions}") return wrapper def warm_up_bi_encoder( embedding_model: EmbeddingModel, non_blocking: bool = False, ) -> None: if SKIP_WARM_UP: return warm_up_str = " ".join(WARM_UP_STRINGS) logger.debug(f"Warming up encoder model: {embedding_model.model_name}") get_tokenizer( model_name=embedding_model.model_name, provider_type=embedding_model.provider_type, ).encode(warm_up_str) def _warm_up() -> None: try: embedding_model.encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY) logger.debug( f"Warm-up complete for encoder model: {embedding_model.model_name}" ) except Exception as e: logger.warning( f"Warm-up request failed for encoder model {embedding_model.model_name}: {e}" ) if non_blocking: threading.Thread(target=_warm_up, daemon=True).start() logger.debug( f"Started non-blocking warm-up for encoder model: {embedding_model.model_name}" ) else: retry_encode = warm_up_retry(embedding_model.encode) retry_encode(texts=[warm_up_str], text_type=EmbedTextType.QUERY) # No longer used def warm_up_cross_encoder( rerank_model_name: str, non_blocking: bool = False, ) -> None: if SKIP_WARM_UP: return logger.debug(f"Warming up reranking model: {rerank_model_name}") reranking_model = RerankingModel( model_name=rerank_model_name, provider_type=None, api_url=None, api_key=None, ) def _warm_up() -> None: try: reranking_model.predict(WARM_UP_STRINGS[0], WARM_UP_STRINGS[1:]) logger.debug(f"Warm-up complete for reranking model: {rerank_model_name}") except Exception as e: logger.warning( f"Warm-up request failed for reranking model {rerank_model_name}: {e}" ) if non_blocking: threading.Thread(target=_warm_up, daemon=True).start() logger.debug( f"Started non-blocking warm-up for reranking model: {rerank_model_name}" ) else: retry_rerank = warm_up_retry(reranking_model.predict) retry_rerank(WARM_UP_STRINGS[0], WARM_UP_STRINGS[1:]) ================================================ FILE: backend/onyx/natural_language_processing/utils.py ================================================ import os from abc import ABC from abc import abstractmethod from copy import copy from tokenizers import Encoding # type: ignore[import-untyped] from tokenizers import Tokenizer from onyx.configs.model_configs import DOCUMENT_ENCODER_MODEL from onyx.context.search.models import InferenceChunk from onyx.utils.logger import setup_logger from shared_configs.configs import DOC_EMBEDDING_CONTEXT_SIZE from shared_configs.enums import EmbeddingProvider TRIM_SEP_PAT = "\n... {n} tokens removed...\n" logger = setup_logger() os.environ["TOKENIZERS_PARALLELISM"] = "false" os.environ["HF_HUB_DISABLE_TELEMETRY"] = "1" class BaseTokenizer(ABC): @abstractmethod def encode(self, string: str) -> list[int]: pass @abstractmethod def tokenize(self, string: str) -> list[str]: pass @abstractmethod def decode(self, tokens: list[int]) -> str: pass class TiktokenTokenizer(BaseTokenizer): _instances: dict[str, "TiktokenTokenizer"] = {} def __new__(cls, model_name: str) -> "TiktokenTokenizer": if model_name not in cls._instances: cls._instances[model_name] = super(TiktokenTokenizer, cls).__new__(cls) return cls._instances[model_name] def __init__(self, model_name: str): if not hasattr(self, "encoder"): import tiktoken self.encoder = tiktoken.encoding_for_model(model_name) def encode(self, string: str) -> list[int]: # this ignores special tokens that the model is trained on, see encode_ordinary for details return self.encoder.encode_ordinary(string) def tokenize(self, string: str) -> list[str]: encoded = self.encode(string) decoded = [self.encoder.decode([token]) for token in encoded] if len(decoded) != len(encoded): logger.warning( f"OpenAI tokenized length {len(decoded)} does not match encoded length {len(encoded)} for string: {string}" ) return decoded def decode(self, tokens: list[int]) -> str: return self.encoder.decode(tokens) class HuggingFaceTokenizer(BaseTokenizer): def __init__(self, model_name: str): self.encoder: Tokenizer = Tokenizer.from_pretrained(model_name) def _safer_encode(self, string: str) -> Encoding: """ Encode a string using the HuggingFaceTokenizer, but if it fails, encode the string as ASCII and decode it back to a string. This helps in cases where the string has weird characters like \udeb4. """ try: return self.encoder.encode(string, add_special_tokens=False) except Exception: return self.encoder.encode( string.encode("ascii", "ignore").decode(), add_special_tokens=False ) def encode(self, string: str) -> list[int]: # this returns no special tokens return self._safer_encode(string).ids def tokenize(self, string: str) -> list[str]: return self._safer_encode(string).tokens def decode(self, tokens: list[int]) -> str: return self.encoder.decode(tokens) _TOKENIZER_CACHE: dict[tuple[EmbeddingProvider | None, str | None], BaseTokenizer] = {} def _check_tokenizer_cache( model_provider: EmbeddingProvider | None, model_name: str | None ) -> BaseTokenizer: global _TOKENIZER_CACHE id_tuple = (model_provider, model_name) if id_tuple not in _TOKENIZER_CACHE: tokenizer = None if model_name: tokenizer = _try_initialize_tokenizer(model_name, model_provider) if not tokenizer: logger.info( f"Falling back to default embedding model tokenizer: {DOCUMENT_ENCODER_MODEL}" ) tokenizer = _get_default_tokenizer() _TOKENIZER_CACHE[id_tuple] = tokenizer return _TOKENIZER_CACHE[id_tuple] def _try_initialize_tokenizer( model_name: str, model_provider: EmbeddingProvider | None ) -> BaseTokenizer | None: tokenizer: BaseTokenizer | None = None if model_provider is not None: # Try using TiktokenTokenizer first if model_provider exists try: tokenizer = TiktokenTokenizer(model_name) logger.info(f"Initialized TiktokenTokenizer for: {model_name}") return tokenizer except Exception as tiktoken_error: logger.debug( f"TiktokenTokenizer not available for model {model_name}: {tiktoken_error}" ) else: # If no provider specified, try HuggingFaceTokenizer try: tokenizer = HuggingFaceTokenizer(model_name) logger.info(f"Initialized HuggingFaceTokenizer for: {model_name}") return tokenizer except Exception as hf_error: logger.warning( f"Failed to initialize HuggingFaceTokenizer for {model_name}: {hf_error}" ) # If both initializations fail, return None return None _DEFAULT_TOKENIZER: BaseTokenizer | None = None def _get_default_tokenizer() -> BaseTokenizer: """Lazy-load the default tokenizer to avoid loading it at module import time.""" global _DEFAULT_TOKENIZER if _DEFAULT_TOKENIZER is None: _DEFAULT_TOKENIZER = HuggingFaceTokenizer(DOCUMENT_ENCODER_MODEL) return _DEFAULT_TOKENIZER def get_tokenizer( model_name: str | None, provider_type: EmbeddingProvider | str | None ) -> BaseTokenizer: if isinstance(provider_type, str): try: provider_type = EmbeddingProvider(provider_type) except ValueError: logger.debug( f"Invalid provider_type '{provider_type}'. Falling back to default tokenizer." ) return _get_default_tokenizer() return _check_tokenizer_cache(provider_type, model_name) # Max characters per encode() call. _ENCODE_CHUNK_SIZE = 500_000 def count_tokens( text: str, tokenizer: BaseTokenizer, token_limit: int | None = None, ) -> int: """Count tokens, chunking the input to avoid tiktoken stack overflow. If token_limit is provided and the text is large enough to require multiple chunks (> 500k chars), stops early once the count exceeds it. When early-exiting, the returned value exceeds token_limit but may be less than the true full token count. """ if len(text) <= _ENCODE_CHUNK_SIZE: return len(tokenizer.encode(text)) total = 0 for start in range(0, len(text), _ENCODE_CHUNK_SIZE): total += len(tokenizer.encode(text[start : start + _ENCODE_CHUNK_SIZE])) if token_limit is not None and total > token_limit: return total # Already over — skip remaining chunks return total def tokenizer_trim_content( content: str, desired_length: int, tokenizer: BaseTokenizer ) -> str: tokens = tokenizer.encode(content) if len(tokens) <= desired_length: return content return tokenizer.decode(tokens[:desired_length]) def tokenizer_trim_middle( tokens: list[int], desired_length: int, tokenizer: BaseTokenizer ) -> str: if len(tokens) <= desired_length: return tokenizer.decode(tokens) sep_str = TRIM_SEP_PAT.format(n=len(tokens) - desired_length) sep_tokens = tokenizer.encode(sep_str) slice_size = (desired_length - len(sep_tokens)) // 2 assert slice_size > 0, "Slice size is not positive, desired length is too short" return ( tokenizer.decode(tokens[:slice_size]) + sep_str + tokenizer.decode(tokens[-slice_size:]) ) def tokenizer_trim_chunks( chunks: list[InferenceChunk], tokenizer: BaseTokenizer, max_chunk_toks: int = DOC_EMBEDDING_CONTEXT_SIZE, ) -> list[InferenceChunk]: new_chunks = copy(chunks) for ind, chunk in enumerate(new_chunks): new_content = tokenizer_trim_content(chunk.content, max_chunk_toks, tokenizer) if len(new_content) != len(chunk.content): new_chunk = copy(chunk) new_chunk.content = new_content new_chunks[ind] = new_chunk return new_chunks ================================================ FILE: backend/onyx/onyxbot/discord/DISCORD_MULTITENANT_README.md ================================================ # Discord Bot Multitenant Architecture This document analyzes how the Discord cache manager and API client coordinate to handle multitenant API keys from a single Discord client. ## Overview The Discord bot uses a **single-client, multi-tenant** architecture where one `OnyxDiscordClient` instance serves multiple tenants (organizations) simultaneously. Tenant isolation is achieved through: - **Cache Manager**: Maps Discord guilds to tenants and stores per-tenant API keys - **API Client**: Stateless HTTP client that accepts dynamic API keys per request ``` ┌─────────────────────────────────────────────────────────────────────┐ │ OnyxDiscordClient │ │ │ │ ┌─────────────────────────┐ ┌─────────────────────────────┐ │ │ │ DiscordCacheManager │ │ OnyxAPIClient │ │ │ │ │ │ │ │ │ │ guild_id → tenant_id │───▶│ send_chat_message( │ │ │ │ tenant_id → api_key │ │ message, │ │ │ │ │ │ api_key=, │ │ │ └─────────────────────────┘ │ persona_id=... │ │ │ │ ) │ │ │ └─────────────────────────────┘ │ └─────────────────────────────────────────────────────────────────────┘ ``` --- ## Component Details ### 1. Cache Manager (`backend/onyx/onyxbot/discord/cache.py`) The `DiscordCacheManager` maintains two critical in-memory mappings: ```python class DiscordCacheManager: _guild_tenants: dict[int, str] # guild_id → tenant_id _api_keys: dict[str, str] # tenant_id → api_key _lock: asyncio.Lock # Concurrency control ``` #### Key Responsibilities | Function | Purpose | |----------|---------| | `get_tenant(guild_id)` | O(1) lookup: guild → tenant | | `get_api_key(tenant_id)` | O(1) lookup: tenant → API key | | `refresh_all()` | Full cache rebuild from database | | `refresh_guild()` | Incremental update for single guild | #### API Key Provisioning Strategy API keys are **lazily provisioned** - only created when first needed: ```python async def _load_tenant_data(self, tenant_id: str) -> tuple[list[int], str | None]: needs_key = tenant_id not in self._api_keys with get_session_with_tenant(tenant_id) as db: # Load guild configs configs = get_discord_bot_configs(db) guild_ids = [c.guild_id for c in configs if c.enabled] # Only provision API key if not already cached api_key = None if needs_key: api_key = get_or_create_discord_service_api_key(db, tenant_id) return guild_ids, api_key ``` This optimization avoids repeated database calls for API key generation. #### Concurrency Control All write operations acquire an async lock to prevent race conditions: ```python async def refresh_all(self) -> None: async with self._lock: # Safe to modify _guild_tenants and _api_keys for tenant_id in get_all_tenant_ids(): guild_ids, api_key = await self._load_tenant_data(tenant_id) # Update mappings... ``` Read operations (`get_tenant`, `get_api_key`) are lock-free since Python dict lookups are atomic. --- ### 2. API Client (`backend/onyx/onyxbot/discord/api_client.py`) The `OnyxAPIClient` is a **stateless async HTTP client** that communicates with Onyx API pods. #### Key Design: Per-Request API Key Injection ```python class OnyxAPIClient: async def send_chat_message( self, message: str, api_key: str, # Injected per-request persona_id: int | None, ... ) -> ChatFullResponse: headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", # Tenant-specific auth } # Make request... ``` The client accepts `api_key` as a parameter to each method, enabling **dynamic tenant selection at request time**. This design allows a single client instance to serve multiple tenants: ```python # Same client, different tenants await api_client.send_chat_message(msg, api_key=key_for_tenant_1, ...) await api_client.send_chat_message(msg, api_key=key_for_tenant_2, ...) ``` --- ## Coordination Flow ### Message Processing Pipeline When a Discord message arrives, the client coordinates cache and API client: ```python async def on_message(self, message: Message) -> None: guild_id = message.guild.id # Step 1: Cache lookup - guild → tenant tenant_id = self.cache.get_tenant(guild_id) if not tenant_id: return # Guild not registered # Step 2: Cache lookup - tenant → API key api_key = self.cache.get_api_key(tenant_id) if not api_key: logger.warning(f"No API key for tenant {tenant_id}") return # Step 3: API call with tenant-specific credentials await process_chat_message( message=message, api_key=api_key, # Tenant-specific persona_id=persona_id, # Tenant-specific api_client=self.api_client, ) ``` ### Startup Sequence ```python async def setup_hook(self) -> None: # 1. Initialize API client (create aiohttp session) await self.api_client.initialize() # 2. Populate cache with all tenants await self.cache.refresh_all() # 3. Start background refresh task self._cache_refresh_task = self.loop.create_task( self._periodic_cache_refresh() # Every 60 seconds ) ``` ### Shutdown Sequence ```python async def close(self) -> None: # 1. Cancel background refresh if self._cache_refresh_task: self._cache_refresh_task.cancel() # 2. Close Discord connection await super().close() # 3. Close API client session await self.api_client.close() # 4. Clear cache self.cache.clear() ``` --- ## Tenant Isolation Mechanisms ### 1. Per-Tenant API Keys Each tenant has a dedicated service API key: ```python # backend/onyx/db/discord_bot.py def get_or_create_discord_service_api_key(db_session: Session, tenant_id: str) -> str: existing = get_discord_service_api_key(db_session) if existing: return regenerate_key(existing) # Create LIMITED role key (chat-only permissions) return insert_api_key( db_session=db_session, api_key_args=APIKeyArgs( name=DISCORD_SERVICE_API_KEY_NAME, role=UserRole.LIMITED, # Minimal permissions ), user_id=None, # Service account (system-owned) ).api_key ``` ### 2. Database Context Variables The cache uses context variables for proper tenant-scoped DB sessions: ```python context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) try: with get_session_with_tenant(tenant_id) as db: # All DB operations scoped to this tenant ... finally: CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token) ``` ### 3. Enterprise Gating Support Gated tenants are filtered during cache refresh: ```python gated_tenants = fetch_ee_implementation_or_noop( "onyx.server.tenants.product_gating", "get_gated_tenants", set(), )() for tenant_id in get_all_tenant_ids(): if tenant_id in gated_tenants: continue # Skip gated tenants ``` --- ## Cache Refresh Strategy | Trigger | Method | Scope | |---------|--------|-------| | Startup | `refresh_all()` | All tenants | | Periodic (60s) | `refresh_all()` | All tenants | | Guild registration | `refresh_guild()` | Single tenant | ### Error Handling - **Tenant-level errors**: Logged and skipped (doesn't stop other tenants) - **Missing API key**: Bot silently ignores messages from that guild - **Network errors**: Logged, cache continues with stale data until next refresh --- ## Key Design Insights 1. **Single Client, Multiple Tenants**: One `OnyxAPIClient` and one `DiscordCacheManager` instance serves all tenants via dynamic API key injection. 2. **Cache-First Architecture**: Guild lookups are O(1) in-memory; API keys are cached after first provisioning to avoid repeated DB calls. 3. **Graceful Degradation**: If an API key is missing or stale, the bot simply doesn't respond (no crash or error propagation). 4. **Thread Safety Without Blocking**: `asyncio.Lock` prevents race conditions while maintaining async concurrency for reads. 5. **Lazy Provisioning**: API keys are only created when first needed, then cached for performance. 6. **Stateless API Client**: The HTTP client holds no tenant state - all tenant context is injected per-request via the `api_key` parameter. --- ## File References | Component | Path | |-----------|------| | Cache Manager | `backend/onyx/onyxbot/discord/cache.py` | | API Client | `backend/onyx/onyxbot/discord/api_client.py` | | Discord Client | `backend/onyx/onyxbot/discord/client.py` | | API Key DB Operations | `backend/onyx/db/discord_bot.py` | | Cache Manager Tests | `backend/tests/unit/onyx/onyxbot/discord/test_cache_manager.py` | | API Client Tests | `backend/tests/unit/onyx/onyxbot/discord/test_api_client.py` | ================================================ FILE: backend/onyx/onyxbot/discord/api_client.py ================================================ """Async HTTP client for communicating with Onyx API pods.""" import aiohttp from onyx.chat.models import ChatFullResponse from onyx.onyxbot.discord.constants import API_REQUEST_TIMEOUT from onyx.onyxbot.discord.exceptions import APIConnectionError from onyx.onyxbot.discord.exceptions import APIResponseError from onyx.onyxbot.discord.exceptions import APITimeoutError from onyx.server.query_and_chat.models import ChatSessionCreationRequest from onyx.server.query_and_chat.models import MessageOrigin from onyx.server.query_and_chat.models import SendMessageRequest from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import build_api_server_url_for_http_requests logger = setup_logger() class OnyxAPIClient: """Async HTTP client for sending chat requests to Onyx API pods. This client manages an aiohttp session for making non-blocking HTTP requests to the Onyx API server. It handles authentication with per-tenant API keys and multi-tenant routing. Usage: client = OnyxAPIClient() await client.initialize() try: response = await client.send_chat_message( message="What is our deployment process?", tenant_id="tenant_123", api_key="dn_xxx...", persona_id=1, ) print(response.answer) finally: await client.close() """ def __init__( self, timeout: int = API_REQUEST_TIMEOUT, ) -> None: """Initialize the API client. Args: timeout: Request timeout in seconds. """ # Helm chart uses API_SERVER_URL_OVERRIDE_FOR_HTTP_REQUESTS to set the base URL # TODO: Ideally, this override is only used when someone is launching an Onyx service independently self._base_url = build_api_server_url_for_http_requests( respect_env_override_if_set=True ).rstrip("/") self._timeout = timeout self._session: aiohttp.ClientSession | None = None async def initialize(self) -> None: """Create the aiohttp session. Must be called before making any requests. The session is created with a total timeout and connection timeout. """ if self._session is not None: logger.warning("API client session already initialized") return timeout = aiohttp.ClientTimeout( total=self._timeout, connect=30, # 30 seconds to establish connection ) self._session = aiohttp.ClientSession(timeout=timeout) logger.info(f"API client initialized with base URL: {self._base_url}") async def close(self) -> None: """Close the aiohttp session. Should be called when shutting down the bot to properly release resources. """ if self._session is not None: await self._session.close() self._session = None logger.info("API client session closed") @property def is_initialized(self) -> bool: """Check if the session is initialized.""" return self._session is not None async def send_chat_message( self, message: str, api_key: str, persona_id: int | None = None, ) -> ChatFullResponse: """Send a chat message to the Onyx API server and get a response. This method sends a non-streaming chat request to the API server. The response contains the complete answer with any citations and metadata. Args: message: The user's message to process. api_key: The API key for authentication. persona_id: Optional persona ID to use for the response. Returns: ChatFullResponse containing the answer, citations, and metadata. Raises: APIConnectionError: If unable to connect to the API. APITimeoutError: If the request times out. APIResponseError: If the API returns an error response. """ if self._session is None: raise APIConnectionError( "API client not initialized. Call initialize() first." ) url = f"{self._base_url}/chat/send-chat-message" # Build request payload request = SendMessageRequest( message=message, stream=False, origin=MessageOrigin.DISCORDBOT, chat_session_info=ChatSessionCreationRequest( persona_id=persona_id if persona_id is not None else 0, ), ) # Build headers headers = { "Content-Type": "application/json", "Authorization": f"Bearer {api_key}", } try: async with self._session.post( url, json=request.model_dump(mode="json"), headers=headers, ) as response: if response.status == 401: raise APIResponseError( "Authentication failed - invalid API key", status_code=401, ) elif response.status == 403: raise APIResponseError( "Access denied - insufficient permissions", status_code=403, ) elif response.status == 404: raise APIResponseError( "API endpoint not found", status_code=404, ) elif response.status >= 500: error_text = await response.text() raise APIResponseError( f"Server error: {error_text}", status_code=response.status, ) elif response.status >= 400: error_text = await response.text() raise APIResponseError( f"Request error: {error_text}", status_code=response.status, ) # Parse successful response data = await response.json() response_obj = ChatFullResponse.model_validate(data) if response_obj.error_msg: logger.warning(f"Chat API returned error: {response_obj.error_msg}") return response_obj except aiohttp.ClientConnectorError as e: logger.error(f"Failed to connect to API: {e}") raise APIConnectionError( f"Failed to connect to API at {self._base_url}: {e}" ) from e except TimeoutError as e: logger.error(f"API request timed out after {self._timeout}s") raise APITimeoutError( f"Request timed out after {self._timeout} seconds" ) from e except aiohttp.ClientError as e: logger.error(f"HTTP client error: {e}") raise APIConnectionError(f"HTTP client error: {e}") from e async def health_check(self) -> bool: """Check if the API server is healthy. Returns: True if the API server is reachable and healthy, False otherwise. """ if self._session is None: logger.warning("API client not initialized. Call initialize() first.") return False try: url = f"{self._base_url}/health" async with self._session.get( url, timeout=aiohttp.ClientTimeout(total=10) ) as response: return response.status == 200 except Exception as e: logger.warning(f"API server health check failed: {e}") return False ================================================ FILE: backend/onyx/onyxbot/discord/cache.py ================================================ """Multi-tenant cache for Discord bot guild-tenant mappings and API keys.""" import asyncio from onyx.db.discord_bot import get_guild_configs from onyx.db.discord_bot import get_or_create_discord_service_api_key from onyx.db.engine.sql_engine import get_session_with_tenant from onyx.db.engine.tenant_utils import get_all_tenant_ids from onyx.onyxbot.discord.exceptions import CacheError from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() class DiscordCacheManager: """Caches guild->tenant mappings and tenant->API key mappings. Refreshed on startup, periodically (every 60s), and when guilds register. """ def __init__(self) -> None: self._guild_tenants: dict[int, str] = {} # guild_id -> tenant_id self._api_keys: dict[str, str] = {} # tenant_id -> api_key self._lock = asyncio.Lock() self._initialized = False @property def is_initialized(self) -> bool: return self._initialized async def refresh_all(self) -> None: """Full cache refresh from all tenants.""" async with self._lock: logger.info("Starting Discord cache refresh") new_guild_tenants: dict[int, str] = {} new_api_keys: dict[str, str] = {} try: gated = fetch_ee_implementation_or_noop( "onyx.server.tenants.product_gating", "get_gated_tenants", set(), )() tenant_ids = await asyncio.to_thread(get_all_tenant_ids) for tenant_id in tenant_ids: if tenant_id in gated: continue context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) try: guild_ids, api_key = await self._load_tenant_data(tenant_id) if not guild_ids: logger.debug(f"No guilds found for tenant {tenant_id}") continue if not api_key: logger.warning( "Discord service API key missing for tenant that has registered guilds. " f"{tenant_id} will not be handled in this refresh cycle." ) continue for guild_id in guild_ids: new_guild_tenants[guild_id] = tenant_id new_api_keys[tenant_id] = api_key except Exception as e: logger.warning(f"Failed to refresh tenant {tenant_id}: {e}") finally: CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token) self._guild_tenants = new_guild_tenants self._api_keys = new_api_keys self._initialized = True logger.info( f"Cache refresh complete: {len(new_guild_tenants)} guilds, {len(new_api_keys)} tenants" ) except Exception as e: logger.error(f"Cache refresh failed: {e}") raise CacheError(f"Failed to refresh cache: {e}") from e async def refresh_guild(self, guild_id: int, tenant_id: str) -> None: """Add a single guild to cache after registration.""" async with self._lock: logger.info(f"Refreshing cache for guild {guild_id} (tenant: {tenant_id})") guild_ids, api_key = await self._load_tenant_data(tenant_id) if guild_id in guild_ids: self._guild_tenants[guild_id] = tenant_id if api_key: self._api_keys[tenant_id] = api_key logger.info(f"Cache updated for guild {guild_id}") else: logger.warning(f"Guild {guild_id} not found or disabled") async def _load_tenant_data(self, tenant_id: str) -> tuple[list[int], str | None]: """Load guild IDs and provision API key if needed. Returns: (active_guild_ids, api_key) - api_key is the cached key if available, otherwise a newly created key. Returns None if no guilds found. """ cached_key = self._api_keys.get(tenant_id) def _sync() -> tuple[list[int], str | None]: with get_session_with_tenant(tenant_id=tenant_id) as db: configs = get_guild_configs(db) guild_ids = [ config.guild_id for config in configs if config.enabled and config.guild_id is not None ] if not guild_ids: return [], None if not cached_key: new_key = get_or_create_discord_service_api_key(db, tenant_id) db.commit() return guild_ids, new_key return guild_ids, cached_key return await asyncio.to_thread(_sync) def get_tenant(self, guild_id: int) -> str | None: """Get tenant ID for a guild.""" return self._guild_tenants.get(guild_id) def get_api_key(self, tenant_id: str) -> str | None: """Get API key for a tenant.""" return self._api_keys.get(tenant_id) def remove_guild(self, guild_id: int) -> None: """Remove a guild from cache.""" self._guild_tenants.pop(guild_id, None) def get_all_guild_ids(self) -> list[int]: """Get all cached guild IDs.""" return list(self._guild_tenants.keys()) def clear(self) -> None: """Clear all caches.""" self._guild_tenants.clear() self._api_keys.clear() self._initialized = False ================================================ FILE: backend/onyx/onyxbot/discord/client.py ================================================ """Discord bot client with integrated message handling.""" import asyncio import time import discord from discord.ext import commands from onyx.configs.app_configs import DISCORD_BOT_INVOKE_CHAR from onyx.onyxbot.discord.api_client import OnyxAPIClient from onyx.onyxbot.discord.cache import DiscordCacheManager from onyx.onyxbot.discord.constants import CACHE_REFRESH_INTERVAL from onyx.onyxbot.discord.handle_commands import handle_dm from onyx.onyxbot.discord.handle_commands import handle_registration_command from onyx.onyxbot.discord.handle_commands import handle_sync_channels_command from onyx.onyxbot.discord.handle_message import process_chat_message from onyx.onyxbot.discord.handle_message import should_respond from onyx.onyxbot.discord.utils import get_bot_token from onyx.utils.logger import setup_logger logger = setup_logger() class OnyxDiscordClient(commands.Bot): """Discord bot client with integrated cache, API client, and message handling. This client handles: - Guild registration via !register command - Message processing with persona-based responses - Thread context for conversation continuity - Multi-tenant support via cached API keys """ def __init__(self, command_prefix: str = DISCORD_BOT_INVOKE_CHAR) -> None: intents = discord.Intents.default() intents.message_content = True intents.members = True super().__init__(command_prefix=command_prefix, intents=intents) self.ready = False self.cache = DiscordCacheManager() self.api_client = OnyxAPIClient() self._cache_refresh_task: asyncio.Task | None = None # ------------------------------------------------------------------------- # Lifecycle Methods # ------------------------------------------------------------------------- async def setup_hook(self) -> None: """Called before on_ready. Initialize components.""" logger.info("Initializing Discord bot components...") # Initialize API client await self.api_client.initialize() # Initial cache load await self.cache.refresh_all() # Start periodic cache refresh self._cache_refresh_task = self.loop.create_task(self._periodic_cache_refresh()) logger.info("Discord bot components initialized") async def _periodic_cache_refresh(self) -> None: """Background task to refresh cache periodically.""" while not self.is_closed(): await asyncio.sleep(CACHE_REFRESH_INTERVAL) try: await self.cache.refresh_all() except Exception as e: logger.error(f"Cache refresh failed: {e}") async def on_ready(self) -> None: """Bot connected and ready.""" if self.ready: return if not self.user: raise RuntimeError("Critical error: Discord Bot user not found") logger.info(f"Discord Bot connected as {self.user} (ID: {self.user.id})") logger.info(f"Connected to {len(self.guilds)} guild(s)") logger.info(f"Cached {len(self.cache.get_all_guild_ids())} registered guild(s)") self.ready = True async def close(self) -> None: """Graceful shutdown.""" logger.info("Shutting down Discord bot...") # Cancel cache refresh task if self._cache_refresh_task: self._cache_refresh_task.cancel() try: await self._cache_refresh_task except asyncio.CancelledError: pass # Close Discord connection first - stops new commands from triggering cache ops if not self.is_closed(): await super().close() # Close API client await self.api_client.close() # Clear cache (safe now - no concurrent operations possible) self.cache.clear() self.ready = False logger.info("Discord bot shutdown complete") # ------------------------------------------------------------------------- # Message Handling # ------------------------------------------------------------------------- async def on_message(self, message: discord.Message) -> None: """Main message handler.""" # mypy if not self.user: raise RuntimeError("Critical error: Discord Bot user not found") try: # Ignore bot messages if message.author.bot: return # Ignore thread starter messages (empty reference nodes that don't contain content) if message.type == discord.MessageType.thread_starter_message: return # Handle DMs if isinstance(message.channel, discord.DMChannel): await handle_dm(message) return # Must have a guild if not message.guild or not message.guild.id: return guild_id = message.guild.id # Check for registration command first if await handle_registration_command(message, self.cache): return # Look up guild in cache tenant_id = self.cache.get_tenant(guild_id) # Check for sync-channels command (requires registered guild) if await handle_sync_channels_command(message, tenant_id, self): return if not tenant_id: # Guild not registered, ignore return # Get API key api_key = self.cache.get_api_key(tenant_id) if not api_key: logger.warning(f"No API key cached for tenant {tenant_id}") return # Check if bot should respond should_respond_context = await should_respond(message, tenant_id, self.user) if not should_respond_context.should_respond: return logger.debug( f"Processing message: '{message.content[:50]}' in " f"#{getattr(message.channel, 'name', 'unknown')} ({message.guild.name}), " f"persona_id={should_respond_context.persona_id}" ) # Process the message await process_chat_message( message=message, api_key=api_key, persona_id=should_respond_context.persona_id, thread_only_mode=should_respond_context.thread_only_mode, api_client=self.api_client, bot_user=self.user, ) except Exception as e: logger.exception(f"Error processing message: {e}") # ----------------------------------------------------------------------------- # Entry Point # ----------------------------------------------------------------------------- def main() -> None: """Main entry point for Discord bot.""" from onyx.db.engine.sql_engine import SqlEngine from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable logger.info("Starting Onyx Discord Bot...") # Initialize the database engine (required before any DB operations) SqlEngine.init_engine(pool_size=20, max_overflow=5) # Initialize EE features based on environment set_is_ee_based_on_env_variable() counter = 0 while True: token = get_bot_token() if not token: if counter % 180 == 0: logger.info( "Discord bot is dormant. Waiting for token configuration..." ) counter += 1 time.sleep(5) continue counter = 0 bot = OnyxDiscordClient() try: # bot.run() handles SIGINT/SIGTERM and calls close() automatically bot.run(token) except Exception: logger.exception("Fatal error in Discord bot") raise if __name__ == "__main__": main() ================================================ FILE: backend/onyx/onyxbot/discord/constants.py ================================================ """Discord bot constants.""" # API settings API_REQUEST_TIMEOUT: int = 3 * 60 # 3 minutes # Cache settings CACHE_REFRESH_INTERVAL: int = 60 # 1 minute # Message settings MAX_MESSAGE_LENGTH: int = 2000 # Discord's character limit MAX_CONTEXT_MESSAGES: int = 10 # Max messages to include in conversation context # Note: Discord.py's add_reaction() requires unicode emoji, not :name: format THINKING_EMOJI: str = "🤔" # U+1F914 - Thinking Face SUCCESS_EMOJI: str = "✅" # U+2705 - White Heavy Check Mark ERROR_EMOJI: str = "❌" # U+274C - Cross Mark # Command prefix REGISTER_COMMAND: str = "register" SYNC_CHANNELS_COMMAND: str = "sync-channels" ================================================ FILE: backend/onyx/onyxbot/discord/exceptions.py ================================================ """Custom exception classes for Discord bot.""" class DiscordBotError(Exception): """Base exception for Discord bot errors.""" class RegistrationError(DiscordBotError): """Error during guild registration.""" class SyncChannelsError(DiscordBotError): """Error during channel sync.""" class APIError(DiscordBotError): """Base API error.""" class CacheError(DiscordBotError): """Error during cache operations.""" class APIConnectionError(APIError): """Failed to connect to API.""" class APITimeoutError(APIError): """Request timed out.""" class APIResponseError(APIError): """API returned an error response.""" def __init__(self, message: str, status_code: int | None = None): super().__init__(message) self.status_code = status_code ================================================ FILE: backend/onyx/onyxbot/discord/handle_commands.py ================================================ """Discord bot command handlers for registration and channel sync.""" import asyncio from datetime import datetime from datetime import timezone import discord from onyx.configs.app_configs import DISCORD_BOT_INVOKE_CHAR from onyx.configs.constants import ONYX_DISCORD_URL from onyx.db.discord_bot import bulk_create_channel_configs from onyx.db.discord_bot import get_guild_config_by_discord_id from onyx.db.discord_bot import get_guild_config_by_internal_id from onyx.db.discord_bot import get_guild_config_by_registration_key from onyx.db.discord_bot import sync_channel_configs from onyx.db.engine.sql_engine import get_session_with_tenant from onyx.db.utils import DiscordChannelView from onyx.onyxbot.discord.cache import DiscordCacheManager from onyx.onyxbot.discord.constants import REGISTER_COMMAND from onyx.onyxbot.discord.constants import SYNC_CHANNELS_COMMAND from onyx.onyxbot.discord.exceptions import RegistrationError from onyx.onyxbot.discord.exceptions import SyncChannelsError from onyx.server.manage.discord_bot.utils import parse_discord_registration_key from onyx.utils.logger import setup_logger from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() async def handle_dm(message: discord.Message) -> None: """Handle direct messages.""" dm_response = ( "**I can't respond to DMs** :sweat:\n\n" f"Please chat with me in a server channel, or join the official " f"[Onyx Discord]({ONYX_DISCORD_URL}) for help!" ) await message.channel.send(dm_response) # ------------------------------------------------------------------------- # Helper functions for error handling # ------------------------------------------------------------------------- async def _try_dm_author(message: discord.Message, content: str) -> bool: """Attempt to DM the message author. Returns True if successful.""" logger.debug(f"Responding in Discord DM with {content}") try: await message.author.send(content) return True except (discord.Forbidden, discord.HTTPException) as e: # User has DMs disabled or other error logger.warning(f"Failed to DM author {message.author.id}: {e}") except Exception as e: logger.exception(f"Unexpected error DMing author {message.author.id}: {e}") return False async def _try_delete_message(message: discord.Message) -> bool: """Attempt to delete a message. Returns True if successful.""" logger.debug(f"Deleting potentially sensitive message {message.id}") try: await message.delete() return True except (discord.Forbidden, discord.HTTPException) as e: # Bot lacks permission or other error logger.warning(f"Failed to delete message {message.id}: {e}") except Exception as e: logger.exception(f"Unexpected error deleting message {message.id}: {e}") return False async def _try_react_x(message: discord.Message) -> bool: """Attempt to react to a message with ❌. Returns True if successful.""" try: await message.add_reaction("❌") return True except (discord.Forbidden, discord.HTTPException) as e: # Bot lacks permission or other error logger.warning(f"Failed to react to message {message.id}: {e}") except Exception as e: logger.exception(f"Unexpected error reacting to message {message.id}: {e}") return False # ------------------------------------------------------------------------- # Registration # ------------------------------------------------------------------------- async def handle_registration_command( message: discord.Message, cache: DiscordCacheManager, ) -> bool: """Handle !register command. Returns True if command was handled.""" content = message.content.strip() # Check for !register command if not content.startswith(f"{DISCORD_BOT_INVOKE_CHAR}{REGISTER_COMMAND}"): return False # Must be in a server if not message.guild: await _try_dm_author( message, "This command can only be used in a server channel." ) return True guild_name = message.guild.name logger.info(f"Registration command received: {guild_name}") try: # Parse the registration key parts = content.split(maxsplit=1) if len(parts) < 2: raise RegistrationError( "Invalid registration key format. Please check the key and try again." ) registration_key = parts[1].strip() if not message.author or not isinstance(message.author, discord.Member): raise RegistrationError( "You need to be a server administrator to register the bot." ) # Check permissions - require admin or manage_guild if not message.author.guild_permissions.administrator: if not message.author.guild_permissions.manage_guild: raise RegistrationError( "You need **Administrator** or **Manage Server** permissions to register this bot." ) await _register_guild(message, registration_key, cache) logger.info(f"Registration successful: {guild_name}") await message.reply( ":white_check_mark: **Successfully registered!**\n\n" "This server is now connected to Onyx. " "I'll respond to messages based on your server and channel settings set in Onyx." ) except RegistrationError as e: logger.debug(f"Registration failed: {guild_name}, error={e}") await _try_dm_author(message, f":x: **Registration failed.**\n\n{e}") await _try_delete_message(message) except Exception: logger.exception(f"Registration failed unexpectedly: {guild_name}") await _try_dm_author( message, ":x: **Registration failed.**\n\nAn unexpected error occurred. Please try again later.", ) await _try_delete_message(message) return True async def _register_guild( message: discord.Message, registration_key: str, cache: DiscordCacheManager, ) -> None: """Register a guild with a registration key.""" if not message.guild: # mypy, even though we already know that message.guild is not None raise RegistrationError("This command can only be used in a server.") logger.info(f"Guild '{message.guild.name}' attempting to register Discord bot") registration_key = registration_key.strip() # Parse tenant_id from registration key parsed = parse_discord_registration_key(registration_key) if parsed is None: raise RegistrationError( "Invalid registration key format. Please check the key and try again." ) tenant_id = parsed logger.info(f"Parsed tenant_id {tenant_id} from registration key") # Check if this guild is already registered to any tenant guild_id = message.guild.id existing_tenant = cache.get_tenant(guild_id) if existing_tenant is not None: logger.warning( f"Guild {guild_id} is already registered to tenant {existing_tenant}" ) raise RegistrationError( "This server is already registered.\n\nOnyxBot can only connect one Discord server to one Onyx workspace." ) context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) try: guild = message.guild guild_name = guild.name # Collect all text channels from the guild channels = get_text_channels(guild) logger.info(f"Found {len(channels)} text channels in guild '{guild_name}'") # Validate and update in database def _sync_register() -> int: with get_session_with_tenant(tenant_id=tenant_id) as db: # Find the guild config by registration key config = get_guild_config_by_registration_key(db, registration_key) if not config: raise RegistrationError( "Registration key not found.\n\n" "The key may have expired or been deleted. " "Please generate a new one from the Onyx admin panel." ) # Check if already used if config.guild_id is not None: raise RegistrationError( "This registration key has already been used.\n\n" "Each key can only be used once. " "Please generate a new key from the Onyx admin panel." ) # Update the guild config config.guild_id = guild_id config.guild_name = guild_name config.registered_at = datetime.now(timezone.utc) # Create channel configs for all text channels bulk_create_channel_configs(db, config.id, channels) db.commit() return config.id await asyncio.to_thread(_sync_register) # Refresh cache for this guild await cache.refresh_guild(guild_id, tenant_id) logger.info( f"Guild '{guild_name}' registered with {len(channels)} channel configs" ) finally: CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token) def get_text_channels(guild: discord.Guild) -> list[DiscordChannelView]: """Get all text channels from a guild as DiscordChannelView objects.""" channels: list[DiscordChannelView] = [] for channel in guild.channels: # Include text channels and forum channels (where threads can be created) if isinstance(channel, (discord.TextChannel, discord.ForumChannel)): # Check if channel is private (not visible to @everyone) everyone_perms = channel.permissions_for(guild.default_role) is_private = not everyone_perms.view_channel logger.debug( f"Found channel: #{channel.name}, type={channel.type.name}, is_private={is_private}" ) channels.append( DiscordChannelView( channel_id=channel.id, channel_name=channel.name, channel_type=channel.type.name, # "text" or "forum" is_private=is_private, ) ) logger.debug(f"Retrieved {len(channels)} channels from guild '{guild.name}'") return channels # ------------------------------------------------------------------------- # Sync Channels # ------------------------------------------------------------------------- async def handle_sync_channels_command( message: discord.Message, tenant_id: str | None, bot: discord.Client, ) -> bool: """Handle !sync-channels command. Returns True if command was handled.""" content = message.content.strip() # Check for !sync-channels command if not content.startswith(f"{DISCORD_BOT_INVOKE_CHAR}{SYNC_CHANNELS_COMMAND}"): return False # Must be in a server if not message.guild: await _try_dm_author( message, "This command can only be used in a server channel." ) return True guild_name = message.guild.name logger.info(f"Sync-channels command received: {guild_name}") try: # Must be registered if not tenant_id: raise SyncChannelsError( "This server is not registered. Please register it first." ) # Check permissions - require admin or manage_guild if not message.author or not isinstance(message.author, discord.Member): raise SyncChannelsError( "You need to be a server administrator to sync channels." ) if not message.author.guild_permissions.administrator: if not message.author.guild_permissions.manage_guild: raise SyncChannelsError( "You need **Administrator** or **Manage Server** permissions to sync channels." ) # Get guild config ID def _get_guild_config_id() -> int | None: with get_session_with_tenant(tenant_id=tenant_id) as db: if not message.guild: raise SyncChannelsError( "Server not found. This shouldn't happen. Please contact Onyx support." ) config = get_guild_config_by_discord_id(db, message.guild.id) return config.id if config else None guild_config_id = await asyncio.to_thread(_get_guild_config_id) if not guild_config_id: raise SyncChannelsError( "Server config not found. This shouldn't happen. Please contact Onyx support." ) # Perform the sync added, removed, updated = await sync_guild_channels( guild_config_id, tenant_id, bot ) logger.info( f"Sync-channels successful: {guild_name}, added={added}, removed={removed}, updated={updated}" ) await message.reply( f":white_check_mark: **Channel sync complete!**\n\n" f"* **{added}** new channel(s) added\n" f"* **{removed}** deleted channel(s) removed\n" f"* **{updated}** channel name(s) updated\n\n" "New channels are disabled by default. Enable them in the Onyx admin panel." ) except SyncChannelsError as e: logger.debug(f"Sync-channels failed: {guild_name}, error={e}") await _try_dm_author(message, f":x: **Channel sync failed.**\n\n{e}") await _try_react_x(message) except Exception: logger.exception(f"Sync-channels failed unexpectedly: {guild_name}") await _try_dm_author( message, ":x: **Channel sync failed.**\n\nAn unexpected error occurred. Please try again later.", ) await _try_react_x(message) return True async def sync_guild_channels( guild_config_id: int, tenant_id: str, bot: discord.Client, ) -> tuple[int, int, int]: """Sync channel configs with current Discord channels for a guild. Fetches current channels from Discord and syncs with database: - Creates configs for new channels (disabled by default) - Removes configs for deleted channels - Updates names for existing channels if changed Args: guild_config_id: Internal ID of the guild config tenant_id: Tenant ID for database access bot: Discord bot client Returns: (added_count, removed_count, updated_count) Raises: ValueError: If guild config not found or guild not registered """ context_token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id) try: # Get guild_id from config def _get_guild_id() -> int | None: with get_session_with_tenant(tenant_id=tenant_id) as db: config = get_guild_config_by_internal_id(db, guild_config_id) if not config: return None return config.guild_id guild_id = await asyncio.to_thread(_get_guild_id) if guild_id is None: raise ValueError( f"Guild config {guild_config_id} not found or not registered" ) # Get the guild from Discord guild = bot.get_guild(guild_id) if not guild: raise ValueError(f"Guild {guild_id} not found in Discord cache") # Get current channels from Discord channels = get_text_channels(guild) logger.info(f"Syncing {len(channels)} channels for guild '{guild.name}'") # Sync with database def _sync() -> tuple[int, int, int]: with get_session_with_tenant(tenant_id=tenant_id) as db: added, removed, updated = sync_channel_configs( db, guild_config_id, channels ) db.commit() return added, removed, updated added, removed, updated = await asyncio.to_thread(_sync) logger.info( f"Channel sync complete for guild '{guild.name}': added={added}, removed={removed}, updated={updated}" ) return added, removed, updated finally: CURRENT_TENANT_ID_CONTEXTVAR.reset(context_token) ================================================ FILE: backend/onyx/onyxbot/discord/handle_message.py ================================================ """Discord bot message handling and response logic.""" import asyncio import discord from pydantic import BaseModel from onyx.chat.models import ChatFullResponse from onyx.db.discord_bot import get_channel_config_by_discord_ids from onyx.db.discord_bot import get_guild_config_by_discord_id from onyx.db.engine.sql_engine import get_session_with_tenant from onyx.db.models import DiscordChannelConfig from onyx.db.models import DiscordGuildConfig from onyx.onyxbot.discord.api_client import OnyxAPIClient from onyx.onyxbot.discord.constants import MAX_CONTEXT_MESSAGES from onyx.onyxbot.discord.constants import MAX_MESSAGE_LENGTH from onyx.onyxbot.discord.constants import THINKING_EMOJI from onyx.onyxbot.discord.exceptions import APIError from onyx.utils.logger import setup_logger logger = setup_logger() # Message types with actual content (excludes system notifications like "user joined") CONTENT_MESSAGE_TYPES = ( discord.MessageType.default, discord.MessageType.reply, discord.MessageType.thread_starter_message, ) class ShouldRespondContext(BaseModel): """Context for whether the bot should respond to a message.""" should_respond: bool persona_id: int | None thread_only_mode: bool # ------------------------------------------------------------------------- # Response Logic # ------------------------------------------------------------------------- async def should_respond( message: discord.Message, tenant_id: str, bot_user: discord.ClientUser, ) -> ShouldRespondContext: """Determine if bot should respond and which persona to use.""" if not message.guild: logger.warning("Received a message that isn't in a server.") return ShouldRespondContext( should_respond=False, persona_id=None, thread_only_mode=False ) guild_id = message.guild.id channel_id = message.channel.id bot_mentioned = bot_user in message.mentions def _get_configs() -> tuple[DiscordGuildConfig | None, DiscordChannelConfig | None]: with get_session_with_tenant(tenant_id=tenant_id) as db: guild_config = get_guild_config_by_discord_id(db, guild_id) if not guild_config or not guild_config.enabled: return None, None # For threads, use parent channel ID actual_channel_id = channel_id if isinstance(message.channel, discord.Thread) and message.channel.parent: actual_channel_id = message.channel.parent.id channel_config = get_channel_config_by_discord_ids( db, guild_id, actual_channel_id ) return guild_config, channel_config guild_config, channel_config = await asyncio.to_thread(_get_configs) if not guild_config or not channel_config or not channel_config.enabled: return ShouldRespondContext( should_respond=False, persona_id=None, thread_only_mode=False ) # Determine persona (channel override or guild default) persona_id = channel_config.persona_override_id or guild_config.default_persona_id # Check mention requirement (with exceptions for implicit invocation) if channel_config.require_bot_invocation and not bot_mentioned: if not await check_implicit_invocation(message, bot_user): return ShouldRespondContext( should_respond=False, persona_id=None, thread_only_mode=False ) return ShouldRespondContext( should_respond=True, persona_id=persona_id, thread_only_mode=channel_config.thread_only_mode, ) async def check_implicit_invocation( message: discord.Message, bot_user: discord.ClientUser, ) -> bool: """Check if the bot should respond without explicit mention. Returns True if: 1. User is replying to a bot message 2. User is in a thread owned by the bot 3. User is in a thread created from a bot message """ # Check if replying to a bot message if message.reference and message.reference.message_id: try: referenced_msg = await message.channel.fetch_message( message.reference.message_id ) if referenced_msg.author.id == bot_user.id: logger.debug( f"Implicit invocation via reply: '{message.content[:50]}...'" ) return True except (discord.NotFound, discord.HTTPException): pass # Check thread-related conditions if isinstance(message.channel, discord.Thread): thread = message.channel # Bot owns the thread if thread.owner_id == bot_user.id: logger.debug( f"Implicit invocation via bot-owned thread: '{message.content[:50]}...' in #{thread.name}" ) return True # Thread was created from a bot message if thread.parent and not isinstance(thread.parent, discord.ForumChannel): try: starter = await thread.parent.fetch_message(thread.id) if starter.author.id == bot_user.id: logger.debug( f"Implicit invocation via bot-started thread: '{message.content[:50]}...' in #{thread.name}" ) return True except (discord.NotFound, discord.HTTPException): pass return False # ------------------------------------------------------------------------- # Message Processing # ------------------------------------------------------------------------- async def process_chat_message( message: discord.Message, api_key: str, persona_id: int | None, thread_only_mode: bool, api_client: OnyxAPIClient, bot_user: discord.ClientUser, ) -> None: """Process a message and send response.""" try: await message.add_reaction(THINKING_EMOJI) except discord.DiscordException: logger.warning( f"Failed to add thinking reaction to message: '{message.content[:50]}...'" ) try: # Build conversation context context = await _build_conversation_context(message, bot_user) # Prepare full message content parts = [] if context: parts.append(context) if isinstance(message.channel, discord.Thread): if isinstance(message.channel.parent, discord.ForumChannel): parts.append(f"Forum post title: {message.channel.name}") parts.append( f"Current message from @{message.author.display_name}: {format_message_content(message)}" ) # Send to API response = await api_client.send_chat_message( message="\n\n".join(parts), api_key=api_key, persona_id=persona_id, ) # Format response with citations answer = response.answer or "I couldn't generate a response." answer = _append_citations(answer, response) await send_response(message, answer, thread_only_mode) try: await message.remove_reaction(THINKING_EMOJI, bot_user) except discord.DiscordException: pass except APIError as e: logger.error(f"API error processing message: {e}") await send_error_response(message, bot_user) except Exception as e: logger.exception(f"Error processing chat message: {e}") await send_error_response(message, bot_user) async def _build_conversation_context( message: discord.Message, bot_user: discord.ClientUser, ) -> str | None: """Build conversation context from thread history or reply chain.""" if isinstance(message.channel, discord.Thread): return await _build_thread_context(message, bot_user) elif message.reference: return await _build_reply_chain_context(message, bot_user) return None def _append_citations(answer: str, response: ChatFullResponse) -> str: """Append citation sources to the answer if present.""" if not response.citation_info or not response.top_documents: return answer cited_docs: list[tuple[int, str, str | None]] = [] for citation in response.citation_info: doc = next( ( d for d in response.top_documents if d.document_id == citation.document_id ), None, ) if doc: cited_docs.append( ( citation.citation_number, doc.semantic_identifier or "Source", doc.link, ) ) if not cited_docs: return answer cited_docs.sort(key=lambda x: x[0]) citations = "\n\n**Sources:**\n" for num, name, link in cited_docs[:5]: if link: citations += f"{num}. [{name}](<{link}>)\n" else: citations += f"{num}. {name}\n" return answer + citations # ------------------------------------------------------------------------- # Context Building # ------------------------------------------------------------------------- async def _build_reply_chain_context( message: discord.Message, bot_user: discord.ClientUser, ) -> str | None: """Build context by following the reply chain backwards.""" if not message.reference or not message.reference.message_id: return None try: messages: list[discord.Message] = [] current = message # Follow reply chain backwards up to MAX_CONTEXT_MESSAGES while ( current.reference and current.reference.message_id and len(messages) < MAX_CONTEXT_MESSAGES ): try: parent = await message.channel.fetch_message( current.reference.message_id ) messages.append(parent) current = parent except (discord.NotFound, discord.HTTPException): break if not messages: return None messages.reverse() # Chronological order logger.debug( f"Built reply chain context: {len(messages)} messages in #{getattr(message.channel, 'name', 'unknown')}" ) return _format_messages_as_context(messages, bot_user) except Exception as e: logger.warning(f"Failed to build reply chain context: {e}") return None async def _build_thread_context( message: discord.Message, bot_user: discord.ClientUser, ) -> str | None: """Build context from thread message history.""" if not isinstance(message.channel, discord.Thread): return None try: thread = message.channel messages: list[discord.Message] = [] # Fetch recent messages (excluding current) async for msg in thread.history(limit=MAX_CONTEXT_MESSAGES, oldest_first=False): if msg.id != message.id: messages.append(msg) # Include thread starter message and its reply chain if not already present if thread.parent and not isinstance(thread.parent, discord.ForumChannel): try: starter = await thread.parent.fetch_message(thread.id) if starter.id != message.id and not any( m.id == starter.id for m in messages ): messages.append(starter) # Trace back through the starter's reply chain for more context current = starter while ( current.reference and current.reference.message_id and len(messages) < MAX_CONTEXT_MESSAGES ): try: parent = await thread.parent.fetch_message( current.reference.message_id ) if not any(m.id == parent.id for m in messages): messages.append(parent) current = parent except (discord.NotFound, discord.HTTPException): break except (discord.NotFound, discord.HTTPException): pass if not messages: return None messages.sort(key=lambda m: m.id) # Chronological order logger.debug( f"Built thread context: {len(messages)} messages in #{thread.name}" ) return _format_messages_as_context(messages, bot_user) except Exception as e: logger.warning(f"Failed to build thread context: {e}") return None def _format_messages_as_context( messages: list[discord.Message], bot_user: discord.ClientUser, ) -> str | None: """Format a list of messages into a conversation context string.""" formatted = [] for msg in messages: if msg.type not in CONTENT_MESSAGE_TYPES: continue sender = ( "OnyxBot" if msg.author.id == bot_user.id else f"@{msg.author.display_name}" ) formatted.append(f"{sender}: {format_message_content(msg)}") if not formatted: return None return ( "You are a Discord bot named OnyxBot.\n" 'Always assume that [user] is the same as the "Current message" author.' "Conversation history:\n" "---\n" + "\n".join(formatted) + "\n---" ) # ------------------------------------------------------------------------- # Message Formatting # ------------------------------------------------------------------------- def format_message_content(message: discord.Message) -> str: """Format message content with readable mentions.""" content = message.content for user in message.mentions: content = content.replace(f"<@{user.id}>", f"@{user.display_name}") content = content.replace(f"<@!{user.id}>", f"@{user.display_name}") for role in message.role_mentions: content = content.replace(f"<@&{role.id}>", f"@{role.name}") for channel in message.channel_mentions: content = content.replace(f"<#{channel.id}>", f"#{channel.name}") return content # ------------------------------------------------------------------------- # Response Sending # ------------------------------------------------------------------------- async def send_response( message: discord.Message, content: str, thread_only_mode: bool, ) -> None: """Send response based on thread_only_mode setting.""" chunks = _split_message(content) if isinstance(message.channel, discord.Thread): for chunk in chunks: await message.channel.send(chunk) elif thread_only_mode: thread_name = f"OnyxBot <> {message.author.display_name}"[:100] thread = await message.create_thread(name=thread_name) for chunk in chunks: await thread.send(chunk) else: for i, chunk in enumerate(chunks): if i == 0: await message.reply(chunk) else: await message.channel.send(chunk) def _split_message(content: str) -> list[str]: """Split content into chunks that fit Discord's message limit.""" chunks = [] while content: if len(content) <= MAX_MESSAGE_LENGTH: chunks.append(content) break # Find a good split point split_at = MAX_MESSAGE_LENGTH for sep in ["\n\n", "\n", ". ", " "]: idx = content.rfind(sep, 0, MAX_MESSAGE_LENGTH) if idx > MAX_MESSAGE_LENGTH // 2: split_at = idx + len(sep) break chunks.append(content[:split_at]) content = content[split_at:] return chunks async def send_error_response( message: discord.Message, bot_user: discord.ClientUser, ) -> None: """Send error response and clean up reaction.""" try: await message.remove_reaction(THINKING_EMOJI, bot_user) except discord.DiscordException: pass error_msg = "Sorry, I encountered an error processing your message. You may want to contact Onyx for support :sweat_smile:" try: if isinstance(message.channel, discord.Thread): await message.channel.send(error_msg) else: thread = await message.create_thread( name=f"Response to {message.author.display_name}"[:100] ) await thread.send(error_msg) except discord.DiscordException: pass ================================================ FILE: backend/onyx/onyxbot/discord/utils.py ================================================ from onyx.configs.app_configs import AUTH_TYPE from onyx.configs.app_configs import DISCORD_BOT_TOKEN from onyx.configs.constants import AuthType from onyx.db.discord_bot import get_discord_bot_config from onyx.db.engine.sql_engine import get_session_with_tenant from onyx.utils.logger import setup_logger from onyx.utils.sensitive import SensitiveValue from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA logger = setup_logger() def get_bot_token() -> str | None: """Get Discord bot token from env var or database. Priority: 1. DISCORD_BOT_TOKEN env var (always takes precedence) 2. For self-hosted: DiscordBotConfig in database (default tenant) 3. For Cloud: should always have env var set Returns: Bot token string, or None if not configured. """ # Environment variable takes precedence if DISCORD_BOT_TOKEN: return DISCORD_BOT_TOKEN # Cloud should always have env var; if not, return None if AUTH_TYPE == AuthType.CLOUD: logger.warning("Cloud deployment missing DISCORD_BOT_TOKEN env var") return None # Self-hosted: check database for bot config try: with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db: config = get_discord_bot_config(db) except Exception as e: logger.error(f"Failed to get bot token from database: {e}") return None if config and config.bot_token: if isinstance(config.bot_token, SensitiveValue): return config.bot_token.get_value(apply_mask=False) return config.bot_token return None ================================================ FILE: backend/onyx/onyxbot/slack/blocks.py ================================================ from datetime import datetime from typing import cast import pytz import timeago # type: ignore from slack_sdk.models.blocks import ActionsBlock from slack_sdk.models.blocks import Block from slack_sdk.models.blocks import ButtonElement from slack_sdk.models.blocks import ContextBlock from slack_sdk.models.blocks import DividerBlock from slack_sdk.models.blocks import HeaderBlock from slack_sdk.models.blocks import Option from slack_sdk.models.blocks import RadioButtonsElement from slack_sdk.models.blocks import SectionBlock from slack_sdk.models.blocks.basic_components import MarkdownTextObject from slack_sdk.models.blocks.block_elements import ImageElement from onyx.chat.models import ChatBasicResponse from onyx.configs.app_configs import WEB_DOMAIN from onyx.configs.constants import DocumentSource from onyx.configs.constants import SearchFeedbackType from onyx.configs.onyxbot_configs import ONYX_BOT_NUM_DOCS_TO_DISPLAY from onyx.context.search.models import SearchDoc from onyx.db.chat import get_chat_session_by_message_id from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import ChannelConfig from onyx.onyxbot.slack.constants import CONTINUE_IN_WEB_UI_ACTION_ID from onyx.onyxbot.slack.constants import DISLIKE_BLOCK_ACTION_ID from onyx.onyxbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID from onyx.onyxbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID from onyx.onyxbot.slack.constants import KEEP_TO_YOURSELF_ACTION_ID from onyx.onyxbot.slack.constants import LIKE_BLOCK_ACTION_ID from onyx.onyxbot.slack.constants import SHOW_EVERYONE_ACTION_ID from onyx.onyxbot.slack.formatting import format_slack_message from onyx.onyxbot.slack.icons import source_to_github_img_link from onyx.onyxbot.slack.models import ActionValuesEphemeralMessage from onyx.onyxbot.slack.models import ActionValuesEphemeralMessageChannelConfig from onyx.onyxbot.slack.models import ActionValuesEphemeralMessageMessageInfo from onyx.onyxbot.slack.models import SlackMessageInfo from onyx.onyxbot.slack.utils import build_continue_in_web_ui_id from onyx.onyxbot.slack.utils import build_feedback_id from onyx.onyxbot.slack.utils import build_publish_ephemeral_message_id from onyx.onyxbot.slack.utils import remove_slack_text_interactions from onyx.onyxbot.slack.utils import translate_vespa_highlight_to_slack from onyx.utils.text_processing import decode_escapes _MAX_BLURB_LEN = 45 def _format_doc_updated_at(updated_at: datetime | None) -> str | None: """Convert document timestamps to a human friendly relative string.""" if updated_at is None: return None if updated_at.tzinfo is None or updated_at.tzinfo.utcoffset(updated_at) is None: aware_updated_at = updated_at.replace(tzinfo=pytz.utc) else: aware_updated_at = updated_at.astimezone(pytz.utc) return timeago.format(aware_updated_at, datetime.now(pytz.utc)) def get_feedback_reminder_blocks(thread_link: str, include_followup: bool) -> Block: text = ( f"Please provide feedback on <{thread_link}|this answer>. " "This is essential to help us to improve the quality of the answers. " "Please rate it by clicking the `Helpful` or `Not helpful` button. " ) if include_followup: text += "\n\nIf you need more help, click the `I need more help from a human!` button. " text += "\n\nThanks!" return SectionBlock(text=text) def _split_text(text: str, limit: int = 3000) -> list[str]: if len(text) <= limit: return [text] chunks = [] while text: if len(text) <= limit: chunks.append(text) break # Find the nearest space before the limit to avoid splitting a word split_at = text.rfind(" ", 0, limit) if split_at == -1: # No spaces found, force split split_at = limit chunk = text[:split_at] chunks.append(chunk) text = text[split_at:].lstrip() # Remove leading spaces from the next chunk return chunks def _clean_markdown_link_text(text: str) -> str: # Remove any newlines within the text return format_slack_message(text).replace("\n", " ").strip() def _build_qa_feedback_block( message_id: int, feedback_reminder_id: str | None = None ) -> Block: return ActionsBlock( block_id=build_feedback_id(message_id), elements=[ ButtonElement( action_id=LIKE_BLOCK_ACTION_ID, text="👍 Helpful", value=feedback_reminder_id, ), ButtonElement( action_id=DISLIKE_BLOCK_ACTION_ID, text="👎 Not helpful", value=feedback_reminder_id, ), ], ) def _build_ephemeral_publication_block( channel_id: str, # noqa: ARG001 chat_message_id: int, message_info: SlackMessageInfo, original_question_ts: str, channel_conf: ChannelConfig, feedback_reminder_id: str | None = None, ) -> Block: # check whether the message is in a thread if ( message_info is not None and message_info.msg_to_respond is not None and message_info.thread_to_respond is not None and (message_info.msg_to_respond == message_info.thread_to_respond) ): respond_ts = None else: respond_ts = original_question_ts action_values_ephemeral_message_channel_config = ( ActionValuesEphemeralMessageChannelConfig( channel_name=channel_conf.get("channel_name"), respond_tag_only=channel_conf.get("respond_tag_only"), respond_to_bots=channel_conf.get("respond_to_bots"), is_ephemeral=channel_conf.get("is_ephemeral", False), respond_member_group_list=channel_conf.get("respond_member_group_list"), answer_filters=channel_conf.get("answer_filters"), follow_up_tags=channel_conf.get("follow_up_tags"), show_continue_in_web_ui=channel_conf.get("show_continue_in_web_ui", False), ) ) action_values_ephemeral_message_message_info = ( ActionValuesEphemeralMessageMessageInfo( bypass_filters=message_info.bypass_filters, channel_to_respond=message_info.channel_to_respond, msg_to_respond=message_info.msg_to_respond, email=message_info.email, sender_id=message_info.sender_id, thread_messages=[], is_slash_command=message_info.is_slash_command, is_bot_dm=message_info.is_bot_dm, thread_to_respond=respond_ts, ) ) action_values_ephemeral_message = ActionValuesEphemeralMessage( original_question_ts=original_question_ts, feedback_reminder_id=feedback_reminder_id, chat_message_id=chat_message_id, message_info=action_values_ephemeral_message_message_info, channel_conf=action_values_ephemeral_message_channel_config, ) return ActionsBlock( block_id=build_publish_ephemeral_message_id(original_question_ts), elements=[ ButtonElement( action_id=SHOW_EVERYONE_ACTION_ID, text="📢 Share with Everyone", value=action_values_ephemeral_message.model_dump_json(), ), ButtonElement( action_id=KEEP_TO_YOURSELF_ACTION_ID, text="🤫 Keep to Yourself", value=action_values_ephemeral_message.model_dump_json(), ), ], ) def get_document_feedback_blocks() -> Block: return SectionBlock( text=( "- 'Up-Boost' if this document is a good source of information and should be " "shown more often.\n" "- 'Down-boost' if this document is a poor source of information and should be " "shown less often.\n" "- 'Hide' if this document is deprecated and should never be shown anymore." ), accessory=RadioButtonsElement( options=[ Option( text=":thumbsup: Up-Boost", value=SearchFeedbackType.ENDORSE.value, ), Option( text=":thumbsdown: Down-Boost", value=SearchFeedbackType.REJECT.value, ), Option( text=":x: Hide", value=SearchFeedbackType.HIDE.value, ), ] ), ) def _build_doc_feedback_block( message_id: int, document_id: str, document_rank: int, ) -> ButtonElement: feedback_id = build_feedback_id(message_id, document_id, document_rank) return ButtonElement( action_id=FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID, value=feedback_id, text="Give Feedback", ) def get_restate_blocks( msg: str, is_slash_command: bool, ) -> list[Block]: # Only the slash command needs this context because the user doesn't see their own input if not is_slash_command: return [] return [ HeaderBlock(text="Responding to the Query"), SectionBlock(text=f"```{msg}```"), ] def _build_documents_blocks( documents: list[SearchDoc], message_id: int | None, num_docs_to_display: int = ONYX_BOT_NUM_DOCS_TO_DISPLAY, ) -> list[Block]: header_text = "Reference Documents" seen_docs_identifiers = set() section_blocks: list[Block] = [HeaderBlock(text=header_text)] included_docs = 0 for rank, d in enumerate(documents): if d.document_id in seen_docs_identifiers: continue seen_docs_identifiers.add(d.document_id) # Strip newlines from the semantic identifier for Slackbot formatting doc_sem_id = d.semantic_identifier.replace("\n", " ") if d.source_type == DocumentSource.SLACK.value: doc_sem_id = "#" + doc_sem_id used_chars = len(doc_sem_id) + 3 match_str = translate_vespa_highlight_to_slack(d.match_highlights, used_chars) included_docs += 1 header_line = f"{doc_sem_id}\n" if d.link: header_line = f"<{d.link}|{doc_sem_id}>\n" updated_at_line = "" updated_at_str = _format_doc_updated_at(d.updated_at) if updated_at_str: updated_at_line = f"_Updated {updated_at_str}_\n" body_text = f">{remove_slack_text_interactions(match_str)}" block_text = header_line + updated_at_line + body_text feedback: ButtonElement | dict = {} if message_id is not None: feedback = _build_doc_feedback_block( message_id=message_id, document_id=d.document_id, document_rank=rank, ) section_blocks.append( SectionBlock(text=block_text, accessory=feedback), ) section_blocks.append(DividerBlock()) if included_docs >= num_docs_to_display: break return section_blocks def _build_sources_blocks( cited_documents: list[tuple[int, SearchDoc]], num_docs_to_display: int = ONYX_BOT_NUM_DOCS_TO_DISPLAY, ) -> list[Block]: if not cited_documents: return [ SectionBlock( text="*Warning*: no sources were cited for this answer, so it may be unreliable 😔" ) ] seen_docs_identifiers = set() section_blocks: list[Block] = [SectionBlock(text="*Sources:*")] included_docs = 0 for citation_num, d in cited_documents: if d.document_id in seen_docs_identifiers: continue seen_docs_identifiers.add(d.document_id) doc_sem_id = d.semantic_identifier if d.source_type == DocumentSource.SLACK.value: # for legacy reasons, before the switch to how Slack semantic identifiers are constructed if "#" not in doc_sem_id: doc_sem_id = "#" + doc_sem_id # this is needed to try and prevent the line from overflowing # if it does overflow, the image gets placed above the title and it # looks bad doc_sem_id = ( doc_sem_id[:_MAX_BLURB_LEN] + "..." if len(doc_sem_id) > _MAX_BLURB_LEN else doc_sem_id ) owner_str = f"By {d.primary_owners[0]}" if d.primary_owners else None days_ago_str = _format_doc_updated_at(d.updated_at) final_metadata_str = " | ".join( ([owner_str] if owner_str else []) + ([days_ago_str] if days_ago_str else []) ) document_title = _clean_markdown_link_text(doc_sem_id) img_link = source_to_github_img_link(d.source_type) section_blocks.append( ContextBlock( elements=( [ ImageElement( image_url=img_link, alt_text=f"{d.source_type.value} logo", ) ] if img_link else [] ) + [ ( MarkdownTextObject(text=f"{document_title}") if d.link == "" else MarkdownTextObject( text=f"*<{d.link}|[{citation_num}] {document_title}>*\n{final_metadata_str}" ) ), ] ) ) if included_docs >= num_docs_to_display: break return section_blocks def _priority_ordered_documents_blocks( answer: ChatBasicResponse, ) -> list[Block]: top_docs = answer.top_documents if answer.top_documents else None if not top_docs: return [] document_blocks = _build_documents_blocks( documents=top_docs, message_id=answer.message_id, ) if document_blocks: document_blocks = [DividerBlock()] + document_blocks return document_blocks def _build_citations_blocks( answer: ChatBasicResponse, ) -> list[Block]: top_docs = answer.top_documents citations = answer.citation_info or [] cited_docs: list[tuple[int, SearchDoc]] = [] for citation_info in citations: matching_doc = next( (d for d in top_docs if d.document_id == citation_info.document_id), None, ) if matching_doc: cited_docs.append((citation_info.citation_number, matching_doc)) cited_docs.sort() citations_block = _build_sources_blocks(cited_documents=cited_docs) return citations_block def _build_main_response_blocks( answer: ChatBasicResponse, ) -> list[Block]: # TODO: add back in later when auto-filtering is implemented # if ( # retrieval_info.applied_time_cutoff # or retrieval_info.recency_bias_multiplier > 1 # or retrieval_info.applied_source_filters # ): # filter_text = "Filters: " # if retrieval_info.applied_source_filters: # sources_str = ", ".join( # [s.value for s in retrieval_info.applied_source_filters] # ) # filter_text += f"`Sources in [{sources_str}]`" # if ( # retrieval_info.applied_time_cutoff # or retrieval_info.recency_bias_multiplier > 1 # ): # filter_text += " and " # if retrieval_info.applied_time_cutoff is not None: # time_str = retrieval_info.applied_time_cutoff.strftime("%b %d, %Y") # filter_text += f"`Docs Updated >= {time_str}` " # if retrieval_info.recency_bias_multiplier > 1: # if retrieval_info.applied_time_cutoff is not None: # filter_text += "+ " # filter_text += "`Prioritize Recently Updated Docs`" # filter_block = SectionBlock(text=f"_{filter_text}_") # replaces markdown links with slack format links formatted_answer = format_slack_message(answer.answer) answer_processed = decode_escapes(remove_slack_text_interactions(formatted_answer)) answer_blocks = [SectionBlock(text=text) for text in _split_text(answer_processed)] return cast(list[Block], answer_blocks) def _build_continue_in_web_ui_block( message_id: int | None, ) -> Block: if message_id is None: raise ValueError("No message id provided to build continue in web ui block") with get_session_with_current_tenant() as db_session: chat_session = get_chat_session_by_message_id( db_session=db_session, message_id=message_id, ) return ActionsBlock( block_id=build_continue_in_web_ui_id(message_id), elements=[ ButtonElement( action_id=CONTINUE_IN_WEB_UI_ACTION_ID, text="Continue Chat in Onyx!", style="primary", url=f"{WEB_DOMAIN}/chat?slackChatId={chat_session.id}", ), ], ) def _build_follow_up_block(message_id: int | None) -> ActionsBlock: return ActionsBlock( block_id=build_feedback_id(message_id) if message_id is not None else None, elements=[ ButtonElement( action_id=IMMEDIATE_RESOLVED_BUTTON_ACTION_ID, style="primary", text="I'm all set!", ), ButtonElement( action_id=FOLLOWUP_BUTTON_ACTION_ID, style="danger", text="I need more help from a human!", ), ], ) def build_follow_up_resolved_blocks( tag_ids: list[str], group_ids: list[str] ) -> list[Block]: tag_str = " ".join([f"<@{tag}>" for tag in tag_ids]) if tag_str: tag_str += " " group_str = " ".join([f"" for group_id in group_ids]) if group_str: group_str += " " text = ( tag_str + group_str + "Someone has requested more help.\n\n:point_down:Please mark this resolved after answering!" ) text_block = SectionBlock(text=text) button_block = ActionsBlock( elements=[ ButtonElement( action_id=FOLLOWUP_BUTTON_RESOLVED_ACTION_ID, style="primary", text="Mark Resolved", ) ] ) return [text_block, button_block] def build_slack_response_blocks( answer: ChatBasicResponse, message_info: SlackMessageInfo, channel_conf: ChannelConfig | None, feedback_reminder_id: str | None, skip_ai_feedback: bool = False, offer_ephemeral_publication: bool = False, skip_restated_question: bool = False, ) -> list[Block]: """ This function is a top level function that builds all the blocks for the Slack response. It also handles combining all the blocks together. """ # If called with the OnyxBot slash command, the question is lost so we have to reshow it if not skip_restated_question: restate_question_block = get_restate_blocks( message_info.thread_messages[-1].message, message_info.is_slash_command ) else: restate_question_block = [] answer_blocks = _build_main_response_blocks(answer) web_follow_up_block = [] if channel_conf and channel_conf.get("show_continue_in_web_ui"): web_follow_up_block.append( _build_continue_in_web_ui_block( message_id=answer.message_id, ) ) follow_up_block = [] if ( channel_conf and channel_conf.get("follow_up_tags") is not None and not channel_conf.get("is_ephemeral", False) ): follow_up_block.append(_build_follow_up_block(message_id=answer.message_id)) publish_ephemeral_message_block = [] if ( offer_ephemeral_publication and answer.message_id is not None and message_info.msg_to_respond is not None and channel_conf is not None ): publish_ephemeral_message_block.append( _build_ephemeral_publication_block( channel_id=message_info.channel_to_respond, chat_message_id=answer.message_id, original_question_ts=message_info.msg_to_respond, message_info=message_info, channel_conf=channel_conf, feedback_reminder_id=feedback_reminder_id, ) ) ai_feedback_block: list[Block] = [] if answer.message_id is not None and not skip_ai_feedback: ai_feedback_block.append( _build_qa_feedback_block( message_id=answer.message_id, feedback_reminder_id=feedback_reminder_id, ) ) citations_blocks = [] if answer.citation_info: citations_blocks = _build_citations_blocks(answer) citations_divider = [DividerBlock()] if citations_blocks else [] buttons_divider = [DividerBlock()] if web_follow_up_block or follow_up_block else [] all_blocks = ( restate_question_block + answer_blocks + publish_ephemeral_message_block + ai_feedback_block + citations_divider + citations_blocks + buttons_divider + web_follow_up_block + follow_up_block ) return all_blocks ================================================ FILE: backend/onyx/onyxbot/slack/config.py ================================================ import os from sqlalchemy.orm import Session from onyx.db.models import SlackChannelConfig from onyx.db.slack_channel_config import ( fetch_slack_channel_config_for_channel_or_default, ) from onyx.db.slack_channel_config import fetch_slack_channel_configs VALID_SLACK_FILTERS = [ "answerable_prefilter", "well_answered_postfilter", "questionmark_prefilter", ] def get_slack_channel_config_for_bot_and_channel( db_session: Session, slack_bot_id: int, channel_name: str | None, ) -> SlackChannelConfig: slack_bot_config = fetch_slack_channel_config_for_channel_or_default( db_session=db_session, slack_bot_id=slack_bot_id, channel_name=channel_name ) if not slack_bot_config: raise ValueError( "No default configuration has been set for this Slack bot. This should not be possible." ) return slack_bot_config def validate_channel_name( db_session: Session, current_slack_bot_id: int, channel_name: str, current_slack_channel_config_id: int | None, ) -> str: """Make sure that this channel_name does not exist in other Slack channel configs. Returns a cleaned up channel name (e.g. '#' removed if present)""" slack_bot_configs = fetch_slack_channel_configs( db_session=db_session, slack_bot_id=current_slack_bot_id, ) cleaned_channel_name = channel_name.lstrip("#").lower() for slack_channel_config in slack_bot_configs: if slack_channel_config.id == current_slack_channel_config_id: continue if cleaned_channel_name == slack_channel_config.channel_config["channel_name"]: raise ValueError( f"Channel name '{channel_name}' already exists in " "another Slack channel config with in Slack Bot with name: " f"{slack_channel_config.slack_bot.name}" ) return cleaned_channel_name # Scaling configurations for multi-tenant Slack channel handling TENANT_LOCK_EXPIRATION = 1800 # How long a pod can hold exclusive access to a tenant before other pods can acquire it TENANT_HEARTBEAT_INTERVAL = ( 15 # How often pods send heartbeats to indicate they are still processing a tenant ) TENANT_HEARTBEAT_EXPIRATION = ( 60 # How long before a tenant's heartbeat expires, allowing other pods to take over ) TENANT_ACQUISITION_INTERVAL = 60 # How often pods attempt to acquire unprocessed tenants and checks for new tokens MAX_TENANTS_PER_POD = int(os.getenv("MAX_TENANTS_PER_POD", 50)) ================================================ FILE: backend/onyx/onyxbot/slack/constants.py ================================================ import re from enum import Enum # Matches Slack channel references like <#C097NBWMY8Y> or <#C097NBWMY8Y|channel-name> SLACK_CHANNEL_REF_PATTERN = re.compile(r"<#([A-Z0-9]+)(?:\|([^>]+))?>") LIKE_BLOCK_ACTION_ID = "feedback-like" DISLIKE_BLOCK_ACTION_ID = "feedback-dislike" SHOW_EVERYONE_ACTION_ID = "show-everyone" KEEP_TO_YOURSELF_ACTION_ID = "keep-to-yourself" CONTINUE_IN_WEB_UI_ACTION_ID = "continue-in-web-ui" FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID = "feedback-doc-button" IMMEDIATE_RESOLVED_BUTTON_ACTION_ID = "immediate-resolved-button" FOLLOWUP_BUTTON_ACTION_ID = "followup-button" FOLLOWUP_BUTTON_RESOLVED_ACTION_ID = "followup-resolved-button" VIEW_DOC_FEEDBACK_ID = "view-doc-feedback" GENERATE_ANSWER_BUTTON_ACTION_ID = "generate-answer-button" class FeedbackVisibility(str, Enum): PRIVATE = "private" ANONYMOUS = "anonymous" PUBLIC = "public" ================================================ FILE: backend/onyx/onyxbot/slack/formatting.py ================================================ import re from collections.abc import Callable from typing import Any from mistune import create_markdown from mistune import HTMLRenderer # Tags that should be replaced with a newline (line-break and block-level elements) _HTML_NEWLINE_TAG_PATTERN = re.compile( r"|", re.IGNORECASE, ) # Strips HTML tags but excludes autolinks like and _HTML_TAG_PATTERN = re.compile( r"<(?!https?://|mailto:)/?[a-zA-Z][^>]*>", ) # Matches fenced code blocks (``` ... ```) so we can skip sanitization inside them _FENCED_CODE_BLOCK_PATTERN = re.compile(r"```[\s\S]*?```") # Matches the start of any markdown link: [text]( or [[n]]( # The inner group handles nested brackets for citation links like [[1]](. _MARKDOWN_LINK_PATTERN = re.compile(r"\[(?:[^\[\]]|\[[^\]]*\])*\]\(") # Matches Slack-style links that LLMs sometimes output directly. # Mistune doesn't recognise this syntax, so text() would escape the angle # brackets and Slack would render them as literal text instead of links. _SLACK_LINK_PATTERN = re.compile(r"<(https?://[^|>]+)\|([^>]+)>") def _sanitize_html(text: str) -> str: """Strip HTML tags from a text fragment. Block-level closing tags and
    are converted to newlines. All other HTML tags are removed. Autolinks () are preserved. """ text = _HTML_NEWLINE_TAG_PATTERN.sub("\n", text) text = _HTML_TAG_PATTERN.sub("", text) return text def _transform_outside_code_blocks( message: str, transform: Callable[[str], str] ) -> str: """Apply *transform* only to text outside fenced code blocks.""" parts = _FENCED_CODE_BLOCK_PATTERN.split(message) code_blocks = _FENCED_CODE_BLOCK_PATTERN.findall(message) result: list[str] = [] for i, part in enumerate(parts): result.append(transform(part)) if i < len(code_blocks): result.append(code_blocks[i]) return "".join(result) def _extract_link_destination(message: str, start_idx: int) -> tuple[str, int | None]: """Extract markdown link destination, allowing nested parentheses in the URL.""" depth = 0 i = start_idx while i < len(message): curr = message[i] if curr == "\\": i += 2 continue if curr == "(": depth += 1 elif curr == ")": if depth == 0: return message[start_idx:i], i depth -= 1 i += 1 return message[start_idx:], None def _normalize_link_destinations(message: str) -> str: """Wrap markdown link URLs in angle brackets so the parser handles special chars safely. Markdown link syntax [text](url) breaks when the URL contains unescaped parentheses, spaces, or other special characters. Wrapping the URL in angle brackets — [text]() — tells the parser to treat everything inside as a literal URL. This applies to all links, not just citations. """ if "](" not in message: return message normalized_parts: list[str] = [] cursor = 0 while match := _MARKDOWN_LINK_PATTERN.search(message, cursor): normalized_parts.append(message[cursor : match.end()]) destination_start = match.end() destination, end_idx = _extract_link_destination(message, destination_start) if end_idx is None: normalized_parts.append(message[destination_start:]) return "".join(normalized_parts) already_wrapped = destination.startswith("<") and destination.endswith(">") if destination and not already_wrapped: destination = f"<{destination}>" normalized_parts.append(destination) normalized_parts.append(")") cursor = end_idx + 1 normalized_parts.append(message[cursor:]) return "".join(normalized_parts) def _convert_slack_links_to_markdown(message: str) -> str: """Convert Slack-style links to standard markdown [text](url). LLMs sometimes emit Slack mrkdwn link syntax directly. Mistune doesn't recognise it, so the angle brackets would be escaped by text() and Slack would render the link as literal text instead of a clickable link. """ return _transform_outside_code_blocks( message, lambda text: _SLACK_LINK_PATTERN.sub(r"[\2](\1)", text) ) def format_slack_message(message: str | None) -> str: if message is None: return "" message = _transform_outside_code_blocks(message, _sanitize_html) message = _convert_slack_links_to_markdown(message) normalized_message = _normalize_link_destinations(message) md = create_markdown(renderer=SlackRenderer(), plugins=["strikethrough", "table"]) result = md(normalized_message) # With HTMLRenderer, result is always str (not AST list) assert isinstance(result, str) return result.rstrip("\n") class SlackRenderer(HTMLRenderer): """Renders markdown as Slack mrkdwn format instead of HTML. Overrides all HTMLRenderer methods that produce HTML tags to ensure no raw HTML ever appears in Slack messages. """ SPECIALS: dict[str, str] = {"&": "&", "<": "<", ">": ">"} def __init__(self) -> None: super().__init__() self._table_headers: list[str] = [] self._current_row_cells: list[str] = [] def escape_special(self, text: str) -> str: for special, replacement in self.SPECIALS.items(): text = text.replace(special, replacement) return text def heading(self, text: str, level: int, **attrs: Any) -> str: # noqa: ARG002 return f"*{text}*\n\n" def emphasis(self, text: str) -> str: return f"_{text}_" def strong(self, text: str) -> str: return f"*{text}*" def strikethrough(self, text: str) -> str: return f"~{text}~" def list(self, text: str, ordered: bool, **attrs: Any) -> str: # noqa: ARG002 lines = text.split("\n") count = 0 for i, line in enumerate(lines): if line.startswith("li: "): count += 1 prefix = f"{count}. " if ordered else "• " lines[i] = f"{prefix}{line[4:]}" return "\n".join(lines) + "\n" def list_item(self, text: str) -> str: return f"li: {text}\n" def link(self, text: str, url: str, title: str | None = None) -> str: escaped_url = self.escape_special(url) if text: return f"<{escaped_url}|{text}>" if title: return f"<{escaped_url}|{title}>" return f"<{escaped_url}>" def image(self, text: str, url: str, title: str | None = None) -> str: escaped_url = self.escape_special(url) display_text = title or text return f"<{escaped_url}|{display_text}>" if display_text else f"<{escaped_url}>" def codespan(self, text: str) -> str: return f"`{text}`" def block_code(self, code: str, info: str | None = None) -> str: # noqa: ARG002 return f"```\n{code.rstrip(chr(10))}\n```\n\n" def linebreak(self) -> str: return "\n" def thematic_break(self) -> str: return "---\n\n" def block_quote(self, text: str) -> str: lines = text.strip().split("\n") quoted = "\n".join(f">{line}" for line in lines) return quoted + "\n\n" def block_html(self, html: str) -> str: return _sanitize_html(html) + "\n\n" def block_error(self, text: str) -> str: return f"```\n{text}\n```\n\n" def text(self, text: str) -> str: # Only escape the three entities Slack recognizes: & < > # HTMLRenderer.text() also escapes " to " which Slack renders # as literal " text since Slack doesn't recognize that entity. return self.escape_special(text) # -- Table rendering (converts markdown tables to vertical cards) -- def table_cell( self, text: str, align: str | None = None, # noqa: ARG002 head: bool = False, # noqa: ARG002 ) -> str: if head: self._table_headers.append(text.strip()) else: self._current_row_cells.append(text.strip()) return "" def table_head(self, text: str) -> str: # noqa: ARG002 self._current_row_cells = [] return "" def table_row(self, text: str) -> str: # noqa: ARG002 cells = self._current_row_cells self._current_row_cells = [] # First column becomes the bold title, remaining columns are bulleted fields lines: list[str] = [] if cells: title = cells[0] if title: # Avoid double-wrapping if cell already contains bold markup if title.startswith("*") and title.endswith("*") and len(title) > 1: lines.append(title) else: lines.append(f"*{title}*") for i, cell in enumerate(cells[1:], start=1): if i < len(self._table_headers): lines.append(f" • {self._table_headers[i]}: {cell}") else: lines.append(f" • {cell}") return "\n".join(lines) + "\n\n" def table_body(self, text: str) -> str: return text def table(self, text: str) -> str: self._table_headers = [] self._current_row_cells = [] return text + "\n" def paragraph(self, text: str) -> str: return f"{text}\n\n" ================================================ FILE: backend/onyx/onyxbot/slack/handlers/__init__.py ================================================ ================================================ FILE: backend/onyx/onyxbot/slack/handlers/handle_buttons.py ================================================ import json from typing import Any from typing import cast from slack_sdk import WebClient from slack_sdk.models.blocks import SectionBlock from slack_sdk.models.views import View from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.webhook import WebhookClient from onyx.chat.models import ChatBasicResponse from onyx.chat.process_message import remove_answer_citations from onyx.configs.constants import MessageType from onyx.configs.constants import SearchFeedbackType from onyx.configs.onyxbot_configs import ONYX_BOT_FOLLOWUP_EMOJI from onyx.connectors.slack.utils import expert_info_from_slack_id from onyx.context.search.models import SavedSearchDoc from onyx.context.search.models import SearchDoc from onyx.db.chat import get_chat_message from onyx.db.chat import translate_db_message_to_chat_message_detail from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.feedback import create_chat_message_feedback from onyx.db.feedback import create_doc_retrieval_feedback from onyx.db.users import get_user_by_email from onyx.onyxbot.slack.blocks import build_follow_up_resolved_blocks from onyx.onyxbot.slack.blocks import build_slack_response_blocks from onyx.onyxbot.slack.blocks import get_document_feedback_blocks from onyx.onyxbot.slack.config import get_slack_channel_config_for_bot_and_channel from onyx.onyxbot.slack.constants import DISLIKE_BLOCK_ACTION_ID from onyx.onyxbot.slack.constants import FeedbackVisibility from onyx.onyxbot.slack.constants import KEEP_TO_YOURSELF_ACTION_ID from onyx.onyxbot.slack.constants import LIKE_BLOCK_ACTION_ID from onyx.onyxbot.slack.constants import SHOW_EVERYONE_ACTION_ID from onyx.onyxbot.slack.constants import VIEW_DOC_FEEDBACK_ID from onyx.onyxbot.slack.handlers.handle_message import ( remove_scheduled_feedback_reminder, ) from onyx.onyxbot.slack.handlers.handle_regular_answer import ( handle_regular_answer, ) from onyx.onyxbot.slack.models import SlackMessageInfo from onyx.onyxbot.slack.utils import build_feedback_id from onyx.onyxbot.slack.utils import decompose_action_id from onyx.onyxbot.slack.utils import fetch_group_ids_from_names from onyx.onyxbot.slack.utils import fetch_slack_user_ids_from_emails from onyx.onyxbot.slack.utils import get_channel_name_from_id from onyx.onyxbot.slack.utils import get_feedback_visibility from onyx.onyxbot.slack.utils import read_slack_thread from onyx.onyxbot.slack.utils import respond_in_thread_or_channel from onyx.onyxbot.slack.utils import TenantSocketModeClient from onyx.onyxbot.slack.utils import update_emote_react from onyx.server.query_and_chat.models import ChatMessageDetail from onyx.server.query_and_chat.streaming_models import CitationInfo from onyx.utils.logger import setup_logger logger = setup_logger() def _convert_document_ids_to_citation_info( citation_dict: dict[int, str], top_documents: list[SavedSearchDoc] ) -> list[CitationInfo]: citation_list_with_document_id = [] # Build a set of valid document_ids from top_documents for validation valid_document_ids = {doc.document_id for doc in top_documents} for citation_num, document_id in citation_dict.items(): if document_id is not None and document_id in valid_document_ids: citation_list_with_document_id.append( CitationInfo( citation_number=citation_num, document_id=document_id, ) ) return citation_list_with_document_id def _build_citation_list(chat_message_detail: ChatMessageDetail) -> list[CitationInfo]: citation_dict = chat_message_detail.citations if citation_dict is None: return [] else: top_documents = ( chat_message_detail.context_docs if chat_message_detail.context_docs else [] ) citation_list = _convert_document_ids_to_citation_info( citation_dict, top_documents ) return citation_list def handle_doc_feedback_button( req: SocketModeRequest, client: TenantSocketModeClient, ) -> None: if not (actions := req.payload.get("actions")): logger.error("Missing actions. Unable to build the source feedback view") return # Extracts the feedback_id coming from the 'source feedback' button # and generates a new one for the View, to keep track of the doc info query_event_id, doc_id, doc_rank = decompose_action_id(actions[0].get("value")) external_id = build_feedback_id(query_event_id, doc_id, doc_rank) channel_id = req.payload["container"]["channel_id"] thread_ts = req.payload["container"].get("thread_ts", None) data = View( type="modal", callback_id=VIEW_DOC_FEEDBACK_ID, external_id=external_id, # We use the private metadata to keep track of the channel id and thread ts private_metadata=f"{channel_id}_{thread_ts}", title="Give Feedback", blocks=[get_document_feedback_blocks()], submit="send", close="cancel", ) client.web_client.views_open( trigger_id=req.payload["trigger_id"], view=data.to_dict() ) def handle_generate_answer_button( req: SocketModeRequest, client: TenantSocketModeClient, ) -> None: channel_id = req.payload["channel"]["id"] channel_name = req.payload["channel"]["name"] message_ts = req.payload["message"]["ts"] thread_ts = req.payload["container"].get("thread_ts", None) user_id = req.payload["user"]["id"] expert_info = expert_info_from_slack_id(user_id, client.web_client, user_cache={}) email = expert_info.email if expert_info else None if not thread_ts: raise ValueError("Missing thread_ts in the payload") thread_messages = read_slack_thread( tenant_id=client._tenant_id, channel=channel_id, thread=thread_ts, client=client.web_client, ) # remove all assistant messages till we get to the last user message # we want the new answer to be generated off of the last "question" in # the thread for i in range(len(thread_messages) - 1, -1, -1): if thread_messages[i].role == MessageType.USER: break if thread_messages[i].role == MessageType.ASSISTANT: thread_messages.pop(i) # tell the user that we're working on it # Send an ephemeral message to the user that we're generating the answer respond_in_thread_or_channel( client=client.web_client, channel=channel_id, receiver_ids=[user_id], text="I'm working on generating a full answer for you. This may take a moment...", thread_ts=thread_ts, ) with get_session_with_current_tenant() as db_session: slack_channel_config = get_slack_channel_config_for_bot_and_channel( db_session=db_session, slack_bot_id=client.slack_bot_id, channel_name=channel_name, ) handle_regular_answer( message_info=SlackMessageInfo( thread_messages=thread_messages, channel_to_respond=channel_id, msg_to_respond=cast(str, message_ts or thread_ts), thread_to_respond=cast(str, thread_ts or message_ts), sender_id=user_id or None, email=email or None, bypass_filters=True, is_slash_command=False, is_bot_dm=False, ), slack_channel_config=slack_channel_config, receiver_ids=None, client=client.web_client, channel=channel_id, logger=logger, feedback_reminder_id=None, ) def handle_publish_ephemeral_message_button( req: SocketModeRequest, client: TenantSocketModeClient, action_id: str, ) -> None: """ This function handles the Share with Everyone/Keep for Yourself buttons for ephemeral messages. """ channel_id = req.payload["channel"]["id"] ephemeral_message_ts = req.payload["container"]["message_ts"] slack_sender_id = req.payload["user"]["id"] response_url = req.payload["response_url"] webhook = WebhookClient(url=response_url) # The additional data required that was added to buttons. # Specifically, this contains the message_info, channel_conf information # and some additional attributes. value_dict = json.loads(req.payload["actions"][0]["value"]) original_question_ts = value_dict.get("original_question_ts") if not original_question_ts: raise ValueError("Missing original_question_ts in the payload") if not ephemeral_message_ts: raise ValueError("Missing ephemeral_message_ts in the payload") feedback_reminder_id = value_dict.get("feedback_reminder_id") slack_message_info = SlackMessageInfo(**value_dict["message_info"]) channel_conf = value_dict.get("channel_conf") user_email = value_dict.get("message_info", {}).get("email") chat_message_id = value_dict.get("chat_message_id") # Obtain onyx_user and chat_message information if not chat_message_id: raise ValueError("Missing chat_message_id in the payload") with get_session_with_current_tenant() as db_session: onyx_user = get_user_by_email(user_email, db_session) if not onyx_user: raise ValueError("Cannot determine onyx_user_id from email in payload") try: chat_message = get_chat_message(chat_message_id, onyx_user.id, db_session) except ValueError: chat_message = get_chat_message( chat_message_id, None, db_session ) # is this good idea? except Exception as e: logger.error(f"Failed to get chat message: {e}") raise e chat_message_detail = translate_db_message_to_chat_message_detail(chat_message) # construct the proper citation format and then the answer in the suitable format # we need to construct the blocks. citation_list = _build_citation_list(chat_message_detail) if chat_message_detail.context_docs: top_documents: list[SearchDoc] = [ SearchDoc.from_saved_search_doc(doc) for doc in chat_message_detail.context_docs ] else: top_documents = [] onyx_bot_answer = ChatBasicResponse( answer=chat_message_detail.message, answer_citationless=remove_answer_citations(chat_message_detail.message), top_documents=top_documents, message_id=chat_message_id, error_msg=None, citation_info=citation_list, ) # Note: we need to use the webhook and the respond_url to update/delete ephemeral messages if action_id == SHOW_EVERYONE_ACTION_ID: # Convert to non-ephemeral message in thread try: webhook.send( response_type="ephemeral", text="", blocks=[], replace_original=True, delete_original=True, ) except Exception as e: logger.error(f"Failed to send webhook: {e}") # remove handling of empheremal block and add AI feedback. all_blocks = build_slack_response_blocks( answer=onyx_bot_answer, message_info=slack_message_info, channel_conf=channel_conf, feedback_reminder_id=feedback_reminder_id, skip_ai_feedback=False, offer_ephemeral_publication=False, skip_restated_question=True, ) try: # Post in thread as non-ephemeral message respond_in_thread_or_channel( client=client.web_client, channel=channel_id, receiver_ids=None, # If respond_member_group_list is set, send to them. TODO: check! text="Hello! Onyx has some results for you!", blocks=all_blocks, thread_ts=original_question_ts, # don't unfurl, since otherwise we will have 5+ previews which makes the message very long unfurl=False, send_as_ephemeral=False, ) except Exception as e: logger.error(f"Failed to publish ephemeral message: {e}") raise e elif action_id == KEEP_TO_YOURSELF_ACTION_ID: # Keep as ephemeral message in channel or thread, but remove the publish button and add feedback button changed_blocks = build_slack_response_blocks( answer=onyx_bot_answer, message_info=slack_message_info, channel_conf=channel_conf, feedback_reminder_id=feedback_reminder_id, skip_ai_feedback=False, offer_ephemeral_publication=False, skip_restated_question=True, ) try: if slack_message_info.thread_to_respond is not None: # There seems to be a bug in slack where an update within the thread # actually leads to the update to be posted in the channel. Therefore, # for now we delete the original ephemeral message and post a new one # if the ephemeral message is in a thread. webhook.send( response_type="ephemeral", text="", blocks=[], replace_original=True, delete_original=True, ) respond_in_thread_or_channel( client=client.web_client, channel=channel_id, receiver_ids=[slack_sender_id], text="Your personal response, sent as an ephemeral message.", blocks=changed_blocks, thread_ts=original_question_ts, # don't unfurl, since otherwise we will have 5+ previews which makes the message very long unfurl=False, send_as_ephemeral=True, ) else: # This works fine if the ephemeral message is in the channel webhook.send( response_type="ephemeral", text="Your personal response, sent as an ephemeral message.", blocks=changed_blocks, replace_original=True, delete_original=False, ) except Exception as e: logger.error(f"Failed to send webhook: {e}") def handle_slack_feedback( feedback_id: str, feedback_type: str, feedback_msg_reminder: str, client: WebClient, user_id_to_post_confirmation: str, channel_id_to_post_confirmation: str, thread_ts_to_post_confirmation: str, ) -> None: message_id, doc_id, doc_rank = decompose_action_id(feedback_id) # Get Onyx user from Slack ID expert_info = expert_info_from_slack_id( user_id_to_post_confirmation, client, user_cache={} ) email = expert_info.email if expert_info else None with get_session_with_current_tenant() as db_session: onyx_user = get_user_by_email(email, db_session) if email else None if feedback_type in [LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID]: create_chat_message_feedback( is_positive=feedback_type == LIKE_BLOCK_ACTION_ID, feedback_text="", chat_message_id=message_id, user_id=onyx_user.id if onyx_user else None, db_session=db_session, ) remove_scheduled_feedback_reminder( client=client, channel=user_id_to_post_confirmation, msg_id=feedback_msg_reminder, ) elif feedback_type in [ SearchFeedbackType.ENDORSE.value, SearchFeedbackType.REJECT.value, SearchFeedbackType.HIDE.value, ]: if doc_id is None or doc_rank is None: raise ValueError("Missing information for Document Feedback") if feedback_type == SearchFeedbackType.ENDORSE.value: feedback = SearchFeedbackType.ENDORSE elif feedback_type == SearchFeedbackType.REJECT.value: feedback = SearchFeedbackType.REJECT else: feedback = SearchFeedbackType.HIDE create_doc_retrieval_feedback( message_id=message_id, document_id=doc_id, document_rank=doc_rank, db_session=db_session, clicked=False, # Not tracking this for Slack feedback=feedback, ) else: logger.error(f"Feedback type '{feedback_type}' not supported") if get_feedback_visibility() == FeedbackVisibility.PRIVATE or feedback_type not in [ LIKE_BLOCK_ACTION_ID, DISLIKE_BLOCK_ACTION_ID, ]: client.chat_postEphemeral( channel=channel_id_to_post_confirmation, user=user_id_to_post_confirmation, thread_ts=thread_ts_to_post_confirmation, text="Thanks for your feedback!", ) else: feedback_response_txt = ( "liked" if feedback_type == LIKE_BLOCK_ACTION_ID else "disliked" ) if get_feedback_visibility() == FeedbackVisibility.ANONYMOUS: msg = f"A user has {feedback_response_txt} the AI Answer" else: msg = f"<@{user_id_to_post_confirmation}> has {feedback_response_txt} the AI Answer" respond_in_thread_or_channel( client=client, channel=channel_id_to_post_confirmation, text=msg, thread_ts=thread_ts_to_post_confirmation, unfurl=False, ) def handle_followup_button( req: SocketModeRequest, client: TenantSocketModeClient, ) -> None: action_id = None if actions := req.payload.get("actions"): action = cast(dict[str, Any], actions[0]) action_id = cast(str, action.get("block_id")) channel_id = req.payload["container"]["channel_id"] thread_ts = req.payload["container"].get("thread_ts", None) update_emote_react( emoji=ONYX_BOT_FOLLOWUP_EMOJI, channel=channel_id, message_ts=thread_ts, remove=False, client=client.web_client, ) tag_ids: list[str] = [] group_ids: list[str] = [] with get_session_with_current_tenant() as db_session: channel_name, is_dm = get_channel_name_from_id( client=client.web_client, channel_id=channel_id ) slack_channel_config = get_slack_channel_config_for_bot_and_channel( db_session=db_session, slack_bot_id=client.slack_bot_id, channel_name=channel_name, ) if slack_channel_config: tag_names = slack_channel_config.channel_config.get("follow_up_tags") remaining = None if tag_names: tag_ids, remaining = fetch_slack_user_ids_from_emails( tag_names, client.web_client ) if remaining: group_ids, _ = fetch_group_ids_from_names(remaining, client.web_client) blocks = build_follow_up_resolved_blocks(tag_ids=tag_ids, group_ids=group_ids) respond_in_thread_or_channel( client=client.web_client, channel=channel_id, text="Received your request for more help", blocks=blocks, thread_ts=thread_ts, unfurl=False, ) if action_id is not None: message_id, _, _ = decompose_action_id(action_id) create_chat_message_feedback( is_positive=None, feedback_text="", chat_message_id=message_id, user_id=None, # no "user" for Slack bot for now db_session=db_session, required_followup=True, ) def get_clicker_name( req: SocketModeRequest, client: TenantSocketModeClient, ) -> str: clicker_name = req.payload.get("user", {}).get("name", "Someone") clicker_real_name = None try: clicker = client.web_client.users_info(user=req.payload["user"]["id"]) clicker_real_name = ( cast(dict, clicker.data).get("user", {}).get("profile", {}).get("real_name") ) except Exception: # Likely a scope issue pass if clicker_real_name: clicker_name = clicker_real_name return clicker_name def handle_followup_resolved_button( req: SocketModeRequest, client: TenantSocketModeClient, immediate: bool = False, ) -> None: channel_id = req.payload["container"]["channel_id"] message_ts = req.payload["container"]["message_ts"] thread_ts = req.payload["container"].get("thread_ts", None) clicker_name = get_clicker_name(req, client) update_emote_react( emoji=ONYX_BOT_FOLLOWUP_EMOJI, channel=channel_id, message_ts=thread_ts, remove=True, client=client.web_client, ) # Delete the message with the option to mark resolved if not immediate: response = client.web_client.chat_delete( channel=channel_id, ts=message_ts, ) if not response.get("ok"): logger.error("Unable to delete message for resolved") if immediate: msg_text = f"{clicker_name} has marked this question as resolved!" else: msg_text = ( f"{clicker_name} has marked this question as resolved! " f'\n\n You can always click the "I need more help button" to let the team ' f"know that your problem still needs attention." ) resolved_block = SectionBlock(text=msg_text) respond_in_thread_or_channel( client=client.web_client, channel=channel_id, text="Your request for help as been addressed!", blocks=[resolved_block], thread_ts=thread_ts, unfurl=False, ) ================================================ FILE: backend/onyx/onyxbot/slack/handlers/handle_message.py ================================================ import datetime from slack_sdk import WebClient from slack_sdk.errors import SlackApiError from onyx.configs.onyxbot_configs import ONYX_BOT_FEEDBACK_REMINDER from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.enums import AccountType from onyx.db.models import SlackChannelConfig from onyx.db.user_preferences import activate_user from onyx.db.users import add_slack_user_if_not_exists from onyx.db.users import get_user_by_email from onyx.onyxbot.slack.blocks import get_feedback_reminder_blocks from onyx.onyxbot.slack.handlers.handle_regular_answer import ( handle_regular_answer, ) from onyx.onyxbot.slack.handlers.handle_standard_answers import ( handle_standard_answers, ) from onyx.onyxbot.slack.models import SlackMessageInfo from onyx.onyxbot.slack.utils import fetch_slack_user_ids_from_emails from onyx.onyxbot.slack.utils import fetch_user_ids_from_groups from onyx.onyxbot.slack.utils import respond_in_thread_or_channel from onyx.onyxbot.slack.utils import slack_usage_report from onyx.onyxbot.slack.utils import update_emote_react from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop from shared_configs.configs import SLACK_CHANNEL_ID logger_base = setup_logger() def send_msg_ack_to_user(details: SlackMessageInfo, client: WebClient) -> None: if details.is_slash_command and details.sender_id: respond_in_thread_or_channel( client=client, channel=details.channel_to_respond, thread_ts=details.msg_to_respond, receiver_ids=[details.sender_id], text="Hi, we're evaluating your query :face_with_monocle:", ) return update_emote_react( emoji=ONYX_BOT_REACT_EMOJI, channel=details.channel_to_respond, message_ts=details.msg_to_respond, remove=False, client=client, ) def schedule_feedback_reminder( details: SlackMessageInfo, include_followup: bool, client: WebClient ) -> str | None: logger = setup_logger(extra={SLACK_CHANNEL_ID: details.channel_to_respond}) if not ONYX_BOT_FEEDBACK_REMINDER: logger.info("Scheduled feedback reminder disabled...") return None try: permalink = client.chat_getPermalink( channel=details.channel_to_respond, message_ts=details.msg_to_respond, # type:ignore ) except SlackApiError as e: logger.error(f"Unable to generate the feedback reminder permalink: {e}") return None now = datetime.datetime.now() future = now + datetime.timedelta(minutes=ONYX_BOT_FEEDBACK_REMINDER) try: response = client.chat_scheduleMessage( channel=details.sender_id, # type:ignore post_at=int(future.timestamp()), blocks=[ get_feedback_reminder_blocks( thread_link=permalink.data["permalink"], # type:ignore include_followup=include_followup, ) ], text="", ) logger.info("Scheduled feedback reminder configured") return response.data["scheduled_message_id"] # type:ignore except SlackApiError as e: logger.error(f"Unable to generate the feedback reminder message: {e}") return None def remove_scheduled_feedback_reminder( client: WebClient, channel: str | None, msg_id: str ) -> None: logger = setup_logger(extra={SLACK_CHANNEL_ID: channel}) try: client.chat_deleteScheduledMessage( channel=channel, # type:ignore scheduled_message_id=msg_id, ) logger.info("Scheduled feedback reminder deleted") except SlackApiError as e: if e.response["error"] == "invalid_scheduled_message_id": logger.info( "Unable to delete the scheduled message. It must have already been posted" ) def handle_message( message_info: SlackMessageInfo, slack_channel_config: SlackChannelConfig, client: WebClient, feedback_reminder_id: str | None, ) -> bool: """Potentially respond to the user message depending on filters and if an answer was generated Returns True if need to respond with an additional message to the user(s) after this function is finished. True indicates an unexpected failure that needs to be communicated Query thrown out by filters due to config does not count as a failure that should be notified Onyx failing to answer/retrieve docs does count and should be notified """ channel = message_info.channel_to_respond logger = setup_logger(extra={SLACK_CHANNEL_ID: channel}) messages = message_info.thread_messages sender_id = message_info.sender_id bypass_filters = message_info.bypass_filters is_slash_command = message_info.is_slash_command is_bot_dm = message_info.is_bot_dm action = "slack_message" if is_slash_command: action = "slack_slash_message" elif bypass_filters: action = "slack_tag_message" elif is_bot_dm: action = "slack_dm_message" slack_usage_report(action=action, sender_id=sender_id, client=client) document_set_names: list[str] | None = None persona = slack_channel_config.persona if slack_channel_config else None if persona: document_set_names = [ document_set.name for document_set in persona.document_sets ] respond_tag_only = False respond_member_group_list = None channel_conf = None if slack_channel_config and slack_channel_config.channel_config: channel_conf = slack_channel_config.channel_config if not bypass_filters and "answer_filters" in channel_conf: if ( "questionmark_prefilter" in channel_conf["answer_filters"] and "?" not in messages[-1].message ): logger.info( "Skipping message since it does not contain a question mark" ) return False logger.info( "Found slack bot config for channel. Restricting bot to use document " f"sets: {document_set_names}, " f"validity checks enabled: {channel_conf.get('answer_filters', 'NA')}" ) respond_tag_only = channel_conf.get("respond_tag_only") or False respond_member_group_list = channel_conf.get("respond_member_group_list", None) # Only default config can be disabled. # If channel config is disabled, bot should not respond to this message (including DMs) if slack_channel_config.channel_config.get("disabled"): logger.info("Skipping message: OnyxBot is disabled for this channel") return False # If bot should only respond to tags and is not tagged nor in a DM, skip message if respond_tag_only and not bypass_filters and not is_bot_dm: logger.info("Skipping message: OnyxBot only responds to tags in this channel") return False # List of user id to send message to, if None, send to everyone in channel send_to: list[str] | None = None missing_users: list[str] | None = None if respond_member_group_list: send_to, missing_ids = fetch_slack_user_ids_from_emails( respond_member_group_list, client ) user_ids, missing_users = fetch_user_ids_from_groups(missing_ids, client) send_to = list(set(send_to + user_ids)) if send_to else user_ids if missing_users: logger.warning(f"Failed to find these users/groups: {missing_users}") # If configured to respond to team members only, then cannot be used with a /OnyxBot command # which would just respond to the sender if send_to and is_slash_command: if sender_id: respond_in_thread_or_channel( client=client, channel=channel, receiver_ids=[sender_id], text="The OnyxBot slash command is not enabled for this channel", thread_ts=None, ) try: send_msg_ack_to_user(message_info, client) except SlackApiError as e: logger.error(f"Was not able to react to user message due to: {e}") with get_session_with_current_tenant() as db_session: if message_info.email: existing_user = get_user_by_email(message_info.email, db_session) if existing_user is None: # New user — check seat availability before creating check_seat_fn = fetch_ee_implementation_or_noop( "onyx.db.license", "check_seat_availability", None, ) # noop returns None when called; real function returns SeatAvailabilityResult seat_result = check_seat_fn(db_session=db_session) if seat_result is not None and not seat_result.available: logger.info( f"Blocked new Slack user {message_info.email}: {seat_result.error_message}" ) respond_in_thread_or_channel( client=client, channel=channel, thread_ts=message_info.msg_to_respond, text=( "We weren't able to respond because your organization " "has reached its user seat limit. Since this is your " "first time interacting with the bot, a new account " "could not be created for you. Please contact your " "Onyx administrator to add more seats." ), ) return False elif ( not existing_user.is_active and existing_user.account_type == AccountType.BOT ): check_seat_fn = fetch_ee_implementation_or_noop( "onyx.db.license", "check_seat_availability", None, ) seat_result = check_seat_fn(db_session=db_session) if seat_result is not None and not seat_result.available: logger.info( f"Blocked inactive Slack user {message_info.email}: {seat_result.error_message}" ) respond_in_thread_or_channel( client=client, channel=channel, thread_ts=message_info.msg_to_respond, text=( "We weren't able to respond because your organization " "has reached its user seat limit. Your account is " "currently deactivated and cannot be reactivated " "until more seats are available. Please contact " "your Onyx administrator." ), ) return False activate_user(existing_user, db_session) invalidate_license_cache_fn = fetch_ee_implementation_or_noop( "onyx.db.license", "invalidate_license_cache", None, ) invalidate_license_cache_fn() logger.info(f"Reactivated inactive Slack user {message_info.email}") add_slack_user_if_not_exists(db_session, message_info.email) # first check if we need to respond with a standard answer # standard answers should be published in a thread used_standard_answer = handle_standard_answers( message_info=message_info, receiver_ids=send_to, slack_channel_config=slack_channel_config, logger=logger, client=client, db_session=db_session, ) if used_standard_answer: return False # if no standard answer applies, try a regular answer issue_with_regular_answer = handle_regular_answer( message_info=message_info, slack_channel_config=slack_channel_config, receiver_ids=send_to, client=client, channel=channel, logger=logger, feedback_reminder_id=feedback_reminder_id, ) return issue_with_regular_answer ================================================ FILE: backend/onyx/onyxbot/slack/handlers/handle_regular_answer.py ================================================ import functools from collections.abc import Callable from typing import Any from typing import Optional from typing import TypeVar from retry import retry from slack_sdk import WebClient from onyx.auth.users import get_anonymous_user from onyx.chat.models import ChatBasicResponse from onyx.chat.process_message import gather_stream from onyx.chat.process_message import handle_stream_message_objects from onyx.configs.constants import DEFAULT_PERSONA_ID from onyx.configs.constants import MessageType from onyx.configs.onyxbot_configs import ONYX_BOT_DISABLE_DOCS_ONLY_ANSWER from onyx.configs.onyxbot_configs import ONYX_BOT_DISPLAY_ERROR_MSGS from onyx.configs.onyxbot_configs import ONYX_BOT_NUM_RETRIES from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI from onyx.context.search.models import BaseFilters from onyx.context.search.models import Tag from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import SlackChannelConfig from onyx.db.models import User from onyx.db.persona import get_persona_by_id from onyx.db.users import get_user_by_email from onyx.onyxbot.slack.blocks import build_slack_response_blocks from onyx.onyxbot.slack.constants import SLACK_CHANNEL_REF_PATTERN from onyx.onyxbot.slack.handlers.utils import send_team_member_message from onyx.onyxbot.slack.models import SlackMessageInfo from onyx.onyxbot.slack.models import ThreadMessage from onyx.onyxbot.slack.utils import get_channel_from_id from onyx.onyxbot.slack.utils import get_channel_name_from_id from onyx.onyxbot.slack.utils import respond_in_thread_or_channel from onyx.onyxbot.slack.utils import SlackRateLimiter from onyx.onyxbot.slack.utils import update_emote_react from onyx.server.query_and_chat.models import ChatSessionCreationRequest from onyx.server.query_and_chat.models import MessageOrigin from onyx.server.query_and_chat.models import SendMessageRequest from onyx.utils.logger import OnyxLoggingAdapter srl = SlackRateLimiter() RT = TypeVar("RT") # return type def resolve_channel_references( message: str, client: WebClient, logger: OnyxLoggingAdapter, ) -> tuple[str, list[Tag]]: """Parse Slack channel references from a message, resolve IDs to names, replace the raw markup with readable #channel-name, and return channel tags for search filtering.""" tags: list[Tag] = [] channel_matches = SLACK_CHANNEL_REF_PATTERN.findall(message) seen_channel_ids: set[str] = set() for channel_id, channel_name_from_markup in channel_matches: if channel_id in seen_channel_ids: continue seen_channel_ids.add(channel_id) channel_name = channel_name_from_markup or None if not channel_name: try: channel_info = get_channel_from_id(client=client, channel_id=channel_id) channel_name = channel_info.get("name") or None except Exception: logger.warning(f"Failed to resolve channel name for ID: {channel_id}") if not channel_name: continue # Replace raw Slack markup with readable channel name if channel_name_from_markup: message = message.replace( f"<#{channel_id}|{channel_name_from_markup}>", f"#{channel_name}", ) else: message = message.replace( f"<#{channel_id}>", f"#{channel_name}", ) tags.append(Tag(tag_key="Channel", tag_value=channel_name)) return message, tags def rate_limits( client: WebClient, channel: str, thread_ts: Optional[str] ) -> Callable[[Callable[..., RT]], Callable[..., RT]]: def decorator(func: Callable[..., RT]) -> Callable[..., RT]: @functools.wraps(func) def wrapper(*args: Any, **kwargs: Any) -> RT: if not srl.is_available(): func_randid, position = srl.init_waiter() srl.notify(client, channel, position, thread_ts) while not srl.is_available(): srl.waiter(func_randid) srl.acquire_slot() return func(*args, **kwargs) return wrapper return decorator def build_slack_context_str( messages: list[ThreadMessage], channel_name: str | None ) -> str | None: if not messages: return None if channel_name: slack_context_str = f"The following is a thread in Slack in channel {channel_name}:\n====================\n" else: slack_context_str = ( "The following is a thread from Slack:\n====================\n" ) message_strs: list[str] = [] for message in messages: if message.role == MessageType.USER: message_text = f"{message.sender or 'Unknown User'}:\n{message.message}" elif message.role == MessageType.ASSISTANT: message_text = f"AI:\n{message.message}" else: message_text = f"{message.role.value.upper()}:\n{message.message}" message_strs.append(message_text) return slack_context_str + "\n\n".join(message_strs) def handle_regular_answer( message_info: SlackMessageInfo, slack_channel_config: SlackChannelConfig, receiver_ids: list[str] | None, client: WebClient, channel: str, logger: OnyxLoggingAdapter, feedback_reminder_id: str | None, num_retries: int = ONYX_BOT_NUM_RETRIES, should_respond_with_error_msgs: bool = ONYX_BOT_DISPLAY_ERROR_MSGS, disable_docs_only_answer: bool = ONYX_BOT_DISABLE_DOCS_ONLY_ANSWER, ) -> bool: channel_conf = slack_channel_config.channel_config messages = message_info.thread_messages message_ts_to_respond_to = message_info.msg_to_respond is_slash_command = message_info.is_slash_command # Capture whether response mode for channel is ephemeral. Even if the channel is set # to respond with an ephemeral message, we still send as non-ephemeral if # the message is a dm with the Onyx bot. send_as_ephemeral = ( slack_channel_config.channel_config.get("is_ephemeral", False) or message_info.is_slash_command ) and not message_info.is_bot_dm # If the channel is configured to respond with an ephemeral message, # or the message is a dm to the Onyx bot, we should use the proper onyx user from the email. # This will make documents privately accessible to the user available to Onyx Bot answers. # Otherwise - if not ephemeral or DM to Onyx Bot - we use anonymous user to restrict # to public docs. if message_info.email: with get_session_with_current_tenant() as db_session: found_user = get_user_by_email(message_info.email, db_session) user = found_user if found_user else get_anonymous_user() else: user = get_anonymous_user() target_thread_ts = ( None if send_as_ephemeral and len(message_info.thread_messages) < 2 else message_ts_to_respond_to ) target_receiver_ids = ( [message_info.sender_id] if message_info.sender_id and send_as_ephemeral else receiver_ids ) document_set_names: list[str] | None = None # If no persona is specified, use the default search based persona # This way slack flow always has a persona persona = slack_channel_config.persona if not persona: logger.warning("No persona found for channel config, using default persona") with get_session_with_current_tenant() as db_session: persona = get_persona_by_id(DEFAULT_PERSONA_ID, user, db_session) document_set_names = [ document_set.name for document_set in persona.document_sets ] else: logger.info(f"Using persona {persona.name} for channel config") document_set_names = [ document_set.name for document_set in persona.document_sets ] user_message = messages[-1] history_messages = messages[:-1] # Resolve any <#CHANNEL_ID> references in the user message to readable # channel names and extract channel tags for search filtering resolved_message, channel_tags = resolve_channel_references( message=user_message.message, client=client, logger=logger, ) user_message = ThreadMessage( message=resolved_message, sender=user_message.sender, role=user_message.role, ) channel_name, _ = get_channel_name_from_id( client=client, channel_id=channel, ) # NOTE: only the message history will contain the person asking. This is likely # fine since the most common use case for this info is when referring to a user # who previously posted in the thread. slack_context_str = build_slack_context_str(history_messages, channel_name) if not message_ts_to_respond_to and not is_slash_command: # if the message is not "/onyx" command, then it should have a message ts to respond to raise RuntimeError( "No message timestamp to respond to in `handle_message`. This should never happen." ) @retry( tries=num_retries, delay=0.25, backoff=2, ) @rate_limits(client=client, channel=channel, thread_ts=message_ts_to_respond_to) def _get_slack_answer( new_message_request: SendMessageRequest, slack_context_str: str | None, onyx_user: User, ) -> ChatBasicResponse: with get_session_with_current_tenant() as db_session: packets = handle_stream_message_objects( new_msg_req=new_message_request, user=onyx_user, db_session=db_session, bypass_acl=False, additional_context=slack_context_str, slack_context=message_info.slack_context, ) answer = gather_stream(packets) if answer.error_msg: raise RuntimeError(answer.error_msg) return answer try: # By leaving time_cutoff and favor_recent as None, and setting enable_auto_detect_filters # it allows the slack flow to extract out filters from the user query filters = BaseFilters( source_type=None, document_set=document_set_names, time_cutoff=None, tags=channel_tags if channel_tags else None, ) new_message_request = SendMessageRequest( message=user_message.message, allowed_tool_ids=None, forced_tool_id=None, file_descriptors=[], internal_search_filters=filters, deep_research=False, origin=MessageOrigin.SLACKBOT, chat_session_info=ChatSessionCreationRequest( persona_id=persona.id, ), ) # if it's a DM or ephemeral message, answer based on private documents. # otherwise, answer based on public documents ONLY as to not leak information. can_search_over_private_docs = message_info.is_bot_dm or send_as_ephemeral answer = _get_slack_answer( new_message_request=new_message_request, onyx_user=user if can_search_over_private_docs else get_anonymous_user(), slack_context_str=slack_context_str, ) # If a channel filter was applied but no results were found, override # the LLM response to avoid hallucinated answers about unindexed channels if channel_tags and not answer.citation_info and not answer.top_documents: channel_names = ", ".join(f"#{tag.tag_value}" for tag in channel_tags) answer.answer = ( f"No indexed data found for {channel_names}. " "This channel may not be indexed, or there may be no messages " "matching your query within it." ) except Exception as e: logger.exception( f"Unable to process message - did not successfully answer in {num_retries} attempts" ) # Optionally, respond in thread with the error message, Used primarily # for debugging purposes if should_respond_with_error_msgs: respond_in_thread_or_channel( client=client, channel=channel, receiver_ids=target_receiver_ids, text=f"Encountered exception when trying to answer: \n\n```{e}```", thread_ts=target_thread_ts, send_as_ephemeral=send_as_ephemeral, ) # In case of failures, don't keep the reaction there permanently update_emote_react( emoji=ONYX_BOT_REACT_EMOJI, channel=message_info.channel_to_respond, message_ts=message_info.msg_to_respond, remove=True, client=client, ) return True # Got an answer at this point, can remove reaction and give results if not is_slash_command: # Slash commands don't have reactions update_emote_react( emoji=ONYX_BOT_REACT_EMOJI, channel=message_info.channel_to_respond, message_ts=message_info.msg_to_respond, remove=True, client=client, ) if not answer.answer and disable_docs_only_answer: logger.notice( "Unable to find answer - not responding since the `ONYX_BOT_DISABLE_DOCS_ONLY_ANSWER` env variable is set" ) return True only_respond_if_citations = ( channel_conf and "well_answered_postfilter" in channel_conf.get("answer_filters", []) ) if ( only_respond_if_citations and not answer.citation_info and not message_info.bypass_filters and not channel_tags ): logger.error( f"Unable to find citations to answer: '{answer.answer}' - not answering!" ) # Optionally, respond in thread with the error message # Used primarily for debugging purposes if should_respond_with_error_msgs: respond_in_thread_or_channel( client=client, channel=channel, receiver_ids=target_receiver_ids, text="Found no citations or quotes when trying to answer.", thread_ts=target_thread_ts, send_as_ephemeral=send_as_ephemeral, ) return True if ( send_as_ephemeral and target_receiver_ids is not None and len(target_receiver_ids) == 1 ): offer_ephemeral_publication = True skip_ai_feedback = True else: offer_ephemeral_publication = False skip_ai_feedback = False all_blocks = build_slack_response_blocks( message_info=message_info, answer=answer, channel_conf=channel_conf, feedback_reminder_id=feedback_reminder_id, offer_ephemeral_publication=offer_ephemeral_publication, skip_ai_feedback=skip_ai_feedback, ) # NOTE(rkuo): Slack has a maximum block list size of 50. # we should modify build_slack_response_blocks to respect the max # but enforcing the hard limit here is the last resort. all_blocks = all_blocks[:50] try: respond_in_thread_or_channel( client=client, channel=channel, receiver_ids=target_receiver_ids, text="Hello! Onyx has some results for you!", blocks=all_blocks, thread_ts=target_thread_ts, # don't unfurl, since otherwise we will have 5+ previews which makes the message very long unfurl=False, send_as_ephemeral=send_as_ephemeral, ) # For DM (ephemeral message), we need to create a thread via a normal message so the user can see # the ephemeral message. This also will give the user a notification which ephemeral message does not. # if there is no message_ts_to_respond_to, and we have made it this far, then this is a /onyx message # so we shouldn't send_team_member_message if ( target_receiver_ids and message_ts_to_respond_to is not None and not send_as_ephemeral and target_thread_ts is not None ): send_team_member_message( client=client, channel=channel, thread_ts=target_thread_ts, receiver_ids=target_receiver_ids, send_as_ephemeral=send_as_ephemeral, ) return False except Exception: logger.exception( f"Unable to process message - could not respond in slack in {num_retries} attempts" ) return True ================================================ FILE: backend/onyx/onyxbot/slack/handlers/handle_standard_answers.py ================================================ from slack_sdk import WebClient from sqlalchemy.orm import Session from onyx.db.models import SlackChannelConfig from onyx.onyxbot.slack.models import SlackMessageInfo from onyx.utils.logger import OnyxLoggingAdapter from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_versioned_implementation logger = setup_logger() def handle_standard_answers( message_info: SlackMessageInfo, receiver_ids: list[str] | None, slack_channel_config: SlackChannelConfig, logger: OnyxLoggingAdapter, client: WebClient, db_session: Session, ) -> bool: """Returns whether one or more Standard Answer message blocks were emitted by the Slack bot""" versioned_handle_standard_answers = fetch_versioned_implementation( "onyx.onyxbot.slack.handlers.handle_standard_answers", "_handle_standard_answers", ) return versioned_handle_standard_answers( message_info=message_info, receiver_ids=receiver_ids, slack_channel_config=slack_channel_config, logger=logger, client=client, db_session=db_session, ) def _handle_standard_answers( message_info: SlackMessageInfo, # noqa: ARG001 receiver_ids: list[str] | None, # noqa: ARG001 slack_channel_config: SlackChannelConfig, # noqa: ARG001 logger: OnyxLoggingAdapter, # noqa: ARG001 client: WebClient, # noqa: ARG001 db_session: Session, # noqa: ARG001 ) -> bool: """ Standard Answers are a paid Enterprise Edition feature. This is the fallback function handling the case where EE features are not enabled. Always returns false i.e. since EE features are not enabled, we NEVER create any Slack message blocks. """ return False ================================================ FILE: backend/onyx/onyxbot/slack/handlers/utils.py ================================================ from slack_sdk import WebClient from onyx.onyxbot.slack.utils import respond_in_thread_or_channel def send_team_member_message( client: WebClient, channel: str, thread_ts: str, receiver_ids: list[str] | None = None, # noqa: ARG001 send_as_ephemeral: bool = False, ) -> None: respond_in_thread_or_channel( client=client, channel=channel, text=( "👋 Hi, we've just gathered and forwarded the relevant " + "information to the team. They'll get back to you shortly!" ), thread_ts=thread_ts, receiver_ids=None, send_as_ephemeral=send_as_ephemeral, ) ================================================ FILE: backend/onyx/onyxbot/slack/icons.py ================================================ from onyx.configs.constants import DocumentSource def source_to_github_img_link(source: DocumentSource) -> str | None: # TODO: store these images somewhere better if source == DocumentSource.WEB.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Web.png" if source == DocumentSource.FILE.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/File.png" if source == DocumentSource.GOOGLE_SITES.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/GoogleSites.png" if source == DocumentSource.SLACK.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Slack.png" if source == DocumentSource.GMAIL.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Gmail.png" if source == DocumentSource.GOOGLE_DRIVE.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/GoogleDrive.png" if source == DocumentSource.GITHUB.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Github.png" if source == DocumentSource.GITLAB.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Gitlab.png" if source == DocumentSource.CONFLUENCE.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Confluence.png" if source == DocumentSource.JIRA.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Jira.png" if source == DocumentSource.NOTION.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Notion.png" if source == DocumentSource.ZENDESK.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Zendesk.png" if source == DocumentSource.GONG.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Gong.png" if source == DocumentSource.LINEAR.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Linear.png" if source == DocumentSource.PRODUCTBOARD.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Productboard.webp" if source == DocumentSource.SLAB.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/SlabLogo.png" if source == DocumentSource.ZULIP.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Zulip.png" if source == DocumentSource.GURU.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/Guru.png" if source == DocumentSource.HUBSPOT.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/HubSpot.png" if source == DocumentSource.DOCUMENT360.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Document360.png" if source == DocumentSource.BOOKSTACK.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Bookstack.png" if source == DocumentSource.OUTLINE.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Outline.png" if source == DocumentSource.LOOPIO.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Loopio.png" if source == DocumentSource.SHAREPOINT.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/web/public/Sharepoint.png" if source == DocumentSource.REQUESTTRACKER.value: # just use file icon for now return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/File.png" if source == DocumentSource.INGESTION_API.value: return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/File.png" return "https://raw.githubusercontent.com/onyx-dot-app/onyx/main/backend/slackbot_images/File.png" ================================================ FILE: backend/onyx/onyxbot/slack/listener.py ================================================ import os import signal import sys import threading import time from collections.abc import Callable from contextvars import Token from threading import Event from types import FrameType from typing import Any from typing import cast from typing import Dict import psycopg2.errors from prometheus_client import Gauge from prometheus_client import start_http_server from redis.lock import Lock from redis.lock import Lock as RedisLock from slack_sdk import WebClient from slack_sdk.errors import SlackApiError from slack_sdk.http_retry import ConnectionErrorRetryHandler from slack_sdk.http_retry import RateLimitErrorRetryHandler from slack_sdk.http_retry import RetryHandler from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.response import SocketModeResponse from sqlalchemy.orm import Session from onyx.configs.app_configs import DEV_MODE from onyx.configs.app_configs import POD_NAME from onyx.configs.app_configs import POD_NAMESPACE from onyx.configs.constants import MessageType from onyx.configs.constants import OnyxRedisLocks from onyx.configs.onyxbot_configs import NOTIFY_SLACKBOT_NO_ANSWER from onyx.connectors.slack.utils import expert_info_from_slack_id from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.engine.sql_engine import get_session_with_tenant from onyx.db.engine.sql_engine import SqlEngine from onyx.db.engine.tenant_utils import get_all_tenant_ids from onyx.db.models import SlackBot from onyx.db.search_settings import get_current_search_settings from onyx.db.slack_bot import fetch_slack_bot from onyx.db.slack_bot import fetch_slack_bots from onyx.key_value_store.interface import KvKeyNotFoundError from onyx.natural_language_processing.search_nlp_models import EmbeddingModel from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder from onyx.onyxbot.slack.config import get_slack_channel_config_for_bot_and_channel from onyx.onyxbot.slack.config import MAX_TENANTS_PER_POD from onyx.onyxbot.slack.config import TENANT_ACQUISITION_INTERVAL from onyx.onyxbot.slack.config import TENANT_HEARTBEAT_EXPIRATION from onyx.onyxbot.slack.config import TENANT_HEARTBEAT_INTERVAL from onyx.onyxbot.slack.config import TENANT_LOCK_EXPIRATION from onyx.onyxbot.slack.constants import DISLIKE_BLOCK_ACTION_ID from onyx.onyxbot.slack.constants import FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_ACTION_ID from onyx.onyxbot.slack.constants import FOLLOWUP_BUTTON_RESOLVED_ACTION_ID from onyx.onyxbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID from onyx.onyxbot.slack.constants import IMMEDIATE_RESOLVED_BUTTON_ACTION_ID from onyx.onyxbot.slack.constants import KEEP_TO_YOURSELF_ACTION_ID from onyx.onyxbot.slack.constants import LIKE_BLOCK_ACTION_ID from onyx.onyxbot.slack.constants import SHOW_EVERYONE_ACTION_ID from onyx.onyxbot.slack.constants import VIEW_DOC_FEEDBACK_ID from onyx.onyxbot.slack.handlers.handle_buttons import handle_doc_feedback_button from onyx.onyxbot.slack.handlers.handle_buttons import handle_followup_button from onyx.onyxbot.slack.handlers.handle_buttons import ( handle_followup_resolved_button, ) from onyx.onyxbot.slack.handlers.handle_buttons import ( handle_generate_answer_button, ) from onyx.onyxbot.slack.handlers.handle_buttons import ( handle_publish_ephemeral_message_button, ) from onyx.onyxbot.slack.handlers.handle_buttons import handle_slack_feedback from onyx.onyxbot.slack.handlers.handle_message import handle_message from onyx.onyxbot.slack.handlers.handle_message import ( remove_scheduled_feedback_reminder, ) from onyx.onyxbot.slack.handlers.handle_message import schedule_feedback_reminder from onyx.onyxbot.slack.models import SlackContext from onyx.onyxbot.slack.models import SlackMessageInfo from onyx.onyxbot.slack.models import ThreadMessage from onyx.onyxbot.slack.utils import check_message_limit from onyx.onyxbot.slack.utils import decompose_action_id from onyx.onyxbot.slack.utils import get_channel_name_from_id from onyx.onyxbot.slack.utils import get_channel_type_from_id from onyx.onyxbot.slack.utils import get_onyx_bot_auth_ids from onyx.onyxbot.slack.utils import read_slack_thread from onyx.onyxbot.slack.utils import remove_onyx_bot_tag from onyx.onyxbot.slack.utils import respond_in_thread_or_channel from onyx.onyxbot.slack.utils import TenantSocketModeClient from onyx.redis.redis_pool import get_redis_client from onyx.server.manage.models import SlackBotTokens from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable from shared_configs.configs import DISALLOWED_SLACK_BOT_TENANT_LIST from shared_configs.configs import MODEL_SERVER_HOST from shared_configs.configs import MODEL_SERVER_PORT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA from shared_configs.configs import SLACK_CHANNEL_ID from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() # Prometheus metric for HPA active_tenants_gauge = Gauge( "active_tenants", "Number of active tenants handled by this pod", ["namespace", "pod"], ) # In rare cases, some users have been experiencing a massive amount of trivial messages coming through # to the Slack Bot with trivial messages. Adding this to avoid exploding LLM costs while we track down # the cause. _SLACK_GREETINGS_TO_IGNORE = { "Welcome back!", "It's going to be a great day.", "Salutations!", "Greetings!", "Feeling great!", "Hi there", ":wave:", } # This is always (currently) the user id of Slack's official slackbot _OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT" # Fields to exclude from Slack payload logging # Intention is to not log slack message content _EXCLUDED_SLACK_PAYLOAD_FIELDS = {"text", "blocks"} class SlackbotHandler: def __init__(self) -> None: logger.info("Initializing SlackbotHandler") self.tenant_ids: set[str] = set() # The keys for these dictionaries are tuples of (tenant_id, slack_bot_id) self.socket_clients: Dict[tuple[str, int], TenantSocketModeClient] = {} self.slack_bot_tokens: Dict[tuple[str, int], SlackBotTokens] = {} # Store Redis lock objects here so we can release them properly self.redis_locks: Dict[str, Lock] = {} self.running = True self.pod_id = os.environ.get("HOSTNAME", "unknown_pod") self._shutdown_event = Event() self._lock = threading.Lock() logger.info(f"Pod ID: {self.pod_id}") # Set up signal handlers for graceful shutdown signal.signal(signal.SIGTERM, self.shutdown) signal.signal(signal.SIGINT, self.shutdown) logger.info("Signal handlers registered") # Start the Prometheus metrics server logger.info("Starting Prometheus metrics server") start_http_server(8000) logger.info("Prometheus metrics server started") # Start background threads logger.info("Starting background threads") self.acquire_thread = threading.Thread( target=self.acquire_tenants_loop, daemon=True ) self.heartbeat_thread = threading.Thread( target=self.heartbeat_loop, daemon=True ) self.acquire_thread.start() self.heartbeat_thread.start() logger.info("Background threads started") def acquire_tenants_loop(self) -> None: while not self._shutdown_event.is_set(): try: self.acquire_tenants() # After we finish acquiring and managing Slack bots, # set the gauge to the number of active tenants (those with Slack bots). active_tenants_gauge.labels(namespace=POD_NAMESPACE, pod=POD_NAME).set( len(self.tenant_ids) ) logger.debug( f"Current active tenants with Slack bots: {len(self.tenant_ids)}" ) except Exception as e: logger.exception(f"Error in Slack acquisition: {e}") self._shutdown_event.wait(timeout=TENANT_ACQUISITION_INTERVAL) def heartbeat_loop(self) -> None: """This heartbeats into redis. NOTE(rkuo): this is not thread-safe with acquire_tenants_loop and will occasionally exception. Fix it! """ while not self._shutdown_event.is_set(): try: with self._lock: tenant_ids = self.tenant_ids.copy() SlackbotHandler.send_heartbeats(self.pod_id, tenant_ids) logger.debug(f"Sent heartbeats for {len(tenant_ids)} active tenants") except Exception as e: logger.exception(f"Error in heartbeat loop: {e}") self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL) def _manage_clients_per_tenant( self, db_session: Session, tenant_id: str, bot: SlackBot ) -> None: """ - If the tokens are missing or empty, close the socket client and remove them. - If the tokens have changed, close the existing socket client and reconnect. - If the tokens are new, warm up the model and start a new socket client. """ tenant_bot_pair = (tenant_id, bot.id) # If the tokens are missing or empty, close the socket client and remove them. if not bot.bot_token or not bot.app_token: logger.debug( f"No Slack bot tokens found for tenant={tenant_id}, bot {bot.id}" ) if tenant_bot_pair in self.socket_clients: self.socket_clients[tenant_bot_pair].close() del self.socket_clients[tenant_bot_pair] del self.slack_bot_tokens[tenant_bot_pair] return slack_bot_tokens = SlackBotTokens( bot_token=bot.bot_token.get_value(apply_mask=False), app_token=bot.app_token.get_value(apply_mask=False), ) tokens_exist = tenant_bot_pair in self.slack_bot_tokens tokens_changed = ( tokens_exist and slack_bot_tokens != self.slack_bot_tokens[tenant_bot_pair] ) if not tokens_exist or tokens_changed: if tokens_exist: logger.info( f"Slack Bot tokens changed for tenant={tenant_id}, bot {bot.id}; reconnecting" ) else: # Warm up the model if needed search_settings = get_current_search_settings(db_session) embedding_model = EmbeddingModel.from_db_model( search_settings=search_settings, server_host=MODEL_SERVER_HOST, server_port=MODEL_SERVER_PORT, ) warm_up_bi_encoder(embedding_model=embedding_model) self.slack_bot_tokens[tenant_bot_pair] = slack_bot_tokens # Close any existing connection first if tenant_bot_pair in self.socket_clients: self.socket_clients[tenant_bot_pair].close() socket_client = self.start_socket_client( bot.id, tenant_id, slack_bot_tokens ) if socket_client: # Ensure tenant is tracked as active self.socket_clients[tenant_id, bot.id] = socket_client logger.info( f"Started SocketModeClient: {tenant_id=} {socket_client.bot_name=} {bot.id=}" ) self.tenant_ids.add(tenant_id) def acquire_tenants(self) -> None: """ - Attempt to acquire a Redis lock for each tenant. - If acquired, check if that tenant actually has Slack bots. - If yes, store them in self.tenant_ids and manage the socket connections. - If a tenant in self.tenant_ids no longer has Slack bots, remove it (and release the lock in this scope). """ token: Token[str | None] # tenants that are disabled (e.g. their trial is over and haven't subscribed) # for non-cloud, this will return an empty set gated_tenants = fetch_ee_implementation_or_noop( "onyx.server.tenants.product_gating", "get_gated_tenants", set(), )() all_active_tenants = [ tenant_id for tenant_id in get_all_tenant_ids() if tenant_id not in gated_tenants ] # 1) Try to acquire locks for new tenants for tenant_id in all_active_tenants: if ( DISALLOWED_SLACK_BOT_TENANT_LIST is not None and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST ): logger.debug(f"Tenant {tenant_id} is disallowed; skipping.") continue # Already acquired in a previous loop iteration? if tenant_id in self.tenant_ids: continue # Respect max tenant limit per pod if len(self.tenant_ids) >= MAX_TENANTS_PER_POD: logger.info( f"Max tenants per pod reached, not acquiring more: {MAX_TENANTS_PER_POD=}" ) break redis_client = get_redis_client(tenant_id=tenant_id) # Acquire a Redis lock (non-blocking) # thread_local=False because the shutdown event is handled # on an arbitrary thread rlock: RedisLock = redis_client.lock( OnyxRedisLocks.SLACK_BOT_LOCK, timeout=TENANT_LOCK_EXPIRATION, thread_local=False, ) lock_acquired = rlock.acquire(blocking=False) if not lock_acquired and not DEV_MODE: logger.debug( f"Another pod holds the lock for tenant {tenant_id}, skipping." ) continue if lock_acquired: logger.debug(f"Acquired lock for tenant {tenant_id}.") self.redis_locks[tenant_id] = rlock else: # DEV_MODE will skip the lock acquisition guard logger.debug( f"Running in DEV_MODE. Not enforcing lock for {tenant_id}." ) # Now check if this tenant actually has Slack bots token = CURRENT_TENANT_ID_CONTEXTVAR.set( tenant_id or POSTGRES_DEFAULT_SCHEMA ) try: with get_session_with_tenant(tenant_id=tenant_id) as db_session: bots: list[SlackBot] = [] try: bots = list(fetch_slack_bots(db_session=db_session)) except KvKeyNotFoundError: # No Slackbot tokens, pass pass except psycopg2.errors.UndefinedTable: logger.error( "Undefined table error in fetch_slack_bots. Tenant schema may need fixing." ) except Exception as e: logger.exception( f"Error fetching Slack bots for tenant {tenant_id}: {e}" ) if bots: # Mark as active tenant self.tenant_ids.add(tenant_id) for bot in bots: self._manage_clients_per_tenant( db_session=db_session, tenant_id=tenant_id, bot=bot, ) else: # If no Slack bots, release lock immediately (unless in DEV_MODE) if lock_acquired and not DEV_MODE: rlock.release() del self.redis_locks[tenant_id] logger.debug( f"No Slack bots for tenant {tenant_id}; lock released (if held)." ) finally: CURRENT_TENANT_ID_CONTEXTVAR.reset(token) # 2) Make sure tenants we're handling still have Slack bots # and haven't been suspended (gated) for tenant_id in list(self.tenant_ids): if tenant_id in gated_tenants: logger.info( f"Tenant {tenant_id} is now gated (suspended). Disconnecting." ) self._remove_tenant(tenant_id) if tenant_id in self.redis_locks and not DEV_MODE: try: self.redis_locks[tenant_id].release() del self.redis_locks[tenant_id] except Exception as e: logger.error( f"Error releasing lock for gated tenant {tenant_id}: {e}" ) continue token = CURRENT_TENANT_ID_CONTEXTVAR.set( tenant_id or POSTGRES_DEFAULT_SCHEMA ) redis_client = get_redis_client(tenant_id=tenant_id) try: with get_session_with_current_tenant() as db_session: # Attempt to fetch Slack bots try: bots = list(fetch_slack_bots(db_session=db_session)) except KvKeyNotFoundError: # No Slackbot tokens, pass (and remove below) bots = [] except Exception as e: logger.exception(f"Error handling tenant {tenant_id}: {e}") bots = [] if not bots: logger.info( f"Tenant {tenant_id} no longer has Slack bots. Removing." ) self._remove_tenant(tenant_id) # NOTE: We release the lock here (in the same scope it was acquired) if tenant_id in self.redis_locks and not DEV_MODE: try: self.redis_locks[tenant_id].release() del self.redis_locks[tenant_id] logger.info(f"Released lock for tenant {tenant_id}") except Exception as e: logger.error( f"Error releasing lock for tenant {tenant_id}: {e}" ) else: # Manage or reconnect Slack bot sockets for bot in bots: self._manage_clients_per_tenant( db_session=db_session, tenant_id=tenant_id, bot=bot, ) finally: CURRENT_TENANT_ID_CONTEXTVAR.reset(token) def _remove_tenant(self, tenant_id: str) -> None: """ Helper to remove a tenant from `self.tenant_ids` and close any socket clients. (Lock release now happens in `acquire_tenants()`, not here.) """ socket_client_list = list(self.socket_clients.items()) # Close all socket clients for this tenant for (t_id, slack_bot_id), client in socket_client_list: if t_id == tenant_id: client.close() del self.socket_clients[(t_id, slack_bot_id)] del self.slack_bot_tokens[(t_id, slack_bot_id)] logger.info( f"Stopped SocketModeClient for tenant: {t_id}, app: {slack_bot_id}" ) # Remove from active set if tenant_id in self.tenant_ids: self.tenant_ids.remove(tenant_id) @staticmethod def send_heartbeats(pod_id: str, tenant_ids: set[str]) -> None: current_time = int(time.time()) logger.debug(f"Sending heartbeats for {len(tenant_ids)} active tenants") for tenant_id in tenant_ids: redis_client = get_redis_client(tenant_id=tenant_id) heartbeat_key = f"{OnyxRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{pod_id}" redis_client.set( heartbeat_key, current_time, ex=TENANT_HEARTBEAT_EXPIRATION ) @staticmethod def start_socket_client( slack_bot_id: int, tenant_id: str, slack_bot_tokens: SlackBotTokens ) -> TenantSocketModeClient | None: """Returns the socket client if this succeeds""" socket_client: TenantSocketModeClient = _get_socket_client( slack_bot_tokens, tenant_id, slack_bot_id ) try: bot_info = socket_client.web_client.auth_test() if bot_info["ok"]: bot_user_id = bot_info["user_id"] user_info = socket_client.web_client.users_info(user=bot_user_id) if user_info["ok"]: bot_name = ( user_info["user"]["real_name"] or user_info["user"]["name"] ) socket_client.bot_name = bot_name # logger.info( # f"Started socket client for Slackbot with name '{bot_name}' (tenant: {tenant_id}, app: {slack_bot_id})" # ) except SlackApiError as e: # Only error out if we get a not_authed error if "not_authed" in str(e): # for some reason we want to add the tenant to the list when this happens? logger.error( f"Authentication error - Invalid or expired credentials: {tenant_id=} {slack_bot_id=}. Error: {e}" ) return None # Log other Slack API errors but continue logger.error( f"Slack API error fetching bot info: {e} for tenant: {tenant_id}, app: {slack_bot_id}" ) except Exception as e: # Log other exceptions but continue logger.error( f"Error fetching bot info: {e} for tenant: {tenant_id}, app: {slack_bot_id}" ) # Append the event handler process_slack_event = create_process_slack_event() socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore # Establish a WebSocket connection to the Socket Mode servers # logger.debug( # f"Connecting socket client for tenant: {tenant_id}, app: {slack_bot_id}" # ) socket_client.connect() # logger.info( # f"Started SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}" # ) return socket_client @staticmethod def stop_socket_clients( pod_id: str, socket_clients: Dict[tuple[str, int], TenantSocketModeClient] ) -> None: socket_client_list = list(socket_clients.items()) length = len(socket_client_list) x = 0 for (tenant_id, slack_bot_id), client in socket_client_list: x += 1 client.close() logger.info( f"Stopped SocketModeClient {x}/{length}: {pod_id=} {tenant_id=} {slack_bot_id=}" ) def shutdown( self, signum: int | None, # noqa: ARG002 frame: FrameType | None, # noqa: ARG002 ) -> None: if not self.running: return logger.info("Shutting down gracefully") self.running = False self._shutdown_event.set() # set the shutdown event # wait for threads to detect the event and exit self.acquire_thread.join(timeout=60.0) self.heartbeat_thread.join(timeout=60.0) # Stop all socket clients logger.info(f"Stopping {len(self.socket_clients)} socket clients") SlackbotHandler.stop_socket_clients(self.pod_id, self.socket_clients) # Release locks for all tenants we currently hold logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants") for tenant_id in list(self.tenant_ids): if tenant_id in self.redis_locks: try: self.redis_locks[tenant_id].release() logger.info(f"Released lock for tenant {tenant_id}") except Exception as e: logger.error(f"Error releasing lock for tenant {tenant_id}: {e}") finally: del self.redis_locks[tenant_id] # Wait for background threads to finish (with a timeout) logger.info("Waiting for background threads to finish...") self.acquire_thread.join(timeout=5) self.heartbeat_thread.join(timeout=5) logger.info("Shutdown complete") sys.exit(0) def sanitize_slack_payload(payload: dict) -> dict: """Remove message content from Slack payload for logging""" sanitized = { k: v for k, v in payload.items() if k not in _EXCLUDED_SLACK_PAYLOAD_FIELDS } if "event" in sanitized and isinstance(sanitized["event"], dict): sanitized["event"] = { k: v for k, v in sanitized["event"].items() if k not in _EXCLUDED_SLACK_PAYLOAD_FIELDS } return sanitized def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -> bool: """True to keep going, False to ignore this Slack request""" # skip cases where the bot is disabled in the web UI tenant_id = get_current_tenant_id() bot_token_user_id, bot_token_bot_id = get_onyx_bot_auth_ids( tenant_id, client.web_client ) logger.info(f"prefilter_requests: {bot_token_user_id=} {bot_token_bot_id=}") with get_session_with_current_tenant() as db_session: slack_bot = fetch_slack_bot( db_session=db_session, slack_bot_id=client.slack_bot_id ) if not slack_bot: logger.error( f"Slack bot with ID '{client.slack_bot_id}' not found. Skipping request." ) return False if not slack_bot.enabled: logger.info( f"Slack bot with ID '{client.slack_bot_id}' is disabled. Skipping request." ) return False if req.type == "events_api": # Verify channel is valid event = cast(dict[str, Any], req.payload.get("event", {})) msg = cast(str | None, event.get("text")) channel = cast(str | None, event.get("channel")) channel_specific_logger = setup_logger(extra={SLACK_CHANNEL_ID: channel}) # This should never happen, but we can't continue without a channel since # we can't send a response without it if not channel: channel_specific_logger.warning("Found message without channel - skipping") return False if not msg: channel_specific_logger.warning( "Cannot respond to empty message - skipping" ) return False if ( req.payload.setdefault("event", {}).get("user", "") == _OFFICIAL_SLACKBOT_USER_ID ): channel_specific_logger.info( "Ignoring messages from Slack's official Slackbot" ) return False if ( msg in _SLACK_GREETINGS_TO_IGNORE or remove_onyx_bot_tag(tenant_id, msg, client=client.web_client) in _SLACK_GREETINGS_TO_IGNORE ): channel_specific_logger.error( f"Ignoring weird Slack greeting message: '{msg}'" ) channel_specific_logger.error( f"Weird Slack greeting message payload: '{req.payload}'" ) return False # Ensure that the message is a new message of expected type event_type = event.get("type") event.get("channel_type") if event_type not in ["app_mention", "message"]: return False bot_token_user_id, bot_token_bot_id = get_onyx_bot_auth_ids( tenant_id, client.web_client ) if event_type == "message": is_onyx_bot_msg = False is_tagged = False event_user = event.get("user", "") event_bot_id = event.get("bot_id", "") is_dm = event.get("channel_type") == "im" if bot_token_user_id and f"<@{bot_token_user_id}>" in msg: is_tagged = True if bot_token_user_id and bot_token_user_id in event_user: is_onyx_bot_msg = True if bot_token_bot_id and bot_token_bot_id in event_bot_id: is_onyx_bot_msg = True # OnyxBot should never respond to itself if is_onyx_bot_msg: logger.info("Ignoring message from OnyxBot (self-message)") return False # DMs with the bot don't pick up the @OnyxBot so we have to keep the # caught events_api if is_tagged and not is_dm: # Let the tag flow handle this case, don't reply twice return False # Check if this is a bot message (either via bot_profile or bot_message subtype) is_bot_message = bool( event.get("bot_profile") or event.get("subtype") == "bot_message" ) if is_bot_message: channel_name, _ = get_channel_name_from_id( client=client.web_client, channel_id=channel ) with get_session_with_current_tenant() as db_session: slack_channel_config = get_slack_channel_config_for_bot_and_channel( db_session=db_session, slack_bot_id=client.slack_bot_id, channel_name=channel_name, ) # If OnyxBot is not specifically tagged and the channel is not set to respond to bots, ignore the message if (not bot_token_user_id or bot_token_user_id not in msg) and ( not slack_channel_config or not slack_channel_config.channel_config.get("respond_to_bots") ): channel_specific_logger.info( "Ignoring message from bot since respond_to_bots is disabled" ) return False # Ignore things like channel_join, channel_leave, etc. # NOTE: "file_share" is just a message with a file attachment, so we # should not ignore it message_subtype = event.get("subtype") if message_subtype not in [None, "file_share", "bot_message"]: channel_specific_logger.info( f"Ignoring message with subtype '{message_subtype}' since it is a special message type" ) return False message_ts = event.get("ts") thread_ts = event.get("thread_ts") # Pick the root of the thread (if a thread exists) # Can respond in thread if it's an "im" directly to Onyx or @OnyxBot is tagged if ( thread_ts and message_ts != thread_ts and event_type != "app_mention" and event.get("channel_type") != "im" ): channel_specific_logger.debug( "Skipping message since it is not the root of a thread" ) return False msg = cast(str, event.get("text", "")) if not msg: channel_specific_logger.error("Unable to process empty message") return False if req.type == "slash_commands": # Verify that there's an associated channel channel = req.payload.get("channel_id") channel_specific_logger = setup_logger(extra={SLACK_CHANNEL_ID: channel}) if not channel: channel_specific_logger.error( "Received OnyxBot command without channel - skipping" ) return False sender = req.payload.get("user_id") if not sender: channel_specific_logger.error( "Cannot respond to OnyxBot command without sender to respond to." ) return False if not check_message_limit(): return False # Don't log Slack message content logger.debug( f"Handling Slack request: {client.bot_name=} '{sanitize_slack_payload(req.payload)=}'" ) return True def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) -> None: if actions := req.payload.get("actions"): action = cast(dict[str, Any], actions[0]) feedback_type = cast(str, action.get("action_id")) feedback_msg_reminder = cast(str, action.get("value")) feedback_id = cast(str, action.get("block_id")) channel_id = cast(str, req.payload["container"]["channel_id"]) thread_ts = cast( str, req.payload["container"].get("thread_ts") or req.payload["container"].get("message_ts"), ) else: logger.error("Unable to process feedback. Action not found") return user_id = cast(str, req.payload["user"]["id"]) handle_slack_feedback( feedback_id=feedback_id, feedback_type=feedback_type, feedback_msg_reminder=feedback_msg_reminder, client=client.web_client, user_id_to_post_confirmation=user_id, channel_id_to_post_confirmation=channel_id, thread_ts_to_post_confirmation=thread_ts, ) query_event_id, _, _ = decompose_action_id(feedback_id) logger.info(f"Successfully handled QA feedback for event: {query_event_id}") def build_request_details( req: SocketModeRequest, client: TenantSocketModeClient ) -> SlackMessageInfo: tagged: bool = False tenant_id = get_current_tenant_id() if req.type == "events_api": event = cast(dict[str, Any], req.payload["event"]) msg = cast(str, event["text"]) channel = cast(str, event["channel"]) # Check for both app_mention events and messages containing bot tag bot_token_user_id, _ = get_onyx_bot_auth_ids(tenant_id, client.web_client) message_ts = event.get("ts") thread_ts = event.get("thread_ts") sender_id = event.get("user") or None expert_info = expert_info_from_slack_id( sender_id, client.web_client, user_cache={} ) email = expert_info.email if expert_info else None msg = remove_onyx_bot_tag(tenant_id, msg, client=client.web_client) logger.info(f"Received Slack message: {msg}") event_type = event.get("type") if event_type == "app_mention": tagged = True if event_type == "message": if bot_token_user_id: if f"<@{bot_token_user_id}>" in msg: tagged = True if tagged: logger.debug("User tagged OnyxBot") # Build Slack context for federated search # Get proper channel type from Slack API instead of relying on event.channel_type channel_type = get_channel_type_from_id(client.web_client, channel) slack_context = SlackContext( channel_type=channel_type, channel_id=channel, user_id=sender_id or "unknown", message_ts=message_ts, ) logger.info( f"build_request_details: Capturing Slack context: " f"channel_type={channel_type} channel_id={channel} message_ts={message_ts}" ) if thread_ts != message_ts and thread_ts is not None: thread_messages: list[ThreadMessage] = read_slack_thread( tenant_id=tenant_id, channel=channel, thread=thread_ts, client=client.web_client, ) else: sender_display_name = None if expert_info: sender_display_name = expert_info.display_name if sender_display_name is None: sender_display_name = ( f"{expert_info.first_name} {expert_info.last_name}" if expert_info.last_name else expert_info.first_name ) if sender_display_name is None: sender_display_name = expert_info.email thread_messages = [ ThreadMessage( message=msg, sender=sender_display_name, role=MessageType.USER ) ] return SlackMessageInfo( thread_messages=thread_messages, channel_to_respond=channel, msg_to_respond=cast(str, message_ts or thread_ts), thread_to_respond=cast(str, thread_ts or message_ts), sender_id=sender_id, email=email, bypass_filters=tagged, is_slash_command=False, is_bot_dm=event.get("channel_type") == "im", slack_context=slack_context, # Add Slack context for federated search ) elif req.type == "slash_commands": channel = req.payload["channel_id"] channel_name = req.payload["channel_name"] msg = req.payload["text"] sender = req.payload["user_id"] expert_info = expert_info_from_slack_id( sender, client.web_client, user_cache={} ) email = expert_info.email if expert_info else None # Get proper channel type for slash commands too channel_type = get_channel_type_from_id(client.web_client, channel) slack_context = SlackContext( channel_type=channel_type, channel_id=channel, user_id=sender, message_ts=None, # Slash commands don't have a message timestamp ) logger.info( f"build_request_details: Capturing Slack context for slash command: channel_type={channel_type} channel_id={channel}" ) single_msg = ThreadMessage(message=msg, sender=None, role=MessageType.USER) return SlackMessageInfo( thread_messages=[single_msg], channel_to_respond=channel, msg_to_respond=None, thread_to_respond=None, sender_id=sender, email=email, bypass_filters=True, is_slash_command=True, is_bot_dm=channel_name == "directmessage", slack_context=slack_context, # Add Slack context for federated search ) raise RuntimeError("Programming fault, this should never happen.") def apologize_for_fail( details: SlackMessageInfo, client: TenantSocketModeClient, ) -> None: respond_in_thread_or_channel( client=client.web_client, channel=details.channel_to_respond, thread_ts=details.msg_to_respond, text="Sorry, we weren't able to find anything relevant :cold_sweat:", ) def process_message( req: SocketModeRequest, client: TenantSocketModeClient, notify_no_answer: bool = NOTIFY_SLACKBOT_NO_ANSWER, ) -> None: tenant_id = get_current_tenant_id() if req.type == "events_api": event = cast(dict[str, Any], req.payload["event"]) event_type = event.get("type") logger.info( f"process_message start: {tenant_id=} {req.type=} {req.envelope_id=} {event_type=}" ) else: logger.info( f"process_message start: {tenant_id=} {req.type=} {req.envelope_id=}" ) # Throw out requests that can't or shouldn't be handled if not prefilter_requests(req, client): logger.info( f"process_message prefiltered: {tenant_id=} {req.type=} {req.envelope_id=}" ) return details = build_request_details(req, client) channel = details.channel_to_respond channel_name, is_dm = get_channel_name_from_id( client=client.web_client, channel_id=channel ) with get_session_with_current_tenant() as db_session: slack_channel_config = get_slack_channel_config_for_bot_and_channel( db_session=db_session, slack_bot_id=client.slack_bot_id, channel_name=channel_name, ) follow_up = bool( slack_channel_config.channel_config and slack_channel_config.channel_config.get("follow_up_tags") is not None ) feedback_reminder_id = schedule_feedback_reminder( details=details, client=client.web_client, include_followup=follow_up ) failed = handle_message( message_info=details, slack_channel_config=slack_channel_config, client=client.web_client, feedback_reminder_id=feedback_reminder_id, ) if failed: if feedback_reminder_id: remove_scheduled_feedback_reminder( client=client.web_client, channel=details.sender_id, msg_id=feedback_reminder_id, ) # Skipping answering due to pre-filtering is not considered a failure if notify_no_answer: apologize_for_fail(details, client) logger.info( f"process_message finished: success={not failed} {tenant_id=} {req.type=} {req.envelope_id=}" ) def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None: response = SocketModeResponse(envelope_id=req.envelope_id) client.send_socket_mode_response(response) def action_routing(req: SocketModeRequest, client: TenantSocketModeClient) -> None: if actions := req.payload.get("actions"): action = cast(dict[str, Any], actions[0]) if action["action_id"] in [DISLIKE_BLOCK_ACTION_ID, LIKE_BLOCK_ACTION_ID]: # AI Answer feedback return process_feedback(req, client) elif action["action_id"] in [ SHOW_EVERYONE_ACTION_ID, KEEP_TO_YOURSELF_ACTION_ID, ]: # Publish ephemeral message or keep hidden in main channel return handle_publish_ephemeral_message_button( req, client, action["action_id"] ) elif action["action_id"] == FEEDBACK_DOC_BUTTON_BLOCK_ACTION_ID: # Activation of the "source feedback" button return handle_doc_feedback_button(req, client) elif action["action_id"] == FOLLOWUP_BUTTON_ACTION_ID: return handle_followup_button(req, client) elif action["action_id"] == IMMEDIATE_RESOLVED_BUTTON_ACTION_ID: return handle_followup_resolved_button(req, client, immediate=True) elif action["action_id"] == FOLLOWUP_BUTTON_RESOLVED_ACTION_ID: return handle_followup_resolved_button(req, client, immediate=False) elif action["action_id"] == GENERATE_ANSWER_BUTTON_ACTION_ID: return handle_generate_answer_button(req, client) def view_routing(req: SocketModeRequest, client: TenantSocketModeClient) -> None: if view := req.payload.get("view"): if view["callback_id"] == VIEW_DOC_FEEDBACK_ID: return process_feedback(req, client) def _extract_channel_from_request(req: SocketModeRequest) -> str | None: """Best-effort channel extraction from any Slack request type.""" if req.type == "events_api": return cast(dict[str, Any], req.payload.get("event", {})).get("channel") elif req.type == "slash_commands": return req.payload.get("channel_id") elif req.type == "interactive": container = req.payload.get("container", {}) return container.get("channel_id") or req.payload.get("channel", {}).get("id") return None def _check_tenant_gated(client: TenantSocketModeClient, req: SocketModeRequest) -> bool: """Check if the current tenant is gated (suspended or license expired). Multi-tenant: checks the gated tenants Redis set (populated by control plane). Self-hosted: checks the cached license metadata for expiry. Returns True if blocked. """ from onyx.server.settings.models import ApplicationStatus # Multi-tenant path: control plane marks gated tenants in Redis is_gated: bool = fetch_ee_implementation_or_noop( "onyx.server.tenants.product_gating", "is_tenant_gated", False, )(get_current_tenant_id()) # Self-hosted path: check license metadata cache if not is_gated: get_cached_metadata = fetch_ee_implementation_or_noop( "onyx.db.license", "get_cached_license_metadata", None, ) metadata = get_cached_metadata() if metadata is not None: if metadata.status == ApplicationStatus.GATED_ACCESS: is_gated = True if not is_gated: return False # Only notify once per user action: # - Skip bot messages (avoids feedback loop from our own response) # - Skip app_mention events (Slack fires both app_mention AND message # for @mentions; we respond on the message event only) event = req.payload.get("event", {}) if req.type == "events_api" else {} is_bot_event = bool( event.get("bot_id") or event.get("bot_profile") or event.get("subtype") == "bot_message" ) is_duplicate_mention = event.get("type") == "app_mention" if not is_bot_event and not is_duplicate_mention: channel = _extract_channel_from_request(req) thread_ts = event.get("thread_ts") or event.get("ts") if channel: respond_in_thread_or_channel( client=client.web_client, channel=channel, thread_ts=thread_ts, text=( "Your organization's subscription has expired. Please contact your Onyx administrator to restore access." ), ) logger.info(f"Blocked Slack request for gated tenant {get_current_tenant_id()}") return True def create_process_slack_event() -> ( Callable[[TenantSocketModeClient, SocketModeRequest], None] ): def process_slack_event( client: TenantSocketModeClient, req: SocketModeRequest ) -> None: # Always respond right away, if Slack doesn't receive these frequently enough # it will assume the Bot is DEAD!!! :( acknowledge_message(req, client) if _check_tenant_gated(client, req): return try: if req.type == "interactive": if req.payload.get("type") == "block_actions": return action_routing(req, client) elif req.payload.get("type") == "view_submission": return view_routing(req, client) elif req.type == "events_api" or req.type == "slash_commands": return process_message(req, client) except Exception: logger.exception("Failed to process slack event") return process_slack_event def _get_socket_client( slack_bot_tokens: SlackBotTokens, tenant_id: str, slack_bot_id: int ) -> TenantSocketModeClient: # For more info on how to set this up, checkout the docs: # https://docs.onyx.app/admins/getting_started/slack_bot_setup # use the retry handlers built into the slack sdk connection_error_retry_handler = ConnectionErrorRetryHandler() rate_limit_error_retry_handler = RateLimitErrorRetryHandler(max_retry_count=7) slack_retry_handlers: list[RetryHandler] = [ connection_error_retry_handler, rate_limit_error_retry_handler, ] return TenantSocketModeClient( # This app-level token will be used only for establishing a connection app_token=slack_bot_tokens.app_token, web_client=WebClient( token=slack_bot_tokens.bot_token, retry_handlers=slack_retry_handlers ), tenant_id=tenant_id, slack_bot_id=slack_bot_id, ) if __name__ == "__main__": # Initialize the SqlEngine SqlEngine.init_engine(pool_size=20, max_overflow=5) # Initialize the tenant handler which will manage tenant connections logger.info("Starting SlackbotHandler") tenant_handler = SlackbotHandler() set_is_ee_based_on_env_variable() try: # Keep the main thread alive while tenant_handler.running: time.sleep(1) except Exception: logger.exception("Fatal error in main thread") tenant_handler.shutdown(None, None) ================================================ FILE: backend/onyx/onyxbot/slack/models.py ================================================ from enum import Enum from typing import Literal from pydantic import BaseModel from onyx.configs.constants import MessageType class ChannelType(str, Enum): """Slack channel types.""" IM = "im" # Direct message MPIM = "mpim" # Multi-person direct message PRIVATE_CHANNEL = "private_channel" # Private channel PUBLIC_CHANNEL = "public_channel" # Public channel UNKNOWN = "unknown" # Unknown channel type class SlackContext(BaseModel): """Context information for Slack bot interactions.""" channel_type: ChannelType channel_id: str user_id: str message_ts: str | None = None # Used as request ID for log correlation class ThreadMessage(BaseModel): message: str sender: str | None = None role: MessageType = MessageType.USER class SlackMessageInfo(BaseModel): thread_messages: list[ThreadMessage] channel_to_respond: str msg_to_respond: str | None thread_to_respond: str | None sender_id: str | None email: str | None bypass_filters: bool # User has tagged @OnyxBot is_slash_command: bool # User is using /OnyxBot is_bot_dm: bool # User is direct messaging to OnyxBot slack_context: SlackContext | None = None # Models used to encode the relevant data for the ephemeral message actions class ActionValuesEphemeralMessageMessageInfo(BaseModel): bypass_filters: bool | None channel_to_respond: str | None msg_to_respond: str | None email: str | None sender_id: str | None thread_messages: list[ThreadMessage] | None is_slash_command: bool | None is_bot_dm: bool | None thread_to_respond: str | None class ActionValuesEphemeralMessageChannelConfig(BaseModel): channel_name: str | None respond_tag_only: bool | None respond_to_bots: bool | None is_ephemeral: bool respond_member_group_list: list[str] | None answer_filters: ( list[Literal["well_answered_postfilter", "questionmark_prefilter"]] | None ) follow_up_tags: list[str] | None show_continue_in_web_ui: bool class ActionValuesEphemeralMessage(BaseModel): original_question_ts: str | None feedback_reminder_id: str | None chat_message_id: int message_info: ActionValuesEphemeralMessageMessageInfo channel_conf: ActionValuesEphemeralMessageChannelConfig ================================================ FILE: backend/onyx/onyxbot/slack/utils.py ================================================ import logging import random import re import string import threading import time import uuid from collections.abc import Generator from contextlib import contextmanager from typing import Any from typing import cast from retry import retry from slack_sdk import WebClient from slack_sdk.errors import SlackApiError from slack_sdk.models.blocks import Block from slack_sdk.models.blocks import SectionBlock from slack_sdk.models.metadata import Metadata from slack_sdk.socket_mode import SocketModeClient from onyx.configs.app_configs import DISABLE_TELEMETRY from onyx.configs.constants import ID_SEPARATOR from onyx.configs.constants import MessageType from onyx.configs.onyxbot_configs import ONYX_BOT_FEEDBACK_VISIBILITY from onyx.configs.onyxbot_configs import ONYX_BOT_MAX_QPM from onyx.configs.onyxbot_configs import ONYX_BOT_MAX_WAIT_TIME from onyx.configs.onyxbot_configs import ONYX_BOT_NUM_RETRIES from onyx.configs.onyxbot_configs import ( ONYX_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD, ) from onyx.configs.onyxbot_configs import ( ONYX_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS, ) from onyx.connectors.slack.utils import SlackTextCleaner from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.users import get_user_by_email from onyx.onyxbot.slack.constants import FeedbackVisibility from onyx.onyxbot.slack.models import ChannelType from onyx.onyxbot.slack.models import ThreadMessage from onyx.utils.logger import setup_logger from onyx.utils.telemetry import optional_telemetry from onyx.utils.telemetry import RecordType from onyx.utils.text_processing import replace_whitespaces_w_space from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR logger = setup_logger() slack_token_user_ids: dict[str, str | None] = {} slack_token_bot_ids: dict[str, str | None] = {} slack_token_lock = threading.Lock() _ONYX_BOT_MESSAGE_COUNT: int = 0 _ONYX_BOT_COUNT_START_TIME: float = time.time() def get_onyx_bot_auth_ids( tenant_id: str, web_client: WebClient ) -> tuple[str | None, str | None]: """Returns a tuple of user_id and bot_id.""" user_id: str | None bot_id: str | None global slack_token_user_ids global slack_token_bot_ids with slack_token_lock: user_id = slack_token_user_ids.get(tenant_id) bot_id = slack_token_bot_ids.get(tenant_id) if user_id is None or bot_id is None: response = web_client.auth_test() user_id = response.get("user_id") bot_id = response.get("bot_id") with slack_token_lock: slack_token_user_ids[tenant_id] = user_id slack_token_bot_ids[tenant_id] = bot_id return user_id, bot_id def get_channel_type_from_id(web_client: WebClient, channel_id: str) -> ChannelType: """ Get the channel type from a channel ID using Slack API. Returns: ChannelType enum value """ try: channel_info = web_client.conversations_info(channel=channel_id) if channel_info.get("ok") and channel_info.get("channel"): channel: dict[str, Any] = channel_info.get("channel", {}) if channel.get("is_im"): return ChannelType.IM # Direct message elif channel.get("is_mpim"): return ChannelType.MPIM # Multi-person direct message elif channel.get("is_private"): return ChannelType.PRIVATE_CHANNEL # Private channel elif channel.get("is_channel"): return ChannelType.PUBLIC_CHANNEL # Public channel else: logger.warning( f"Could not determine channel type for {channel_id}, defaulting to unknown" ) return ChannelType.UNKNOWN else: logger.warning(f"Invalid channel info response for {channel_id}") return ChannelType.UNKNOWN except Exception as e: logger.warning( f"Error getting channel info for {channel_id}, defaulting to unknown: {e}" ) return ChannelType.UNKNOWN def check_message_limit() -> bool: """ This isnt a perfect solution. High traffic at the end of one period and start of another could cause the limit to be exceeded. """ if ONYX_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD <= 0: return True global _ONYX_BOT_MESSAGE_COUNT global _ONYX_BOT_COUNT_START_TIME time_since_start = time.time() - _ONYX_BOT_COUNT_START_TIME if time_since_start > ONYX_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS: _ONYX_BOT_MESSAGE_COUNT = 0 _ONYX_BOT_COUNT_START_TIME = time.time() if (_ONYX_BOT_MESSAGE_COUNT + 1) > ONYX_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD: logger.error( f"OnyxBot has reached the message limit {ONYX_BOT_RESPONSE_LIMIT_PER_TIME_PERIOD}" f" for the time period {ONYX_BOT_RESPONSE_LIMIT_TIME_PERIOD_SECONDS} seconds." " These limits are configurable in backend/onyx/configs/onyxbot_configs.py" ) return False _ONYX_BOT_MESSAGE_COUNT += 1 return True def update_emote_react( emoji: str, channel: str, message_ts: str | None, remove: bool, client: WebClient, ) -> None: if not message_ts: action = "remove" if remove else "add" logger.error(f"update_emote_react - no message specified: {channel=} {action=}") return if remove: try: client.reactions_remove( name=emoji, channel=channel, timestamp=message_ts, ) except SlackApiError as e: logger.error(f"Failed to remove Reaction due to: {e}") return try: client.reactions_add( name=emoji, channel=channel, timestamp=message_ts, ) except SlackApiError as e: logger.error(f"Was not able to react to user message due to: {e}") return def remove_onyx_bot_tag(tenant_id: str, message_str: str, client: WebClient) -> str: bot_token_user_id, _ = get_onyx_bot_auth_ids(tenant_id, web_client=client) return re.sub(rf"<@{bot_token_user_id}>\s*", "", message_str) def _check_for_url_in_block(block: Block) -> bool: """ Check if the block has a key that contains "url" in it """ block_dict = block.to_dict() def check_dict_for_url(d: dict) -> bool: for key, value in d.items(): if "url" in key.lower(): return True if isinstance(value, dict): if check_dict_for_url(value): return True elif isinstance(value, list): for item in value: if isinstance(item, dict) and check_dict_for_url(item): return True return False return check_dict_for_url(block_dict) def _build_error_block(error_message: str) -> Block: """ Build an error block to display in slack so that the user can see the error without completely breaking """ display_text = ( "There was an error displaying all of the Onyx answers." f" Please let an admin or an onyx developer know. Error: {error_message}" ) return SectionBlock(text=display_text) @retry( tries=ONYX_BOT_NUM_RETRIES, delay=0.25, backoff=2, logger=cast(logging.Logger, logger), ) def respond_in_thread_or_channel( client: WebClient, channel: str, thread_ts: str | None, text: str | None = None, blocks: list[Block] | None = None, receiver_ids: list[str] | None = None, metadata: Metadata | None = None, unfurl: bool = True, send_as_ephemeral: bool | None = True, # noqa: ARG001 ) -> list[str]: if not text and not blocks: raise ValueError("One of `text` or `blocks` must be provided") message_ids: list[str] = [] if not receiver_ids: try: response = client.chat_postMessage( channel=channel, text=text, blocks=blocks, thread_ts=thread_ts, metadata=metadata, unfurl_links=unfurl, unfurl_media=unfurl, ) except Exception as e: blocks_str = str(blocks)[:1024] # truncate block logging logger.warning(f"Failed to post message: {e} \n blocks: {blocks_str}") logger.warning("Trying again without blocks that have urls") if not blocks: raise e blocks_without_urls = [ block for block in blocks if not _check_for_url_in_block(block) ] blocks_without_urls.append(_build_error_block(str(e))) # Try again wtihout blocks containing url response = client.chat_postMessage( channel=channel, text=text, blocks=blocks_without_urls, thread_ts=thread_ts, metadata=metadata, unfurl_links=unfurl, unfurl_media=unfurl, ) message_ids.append(response["message_ts"]) else: for receiver in receiver_ids: try: response = client.chat_postEphemeral( channel=channel, user=receiver, text=text, blocks=blocks, thread_ts=thread_ts, metadata=metadata, unfurl_links=unfurl, unfurl_media=unfurl, ) except Exception as e: blocks_str = str(blocks)[:1024] # truncate block logging logger.warning(f"Failed to post message: {e} \n blocks: {blocks_str}") logger.warning("Trying again without blocks that have urls") if not blocks: raise e blocks_without_urls = [ block for block in blocks if not _check_for_url_in_block(block) ] blocks_without_urls.append(_build_error_block(str(e))) # Try again wtihout blocks containing url response = client.chat_postEphemeral( channel=channel, user=receiver, text=text, blocks=blocks_without_urls, thread_ts=thread_ts, metadata=metadata, unfurl_links=unfurl, unfurl_media=unfurl, ) message_ids.append(response["message_ts"]) return message_ids def build_feedback_id( message_id: int, document_id: str | None = None, document_rank: int | None = None, ) -> str: unique_prefix = "".join(random.choice(string.ascii_letters) for _ in range(10)) if document_id is not None: if not document_id or document_rank is None: raise ValueError("Invalid document, missing information") if ID_SEPARATOR in document_id: raise ValueError( "Separator pattern should not already exist in document id" ) feedback_id = ID_SEPARATOR.join( [str(message_id), document_id, str(document_rank)] ) else: feedback_id = str(message_id) return unique_prefix + ID_SEPARATOR + feedback_id def build_publish_ephemeral_message_id( original_question_ts: str, ) -> str: return "publish_ephemeral_message__" + original_question_ts def build_continue_in_web_ui_id( message_id: int, ) -> str: unique_prefix = str(uuid.uuid4())[:10] return unique_prefix + ID_SEPARATOR + str(message_id) def decompose_action_id(feedback_id: str) -> tuple[int, str | None, int | None]: """Decompose into query_id, document_id, document_rank, see above function""" try: components = feedback_id.split(ID_SEPARATOR) if len(components) != 2 and len(components) != 4: raise ValueError("Feedback ID does not contain right number of elements") if len(components) == 2: return int(components[-1]), None, None return int(components[1]), components[2], int(components[3]) except Exception as e: logger.error(e) raise ValueError("Received invalid Feedback Identifier") def get_view_values(state_values: dict[str, Any]) -> dict[str, str]: """Extract view values Args: state_values (dict): The Slack view-submission values Returns: dict: keys/values of the view state content """ view_values = {} for _, view_data in state_values.items(): for k, v in view_data.items(): if ( "selected_option" in v and isinstance(v["selected_option"], dict) and "value" in v["selected_option"] ): view_values[k] = v["selected_option"]["value"] elif "selected_options" in v and isinstance(v["selected_options"], list): view_values[k] = [ x["value"] for x in v["selected_options"] if "value" in x ] elif "selected_date" in v: view_values[k] = v["selected_date"] elif "value" in v: view_values[k] = v["value"] return view_values def translate_vespa_highlight_to_slack(match_strs: list[str], used_chars: int) -> str: def _replace_highlight(s: str) -> str: s = re.sub(r"(?<=[^\s])(.*?)", r"\1", s) s = s.replace("", "*").replace("", "*") return s final_matches = [ replace_whitespaces_w_space(_replace_highlight(match_str)).strip() for match_str in match_strs if match_str ] combined = "... ".join(final_matches) # Slack introduces "Show More" after 300 on desktop which is ugly # But don't trim the message if there is still a highlight after 300 chars remaining = 300 - used_chars if len(combined) > remaining and "*" not in combined[remaining:]: combined = combined[: remaining - 3] + "..." return combined def remove_slack_text_interactions(slack_str: str) -> str: slack_str = SlackTextCleaner.replace_tags_basic(slack_str) slack_str = SlackTextCleaner.replace_channels_basic(slack_str) slack_str = SlackTextCleaner.replace_special_mentions(slack_str) slack_str = SlackTextCleaner.replace_special_catchall(slack_str) slack_str = SlackTextCleaner.add_zero_width_whitespace_after_tag(slack_str) return slack_str def get_channel_from_id(client: WebClient, channel_id: str) -> dict[str, Any]: response = client.conversations_info(channel=channel_id) response.validate() return response["channel"] def get_channel_name_from_id( client: WebClient, channel_id: str ) -> tuple[str | None, bool]: try: channel_info = get_channel_from_id(client, channel_id) name = channel_info.get("name") is_dm = any([channel_info.get("is_im"), channel_info.get("is_mpim")]) return name, is_dm except SlackApiError as e: logger.exception(f"Couldn't fetch channel name from id: {channel_id}") raise e def fetch_slack_user_ids_from_emails( user_emails: list[str], client: WebClient ) -> tuple[list[str], list[str]]: user_ids: list[str] = [] failed_to_find: list[str] = [] for email in user_emails: try: user = client.users_lookupByEmail(email=email) user_ids.append(user.data["user"]["id"]) # type: ignore except Exception: logger.error(f"Was not able to find slack user by email: {email}") failed_to_find.append(email) return user_ids, failed_to_find def fetch_user_ids_from_groups( given_names: list[str], client: WebClient ) -> tuple[list[str], list[str]]: user_ids: list[str] = [] failed_to_find: list[str] = [] try: response = client.usergroups_list() if not isinstance(response.data, dict): logger.error("Error fetching user groups") return user_ids, given_names all_group_data = response.data.get("usergroups", []) name_id_map = {d["name"]: d["id"] for d in all_group_data} handle_id_map = {d["handle"]: d["id"] for d in all_group_data} for given_name in given_names: group_id = name_id_map.get(given_name) or handle_id_map.get( given_name.lstrip("@") ) if not group_id: failed_to_find.append(given_name) continue try: response = client.usergroups_users_list(usergroup=group_id) if isinstance(response.data, dict): user_ids.extend(response.data.get("users", [])) else: failed_to_find.append(given_name) except Exception as e: logger.error(f"Error fetching user group ids: {str(e)}") failed_to_find.append(given_name) except Exception as e: logger.error(f"Error fetching user groups: {str(e)}") failed_to_find = given_names return user_ids, failed_to_find def fetch_group_ids_from_names( given_names: list[str], client: WebClient ) -> tuple[list[str], list[str]]: group_data: list[str] = [] failed_to_find: list[str] = [] try: response = client.usergroups_list() if not isinstance(response.data, dict): logger.error("Error fetching user groups") return group_data, given_names all_group_data = response.data.get("usergroups", []) name_id_map = {d["name"]: d["id"] for d in all_group_data} handle_id_map = {d["handle"]: d["id"] for d in all_group_data} for given_name in given_names: id = handle_id_map.get(given_name.lstrip("@")) id = id or name_id_map.get(given_name) if id: group_data.append(id) else: failed_to_find.append(given_name) except Exception as e: failed_to_find = given_names logger.error(f"Error fetching user groups: {str(e)}") return group_data, failed_to_find def fetch_user_semantic_id_from_id( user_id: str | None, client: WebClient ) -> str | None: if not user_id: return None response = client.users_info(user=user_id) if not response["ok"]: return None user: dict = cast(dict[Any, dict], response.data).get("user", {}) return ( user.get("real_name") or user.get("name") or user.get("profile", {}).get("email") ) def read_slack_thread( tenant_id: str, channel: str, thread: str, client: WebClient ) -> list[ThreadMessage]: thread_messages: list[ThreadMessage] = [] response = client.conversations_replies(channel=channel, ts=thread) replies = cast(dict, response.data).get("messages", []) for reply in replies: if "user" in reply and "bot_id" not in reply: message = reply["text"] user_sem_id = ( fetch_user_semantic_id_from_id(reply.get("user"), client) or "Unknown User" ) message_type = MessageType.USER else: blocks: Any is_onyx_bot_response = False reply_user = reply.get("user") reply_bot_id = reply.get("bot_id") self_slack_bot_user_id, self_slack_bot_bot_id = get_onyx_bot_auth_ids( tenant_id, client ) if reply_user is not None and reply_user == self_slack_bot_user_id: is_onyx_bot_response = True if reply_bot_id is not None and reply_bot_id == self_slack_bot_bot_id: is_onyx_bot_response = True if is_onyx_bot_response: # OnyxBot response message_type = MessageType.ASSISTANT user_sem_id = "Assistant" # OnyxBot responses have both text and blocks # The useful content is in the blocks, specifically the first block unless there are # auto-detected filters blocks = reply.get("blocks") if not blocks: logger.warning(f"OnyxBot response has no blocks: {reply}") continue message = blocks[0].get("text", {}).get("text") # If auto-detected filters are on, use the second block for the actual answer # The first block is the auto-detected filters if message is not None and message.startswith("_Filters"): if len(blocks) < 2: logger.warning(f"Only filter blocks found: {reply}") continue # This is the OnyxBot answer format, if there is a change to how we respond, # this will need to be updated to get the correct "answer" portion message = reply["blocks"][1].get("text", {}).get("text") else: # Other bots are not counted as the LLM response which only comes from Onyx message_type = MessageType.USER bot_user_name = fetch_user_semantic_id_from_id( reply.get("user"), client ) user_sem_id = bot_user_name or "Unknown" + " Bot" # For other bots, just use the text as we have no way of knowing that the # useful portion is message = reply.get("text") if not message: message = blocks[0].get("text", {}).get("text") if not message: logger.warning("Skipping Slack thread message, no text found") continue message = remove_onyx_bot_tag(tenant_id, message, client=client) thread_messages.append( ThreadMessage(message=message, sender=user_sem_id, role=message_type) ) return thread_messages def slack_usage_report(action: str, sender_id: str | None, client: WebClient) -> None: if DISABLE_TELEMETRY: return onyx_user = None sender_email = None try: sender_email = client.users_info(user=sender_id).data["user"]["profile"]["email"] # type: ignore except Exception: logger.warning("Unable to find sender email") if sender_email is not None: with get_session_with_current_tenant() as db_session: onyx_user = get_user_by_email(email=sender_email, db_session=db_session) optional_telemetry( record_type=RecordType.USAGE, data={"action": action}, user_id=str(onyx_user.id) if onyx_user else "Non-Onyx-Or-No-Auth-User", ) class SlackRateLimiter: def __init__(self) -> None: self.max_qpm: int | None = ONYX_BOT_MAX_QPM self.max_wait_time = ONYX_BOT_MAX_WAIT_TIME self.active_question = 0 self.last_reset_time = time.time() self.waiting_questions: list[int] = [] def refill(self) -> None: # If elapsed time is greater than the period, reset the active question count if (time.time() - self.last_reset_time) > 60: self.active_question = 0 self.last_reset_time = time.time() def notify( self, client: WebClient, channel: str, position: int, thread_ts: str | None ) -> None: respond_in_thread_or_channel( client=client, channel=channel, receiver_ids=None, text=f"Your question has been queued. You are in position {position}.\nPlease wait a moment :hourglass_flowing_sand:", thread_ts=thread_ts, ) def is_available(self) -> bool: if self.max_qpm is None: return True self.refill() return self.active_question < self.max_qpm def acquire_slot(self) -> None: self.active_question += 1 def init_waiter(self) -> tuple[int, int]: func_randid = random.getrandbits(128) self.waiting_questions.append(func_randid) position = self.waiting_questions.index(func_randid) + 1 return func_randid, position def waiter(self, func_randid: int) -> None: if self.max_qpm is None: return wait_time = 0 while ( self.active_question >= self.max_qpm or self.waiting_questions[0] != func_randid ): if wait_time > self.max_wait_time: raise TimeoutError time.sleep(2) wait_time += 2 self.refill() del self.waiting_questions[0] def get_feedback_visibility() -> FeedbackVisibility: try: return FeedbackVisibility(ONYX_BOT_FEEDBACK_VISIBILITY.lower()) except ValueError: return FeedbackVisibility.PRIVATE class TenantSocketModeClient(SocketModeClient): def __init__(self, tenant_id: str, slack_bot_id: int, *args: Any, **kwargs: Any): super().__init__(*args, **kwargs) self._tenant_id = tenant_id self.slack_bot_id = slack_bot_id self.bot_name: str = "Unnamed" @contextmanager def _set_tenant_context(self) -> Generator[None, None, None]: token = None try: if self._tenant_id: token = CURRENT_TENANT_ID_CONTEXTVAR.set(self._tenant_id) yield finally: if token: CURRENT_TENANT_ID_CONTEXTVAR.reset(token) def enqueue_message(self, message: str) -> None: with self._set_tenant_context(): super().enqueue_message(message) def process_message(self) -> None: with self._set_tenant_context(): super().process_message() def run_message_listeners(self, message: dict, raw_message: str) -> None: with self._set_tenant_context(): super().run_message_listeners(message, raw_message) ================================================ FILE: backend/onyx/prompts/__init__.py ================================================ ================================================ FILE: backend/onyx/prompts/basic_memory.py ================================================ # ruff: noqa: E501, W605 start # Note that the user_basic_information is only included if we have at least 1 of the following: user_name, user_email, user_role # This is included because sometimes we need to know the user's name or basic info to best generate the memory. FULL_MEMORY_UPDATE_PROMPT = """ You are a memory update agent that helps the user add or update memories. You are given a list of existing memories and a new memory to add. \ Just as context, you are also given the last few user messages from the conversation which generated the new memory. You must determine if the memory is brand new or if it is related to an existing memory. \ If the new memory is an update to an existing memory or contradicts an existing memory, it should be treated as an update and you should reference the existing memory by memory_id (see below). \ The memory should omit the user's name and direct reference to the user - for example, a memory like "Yuhong prefers dark mode." should be modified to "Prefers dark mode." (if the user's name is Yuhong). # Truncated chat history {chat_history}{user_basic_information} # User's existing memories {existing_memories} # New memory the user wants to insert {new_memory} # Response Style You MUST respond in a json which follows the following format and keys: ```json {{ "operation": "add or update", "memory_id": "if the operation is update, the id of the memory to update, otherwise null", "memory_text": "the text of the memory to add or update" }} ``` """.strip() # ruff: noqa: E501, W605 end MEMORY_USER_BASIC_INFORMATION_PROMPT = """ # User Basic Information User name: {user_name} User email: {user_email} User role: {user_role} """ ================================================ FILE: backend/onyx/prompts/chat_prompts.py ================================================ # ruff: noqa: E501, W605 start from onyx.prompts.constants import REMINDER_TAG_NO_HEADER DATETIME_REPLACEMENT_PAT = "{{CURRENT_DATETIME}}" CITATION_GUIDANCE_REPLACEMENT_PAT = "{{CITATION_GUIDANCE}}" REMINDER_TAG_REPLACEMENT_PAT = "{{REMINDER_TAG_DESCRIPTION}}" # Note this uses a string pattern replacement so the user can also include it in their custom prompts. Keeps the replacement logic simple # This is editable by the user in the admin UI. # The first line is intended to help guide the general feel/behavior of the system. DEFAULT_SYSTEM_PROMPT = f""" You are an expert assistant who is truthful, nuanced, insightful, and efficient. \ Your goal is to deeply understand the user's intent, think step-by-step through complex problems, provide clear and accurate answers, and proactively anticipate helpful follow-up information. \ Whenever there is any ambiguity around the user's query (or more information would be helpful), you use available tools (if any) to get more context. The current date is {DATETIME_REPLACEMENT_PAT}.{CITATION_GUIDANCE_REPLACEMENT_PAT} # Response Style You use different text styles, bolding, emojis (sparingly), block quotes, and other formatting to make your responses more readable and engaging. You use proper Markdown and LaTeX to format your responses for math, scientific, and chemical formulas, symbols, etc.: '$$\\n[expression]\\n$$' for standalone cases and '\\( [expression] \\)' when inline. For code you prefer to use Markdown and specify the language. You can use horizontal rules (---) to separate sections of your responses. You can use Markdown tables to format your responses for data, lists, and other structured information. {REMINDER_TAG_REPLACEMENT_PAT} """.lstrip() COMPANY_NAME_BLOCK = """ The user is at an organization called `{company_name}`. """ COMPANY_DESCRIPTION_BLOCK = """ Organization description: {company_description} """ # This is added to the system prompt prior to the tools section and is applied only if search tools have been run REQUIRE_CITATION_GUIDANCE = """ CRITICAL: If referencing knowledge from searches, cite relevant statements INLINE using the format [1], [2], [3], etc. to reference the "document" field. \ DO NOT provide any links following the citations. Cite inline as opposed to leaving all citations until the very end of the response. """ # Reminder message if any search tool has been run anytime in the chat turn CITATION_REMINDER = """ Remember to provide inline citations in the format [1], [2], [3], etc. based on the "document" field of the documents. """.strip() LAST_CYCLE_CITATION_REMINDER = """ You are on your last cycle and no longer have any tool calls available. You must answer the query now to the best of your ability. """.strip() # Reminder message that replaces the usual reminder if web_search was the last tool call OPEN_URL_REMINDER = """ Remember that after using web_search, you are encouraged to open some pages to get more context unless the query is completely answered by the snippets. Open the pages that look the most promising and high quality by calling the open_url tool with an array of URLs. Open as many as you want. If you do have enough to answer, remember to provide INLINE citations using the "document" field in the format [1], [2], [3], etc. """.strip() IMAGE_GEN_REMINDER = """ Very briefly describe the image(s) generated. Do not include any links or attachments. """.strip() FILE_REMINDER = """ Your code execution generated file(s) with download links. If you reference or share these files, use the exact markdown format [filename](file_link) with the file_link from the execution result. """.strip() # Specifically for OpenAI models, this prefix needs to be in place for the model to output markdown and correct styling CODE_BLOCK_MARKDOWN = "Formatting re-enabled. " # This is just for Slack context today ADDITIONAL_CONTEXT_PROMPT = """ Here is some additional context which may be relevant to the user query: {additional_context} """.strip() TOOL_CALL_RESPONSE_CROSS_MESSAGE = """ This tool call completed but the results are no longer accessible. """.strip() # This is used to add the current date and time to the prompt in the case where the Agent should be aware of the current # date and time but the replacement pattern is not present in the prompt. ADDITIONAL_INFO = "\n\nAdditional Information:\n\t- {datetime_info}." CHAT_NAMING_SYSTEM_PROMPT = f""" Given the conversation history, provide a SHORT name for the conversation. Focus the name on the important keywords to convey the topic of the conversation. \ Make sure the name is in the same language as the user's first message. {REMINDER_TAG_NO_HEADER} IMPORTANT: DO NOT OUTPUT ANYTHING ASIDE FROM THE NAME. MAKE IT AS CONCISE AS POSSIBLE. NEVER USE MORE THAN 5 WORDS, LESS IS FINE. """.strip() CHAT_NAMING_REMINDER = """ Provide a short name for the conversation. Refer to other messages in the conversation (not including this one) to determine the language of the name. IMPORTANT: DO NOT OUTPUT ANYTHING ASIDE FROM THE NAME. MAKE IT AS CONCISE AS POSSIBLE. NEVER USE MORE THAN 5 WORDS, LESS IS FINE. """.strip() # ruff: noqa: E501, W605 end ================================================ FILE: backend/onyx/prompts/chat_tools.py ================================================ # These prompts are to support tool calling. Currently not used in the main flow or via any configs # The current generation of LLM is too unreliable for this task. # Onyx retrieval call as a tool option DANSWER_TOOL_NAME = "Current Search" DANSWER_TOOL_DESCRIPTION = "A search tool that can find information on any topic including up to date and proprietary knowledge." # Tool calling format inspired from LangChain TOOL_TEMPLATE = """ TOOLS ------ You can use tools to look up information that may be helpful in answering the user's \ original question. The available tools are: {tool_overviews} RESPONSE FORMAT INSTRUCTIONS ---------------------------- When responding to me, please output a response in one of two formats: **Option 1:** Use this if you want to use a tool. Markdown code snippet formatted in the following schema: ```json {{ "action": string, \\ The action to take. {tool_names} "action_input": string \\ The input to the action }} ``` **Option #2:** Use this if you want to respond directly to the user. Markdown code snippet formatted in the following schema: ```json {{ "action": "Final Answer", "action_input": string \\ You should put what you want to return to use here }} ``` """ # For the case where the user has not configured any tools to call, but still using the tool-flow # expected format TOOL_LESS_PROMPT = """ Respond with a markdown code snippet in the following schema: ```json {{ "action": "Final Answer", "action_input": string \\ You should put what you want to return to use here }} ``` """ # Second part of the prompt to include the user query USER_INPUT = """ USER'S INPUT -------------------- Here is the user's input \ (remember to respond with a markdown code snippet of a json blob with a single action, and NOTHING else): {user_input} """ # After the tool call, this is the following message to get a final answer # Tools are not chained currently, the system must provide an answer after calling a tool TOOL_FOLLOWUP = """ TOOL RESPONSE: --------------------- {tool_output} USER'S INPUT -------------------- Okay, so what is the response to my last comment? If using information obtained from the tools you must \ mention it explicitly without mentioning the tool names - I have forgotten all TOOL RESPONSES! If the tool response is not useful, ignore it completely. {optional_reminder}{hint} IMPORTANT! You MUST respond with a markdown code snippet of a json blob with a single action, and NOTHING else. """ # If no tools were used, but retrieval is enabled, then follow up with this message to get the final answer TOOL_LESS_FOLLOWUP = """ Refer to the following documents when responding to my final query. Ignore any documents that are not relevant. CONTEXT DOCUMENTS: --------------------- {context_str} FINAL QUERY: -------------------- {user_query} {hint_text} """ ================================================ FILE: backend/onyx/prompts/compression_prompts.py ================================================ # Prompts for chat history compression via summarization. # ruff: noqa: E501, W605 start # Cutoff marker helps the LLM focus on summarizing only messages before this point. # This improves "needle in haystack" accuracy by explicitly marking where to stop with an exact pattern which is also placed in locations easily attended to by the LLM (last user message and system prompt). CONTEXT_CUTOFF_START_MARKER = "" CONTEXT_CUTOFF_END_MARKER = "" SUMMARIZATION_CUTOFF_MARKER = f"{CONTEXT_CUTOFF_START_MARKER} Stop summarizing the rest of the conversation past this point. {CONTEXT_CUTOFF_END_MARKER}" SUMMARIZATION_PROMPT = f""" You are a summarization system. Your task is to produce a detailed and accurate summary of a chat conversation up to a specified cutoff message. The cutoff will be marked by the string {CONTEXT_CUTOFF_START_MARKER}. \ IMPORTANT: Do not explicitly mention anything about the cutoff in your response. Do not situate the summary with respect to the cutoff. The context cutoff is only a system injected marker. # Guidelines - Only consider messages that occur at or before the cutoff point. Use the messages after it purely as context without including any of it in the summary. - Preserve factual correctness and intent; do not infer or speculate. - The summary should be information dense and detailed. - The summary should be in paragraph format and long enough to capture all of the most prominent details. # Focus on - Key topics discussed. - Decisions made, tools used, and conclusions reached. - Open questions or unresolved items. - Important constraints, preferences, or assumptions stated. - Omit small talk, repetition, and stylistic filler unless it affects meaning. """.strip() PROGRESSIVE_SUMMARY_SYSTEM_PROMPT_BLOCK = """ # Existing summary There is a previous summary of the conversation. Build on top of this when constructing the new overall summary of the conversation: {previous_summary} """.rstrip() USER_REMINDER = f"Help summarize the conversation up to the cutoff point (do not mention anything related to the cutoff directly in your response). It should be a long form summary of the conversation up to the cutoff point as marked by {CONTEXT_CUTOFF_START_MARKER}. Be thorough." PROGRESSIVE_USER_REMINDER = f"Update the existing summary by incorporating the new messages up to the cutoff point as marked by {CONTEXT_CUTOFF_START_MARKER} (do not mention anything related to the cutoff directly in your response). Be thorough and maintain the long form summary format." # ruff: noqa: E501, W605 end ================================================ FILE: backend/onyx/prompts/constants.py ================================================ # ruff: noqa: E501, W605 start CODE_BLOCK_PAT = "```\n{}\n```" TRIPLE_BACKTICK = "```" SYSTEM_REMINDER_TAG_OPEN = "" SYSTEM_REMINDER_TAG_CLOSE = "" # Tags format inspired by Anthropic and OpenCode REMINDER_TAG_NO_HEADER = f""" User messages may include {SYSTEM_REMINDER_TAG_OPEN} and {SYSTEM_REMINDER_TAG_CLOSE} tags. These {SYSTEM_REMINDER_TAG_OPEN} tags contain useful information and reminders. \ They are automatically added by the system and are not actual user inputs. Behave in accordance to these instructions if relevant, and continue normally if they are not. """.strip() REMINDER_TAG_DESCRIPTION = f""" # System Reminders {REMINDER_TAG_NO_HEADER} """.strip() # ruff: noqa: E501, W605 end ================================================ FILE: backend/onyx/prompts/contextual_retrieval.py ================================================ # NOTE: the prompt separation is partially done for efficiency; previously I tried # to do it all in one prompt with sequential format() calls but this will cause a backend # error when the document contains any {} as python will expect the {} to be filled by # format() arguments # ruff: noqa: E501, W605 start CONTEXTUAL_RAG_PROMPT1 = """ {document} Here is the chunk we want to situate within the whole document""" CONTEXTUAL_RAG_PROMPT2 = """ {chunk} Please give a short succinct context to situate this chunk within the overall document for the purposes of improving search retrieval of the chunk. Answer only with the succinct context and nothing else. """.rstrip() CONTEXTUAL_RAG_TOKEN_ESTIMATE = 64 # 19 + 45 DOCUMENT_SUMMARY_PROMPT = """ {document} Please give a short succinct summary of the entire document. Answer only with the succinct summary and nothing else. """.rstrip() DOCUMENT_SUMMARY_TOKEN_ESTIMATE = 50 # ruff: noqa: E501, W605 end ================================================ FILE: backend/onyx/prompts/deep_research/__init__.py ================================================ ================================================ FILE: backend/onyx/prompts/deep_research/dr_tool_prompts.py ================================================ GENERATE_PLAN_TOOL_NAME = "generate_plan" GENERATE_REPORT_TOOL_NAME = "generate_report" RESEARCH_AGENT_TOOL_NAME = "research_agent" # This is to ensure that even the non-reasoning models can have an ok time with this more complex flow. THINK_TOOL_NAME = "think_tool" # ruff: noqa: E501, W605 start # Hard for the open_url tool to be called for a ton of search results all at once so limit to 3 WEB_SEARCH_TOOL_DESCRIPTION = """ ## web_search Use the `web_search` tool to get search results from the web. You should use this tool to get context for your research. These should be optimized for search engines like Google. \ Use concise and specific queries and avoid merging multiple queries into one. You can call web_search with multiple queries at once (3 max) but generally only do this when there is a clear opportunity for parallel searching. \ If you use multiple queries, ensure that the queries are related in topic but not similar such that the results would be redundant. """ # This one is mostly similar to the one for the main flow but there won't be any user specified URLs to open. OPEN_URLS_TOOL_DESCRIPTION = f""" ## open_urls Use the `open_urls` tool to read the content of one or more URLs. Use this tool to access the contents of the most promising web pages from your searches. \ You can open many URLs at once by passing multiple URLs in the array if multiple pages seem promising. Prioritize the most promising pages and reputable sources. \ You should almost always use open_urls after a web_search call and sometimes after reasoning with the {THINK_TOOL_NAME} tool. """ OPEN_URLS_TOOL_DESCRIPTION_REASONING = """ ## open_urls Use the `open_urls` tool to read the content of one or more URLs. Use this tool to access the contents of the most promising web pages from your searches. \ You can open many URLs at once by passing multiple URLs in the array if multiple pages seem promising. Prioritize the most promising pages and reputable sources. \ You should almost always use open_urls after a web_search call. """ # NOTE: Internal search tool uses the same description as the default flow, not duplicating here. # ruff: noqa: E501, W605 end ================================================ FILE: backend/onyx/prompts/deep_research/orchestration_layer.py ================================================ from onyx.prompts.deep_research.dr_tool_prompts import GENERATE_PLAN_TOOL_NAME from onyx.prompts.deep_research.dr_tool_prompts import GENERATE_REPORT_TOOL_NAME from onyx.prompts.deep_research.dr_tool_prompts import RESEARCH_AGENT_TOOL_NAME from onyx.prompts.deep_research.dr_tool_prompts import THINK_TOOL_NAME # ruff: noqa: E501, W605 start CLARIFICATION_PROMPT = f""" You are a clarification agent that runs prior to deep research. Assess whether you need to ask clarifying questions, or if the user has already provided enough information for you to start research. \ CRITICAL - Never directly answer the user's query, you must only ask clarifying questions or call the `{GENERATE_PLAN_TOOL_NAME}` tool. If the user query is already very detailed or lengthy (more than 3 sentences), do not ask for clarification and instead call the `{GENERATE_PLAN_TOOL_NAME}` tool. For context, the date is {{current_datetime}}. Be conversational and friendly, prefer saying "could you" rather than "I need" etc. If you need to ask questions, follow these guidelines: - Be concise and do not ask more than 5 questions. - If there are ambiguous terms or questions, ask the user to clarify. - Your questions should be a numbered list for clarity. - Respond in the same language as the user's query. - Make sure to gather all the information needed to carry out the research task in a concise, well-structured manner.{{internal_search_clarification_guidance}} - Wrap up with a quick sentence on what the clarification will help with, it's ok to reference the user query closely here. """.strip() INTERNAL_SEARCH_CLARIFICATION_GUIDANCE = """ - The deep research system is connected with organization internal document search and web search capabilities. In cases where it is unclear which source is more appropriate, ask the user to clarify. """ # Here there is a bit of combating model behavior which during alignment may be overly tuned to be cautious about access to data and feasibility. # Sometimes the model will just apologize and claim the task is not possible, hence the long section following CRITICAL. RESEARCH_PLAN_PROMPT = """ You are a research planner agent that generates the high level approach for deep research on a user query. Analyze the query carefully and break it down into main concepts and areas of exploration. \ Stick closely to the user query and stay on topic but be curious and avoid duplicate or overlapping exploration directions. \ Be sure to take into account the time sensitive aspects of the research topic and make sure to emphasize up to date information where appropriate. \ Focus on providing thorough research of the user's query over being helpful. CRITICAL - You MUST only output the research plan for the deep research flow and nothing else, you are not responding to the user. \ Do not worry about the feasibility of the plan or access to data or tools, a different deep research flow will handle that. For context, the date is {current_datetime}. The research plan should be formatted as a numbered list of steps and have 6 or less individual steps. Each step should be a standalone exploration question or topic that can be researched independently but may build on previous steps. The plan should be in the same language as the user's query. Output only the numbered list of steps with no additional prefix or suffix. """.strip() # Specifically for some models, it really struggles to not just answer the user when there are questions about internal knowledge. # A reminder (specifically the fact that it's also a User type message) helps to prevent this. RESEARCH_PLAN_REMINDER = """ Remember to only output the research plan and nothing else. Do not worry about the feasibility of the plan or data access. Your response must only be a numbered list of steps with no additional prefix or suffix. """.strip() ORCHESTRATOR_PROMPT = f""" You are an orchestrator agent for deep research. Your job is to conduct research by calling the {RESEARCH_AGENT_TOOL_NAME} tool with high level research tasks. \ This delegates the lower level research work to the {RESEARCH_AGENT_TOOL_NAME} which will provide back the results of the research. For context, the date is {{current_datetime}}. Before calling {GENERATE_REPORT_TOOL_NAME}, reason to double check that all aspects of the user's query have been well researched and that all key topics around the plan have been researched. \ There are cases where new discoveries from research may lead to a deviation from the original research plan. In these cases, ensure that the new directions are thoroughly investigated prior to calling {GENERATE_REPORT_TOOL_NAME}. NEVER output normal response tokens, you must only call tools. # Tools You have currently used {{current_cycle_count}} of {{max_cycles}} max research cycles. You do not need to use all cycles. ## {RESEARCH_AGENT_TOOL_NAME} The research task provided to the {RESEARCH_AGENT_TOOL_NAME} should be reasonably high level with a clear direction for investigation. \ It should not be a single short query, rather it should be 1 (or 2 if necessary) descriptive sentences that outline the direction of the investigation. \ The research task should be in the same language as the overall research plan. CRITICAL - the {RESEARCH_AGENT_TOOL_NAME} only receives the task and has no additional context about the user's query, research plan, other research agents, or message history. \ You absolutely must provide all of the context needed to complete the task in the argument to the {RESEARCH_AGENT_TOOL_NAME}.{{internal_search_research_task_guidance}} You should call the {RESEARCH_AGENT_TOOL_NAME} MANY times before completing with the {GENERATE_REPORT_TOOL_NAME} tool. You are encouraged to call the {RESEARCH_AGENT_TOOL_NAME} in parallel if the research tasks are not dependent on each other, which is typically the case. NEVER call more than 3 {RESEARCH_AGENT_TOOL_NAME} calls in parallel. ## {GENERATE_REPORT_TOOL_NAME} You should call the {GENERATE_REPORT_TOOL_NAME} tool if any of the following conditions are met: - You have researched all of the relevant topics of the research plan. - You have shifted away from the original research plan and believe that you are done. - You have all of the information needed to thoroughly answer all aspects of the user's query. - The last research cycle yielded minimal new information and future cycles are unlikely to yield more information. ## {THINK_TOOL_NAME} CRITICAL - use the {THINK_TOOL_NAME} to reason between every call to the {RESEARCH_AGENT_TOOL_NAME} and before calling {GENERATE_REPORT_TOOL_NAME}. You should treat this as chain-of-thought reasoning to think deeply on what to do next. \ Be curious, identify knowledge gaps and consider new potential directions of research. Use paragraph format, do not use bullet points or lists. NEVER use the {THINK_TOOL_NAME} in parallel with other {RESEARCH_AGENT_TOOL_NAME} or {GENERATE_REPORT_TOOL_NAME}. Before calling {GENERATE_REPORT_TOOL_NAME}, double check that all aspects of the user's query have been researched and that all key topics around the plan have been researched (unless you have gone in a different direction). # Research Plan {{research_plan}} """.strip() INTERNAL_SEARCH_RESEARCH_TASK_GUIDANCE = """ If necessary, clarify if the research agent should focus mostly on organization internal searches, web searches, or a combination of both. If the task doesn't require a clear priority, don't add sourcing guidance. """.strip( "\n" ) USER_ORCHESTRATOR_PROMPT = """ Remember to refer to the system prompt and follow how to use the tools. Call the {THINK_TOOL_NAME} between every call to the {RESEARCH_AGENT_TOOL_NAME} and before calling {GENERATE_REPORT_TOOL_NAME}. Never run more than 3 {RESEARCH_AGENT_TOOL_NAME} calls in parallel. Don't mention this reminder or underlying details about the system. """.strip() FINAL_REPORT_PROMPT = """ You are the final answer generator for a deep research task. Your job is to produce a thorough, balanced, and comprehensive answer on the research question provided by the user. \ You have access to high-quality, diverse sources collected by secondary research agents as well as their analysis of the sources. IMPORTANT - You get straight to the point, never providing a title and avoiding lengthy introductions/preambles. For context, the date is {current_datetime}. Users have explicitly selected the deep research mode and will expect a long and detailed answer. It is ok and encouraged that your response is several pages long. \ Structure your response logically into relevant sections. You may find it helpful to reference the research plan to help structure your response but do not limit yourself to what is contained in the plan. You use different text styles and formatting to make the response easier to read. You may use markdown rarely when necessary to make the response more digestible. Provide inline citations in the format [1], [2], [3], etc. based on the citations included by the research agents. """.strip() USER_FINAL_REPORT_QUERY = f""" The original research plan is included below (use it as a helpful reference but do not limit yourself to this): ``` {{research_plan}} ``` Based on all of the context provided in the research history, provide a comprehensive, well structured, and insightful answer to the user's previous query. \ CRITICAL: be extremely thorough in your response and address all relevant aspects of the query. Ignore the format styles of the intermediate {RESEARCH_AGENT_TOOL_NAME} reports, those are not end user facing and different from your task. Provide inline citations in the format [1], [2], [3], etc. based on the citations included by the research agents. The citations should be just a number in a bracket, nothing additional. """.strip() # Reasoning Model Variants of the prompts ORCHESTRATOR_PROMPT_REASONING = f""" You are an orchestrator agent for deep research. Your job is to conduct research by calling the {RESEARCH_AGENT_TOOL_NAME} tool with high level research tasks. \ This delegates the lower level research work to the {RESEARCH_AGENT_TOOL_NAME} which will provide back the results of the research. For context, the date is {{current_datetime}}. Before calling {GENERATE_REPORT_TOOL_NAME}, reason to double check that all aspects of the user's query have been well researched and that all key topics around the plan have been researched. There are cases where new discoveries from research may lead to a deviation from the original research plan. In these cases, ensure that the new directions are thoroughly investigated prior to calling {GENERATE_REPORT_TOOL_NAME}. Between calls, think deeply on what to do next. Be curious, identify knowledge gaps and consider new potential directions of research. Use paragraph format for your reasoning, do not use bullet points or lists. NEVER output normal response tokens, you must only call tools. # Tools You have currently used {{current_cycle_count}} of {{max_cycles}} max research cycles. You do not need to use all cycles. ## {RESEARCH_AGENT_TOOL_NAME} The research task provided to the {RESEARCH_AGENT_TOOL_NAME} should be reasonably high level with a clear direction for investigation. \ It should not be a single short query, rather it should be 1 (or 2 if necessary) descriptive sentences that outline the direction of the investigation. \ The research task should be in the same language as the overall research plan. CRITICAL - the {RESEARCH_AGENT_TOOL_NAME} only receives the task and has no additional context about the user's query, research plan, or message history. \ You absolutely must provide all of the context needed to complete the task in the argument to the {RESEARCH_AGENT_TOOL_NAME}.{{internal_search_research_task_guidance}} You should call the {RESEARCH_AGENT_TOOL_NAME} MANY times before completing with the {GENERATE_REPORT_TOOL_NAME} tool. You are encouraged to call the {RESEARCH_AGENT_TOOL_NAME} in parallel if the research tasks are not dependent on each other, which is typically the case. NEVER call more than 3 {RESEARCH_AGENT_TOOL_NAME} calls in parallel. ## {GENERATE_REPORT_TOOL_NAME} You should call the {GENERATE_REPORT_TOOL_NAME} tool if any of the following conditions are met: - You have researched all of the relevant topics of the research plan. - You have shifted away from the original research plan and believe that you are done. - You have all of the information needed to thoroughly answer all aspects of the user's query. - The last research cycle yielded minimal new information and future cycles are unlikely to yield more information. # Research Plan {{research_plan}} """.strip() USER_ORCHESTRATOR_PROMPT_REASONING = """ Remember to refer to the system prompt and follow how to use the tools. \ You are encouraged to call the {RESEARCH_AGENT_TOOL_NAME} in parallel when the research tasks are not dependent on each other, but never call more than 3 {RESEARCH_AGENT_TOOL_NAME} calls in parallel. Don't mention this reminder or underlying details about the system. """.strip() # Only for the first cycle, we encourage the model to research more, since it is unlikely that it has already addressed all parts of the plan at this point. FIRST_CYCLE_REMINDER_TOKENS = 100 FIRST_CYCLE_REMINDER = """ Make sure all parts of the user question and the plan have been thoroughly explored before calling generate_report. If new interesting angles have been revealed from the research, you may deviate from the plan to research new directions. """.strip() # ruff: noqa: E501, W605 end ================================================ FILE: backend/onyx/prompts/deep_research/research_agent.py ================================================ from onyx.prompts.deep_research.dr_tool_prompts import GENERATE_REPORT_TOOL_NAME from onyx.prompts.deep_research.dr_tool_prompts import THINK_TOOL_NAME MAX_RESEARCH_CYCLES = 8 # ruff: noqa: E501, W605 start RESEARCH_AGENT_PROMPT = f""" You are a highly capable, thoughtful, and precise research agent that conducts research on a specific topic. Prefer being thorough in research over being helpful. Be curious but stay strictly on topic. \ You iteratively call the tools available to you including {{available_tools}} until you have completed your research at which point you call the {GENERATE_REPORT_TOOL_NAME} tool. NEVER output normal response tokens, you must only call tools. For context, the date is {{current_datetime}}. # Tools You have a limited number of cycles to complete your research and you do not have to use all cycles. You are on cycle {{current_cycle_count}} of {MAX_RESEARCH_CYCLES}.\ {{optional_internal_search_tool_description}}\ {{optional_web_search_tool_description}}\ {{optional_open_url_tool_description}} ## {THINK_TOOL_NAME} CRITICAL - use the think tool after every set of searches and reads (so search, read some pages, then think and repeat). \ You MUST use the {THINK_TOOL_NAME} before calling the web_search tool for all calls to web_search except for the first call. \ Use the {THINK_TOOL_NAME} before calling the {GENERATE_REPORT_TOOL_NAME} tool. After a set of searches + reads, use the {THINK_TOOL_NAME} to analyze the results and plan the next steps. - Reflect on the key information found with relation to the task. - Reason thoroughly about what could be missing, the knowledge gaps, and what queries might address them, \ or why there is enough information to answer the research task comprehensively. ## {GENERATE_REPORT_TOOL_NAME} Once you have completed your research, call the `{GENERATE_REPORT_TOOL_NAME}` tool. \ You should only call this tool after you have fully researched the topic. \ Consider other potential areas of research and weigh that against the materials already gathered before calling this tool. """.strip() RESEARCH_REPORT_PROMPT = """ You are a highly capable and precise research sub-agent that has conducted research on a specific topic. \ Your job is now to organize the findings to return a comprehensive report that preserves all relevant statements and information that has been gathered in the existing messages. \ The report will be seen by another agent instead of a user so keep it free of formatting or commentary and instead focus on the facts only. \ Do not give it a title, do not break it down into sections, and do not provide any of your own conclusions/analysis. You may see a list of tool calls in the history but you do not have access to tools anymore. You should only use the information in the history to create the report. CRITICAL - This report should be as long as necessary to return ALL of the information that the researcher has gathered. It should be several pages long so as to capture as much detail as possible from the research. \ It cannot be stressed enough that this report must be EXTREMELY THOROUGH and COMPREHENSIVE. Only this report is going to be returned, so it's CRUCIAL that you don't lose any details from the raw messages. Remove any obviously irrelevant or duplicative information. If a statement seems not trustworthy or is contradictory to other statements, it is important to flag it. Write the report in the same language as the provided task. Cite all sources INLINE using the format [1], [2], [3], etc. based on the `document` field of the source. \ Cite inline as opposed to leaving all citations until the very end of the response. """ USER_REPORT_QUERY = """ Please write me a comprehensive report on the research topic given the context above. As a reminder, the original topic was: {research_topic} Remember to include AS MUCH INFORMATION AS POSSIBLE and as faithful to the original sources as possible. \ Keep it free of formatting and focus on the facts only. Be sure to include all context for each fact to avoid misinterpretation or misattribution. \ Respond in the same language as the topic provided above. Cite every fact INLINE using the format [1], [2], [3], etc. based on the `document` field of the source. CRITICAL - BE EXTREMELY THOROUGH AND COMPREHENSIVE, YOUR RESPONSE SHOULD BE SEVERAL PAGES LONG. """ # Reasoning Model Variants of the prompts RESEARCH_AGENT_PROMPT_REASONING = f""" You are a highly capable, thoughtful, and precise research agent that conducts research on a specific topic. Prefer being thorough in research over being helpful. Be curious but stay strictly on topic. \ You iteratively call the tools available to you including {{available_tools}} until you have completed your research at which point you call the {GENERATE_REPORT_TOOL_NAME} tool. Between calls, think about the results of the previous tool call and plan the next steps. \ Reason thoroughly about what could be missing, identify knowledge gaps, and what queries might address them. Or consider why there is enough information to answer the research task comprehensively. Once you have completed your research, call the `{GENERATE_REPORT_TOOL_NAME}` tool. NEVER output normal response tokens, you must only call tools. For context, the date is {{current_datetime}}. # Tools You have a limited number of cycles to complete your research and you do not have to use all cycles. You are on cycle {{current_cycle_count}} of {MAX_RESEARCH_CYCLES}.\ {{optional_internal_search_tool_description}}\ {{optional_web_search_tool_description}}\ {{optional_open_url_tool_description}} ## {GENERATE_REPORT_TOOL_NAME} Once you have completed your research, call the `{GENERATE_REPORT_TOOL_NAME}` tool. You should only call this tool after you have fully researched the topic. """.strip() OPEN_URL_REMINDER_RESEARCH_AGENT = """ Remember that after using web_search, you are encouraged to open some pages to get more context unless the query is completely answered by the snippets. Open the pages that look the most promising and high quality by calling the open_url tool with an array of URLs. """.strip() # ruff: noqa: E501, W605 end ================================================ FILE: backend/onyx/prompts/federated_search.py ================================================ from onyx.configs.app_configs import MAX_SLACK_QUERY_EXPANSIONS SLACK_QUERY_EXPANSION_PROMPT = f""" Rewrite the user's query into at most {MAX_SLACK_QUERY_EXPANSIONS} keyword-only queries for Slack's keyword search. Slack search behavior: - Pure keyword AND search (no semantics) - More words = fewer matches, so keep queries concise (1-3 words) ALWAYS include: - Person names (e.g., "Sarah Chen", "Mike Johnson") - people search for messages from/about specific people - Project/product names, technical terms, proper nouns - Actual content words: "performance", "bug", "deployment", "API", "error" DO NOT include: - Meta-words: "topics", "conversations", "discussed", "summary", "messages" - Temporal: "today", "yesterday", "week", "month", "recent", "last" - Channel names: "general", "eng-general", "random" Examples: Query: "what are the big topics in eng-general this week?" Output: Query: "messages with Sarah about the deployment" Output: Sarah deployment Sarah deployment Query: "what did Mike say about the budget?" Output: Mike budget Mike budget Query: "performance issues in eng-general" Output: performance issues performance issues Query: "what did we discuss about the API migration?" Output: API migration API migration Now process this query: {{query}} Output (keywords only, one per line, NO explanations or commentary): """ SLACK_DATE_EXTRACTION_PROMPT = """ Extract the date range from the user's query and return it in a structured format. Current date context: - Today: {today} - Current time: {current_time} Guidelines: 1. Return a JSON object with "days_back" (integer) indicating how many days back to search 2. If no date/time is mentioned, return {{"days_back": null}} 3. Interpret relative dates accurately: - "today" or "today's" = 0 days back - "yesterday" = 1 day back - "last week" = 7 days back - "last month" = 30 days back - "last X days" = X days back - "past X days" = X days back - "this week" = 7 days back - "this month" = 30 days back 4. For creative expressions, interpret intent: - "recent" = 7 days back - "recently" = 7 days back - "lately" = 14 days back 5. Always be conservative - if uncertain, use a longer time range User query: {query} Return ONLY a valid JSON object in this format: {{"days_back": }} Nothing else. """ ================================================ FILE: backend/onyx/prompts/filter_extration.py ================================================ # The following prompts are used for extracting filters to apply along with the query in the # document index. For example, a filter for dates or a filter by source type such as GitHub # or Slack SOURCES_KEY = "sources" # Smaller followup prompts in time_filter.py TIME_FILTER_PROMPT = """ You are a tool to identify time filters to apply to a user query for a downstream search \ application. The downstream application is able to use a recency bias or apply a hard cutoff to \ remove all documents before the cutoff. Identify the correct filters to apply for the user query. The current day and time is {current_day_time_str}. Always answer with ONLY a json which contains the keys "filter_type", "filter_value", \ "value_multiple" and "date". The valid values for "filter_type" are "hard cutoff", "favors recent", or "not time sensitive". The valid values for "filter_value" are "day", "week", "month", "quarter", "half", or "year". The valid values for "value_multiple" is any number. The valid values for "date" is a date in format MM/DD/YYYY, ALWAYS follow this format. """.strip() # Smaller followup prompts in source_filter.py # Known issue: LLMs like GPT-3.5 try to generalize. If the valid sources contains "web" but not # "confluence" and the user asks for confluence related things, the LLM will select "web" since # confluence is accessed as a website. This cannot be fixed without also reducing the capability # to match things like repository->github, website->web, etc. # This is generally not a big issue though as if the company has confluence, hopefully they add # a connector for it or the user is aware that confluence has not been added. SOURCE_FILTER_PROMPT = f""" Given a user query, extract relevant source filters for use in a downstream search tool. Respond with a json containing the source filters or null if no specific sources are referenced. ONLY extract sources when the user is explicitly limiting the scope of where information is \ coming from. The user may provide invalid source filters, ignore those. The valid sources are: {{valid_sources}} {{web_source_warning}} {{file_source_warning}} ALWAYS answer with ONLY a json with the key "{SOURCES_KEY}". \ The value for "{SOURCES_KEY}" must be null or a list of valid sources. Sample Response: {{sample_response}} """.strip() WEB_SOURCE_WARNING = """ Note: The "web" source only applies to when the user specifies "website" in the query. \ It does not apply to tools such as Confluence, GitHub, etc. that have a website. """.strip() FILE_SOURCE_WARNING = """ Note: The "file" source only applies to when the user refers to uploaded files in the query. """.strip() # Use the following for easy viewing of prompts if __name__ == "__main__": print(TIME_FILTER_PROMPT) print("------------------") print(SOURCE_FILTER_PROMPT) ================================================ FILE: backend/onyx/prompts/image_analysis.py ================================================ # Used for creating embeddings of images for vector search DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT = """ You are an assistant for summarizing images for retrieval. Summarize the content of the following image and be as precise as possible. The summary will be embedded and used to retrieve the original image. Therefore, write a concise summary of the image that is optimized for retrieval. """ # Prompt for generating image descriptions with filename context DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT = """ Describe precisely and concisely what the image shows. """ # Used for analyzing images in response to user queries at search time DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT = ( "You are an AI assistant specialized in describing images.\n" "You will receive a user question plus an image URL. Provide a concise textual answer.\n" "Focus on aspects of the image that are relevant to the user's question.\n" "Be specific and detailed about visual elements that directly address the query.\n" ) ================================================ FILE: backend/onyx/prompts/kg_prompts.py ================================================ # Standards SEPARATOR_LINE = "-------" SEPARATOR_LINE_LONG = "---------------" NO_EXTRACTION = "No extraction of knowledge graph objects was feasible." YES = "yes" NO = "no" # Framing/Support/Template Prompts ENTITY_TYPE_SETTING_PROMPT = f""" {SEPARATOR_LINE} {{entity_types}} {SEPARATOR_LINE} """.strip() RELATIONSHIP_TYPE_SETTING_PROMPT = f""" Here are the types of relationships: {SEPARATOR_LINE} {{relationship_types}} {SEPARATOR_LINE} """.strip() EXTRACTION_FORMATTING_PROMPT = r""" {{"entities": [::' (please use that capitalization). If allowed options \ are provided above, you can only extract those types of entities! Again, there should be an 'Other' \ option. Pick this if none of the others apply.>], "relationships": [::__\ __::'>], "terms": ['>] }} """.strip() QUERY_ENTITY_EXTRACTION_FORMATTING_PROMPT = r""" {{"entities": [::' (please use that capitalization)>. Each entity \ also should be followed by a list of comma-separated attribute filters for the entity, if referred to in the \ question for that entity. CRITICAL: you can only use attributes that are mentioned above for the \ entity type in question. Example: 'ACCOUNT::* -- [account_type: customer, status: active]' if the question is \ 'list all customer accounts', and ACCOUNT was an entity type with these attribute key/values allowed.] \ "time_filter": }} """.strip() QUERY_RELATIONSHIP_EXTRACTION_FORMATTING_PROMPT = r""" {{"relationships": [::__\ __::'>] }} """.strip() EXAMPLE_1 = r""" {{"entities": ["ACCOUNT::Nike", "CONCERN::*"], "relationships": ["ACCOUNT::Nike__had__CONCERN::*"], "terms": []}} """.strip() EXAMPLE_2 = r""" {{"entities": ["ACCOUNT::Nike", "CONCERN::performance"], "relationships": ["ACCOUNT::*__had_issues__CONCERN::performance"], "terms": ["performance issue"]}} """.strip() EXAMPLE_3 = r""" {{"entities": ["ACCOUNT::Nike", "CONCERN::performance", "CONCERN::user_experience"], "relationships": ["ACCOUNT::Nike__had__CONCERN::performance", "ACCOUNT::Nike__solved__CONCERN::user_experience"], "terms": ["performance", "user experience"]}} """.strip() EXAMPLE_4 = r""" {{"entities": ["ACCOUNT::Nike", "FEATURE::dashboard", "CONCERN::performance"], "relationships": ["ACCOUNT::Nike__had__CONCERN::performance", "ACCOUNT::Nike__had_issues__FEATURE::dashboard", "ACCOUNT::NIKE__gets_value_from__FEATURE::dashboard"], "terms": ["value", "performance"]}} """.strip() RELATIONSHIP_EXAMPLE_1 = r""" 'Which issues did Nike report?' and the extracted entities were found to be: "ACCOUNT::Nike", "CONCERN::*" then a valid relationship extraction could be: {{"relationships": ["ACCOUNT::Nike__had__CONCERN::*"]}} """.strip() RELATIONSHIP_EXAMPLE_2 = r""" 'Did Nike say anything about performance issues?' and the extracted entities were found to be: "ACCOUNT::Nike", "CONCERN::performance" then a much more suitable relationship extraction could be: {{"relationships": ["ACCOUNT::*__had_issues__CONCERN::performance"]}} """.strip() RELATIONSHIP_EXAMPLE_3 = r""" 'Did Nike report some performance issues with our solution? And were they happy that the user experience issue got solved?', \ and the extracted entities were found to be: "ACCOUNT::Nike", "CONCERN::performance", "CONCERN::user_experience" then a valid relationship extraction could be: {{"relationships": ["ACCOUNT::Nike__had__CONCERN::performance", "ACCOUNT::Nike__solved__CONCERN::user_experience"]}} """.strip() RELATIONSHIP_EXAMPLE_4 = r""" 'Nike reported some performance issues with our dashboard solution, but do they think it delivers great value nevertheless?' \ and the extracted entities were found to be: "ACCOUNT::Nike", "FEATURE::dashboard", "CONCERN::performance" then a valid relationship extraction could be: Example 4: {{"relationships": ["ACCOUNT::Nike__had__CONCERN::performance", "ACCOUNT::Nike__had_issues__FEATURE::dashboard", "ACCOUNT::NIKE__gets_value_from__FEATURE::dashboard"]}} Explanation: - Nike did report performance concerns - Nike had problems with the dashboard, which is a feature - We are interested in the value relationship between Nike and the dashboard feature """.strip() RELATIONSHIP_EXAMPLE_5 = r""" 'In which emails did Nike discuss their issues with the dashboard?' \ and the extracted entities were found to be: "ACCOUNT::Nike", "FEATURE::dashboard", "EMAIL::*" then a valid relationship extraction could be: {{"relationships": ["ACCOUNT::Nike__had__CONCERN::*", "ACCOUNT::Nike__had_issues__FEATURE::dashboard", "ACCOUNT::NIKE__in__EMAIL::*", "EMAIL::*__discusses__FEATURE::dashboard", "EMAIL::*Nike__had__CONCERN::* "]}} Explanation: - Nike did report unspecified concerns - Nike had problems with the dashboard, which is a feature - We are interested in emails that Nike exchanged with us """.strip() RELATIONSHIP_EXAMPLE_6 = r""" 'List the last 5 emails that Lisa exchanged with Nike:' \ and the extracted entities were found to be: "ACCOUNT::Nike", "EMAIL::*", "EMPLOYEE::Lisa" then a valid relationship extraction could be: {{"relationships": ["ACCOUNT::Nike__had__CONCERN::*", "ACCOUNT::Nike__had_issues__FEATURE::dashboard", "ACCOUNT::NIKE__in__EMAIL::*"]}} Explanation: - Nike did report unspecified concerns - Nike had problems with the dashboard, which is a feature - We are interested in emails that Nike exchanged with us """.strip() ENTITY_EXAMPLE_1 = r""" {{"entities": ["ACCOUNT::Nike--[]", "CONCERN::*--[]"]}} """.strip() ENTITY_EXAMPLE_2 = r""" {{"entities": ["ACCOUNT::Nike--[]", "CONCERN::performance--[]"]}} """.strip() ENTITY_EXAMPLE_3 = r""" {{"entities": ["ACCOUNT::*--[]", "CONCERN::performance--[]", "CONCERN::user_experience--[]"]}} """.strip() ENTITY_EXAMPLE_4 = r""" {{"entities": ["ACCOUNT::*--[]", "CONCERN::performance--[degree: severe]"]}} """.strip() MASTER_EXTRACTION_PROMPT = f""" You are an expert in the area of knowledge extraction in order to construct a knowledge graph. You are given a text \ and asked to extract entities, relationships, and terms from it that you can reliably identify. Here are the entity types that are available for extraction. Some of them may have a description, others \ should be obvious. Also, for a given entity allowed options may be provided. If allowed options are provided, \ you can only extract those types of entities! If no allowed options are provided, take your best guess. You can ONLY extract entities of these types and relationships between objects of these types: {SEPARATOR_LINE} {ENTITY_TYPE_SETTING_PROMPT} {SEPARATOR_LINE} Please format your answer in this format: {SEPARATOR_LINE} {EXTRACTION_FORMATTING_PROMPT} {SEPARATOR_LINE} The list above here is the exclusive, only list of entities you can choose from! Here are some important additional instructions. (For the purpose of illustration, assume that ] "ACCOUNT", "CONCERN", and "FEATURE" are all in the list of entity types above, and shown actual \ entities fall into allowed options. Note that this \ is just assumed for these examples, but you MUST use only the entities above for the actual extraction!) - You can either extract specific entities if a specific entity is referred to, or you can refer to the entity type. * if the entity type is referred to in general, you would use '*' as the entity name in the extraction. As an example, if the text would say: 'Nike reported that they had issues' then a valid extraction could be: Example 1: {EXAMPLE_1} * If on the other hand the text would say: 'Nike reported that they had performance issues' then a much more suitable extraction could be: Example 2: {EXAMPLE_2} - You can extract multiple relationships between the same two entity types. As an example, if the text would say: 'Nike reported some performance issues with our solution, but they are very happy that the user experience issue got solved.' then a valid extraction could be: Example 3: {EXAMPLE_3} - You can extract multiple relationships between the same two actual entities if you think that \ there are multiple relationships between them based on the text. As an example, if the text would say: 'Nike reported some performance issues with our dashboard solution, but they think it delivers great value.' then a valid extraction could be: Example 4: {EXAMPLE_4} Note that effectively a three-way relationship (Nike - performance issues - dashboard) extracted as two individual \ relationships. - Again, - you should only extract entities belonging to the entity types above - but do extract all that you \ can reliably identify in the text - use refer to 'all' entities in an entity type listed above by using '*' as the entity name - only extract important relationships that signify something non-trivial, expressing things like \ needs, wants, likes, dislikes, plans, interests, lack of interests, problems the account is having, etc. - you MUST only use the initial list of entities provided! Ignore the entities in the examples unless \ they are also part of the initial list of entities! This is essential! - only extract relationships between the entities extracted first! {SEPARATOR_LINE} Here is the text you are asked to extract knowledge from, if needed with additional information about any participants: {SEPARATOR_LINE} ---content--- {SEPARATOR_LINE} """.strip() QUERY_ENTITY_EXTRACTION_PROMPT = f""" You are an expert in the area of knowledge extraction and using knowledge graphs. You are given a question \ and asked to extract entities (with attributes if applicable) that you can reliably identify, which will then be matched with a known entity in the knowledge graph. You are also asked to extract time constraints information \ from the QUESTION. Some time constraints will be captured by entity attributes if \ the entity type has a fitting attribute (example: 'created_at' could be a candidate for that), other times we will extract an explicit time filter if no attribute fits. (Note regarding 'last', 'first', etc.: DO NOT \ imply the need for a time filter just because the question asks for something that is not the current date. \ They will relate to ordering that we will handle separately later). In case useful, today is ---today_date--- and the user asking is ---user_name---, which may or may not be relevant. Here are the entity types that are available for extraction. Some of them may have \ a description, others should be obvious. Also, notice that some may have attributes associated with them, which will \ be important later. You can ONLY extract entities of these types: {SEPARATOR_LINE} {ENTITY_TYPE_SETTING_PROMPT} {SEPARATOR_LINE} The list above here is the exclusive, only list of entities you can choose from! Also, note that there are fixed relationship types between these entities. Please consider those \ as well so to make sure that you are not missing implicit entities! Implicit entities are often \ in verbs ('emailed to', 'talked to', ...). Also, they may be used to connect entities that are \ clearly in the question. {SEPARATOR_LINE} {RELATIONSHIP_TYPE_SETTING_PROMPT} {SEPARATOR_LINE} Here are some important additional instructions. (For the purpose of illustration, assume that \ "ACCOUNT", "CONCERN", "EMAIL", and "FEATURE" are all in the list of entity types above, and the \ attribute options for "CONCERN" include 'degree' with possible values that include 'severe'. Note that this \ is just assumed for these examples, but you MUST use only the entities above for the actual extraction!) - You can either extract specific entities if a specific entity is referred to, or you can refer to the entity type. * if the entity type is referred to in general, you would use '*' as the entity name in the extraction. As an example, if the question would say: 'Which issues did Nike report?' then a valid entity and term extraction could be: Example 1: {ENTITY_EXAMPLE_1} * If on the other hand the question would say: 'Did Nike say anything about performance issues?' then a much more suitable entity and term extraction could be: Example 2: {ENTITY_EXAMPLE_2} * Then, if the question is: 'Who reported performance issues?' then a suitable entity and term extraction could be: Example 3: {ENTITY_EXAMPLE_3} * Then, if we inquire about an entity with a specific attribute : 'Who reported severe performance issues?' then a suitable entity and term extraction could be: Example 3: {ENTITY_EXAMPLE_4} - Again, - you should only extract entities belonging to the entity types above - but do extract all that you \ can reliably identify in the text - if you refer to all/any/an unspecified entity of an entity type listed above, use '*' as the entity name - similarly, if a specific entity type is referred to in general, you should use '*' as the entity name - you MUST only use the initial list of entities provided! Ignore the entities in the examples unless \ they are also part of the initial list of entities! This is essential! - don't forget to provide answers also to the event filtering and whether documents need to be inspected! - 'who' often refers to individuals or accounts. - see whether any of the entities are supposed to be narrowed down by an attribute value. The precise attribute \ and the value would need to be taken from the specification, as the question may use different words and the \ actual attribute may be implied. - don't just look at the entities that are mentioned in the question but also those that the question \ may be about. - be very careful that you only extract attributes that are listed above for the entity type in question! Do \ not make up attributes even if they are implied! Particularly if there is a relationship type that would \ actually represent that information, you MUST not extract the information as an attribute. We \ will extract the relationship type later. - For the values of attributes, look at the possible values above! For example 'open' may refer to \ 'backlog', 'todo', 'in progress', etc. In cases like that construct a ';'-separated list of values that you think may fit \ what is implied in the question (in the exanple: 'open; backlog; todo; in progress'). Also, if you think the name or the title of an entity is given but name or title are not mentioned \ explicitly as an attribute, then you should indeed extract the name/title as the entity name. {SEPARATOR_LINE} Here is the question you are asked to extract desired entities and time filters from: {SEPARATOR_LINE} ---content--- {SEPARATOR_LINE} Please format your answer in this format: {SEPARATOR_LINE} {QUERY_ENTITY_EXTRACTION_FORMATTING_PROMPT} {SEPARATOR_LINE} """.strip() QUERY_RELATIONSHIP_EXTRACTION_PROMPT = f""" You are an expert in the area of knowledge extraction and using knowledge graphs. You are given a question \ and previously you were asked to identify known entities in the question. Now you are asked to extract \ the relationships between the entities you have identified earlier. First off as background, here are the entity types that are known to the system: {SEPARATOR_LINE} ---entity_types--- {SEPARATOR_LINE} Here are the entities you have identified earlier: {SEPARATOR_LINE} ---identified_entities--- {SEPARATOR_LINE} Note that the notation for the entities is ::. Here are the options for the relationship types(!) between the entities you have identified earlier \ as well as relationship types between the identified entities and other entities \ not explicitly mentioned: {SEPARATOR_LINE} ---relationship_type_options--- {SEPARATOR_LINE} These types are, if any were identified, formatted as \ ____, and they \ limit the allowed relationships that you can extract. You would then though use the actual full entities as in: ::____::. Note: should be a word or two that captures the nature \ of the relationship. Common relationships may be: 'likes', 'dislikes', 'uses', 'is interested in', 'mentions', \ 'addresses', 'participates in', etc., but look at the text to find the most appropriate relationship. \ Use spaces here for word separation. Please format your answer in this format: {SEPARATOR_LINE} {QUERY_RELATIONSHIP_EXTRACTION_FORMATTING_PROMPT} {SEPARATOR_LINE} The list above here is the exclusive, only list of entities and relationship types you can choose from! Here are some important additional instructions. (For the purpose of illustration, assume that ] "ACCOUNT", "CONCERN", and "FEATURE" are all in the list of entity types above. Note that this \ is just assumed for these examples, but you MUST use only the entities above for the actual extraction!) - You can either extract specific entities if a specific entity is referred to, or you can refer to the entity type. * if the entity type is referred to in general, you would use '*' as the entity name in the extraction. As an example, if the question would say: {RELATIONSHIP_EXAMPLE_1} * If on the other hand the question would say: {RELATIONSHIP_EXAMPLE_2} - You can extract multiple relationships between the same two entity types. For example 3, if the question would say: {RELATIONSHIP_EXAMPLE_3} - You can extract multiple relationships between the same two actual entities if you think that \ there are multiple relationships between them based on the question. As an example, if the question would say: {RELATIONSHIP_EXAMPLE_4} Note that effectively a three-way relationship (Nike - performance issues - dashboard) extracted as two individual \ relationships. - Again, - you can only extract relationships between the entities extracted earlier - you can only extract the relationships that match the listed relationship types - if in doubt and there are multiple relationships between the same two entities, you can extract \ all of those that may fit with the question. - be really thinking through the question which type of relationships should be extracted and which should not. Other important notes: - For questions that really try to explore in general what a certain entity was involved in like 'what did Paul Smith do \ in the last 3 months?', and Paul Smith has been extracted i.e. as an entity of type 'EMPLOYEE', then you need to extract \ all of the possible relationships an empoyee Paul Smith could have. - You are not forced to use all or any of the relationship types listed above. Really look at the question to \ determine which relationships are explicitly or implicitly referred to in the question. {SEPARATOR_LINE} Here is the question you are asked to extract desired entities, relationships, and terms from: {SEPARATOR_LINE} ---question--- {SEPARATOR_LINE} """.strip() GENERAL_CHUNK_PREPROCESSING_PROMPT = """ This is a part of a document that you need to extract information (entities, relationships) from. Note: when you extract relationships, please make sure that: - if you see a relationship for one of our employees, you should extract the relationship both for the employee AND \ VENDOR::{vendor}. - if you see a relationship for one of the representatives of other accounts, you should extract the relationship \ only for the account ACCOUNT::! -- And here is the content: {content} """.strip() ### Source-specific prompts CALL_CHUNK_PREPROCESSING_PROMPT = """ This is a call between employees of the VENDOR's company and representatives of one or more ACCOUNTs (usually one). \ When you extract information based on the instructions, please make sure that you properly attribute the information \ to the correct employee and account. \ Here are the participants (name component of email) from us ({vendor}): {participant_string} Here are the participants (name component of email) from the other account(s): {account_participant_string} In the text it should be easy to associate a name with the email, and then with the account ('us' vs 'them'). If in doubt, \ look at the context and try to identify whether the statement comes from the other account. If you are not sure, ignore. Note: when you extract relationships, please make sure that: - if you see a relationship for one of our employees, you should extract the relationship both for the employee AND \ VENDOR::{vendor}. - if you see a relationship for one of the representatives of other accounts, you should extract the relationship \ only for the account ACCOUNT::! -- And here is the content: {content} """.strip() CALL_DOCUMENT_CLASSIFICATION_PROMPT = """ This is the beginning of a call between employees of the VENDOR's company ({vendor}) and other participants. Your task is to classify the call into one of the following categories: {category_options} Please also consider the participants when you perform your classification task - they can be important indicators \ for the category. Please format your answer as a string in the format: REASONING: - CATEGORY: -- And here is the beginning of the call, including title and participants: {beginning_of_call_content} """.strip() STRATEGY_GENERATION_PROMPT = f""" Now you need to decide what type of strategy to use to answer a given question, how ultimately \ the answer should be formatted to match the user's expectation, and what an appropriate question \ to/about 'one object or one set of objects' may be, should the answer logically benefit from a divide \ and conquer strategy, or it naturally relates to one or few individual objects. Also, you are \ supposed to determine whether a divide and conquer strategy would be appropriate. Here are the entity types that are available in the knowledge graph: {SEPARATOR_LINE} ---possible_entities--- {SEPARATOR_LINE} Here are the relationship types that are available in the knowledge graph: {SEPARATOR_LINE} ---possible_relationships--- {SEPARATOR_LINE} Here is the question whose answer is ultimately sought: {SEPARATOR_LINE} ---question--- {SEPARATOR_LINE} And here are the entities and relationships that have been extracted earlier from this question: {SEPARATOR_LINE} ---entities--- ---relationships--- {SEPARATOR_LINE} Here are more instructions: a) Regarding the strategy, there are three aspects to it: a1) "Search Type": Should the question be answered as a SEARCH ('filtered search'), or as a SQL ('SQL query search')? The options are: 1. SEARCH: A filtered search simply uses the entities and relationships that you extracted earlier and \ applies them as filters to search the underlying documents, which are properly indexed. Examples are \ 'what did Nike say about the Analyzer product?', or 'what did I say in my calls with Nike about pricing?'. So this \ is used really when there is *no implicit or explicit constraint or requirements* on underlying source documents \ outside of filters, and there is no ordering, no limiting their number, etc. So use this for a question that \ tries to get information *across* documents which may be filtered by their related relationships and entities, but without \ other constraints. 2. SQL: Choose this option if the question either requires counting of entities (e.g. 'how many calls...'), or \ if the query refers to specific entities that first need to be identified and then analyzed/searched/listed. \ Examples here are 'what did I say about pricing in my call with Nike last week?' (the specific call needs to \ be identified first and then analyzed), \ 'what are the next steps of our two largest opportunities?', or 'summarize my 3 most recent customer calls'. So \ this is used if there *are implicit constraints* on the underlying source documents beyond filtering, including \ ordering, limiting, etc. Use this also if the answer expects to analyze each source independently as part \ of the overall answer. Note: - here, you should look at the extracted entities and relationships and judge whether using them as filters \ (using an *and*) would be appropriate to identify the range of underlying sources, or whether more \ calculations would be needed to find the underlying sources ('last 2...', etc.) . - It is also *critical* to look at the attributes of the entities! You only can use the given attributes (and their values, if given) as where conditions etc in a SQL statement. So if you think you would 'want to' have a where condition but there is not appropriate attribute, then you should not use the SQL strategy but the SEARCH strategy. (A Search can always look through data and see what is the best fit, SQL needs to be more specific.). On the other hand, if the question maps well to the entities and attributes, then SQL may be a good choice. - Likely, if there are questions 'about something', then this only is used in a SQL statement or a filter \ if it shows up as an entity or relationship in the extracted entities and relationships. Otherwise, it will \ be part of the analysis/search. not the document identification. - again, note that we can only FILTER (SEARCH) or COMPUTE (SQL) using the extracted entities (and their attributes) and relationships. \ So do not think that if there is another term in the question, it should be included in the SQL statement. \ It cannot. a2) "Search Strategy": If a SQL search is chosen, i.e., documents have to be identified first, there are two approaches: 1. SIMPLE: You think you can answer the question using a database that is aware of the entities, relationships \ above, and is generally suitable if it is enough to either list or count entities, return dates, etc. Usually, \ 'SIMPLE' is chosen for questions of the form 'how many...' (always), or 'list the...' (often), 'when was...', \ 'what did (someone) work on...'etc. Often it is also used in cases like 'what did John work on since April?'. Here, \ the user would expect to just see the list. So chose 'SIMPLE' here unless there are REALLY CLEAR \ follow-up instructions for each item (like 'summarize...' , 'analyze...', 'what are the main points of...'.) If \ it is a 'what did...'-type question, choose 'SIMPLE'! 2. DEEP: You think you really should ALSO leverage the actual text of sources to answer the question, which sits \ in a vector database. Examples are 'what is discussed in...', 'summarize', 'what is the discussion about...',\ 'how does... relate to...', 'are there any mentions of... in..', 'what are the main points in...', \ 'what are the next steps...', etc. Those are usually questions 'about' \ the entities retrieved from the knowledge graph, or questions about the underlying sources. Your task is to decide which of the two strategies to use. a3) "Relationship Detection": You need to evaluate whether the question involves any relationships between entities (of the same type) \ or between entities and relationships. Respond with 'RELATIONSHIPS' or 'NO_RELATIONSHIPS'. b) Regarding the format of the answer: there are also two types of formats available to you: 1. LIST: The user would expect an answer as a bullet point list of objects, likely with text associated with each \ bullet point (or sub-bullet). This will be clearer once the data is available. 2. TEXT: The user would expect the questions to be answered in text form. Your task is to decide which of the two formats to use. c) Regarding the broken down question for one object: Always generate a broken_down_question if the question pertains ultimately to a specific objects, even if it seems to be \ a singular object. - If the question is of type 'how many...', or similar, then imagine that the individual objects have been \ found and you want to ask each object something that illustrates why/in what what that object relates to the \ question. (question: 'How many cars are fast?' -> broken_down_question: 'How fast is this car?') - Assume the answer would either i) best be generated by first analyzing one object at a time, then aggregating \ the results, or ii) directly relates to one or few objects found through matching suitable criteria. - The key is to drop any filtering/criteria matching as the objects are already filtered by the criteria. Also, do not \ try to verify here whether the object in question actually satisfies a filter criteria, but rather see \ what it says/does etc. In other words, use this to identify more details about the object, as it relates \ to the original question. (Example: question: 'What did our oil & gas customers say about the new product?' -> broken_down_question: \ 'What did this customer say about the new product?', or: question: 'What was in the email from Frank?' -> broken_down_question: 'What is in this email?') d) Regarding the divide and conquer strategy: You are supposed to decide whether a divide and conquer strategy would be appropriate. That means, do you think \ that in order to answer the question, it would be good to first analyze one object at a time, and then aggregate the \ results? Or should the information rather be analyzed as a whole? This would be 'yes' or 'no'. Please answer in json format in this form: {{ "search_type": , "search_strategy": , "relationship_detection": , "format": , "broken_down_question": , "divide_and_conquer": }} Do not include any other text or explanations. """ SOURCE_DETECTION_PROMPT = f""" You are an expert in generating, understanding and analyzing SQL statements. You are given an original SQL statement that returns a list of entities from a table or \ an aggregation of entities from a table. Your task will be to \ identify the source documents that are relevant to what the SQL statement is returning. The task is actually quite simple. There are two tables involved - relationship_table and entity_table. \ relationship_table was used to generate the original SQL statement. Again, returning entities \ or aggregations of entities. The second table, entity_table contains the entities and \ the corresponding source_documents. All you need to do is to appropriately join the \ entity_table table on the entities that would be retrieved from the original SQL statement, \ and then return the source_documents from the entity_table table. For your orientation, the relationship_table table has this structure: - Table name: relationship_table - Columns: - relationship (str): The name of the RELATIONSHIP, combining the nature of the relationship and the names of the entities. \ It is of the form \ ::____:: \ [example: ACCOUNT::Nike__has__CONCERN::performance]. Note that this is NOT UNIQUE! - source_entity (str): the id of the source ENTITY/NODE in the relationship [example: ACCOUNT::Nike] - source_entity_attributes (json): the attributes of the source entity/node [example: {{"account_type": "customer"}}] - target_entity (str): the id of the target ENTITY/NODE in the relationship [example: CONCERN::performance] - target_entity_attributes (json): the attributes of the target entity/node [example: {{"degree": "severe"}}] - source_entity_type (str): the type of the source entity/node [example: ACCOUNT]. Only the entity types provided \ below are valid. - target_entity_type (str): the type of the target entity/node [example: CONCERN]. Only the entity types provided \ below are valid. - relationship_type (str): the type of the relationship, formatted as \ ____. So the explicit entity_names have \ been removed. [example: ACCOUNT__has__CONCERN] - source_date (str): the 'event' date of the source document [example: 2021-01-01] The second table, entity_table, has this structure: - Table name: entity_table - Columns: - entity (str): The name of the ENTITY, which is unique in this table. source_entity and target_entity \ in the relationship_table table are the same as entity in this table. - source_document (str): the id of the document that contains the entity. Again, ultimately, your task is to join the entity_table table on the entities that would be retrieved from the \ original SQL statement, and then return the source_documents from the entity_table table. The way to do that is to create a common table expression for the original SQL statement and join the \ entity_table table suitably on the entities. Here is the *original* SQL statement: {SEPARATOR_LINE} ---original_sql_statement--- {SEPARATOR_LINE} Please structure your answer using , ,, start and end tags as in: [think very briefly through the problem step by step, not more than 2-3 sentences] \ [the new SQL statement that returns the source documents involved in the original SQL statement] """.strip() ENTITY_SOURCE_DETECTION_PROMPT = f""" You are an expert in generating, understanding and analyzing SQL statements. You are given a SQL statement that returned an aggregation of entities in a table. \ Your task will be to identify the source documents for the entities involved in \ the answer. For example, should the original SQL statement be \ 'SELECT COUNT(entity) FROM entity_table where entity_type = "ACCOUNT"' \ then you should return the source documents that contain the entities of type 'ACCOUNT'. The table has this structure: - Table name: entity_table - Columns: - entity (str): The name of the ENTITY, combining the nature of the entity and the id of the entity. \ It is of the form :: [example: ACCOUNT::625482894]. - entity_type (str): the type of the entity [example: ACCOUNT]. - entity_attributes (json): the attributes of the entity [example: {{"priority": "high", "status": "active"}}] - source_document (str): the id of the document that contains the entity. Note that the combination of \ id_name and source_document IS UNIQUE! - source_date (timestamp): the 'event' date of the source document [example: 2025-04-25 21:43:31.054741+00] Specifically, the table contains the 'source_document' column, which is the id of the source document that \ contains the core information about the entity. Make sure that you do not return more documents, i.e. if there \ is a limit on source documents in the original SQL statement, the new SQL statement needs to have \ the same limit. CRITICAL NOTES: - Only return source documents and nothing else! Your task is then to create a new SQL statement that returns the source documents that are relevant to what the \ original SQL statement is returning. So the source document of every row used in the original SQL statement should \ be included in the result of the new SQL statement, and then you should apply a 'distinct'. Here is the *original* SQL statement: {SEPARATOR_LINE} ---original_sql_statement--- {SEPARATOR_LINE} Please structure your answer using , ,, start and end tags as in: [think very briefly through the problem step by step, not more than 2-3 sentences] \ [the new SQL statement that returns the source documents involved in the original SQL statement] """.strip() ENTITY_TABLE_DESCRIPTION = f"""\ - Table name: entity_table - Columns: - entity (str): The name of the ENTITY, combining the nature of the entity and the id of the entity. \ It is of the form :: [example: ACCOUNT::625482894]. - entity_type (str): the type of the entity [example: ACCOUNT]. - entity_attributes (json): the attributes of the entity [example: {{"priority": "high", "status": "active"}}] - source_document (str): the id of the document that contains the entity. Note that the combination of \ id_name and source_document IS UNIQUE! - source_date (timestamp): the 'event' date of the source document [example: 2025-04-25 21:43:31.054741+00] {SEPARATOR_LINE} Importantly, here are the entity (node) types that you can use, with a short description of what they mean. You may need to \ identify the proper entity type through its description. Also notice the allowed attributes for each entity type and \ their values, if provided. Of particular importance is the 'subtype' attribute, if provided, as this is how \ the entity type may also often be referred to. {SEPARATOR_LINE} ---entity_types--- {SEPARATOR_LINE} """ RELATIONSHIP_TABLE_DESCRIPTION = f"""\ - Table name: relationship_table - Columns: - relationship (str): The name of the RELATIONSHIP, combining the nature of the relationship and the names of the entities. \ It is of the form \ ::____:: \ [example: ACCOUNT::Nike__has__CONCERN::performance]. Note that this is NOT UNIQUE! - source_entity (str): the id of the source ENTITY/NODE in the relationship [example: ACCOUNT::Nike] - source_entity_attributes (json): the attributes of the source entity/node [example: {{"account_type": "customer"}}] - target_entity (str): the id of the target ENTITY/NODE in the relationship [example: CONCERN::performance] - target_entity_attributes (json): the attributes of the target entity/node [example: {{"degree": "severe"}}] - source_entity_type (str): the type of the source entity/node [example: ACCOUNT]. Only the entity types provided \ below are valid. - target_entity_type (str): the type of the target entity/node [example: CONCERN]. Only the entity types provided \ below are valid. - relationship_type (str): the type of the relationship, formatted as \ ____. So the explicit entity_names have \ been removed. [example: ACCOUNT__has__CONCERN] - source_document (str): the id of the document that contains the relationship. Note that the combination of \ id_name and source_document IS UNIQUE! - source_date (timestamp): the 'event' date of the source document [example: 2025-04-25 21:43:31.054741+00] {SEPARATOR_LINE} Importantly, here are the entity (node) types that you can use, with a short description of what they mean. You may need to \ identify the proper entity type through its description. Also notice the allowed attributes for each entity type and \ their values, if provided. Of particular importance is the 'subtype' attribute, if provided, as this is how \ the entity type may also often be referred to. {SEPARATOR_LINE} ---entity_types--- {SEPARATOR_LINE} Here are the relationship types that are in the table, denoted as ____. In the table, the actual relationships are not quite of this form, but each is followed by '::' \ in the relationship id as shown above. {SEPARATOR_LINE} ---relationship_types--- {SEPARATOR_LINE} """ SIMPLE_SQL_PROMPT = f""" You are an expert in generating a SQL statement that only uses ONE TABLE that captures RELATIONSHIPS \ between TWO ENTITIES. The table has the following structure: {SEPARATOR_LINE} {RELATIONSHIP_TABLE_DESCRIPTION} Here is the question you are supposed to translate into a SQL statement: {SEPARATOR_LINE} ---question--- {SEPARATOR_LINE} To help you, we already have identified the entities and relationships that the SQL statement likely *should* use (but note the \ exception below!). The entities also contain the list of attributes and attribute values that should specify the entity. \ The format is ::--[:, \ :, ...]. {SEPARATOR_LINE} Identified entities with attributes in query: ---query_entities_with_attributes--- These are the entities that should be used in the SQL statement. However, \ note that these are the entities (with potential attributes) that were *matches* of Knowledge Graph identified with the \ entities originally identified in the original question. A such, they may have id names that may not mean much by themselves, \ eg ACCOUNT::a74f332. Here is the mapping of entities originally identified (whose role in the query should be obvious) with \ the entities that were matched to them in the Knowledge Graph: ---entity_explanation_string--- -- Here are relationships that were identified as explicitly or implicitly referred to in the question: ---query_relationships--- (Again, if applicable, the entities contained in the relationships are the same as the entities in the \ query_entities_with_attributes, and those are the correct ones to use in the SQL statement.) {SEPARATOR_LINE} CRITICAL SPECIAL CASE: - if an identified entity is of the form ::*, or an identified relationship contains an \ entity of this form, this refers to *any* entity of that type. Correspondingly, the SQL query should use the *entity type*, \ and possibly the relationship type, but not the entity with the * itself. \ Example: if you see 'ACCOUNT::*', that means any account matches. So if you are supposed to count the 'ACCOUNT::*', \ you should count the entities of entity_type 'ACCOUNT'. IMPORTANT NOTES: - The id_name of each relationship has the format \ ____. - The relationship id_names are NOT UNIQUE, only the combinations of relationship id_name and source_document_id are unique. \ That is because each relationship is extracted from a document. So make sure you use the proper 'distinct's! - If the SQL contains a 'SELECT DISTINCT' clause and an ORDER BY clause, then you MUST include the columns from the ORDER BY \ clause ALSO IN THE SELECT DISTINCT CLAUSE! This is very important! (This is a postgres db., so this is a MUST!). \ You MUST NOT have a column in the ORDER BY clause that is not ALSO in the SELECT DISTINCT clause! - If you join the relationship table on itself using the source_node or target_node, you need to make sure that you also \ join on the source_document_id. - The id_name of each node/entity has the format ::, where 'entity_type_id_name' \ and 'name' are columns and \ the values and can be used for filtering. - The table can be joined on itself on source nodes and/or target nodes if needed. - the SQL statement MUST ultimately only return NODES/ENTITIES (not relationships!), or aggregations of \ entities/nodes(count, avg, max, min, etc.). \ Again, DO NOT compose a SQL statement that returns id_name of relationships. - You CAN ONLY return ENTITIES or COUNTS (or other aggregations) of ENTITIES, or you can return \ source_date (but only if the question asks for event dates or times). DO NOT return \ source documents or counts of source documents, or relationships or counts of relationships! \ Those can only appear in where clauses, ordering etc., but they cannot be returned or ultimately \ counted here! source_date and date operations can appear in select statements, particularly if \ there is time ordering or grouping involved. - ENTITIES can be target_entity or source_entity. Think about the allowed relationships and the \ question to decide which one you want! - It is ok to generate nested SQL as long as it is correct postgres syntax! - Attributes are stored in the attributes json field. As this is postgres, querying for those must be done as \ "attributes ->> '' = ''". - The SELECT clause MUST only contain entities or aggregations/counts of entities, or, in cases the \ question was about dates or times, then it can also include source_date. But source_document MUST NEVER appear \ in the SELECT clause! - Again, NEVER count or retrieve source documents in SELECT CLAUSE, whether it is in combination with \ entities, with a distinct, etc. NO source_document in SELECT CLAUSE! So NEVER produce a \ 'SELECT COUNT(source_entity, source_document)...' - Please think about whether you are interested in source entities or target entities! For that purpose, \ consider the allowed relationship types to make sure you select or count the correct one! - Again, ALWAYS make sure that EACH COLUMN in an ORDER-BY clause IS ALSO IN THE SELECT CLAUSE! Remind yourself \ of that in the reasoning. - Be careful with dates! Often a date will refer to the source data, which is the date when \ an underlying piece of information was updated. However, if the attributes of an entity contain \ time information as well (like 'started_at', 'completed_at', etc.), then you should really look at \ the wording to see whether you should use a date in the attributes or the event date. - Dates are ALWAYS in string format of the form YYYY-MM-DD, for source date as well as for date-like the attributes! \ So please use that format, particularly if you use data comparisons (>, <, ...) - Again, NO 'relationship' or 'source_document' in the SELECT CLAUSE, be it as direct columns are in aggregations! - Careful with SORT! Really think in which order you want to sort if you have multiple columns you \ want to sort by. If the sorting is time-based and there is a limit for example, then you do want to have a suitable date \ variable as the first column to sort by. - When doing a SORT on an attribute value of an entity, you MUST also apply a WHERE clause to filter \ for entities that have the attribute value set. For example, if you want to sort the target entity \ by the attribute 'created_date', you must also have a WHERE clause that checks whether the target \ entity attribute contains 'created_date'. This is vital for proper ordering with null values. - Usually, you will want to retrieve or count entities, maybe with attributes. But you almost always want to \ have entities involved in the SELECT clause. - Questions like 'What did Paul work on last week?' should generally be handled by finding all entities \ that reasonably relate to 'work entities' that are i) related to Paul, and ii) that were created or \ updated (by him) last week. So this would likely be a UNION of multiple queries. - If you do joins consider the possibility that the second entity does not exist for all examples. \ Therefore joins should generally be LEFT joins (or RIGHT joins) as appropriate. Think about which \ entities you are interested in, and which ones provides attributes. Another important note: - For questions that really try to explore what a certain entity was involved in like 'what did Paul Smith do \ in the last 3 months?', and Paul Smith has been extracted ie as an entity of type 'EMPLOYEE', you will \ want to consider all entities that Paul Smith may be related to that satisfy any potential other conditions. - Joins should always be made on entities, not source documents! - Try to be as efficient as possible. APPROACH: Please think through this step by step. Make sure that you include all columns in the ORDER BY clause \ also in the SELECT DISTINCT clause, \ if applicable! And again, joins should generally be LEFT JOINS! Also, in case it is important, today is ---today_date--- and the user/employee asking is ---user_name---. Please structure your answer using , , , start and end tags as in: [think through the logic but do so extremely briefly! Not more than 3-4 sentences.] [the SQL statement that you generate to satisfy the task] """.strip() # TODO: remove following before merging after enough testing SIMPLE_SQL_CORRECTION_PROMPT = f""" You are an expert in reviewing and fixing SQL statements. Here is a draft SQL statement that you should consider as generally capturing the information intended. \ However, it may or may not be syntactically 100% for our postgresql database. Guidance: - Think about whether attributes should be numbers or strings. You may need to convert them. - If we use SELECT DISTINCT we need to have the ORDER BY columns in the \ SELECT statement as well! And it needs to be in the EXACT FORM! So if a \ conversion took place, make sure to include the conversion in the SELECT and the ORDER BY clause! - never should 'source_document' be in the SELECT clause! Remove if present! - if there are joins, they must be on entities, never source documents - if there are joins, consider the possibility that the second entity does not exist for all examples.\ Therefore consider using LEFT joins (or RIGHT joins) as appropriate. Draft SQL: {SEPARATOR_LINE} ---draft_sql--- {SEPARATOR_LINE} Please structure your answer using , , , start and end tags as in: [think briefly through the problem step by step] [the corrected (or original one, if correct) SQL statement] """.strip() SIMPLE_ENTITY_SQL_PROMPT = f""" You are an expert in generating a SQL statement that only uses ONE TABLE that captures ENTITIES \ and their attributes and other data. The table has the following structure: {SEPARATOR_LINE} {ENTITY_TABLE_DESCRIPTION} Here is the question you are supposed to translate into a SQL statement: {SEPARATOR_LINE} ---question--- {SEPARATOR_LINE} To help you, we already have identified the entities that the SQL statement likely *should* use (but note the \ exception below!). The entities as written below also contain the list of attributes and attribute values \ that should specify the entity. \ The format is ::--[:, \ :, ...]. {SEPARATOR_LINE} Identified entities with attributes in query: ---query_entities_with_attributes--- These are the entities that should be used in the SQL statement. However, \ note that these are the entities (with potential attributes) that were *matches* of Knowledge Graph identified with the \ entities originally identified in the original question. As such, they may have id names that may not mean much by themselves, \ eg ACCOUNT::a74f332. Here is the mapping of entities originally identified (whose role in the query should be obvious) with \ the entities that were matched to them in the Knowledge Graph: ---entity_explanation_string--- -- {SEPARATOR_LINE} CRITICAL SPECIAL CASE: - if an identified entity is of the form ::*, or an identified relationship contains an \ entity of this form, this refers to *any* entity of that type. Correspondingly, the SQL query should use the *entity type*, \ but not the entity with the * itself. \ Example: if you see 'ACCOUNT::*', that means any account matches. So if you are supposed to count the 'ACCOUNT::*', \ you should count the entities of entity_type 'ACCOUNT'. IMPORTANT NOTES: - The entities are unique in the table. - If the SQL contains a 'SELECT DISTINCT' clause and an ORDER BY clause, then you MUST include the columns from the ORDER BY \ clause ALSO IN THE SELECT DISTINCT CLAUSE! This is very important! (This is a postgres db., so this is a MUST!). \ You MUST NOT have a column in the ORDER BY clause that is not ALSO in the SELECT DISTINCT clause! - The table cannot be joined on itself. - You CAN ONLY return ENTITIES or COUNTS (or other aggregations) of ENTITIES, or you can return \ source_date (but only if the question asks for event dates or times, and then the \ corresponding entity must also be returned). - Generally, the query can only return ENTITIES or aggregations of ENTITIES: - if individual entities are returned, then you MUST also return the source_document. \ If the source date was requested, you can return that too. - if aggregations of entities are returned, then you can only aggregate the entities. - Attributes are stored in the attributes json field. As this is postgres, querying for those must be done as \ "attributes ->> '' = ''". - Again, ALWAYS make sure that EACH COLUMN in an ORDER-BY clause IS ALSO IN THE SELECT CLAUSE! Remind yourself \ of that in the reasoning. - Be careful with dates! Often a date will refer to the source data, which is the date when \ an underlying piece of information was updated. However, if the attributes of an entity may contain \ time information as well (like 'started_at', 'completed_at', etc.), then you should really look at \ the wording to see whether you should use a date in the attributes or the event date. - Dates are ALWAYS in string format of the form YYYY-MM-DD, for source date as well as for date-like the attributes! \ So please use that format, particularly if you use data comparisons (>, <, ...) - Careful with SORT! Really think in which order you want to sort if you have multiple columns you \ want to sort by. If the sorting is time-based and there is a limit for example, then you do want to have a suitable date \ variable as the first column to sort by. - When doing a SORT on an attribute value of an entity, you MUST also apply a WHERE clause to filter \ for entities that have the attribute value set. For example, if you want to sort the target entity \ by the attribute 'created_date', you must also have a WHERE clause that checks whether the target \ entity attribute contains 'created_date'. This is vital for proper ordering with null values. - Usually, you will want to retrieve or count entities, maybe with attributes. But you almost always want to \ have entities involved in the SELECT clause. - You MUST ONLY rely on the entity attributes provided! This is essential! Do not assume \ other attributes exist...they don't! Note that there will often be a search using the results \ of this query. So if there is information in the question that does not fit the provided attributes, \ you should not use it here but rely on the later search! - Try to be as efficient as possible. APPROACH: Please think through this step by step. Make sure that you include all columns in the ORDER BY clause \ also in the SELECT DISTINCT clause, \ if applicable! Also, in case it is important, today is ---today_date--- and the user/employee asking is ---user_name---. Please structure your answer using , , , start and end tags as in: [think through the logic but do so extremely briefly! Not more than 3-4 sentences.] [the SQL statement that you generate to satisfy the task] """.strip() SIMPLE_SQL_ERROR_FIX_PROMPT = f""" You are an expert at fixing SQL statements. You will be provided with a SQL statement that aims to address \ a question, but it contains an error. Your task is to fix the SQL statement, based on the error message. Here is the description of the table that the SQL statement is supposed to use: ---table_description--- Here is the question you are supposed to translate into a SQL statement: {SEPARATOR_LINE} ---question--- {SEPARATOR_LINE} Here is the SQL statement that you should fix: {SEPARATOR_LINE} ---sql_statement--- {SEPARATOR_LINE} Here is the error message that was returned: {SEPARATOR_LINE} ---error_message--- {SEPARATOR_LINE} Note that in the case the error states the sql statement did not return any results, it is possible that the \ sql statement is correct, but the question is not addressable with the information in the knowledge graph. \ If you are absolutely certain that is the case, you may return the original sql statement. Here are a couple common errors that you may encounter: - source_document is in the SELECT clause -> remove it - columns used in ORDER BY must also appear in the SELECT DISTINCT clause - consider carefully the type of the columns you are using, especially for attributes. You may need to cast them - dates are ALWAYS in string format of the form YYYY-MM-DD, for source date as well as for date-like the attributes! \ So please use that format, particularly if you use data comparisons (>, <, ...) - attributes are stored in the attributes json field. As this is postgres, querying for those must be done as \ "attributes ->> '' = ''" (or "attributes ? ''" to check for existence). - if you are using joins and the sql returned no joins, make sure you are using the appropriate join type (LEFT, RIGHT, etc.) \ it is possible that the second entity does not exist for all examples. - (ignore if using entity_table) if using the relationship_table and the sql returned no results, make sure you are \ selecting the correct column! Use the available relationship types to determine whether to use the source or target entity. APPROACH: Please think through this step by step. Please also bear in mind that the sql statement is written in postgres syntax. Also, in case it is important, today is ---today_date--- and the user/employee asking is ---user_name---. Please structure your answer using , , , start and end tags as in: [think through the logic but do so extremely briefly! Not more than 3-4 sentences.] [the SQL statement that you generate to satisfy the task] """ SEARCH_FILTER_CONSTRUCTION_PROMPT = f""" You need to prepare a search across text segments that contain the information necessary to \ answer a question. The text segments have tags that can be used to filter for the relevant segments. \ Key are suitable entities and relationships of a knowledge graph, as well as underlying source documents. Your overall task is to find the filters and structures that are needed to filtering a database to \ properly address a user question. You will be given: - the user question - a description of all of the potential entity types involved - a list of 'global' entities and relationships that should be filtered by, given the question - the structure of a schema that was used to derive additional entity filters - a SQL statement that was generated to derive those filters - the results that were generated using the SQL statement. This can have multiple rows, \ and those will be the 'local' filters (which will later mean that each retrieved result will \ need to match at least one of the conditions that you will generate). - the results of another query that asked for the underlying source documents that resulted \ in the answers of the SQL statement Here is the information: 1) The overall user question {SEPARATOR_LINE} ---question--- {SEPARATOR_LINE} 2) Here is a description of all of the entity types: {SEPARATOR_LINE} ---entity_type_descriptions--- {SEPARATOR_LINE} 3) Here are the lists of entity and relationship filters that were derived from the question: {SEPARATOR_LINE} Entity filters: ---entity_filters--- -- Relationship filters: ---relationship_filters--- {SEPARATOR_LINE} 4) Here are the columns of a table in a database that has a lot of knowledge about the \ data: {SEPARATOR_LINE} - relationship (str): The name of the RELATIONSHIP, combining the nature of the relationship and the names of the entities. \ It is of the form \ ::____:: \ [example: ACCOUNT::Nike__has__CONCERN::performance]. Note that this is NOT UNIQUE! - source_entity (str): the id of the source ENTITY/NODE in the relationship [example: ACCOUNT::Nike] - source_entity_attributes (json): the attributes of the source entity/node [example: {{"account_type": "customer"}}] - target_entity (str): the id of the target ENTITY/NODE in the relationship [example: CONCERN::performance] - target_entity_attributes (json): the attributes of the target entity/node [example: {{"degree": "severe"}}] - source_entity_type (str): the type of the source entity/node [example: ACCOUNT]. Only the entity types provided \ below are valid. - target_entity_type (str): the type of the target entity/node [example: CONCERN]. Only the entity types provided \ below are valid. - relationship_type (str): the type of the relationship, formatted as \ ____. So the explicit entity_names have \ been removed. [example: ACCOUNT__has__CONCERN] - source_document (str): the id of the document that contains the relationship. Note that the combination of \ id_name and source_document IS UNIQUE! - source_date (str): the 'event' date of the source document [example: 2021-01-01] {SEPARATOR_LINE} 5) Here is a query that was generated for that table to provide additional filters: {SEPARATOR_LINE} ---sql_query--- {SEPARATOR_LINE} 6) Here are the results of that SQL query. (Consider the schema description and the \ structure of the entities to interpret the results) {SEPARATOR_LINE} ---sql_results--- {SEPARATOR_LINE} 7) Here are the results of the other query that provided the underlying source documents \ using the schema: {SEPARATOR_LINE} ---source_document_results--- {SEPARATOR_LINE} Here is the detailed set of tasks that you should perform, including the proper output format for you: Please reply as a json dictionary in this form: {{ "global_entity_filters": , "global_relationship_filters": , "local_entity_filters": , "source_document_filters": , "structure": }} Again - DO NOT FORGET - here is the user question that motivates this whole task: {SEPARATOR_LINE} ---question--- {SEPARATOR_LINE} Your json dictionary answer: """.strip() OUTPUT_FORMAT_NO_EXAMPLES_PROMPT = f""" You need to format an answer to a research question. \ You will see what the desired output is, the original question, and the unformatted answer to the research question. \ Your purpose is to generate the answer respecting the desired format. Notes: - Note that you are a language model and that answers may or may not be perfect. To communicate \ this to the user, consider phrases like 'I found [10 accounts]...', or 'Here are a number of [goals] that \ I found...] - Please DO NOT mention the explicit output format in your answer. Just use it to inform the formatting. Here is the unformatted answer to the research question: {SEPARATOR_LINE} ---introductory_answer--- {SEPARATOR_LINE} Here is the original question: {SEPARATOR_LINE} ---question--- {SEPARATOR_LINE} And finally, here is the desired output format: {SEPARATOR_LINE} ---output_format--- {SEPARATOR_LINE} Please start generating the answer, without any explanation. There should be no real modifications to \ the text, after all, all you need to do here is formatting. \ Your Answer: """.strip() OUTPUT_FORMAT_PROMPT = f""" You need to format the answers to a research question that was generated using one or more objects. \ An overall introductory answer may be provided to you, as well as the research results for each individual object. \ You will also be provided with the original question as background, and the desired format. \ Your purpose is to generate a consolidated and FORMATTED answer that starts of with the introductory \ answer, and then formats the research results for each individual object in the desired format. \ Do not add any other text please! Notes: - Note that you are a language model and that answers may or may not be perfect. To communicate \ this to the user, consider phrases like 'I found [10 accounts]...', or 'Here are a number of [goals] that \ I found...] - Please DO NOT mention the explicit output format in your answer. Just use it to inform the formatting. - DO NOT add any content to the introductory answer! Here is the original question for your background: {SEPARATOR_LINE} ---question--- {SEPARATOR_LINE} Here is the desired output format: {SEPARATOR_LINE} ---output_format--- {SEPARATOR_LINE} Here is the introductory answer: {SEPARATOR_LINE} ---introductory_answer--- {SEPARATOR_LINE} Here are the research results that you should - respecting the target format- return in a formatted way: {SEPARATOR_LINE} ---research_results--- {SEPARATOR_LINE} Please start generating the answer, without any explanation. After all, all you need to do here is formatting. \ Your Answer: """.strip() OUTPUT_FORMAT_NO_OVERALL_ANSWER_PROMPT = f""" You need to format the return of research on multiple objects. The research results will be given \ to you as a string. You will also see what the desired output is, as well as the original question. \ Your purpose is to generate the answer respecting the desired format. Notes: - Note that you are a language model and that answers may or may not be perfect. To communicate \ this to the user, consider phrases like 'I found [10 accounts]...', or 'Here are a number of [goals] that \ I found...] - Please DO NOT mention the explicit output format in your answer. Just use it to inform the formatting. - Often, you are also provided with a list of explicit examples. If - AND ONLY IF - the list is not \ empty, then these should be listed at the end with the text: '... Here are some examples of what I found: ...' - Again if the list of examples is an empty string then skip this section! Do not use the \ results data for this purpose instead! (They will already be handled in the answer.) - Even if the desired output format is 'text', make sure that you keep the individual research results \ separated by bullet points, and mention the object name first, followed by a new line. The object name \ is at the beginning of the research result, and should be in the format ::. Here is the original question: {SEPARATOR_LINE} ---question--- {SEPARATOR_LINE} And finally, here is the desired output format: {SEPARATOR_LINE} ---output_format--- {SEPARATOR_LINE} Here are the research results that you should properly format: {SEPARATOR_LINE} ---research_results--- {SEPARATOR_LINE} Please start generating the answer, without any explanation. After all, all you need to do here is formatting. \ Your Answer: """.strip() KG_OBJECT_SOURCE_RESEARCH_PROMPT = f""" You are an expert in extracting relevant structured information from a list of documents that \ should relate to one object. You are presented with a list of documents that have been determined to be \ relevant to the task of interest. Your goal is to extract the information asked around these topics: You should look at the documents - in no particular order! - and extract the information that relates \ to a question: {SEPARATOR_LINE} {{question}} {SEPARATOR_LINE} Here are the documents you are supposed to search through: -- {{document_text}} {SEPARATOR_LINE} Note: in this case, please do NOT cite your sources. This is very important! Please now generate the answer to the question given the documents: """.strip() KG_SEARCH_PROMPT = f""" You are an expert in extracting relevant structured information from a list of documents that \ should relate to one object. You are presented with a list of documents that have been determined to be \ relevant to the task of interest. Your goal is to extract the information asked around these topics: You should look at the documents and extract the information that relates \ to a question: {SEPARATOR_LINE} {{question}} {SEPARATOR_LINE} Here are the documents you are supposed to search through: -- {{document_text}} {SEPARATOR_LINE} Note: in this case, please DO cite your sources. This is very important! \ Use the format []. Ie, use [1], [2], and NOT [1,2] if \ there are two documents to cite, etc. \ Please now generate the answer to the question given the documents: """.strip() # KG Beta Assistant System Prompt KG_BETA_ASSISTANT_SYSTEM_PROMPT = """"You are a knowledge graph assistant that helps users explore and \ understand relationships between entities.""" KG_BETA_ASSISTANT_TASK_PROMPT = """"Help users explore and understand the knowledge graph by answering \ questions about entities and their relationships.""" # Just in case, for best practice, send a system message with key rules. # (The db user permissions executing the SQL will avoid issues anyway, # but it does not hurt to to put multiple checks in place.) SQL_INSTRUCTIONS_RELATIONSHIP_PROMPT = """ You are an expert at generating SQL queries to answer questions about a knowledge graph. You will be given a lot of instructions later, but here rules that MUST BE FOLLOWED: - the SQL generated MUST only use the table one table named 'relationship_table'. \ This table is not a table that can be defined or overwritten by the user and the resulting SQL \ statement, it MUST be seen as an existing table in the database. - self-joins of the 'relationship_table' are allowed, as well as common table expressions \ that reference only the 'relationship_table'. - no other table or view can in any way or shape be \ involved in the generated SQL. - no other database operations can be generated except for those that query the 'relationship_table'. \ (WHERE, GROUP BY, etc. are certainly allowed, but no other database table can be used in the generated SQL.) """ SQL_INSTRUCTIONS_ENTITY_PROMPT = """ You are an expert at generating SQL queries to answer questions about a knowledge graph. You will be given a lot of instructions later, but here rules that MUST BE FOLLOWED: - the SQL generated MUST only use the table one table named 'entity_table'. \ This table is not a table that can be defined or overwritten by the user and the resulting SQL \ statement, it MUST be seen as an existing table in the database. - common table expressions that reference only the 'entity_table' are allowed. - no other table or view of a potential underlying schema can in any way or shape be \ involved in the generated SQL. - no other database operations can be generated except for those that query the 'entity_table'. \ (WHERE, GROUP BY, etc. are certainly allowed, but no other database table can be used in the generated SQL.) """ ================================================ FILE: backend/onyx/prompts/prompt_template.py ================================================ import re from onyx.prompts.prompt_utils import replace_current_datetime_tag class PromptTemplate: """ A class for building prompt templates with placeholders. Useful when building templates with json schemas, as {} will not work with f-strings. Unlike string.replace, this class will raise an error if the fields are missing. """ DEFAULT_PATTERN = r"---([a-zA-Z0-9_]+)---" def __init__(self, template: str, pattern: str = DEFAULT_PATTERN): self._pattern_str = pattern self._pattern = re.compile(pattern) self._template = template self._fields: set[str] = set(self._pattern.findall(template)) def build(self, **kwargs: str) -> str: """ Build the prompt template with the given fields. Will raise an error if the fields are missing. Will ignore fields that are not in the template. """ missing = self._fields - set(kwargs.keys()) if missing: raise ValueError(f"Missing required fields: {missing}.") built = self._replace_fields(kwargs) return self._postprocess(built) def partial_build(self, **kwargs: str) -> "PromptTemplate": """ Returns another PromptTemplate with the given fields replaced. Will ignore fields that are not in the template. """ new_template = self._replace_fields(kwargs) return PromptTemplate(new_template, self._pattern_str) def _replace_fields(self, field_vals: dict[str, str]) -> str: def repl(match: re.Match) -> str: key = match.group(1) return field_vals.get(key, match.group(0)) return self._pattern.sub(repl, self._template) def _postprocess(self, text: str) -> str: """Apply global replacements such as [[CURRENT_DATETIME]].""" if not text: return text # Ensure [[CURRENT_DATETIME]] matches shared prompt formatting return replace_current_datetime_tag( text, full_sentence=True, include_day_of_week=True, ) ================================================ FILE: backend/onyx/prompts/prompt_utils.py ================================================ from datetime import datetime from typing import cast from langchain_core.messages import BaseMessage from onyx.configs.constants import DocumentSource from onyx.prompts.chat_prompts import ADDITIONAL_INFO from onyx.prompts.chat_prompts import CITATION_GUIDANCE_REPLACEMENT_PAT from onyx.prompts.chat_prompts import COMPANY_DESCRIPTION_BLOCK from onyx.prompts.chat_prompts import COMPANY_NAME_BLOCK from onyx.prompts.chat_prompts import DATETIME_REPLACEMENT_PAT from onyx.prompts.chat_prompts import REMINDER_TAG_REPLACEMENT_PAT from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE from onyx.prompts.constants import CODE_BLOCK_PAT from onyx.prompts.constants import REMINDER_TAG_DESCRIPTION from onyx.server.settings.store import load_settings from onyx.utils.logger import setup_logger logger = setup_logger() _BASIC_TIME_STR = "The current date is {datetime_info}." def get_current_llm_day_time( include_day_of_week: bool = True, full_sentence: bool = True, include_hour_min: bool = False, ) -> str: current_datetime = datetime.now() # Format looks like: "October 16, 2023 14:30" if include_hour_min, otherwise "October 16, 2023" formatted_datetime = ( current_datetime.strftime("%B %d, %Y %H:%M") if include_hour_min else current_datetime.strftime("%B %d, %Y") ) day_of_week = current_datetime.strftime("%A") if full_sentence: return f"The current day and time is {day_of_week} {formatted_datetime}" if include_day_of_week: return f"{day_of_week} {formatted_datetime}" return f"{formatted_datetime}" def replace_current_datetime_tag( prompt_str: str, *, full_sentence: bool = False, include_day_of_week: bool = True, ) -> str: datetime_str = get_current_llm_day_time( full_sentence=full_sentence, include_day_of_week=include_day_of_week, ) if DATETIME_REPLACEMENT_PAT in prompt_str: prompt_str = prompt_str.replace(DATETIME_REPLACEMENT_PAT, datetime_str) return prompt_str def replace_citation_guidance_tag( prompt_str: str, *, should_cite_documents: bool = False, include_all_guidance: bool = False, ) -> tuple[str, bool]: """ Replace {{CITATION_GUIDANCE}} placeholder with citation guidance if needed. Returns: tuple[str, bool]: (prompt_with_replacement, should_append_fallback) - prompt_with_replacement: The prompt with placeholder replaced (or unchanged if not present) - should_append_fallback: True if citation guidance should be appended (placeholder is not present and citations are needed) """ placeholder_was_present = CITATION_GUIDANCE_REPLACEMENT_PAT in prompt_str if not placeholder_was_present: # Placeholder not present - caller should append if citations are needed should_append = ( should_cite_documents or include_all_guidance ) and REQUIRE_CITATION_GUIDANCE not in prompt_str return prompt_str, should_append citation_guidance = ( REQUIRE_CITATION_GUIDANCE if should_cite_documents or include_all_guidance else "" ) prompt_str = prompt_str.replace( CITATION_GUIDANCE_REPLACEMENT_PAT, citation_guidance, ) return prompt_str, False def replace_reminder_tag(prompt_str: str) -> str: """Replace {{REMINDER_TAG_DESCRIPTION}} with the reminder tag content.""" if REMINDER_TAG_REPLACEMENT_PAT in prompt_str: prompt_str = prompt_str.replace( REMINDER_TAG_REPLACEMENT_PAT, REMINDER_TAG_DESCRIPTION ) return prompt_str def handle_onyx_date_awareness( prompt_str: str, # We always replace the pattern {{CURRENT_DATETIME}} if it shows up # but if it doesn't show up and the prompt is datetime aware, add it to the prompt at the end. datetime_aware: bool = False, ) -> str: """ If there is a {{CURRENT_DATETIME}} tag, replace it with the current date and time no matter what. If the prompt is datetime aware, and there are no datetime tags, add it to the prompt. Do nothing otherwise. This can later be expanded to support other tags. """ prompt_with_datetime = replace_current_datetime_tag( prompt_str, full_sentence=False, include_day_of_week=True, ) if prompt_with_datetime != prompt_str: return prompt_with_datetime if datetime_aware: return prompt_str + ADDITIONAL_INFO.format( datetime_info=_BASIC_TIME_STR.format( datetime_info=get_current_llm_day_time() ) ) return prompt_str def get_company_context() -> str | None: prompt_str = None try: workspace_settings = load_settings() company_name = workspace_settings.company_name company_description = workspace_settings.company_description if not company_name and not company_description: return None prompt_str = "" if company_name: prompt_str += COMPANY_NAME_BLOCK.format(company_name=company_name) if company_description: prompt_str += COMPANY_DESCRIPTION_BLOCK.format( company_description=company_description ) return prompt_str except Exception as e: logger.error(f"Error handling company awareness: {e}") return None # Maps connector enum string to a more natural language representation for the LLM # If not on the list, uses the original but slightly cleaned up, see below CONNECTOR_NAME_MAP = { "web": "Website", "requesttracker": "Request Tracker", "github": "GitHub", "file": "File Upload", } def clean_up_source(source_str: str) -> str: if source_str in CONNECTOR_NAME_MAP: return CONNECTOR_NAME_MAP[source_str] return source_str.replace("_", " ").title() def build_doc_context_str( semantic_identifier: str, source_type: DocumentSource, content: str, metadata_dict: dict[str, str | list[str]], updated_at: datetime | None, ind: int, include_metadata: bool = True, ) -> str: context_str = "" if include_metadata: context_str += f"DOCUMENT {ind}: {semantic_identifier}\n" context_str += f"Source: {clean_up_source(source_type)}\n" for k, v in metadata_dict.items(): if isinstance(v, list): v_str = ", ".join(v) context_str += f"{k.capitalize()}: {v_str}\n" else: context_str += f"{k.capitalize()}: {v}\n" if updated_at: update_str = updated_at.strftime("%B %d, %Y %H:%M") context_str += f"Updated: {update_str}\n" context_str += f"{CODE_BLOCK_PAT.format(content.strip())}\n\n\n" return context_str _PER_MESSAGE_TOKEN_BUFFER = 7 def find_last_index(lst: list[int], max_prompt_tokens: int) -> int: """From the back, find the index of the last element to include before the list exceeds the maximum""" running_sum = 0 if not lst: logger.warning("Empty message history passed to find_last_index") return 0 last_ind = 0 for i in range(len(lst) - 1, -1, -1): running_sum += lst[i] + _PER_MESSAGE_TOKEN_BUFFER if running_sum > max_prompt_tokens: last_ind = i + 1 break if last_ind >= len(lst): logger.error( f"Last message alone is too large! max_prompt_tokens: {max_prompt_tokens}, message_token_counts: {lst}" ) raise ValueError("Last message alone is too large!") return last_ind def drop_messages_history_overflow( messages_with_token_cnts: list[tuple[BaseMessage, int]], max_allowed_tokens: int, ) -> list[BaseMessage]: """As message history grows, messages need to be dropped starting from the furthest in the past. The System message should be kept if at all possible and the latest user input which is inserted in the prompt template must be included""" final_messages: list[BaseMessage] = [] messages, token_counts = cast( tuple[list[BaseMessage], list[int]], zip(*messages_with_token_cnts) ) system_msg = ( final_messages[0] if final_messages and final_messages[0].type == "system" else None ) history_msgs = messages[:-1] final_msg = messages[-1] if final_msg.type != "human": if final_msg.type != "tool": raise ValueError("Last message must be user input OR a tool result") else: final_msgs = messages[-3:] history_msgs = messages[:-3] else: final_msgs = [final_msg] # Start dropping from the history if necessary ind_prev_msg_start = find_last_index( token_counts, max_prompt_tokens=max_allowed_tokens ) if system_msg and ind_prev_msg_start <= len(history_msgs): final_messages.append(system_msg) final_messages.extend(history_msgs[ind_prev_msg_start:]) final_messages.extend(final_msgs) return final_messages ================================================ FILE: backend/onyx/prompts/search_prompts.py ================================================ # How it works and rationale: # First - this works best empirically across multiple LLMs, some of this is back-explaining reasons based on results. # # The system prompt is kept simple and as similar to typical system prompts as possible to stay within training distribution. # The history is passed through as a list of messages, this should allow the LLM to more easily understand what is going on. # The special tokens and separators let the LLM more easily disregard no longer relevant past messages. # The last message is dynamically created and has a detailed description of the actual task. # This is based on the assumption that users give much more varied requests in their prompts and LLMs are well adjusted to this. # The proximity of the instructions and the lack of any breaks should also let the LLM follow the task more clearly. # # For document verification, the history is not included as the queries should ideally be standalone enough. # To keep it simple, it is just a single simple prompt. SEMANTIC_QUERY_REPHRASE_SYSTEM_PROMPT = """ You are an assistant that reformulates the last user message into a standalone, self-contained query suitable for \ semantic search. Your goal is to output a single natural language query that captures the full meaning of the user's \ most recent message. It should be fully semantic and natural language unless the user query is already a keyword query. \ When relevant, you bring in context from the history or knowledge about the user. The current date is {current_date}. """ SEMANTIC_QUERY_REPHRASE_USER_PROMPT = """ Given the chat history above (if any) and the final user query (provided below), provide a standalone query that is as representative of the user query as possible. In most cases, it should be exactly the same as the last user query. \ It should be fully semantic and natural language unless the user query is already a keyword query. \ Focus on the last user message, in most cases the history and extra context should be ignored. For a query like "What are the use cases for product X", your output should remain "What are the use cases for product X". \ It should remain semantic, and as close to the original query as possible. There is nothing additional needed \ from the history or that should be removed / replaced from the query. For modifications, you can: 1. Insert relevant context from the chat history. For example: "How do I set it up?" -> "How do I set up software Y?" (assuming the conversation was about software Y) 2. Remove asks or requests not related to the searching. For example: "Can you summarize the calls with example company" -> "calls with example company" "Can you find me the document that goes over all of the software to set up on an engineer's first day?" -> \ "all of the software to set up on an engineer's first day" 3. Fill in relevant information about the user. For example: "What document did I write last week?" -> "What document did John Doe write last week?" (assuming the user is John Doe) {additional_context} ========================= CRITICAL: ONLY provide the standalone query and nothing else. Final user query: {user_query} """.strip() KEYWORD_REPHRASE_SYSTEM_PROMPT = """ You are an assistant that reformulates the last user message into a set of standalone keyword queries suitable for a keyword \ search engine. Your goal is to output keyword queries that optimize finding relevant documents to answer the user query. \ When relevant, you bring in context from the history or knowledge about the user. The current date is {current_date}. """ KEYWORD_REPHRASE_USER_PROMPT = """ Given the chat history above (if any) and the final user query (provided below), provide a set of keyword only queries that can help find relevant documents. Provide a single query per line (where each query consists of one or more keywords). \ The queries must be purely keywords and not contain any natural language. \ Each query should have as few keywords as necessary to represent the user's search intent. Guidelines: - Do not provide more than 3 queries. - Do not replace or expand niche, proprietary, or obscure terms - Focus on the last user message, in most cases the history and any extra context should be ignored. {additional_context} ========================= CRITICAL: ONLY provide the keyword queries, one set of keywords per line and nothing else. Final user query: {user_query} """.strip() REPHRASE_CONTEXT_PROMPT = """ In most cases the following additional context is not needed. If relevant, here is some information about the user: {user_info} Here are some memories about the user: {memories} """ # This prompt is intended to be fairly lenient since there are additional filters downstream. # There are now multiple places for misleading docs to get dropped so each one can be a bit more lax. # As models get better, it's likely better to include more context than not, some questionably # useful stuff may be helpful downstream. # Adding the ! option to allow better models to handle questions where all of the documents are # necessary to make a good determination. # If a document is by far the best and is a very obvious inclusion, add a ! after the section_id to indicate that it should \ # be included in full. Example output: [8, 2!, 5]. DOCUMENT_SELECTION_PROMPT = """ Select the most relevant document sections for the user's query (maximum {max_sections}).{extra_instructions} # Document Sections ``` {formatted_doc_sections} ``` # User Query ``` {user_query} ``` # Selection Criteria - Choose sections most relevant to answering the query, if at all in doubt, include the section. - Even if only a tiny part of the section is relevant, include it. - It is ok to select multiple sections from the same document. - Consider indirect connections and supporting context to be valuable. - If the section is not directly helpful but the document seems relevant, there is an opportunity \ later to expand the section and read more from the document so include the section. # Output Format Return ONLY section_ids as a comma-separated list, ordered by relevance: [most_relevant_section_id, second_most_relevant_section_id, ...] Section IDs: """.strip() TRY_TO_FILL_TO_MAX_INSTRUCTIONS = """ Try to fill the list to the maximum number of sections if possible without including non-relevant or misleading sections. """ # Some models are trained heavily to reason in the actual output so we allow some flexibility in the prompt. # Downstream of the model, we will attempt to parse the output to extract the number. # This inference will not have a system prompt as it's a single message task more like the traditional ones. # LLMs should do better with just this type of next word prediction. # Opted to not include metadata here as the doc was already selected by the previous step that has it. # Also hopefully it leans not throwing out documents as there are not many bad ones that make it to this stage. # If anything, it's mostly because of something misleading, otherwise this step should be treated as 95% expansion/filtering. DOCUMENT_CONTEXT_SELECTION_PROMPT = """ Analyze the relevance of document sections to a search query and classify according to the categories \ described at the end of the prompt. # Document Title / Metadata ``` {document_title} ``` # Section Above: ``` {section_above} ``` # Main Section: ``` {main_section} ``` # Section Below: ``` {section_below} ``` # User Query: ``` {user_query} ``` # Classification Categories: **0 - NOT_RELEVANT** - Main section and surrounding sections do not help answer the query or provide meaningful, relevant information. - Appears on topic but refers to a different context or subject (could lead to potential confusion or misdirection). \ It is important to avoid conflating different contexts and subjects - if the document is related to the query but not about \ the correct subject. Example: "How much did we quote ACME for project X", "ACME paid us $100,000 for project Y". **1 - MAIN_SECTION_ONLY** - Main section contains useful information relevant to the query. - Adjacent sections do not provide additional directly relevant information. **2 - INCLUDE_ADJACENT_SECTIONS** - The main section AND adjacent sections are all useful for answering the user query. - The surrounding sections provide relevant information that does not exist in the main section. - Even if only 1 of the adjacent sections is useful or there is a small piece in either that is useful. - Additional unseen sections are unlikely to contain valuable related information. **3 - INCLUDE_FULL_DOCUMENT** - Additional unseen sections are likely to contain valuable related information to the query. ## Additional Decision Notes - If only a small piece of the document is useful - use classification 1 or 2, do not use 0. - If the document is on topic and provides additional context that might be useful in \ combination with other documents - use classification 1, 2 or 3, do not use 0. CRITICAL: ONLY output the NUMBER of the situation most applicable to the query and sections provided (0, 1, 2, or 3). Situation Number: """.strip() ================================================ FILE: backend/onyx/prompts/tool_prompts.py ================================================ # ruff: noqa: E501, W605 start # If there are any tools, this section is included, the sections below are for the available tools TOOL_SECTION_HEADER = "\n# Tools\n\n" # This section is included if there are search type tools, currently internal_search and web_search TOOL_DESCRIPTION_SEARCH_GUIDANCE = """ For questions that can be answered from existing knowledge, answer the user directly without using any tools. \ If you suspect your knowledge is outdated or for topics where things are rapidly changing, use search tools to get more context. \ For statements that may be describing or referring to a document, run a search for the document. \ In ambiguous cases, favor searching to get more context. When using any search type tool, do not make any assumptions and stay as faithful to the user's query as possible. \ Between internal and web search (if both are available), think about if the user's query is likely better answered by team internal sources or online web pages. \ When searching for information, if the initial results cannot fully answer the user's query, try again with different tools or arguments. \ Do not repeat the same or very similar queries if it already has been run in the chat history. If it is unclear which tool to use, consider using multiple in parallel to be efficient with time. """.lstrip() INTERNAL_SEARCH_GUIDANCE = """ ## internal_search Use the `internal_search` tool to search connected applications for information. Some examples of when to use `internal_search` include: - Internal information: any time where there may be some information stored in internal applications that could help better answer the query. - Niche/Specific information: information that is likely not found in public sources, things specific to a project or product, team, process, etc. - Keyword Queries: queries that are heavily keyword based are often internal document search queries. - Ambiguity: questions about something that is not widely known or understood. Never provide more than 3 queries at once to `internal_search`. """.lstrip() WEB_SEARCH_GUIDANCE = """ ## web_search Use the `web_search` tool to access up-to-date information from the web. Some examples of when to use `web_search` include: - Freshness: when the answer might be enhanced by up-to-date information on a topic. Very important for topics that are changing or evolving. - Accuracy: if the cost of outdated/inaccurate information is high. - Niche Information: when detailed info is not widely known or understood (but is likely found on the internet).{site_colon_disabled} """.lstrip() WEB_SEARCH_SITE_DISABLED_GUIDANCE = """ Do not use the "site:" operator in your web search queries. """.lstrip() OPEN_URLS_GUIDANCE = """ ## open_url Use the `open_url` tool to read the content of one or more URLs. Use this tool to access the contents of the most promising web pages from your web searches or user specified URLs. \ You can open many URLs at once by passing multiple URLs in the array if multiple pages seem promising. Prioritize the most promising pages and reputable sources. \ Do not open URLs that are image files like .png, .jpg, etc. You should almost always use open_url after a web_search call. Use this tool when a user asks about a specific provided URL. """.lstrip() PYTHON_TOOL_GUIDANCE = """ ## python Use the `python` tool to execute Python code in an isolated sandbox. The tool will respond with the output of the execution or time out after 60.0 seconds. Any files uploaded to the chat will be automatically be available in the execution environment's current directory. \ The current directory in the file system can be used to save and persist user files. Files written to the current directory will be returned with a `file_link`. \ Use this to give the user a way to download the file OR to display generated images. Internet access for this session is disabled. Do not make external web requests or API calls as they will fail. Use `openpyxl` to read and write Excel files. You have access to libraries like numpy, pandas, scipy, matplotlib, and PIL. IMPORTANT: each call to this tool is independent. Variables from previous calls will NOT be available in the current call. """.lstrip() GENERATE_IMAGE_GUIDANCE = """ ## generate_image NEVER use generate_image unless the user specifically requests an image. For edits/variations of a previously generated image, pass `reference_image_file_ids` with the `file_id` values returned by earlier `generate_image` tool results. """.lstrip() MEMORY_GUIDANCE = """ ## add_memory Use the `add_memory` tool for facts shared by the user that should be remembered for future conversations. \ Only add memories that are specific, likely to remain true, and likely to be useful later. \ Focus on enduring preferences, long-term goals, stable constraints, and explicit "remember this" type requests. """.lstrip() TOOL_CALL_FAILURE_PROMPT = """ LLM attempted to call a tool but failed. Most likely the tool name or arguments were misspelled. """.strip() # ruff: noqa: E501, W605 end ================================================ FILE: backend/onyx/prompts/user_info.py ================================================ # ruff: noqa: E501, W605 start USER_INFORMATION_HEADER = "\n# User Information\n\n" BASIC_INFORMATION_PROMPT = """ ## Basic Information User name: {user_name} User email: {user_email}{user_role} """.lstrip() # This line only shows up if the user has configured their role. USER_ROLE_PROMPT = """ User role: {user_role} """.lstrip() # Team information should be a paragraph style description of the user's team. TEAM_INFORMATION_PROMPT = """ ## Team Information {team_information} """.lstrip() # User preferences should be a paragraph style description of the user's preferences. USER_PREFERENCES_PROMPT = """ ## User Preferences {user_preferences} """.lstrip() # User memories should look something like: # - Memory 1 # - Memory 2 # - Memory 3 USER_MEMORIES_PROMPT = """ ## User Memories {user_memories} """.lstrip() # ruff: noqa: E501, W605 end ================================================ FILE: backend/onyx/redis/iam_auth.py ================================================ """ Redis IAM Authentication Module This module provides Redis IAM authentication functionality for AWS ElastiCache. Unlike RDS IAM auth, Redis IAM auth relies on IAM roles and policies rather than generating authentication tokens. Key functions: - configure_redis_iam_auth: Configure Redis connection parameters for IAM auth - create_redis_ssl_context_if_iam: Create SSL context for secure connections """ import ssl from typing import Any def configure_redis_iam_auth(connection_kwargs: dict[str, Any]) -> None: """ Configure Redis connection parameters for IAM authentication. Modifies the connection_kwargs dict in-place to: 1. Remove password (not needed with IAM) 2. Enable SSL with system CA certificates 3. Set proper SSL context for secure connections """ # Remove password as it's not needed with IAM authentication if "password" in connection_kwargs: del connection_kwargs["password"] # Ensure SSL is enabled for IAM authentication connection_kwargs["ssl"] = True connection_kwargs["ssl_context"] = create_redis_ssl_context_if_iam() def create_redis_ssl_context_if_iam() -> ssl.SSLContext: """Create an SSL context for Redis IAM authentication using system CA certificates.""" # Use system CA certificates by default - no need for additional CA files ssl_context = ssl.create_default_context() ssl_context.check_hostname = True ssl_context.verify_mode = ssl.CERT_REQUIRED return ssl_context ================================================ FILE: backend/onyx/redis/redis_connector.py ================================================ import redis from onyx.redis.redis_connector_delete import RedisConnectorDelete from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync from onyx.redis.redis_connector_prune import RedisConnectorPrune from onyx.redis.redis_connector_stop import RedisConnectorStop from onyx.redis.redis_pool import get_redis_client # TODO: reduce dependence on redis class RedisConnector: """Composes several classes to simplify interacting with a connector and its associated background tasks / associated redis interactions.""" def __init__(self, tenant_id: str, cc_pair_id: int) -> None: """id: a connector credential pair id""" self.tenant_id: str = tenant_id self.cc_pair_id: int = cc_pair_id self.redis: redis.Redis = get_redis_client(tenant_id=tenant_id) self.stop = RedisConnectorStop(tenant_id, cc_pair_id, self.redis) self.prune = RedisConnectorPrune(tenant_id, cc_pair_id, self.redis) self.delete = RedisConnectorDelete(tenant_id, cc_pair_id, self.redis) self.permissions = RedisConnectorPermissionSync( tenant_id, cc_pair_id, self.redis ) self.external_group_sync = RedisConnectorExternalGroupSync( tenant_id, cc_pair_id, self.redis ) @staticmethod def get_id_from_fence_key(key: str) -> str | None: """ Extracts the object ID from a fence key in the format `PREFIX_fence_X`. Args: key (str): The fence key string. Returns: Optional[int]: The extracted ID if the key is in the correct format, otherwise None. """ parts = key.split("_") if len(parts) != 3: return None object_id = parts[2] return object_id @staticmethod def get_id_from_task_id(task_id: str) -> str | None: """ Extracts the object ID from a task ID string. This method assumes the task ID is formatted as `prefix_objectid_suffix`, where: - `prefix` is an arbitrary string (e.g., the name of the task or entity), - `objectid` is the ID you want to extract, - `suffix` is another arbitrary string (e.g., a UUID). Example: If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`, this method will return the string `"1"`. Args: task_id (str): The task ID string from which to extract the object ID. Returns: str | None: The extracted object ID if the task ID is in the correct format, otherwise None. """ # example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc parts = task_id.split("_") if len(parts) != 3: return None object_id = parts[1] return object_id def db_lock_key(self, search_settings_id: int) -> str: """ Key for the db lock for an indexing attempt. Prevents multiple modifications to the current indexing attempt row from multiple docfetching/docprocessing tasks. """ return f"da_lock:indexing:db_{self.cc_pair_id}/{search_settings_id}" ================================================ FILE: backend/onyx/redis/redis_connector_delete.py ================================================ import time from datetime import datetime from typing import cast from uuid import uuid4 import redis from celery import Celery from pydantic import BaseModel from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from onyx.configs.app_configs import DB_YIELD_PER_DEFAULT from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisConstants from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.db.document import construct_document_id_select_for_connector_credential_pair class RedisConnectorDeletePayload(BaseModel): num_tasks: int | None submitted: datetime class RedisConnectorDelete: """Manages interactions with redis for deletion tasks. Should only be accessed through RedisConnector.""" PREFIX = "connectordeletion" FENCE_PREFIX = f"{PREFIX}_fence" # "connectordeletion_fence" FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks TASKSET_PREFIX = f"{PREFIX}_taskset" # "connectordeletion_taskset" TASKSET_TTL = FENCE_TTL # used to signal the overall workflow is still active # it's impossible to get the exact state of the system at a single point in time # so we need a signal with a TTL to bridge gaps in our checks ACTIVE_PREFIX = PREFIX + "_active" ACTIVE_TTL = 3600 def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None: self.tenant_id: str = tenant_id self.id = id self.redis = redis self.fence_key: str = f"{self.FENCE_PREFIX}_{id}" self.taskset_key = f"{self.TASKSET_PREFIX}_{id}" self.active_key = f"{self.ACTIVE_PREFIX}_{id}" def taskset_clear(self) -> None: self.redis.delete(self.taskset_key) def get_remaining(self) -> int: # todo: move into fence remaining = cast(int, self.redis.scard(self.taskset_key)) return remaining @property def fenced(self) -> bool: return bool(self.redis.exists(self.fence_key)) @property def payload(self) -> RedisConnectorDeletePayload | None: # read related data and evaluate/print task progress fence_bytes = cast(bytes, self.redis.get(self.fence_key)) if fence_bytes is None: return None fence_str = fence_bytes.decode("utf-8") payload = RedisConnectorDeletePayload.model_validate_json(cast(str, fence_str)) return payload def set_fence(self, payload: RedisConnectorDeletePayload | None) -> None: if not payload: self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) self.redis.delete(self.fence_key) return self.redis.set(self.fence_key, payload.model_dump_json(), ex=self.FENCE_TTL) self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) def set_active(self) -> None: """This sets a signal to keep the permissioning flow from getting cleaned up within the expiration time. The slack in timing is needed to avoid race conditions where simply checking the celery queue and task status could result in race conditions.""" self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL) def active(self) -> bool: return bool(self.redis.exists(self.active_key)) def _generate_task_id(self) -> str: # celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac" # we prefix the task id so it's easier to keep track of who created the task # aka "connectordeletion_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac" return f"{self.PREFIX}_{self.id}_{uuid4()}" def generate_tasks( self, celery_app: Celery, db_session: Session, lock: RedisLock, ) -> int | None: """Returns None if the cc_pair doesn't exist. Otherwise, returns an int with the number of generated tasks.""" last_lock_time = time.monotonic() cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=int(self.id), ) if not cc_pair: return None num_tasks_sent = 0 stmt = construct_document_id_select_for_connector_credential_pair( cc_pair.connector_id, cc_pair.credential_id ) for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT): doc_id = cast(str, doc_id) current_time = time.monotonic() if current_time - last_lock_time >= ( CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 ): lock.reacquire() last_lock_time = current_time custom_task_id = self._generate_task_id() # add to the tracking taskset in redis BEFORE creating the celery task. # note that for the moment we are using a single taskset key, not differentiated by cc_pair id self.redis.sadd(self.taskset_key, custom_task_id) self.redis.expire(self.taskset_key, self.TASKSET_TTL) # Priority on sync's triggered by new indexing should be medium celery_app.send_task( OnyxCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK, kwargs=dict( document_id=doc_id, connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id, tenant_id=self.tenant_id, ), queue=OnyxCeleryQueues.CONNECTOR_DELETION, task_id=custom_task_id, priority=OnyxCeleryPriority.MEDIUM, ignore_result=True, ) num_tasks_sent += 1 return num_tasks_sent def reset(self) -> None: self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) self.redis.delete(self.active_key) self.redis.delete(self.taskset_key) self.redis.delete(self.fence_key) @staticmethod def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None: taskset_key = f"{RedisConnectorDelete.TASKSET_PREFIX}_{id}" r.srem(taskset_key, task_id) return @staticmethod def reset_all(r: redis.Redis) -> None: """Deletes all redis values for all connectors""" for key in r.scan_iter(RedisConnectorDelete.ACTIVE_PREFIX + "*"): r.delete(key) for key in r.scan_iter(RedisConnectorDelete.TASKSET_PREFIX + "*"): r.delete(key) for key in r.scan_iter(RedisConnectorDelete.FENCE_PREFIX + "*"): r.delete(key) ================================================ FILE: backend/onyx/redis/redis_connector_doc_perm_sync.py ================================================ import time from datetime import datetime from logging import Logger from typing import Any from typing import cast from typing import NamedTuple import redis from pydantic import BaseModel from redis.lock import Lock as RedisLock from onyx.access.models import DocExternalAccess from onyx.access.models import ElementExternalAccess from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT from onyx.configs.constants import OnyxRedisConstants from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT from onyx.utils.variable_functionality import fetch_versioned_implementation class PermissionSyncResult(NamedTuple): """Result of a permission sync operation. Attributes: num_updated: Number of documents successfully updated num_errors: Number of documents that failed to update """ num_updated: int num_errors: int class RedisConnectorPermissionSyncPayload(BaseModel): id: str submitted: datetime started: datetime | None celery_task_id: str | None class RedisConnectorPermissionSync: """Manages interactions with redis for doc permission sync tasks. Should only be accessed through RedisConnector.""" PREFIX = "connectordocpermissionsync" FENCE_PREFIX = f"{PREFIX}_fence" FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks # phase 1 - geneartor task and progress signals GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorpermissions+generator GENERATOR_PROGRESS_PREFIX = ( PREFIX + "_generator_progress" ) # connectorpermissions_generator_progress GENERATOR_COMPLETE_PREFIX = ( PREFIX + "_generator_complete" ) # connectorpermissions_generator_complete TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpermissions_taskset SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpermissions+sub # used to signal the overall workflow is still active # it's impossible to get the exact state of the system at a single point in time # so we need a signal with a TTL to bridge gaps in our checks ACTIVE_PREFIX = PREFIX + "_active" ACTIVE_TTL = CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT * 2 def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None: self.tenant_id: str = tenant_id self.id = id self.redis = redis self.fence_key: str = f"{self.FENCE_PREFIX}_{id}" self.generator_task_key = f"{self.GENERATORTASK_PREFIX}_{id}" self.generator_progress_key = f"{self.GENERATOR_PROGRESS_PREFIX}_{id}" self.generator_complete_key = f"{self.GENERATOR_COMPLETE_PREFIX}_{id}" self.taskset_key = f"{self.TASKSET_PREFIX}_{id}" self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}" self.active_key = f"{self.ACTIVE_PREFIX}_{id}" def taskset_clear(self) -> None: self.redis.delete(self.taskset_key) def generator_clear(self) -> None: self.redis.delete(self.generator_progress_key) self.redis.delete(self.generator_complete_key) def get_remaining(self) -> int: remaining = cast(int, self.redis.scard(self.taskset_key)) return remaining def get_active_task_count(self) -> int: """Count of active permission sync tasks""" count = 0 for _ in self.redis.sscan_iter( OnyxRedisConstants.ACTIVE_FENCES, RedisConnectorPermissionSync.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT, ): count += 1 return count @property def fenced(self) -> bool: return bool(self.redis.exists(self.fence_key)) @property def payload(self) -> RedisConnectorPermissionSyncPayload | None: # read related data and evaluate/print task progress fence_bytes = cast(Any, self.redis.get(self.fence_key)) if fence_bytes is None: return None fence_str = fence_bytes.decode("utf-8") payload = RedisConnectorPermissionSyncPayload.model_validate_json( cast(str, fence_str) ) return payload def set_fence( self, payload: RedisConnectorPermissionSyncPayload | None, ) -> None: if not payload: self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) self.redis.delete(self.fence_key) return self.redis.set(self.fence_key, payload.model_dump_json(), ex=self.FENCE_TTL) self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) def set_active(self) -> None: """This sets a signal to keep the permissioning flow from getting cleaned up within the expiration time. The slack in timing is needed to avoid race conditions where simply checking the celery queue and task status could result in race conditions.""" self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL) def active(self) -> bool: return bool(self.redis.exists(self.active_key)) @property def generator_complete(self) -> int | None: """the fence payload is an int representing the starting number of permission sync tasks to be processed ... just after the generator completes.""" fence_bytes = self.redis.get(self.generator_complete_key) if fence_bytes is None: return None if fence_bytes == b"None": return None fence_int = int(cast(bytes, fence_bytes).decode()) return fence_int @generator_complete.setter def generator_complete(self, payload: int | None) -> None: """Set the payload to an int to set the fence, otherwise if None it will be deleted""" if payload is None: self.redis.delete(self.generator_complete_key) return self.redis.set(self.generator_complete_key, payload, ex=self.FENCE_TTL) def update_db( self, lock: RedisLock | None, new_permissions: list[ElementExternalAccess], source_string: str, connector_id: int, credential_id: int, task_logger: Logger | None = None, ) -> PermissionSyncResult: """Update permissions for documents and hierarchy nodes. Returns: PermissionSyncResult containing counts of successful updates and errors """ last_lock_time = time.monotonic() element_update_permissions_fn = fetch_versioned_implementation( "onyx.background.celery.tasks.doc_permission_syncing.tasks", "element_update_permissions", ) num_permissions = 0 num_errors = 0 # Create a task for each permission sync for permissions in new_permissions: current_time = time.monotonic() if lock and current_time - last_lock_time >= ( CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4 ): lock.reacquire() last_lock_time = current_time if ( permissions.external_access.num_entries > permissions.external_access.MAX_NUM_ENTRIES ): if task_logger: num_users = len(permissions.external_access.external_user_emails) num_groups = len( permissions.external_access.external_user_group_ids ) element_id = ( permissions.doc_id if isinstance(permissions, DocExternalAccess) else permissions.raw_node_id ) task_logger.warning( f"Permissions length exceeded, skipping...: " f"{element_id} " f"{num_users=} {num_groups=} " f"{permissions.external_access.MAX_NUM_ENTRIES=}" ) continue # NOTE(rkuo): this used to fire a task instead of directly writing to the DB, # but the permissions can be excessively large if sent over the wire. # On the other hand, the downside of doing db updates here is that we can # block and fail if we can't make the calls to the DB ... but that's probably # a rare enough case to be acceptable. # This can internally exception due to db issues but still continue # Catch exceptions per-element to avoid breaking the entire sync try: element_update_permissions_fn( self.tenant_id, permissions, source_string, connector_id, credential_id, ) num_permissions += 1 except Exception: num_errors += 1 if task_logger: element_id = ( permissions.doc_id if isinstance(permissions, DocExternalAccess) else permissions.raw_node_id ) task_logger.exception( f"Failed to update permissions for element {element_id}" ) # Continue processing other elements return PermissionSyncResult(num_updated=num_permissions, num_errors=num_errors) def reset(self) -> None: self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) self.redis.delete(self.active_key) self.redis.delete(self.generator_progress_key) self.redis.delete(self.generator_complete_key) self.redis.delete(self.taskset_key) self.redis.delete(self.fence_key) @staticmethod def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None: taskset_key = f"{RedisConnectorPermissionSync.TASKSET_PREFIX}_{id}" r.srem(taskset_key, task_id) return @staticmethod def reset_all(r: redis.Redis) -> None: """Deletes all redis values for all connectors""" for key in r.scan_iter(RedisConnectorPermissionSync.ACTIVE_PREFIX + "*"): r.delete(key) for key in r.scan_iter(RedisConnectorPermissionSync.TASKSET_PREFIX + "*"): r.delete(key) for key in r.scan_iter( RedisConnectorPermissionSync.GENERATOR_COMPLETE_PREFIX + "*" ): r.delete(key) for key in r.scan_iter( RedisConnectorPermissionSync.GENERATOR_PROGRESS_PREFIX + "*" ): r.delete(key) for key in r.scan_iter(RedisConnectorPermissionSync.FENCE_PREFIX + "*"): r.delete(key) ================================================ FILE: backend/onyx/redis/redis_connector_ext_group_sync.py ================================================ from datetime import datetime from typing import cast import redis from celery import Celery from pydantic import BaseModel from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from onyx.configs.constants import OnyxRedisConstants from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT class RedisConnectorExternalGroupSyncPayload(BaseModel): id: str submitted: datetime started: datetime | None celery_task_id: str | None class RedisConnectorExternalGroupSync: """Manages interactions with redis for external group syncing tasks. Should only be accessed through RedisConnector.""" PREFIX = "connectorexternalgroupsync" FENCE_PREFIX = f"{PREFIX}_fence" FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks # phase 1 - geneartor task and progress signals GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorexternalgroupsync+generator GENERATOR_PROGRESS_PREFIX = ( PREFIX + "_generator_progress" ) # connectorexternalgroupsync_generator_progress GENERATOR_COMPLETE_PREFIX = ( PREFIX + "_generator_complete" ) # connectorexternalgroupsync_generator_complete TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorexternalgroupsync_taskset SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorexternalgroupsync+sub # used to signal the overall workflow is still active # it's impossible to get the exact state of the system at a single point in time # so we need a signal with a TTL to bridge gaps in our checks ACTIVE_PREFIX = PREFIX + "_active" ACTIVE_TTL = 3600 def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None: self.tenant_id: str = tenant_id self.id = id self.redis = redis self.fence_key: str = f"{self.FENCE_PREFIX}_{id}" self.generator_task_key = f"{self.GENERATORTASK_PREFIX}_{id}" self.generator_progress_key = f"{self.GENERATOR_PROGRESS_PREFIX}_{id}" self.generator_complete_key = f"{self.GENERATOR_COMPLETE_PREFIX}_{id}" self.taskset_key = f"{self.TASKSET_PREFIX}_{id}" self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}" self.active_key = f"{self.ACTIVE_PREFIX}_{id}" def taskset_clear(self) -> None: self.redis.delete(self.taskset_key) def generator_clear(self) -> None: self.redis.delete(self.generator_progress_key) self.redis.delete(self.generator_complete_key) def get_remaining(self) -> int: # todo: move into fence remaining = cast(int, self.redis.scard(self.taskset_key)) return remaining def get_active_task_count(self) -> int: """Count of active external group syncing tasks""" count = 0 for _ in self.redis.sscan_iter( OnyxRedisConstants.ACTIVE_FENCES, RedisConnectorExternalGroupSync.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT, ): count += 1 return count @property def fenced(self) -> bool: return bool(self.redis.exists(self.fence_key)) @property def payload(self) -> RedisConnectorExternalGroupSyncPayload | None: # read related data and evaluate/print task progress fence_raw = self.redis.get(self.fence_key) if fence_raw is None: return None fence_bytes = cast(bytes, fence_raw) fence_str = fence_bytes.decode("utf-8") payload = RedisConnectorExternalGroupSyncPayload.model_validate_json( cast(str, fence_str) ) return payload def set_fence( self, payload: RedisConnectorExternalGroupSyncPayload | None, ) -> None: if not payload: self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) self.redis.delete(self.fence_key) return self.redis.set(self.fence_key, payload.model_dump_json(), ex=self.FENCE_TTL) self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) def set_active(self) -> None: """This sets a signal to keep the permissioning flow from getting cleaned up within the expiration time. The slack in timing is needed to avoid race conditions where simply checking the celery queue and task status could result in race conditions.""" self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL) def active(self) -> bool: return bool(self.redis.exists(self.active_key)) @property def generator_complete(self) -> int | None: """the fence payload is an int representing the starting number of external group syncing tasks to be processed ... just after the generator completes. """ fence_bytes = self.redis.get(self.generator_complete_key) if fence_bytes is None: return None if fence_bytes == b"None": return None fence_int = int(cast(bytes, fence_bytes).decode()) return fence_int @generator_complete.setter def generator_complete(self, payload: int | None) -> None: """Set the payload to an int to set the fence, otherwise if None it will be deleted""" if payload is None: self.redis.delete(self.generator_complete_key) return self.redis.set(self.generator_complete_key, payload, ex=self.FENCE_TTL) def generate_tasks( self, celery_app: Celery, db_session: Session, lock: RedisLock | None, ) -> int | None: pass def reset(self) -> None: self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) self.redis.delete(self.active_key) self.redis.delete(self.generator_progress_key) self.redis.delete(self.generator_complete_key) self.redis.delete(self.taskset_key) self.redis.delete(self.fence_key) @staticmethod def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None: taskset_key = f"{RedisConnectorExternalGroupSync.TASKSET_PREFIX}_{id}" r.srem(taskset_key, task_id) return @staticmethod def reset_all(r: redis.Redis) -> None: """Deletes all redis values for all connectors""" for key in r.scan_iter(RedisConnectorExternalGroupSync.ACTIVE_PREFIX + "*"): r.delete(key) for key in r.scan_iter(RedisConnectorExternalGroupSync.TASKSET_PREFIX + "*"): r.delete(key) for key in r.scan_iter( RedisConnectorExternalGroupSync.GENERATOR_COMPLETE_PREFIX + "*" ): r.delete(key) for key in r.scan_iter( RedisConnectorExternalGroupSync.GENERATOR_PROGRESS_PREFIX + "*" ): r.delete(key) for key in r.scan_iter(RedisConnectorExternalGroupSync.FENCE_PREFIX + "*"): r.delete(key) ================================================ FILE: backend/onyx/redis/redis_connector_index.py ================================================ from datetime import datetime from pydantic import BaseModel class RedisConnectorIndexPayload(BaseModel): index_attempt_id: int | None started: datetime | None submitted: datetime celery_task_id: str | None ================================================ FILE: backend/onyx/redis/redis_connector_prune.py ================================================ import time from datetime import datetime from typing import cast from uuid import uuid4 import redis from celery import Celery from pydantic import BaseModel from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import CELERY_PRUNING_LOCK_TIMEOUT from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisConstants from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT class RedisConnectorPrunePayload(BaseModel): id: str submitted: datetime started: datetime | None celery_task_id: str | None class RedisConnectorPrune: """Manages interactions with redis for pruning tasks. Should only be accessed through RedisConnector.""" PREFIX = "connectorpruning" FENCE_PREFIX = f"{PREFIX}_fence" FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks # phase 1 - geneartor task and progress signals GENERATORTASK_PREFIX = f"{PREFIX}+generator" # connectorpruning+generator GENERATOR_PROGRESS_PREFIX = ( PREFIX + "_generator_progress" ) # connectorpruning_generator_progress GENERATOR_COMPLETE_PREFIX = ( PREFIX + "_generator_complete" ) # connectorpruning_generator_complete TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpruning_taskset TASKSET_TTL = FENCE_TTL SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpruning+sub # used to signal the overall workflow is still active # it's impossible to get the exact state of the system at a single point in time # so we need a signal with a TTL to bridge gaps in our checks ACTIVE_PREFIX = PREFIX + "_active" ACTIVE_TTL = CELERY_PRUNING_LOCK_TIMEOUT * 2 def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None: self.tenant_id: str = tenant_id self.id = id self.redis = redis self.fence_key: str = f"{self.FENCE_PREFIX}_{id}" self.generator_task_key = f"{self.GENERATORTASK_PREFIX}_{id}" self.generator_progress_key = f"{self.GENERATOR_PROGRESS_PREFIX}_{id}" self.generator_complete_key = f"{self.GENERATOR_COMPLETE_PREFIX}_{id}" self.taskset_key = f"{self.TASKSET_PREFIX}_{id}" self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}" self.active_key = f"{self.ACTIVE_PREFIX}_{id}" def taskset_clear(self) -> None: self.redis.delete(self.taskset_key) def generator_clear(self) -> None: self.redis.delete(self.generator_progress_key) self.redis.delete(self.generator_complete_key) def get_remaining(self) -> int: # todo: move into fence remaining = cast(int, self.redis.scard(self.taskset_key)) return remaining def get_active_task_count(self) -> int: """Count of active pruning tasks""" count = 0 for _ in self.redis.sscan_iter( OnyxRedisConstants.ACTIVE_FENCES, RedisConnectorPrune.FENCE_PREFIX + "*", count=SCAN_ITER_COUNT_DEFAULT, ): count += 1 return count @property def fenced(self) -> bool: return bool(self.redis.exists(self.fence_key)) @property def payload(self) -> RedisConnectorPrunePayload | None: # read related data and evaluate/print task progress fence_bytes = cast(bytes, self.redis.get(self.fence_key)) if fence_bytes is None: return None fence_str = fence_bytes.decode("utf-8") payload = RedisConnectorPrunePayload.model_validate_json(cast(str, fence_str)) return payload def set_fence( self, payload: RedisConnectorPrunePayload | None, ) -> None: if not payload: self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) self.redis.delete(self.fence_key) return self.redis.set(self.fence_key, payload.model_dump_json(), ex=self.FENCE_TTL) self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) def set_active(self) -> None: """This sets a signal to keep the permissioning flow from getting cleaned up within the expiration time. The slack in timing is needed to avoid race conditions where simply checking the celery queue and task status could result in race conditions.""" self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL) def active(self) -> bool: return bool(self.redis.exists(self.active_key)) @property def generator_complete(self) -> int | None: """the fence payload is an int representing the starting number of pruning tasks to be processed ... just after the generator completes.""" fence_bytes = self.redis.get(self.generator_complete_key) if fence_bytes is None: return None fence_int = int(cast(bytes, fence_bytes)) return fence_int @generator_complete.setter def generator_complete(self, payload: int | None) -> None: """Set the payload to an int to set the fence, otherwise if None it will be deleted""" if payload is None: self.redis.delete(self.generator_complete_key) return self.redis.set(self.generator_complete_key, payload, ex=self.FENCE_TTL) def generate_tasks( self, documents_to_prune: set[str], celery_app: Celery, db_session: Session, lock: RedisLock | None, ) -> int | None: last_lock_time = time.monotonic() async_results = [] cc_pair = get_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=int(self.id), ) if not cc_pair: return None for doc_id in documents_to_prune: current_time = time.monotonic() if lock and current_time - last_lock_time >= ( CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4 ): lock.reacquire() last_lock_time = current_time # celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac" # the actual redis key is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac" # we prefix the task id so it's easier to keep track of who created the task # aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac" custom_task_id = f"{self.subtask_prefix}_{uuid4()}" # add to the tracking taskset in redis BEFORE creating the celery task. self.redis.sadd(self.taskset_key, custom_task_id) self.redis.expire(self.taskset_key, self.TASKSET_TTL) # Priority on sync's triggered by new indexing should be medium result = celery_app.send_task( OnyxCeleryTask.DOCUMENT_BY_CC_PAIR_CLEANUP_TASK, kwargs=dict( document_id=doc_id, connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id, tenant_id=self.tenant_id, ), queue=OnyxCeleryQueues.CONNECTOR_DELETION, task_id=custom_task_id, priority=OnyxCeleryPriority.MEDIUM, ignore_result=True, ) async_results.append(result) return len(async_results) def reset(self) -> None: self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) self.redis.delete(self.active_key) self.redis.delete(self.generator_progress_key) self.redis.delete(self.generator_complete_key) self.redis.delete(self.taskset_key) self.redis.delete(self.fence_key) @staticmethod def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None: taskset_key = f"{RedisConnectorPrune.TASKSET_PREFIX}_{id}" r.srem(taskset_key, task_id) return @staticmethod def reset_all(r: redis.Redis) -> None: """Deletes all redis values for all connectors""" for key in r.scan_iter(RedisConnectorPrune.ACTIVE_PREFIX + "*"): r.delete(key) for key in r.scan_iter(RedisConnectorPrune.TASKSET_PREFIX + "*"): r.delete(key) for key in r.scan_iter(RedisConnectorPrune.GENERATOR_COMPLETE_PREFIX + "*"): r.delete(key) for key in r.scan_iter(RedisConnectorPrune.GENERATOR_PROGRESS_PREFIX + "*"): r.delete(key) for key in r.scan_iter(RedisConnectorPrune.FENCE_PREFIX + "*"): r.delete(key) ================================================ FILE: backend/onyx/redis/redis_connector_stop.py ================================================ import redis class RedisConnectorStop: """Manages interactions with redis for stop signaling. Should only be accessed through RedisConnector.""" PREFIX = "connectorstop" FENCE_PREFIX = f"{PREFIX}_fence" FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks # if this timeout is exceeded, the caller may decide to take more # drastic measures TIMEOUT_PREFIX = f"{PREFIX}_timeout" TIMEOUT_TTL = 300 def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None: self.tenant_id: str = tenant_id self.id: int = id self.redis = redis self.fence_key: str = f"{self.FENCE_PREFIX}_{id}" self.timeout_key: str = f"{self.TIMEOUT_PREFIX}_{id}" @property def fenced(self) -> bool: return bool(self.redis.exists(self.fence_key)) def set_fence(self, value: bool) -> None: if not value: self.redis.delete(self.fence_key) return self.redis.set(self.fence_key, 0, ex=self.FENCE_TTL) @property def timed_out(self) -> bool: return not bool(self.redis.exists(self.timeout_key)) def set_timeout(self) -> None: """After calling this, call timed_out to determine if the timeout has been exceeded.""" self.redis.set(f"{self.timeout_key}", 0, ex=self.TIMEOUT_TTL) @staticmethod def reset_all(r: redis.Redis) -> None: for key in r.scan_iter(RedisConnectorStop.FENCE_PREFIX + "*"): r.delete(key) for key in r.scan_iter(RedisConnectorStop.TIMEOUT_PREFIX + "*"): r.delete(key) ================================================ FILE: backend/onyx/redis/redis_connector_utils.py ================================================ from sqlalchemy.orm import Session from onyx.db.connector_credential_pair import get_connector_credential_pair from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import TaskStatus from onyx.db.models import TaskQueueState from onyx.redis.redis_connector import RedisConnector from onyx.server.documents.models import DeletionAttemptSnapshot def _get_deletion_status( connector_id: int, credential_id: int, db_session: Session, tenant_id: str, ) -> TaskQueueState | None: """We no longer store TaskQueueState in the DB for a deletion attempt. This function populates TaskQueueState by just checking redis. """ cc_pair = get_connector_credential_pair( connector_id=connector_id, credential_id=credential_id, db_session=db_session ) if not cc_pair: return None redis_connector = RedisConnector(tenant_id, cc_pair.id) if redis_connector.delete.fenced: return TaskQueueState( task_id="", task_name=redis_connector.delete.fence_key, status=TaskStatus.STARTED, ) if cc_pair.status == ConnectorCredentialPairStatus.DELETING: return TaskQueueState( task_id="", task_name=redis_connector.delete.fence_key, status=TaskStatus.PENDING, ) return None def get_deletion_attempt_snapshot( connector_id: int, credential_id: int, db_session: Session, tenant_id: str, ) -> DeletionAttemptSnapshot | None: deletion_task = _get_deletion_status( connector_id, credential_id, db_session, tenant_id ) if not deletion_task: return None return DeletionAttemptSnapshot( connector_id=connector_id, credential_id=credential_id, status=deletion_task.status, ) ================================================ FILE: backend/onyx/redis/redis_document_set.py ================================================ import time from typing import cast from uuid import uuid4 import redis from celery import Celery from redis import Redis from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from onyx.configs.app_configs import DB_YIELD_PER_DEFAULT from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisConstants from onyx.db.document_set import construct_document_id_select_by_docset from onyx.redis.redis_object_helper import RedisObjectHelper class RedisDocumentSet(RedisObjectHelper): PREFIX = "documentset" FENCE_PREFIX = PREFIX + "_fence" FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks TASKSET_PREFIX = PREFIX + "_taskset" TASKSET_TTL = FENCE_TTL def __init__(self, tenant_id: str, id: int) -> None: super().__init__(tenant_id, str(id)) @property def fenced(self) -> bool: return bool(self.redis.exists(self.fence_key)) def set_fence(self, payload: int | None) -> None: if payload is None: self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) self.redis.delete(self.fence_key) return self.redis.set(self.fence_key, payload, ex=self.FENCE_TTL) self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) @property def payload(self) -> int | None: bytes = self.redis.get(self.fence_key) if bytes is None: return None progress = int(cast(int, bytes)) return progress def generate_tasks( self, max_tasks: int, # noqa: ARG002 celery_app: Celery, db_session: Session, redis_client: Redis, lock: RedisLock, tenant_id: str, ) -> tuple[int, int] | None: """Max tasks is ignored for now until we can build the logic to mark the document set up to date over multiple batches. """ last_lock_time = time.monotonic() num_tasks_sent = 0 stmt = construct_document_id_select_by_docset(int(self._id), current_only=False) for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT): doc_id = cast(str, doc_id) current_time = time.monotonic() if current_time - last_lock_time >= ( CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 ): lock.reacquire() last_lock_time = current_time # celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac" # the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac" # we prefix the task id so it's easier to keep track of who created the task # aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac" custom_task_id = f"{self.task_id_prefix}_{uuid4()}" # add to the set BEFORE creating the task. redis_client.sadd(self.taskset_key, custom_task_id) redis_client.expire(self.taskset_key, self.TASKSET_TTL) celery_app.send_task( OnyxCeleryTask.VESPA_METADATA_SYNC_TASK, kwargs=dict(document_id=doc_id, tenant_id=tenant_id), queue=OnyxCeleryQueues.VESPA_METADATA_SYNC, task_id=custom_task_id, priority=OnyxCeleryPriority.MEDIUM, ) num_tasks_sent += 1 return num_tasks_sent, num_tasks_sent def reset(self) -> None: self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) self.redis.delete(self.taskset_key) self.redis.delete(self.fence_key) @staticmethod def reset_all(r: redis.Redis) -> None: for key in r.scan_iter(RedisDocumentSet.TASKSET_PREFIX + "*"): r.delete(key) for key in r.scan_iter(RedisDocumentSet.FENCE_PREFIX + "*"): r.delete(key) ================================================ FILE: backend/onyx/redis/redis_hierarchy.py ================================================ """Redis cache operations for hierarchy node ancestor resolution. This module provides a Redis-based cache for hierarchy node parent relationships, enabling fast ancestor path resolution without repeated database queries. The cache stores node_id -> parent_id mappings for all hierarchy nodes of a given source type. When resolving ancestors for a document, we walk up the tree using Redis lookups instead of database queries. Cache Strategy: - Nodes are cached per source type with a 6-hour TTL - During docfetching, nodes are added to cache as they're upserted to Postgres - If the cache is stale (TTL expired during long-running job), one worker does a full refresh from DB while others wait - If a node is still not found after refresh, we log an error and fall back to using only the SOURCE-type node as the ancestor """ from typing import cast from typing import TYPE_CHECKING from pydantic import BaseModel from redis import Redis from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from onyx.configs.constants import DocumentSource from onyx.db.enums import HierarchyNodeType from onyx.db.hierarchy import ensure_source_node_exists as db_ensure_source_node_exists from onyx.db.hierarchy import get_all_hierarchy_nodes_for_source from onyx.utils.logger import setup_logger if TYPE_CHECKING: from onyx.db.models import HierarchyNode as DBHierarchyNode logger = setup_logger() # Cache TTL: 6 hours in seconds HIERARCHY_CACHE_TTL_SECONDS = 6 * 60 * 60 # Lock timeout for cache refresh: 5 minutes HIERARCHY_CACHE_LOCK_TIMEOUT_SECONDS = 5 * 60 # Lock acquisition timeout: 60 seconds HIERARCHY_CACHE_LOCK_ACQUIRE_TIMEOUT_SECONDS = 60 MAX_DEPTH = 1000 class HierarchyNodeCacheEntry(BaseModel): """Represents a hierarchy node for caching purposes.""" node_id: int parent_id: int | None node_type: HierarchyNodeType raw_node_id: str @classmethod def from_db_model(cls, node: "DBHierarchyNode") -> "HierarchyNodeCacheEntry": """Create a cache entry from a SQLAlchemy HierarchyNode model.""" return cls( node_id=node.id, parent_id=node.parent_id, node_type=node.node_type, raw_node_id=node.raw_node_id, ) def _cache_key(source: DocumentSource) -> str: """Get the Redis hash key for hierarchy node cache of a given source. This hash stores: node_id -> "parent_id:node_type" """ return f"hierarchy_cache:{source.value}" def _raw_id_cache_key(source: DocumentSource) -> str: """Get the Redis hash key for raw_node_id -> node_id mapping. This hash stores: raw_node_id -> node_id """ return f"hierarchy_cache_rawid:{source.value}" def _source_node_key(source: DocumentSource) -> str: """Get the Redis key for the SOURCE-type node ID of a given source. This is a simple string key storing the database ID of the SOURCE node. """ return f"hierarchy_source_node:{source.value}" def _loading_lock_key(source: DocumentSource) -> str: """Get the Redis lock key for cache loading of a given source.""" return f"hierarchy_cache_loading:{source.value}" def _construct_parent_value(parent_id: int | None, node_type: HierarchyNodeType) -> str: """Construct the cached value string from parent_id and node_type. Format: "parent_id:node_type" where parent_id is empty string if None. """ parent_str = str(parent_id) if parent_id is not None else "" return f"{parent_str}:{node_type.value}" def _unpack_parent_value(value: str) -> tuple[int | None, HierarchyNodeType | None]: """Unpack a cached value string back into (parent_id, node_type). Returns None for invalid values. """ parts = value.split(":", 1) parent_str = parts[0] node_type_str = parts[1] if len(parts) > 1 else "" parent_id = int(parent_str) if parent_str else None node_type = HierarchyNodeType(node_type_str) if node_type_str else None return parent_id, node_type def cache_hierarchy_node( redis_client: Redis, source: DocumentSource, entry: HierarchyNodeCacheEntry, ) -> None: """ Add or update a single hierarchy node in the Redis cache. Called during docfetching when nodes are upserted to Postgres. Stores the parent chain mapping, raw_id -> node_id mapping, and SOURCE node ID (if this is a SOURCE-type node). Args: redis_client: Redis client with tenant prefixing source: The document source (e.g., CONFLUENCE, GOOGLE_DRIVE) entry: The hierarchy node cache entry """ cache_key = _cache_key(source) raw_id_key = _raw_id_cache_key(source) # Store parent chain: node_id -> "parent_id:node_type" value = _construct_parent_value(entry.parent_id, entry.node_type) redis_client.hset(cache_key, str(entry.node_id), value) # Store raw_id -> node_id mapping redis_client.hset(raw_id_key, entry.raw_node_id, str(entry.node_id)) # If this is the SOURCE node, store its ID in the dedicated key if entry.node_type == HierarchyNodeType.SOURCE: source_node_key = _source_node_key(source) redis_client.set(source_node_key, str(entry.node_id)) redis_client.expire(source_node_key, HIERARCHY_CACHE_TTL_SECONDS) # Refresh TTL on every write (ensures cache stays alive during long indexing) redis_client.expire(cache_key, HIERARCHY_CACHE_TTL_SECONDS) redis_client.expire(raw_id_key, HIERARCHY_CACHE_TTL_SECONDS) def cache_hierarchy_nodes_batch( redis_client: Redis, source: DocumentSource, entries: list[HierarchyNodeCacheEntry], ) -> None: """ Add or update multiple hierarchy nodes in the Redis cache. Args: redis_client: Redis client with tenant prefixing source: The document source entries: List of HierarchyNodeCacheEntry objects """ if not entries: return cache_key = _cache_key(source) raw_id_key = _raw_id_cache_key(source) source_node_key = _source_node_key(source) # Build mappings for batch insert parent_mapping: dict[str, str] = {} raw_id_mapping: dict[str, str] = {} source_node_id: int | None = None for entry in entries: parent_mapping[str(entry.node_id)] = _construct_parent_value( entry.parent_id, entry.node_type ) raw_id_mapping[entry.raw_node_id] = str(entry.node_id) # Track the SOURCE node if we encounter it if entry.node_type == HierarchyNodeType.SOURCE: source_node_id = entry.node_id # Use hset with mapping for batch insert redis_client.hset(cache_key, mapping=parent_mapping) redis_client.hset(raw_id_key, mapping=raw_id_mapping) # Cache the SOURCE node ID if found if source_node_id is not None: redis_client.set(source_node_key, str(source_node_id)) redis_client.expire(source_node_key, HIERARCHY_CACHE_TTL_SECONDS) redis_client.expire(cache_key, HIERARCHY_CACHE_TTL_SECONDS) redis_client.expire(raw_id_key, HIERARCHY_CACHE_TTL_SECONDS) def evict_hierarchy_nodes_from_cache( redis_client: Redis, source: DocumentSource, raw_node_ids: list[str], ) -> None: """Remove specific hierarchy nodes from the Redis cache. Deletes entries from both the parent-chain hash and the raw_id→node_id hash. """ if not raw_node_ids: return cache_key = _cache_key(source) raw_id_key = _raw_id_cache_key(source) # Look up node_ids so we can remove them from the parent-chain hash raw_values = cast(list[str | None], redis_client.hmget(raw_id_key, raw_node_ids)) node_id_strs = [v for v in raw_values if v is not None] if node_id_strs: redis_client.hdel(cache_key, *node_id_strs) redis_client.hdel(raw_id_key, *raw_node_ids) def get_node_id_from_raw_id( redis_client: Redis, source: DocumentSource, raw_node_id: str, ) -> tuple[int | None, bool]: """ Get the database node_id for a raw_node_id from the cache. Returns: Tuple of (node_id or None, found_in_cache) - If found_in_cache is False, the raw_id doesn't exist in cache - If found_in_cache is True, node_id is the database ID """ raw_id_key = _raw_id_cache_key(source) value = redis_client.hget(raw_id_key, raw_node_id) if value is None: return None, False # Decode bytes if needed value_str: str if isinstance(value, bytes): value_str = value.decode("utf-8") else: value_str = str(value) return int(value_str), True def get_parent_id_from_cache( redis_client: Redis, source: DocumentSource, node_id: int, ) -> tuple[int | None, bool]: """ Get the parent_id for a node from the cache. Returns: Tuple of (parent_id or None, found_in_cache) - If found_in_cache is False, the node doesn't exist in cache - If found_in_cache is True, parent_id is the actual parent (or None for root) """ cache_key = _cache_key(source) value = redis_client.hget(cache_key, str(node_id)) if value is None: return None, False # Decode bytes if needed value_str: str if isinstance(value, bytes): value_str = value.decode("utf-8") else: value_str = str(value) parent_id, _ = _unpack_parent_value(value_str) return parent_id, True def is_cache_populated(redis_client: Redis, source: DocumentSource) -> bool: """Check if the cache has any entries for this source.""" cache_key = _cache_key(source) # redis.exists returns int (number of keys that exist) exists_result: int = redis_client.exists(cache_key) # type: ignore[assignment] return exists_result > 0 def refresh_hierarchy_cache_from_db( redis_client: Redis, db_session: Session, source: DocumentSource, ) -> None: """ Refresh the entire hierarchy cache for a source from the database. This function acquires a distributed lock to ensure only one worker performs the refresh. Other workers will wait for the refresh to complete. Args: redis_client: Redis client with tenant prefixing db_session: SQLAlchemy session for database access source: The document source to refresh """ lock_key = _loading_lock_key(source) # Try to acquire lock - if we can't get it, someone else is refreshing lock: RedisLock = redis_client.lock( lock_key, timeout=HIERARCHY_CACHE_LOCK_TIMEOUT_SECONDS, blocking=True, blocking_timeout=HIERARCHY_CACHE_LOCK_ACQUIRE_TIMEOUT_SECONDS, ) acquired = lock.acquire(blocking=True) if not acquired: logger.warning( f"Could not acquire lock for hierarchy cache refresh for source {source.value} - another worker may be refreshing" ) return try: # Always refresh from DB when called - new nodes may have been added # since the cache was last populated. The lock ensures only one worker # does the refresh at a time. logger.info(f"Refreshing hierarchy cache for source {source.value} from DB") # Load all nodes for this source from DB nodes = get_all_hierarchy_nodes_for_source(db_session, source) if not nodes: logger.warning(f"No hierarchy nodes found in DB for source {source.value}") return # Batch insert into cache cache_entries = [HierarchyNodeCacheEntry.from_db_model(node) for node in nodes] cache_hierarchy_nodes_batch(redis_client, source, cache_entries) logger.info( f"Refreshed hierarchy cache for {source.value} with {len(nodes)} nodes" ) finally: try: lock.release() except Exception as e: logger.warning(f"Error releasing hierarchy cache lock: {e}") def _walk_ancestor_chain( redis_client: Redis, source: DocumentSource, start_node_id: int, db_session: Session, ) -> list[int]: """ Walk up the hierarchy tree from a node, collecting all ancestor IDs. Internal helper used by both get_ancestors_from_node_id and get_ancestors_from_raw_id. """ ancestors: list[int] = [] current_id: int | None = start_node_id visited: set[int] = set() while current_id is not None and len(ancestors) < MAX_DEPTH: if current_id in visited: logger.error( f"Cycle detected in hierarchy for source {source.value} at node {current_id}. Ancestors so far: {ancestors}" ) break visited.add(current_id) ancestors.append(current_id) parent_id, found = get_parent_id_from_cache(redis_client, source, current_id) if not found: logger.debug( f"Cache miss for hierarchy node {current_id} of source {source.value}, attempting refresh" ) refresh_hierarchy_cache_from_db(redis_client, db_session, source) parent_id, found = get_parent_id_from_cache( redis_client, source, current_id ) if not found: logger.error( f"Hierarchy node {current_id} not found in cache for source {source.value} even after refresh." ) break current_id = parent_id if len(ancestors) >= MAX_DEPTH: logger.error( f"Hit max depth {MAX_DEPTH} traversing hierarchy for source " f"{source.value}. Possible infinite loop or very deep hierarchy." ) return ancestors def get_ancestors_from_raw_id( redis_client: Redis, source: DocumentSource, parent_hierarchy_raw_node_id: str | None, db_session: Session, ) -> list[int]: """ Get all ancestor hierarchy node IDs from a raw_node_id. This is the main entry point for getting ancestors from a document's parent_hierarchy_raw_node_id. It resolves the raw_id to a database ID via Redis cache, then walks up the tree. No DB calls are made unless the cache is stale. Args: redis_client: Redis client with tenant prefixing source: The document source parent_hierarchy_raw_node_id: The document's parent raw node ID (from connector) db_session: DB session for cache refresh if needed Returns: List of ancestor hierarchy node IDs from parent to root (inclusive). Returns list with just SOURCE node ID if parent is None or not found. """ # If no parent specified, return just the SOURCE node if parent_hierarchy_raw_node_id is None: source_node_id = get_source_node_id_from_cache(redis_client, db_session, source) return [source_node_id] if source_node_id else [] # Resolve raw_id to node_id via Redis node_id, found = get_node_id_from_raw_id( redis_client, source, parent_hierarchy_raw_node_id ) if not found: # Cache miss - try refresh logger.debug( f"Cache miss for raw_node_id '{parent_hierarchy_raw_node_id}' of source {source.value}, attempting refresh" ) refresh_hierarchy_cache_from_db(redis_client, db_session, source) node_id, found = get_node_id_from_raw_id( redis_client, source, parent_hierarchy_raw_node_id ) if not found or node_id is None: logger.error( f"Raw node ID '{parent_hierarchy_raw_node_id}' not found in cache " f"for source {source.value}. Falling back to SOURCE node only." ) source_node_id = get_source_node_id_from_cache(redis_client, db_session, source) return [source_node_id] if source_node_id else [] # Walk up the ancestor chain return _walk_ancestor_chain(redis_client, source, node_id, db_session) def get_source_node_id_from_cache( redis_client: Redis, db_session: Session, source: DocumentSource, ) -> int | None: """ Get the SOURCE-type node ID for a given source from cache. If not in cache and db_session is provided, refreshes from DB. Returns: The ID of the SOURCE node, or None if not found. """ source_node_key = _source_node_key(source) # Try to get from dedicated SOURCE node key value = redis_client.get(source_node_key) if value is not None: if isinstance(value, bytes): value = value.decode("utf-8") if not isinstance(value, str): raise ValueError(f"SOURCE node value is not a string: {value}") return int(value) # Not in cache - try refresh from DB refresh_hierarchy_cache_from_db(redis_client, db_session, source) # Try again after refresh value = redis_client.get(source_node_key) if value is not None: if isinstance(value, bytes): value = value.decode("utf-8") if not isinstance(value, str): raise ValueError(f"SOURCE node value is not a string: {value}") return int(value) logger.error(f"SOURCE node not found for source {source.value}") return None def clear_hierarchy_cache(redis_client: Redis, source: DocumentSource) -> None: """Clear the hierarchy cache for a source (useful for testing).""" cache_key = _cache_key(source) raw_id_key = _raw_id_cache_key(source) source_node_key = _source_node_key(source) redis_client.delete(cache_key) redis_client.delete(raw_id_key) redis_client.delete(source_node_key) def ensure_source_node_exists( redis_client: Redis, db_session: Session, source: DocumentSource, ) -> int: """ Ensure that a SOURCE-type hierarchy node exists for the given source and cache it. This is the primary entry point for ensuring hierarchy infrastructure is set up for a source before processing documents. It should be called early in the indexing pipeline (e.g., at the start of docfetching or hierarchy fetching). The function: 1. Checks Redis cache for existing SOURCE node ID 2. If not cached, ensures the SOURCE node exists in the database 3. Caches the SOURCE node in Redis for fast subsequent lookups This is idempotent and safe to call multiple times concurrently. Args: redis_client: Redis client with tenant prefixing db_session: SQLAlchemy session for database operations source: The document source type (e.g., GOOGLE_DRIVE, CONFLUENCE) Returns: The database ID of the SOURCE-type hierarchy node """ # First check if we already have it cached source_node_key = _source_node_key(source) cached_value = redis_client.get(source_node_key) if cached_value is not None: value_str: str if isinstance(cached_value, bytes): value_str = cached_value.decode("utf-8") else: value_str = str(cached_value) return int(value_str) # Not cached - ensure it exists in DB and cache it source_node = db_ensure_source_node_exists(db_session, source, commit=True) # Cache the SOURCE node cache_entry = HierarchyNodeCacheEntry.from_db_model(source_node) cache_hierarchy_node(redis_client, source, cache_entry) logger.info( f"Ensured SOURCE node exists and cached for {source.value}: id={source_node.id}" ) return source_node.id ================================================ FILE: backend/onyx/redis/redis_object_helper.py ================================================ from abc import ABC from abc import abstractmethod from celery import Celery from redis import Redis from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from onyx.redis.redis_pool import get_redis_client class RedisObjectHelper(ABC): PREFIX = "base" FENCE_PREFIX = PREFIX + "_fence" TASKSET_PREFIX = PREFIX + "_taskset" def __init__(self, tenant_id: str, id: str): self._tenant_id: str = tenant_id self._id: str = id self.redis = get_redis_client(tenant_id=tenant_id) @property def task_id_prefix(self) -> str: return f"{self.PREFIX}_{self._id}" @property def fence_key(self) -> str: # example: documentset_fence_1 return f"{self.FENCE_PREFIX}_{self._id}" @property def taskset_key(self) -> str: # example: documentset_taskset_1 return f"{self.TASKSET_PREFIX}_{self._id}" @staticmethod def get_id_from_fence_key(key: str) -> str | None: """ Extracts the object ID from a fence key in the format `PREFIX_fence_X`. Args: key (str): The fence key string. Returns: Optional[int]: The extracted ID if the key is in the correct format, otherwise None. """ parts = key.split("_") if len(parts) != 3: return None object_id = parts[2] return object_id @staticmethod def get_id_from_task_id(task_id: str) -> str | None: """ Extracts the object ID from a task ID string. This method assumes the task ID is formatted as `prefix_objectid_suffix`, where: - `prefix` is an arbitrary string (e.g., the name of the task or entity), - `objectid` is the ID you want to extract, - `suffix` is another arbitrary string (e.g., a UUID). Example: If the input `task_id` is `documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc`, this method will return the string `"1"`. Args: task_id (str): The task ID string from which to extract the object ID. Returns: str | None: The extracted object ID if the task ID is in the correct format, otherwise None. """ # example: task_id=documentset_1_cbfdc96a-80ca-4312-a242-0bb68da3c1dc parts = task_id.split("_") if len(parts) != 3: return None object_id = parts[1] return object_id @abstractmethod def generate_tasks( self, max_tasks: int, celery_app: Celery, db_session: Session, redis_client: Redis, lock: RedisLock, tenant_id: str, ) -> tuple[int, int] | None: """First element should be the number of actual tasks generated, second should be the number of docs that were candidates to be synced for the cc pair. The need for this is when we are syncing stale docs referenced by multiple connectors. In a single pass across multiple cc pairs, we only want a task for be created for a particular document id the first time we see it. The rest can be skipped.""" ================================================ FILE: backend/onyx/redis/redis_pool.py ================================================ import asyncio import functools import json import ssl import threading from collections.abc import Callable from typing import Any from typing import cast from typing import Optional import redis from fastapi import Request from redis import asyncio as aioredis from redis.client import Redis from redis.lock import Lock as RedisLock from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX from onyx.configs.app_configs import REDIS_DB_NUMBER from onyx.configs.app_configs import REDIS_HEALTH_CHECK_INTERVAL from onyx.configs.app_configs import REDIS_HOST from onyx.configs.app_configs import REDIS_PASSWORD from onyx.configs.app_configs import REDIS_POOL_MAX_CONNECTIONS from onyx.configs.app_configs import REDIS_PORT from onyx.configs.app_configs import REDIS_REPLICA_HOST from onyx.configs.app_configs import REDIS_SSL from onyx.configs.app_configs import REDIS_SSL_CA_CERTS from onyx.configs.app_configs import REDIS_SSL_CERT_REQS from onyx.configs.app_configs import USE_REDIS_IAM_AUTH from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME from onyx.configs.constants import REDIS_SOCKET_KEEPALIVE_OPTIONS from onyx.redis.iam_auth import configure_redis_iam_auth from onyx.redis.iam_auth import create_redis_ssl_context_if_iam from onyx.utils.logger import setup_logger from shared_configs.configs import DEFAULT_REDIS_PREFIX from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() SCAN_ITER_COUNT_DEFAULT = 4096 class TenantRedis(redis.Redis): def __init__(self, tenant_id: str, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self.tenant_id: str = tenant_id def _prefixed(self, key: str | bytes | memoryview) -> str | bytes | memoryview: prefix: str = f"{self.tenant_id}:" if isinstance(key, str): if key.startswith(prefix): return key else: return prefix + key elif isinstance(key, bytes): prefix_bytes = prefix.encode() if key.startswith(prefix_bytes): return key else: return prefix_bytes + key elif isinstance(key, memoryview): key_bytes = key.tobytes() prefix_bytes = prefix.encode() if key_bytes.startswith(prefix_bytes): return key else: return memoryview(prefix_bytes + key_bytes) else: raise TypeError(f"Unsupported key type: {type(key)}") def _prefix_method(self, method: Callable) -> Callable: @functools.wraps(method) def wrapper(*args: Any, **kwargs: Any) -> Any: if "name" in kwargs: kwargs["name"] = self._prefixed(kwargs["name"]) elif len(args) > 0: args = (self._prefixed(args[0]),) + args[1:] return method(*args, **kwargs) return wrapper def _prefix_scan_iter(self, method: Callable) -> Callable: @functools.wraps(method) def wrapper(*args: Any, **kwargs: Any) -> Any: # Prefix the match pattern if provided if "match" in kwargs: kwargs["match"] = self._prefixed(kwargs["match"]) elif len(args) > 0: args = (self._prefixed(args[0]),) + args[1:] # Get the iterator iterator = method(*args, **kwargs) # Remove prefix from returned keys prefix = f"{self.tenant_id}:".encode() prefix_len = len(prefix) for key in iterator: if isinstance(key, bytes) and key.startswith(prefix): yield key[prefix_len:] else: yield key return wrapper def __getattribute__(self, item: str) -> Any: original_attr = super().__getattribute__(item) methods_to_wrap = [ "lock", "unlock", "get", "set", "setex", "delete", "exists", "incrby", "hset", "hget", "getset", "owned", "reacquire", "create_lock", "startswith", "smembers", "sismember", "sadd", "srem", "scard", "hexists", "hset", "hdel", "ttl", "pttl", ] # Regular methods that need simple prefixing if item == "scan_iter" or item == "sscan_iter": return self._prefix_scan_iter(original_attr) elif item in methods_to_wrap and callable(original_attr): return self._prefix_method(original_attr) return original_attr class RedisPool: _instance: Optional["RedisPool"] = None _lock: threading.Lock = threading.Lock() _pool: redis.BlockingConnectionPool _replica_pool: redis.BlockingConnectionPool def __new__(cls) -> "RedisPool": if not cls._instance: with cls._lock: if not cls._instance: cls._instance = super(RedisPool, cls).__new__(cls) cls._instance._init_pools() return cls._instance def _init_pools(self) -> None: self._pool = RedisPool.create_pool(ssl=REDIS_SSL) self._replica_pool = RedisPool.create_pool( host=REDIS_REPLICA_HOST, ssl=REDIS_SSL ) def get_client(self, tenant_id: str) -> Redis: return TenantRedis(tenant_id, connection_pool=self._pool) def get_replica_client(self, tenant_id: str) -> Redis: return TenantRedis(tenant_id, connection_pool=self._replica_pool) def get_raw_client(self) -> Redis: """ Returns a Redis client with direct access to the primary connection pool, without tenant prefixing. """ return redis.Redis(connection_pool=self._pool) def get_raw_replica_client(self) -> Redis: """ Returns a Redis client with direct access to the replica connection pool, without tenant prefixing. """ return redis.Redis(connection_pool=self._replica_pool) @staticmethod def create_pool( host: str = REDIS_HOST, port: int = REDIS_PORT, db: int = REDIS_DB_NUMBER, password: str = REDIS_PASSWORD, max_connections: int = REDIS_POOL_MAX_CONNECTIONS, ssl_ca_certs: str | None = REDIS_SSL_CA_CERTS, ssl_cert_reqs: str = REDIS_SSL_CERT_REQS, ssl: bool = False, ) -> redis.BlockingConnectionPool: """ Create a Redis connection pool with appropriate SSL configuration. SSL Configuration Priority: 1. IAM Authentication (USE_REDIS_IAM_AUTH=true): Uses system CA certificates 2. Regular SSL (REDIS_SSL=true): Uses custom SSL configuration 3. No SSL: Standard connection without encryption Note: IAM authentication automatically enables SSL and takes precedence over regular SSL configuration to ensure proper security. We use BlockingConnectionPool because it will block and wait for a connection rather than error if max_connections is reached. This is far more deterministic behavior and aligned with how we want to use Redis.""" # Using ConnectionPool is not well documented. # Useful examples: https://github.com/redis/redis-py/issues/780 # Handle IAM authentication if USE_REDIS_IAM_AUTH: # For IAM authentication, we don't use password # and ensure SSL is enabled with proper context ssl_context = create_redis_ssl_context_if_iam() return redis.BlockingConnectionPool( host=host, port=port, db=db, password=None, # No password with IAM auth max_connections=max_connections, timeout=None, health_check_interval=REDIS_HEALTH_CHECK_INTERVAL, socket_keepalive=True, socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS, connection_class=redis.SSLConnection, ssl_context=ssl_context, # Use IAM auth SSL context ) if ssl: return redis.BlockingConnectionPool( host=host, port=port, db=db, password=password, max_connections=max_connections, timeout=None, health_check_interval=REDIS_HEALTH_CHECK_INTERVAL, socket_keepalive=True, socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS, connection_class=redis.SSLConnection, ssl_ca_certs=ssl_ca_certs, ssl_cert_reqs=ssl_cert_reqs, ) return redis.BlockingConnectionPool( host=host, port=port, db=db, password=password, max_connections=max_connections, timeout=None, health_check_interval=REDIS_HEALTH_CHECK_INTERVAL, socket_keepalive=True, socket_keepalive_options=REDIS_SOCKET_KEEPALIVE_OPTIONS, ) redis_pool = RedisPool() # # Usage example # redis_pool = RedisPool() # redis_client = redis_pool.get_client() # # Example of setting and getting a value # redis_client.set('key', 'value') # value = redis_client.get('key') # print(value.decode()) # Output: 'value' def get_redis_client( *, # This argument will be deprecated in the future tenant_id: str | None = None, ) -> Redis: """ Returns a Redis client with tenant-specific key prefixing. This ensures proper data isolation between tenants by automatically prefixing all Redis keys with the tenant ID. Use this when working with tenant-specific data that should be isolated from other tenants. """ if tenant_id is None: tenant_id = get_current_tenant_id() return redis_pool.get_client(tenant_id) def get_redis_replica_client( *, # this argument will be deprecated in the future tenant_id: str | None = None, ) -> Redis: """ Returns a Redis replica client with tenant-specific key prefixing. Similar to get_redis_client(), but connects to a read replica when available. This ensures proper data isolation between tenants by automatically prefixing all Redis keys with the tenant ID. Use this for read-heavy operations on tenant-specific data. """ if tenant_id is None: tenant_id = get_current_tenant_id() return redis_pool.get_replica_client(tenant_id) def get_shared_redis_client() -> Redis: """ Returns a Redis client with a shared namespace prefix. Unlike tenant-specific clients, this uses a common prefix for all keys, creating a shared namespace accessible across all tenants. Use this for data that should be shared across the application and isn't specific to any individual tenant. """ return redis_pool.get_client(DEFAULT_REDIS_PREFIX) def get_shared_redis_replica_client() -> Redis: """ Returns a Redis replica client with a shared namespace prefix. Similar to get_shared_redis_client(), but connects to a read replica when available. Uses a common prefix for all keys, creating a shared namespace. Use this for read-heavy operations on data that should be shared across the application. """ return redis_pool.get_replica_client(DEFAULT_REDIS_PREFIX) def get_raw_redis_client() -> Redis: """ Returns a Redis client that doesn't apply tenant prefixing to keys. Use this only when you need to access Redis directly without tenant isolation or any key prefixing. Typically needed for integrating with external systems or libraries that have inflexible key requirements. Warning: Be careful with this client as it bypasses tenant isolation. """ return redis_pool.get_raw_client() def get_raw_redis_replica_client() -> Redis: """ Returns a Redis replica client that doesn't apply tenant prefixing to keys. Similar to get_raw_redis_client(), but connects to a read replica when available. Use this for read-heavy operations that need direct Redis access without tenant isolation or key prefixing. Warning: Be careful with this client as it bypasses tenant isolation. """ return redis_pool.get_raw_replica_client() SSL_CERT_REQS_MAP = { "none": ssl.CERT_NONE, "optional": ssl.CERT_OPTIONAL, "required": ssl.CERT_REQUIRED, } _async_redis_connection: aioredis.Redis | None = None _async_lock = asyncio.Lock() async def get_async_redis_connection() -> aioredis.Redis: """ Provides a shared async Redis connection, using the same configs (host, port, SSL, etc.). Ensures that the connection is created only once (lazily) and reused for all future calls. """ global _async_redis_connection # If we haven't yet created an async Redis connection, we need to create one if _async_redis_connection is None: # Acquire the lock to ensure that only one coroutine attempts to create the connection async with _async_lock: # Double-check inside the lock to avoid race conditions if _async_redis_connection is None: # Load env vars or your config variables connection_kwargs: dict[str, Any] = { "host": REDIS_HOST, "port": REDIS_PORT, "db": REDIS_DB_NUMBER, "password": REDIS_PASSWORD, "max_connections": REDIS_POOL_MAX_CONNECTIONS, "health_check_interval": REDIS_HEALTH_CHECK_INTERVAL, "socket_keepalive": True, "socket_keepalive_options": REDIS_SOCKET_KEEPALIVE_OPTIONS, } if USE_REDIS_IAM_AUTH: configure_redis_iam_auth(connection_kwargs) elif REDIS_SSL: ssl_context = ssl.create_default_context() if REDIS_SSL_CA_CERTS: ssl_context.load_verify_locations(REDIS_SSL_CA_CERTS) ssl_context.check_hostname = False # Map your string to the proper ssl.CERT_* constant ssl_context.verify_mode = SSL_CERT_REQS_MAP.get( REDIS_SSL_CERT_REQS, ssl.CERT_NONE ) connection_kwargs["ssl"] = ssl_context # Create a new Redis connection (or connection pool) with SSL configuration _async_redis_connection = aioredis.Redis(**connection_kwargs) # Return the established connection (or pool) for all future operations return _async_redis_connection async def retrieve_auth_token_data(token: str) -> dict | None: """Validate auth token against Redis and return token data. Args: token: The raw authentication token string. Returns: Token data dict if valid, None if invalid/expired. """ try: redis = await get_async_redis_connection() redis_key = REDIS_AUTH_KEY_PREFIX + token token_data_str = await redis.get(redis_key) if not token_data_str: logger.debug(f"Token key {redis_key} not found or expired in Redis") return None return json.loads(token_data_str) except json.JSONDecodeError: logger.error("Error decoding token data from Redis") return None except Exception as e: logger.error(f"Unexpected error in retrieve_auth_token_data: {str(e)}") raise ValueError(f"Unexpected error in retrieve_auth_token_data: {str(e)}") async def retrieve_auth_token_data_from_redis(request: Request) -> dict | None: """Validate auth token from request cookie. Wrapper for backwards compatibility.""" token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME) if not token: logger.debug("No auth token cookie found") return None return await retrieve_auth_token_data(token) # WebSocket token prefix (separate from regular auth tokens) REDIS_WS_TOKEN_PREFIX = "ws_token:" # WebSocket tokens expire after 60 seconds WS_TOKEN_TTL_SECONDS = 60 # Rate limit: max tokens per user per window WS_TOKEN_RATE_LIMIT_MAX = 10 WS_TOKEN_RATE_LIMIT_WINDOW_SECONDS = 60 REDIS_WS_TOKEN_RATE_LIMIT_PREFIX = "ws_token_rate:" class WsTokenRateLimitExceeded(Exception): """Raised when a user exceeds the WS token generation rate limit.""" async def store_ws_token(token: str, user_id: str) -> None: """Store a short-lived WebSocket authentication token in Redis. Args: token: The generated WS token. user_id: The user ID to associate with this token. Raises: WsTokenRateLimitExceeded: If the user has exceeded the rate limit. """ redis = await get_async_redis_connection() # Atomically increment and check rate limit to avoid TOCTOU races rate_limit_key = REDIS_WS_TOKEN_RATE_LIMIT_PREFIX + user_id pipe = redis.pipeline() pipe.incr(rate_limit_key) pipe.expire(rate_limit_key, WS_TOKEN_RATE_LIMIT_WINDOW_SECONDS) results = await pipe.execute() new_count = results[0] if new_count > WS_TOKEN_RATE_LIMIT_MAX: # Over limit — decrement back since we won't use this slot await redis.decr(rate_limit_key) logger.warning(f"WS token rate limit exceeded for user {user_id}") raise WsTokenRateLimitExceeded( f"Rate limit exceeded. Maximum {WS_TOKEN_RATE_LIMIT_MAX} tokens per minute." ) # Store the actual token redis_key = REDIS_WS_TOKEN_PREFIX + token token_data = json.dumps({"sub": user_id}) await redis.set(redis_key, token_data, ex=WS_TOKEN_TTL_SECONDS) async def retrieve_ws_token_data(token: str) -> dict | None: """Validate a WebSocket token and return the token data. This uses GETDEL for atomic get-and-delete to prevent race conditions where the same token could be used twice. Args: token: The WS token to validate. Returns: Token data dict with 'sub' (user ID) if valid, None if invalid/expired. """ try: redis = await get_async_redis_connection() redis_key = REDIS_WS_TOKEN_PREFIX + token # Atomic get-and-delete to prevent race conditions (Redis 6.2+) token_data_str = await redis.getdel(redis_key) if not token_data_str: return None return json.loads(token_data_str) except json.JSONDecodeError: logger.error("Error decoding WS token data from Redis") return None except Exception as e: logger.error(f"Unexpected error in retrieve_ws_token_data: {str(e)}") return None def redis_lock_dump(lock: RedisLock, r: Redis) -> None: # diagnostic logging for lock errors name = lock.name ttl = r.ttl(name) locked = lock.locked() owned = lock.owned() local_token: str | None = lock.local.token remote_token_raw = r.get(lock.name) if remote_token_raw: remote_token_bytes = cast(bytes, remote_token_raw) remote_token = remote_token_bytes.decode("utf-8") else: remote_token = None logger.warning( f"RedisLock diagnostic: " f"name={name} " f"locked={locked} " f"owned={owned} " f"local_token={local_token} " f"remote_token={remote_token} " f"ttl={ttl}" ) ================================================ FILE: backend/onyx/redis/redis_usergroup.py ================================================ import time from typing import cast from uuid import uuid4 import redis from celery import Celery from redis import Redis from redis.lock import Lock as RedisLock from sqlalchemy.orm import Session from onyx.configs.app_configs import DB_YIELD_PER_DEFAULT from onyx.configs.constants import CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import OnyxRedisConstants from onyx.redis.redis_object_helper import RedisObjectHelper from onyx.utils.variable_functionality import fetch_versioned_implementation from onyx.utils.variable_functionality import global_version class RedisUserGroup(RedisObjectHelper): PREFIX = "usergroup" FENCE_PREFIX = PREFIX + "_fence" FENCE_TTL = 7 * 24 * 60 * 60 # 7 days - defensive TTL to prevent memory leaks TASKSET_PREFIX = PREFIX + "_taskset" TASKSET_TTL = FENCE_TTL def __init__(self, tenant_id: str, id: int) -> None: super().__init__(tenant_id, str(id)) @property def fenced(self) -> bool: if self.redis.exists(self.fence_key): return True return False def set_fence(self, payload: int | None) -> None: if payload is None: self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) self.redis.delete(self.fence_key) return self.redis.set(self.fence_key, payload, ex=self.FENCE_TTL) self.redis.sadd(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) @property def payload(self) -> int | None: bytes = self.redis.get(self.fence_key) if bytes is None: return None progress = int(cast(int, bytes)) return progress def generate_tasks( self, max_tasks: int, # noqa: ARG002 celery_app: Celery, db_session: Session, redis_client: Redis, lock: RedisLock, tenant_id: str, ) -> tuple[int, int] | None: """Max tasks is ignored for now until we can build the logic to mark the user group up to date over multiple batches. """ last_lock_time = time.monotonic() num_tasks_sent = 0 if not global_version.is_ee_version(): return 0, 0 try: construct_document_id_select_by_usergroup = fetch_versioned_implementation( "onyx.db.user_group", "construct_document_id_select_by_usergroup", ) except ModuleNotFoundError: return 0, 0 stmt = construct_document_id_select_by_usergroup(int(self._id)) for doc_id in db_session.scalars(stmt).yield_per(DB_YIELD_PER_DEFAULT): doc_id = cast(str, doc_id) current_time = time.monotonic() if current_time - last_lock_time >= ( CELERY_VESPA_SYNC_BEAT_LOCK_TIMEOUT / 4 ): lock.reacquire() last_lock_time = current_time # celery's default task id format is "dd32ded3-00aa-4884-8b21-42f8332e7fac" # the key for the result is "celery-task-meta-dd32ded3-00aa-4884-8b21-42f8332e7fac" # we prefix the task id so it's easier to keep track of who created the task # aka "documentset_1_6dd32ded3-00aa-4884-8b21-42f8332e7fac" custom_task_id = f"{self.task_id_prefix}_{uuid4()}" # add to the set BEFORE creating the task. redis_client.sadd(self.taskset_key, custom_task_id) redis_client.expire(self.taskset_key, self.TASKSET_TTL) celery_app.send_task( OnyxCeleryTask.VESPA_METADATA_SYNC_TASK, kwargs=dict(document_id=doc_id, tenant_id=tenant_id), queue=OnyxCeleryQueues.VESPA_METADATA_SYNC, task_id=custom_task_id, priority=OnyxCeleryPriority.MEDIUM, ) num_tasks_sent += 1 return num_tasks_sent, num_tasks_sent def reset(self) -> None: self.redis.srem(OnyxRedisConstants.ACTIVE_FENCES, self.fence_key) self.redis.delete(self.taskset_key) self.redis.delete(self.fence_key) @staticmethod def reset_all(r: redis.Redis) -> None: for key in r.scan_iter(RedisUserGroup.TASKSET_PREFIX + "*"): r.delete(key) for key in r.scan_iter(RedisUserGroup.FENCE_PREFIX + "*"): r.delete(key) ================================================ FILE: backend/onyx/redis/redis_utils.py ================================================ from onyx.redis.redis_connector_delete import RedisConnectorDelete from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync from onyx.redis.redis_connector_prune import RedisConnectorPrune from onyx.redis.redis_document_set import RedisDocumentSet from onyx.redis.redis_usergroup import RedisUserGroup def is_fence(key_bytes: bytes) -> bool: key_str = key_bytes.decode("utf-8") if key_str.startswith(RedisDocumentSet.FENCE_PREFIX): return True if key_str.startswith(RedisUserGroup.FENCE_PREFIX): return True if key_str.startswith(RedisConnectorDelete.FENCE_PREFIX): return True if key_str.startswith(RedisConnectorPrune.FENCE_PREFIX): return True if key_str.startswith(RedisConnectorPermissionSync.FENCE_PREFIX): return True return False ================================================ FILE: backend/onyx/secondary_llm_flows/__init__.py ================================================ ================================================ FILE: backend/onyx/secondary_llm_flows/chat_session_naming.py ================================================ from onyx.chat.llm_step import translate_history_to_llm_format from onyx.chat.models import ChatMessageSimple from onyx.configs.constants import MessageType from onyx.llm.interfaces import LLM from onyx.llm.models import ReasoningEffort from onyx.llm.utils import llm_response_to_string from onyx.prompts.chat_prompts import CHAT_NAMING_REMINDER from onyx.prompts.chat_prompts import CHAT_NAMING_SYSTEM_PROMPT from onyx.tracing.llm_utils import llm_generation_span from onyx.tracing.llm_utils import record_llm_response from onyx.utils.logger import setup_logger logger = setup_logger() def generate_chat_session_name( chat_history: list[ChatMessageSimple], llm: LLM, ) -> str: system_prompt = ChatMessageSimple( message=CHAT_NAMING_SYSTEM_PROMPT, token_count=100, message_type=MessageType.SYSTEM, ) reminder_prompt = ChatMessageSimple( message=CHAT_NAMING_REMINDER, token_count=100, message_type=MessageType.USER_REMINDER, ) complete_message_history = [system_prompt] + chat_history + [reminder_prompt] llm_facing_history = translate_history_to_llm_format( complete_message_history, llm.config ) # Call LLM with Braintrust tracing with llm_generation_span( llm=llm, flow="chat_session_naming", input_messages=llm_facing_history ) as span_generation: response = llm.invoke(llm_facing_history, reasoning_effort=ReasoningEffort.OFF) record_llm_response(span_generation, response) new_name_raw = llm_response_to_string(response) return new_name_raw.strip().strip('"') ================================================ FILE: backend/onyx/secondary_llm_flows/document_filter.py ================================================ import json import re from onyx.context.search.models import ContextExpansionType from onyx.context.search.models import InferenceChunk from onyx.context.search.models import InferenceSection from onyx.llm.interfaces import LLM from onyx.llm.models import ReasoningEffort from onyx.llm.models import UserMessage from onyx.prompts.search_prompts import DOCUMENT_CONTEXT_SELECTION_PROMPT from onyx.prompts.search_prompts import DOCUMENT_SELECTION_PROMPT from onyx.prompts.search_prompts import TRY_TO_FILL_TO_MAX_INSTRUCTIONS from onyx.tools.tool_implementations.search.constants import ( MAX_CHUNKS_FOR_RELEVANCE, ) from onyx.tracing.llm_utils import llm_generation_span from onyx.tracing.llm_utils import record_llm_response from onyx.utils.logger import setup_logger logger = setup_logger() def select_chunks_for_relevance( section: InferenceSection, max_chunks: int = MAX_CHUNKS_FOR_RELEVANCE, ) -> list[InferenceChunk]: """Select a subset of chunks from a section based on center chunk position. Logic: - Always include the center chunk - If there are chunks directly next to it by index, grab the preceding and following - Otherwise grab 2 in the direction that does exist (2 before or 2 after) - If there are not enough in either direction, just grab what's available - If there are no other chunks, just use the central chunk Args: section: InferenceSection with center_chunk and chunks max_chunks: Maximum number of chunks to select (default: MAX_CHUNKS_FOR_RELEVANCE) Returns: List of selected InferenceChunks ordered by position """ if max_chunks <= 0: return [] center_chunk = section.center_chunk all_chunks = section.chunks # Find the index of the center chunk in the chunks list try: center_index = next( i for i, chunk in enumerate(all_chunks) if chunk.chunk_id == center_chunk.chunk_id ) except StopIteration: # If center chunk not found in chunks list, just return center chunk return [center_chunk] if max_chunks == 1: return [center_chunk] # Calculate how many chunks to take before and after chunks_needed = max_chunks - 1 # minus 1 for center chunk # Determine available chunks before and after center chunks_before_available = center_index chunks_after_available = len(all_chunks) - center_index - 1 # Start with balanced distribution (1 before, 1 after for max_chunks=3) chunks_before = min(chunks_needed // 2, chunks_before_available) chunks_after = min(chunks_needed // 2, chunks_after_available) # Allocate remaining chunks to whichever direction has availability remaining = chunks_needed - chunks_before - chunks_after if remaining > 0: # Try to add more chunks before center if available if chunks_before_available > chunks_before: additional_before = min(remaining, chunks_before_available - chunks_before) chunks_before += additional_before remaining -= additional_before # Try to add more chunks after center if available if remaining > 0 and chunks_after_available > chunks_after: additional_after = min(remaining, chunks_after_available - chunks_after) chunks_after += additional_after # Select the chunks start_index = center_index - chunks_before end_index = center_index + chunks_after + 1 # +1 to include center and chunks after return all_chunks[start_index:end_index] def classify_section_relevance( document_title: str, section_text: str, user_query: str, llm: LLM, section_above_text: str | None, section_below_text: str | None, ) -> ContextExpansionType: """Use LLM to classify section relevance and determine context expansion type. Args: section_text: The text content of the section to classify user_query: The user's search query llm: LLM instance to use for classification section_above_text: Text content from chunks above the section section_below_text: Text content from chunks below the section Returns: ContextExpansionType indicating how the section should be expanded """ # Build the prompt prompt_text = DOCUMENT_CONTEXT_SELECTION_PROMPT.format( document_title=document_title, main_section=section_text, section_above=section_above_text if section_above_text else "N/A", section_below=section_below_text if section_below_text else "N/A", user_query=user_query, ) # Default to MAIN_SECTION_ONLY default_classification = ContextExpansionType.MAIN_SECTION_ONLY # Call LLM for classification with Braintrust tracing try: prompt_msg = UserMessage(content=prompt_text) with llm_generation_span( llm=llm, flow="classify_section_relevance", input_messages=[prompt_msg] ) as span_generation: response = llm.invoke( prompt=prompt_msg, reasoning_effort=ReasoningEffort.OFF, ) record_llm_response(span_generation, response) llm_response = response.choice.message.content if not llm_response: logger.warning( "LLM returned empty response for context selection, defaulting to MAIN_SECTION_ONLY" ) classification = default_classification else: # Parse the response to extract the situation number (0-3) numbers = re.findall(r"\b[0-3]\b", llm_response) if numbers: situation = int(numbers[-1]) # Map situation number to ContextExpansionType situation_to_type = { 0: ContextExpansionType.NOT_RELEVANT, 1: ContextExpansionType.MAIN_SECTION_ONLY, 2: ContextExpansionType.INCLUDE_ADJACENT_SECTIONS, 3: ContextExpansionType.FULL_DOCUMENT, } classification = situation_to_type.get( situation, default_classification ) else: logger.warning( f"Could not parse situation number from LLM response: {llm_response}" ) classification = default_classification except Exception as e: logger.error(f"Error calling LLM for context selection: {e}") classification = default_classification # To save some effort down the line, if there is nothing surrounding, don't allow a classification of adjacent or whole doc if ( not section_above_text and not section_below_text and classification != ContextExpansionType.NOT_RELEVANT ): classification = ContextExpansionType.MAIN_SECTION_ONLY return classification def select_sections_for_expansion( sections: list[InferenceSection], user_query: str, llm: LLM, max_sections: int = 10, max_chunks_per_section: int | None = MAX_CHUNKS_FOR_RELEVANCE, try_to_fill_to_max: bool = False, ) -> tuple[list[InferenceSection], list[str] | None]: """Use LLM to select the most relevant document sections for expansion. Args: sections: List of InferenceSection objects to select from user_query: The user's search query llm: LLM instance to use for selection max_sections: Maximum number of sections to select (default: 10) max_chunks_per_section: Maximum chunks to consider per section (default: MAX_CHUNKS_FOR_RELEVANCE) Returns: A tuple of: - Filtered list of InferenceSection objects selected by the LLM - List of document IDs for sections marked with "!" by the LLM, or None if none. Note: The "!" marker support exists in parsing but is not currently used because the prompt does not instruct the LLM to use it. """ if not sections: return [], None # Create a mapping of section ID to section section_map: dict[str, InferenceSection] = {} sections_dict: list[dict[str, str | int | list[str]]] = [] for idx, section in enumerate(sections): # Create a unique ID for each section section_id = f"{idx}" section_map[section_id] = section # Format the section for the LLM chunk = section.center_chunk # Combine primary and secondary owners for authors authors = None if chunk.primary_owners or chunk.secondary_owners: authors = [] if chunk.primary_owners: authors.extend(chunk.primary_owners) if chunk.secondary_owners: authors.extend(chunk.secondary_owners) # Format updated_at as ISO string if available updated_at_str = None if chunk.updated_at: updated_at_str = chunk.updated_at.isoformat() # Convert metadata to JSON string metadata_str = json.dumps(chunk.metadata) # Select only the most relevant chunks from the section to avoid flooding # the LLM with too much content from documents with many matching sections if max_chunks_per_section is not None: selected_chunks = select_chunks_for_relevance( section, max_chunks_per_section ) selected_content = " ".join(chunk.content for chunk in selected_chunks) else: selected_content = section.combined_content section_dict: dict[str, str | int | list[str]] = { "section_id": idx, "title": chunk.semantic_identifier, } # Only include updated_at if not None if updated_at_str is not None: section_dict["updated_at"] = updated_at_str # Only include authors if not None if authors is not None: section_dict["authors"] = authors section_dict["source_type"] = str(chunk.source_type) section_dict["metadata"] = metadata_str section_dict["content"] = selected_content sections_dict.append(section_dict) # Build the prompt extra_instructions = TRY_TO_FILL_TO_MAX_INSTRUCTIONS if try_to_fill_to_max else "" prompt_text = UserMessage( content=DOCUMENT_SELECTION_PROMPT.format( max_sections=max_sections, extra_instructions=extra_instructions, formatted_doc_sections=json.dumps(sections_dict, indent=2), user_query=user_query, ) ) # Call LLM for selection with Braintrust tracing try: with llm_generation_span( llm=llm, flow="select_sections_for_expansion", input_messages=[prompt_text] ) as span_generation: response = llm.invoke( prompt=[prompt_text], reasoning_effort=ReasoningEffort.OFF ) record_llm_response(span_generation, response) llm_response = response.choice.message.content if not llm_response: logger.warning( "LLM returned empty response for document selection, returning first max_sections" ) return sections[:max_sections], None # Parse the response to extract section IDs # Look for patterns like [1, 2, 3] or [1,2,3] with flexible whitespace/newlines # Also handle unbracketed comma-separated lists like "1, 2, 3" # Track which sections have "!" marker (e.g., "1, 2!, 3" or "[1, 2!, 3]") section_ids = [] sections_with_exclamation = set() # Track section IDs that have "!" marker # First try to find a bracketed list bracket_pattern = r"\[([^\]]+)\]" bracket_match = re.search(bracket_pattern, llm_response) if bracket_match: # Extract the content between brackets list_content = bracket_match.group(1) # Split by comma, preserving the parts parts = [part.strip() for part in list_content.split(",")] for part in parts: # Check if this part has an exclamation mark has_exclamation = "!" in part # Extract the number (digits only) numbers = re.findall(r"\d+", part) if numbers: section_id = numbers[0] section_ids.append(section_id) if has_exclamation: sections_with_exclamation.add(section_id) else: # Try to find an unbracketed comma-separated list # Look for patterns like "1, 2, 3" or "1, 2!, 3" # This regex finds sequences of digits optionally followed by "!" and separated by commas comma_list_pattern = r"\b\d+!?\b(?:\s*,\s*\b\d+!?\b)*" comma_match = re.search(comma_list_pattern, llm_response) if comma_match: # Extract the matched comma-separated list list_content = comma_match.group(0) parts = [part.strip() for part in list_content.split(",")] for part in parts: # Check if this part has an exclamation mark has_exclamation = "!" in part # Extract the number (digits only) numbers = re.findall(r"\d+", part) if numbers: section_id = numbers[0] section_ids.append(section_id) if has_exclamation: sections_with_exclamation.add(section_id) else: # Fallback: try to extract all numbers from the response # Also check for "!" after numbers number_pattern = r"\b(\d+)(!)?\b" matches = re.finditer(number_pattern, llm_response) for match in matches: section_id = match.group(1) has_exclamation = match.group(2) == "!" section_ids.append(section_id) if has_exclamation: sections_with_exclamation.add(section_id) if not section_ids: logger.warning( f"Could not parse section IDs from LLM response: {llm_response}" ) return sections[:max_sections], None # Filter sections based on LLM selection # Skip out-of-range IDs and don't count them toward max_sections selected_sections = [] document_ids_with_exclamation = [] # Collect document_ids for sections with "!" num_sections = len(sections) for section_id_str in section_ids: # Convert to int try: section_id_int = int(section_id_str) except ValueError: logger.warning(f"Could not convert section ID to int: {section_id_str}") continue # Check if in valid range if section_id_int < 0 or section_id_int >= num_sections: logger.warning( f"Section ID {section_id_int} is out of range [0, {num_sections - 1}], skipping" ) continue # Convert back to string for section_map lookup section_id = str(section_id_int) if section_id in section_map: section = section_map[section_id] selected_sections.append(section) # If this section has an exclamation mark, collect its document_id if section_id_str in sections_with_exclamation: document_id = section.center_chunk.document_id if document_id not in document_ids_with_exclamation: document_ids_with_exclamation.append(document_id) # Stop if we've reached max_sections valid selections if len(selected_sections) >= max_sections: break if not selected_sections: logger.warning( "No valid sections selected from LLM response, returning first max_sections" ) return sections[:max_sections], None # Collect all selected document IDs selected_document_ids = [ section.center_chunk.document_id for section in selected_sections ] logger.debug( f"LLM selected {len(selected_sections)} valid sections from {len(sections)} total candidates. " f"Selected document IDs: {selected_document_ids}. " f"Document IDs with exclamation: {document_ids_with_exclamation if document_ids_with_exclamation else []}" ) # Return document_ids if any sections had exclamation marks, otherwise None return selected_sections, ( document_ids_with_exclamation if document_ids_with_exclamation else None ) except Exception as e: logger.error(f"Error calling LLM for document selection: {e}") return sections[:max_sections], None ================================================ FILE: backend/onyx/secondary_llm_flows/memory_update.py ================================================ from onyx.configs.constants import MessageType from onyx.llm.interfaces import LLM from onyx.llm.models import ReasoningEffort from onyx.llm.models import UserMessage from onyx.prompts.basic_memory import FULL_MEMORY_UPDATE_PROMPT from onyx.tools.models import ChatMinimalTextMessage from onyx.tracing.llm_utils import llm_generation_span from onyx.tracing.llm_utils import record_llm_response from onyx.utils.logger import setup_logger from onyx.utils.text_processing import parse_llm_json_response logger = setup_logger() # Maximum number of user messages to include MAX_USER_MESSAGES = 3 MAX_CHARS_PER_MESSAGE = 500 def _format_chat_history(chat_history: list[ChatMinimalTextMessage]) -> str: user_messages = [ msg for msg in chat_history if msg.message_type == MessageType.USER ] if not user_messages: return "No chat history available." # Take the last N user messages recent_user_messages = user_messages[-MAX_USER_MESSAGES:] formatted_parts = [] for i, msg in enumerate(recent_user_messages, start=1): if len(msg.message) > MAX_CHARS_PER_MESSAGE: truncated_message = msg.message[:MAX_CHARS_PER_MESSAGE] + "[...truncated]" else: truncated_message = msg.message formatted_parts.append(f"\nUser message:\n{truncated_message}\n") return "".join(formatted_parts).strip() def _format_existing_memories(existing_memories: list[str]) -> str: """Format existing memories as a numbered list (1-indexed for readability).""" if not existing_memories: return "No existing memories." formatted_lines = [] for i, memory in enumerate(existing_memories, start=1): formatted_lines.append(f"{i}. {memory}") return "\n".join(formatted_lines) def _format_user_basic_information( user_name: str | None, user_email: str | None, user_role: str | None, ) -> str: """Format user basic information, only including fields that have values.""" lines = [] if user_name: lines.append(f"User name: {user_name}") if user_email: lines.append(f"User email: {user_email}") if user_role: lines.append(f"User role: {user_role}") if not lines: return "" return "\n\n# User Basic Information\n" + "\n".join(lines) def process_memory_update( new_memory: str, existing_memories: list[str], chat_history: list[ChatMinimalTextMessage], llm: LLM, user_name: str | None = None, user_email: str | None = None, user_role: str | None = None, ) -> tuple[str, int | None]: """ Determine if a memory should be added or updated. Uses the LLM to analyze the new memory against existing memories and determine whether to add it as new or update an existing memory. Args: new_memory: The new memory text from the memory tool existing_memories: List of existing memory strings chat_history: Recent chat history for context llm: LLM instance to use for the decision user_name: Optional user name for context user_email: Optional user email for context user_role: Optional user role for context Returns: Tuple of (memory_text, index_to_replace) - memory_text: The final memory text to store - index_to_replace: Index in existing_memories to replace, or None if adding new """ # Format inputs for the prompt formatted_chat_history = _format_chat_history(chat_history) formatted_memories = _format_existing_memories(existing_memories) formatted_user_info = _format_user_basic_information( user_name, user_email, user_role ) # Build the prompt prompt = FULL_MEMORY_UPDATE_PROMPT.format( chat_history=formatted_chat_history, user_basic_information=formatted_user_info, existing_memories=formatted_memories, new_memory=new_memory, ) # Call LLM with Braintrust tracing try: prompt_msg = UserMessage(content=prompt) with llm_generation_span( llm=llm, flow="memory_update", input_messages=[prompt_msg] ) as span_generation: response = llm.invoke( prompt=prompt_msg, reasoning_effort=ReasoningEffort.OFF ) record_llm_response(span_generation, response) content = response.choice.message.content except Exception as e: logger.warning(f"LLM invocation failed for memory update: {e}") return (new_memory, None) # Handle empty response if not content: logger.warning( "LLM returned empty response for memory update, defaulting to add" ) return (new_memory, None) # Parse JSON response parsed_response = parse_llm_json_response(content) if not parsed_response: logger.warning( f"Failed to parse JSON from LLM response: {content[:200]}..., defaulting to add" ) return (new_memory, None) # Extract fields from response operation = parsed_response.get("operation", "add").lower() memory_id = parsed_response.get("memory_id") memory_text = parsed_response.get("memory_text", new_memory) # Ensure memory_text is valid if not memory_text or not isinstance(memory_text, str): memory_text = new_memory # Handle add operation if operation == "add": logger.debug("Memory update operation: add") return (memory_text, None) # Handle update operation if operation == "update": # Validate memory_id if memory_id is None: logger.warning("Update operation specified but no memory_id provided") return (memory_text, None) # Convert memory_id to integer if it's a string try: memory_id_int = int(memory_id) except (ValueError, TypeError): logger.warning(f"Invalid memory_id format: {memory_id}") return (memory_text, None) # Convert from 1-indexed (LLM response) to 0-indexed (internal) index_to_replace = memory_id_int - 1 # Validate index is in range if index_to_replace < 0 or index_to_replace >= len(existing_memories): logger.warning( f"memory_id {memory_id_int} out of range (1-{len(existing_memories)}), defaulting to add" ) return (memory_text, None) logger.debug(f"Memory update operation: update at index {index_to_replace}") return (memory_text, index_to_replace) # Unknown operation, default to add logger.warning(f"Unknown operation '{operation}', defaulting to add") return (memory_text, None) ================================================ FILE: backend/onyx/secondary_llm_flows/query_expansion.py ================================================ from onyx.configs.constants import MessageType from onyx.llm.interfaces import LLM from onyx.llm.models import AssistantMessage from onyx.llm.models import ChatCompletionMessage from onyx.llm.models import ReasoningEffort from onyx.llm.models import SystemMessage from onyx.llm.models import UserMessage from onyx.prompts.prompt_utils import get_current_llm_day_time from onyx.prompts.search_prompts import KEYWORD_REPHRASE_SYSTEM_PROMPT from onyx.prompts.search_prompts import KEYWORD_REPHRASE_USER_PROMPT from onyx.prompts.search_prompts import REPHRASE_CONTEXT_PROMPT from onyx.prompts.search_prompts import SEMANTIC_QUERY_REPHRASE_SYSTEM_PROMPT from onyx.prompts.search_prompts import SEMANTIC_QUERY_REPHRASE_USER_PROMPT from onyx.tools.models import ChatMinimalTextMessage from onyx.tracing.llm_utils import llm_generation_span from onyx.tracing.llm_utils import record_llm_response from onyx.utils.logger import setup_logger logger = setup_logger() def _build_additional_context( user_info: str | None = None, memories: list[str] | None = None, ) -> str: """Build additional context section for query rephrasing/expansion. Returns empty string if both user_info and memories are None/empty. Otherwise returns formatted context with "N/A" for missing fields. """ has_user_info = user_info and user_info.strip() has_memories = memories and any(m.strip() for m in memories) if not has_user_info and not has_memories: return "" formatted_user_info = user_info if has_user_info else "N/A" formatted_memories = ( "\n".join(f"- {memory}" for memory in memories) if has_memories and memories else "N/A" ) return REPHRASE_CONTEXT_PROMPT.format( user_info=formatted_user_info, memories=formatted_memories, ) def _build_message_history( history: list[ChatMinimalTextMessage], ) -> list[ChatCompletionMessage]: """Convert ChatMinimalTextMessage list to ChatCompletionMessage list.""" messages: list[ChatCompletionMessage] = [] for msg in history: if msg.message_type == MessageType.USER: user_msg = UserMessage(content=msg.message) messages.append(user_msg) elif msg.message_type == MessageType.ASSISTANT: assistant_msg = AssistantMessage(content=msg.message) messages.append(assistant_msg) return messages def semantic_query_rephrase( history: list[ChatMinimalTextMessage], llm: LLM, user_info: str | None = None, memories: list[str] | None = None, ) -> str: """Rephrase a query into a standalone query using chat history context. Converts the user's query into a self-contained search query that incorporates relevant context from the chat history and optional user information/memories. Args: history: Chat message history. Must contain at least one user message. llm: Language model to use for rephrasing user_info: Optional user information for personalization memories: Optional user memories for personalization Returns: Rephrased standalone query string Raises: ValueError: If history is empty or contains no user messages RuntimeError: If LLM fails to generate a rephrased query """ if not history: raise ValueError("History cannot be empty for query rephrasing") # Find the last user message in the history last_user_message_idx = None for i in range(len(history) - 1, -1, -1): if history[i].message_type == MessageType.USER: last_user_message_idx = i break if last_user_message_idx is None: raise ValueError("History must contain at least one user message") # Extract the last user query user_query = history[last_user_message_idx].message # Build additional context section additional_context = _build_additional_context(user_info, memories) current_datetime_str = get_current_llm_day_time( include_day_of_week=True, full_sentence=False ) # Build system message with current date system_msg = SystemMessage( content=SEMANTIC_QUERY_REPHRASE_SYSTEM_PROMPT.format( current_date=current_datetime_str ) ) # Convert chat history to message format (excluding the last user message and everything after it) messages: list[ChatCompletionMessage] = [system_msg] messages.extend(_build_message_history(history[:last_user_message_idx])) # Add the last message as the user prompt with instructions final_user_msg = UserMessage( content=SEMANTIC_QUERY_REPHRASE_USER_PROMPT.format( additional_context=additional_context, user_query=user_query ) ) messages.append(final_user_msg) # Call LLM and return result with Braintrust tracing with llm_generation_span( llm=llm, flow="semantic_query_rephrase", input_messages=messages ) as span_generation: response = llm.invoke(prompt=messages, reasoning_effort=ReasoningEffort.OFF) record_llm_response(span_generation, response) final_query = response.choice.message.content if not final_query: # It's ok if some other queries fail, this one is likely the best one # It also can't fail in parsing so we should be able to guarantee a valid query here. raise RuntimeError("LLM failed to generate a rephrased query") return final_query def keyword_query_expansion( history: list[ChatMinimalTextMessage], llm: LLM, user_info: str | None = None, memories: list[str] | None = None, ) -> list[str] | None: """Expand a query into multiple keyword-only queries using chat history context. Converts the user's query into a set of keyword-based search queries (max 3) that incorporate relevant context from the chat history and optional user information/memories. Returns a list of keyword queries. Args: history: Chat message history. Must contain at least one user message. llm: Language model to use for keyword expansion user_info: Optional user information for personalization memories: Optional user memories for personalization Returns: List of keyword-only query strings (max 3), or empty list if generation fails Raises: ValueError: If history is empty or contains no user messages """ if not history: raise ValueError("History cannot be empty for keyword query expansion") # Find the last user message in the history last_user_message_idx = None for i in range(len(history) - 1, -1, -1): if history[i].message_type == MessageType.USER: last_user_message_idx = i break if last_user_message_idx is None: raise ValueError("History must contain at least one user message") # Extract the last user query user_query = history[last_user_message_idx].message # Build additional context section additional_context = _build_additional_context(user_info, memories) current_datetime_str = get_current_llm_day_time( include_day_of_week=True, full_sentence=False ) # Build system message with current date system_msg = SystemMessage( content=KEYWORD_REPHRASE_SYSTEM_PROMPT.format(current_date=current_datetime_str) ) # Convert chat history to message format (excluding the last user message and everything after it) messages: list[ChatCompletionMessage] = [system_msg] messages.extend(_build_message_history(history[:last_user_message_idx])) # Add the last message as the user prompt with instructions final_user_msg = UserMessage( content=KEYWORD_REPHRASE_USER_PROMPT.format( additional_context=additional_context, user_query=user_query ) ) messages.append(final_user_msg) # Call LLM and return result with Braintrust tracing with llm_generation_span( llm=llm, flow="keyword_query_expansion", input_messages=messages ) as span_generation: response = llm.invoke(prompt=messages, reasoning_effort=ReasoningEffort.OFF) record_llm_response(span_generation, response) content = response.choice.message.content # Parse the response - each line is a separate keyword query if not content: return [] queries = [line.strip() for line in content.strip().split("\n") if line.strip()] return queries ================================================ FILE: backend/onyx/secondary_llm_flows/source_filter.py ================================================ from sqlalchemy.orm import Session from onyx.configs.constants import DocumentSource from onyx.llm.interfaces import LLM from onyx.utils.logger import setup_logger logger = setup_logger() def strings_to_document_sources(source_strs: list[str]) -> list[DocumentSource]: sources = [] for s in source_strs: try: sources.append(DocumentSource(s)) except ValueError: logger.warning(f"Failed to translate {s} to a DocumentSource") return sources def extract_source_filter( query: str, llm: LLM, db_session: Session ) -> list[DocumentSource] | None: # Can reference onyx/prompts/filter_extration.py for previous implementation prompts raise NotImplementedError("This function should not be getting called right now") ================================================ FILE: backend/onyx/secondary_llm_flows/time_filter.py ================================================ from datetime import datetime from datetime import timezone from dateutil.parser import parse from onyx.llm.interfaces import LLM from onyx.utils.logger import setup_logger logger = setup_logger() def best_match_time(time_str: str) -> datetime | None: preferred_formats = ["%m/%d/%Y", "%m-%d-%Y"] for fmt in preferred_formats: try: # As we don't know if the user is interacting with the API server from # the same timezone as the API server, just assume the queries are UTC time # the few hours offset (if any) shouldn't make any significant difference dt = datetime.strptime(time_str, fmt) return dt.replace(tzinfo=timezone.utc) except ValueError: continue # If the above formats don't match, try using dateutil's parser try: dt = parse(time_str) return ( dt.astimezone(timezone.utc) if dt.tzinfo else dt.replace(tzinfo=timezone.utc) ) except ValueError: return None def extract_time_filter(query: str, llm: LLM) -> tuple[datetime | None, bool]: """Returns a datetime if a hard time filter should be applied for the given query Additionally returns a bool, True if more recently updated Documents should be heavily favored""" raise NotImplementedError("This function should not be getting called right now") # def _get_time_filter_messages(query: str) -> list[dict[str, str]]: # messages = [ # { # "role": "system", # "content": TIME_FILTER_PROMPT.format( # current_day_time_str=get_current_llm_day_time() # ), # }, # { # "role": "user", # "content": "What documents in Confluence were written in the last two quarters", # }, # { # "role": "assistant", # "content": json.dumps( # { # "filter_type": "hard cutoff", # "filter_value": "quarter", # "value_multiple": 2, # } # ), # }, # {"role": "user", "content": "What's the latest on project Corgies?"}, # { # "role": "assistant", # "content": json.dumps({"filter_type": "favor recent"}), # }, # { # "role": "user", # "content": "Which customer asked about security features in February of 2022?", # }, # { # "role": "assistant", # "content": json.dumps( # {"filter_type": "hard cutoff", "date": "02/01/2022"} # ), # }, # {"role": "user", "content": query}, # ] # return messages # def _extract_time_filter_from_llm_out( # model_out: str, # ) -> tuple[datetime | None, bool]: # """Returns a datetime for a hard cutoff and a bool for if the""" # try: # model_json = json.loads(model_out, strict=False) # except json.JSONDecodeError: # return None, False # # If filter type is not present, just assume something has gone wrong # # Potentially model has identified a date and just returned that but # # better to be conservative and not identify the wrong filter. # if "filter_type" not in model_json: # return None, False # if "hard" in model_json["filter_type"] or "recent" in model_json["filter_type"]: # favor_recent = "recent" in model_json["filter_type"] # if "date" in model_json: # extracted_time = best_match_time(model_json["date"]) # if extracted_time is not None: # # LLM struggles to understand the concept of not sensitive within a time range # # So if a time is extracted, just go with that alone # return extracted_time, False # time_diff = None # multiplier = 1.0 # if "value_multiple" in model_json: # try: # multiplier = float(model_json["value_multiple"]) # except ValueError: # pass # if "filter_value" in model_json: # filter_value = model_json["filter_value"] # if "day" in filter_value: # time_diff = timedelta(days=multiplier) # elif "week" in filter_value: # time_diff = timedelta(weeks=multiplier) # elif "month" in filter_value: # # Have to just use the average here, too complicated to calculate exact day # # based on current day etc. # time_diff = timedelta(days=multiplier * 30.437) # elif "quarter" in filter_value: # time_diff = timedelta(days=multiplier * 91.25) # elif "year" in filter_value: # time_diff = timedelta(days=multiplier * 365) # if time_diff is not None: # current = datetime.now(timezone.utc) # # LLM struggles to understand the concept of not sensitive within a time range # # So if a time is extracted, just go with that alone # return current - time_diff, False # # If we failed to extract a hard filter, just pass back the value of favor recent # return None, favor_recent # return None, False # messages = _get_time_filter_messages(query) # filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) # model_output = message_to_string(llm.invoke_langchain(filled_llm_prompt)) # logger.debug(model_output) # return _extract_time_filter_from_llm_out(model_output) ================================================ FILE: backend/onyx/seeding/__init__.py ================================================ ================================================ FILE: backend/onyx/server/__init__.py ================================================ ================================================ FILE: backend/onyx/server/api_key/api.py ================================================ from fastapi import APIRouter from fastapi import Depends from sqlalchemy.orm import Session from onyx.auth.users import current_admin_user from onyx.db.api_key import ApiKeyDescriptor from onyx.db.api_key import fetch_api_keys from onyx.db.api_key import insert_api_key from onyx.db.api_key import regenerate_api_key from onyx.db.api_key import remove_api_key from onyx.db.api_key import update_api_key from onyx.db.engine.sql_engine import get_session from onyx.db.models import User from onyx.server.api_key.models import APIKeyArgs router = APIRouter(prefix="/admin/api-key") @router.get("") def list_api_keys( _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> list[ApiKeyDescriptor]: return fetch_api_keys(db_session) @router.post("") def create_api_key( api_key_args: APIKeyArgs, user: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> ApiKeyDescriptor: return insert_api_key(db_session, api_key_args, user.id) @router.post("/{api_key_id}/regenerate") def regenerate_existing_api_key( api_key_id: int, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> ApiKeyDescriptor: return regenerate_api_key(db_session, api_key_id) @router.patch("/{api_key_id}") def update_existing_api_key( api_key_id: int, api_key_args: APIKeyArgs, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> ApiKeyDescriptor: return update_api_key(db_session, api_key_id, api_key_args) @router.delete("/{api_key_id}") def delete_api_key( api_key_id: int, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> None: remove_api_key(db_session, api_key_id) ================================================ FILE: backend/onyx/server/api_key/models.py ================================================ from pydantic import BaseModel from onyx.auth.schemas import UserRole class APIKeyArgs(BaseModel): name: str | None = None role: UserRole = UserRole.BASIC ================================================ FILE: backend/onyx/server/api_key_usage.py ================================================ """API key and PAT usage tracking for cloud usage limits.""" from fastapi import Depends from fastapi import Request from sqlalchemy.orm import Session from onyx.auth.api_key import get_hashed_api_key_from_request from onyx.auth.pat import get_hashed_pat_from_request from onyx.db.engine.sql_engine import get_session from onyx.db.usage import increment_usage from onyx.db.usage import UsageType from onyx.server.usage_limits import check_usage_and_raise from onyx.server.usage_limits import is_usage_limits_enabled from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() def check_api_key_usage( request: Request, db_session: Session = Depends(get_session), ) -> None: """ FastAPI dependency that checks and tracks API key/PAT usage limits. This should be added as a dependency to endpoints that accept API key or PAT authentication and should be usage-limited. """ if not is_usage_limits_enabled(): return # Check if request is authenticated via API key or PAT is_api_key_request = get_hashed_api_key_from_request(request) is not None is_pat_request = get_hashed_pat_from_request(request) is not None if not is_api_key_request and not is_pat_request: return tenant_id = get_current_tenant_id() # Check usage limit check_usage_and_raise( db_session=db_session, usage_type=UsageType.API_CALLS, tenant_id=tenant_id, pending_amount=1, ) # Increment usage counter increment_usage( db_session=db_session, usage_type=UsageType.API_CALLS, amount=1, ) db_session.commit() ================================================ FILE: backend/onyx/server/auth_check.py ================================================ from typing import cast from fastapi import FastAPI from fastapi.dependencies.models import Dependant from starlette.routing import BaseRoute from onyx.auth.users import current_admin_user from onyx.auth.users import current_chat_accessible_user from onyx.auth.users import current_curator_or_admin_user from onyx.auth.users import current_limited_user from onyx.auth.users import current_user from onyx.auth.users import current_user_from_websocket from onyx.auth.users import current_user_with_expired_token from onyx.configs.app_configs import APP_API_PREFIX from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop PUBLIC_ENDPOINT_SPECS = [ # built-in documentation functions ("/openapi.json", {"GET", "HEAD"}), ("/docs", {"GET", "HEAD"}), ("/docs/oauth2-redirect", {"GET", "HEAD"}), ("/redoc", {"GET", "HEAD"}), # should always be callable, will just return 401 if not authenticated ("/me", {"GET"}), # just returns 200 to validate that the server is up ("/health", {"GET"}), # just returns auth type, needs to be accessible before the user is logged # in to determine what flow to give the user ("/auth/type", {"GET"}), # just gets the version of Onyx (e.g. 0.3.11) ("/version", {"GET"}), # Gets stable and beta versions for Onyx docker images ("/versions", {"GET"}), # stuff related to basic auth ("/auth/refresh", {"POST"}), ("/auth/register", {"POST"}), ("/auth/login", {"POST"}), ("/auth/logout", {"POST"}), ("/auth/forgot-password", {"POST"}), ("/auth/reset-password", {"POST"}), ("/auth/request-verify-token", {"POST"}), ("/auth/verify", {"POST"}), ("/users/me", {"GET"}), ("/users/me", {"PATCH"}), ("/users/{id}", {"GET"}), ("/users/{id}", {"PATCH"}), ("/users/{id}", {"DELETE"}), # oauth ("/auth/oauth/authorize", {"GET"}), ("/auth/oauth/callback", {"GET"}), # oidc ("/auth/oidc/authorize", {"GET"}), ("/auth/oidc/callback", {"GET"}), # saml ("/auth/saml/authorize", {"GET"}), ("/auth/saml/callback", {"POST"}), ("/auth/saml/callback", {"GET"}), ("/auth/saml/logout", {"POST"}), # anonymous user on cloud ("/tenants/anonymous-user", {"POST"}), ("/metrics", {"GET"}), # added by prometheus_fastapi_instrumentator # craft webapp proxy — access enforced per-session via sharing_scope in handler ("/build/sessions/{session_id}/webapp", {"GET"}), ("/build/sessions/{session_id}/webapp/{path:path}", {"GET"}), ] def is_route_in_spec_list( route: BaseRoute, public_endpoint_specs: list[tuple[str, set[str]]] ) -> bool: if not hasattr(route, "path") or not hasattr(route, "methods"): return False # try adding the prefix AND not adding the prefix, since some endpoints # are not prefixed (e.g. /openapi.json) if (route.path, route.methods) in public_endpoint_specs: return True processed_global_prefix = f"/{APP_API_PREFIX.strip('/')}" if APP_API_PREFIX else "" if not processed_global_prefix: return False for endpoint_spec in public_endpoint_specs: base_path, methods = endpoint_spec prefixed_path = f"{processed_global_prefix}/{base_path.strip('/')}" if prefixed_path == route.path and route.methods == methods: return True return False def check_router_auth( application: FastAPI, public_endpoint_specs: list[tuple[str, set[str]]] = PUBLIC_ENDPOINT_SPECS, ) -> None: """Ensures that all endpoints on the passed in application either (1) have auth enabled OR (2) are explicitly marked as a public endpoint """ control_plane_dep = fetch_ee_implementation_or_noop( "onyx.server.tenants.access", "control_plane_dep" ) current_cloud_superuser = fetch_ee_implementation_or_noop( "onyx.auth.users", "current_cloud_superuser" ) verify_scim_token = fetch_ee_implementation_or_noop( "onyx.server.scim.auth", "verify_scim_token" ) for route in application.routes: # explicitly marked as public if is_route_in_spec_list(route, public_endpoint_specs): continue # check for auth found_auth = False route_dependant_obj = cast( Dependant | None, route.dependant if hasattr(route, "dependant") else None ) if route_dependant_obj: for dependency in route_dependant_obj.dependencies: depends_fn = dependency.cache_key[0] if ( depends_fn == current_limited_user or depends_fn == current_user or depends_fn == current_admin_user or depends_fn == current_curator_or_admin_user or depends_fn == current_user_with_expired_token or depends_fn == current_chat_accessible_user or depends_fn == current_user_from_websocket or depends_fn == control_plane_dep or depends_fn == current_cloud_superuser or depends_fn == verify_scim_token ): found_auth = True break if not found_auth: # uncomment to print out all route(s) that are missing auth # print(f"(\"{route.path}\", {set(route.methods)}),") raise RuntimeError( f"Did not find user dependency in private route - {route}" ) ================================================ FILE: backend/onyx/server/documents/__init__.py ================================================ ================================================ FILE: backend/onyx/server/documents/cc_pair.py ================================================ from datetime import datetime from http import HTTPStatus from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import Query from fastapi.responses import JSONResponse from sqlalchemy import select from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from onyx.auth.users import current_curator_or_admin_user from onyx.auth.users import current_user from onyx.background.celery.tasks.pruning.tasks import ( try_creating_prune_generator_task, ) from onyx.background.celery.versioned_apps.client import app as client_app from onyx.background.indexing.models import IndexAttemptErrorPydantic from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import PUBLIC_API_TAGS from onyx.connectors.exceptions import ValidationError from onyx.connectors.factory import validate_ccpair_for_user from onyx.db.connector import delete_connector from onyx.db.connector_credential_pair import add_credential_to_connector from onyx.db.connector_credential_pair import ( get_connector_credential_pair_from_id_for_user, ) from onyx.db.connector_credential_pair import remove_credential_from_connector from onyx.db.connector_credential_pair import ( update_connector_credential_pair_from_id, ) from onyx.db.connector_credential_pair import verify_user_has_access_to_cc_pair from onyx.db.document import get_document_counts_for_cc_pairs from onyx.db.document import get_documents_for_cc_pair from onyx.db.engine.sql_engine import get_session from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import IndexingStatus from onyx.db.enums import PermissionSyncStatus from onyx.db.index_attempt import count_index_attempt_errors_for_cc_pair from onyx.db.index_attempt import count_index_attempts_for_cc_pair from onyx.db.index_attempt import get_index_attempt_errors_for_cc_pair from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id from onyx.db.index_attempt import ( get_latest_successful_index_attempt_for_cc_pair_id, ) from onyx.db.index_attempt import get_paginated_index_attempts_for_cc_pair_id from onyx.db.indexing_coordination import IndexingCoordination from onyx.db.models import IndexAttempt from onyx.db.models import User from onyx.db.permission_sync_attempt import ( get_latest_doc_permission_sync_attempt_for_cc_pair, ) from onyx.db.permission_sync_attempt import ( get_recent_doc_permission_sync_attempts_for_cc_pair, ) from onyx.redis.redis_connector import RedisConnector from onyx.redis.redis_connector_utils import get_deletion_attempt_snapshot from onyx.redis.redis_pool import get_redis_client from onyx.server.documents.models import CCPairFullInfo from onyx.server.documents.models import CCPropertyUpdateRequest from onyx.server.documents.models import CCStatusUpdateRequest from onyx.server.documents.models import ConnectorCredentialPairIdentifier from onyx.server.documents.models import ConnectorCredentialPairMetadata from onyx.server.documents.models import DocumentSyncStatus from onyx.server.documents.models import IndexAttemptSnapshot from onyx.server.documents.models import PaginatedReturn from onyx.server.documents.models import PermissionSyncAttemptSnapshot from onyx.server.models import StatusResponse from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() router = APIRouter(prefix="/manage") @router.get("/admin/cc-pair/{cc_pair_id}/index-attempts", tags=PUBLIC_API_TAGS) def get_cc_pair_index_attempts( cc_pair_id: int, page_num: int = Query(0, ge=0), page_size: int = Query(10, ge=1, le=1000), user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> PaginatedReturn[IndexAttemptSnapshot]: if user: user_has_access = verify_user_has_access_to_cc_pair( cc_pair_id, db_session, user, get_editable=False ) if not user_has_access: raise HTTPException( status_code=400, detail="CC Pair not found for current user permissions" ) total_count = count_index_attempts_for_cc_pair( db_session=db_session, cc_pair_id=cc_pair_id, ) index_attempts = get_paginated_index_attempts_for_cc_pair_id( db_session=db_session, cc_pair_id=cc_pair_id, page=page_num, page_size=page_size, ) return PaginatedReturn( items=[ IndexAttemptSnapshot.from_index_attempt_db_model(index_attempt) for index_attempt in index_attempts ], total_items=total_count, ) @router.get("/admin/cc-pair/{cc_pair_id}/permission-sync-attempts") def get_cc_pair_permission_sync_attempts( cc_pair_id: int, page_num: int = Query(0, ge=0), page_size: int = Query(10, ge=1, le=1000), user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> PaginatedReturn[PermissionSyncAttemptSnapshot]: if user: user_has_access = verify_user_has_access_to_cc_pair( cc_pair_id, db_session, user, get_editable=False ) if not user_has_access: raise HTTPException( status_code=400, detail="CC Pair not found for current user permissions" ) # Get all permission sync attempts for this cc pair all_attempts = get_recent_doc_permission_sync_attempts_for_cc_pair( cc_pair_id=cc_pair_id, limit=1000, db_session=db_session, ) start_idx = page_num * page_size end_idx = start_idx + page_size paginated_attempts = all_attempts[start_idx:end_idx] items = [ PermissionSyncAttemptSnapshot.from_permission_sync_attempt_db_model(attempt) for attempt in paginated_attempts ] return PaginatedReturn( items=items, total_items=len(all_attempts), ) @router.get("/admin/cc-pair/{cc_pair_id}", tags=PUBLIC_API_TAGS) def get_cc_pair_full_info( cc_pair_id: int, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> CCPairFullInfo: tenant_id = get_current_tenant_id() cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id, db_session, user, get_editable=False ) if not cc_pair: raise HTTPException( status_code=404, detail="CC Pair not found for current user permissions" ) editable_cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id, db_session, user, get_editable=True ) is_editable_for_current_user = editable_cc_pair is not None document_count_info_list = list( get_document_counts_for_cc_pairs( db_session=db_session, cc_pairs=[ ConnectorCredentialPairIdentifier( connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id, ) ], ) ) documents_indexed = ( document_count_info_list[0][-1] if document_count_info_list else 0 ) latest_attempt = get_latest_index_attempt_for_cc_pair_id( db_session=db_session, connector_credential_pair_id=cc_pair_id, secondary_index=False, only_finished=False, ) latest_successful_attempt = get_latest_successful_index_attempt_for_cc_pair_id( db_session=db_session, connector_credential_pair_id=cc_pair_id, ) # Get latest permission sync attempt for status latest_permission_sync_attempt = None if cc_pair.access_type == AccessType.SYNC: latest_permission_sync_attempt = ( get_latest_doc_permission_sync_attempt_for_cc_pair( db_session=db_session, connector_credential_pair_id=cc_pair_id, ) ) return CCPairFullInfo.from_models( cc_pair_model=cc_pair, number_of_index_attempts=count_index_attempts_for_cc_pair( db_session=db_session, cc_pair_id=cc_pair_id, ), last_index_attempt=latest_attempt, last_successful_index_time=( latest_successful_attempt.time_started if latest_successful_attempt else None ), latest_deletion_attempt=get_deletion_attempt_snapshot( connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id, db_session=db_session, tenant_id=tenant_id, ), num_docs_indexed=documents_indexed, is_editable_for_current_user=is_editable_for_current_user, indexing=bool( latest_attempt and latest_attempt.status == IndexingStatus.IN_PROGRESS ), last_permission_sync_attempt_status=( latest_permission_sync_attempt.status if latest_permission_sync_attempt else None ), permission_syncing=bool( latest_permission_sync_attempt and latest_permission_sync_attempt.status == PermissionSyncStatus.IN_PROGRESS ), last_permission_sync_attempt_finished=( latest_permission_sync_attempt.time_finished if latest_permission_sync_attempt else None ), last_permission_sync_attempt_error_message=( latest_permission_sync_attempt.error_message if latest_permission_sync_attempt else None ), ) @router.put("/admin/cc-pair/{cc_pair_id}/status", tags=PUBLIC_API_TAGS) def update_cc_pair_status( cc_pair_id: int, status_update_request: CCStatusUpdateRequest, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> JSONResponse: """This method returns nearly immediately. It simply sets some signals and optimistically assumes any running background processes will clean themselves up. This is done to improve the perceived end user experience. Returns HTTPStatus.OK if everything finished. """ tenant_id = get_current_tenant_id() cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id=cc_pair_id, db_session=db_session, user=user, get_editable=True, ) if not cc_pair: raise HTTPException( status_code=400, detail="Connection not found for current user's permissions", ) redis_connector = RedisConnector(tenant_id, cc_pair_id) if status_update_request.status == ConnectorCredentialPairStatus.PAUSED: redis_connector.stop.set_fence(True) # Request cancellation for any active indexing attempts for this cc_pair active_attempts = ( db_session.execute( select(IndexAttempt).where( IndexAttempt.connector_credential_pair_id == cc_pair_id, IndexAttempt.status.in_( [IndexingStatus.NOT_STARTED, IndexingStatus.IN_PROGRESS] ), ) ) .scalars() .all() ) for attempt in active_attempts: try: IndexingCoordination.request_cancellation(db_session, attempt.id) # Revoke the task to prevent it from running if attempt.celery_task_id: client_app.control.revoke(attempt.celery_task_id) logger.info( f"Requested cancellation for active indexing attempt {attempt.id} " f"due to connector pause: cc_pair={cc_pair_id}" ) except Exception: logger.exception( f"Failed to request cancellation for indexing attempt {attempt.id}" ) else: redis_connector.stop.set_fence(False) update_connector_credential_pair_from_id( db_session=db_session, cc_pair_id=cc_pair_id, status=status_update_request.status, ) db_session.commit() # this speeds up the start of indexing by firing the check immediately client_app.send_task( OnyxCeleryTask.CHECK_FOR_INDEXING, kwargs=dict(tenant_id=tenant_id), priority=OnyxCeleryPriority.HIGH, ) return JSONResponse( status_code=HTTPStatus.OK, content={"message": str(HTTPStatus.OK)} ) @router.put("/admin/cc-pair/{cc_pair_id}/name") def update_cc_pair_name( cc_pair_id: int, new_name: str, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse[int]: cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id=cc_pair_id, db_session=db_session, user=user, get_editable=True, ) if not cc_pair: raise HTTPException( status_code=400, detail="CC Pair not found for current user's permissions" ) try: cc_pair.name = new_name db_session.commit() return StatusResponse( success=True, message="Name updated successfully", data=cc_pair_id ) except IntegrityError: db_session.rollback() raise HTTPException(status_code=400, detail="Name must be unique") @router.put("/admin/cc-pair/{cc_pair_id}/property") def update_cc_pair_property( cc_pair_id: int, update_request: CCPropertyUpdateRequest, # in seconds user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse[int]: cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id=cc_pair_id, db_session=db_session, user=user, get_editable=True, ) if not cc_pair: raise HTTPException( status_code=400, detail="CC Pair not found for current user's permissions" ) # Can we centralize logic for updating connector properties # so that we don't need to manually validate everywhere? if update_request.name == "refresh_frequency": cc_pair.connector.refresh_freq = int(update_request.value) cc_pair.connector.validate_refresh_freq() db_session.commit() msg = "Refresh frequency updated successfully" elif update_request.name == "pruning_frequency": cc_pair.connector.prune_freq = int(update_request.value) cc_pair.connector.validate_prune_freq() db_session.commit() msg = "Pruning frequency updated successfully" else: raise HTTPException( status_code=400, detail=f"Property name {update_request.name} is not valid." ) return StatusResponse(success=True, message=msg, data=cc_pair_id) @router.get("/admin/cc-pair/{cc_pair_id}/last_pruned") def get_cc_pair_last_pruned( cc_pair_id: int, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> datetime | None: cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id=cc_pair_id, db_session=db_session, user=user, get_editable=False, ) if not cc_pair: raise HTTPException( status_code=400, detail="cc_pair not found for current user's permissions", ) return cc_pair.last_pruned @router.post("/admin/cc-pair/{cc_pair_id}/prune", tags=PUBLIC_API_TAGS) def prune_cc_pair( cc_pair_id: int, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse[list[int]]: """Triggers pruning on a particular cc_pair immediately""" tenant_id = get_current_tenant_id() cc_pair = get_connector_credential_pair_from_id_for_user( cc_pair_id=cc_pair_id, db_session=db_session, user=user, get_editable=False, ) if not cc_pair: raise HTTPException( status_code=400, detail="Connection not found for current user's permissions", ) r = get_redis_client() redis_connector = RedisConnector(tenant_id, cc_pair_id) if redis_connector.prune.fenced: raise HTTPException( status_code=HTTPStatus.CONFLICT, detail="Pruning task already in progress.", ) logger.info( f"Pruning cc_pair: cc_pair={cc_pair_id} " f"connector={cc_pair.connector_id} " f"credential={cc_pair.credential_id} " f"{cc_pair.connector.name} connector." ) payload_id = try_creating_prune_generator_task( client_app, cc_pair, db_session, r, tenant_id ) if not payload_id: raise HTTPException( status_code=HTTPStatus.INTERNAL_SERVER_ERROR, detail="Pruning task creation failed.", ) logger.info(f"Pruning queued: cc_pair={cc_pair.id} id={payload_id}") return StatusResponse( success=True, message="Successfully created the pruning task.", ) @router.get("/admin/cc-pair/{cc_pair_id}/get-docs-sync-status") def get_docs_sync_status( cc_pair_id: int, _: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> list[DocumentSyncStatus]: all_docs_for_cc_pair = get_documents_for_cc_pair( db_session=db_session, cc_pair_id=cc_pair_id, ) return [DocumentSyncStatus.from_model(doc) for doc in all_docs_for_cc_pair] @router.get("/admin/cc-pair/{cc_pair_id}/errors", tags=PUBLIC_API_TAGS) def get_cc_pair_indexing_errors( cc_pair_id: int, include_resolved: bool = Query(False), page_num: int = Query(0, ge=0), page_size: int = Query(10, ge=1, le=100), _: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> PaginatedReturn[IndexAttemptErrorPydantic]: """Gives back all errors for a given CC Pair. Allows pagination based on page and page_size params. Args: cc_pair_id: ID of the connector-credential pair to get errors for include_resolved: Whether to include resolved errors in the results page_num: Page number for pagination, starting at 0 page_size: Number of errors to return per page _: Current user, must be curator or admin db_session: Database session Returns: Paginated list of indexing errors for the CC pair. """ total_count = count_index_attempt_errors_for_cc_pair( db_session=db_session, cc_pair_id=cc_pair_id, unresolved_only=not include_resolved, ) index_attempt_errors = get_index_attempt_errors_for_cc_pair( db_session=db_session, cc_pair_id=cc_pair_id, unresolved_only=not include_resolved, page=page_num, page_size=page_size, ) return PaginatedReturn( items=[IndexAttemptErrorPydantic.from_model(e) for e in index_attempt_errors], total_items=total_count, ) @router.put( "/connector/{connector_id}/credential/{credential_id}", tags=PUBLIC_API_TAGS ) def associate_credential_to_connector( connector_id: int, credential_id: int, metadata: ConnectorCredentialPairMetadata, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), tenant_id: str = Depends(get_current_tenant_id), ) -> StatusResponse[int]: """NOTE(rkuo): internally discussed and the consensus is this endpoint and create_connector_with_mock_credential should be combined. The intent of this endpoint is to handle connectors that actually need credentials. """ fetch_ee_implementation_or_noop( "onyx.db.user_group", "validate_object_creation_for_user", None )( db_session=db_session, user=user, target_group_ids=metadata.groups, object_is_public=metadata.access_type == AccessType.PUBLIC, object_is_perm_sync=metadata.access_type == AccessType.SYNC, object_is_new=True, ) try: validate_ccpair_for_user( connector_id, credential_id, metadata.access_type, db_session ) response = add_credential_to_connector( db_session=db_session, user=user, connector_id=connector_id, credential_id=credential_id, cc_pair_name=metadata.name, access_type=metadata.access_type, auto_sync_options=metadata.auto_sync_options, groups=metadata.groups, processing_mode=metadata.processing_mode, ) # trigger indexing immediately client_app.send_task( OnyxCeleryTask.CHECK_FOR_INDEXING, priority=OnyxCeleryPriority.HIGH, kwargs={"tenant_id": tenant_id}, ) logger.info( f"associate_credential_to_connector - running check_for_indexing: cc_pair={response.data}" ) return response except ValidationError as e: # If validation fails, delete the connector and commit the changes # Ensures we don't leave invalid connectors in the database # NOTE: consensus is that it makes sense to unify connector and ccpair creation flows # which would rid us of needing to handle cases like these delete_connector(db_session, connector_id) db_session.commit() raise HTTPException( status_code=400, detail="Connector validation error: " + str(e) ) except IntegrityError as e: logger.error(f"IntegrityError: {e}") delete_connector(db_session, connector_id) db_session.commit() raise HTTPException(status_code=400, detail="Name must be unique") except Exception as e: logger.exception(f"Unexpected error: {e}") raise HTTPException(status_code=500, detail="Unexpected error") @router.delete( "/connector/{connector_id}/credential/{credential_id}", tags=PUBLIC_API_TAGS ) def dissociate_credential_from_connector( connector_id: int, credential_id: int, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> StatusResponse[int]: return remove_credential_from_connector( connector_id, credential_id, user, db_session ) ================================================ FILE: backend/onyx/server/documents/connector.py ================================================ import json import math import mimetypes import os import zipfile from datetime import datetime from io import BytesIO from typing import Any from typing import cast from fastapi import APIRouter from fastapi import Depends from fastapi import File from fastapi import Form from fastapi import HTTPException from fastapi import Query from fastapi import Request from fastapi import Response from fastapi import UploadFile from google.oauth2.credentials import Credentials from pydantic import BaseModel from sqlalchemy.orm import Session from onyx.auth.email_utils import send_email from onyx.auth.users import current_admin_user from onyx.auth.users import current_chat_accessible_user from onyx.auth.users import current_curator_or_admin_user from onyx.auth.users import current_user from onyx.background.celery.tasks.pruning.tasks import ( try_creating_prune_generator_task, ) from onyx.background.celery.versioned_apps.client import app as client_app from onyx.configs.app_configs import EMAIL_CONFIGURED from onyx.configs.app_configs import ENABLED_CONNECTOR_TYPES from onyx.configs.app_configs import MOCK_CONNECTOR_FILE_PATH from onyx.configs.constants import DocumentSource from onyx.configs.constants import FileOrigin from onyx.configs.constants import MilestoneRecordType from onyx.configs.constants import ONYX_METADATA_FILENAME from onyx.configs.constants import OnyxCeleryPriority from onyx.configs.constants import OnyxCeleryTask from onyx.configs.constants import PUBLIC_API_TAGS from onyx.connectors.exceptions import ConnectorValidationError from onyx.connectors.factory import validate_ccpair_for_user from onyx.connectors.google_utils.google_auth import ( get_google_oauth_creds, ) from onyx.connectors.google_utils.google_kv import ( build_service_account_creds, ) from onyx.connectors.google_utils.google_kv import ( delete_google_app_cred, ) from onyx.connectors.google_utils.google_kv import ( delete_service_account_key, ) from onyx.connectors.google_utils.google_kv import get_auth_url from onyx.connectors.google_utils.google_kv import ( get_google_app_cred, ) from onyx.connectors.google_utils.google_kv import ( get_service_account_key, ) from onyx.connectors.google_utils.google_kv import ( update_credential_access_tokens, ) from onyx.connectors.google_utils.google_kv import ( upsert_google_app_cred, ) from onyx.connectors.google_utils.google_kv import ( upsert_service_account_key, ) from onyx.connectors.google_utils.google_kv import verify_csrf from onyx.connectors.google_utils.shared_constants import DB_CREDENTIALS_DICT_TOKEN_KEY from onyx.connectors.google_utils.shared_constants import ( GoogleOAuthAuthenticationMethod, ) from onyx.db.connector import create_connector from onyx.db.connector import delete_connector from onyx.db.connector import fetch_connector_by_id from onyx.db.connector import fetch_connectors from onyx.db.connector import fetch_unique_document_sources from onyx.db.connector import get_connector_credential_ids from onyx.db.connector import mark_ccpair_with_indexing_trigger from onyx.db.connector import update_connector from onyx.db.connector_credential_pair import add_credential_to_connector from onyx.db.connector_credential_pair import ( fetch_connector_credential_pair_for_connector, ) from onyx.db.connector_credential_pair import get_cc_pair_groups_for_ids from onyx.db.connector_credential_pair import get_connector_credential_pair from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user from onyx.db.connector_credential_pair import ( get_connector_credential_pairs_for_user_parallel, ) from onyx.db.connector_credential_pair import verify_user_has_access_to_cc_pair from onyx.db.credentials import cleanup_gmail_credentials from onyx.db.credentials import cleanup_google_drive_credentials from onyx.db.credentials import create_credential from onyx.db.credentials import delete_service_account_credentials from onyx.db.credentials import fetch_credential_by_id_for_user from onyx.db.deletion_attempt import check_deletion_attempt_is_allowed from onyx.db.document import get_document_counts_for_all_cc_pairs from onyx.db.engine.sql_engine import get_session from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import IndexingMode from onyx.db.enums import ProcessingMode from onyx.db.federated import fetch_all_federated_connectors_parallel from onyx.db.index_attempt import get_index_attempts_for_cc_pair from onyx.db.index_attempt import get_latest_index_attempts_by_status from onyx.db.index_attempt import get_latest_index_attempts_parallel from onyx.db.index_attempt import ( get_latest_successful_index_attempts_parallel, ) from onyx.db.models import ConnectorCredentialPair from onyx.db.models import FederatedConnector from onyx.db.models import IndexAttempt from onyx.db.models import IndexingStatus from onyx.db.models import User from onyx.db.models import UserRole from onyx.file_processing.file_types import PLAIN_TEXT_MIME_TYPE from onyx.file_processing.file_types import WORD_PROCESSING_MIME_TYPE from onyx.file_store.file_store import FileStore from onyx.file_store.file_store import get_default_file_store from onyx.key_value_store.interface import KvKeyNotFoundError from onyx.redis.redis_pool import get_redis_client from onyx.server.documents.models import AuthStatus from onyx.server.documents.models import AuthUrl from onyx.server.documents.models import ConnectorBase from onyx.server.documents.models import ConnectorCredentialPairIdentifier from onyx.server.documents.models import ConnectorFileInfo from onyx.server.documents.models import ConnectorFilesResponse from onyx.server.documents.models import ConnectorIndexingStatusLite from onyx.server.documents.models import ConnectorIndexingStatusLiteResponse from onyx.server.documents.models import ConnectorRequestSubmission from onyx.server.documents.models import ConnectorSnapshot from onyx.server.documents.models import ConnectorStatus from onyx.server.documents.models import ConnectorUpdateRequest from onyx.server.documents.models import CredentialBase from onyx.server.documents.models import CredentialSnapshot from onyx.server.documents.models import DocsCountOperator from onyx.server.documents.models import FailedConnectorIndexingStatus from onyx.server.documents.models import FileUploadResponse from onyx.server.documents.models import GDriveCallback from onyx.server.documents.models import GmailCallback from onyx.server.documents.models import GoogleAppCredentials from onyx.server.documents.models import GoogleServiceAccountCredentialRequest from onyx.server.documents.models import GoogleServiceAccountKey from onyx.server.documents.models import IndexedSourcesResponse from onyx.server.documents.models import IndexingStatusRequest from onyx.server.documents.models import ObjectCreationIdResponse from onyx.server.documents.models import RunConnectorRequest from onyx.server.documents.models import SourceSummary from onyx.server.federated.models import FederatedConnectorStatus from onyx.server.models import StatusResponse from onyx.server.utils_vector_db import require_vector_db from onyx.utils.logger import setup_logger from onyx.utils.telemetry import mt_cloud_telemetry from onyx.utils.threadpool_concurrency import CallableProtocol from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() _GMAIL_CREDENTIAL_ID_COOKIE_NAME = "gmail_credential_id" _GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME = "google_drive_credential_id" _INDEXING_STATUS_PAGE_SIZE = 10 SEEN_ZIP_DETAIL = "Only one zip file is allowed per file connector, \ use the ingestion APIs for multiple files" router = APIRouter(prefix="/manage", dependencies=[Depends(require_vector_db)]) """Admin only API endpoints""" @router.get("/admin/connector/gmail/app-credential") def check_google_app_gmail_credentials_exist( _: User = Depends(current_curator_or_admin_user), ) -> dict[str, str]: try: return {"client_id": get_google_app_cred(DocumentSource.GMAIL).web.client_id} except KvKeyNotFoundError: raise HTTPException(status_code=404, detail="Google App Credentials not found") @router.put("/admin/connector/gmail/app-credential") def upsert_google_app_gmail_credentials( app_credentials: GoogleAppCredentials, _: User = Depends(current_admin_user) ) -> StatusResponse: try: upsert_google_app_cred(app_credentials, DocumentSource.GMAIL) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return StatusResponse( success=True, message="Successfully saved Google App Credentials" ) @router.delete("/admin/connector/gmail/app-credential") def delete_google_app_gmail_credentials( _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse: try: delete_google_app_cred(DocumentSource.GMAIL) cleanup_gmail_credentials(db_session=db_session) except KvKeyNotFoundError as e: raise HTTPException(status_code=400, detail=str(e)) return StatusResponse( success=True, message="Successfully deleted Google App Credentials" ) @router.get("/admin/connector/google-drive/app-credential") def check_google_app_credentials_exist( _: User = Depends(current_curator_or_admin_user), ) -> dict[str, str]: try: return { "client_id": get_google_app_cred(DocumentSource.GOOGLE_DRIVE).web.client_id } except KvKeyNotFoundError: raise HTTPException(status_code=404, detail="Google App Credentials not found") @router.put("/admin/connector/google-drive/app-credential") def upsert_google_app_credentials( app_credentials: GoogleAppCredentials, _: User = Depends(current_admin_user) ) -> StatusResponse: try: upsert_google_app_cred(app_credentials, DocumentSource.GOOGLE_DRIVE) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return StatusResponse( success=True, message="Successfully saved Google App Credentials" ) @router.delete("/admin/connector/google-drive/app-credential") def delete_google_app_credentials( _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse: try: delete_google_app_cred(DocumentSource.GOOGLE_DRIVE) cleanup_google_drive_credentials(db_session=db_session) except KvKeyNotFoundError as e: raise HTTPException(status_code=400, detail=str(e)) return StatusResponse( success=True, message="Successfully deleted Google App Credentials" ) @router.get("/admin/connector/gmail/service-account-key") def check_google_service_gmail_account_key_exist( _: User = Depends(current_curator_or_admin_user), ) -> dict[str, str]: try: return { "service_account_email": get_service_account_key( DocumentSource.GMAIL ).client_email } except KvKeyNotFoundError: raise HTTPException( status_code=404, detail="Google Service Account Key not found" ) @router.put("/admin/connector/gmail/service-account-key") def upsert_google_service_gmail_account_key( service_account_key: GoogleServiceAccountKey, _: User = Depends(current_admin_user) ) -> StatusResponse: try: upsert_service_account_key(service_account_key, DocumentSource.GMAIL) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return StatusResponse( success=True, message="Successfully saved Google Service Account Key" ) @router.delete("/admin/connector/gmail/service-account-key") def delete_google_service_gmail_account_key( _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse: try: delete_service_account_key(DocumentSource.GMAIL) cleanup_gmail_credentials(db_session=db_session) except KvKeyNotFoundError as e: raise HTTPException(status_code=400, detail=str(e)) return StatusResponse( success=True, message="Successfully deleted Google Service Account Key" ) @router.get("/admin/connector/google-drive/service-account-key") def check_google_service_account_key_exist( _: User = Depends(current_curator_or_admin_user), ) -> dict[str, str]: try: return { "service_account_email": get_service_account_key( DocumentSource.GOOGLE_DRIVE ).client_email } except KvKeyNotFoundError: raise HTTPException( status_code=404, detail="Google Service Account Key not found" ) @router.put("/admin/connector/google-drive/service-account-key") def upsert_google_service_account_key( service_account_key: GoogleServiceAccountKey, _: User = Depends(current_admin_user) ) -> StatusResponse: try: upsert_service_account_key(service_account_key, DocumentSource.GOOGLE_DRIVE) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return StatusResponse( success=True, message="Successfully saved Google Service Account Key" ) @router.delete("/admin/connector/google-drive/service-account-key") def delete_google_service_account_key( _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse: try: delete_service_account_key(DocumentSource.GOOGLE_DRIVE) cleanup_google_drive_credentials(db_session=db_session) except KvKeyNotFoundError as e: raise HTTPException(status_code=400, detail=str(e)) return StatusResponse( success=True, message="Successfully deleted Google Service Account Key" ) @router.put("/admin/connector/google-drive/service-account-credential") def upsert_service_account_credential( service_account_credential_request: GoogleServiceAccountCredentialRequest, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> ObjectCreationIdResponse: """Special API which allows the creation of a credential for a service account. Combines the input with the saved service account key to create an entry in the `Credential` table.""" try: credential_base = build_service_account_creds( DocumentSource.GOOGLE_DRIVE, primary_admin_email=service_account_credential_request.google_primary_admin, name="Service Account (uploaded)", ) except KvKeyNotFoundError as e: raise HTTPException(status_code=400, detail=str(e)) # first delete all existing service account credentials delete_service_account_credentials(user, db_session, DocumentSource.GOOGLE_DRIVE) # `user=None` since this credential is not a personal credential credential = create_credential( credential_data=credential_base, user=user, db_session=db_session ) return ObjectCreationIdResponse(id=credential.id) @router.put("/admin/connector/gmail/service-account-credential") def upsert_gmail_service_account_credential( service_account_credential_request: GoogleServiceAccountCredentialRequest, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> ObjectCreationIdResponse: """Special API which allows the creation of a credential for a service account. Combines the input with the saved service account key to create an entry in the `Credential` table.""" try: credential_base = build_service_account_creds( DocumentSource.GMAIL, primary_admin_email=service_account_credential_request.google_primary_admin, ) except KvKeyNotFoundError as e: raise HTTPException(status_code=400, detail=str(e)) # first delete all existing service account credentials delete_service_account_credentials(user, db_session, DocumentSource.GMAIL) # `user=None` since this credential is not a personal credential credential = create_credential( credential_data=credential_base, user=user, db_session=db_session ) return ObjectCreationIdResponse(id=credential.id) @router.get("/admin/connector/google-drive/check-auth/{credential_id}") def check_drive_tokens( credential_id: int, user: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> AuthStatus: db_credentials = fetch_credential_by_id_for_user(credential_id, user, db_session) if not db_credentials or not db_credentials.credential_json: return AuthStatus(authenticated=False) credential_json = db_credentials.credential_json.get_value(apply_mask=False) if DB_CREDENTIALS_DICT_TOKEN_KEY not in credential_json: return AuthStatus(authenticated=False) token_json_str = str(credential_json[DB_CREDENTIALS_DICT_TOKEN_KEY]) google_drive_creds = get_google_oauth_creds( token_json_str=token_json_str, source=DocumentSource.GOOGLE_DRIVE, ) if google_drive_creds is None: return AuthStatus(authenticated=False) return AuthStatus(authenticated=True) def save_zip_metadata_to_file_store( zf: zipfile.ZipFile, file_store: FileStore ) -> str | None: """ Extract .onyx_metadata.json from zip and save to file store. Returns the file_id or None if no metadata file exists. """ try: metadata_file_info = zf.getinfo(ONYX_METADATA_FILENAME) with zf.open(metadata_file_info, "r") as metadata_file: metadata_bytes = metadata_file.read() # Validate that it's valid JSON before saving try: json.loads(metadata_bytes) except json.JSONDecodeError as e: logger.warning(f"Unable to load {ONYX_METADATA_FILENAME}: {e}") raise HTTPException( status_code=400, detail=f"Unable to load {ONYX_METADATA_FILENAME}: {e}", ) # Save to file store file_id = file_store.save_file( content=BytesIO(metadata_bytes), display_name=ONYX_METADATA_FILENAME, file_origin=FileOrigin.CONNECTOR_METADATA, file_type="application/json", ) return file_id except KeyError: logger.info(f"No {ONYX_METADATA_FILENAME} file") return None def is_zip_file(file: UploadFile) -> bool: """ Check if the file is a zip file by content type or filename. """ return bool( ( file.content_type and file.content_type.startswith( ( "application/zip", "application/x-zip-compressed", # May be this in Windows "application/x-zip", "multipart/x-zip", ) ) ) or (file.filename and file.filename.lower().endswith(".zip")) ) def upload_files( files: list[UploadFile], file_origin: FileOrigin = FileOrigin.CONNECTOR, unzip: bool = True, ) -> FileUploadResponse: # Skip directories and known macOS metadata entries def should_process_file(file_path: str) -> bool: normalized_path = os.path.normpath(file_path) return not any(part.startswith(".") for part in normalized_path.split(os.sep)) deduped_file_paths = [] deduped_file_names = [] zip_metadata_file_id: str | None = None try: file_store = get_default_file_store() seen_zip = False for file in files: if not file.filename: logger.warning("File has no filename, skipping") continue if is_zip_file(file): if seen_zip: raise HTTPException(status_code=400, detail=SEEN_ZIP_DETAIL) seen_zip = True # Validate the zip by opening it (catches corrupt/non-zip files) with zipfile.ZipFile(file.file, "r") as zf: if unzip: zip_metadata_file_id = save_zip_metadata_to_file_store( zf, file_store ) for file_info in zf.namelist(): if zf.getinfo(file_info).is_dir(): continue if not should_process_file(file_info): continue sub_file_bytes = zf.read(file_info) mime_type, __ = mimetypes.guess_type(file_info) if mime_type is None: mime_type = "application/octet-stream" file_id = file_store.save_file( content=BytesIO(sub_file_bytes), display_name=os.path.basename(file_info), file_origin=file_origin, file_type=mime_type, ) deduped_file_paths.append(file_id) deduped_file_names.append(os.path.basename(file_info)) continue # Store the zip as-is (unzip=False) file.file.seek(0) file_id = file_store.save_file( content=file.file, display_name=file.filename, file_origin=file_origin, file_type=file.content_type or "application/zip", ) deduped_file_paths.append(file_id) deduped_file_names.append(file.filename) continue # Since we can't render docx files in the UI, # we store them in the file store as plain text if file.content_type == WORD_PROCESSING_MIME_TYPE: # Lazy load to avoid importing markitdown when not needed from onyx.file_processing.extract_file_text import read_docx_file text, _ = read_docx_file(file.file, file.filename) file_id = file_store.save_file( content=BytesIO(text.encode("utf-8")), display_name=file.filename, file_origin=file_origin, file_type=PLAIN_TEXT_MIME_TYPE, ) else: file_id = file_store.save_file( content=file.file, display_name=file.filename, file_origin=file_origin, file_type=file.content_type or "text/plain", ) deduped_file_paths.append(file_id) deduped_file_names.append(file.filename) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) return FileUploadResponse( file_paths=deduped_file_paths, file_names=deduped_file_names, zip_metadata_file_id=zip_metadata_file_id, ) def _normalize_file_names_for_backwards_compatibility( file_locations: list[str], file_names: list[str] ) -> list[str]: """ Ensures file_names list is the same length as file_locations for backwards compatibility. In legacy data, file_names might not exist or be shorter than file_locations. If file_names is shorter, pads it with corresponding file_locations values. """ return file_names + file_locations[len(file_names) :] def _fetch_and_check_file_connector_cc_pair_permissions( connector_id: int, user: User, db_session: Session, require_editable: bool, ) -> ConnectorCredentialPair: cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id) if cc_pair is None: raise HTTPException( status_code=404, detail="No Connector-Credential Pair found for this connector", ) has_requested_access = verify_user_has_access_to_cc_pair( cc_pair_id=cc_pair.id, db_session=db_session, user=user, get_editable=require_editable, ) if has_requested_access: return cc_pair # Special case: global curators should be able to manage files # for public file connectors even when they are not the creator. if ( require_editable and user.role == UserRole.GLOBAL_CURATOR and cc_pair.access_type == AccessType.PUBLIC ): return cc_pair raise HTTPException( status_code=403, detail="Access denied. User cannot manage files for this connector.", ) @router.post("/admin/connector/file/upload", tags=PUBLIC_API_TAGS) def upload_files_api( files: list[UploadFile], unzip: bool = True, _: User = Depends(current_curator_or_admin_user), ) -> FileUploadResponse: return upload_files(files, FileOrigin.OTHER, unzip=unzip) @router.get("/admin/connector/{connector_id}/files", tags=PUBLIC_API_TAGS) def list_connector_files( connector_id: int, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> ConnectorFilesResponse: """List all files in a file connector.""" connector = fetch_connector_by_id(connector_id, db_session) if connector is None: raise HTTPException(status_code=404, detail="Connector not found") if connector.source != DocumentSource.FILE: raise HTTPException( status_code=400, detail="This endpoint only works with file connectors" ) _ = _fetch_and_check_file_connector_cc_pair_permissions( connector_id=connector_id, user=user, db_session=db_session, require_editable=False, ) file_locations = connector.connector_specific_config.get("file_locations", []) file_names = connector.connector_specific_config.get("file_names", []) # Normalize file_names for backwards compatibility with legacy data file_names = _normalize_file_names_for_backwards_compatibility( file_locations, file_names ) file_store = get_default_file_store() files = [] for file_id, file_name in zip(file_locations, file_names): try: file_record = file_store.read_file_record(file_id) file_size = None upload_date = None if file_record: file_size = file_store.get_file_size(file_id) upload_date = ( file_record.created_at.isoformat() if file_record.created_at else None ) files.append( ConnectorFileInfo( file_id=file_id, file_name=file_name, file_size=file_size, upload_date=upload_date, ) ) except Exception as e: logger.warning(f"Error reading file record for {file_id}: {e}") # Include file with basic info even if record fetch fails files.append( ConnectorFileInfo( file_id=file_id, file_name=file_name, ) ) return ConnectorFilesResponse(files=files) @router.post("/admin/connector/{connector_id}/files/update", tags=PUBLIC_API_TAGS) def update_connector_files( connector_id: int, files: list[UploadFile] | None = File(None), file_ids_to_remove: str = Form("[]"), user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> FileUploadResponse: """ Update files in a connector by adding new files and/or removing existing ones. This is an atomic operation that validates, updates the connector config, and triggers indexing. """ files = files or [] connector = fetch_connector_by_id(connector_id, db_session) if connector is None: raise HTTPException(status_code=404, detail="Connector not found") if connector.source != DocumentSource.FILE: raise HTTPException( status_code=400, detail="This endpoint only works with file connectors" ) # Get the connector-credential pair for indexing/pruning triggers # and validate user permissions for file management. cc_pair = _fetch_and_check_file_connector_cc_pair_permissions( connector_id=connector_id, user=user, db_session=db_session, require_editable=True, ) # Parse file IDs to remove try: file_ids_list = json.loads(file_ids_to_remove) except json.JSONDecodeError: raise HTTPException(status_code=400, detail="Invalid file_ids_to_remove format") if not isinstance(file_ids_list, list): raise HTTPException( status_code=400, detail="file_ids_to_remove must be a JSON-encoded list", ) # Get current connector config current_config = connector.connector_specific_config current_file_locations = current_config.get("file_locations", []) current_file_names = current_config.get("file_names", []) current_zip_metadata_file_id = current_config.get("zip_metadata_file_id") # Load existing metadata from file store if available file_store = get_default_file_store() current_zip_metadata: dict[str, Any] = {} if current_zip_metadata_file_id: try: metadata_io = file_store.read_file( file_id=current_zip_metadata_file_id, mode="b" ) metadata_bytes = metadata_io.read() loaded_metadata = json.loads(metadata_bytes) if isinstance(loaded_metadata, list): current_zip_metadata = {d["filename"]: d for d in loaded_metadata} else: current_zip_metadata = loaded_metadata except Exception as e: logger.warning(f"Failed to load existing metadata file: {e}") raise HTTPException( status_code=500, detail="Failed to load existing connector metadata file", ) # Upload new files if any new_file_paths = [] new_file_names_list = [] new_zip_metadata_file_id: str | None = None new_zip_metadata: dict[str, Any] = {} if files and len(files) > 0: upload_response = upload_files(files, FileOrigin.CONNECTOR) new_file_paths = upload_response.file_paths new_file_names_list = upload_response.file_names new_zip_metadata_file_id = upload_response.zip_metadata_file_id # Load new metadata from file store if available if new_zip_metadata_file_id: try: metadata_io = file_store.read_file( file_id=new_zip_metadata_file_id, mode="b" ) metadata_bytes = metadata_io.read() loaded_metadata = json.loads(metadata_bytes) if isinstance(loaded_metadata, list): new_zip_metadata = {d["filename"]: d for d in loaded_metadata} else: new_zip_metadata = loaded_metadata except Exception as e: logger.warning(f"Failed to load new metadata file: {e}") # Remove specified files files_to_remove_set = set(file_ids_list) # Normalize file_names for backwards compatibility with legacy data current_file_names = _normalize_file_names_for_backwards_compatibility( current_file_locations, current_file_names ) remaining_file_locations = [] remaining_file_names = [] removed_file_names = set() for file_id, file_name in zip(current_file_locations, current_file_names): if file_id not in files_to_remove_set: remaining_file_locations.append(file_id) remaining_file_names.append(file_name) else: removed_file_names.add(file_name) # Combine remaining files with new files final_file_locations = remaining_file_locations + new_file_paths final_file_names = remaining_file_names + new_file_names_list # Validate that at least one file remains if not final_file_locations: raise HTTPException( status_code=400, detail="Cannot remove all files from connector. At least one file must remain.", ) # Merge and filter metadata (remove metadata for deleted files) final_zip_metadata = { key: value for key, value in current_zip_metadata.items() if key not in removed_file_names } final_zip_metadata.update(new_zip_metadata) # Save merged metadata to file store if we have any metadata final_zip_metadata_file_id: str | None = None if final_zip_metadata: final_zip_metadata_file_id = file_store.save_file( content=BytesIO(json.dumps(final_zip_metadata).encode("utf-8")), display_name=ONYX_METADATA_FILENAME, file_origin=FileOrigin.CONNECTOR_METADATA, file_type="application/json", ) # Update connector config updated_config = { **current_config, "file_locations": final_file_locations, "file_names": final_file_names, "zip_metadata_file_id": final_zip_metadata_file_id, } # Remove old zip_metadata dict if present (backwards compatibility cleanup) updated_config.pop("zip_metadata", None) connector_base = ConnectorBase( name=connector.name, source=connector.source, input_type=connector.input_type, connector_specific_config=updated_config, refresh_freq=connector.refresh_freq, prune_freq=connector.prune_freq, indexing_start=connector.indexing_start, ) updated_connector = update_connector(connector_id, connector_base, db_session) if updated_connector is None: raise HTTPException( status_code=500, detail="Failed to update connector configuration" ) # Trigger re-indexing for new files and pruning for removed files try: tenant_id = get_current_tenant_id() # If files were added, mark for UPDATE indexing (only new docs) if new_file_paths: mark_ccpair_with_indexing_trigger( cc_pair.id, IndexingMode.UPDATE, db_session ) # Send task to check for indexing immediately client_app.send_task( OnyxCeleryTask.CHECK_FOR_INDEXING, kwargs={"tenant_id": tenant_id}, priority=OnyxCeleryPriority.HIGH, ) logger.info( f"Marked cc_pair {cc_pair.id} for UPDATE indexing (new files) for connector {connector_id}" ) # If files were removed, trigger pruning immediately if file_ids_list: r = get_redis_client() payload_id = try_creating_prune_generator_task( client_app, cc_pair, db_session, r, tenant_id ) if payload_id: logger.info( f"Triggered pruning for cc_pair {cc_pair.id} (removed files) for connector " f"{connector_id}, payload_id={payload_id}" ) else: logger.warning( f"Failed to trigger pruning for cc_pair {cc_pair.id} (removed files) for connector {connector_id}" ) except Exception as e: logger.error(f"Failed to trigger re-indexing after file update: {e}") return FileUploadResponse( file_paths=final_file_locations, file_names=final_file_names, zip_metadata_file_id=final_zip_metadata_file_id, ) @router.get("/admin/connector", tags=PUBLIC_API_TAGS) def get_connectors_by_credential( _: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), credential: int | None = None, ) -> list[ConnectorSnapshot]: """Get a list of connectors. Allow filtering by a specific credential id.""" connectors = fetch_connectors(db_session) filtered_connectors = [] for connector in connectors: if connector.source == DocumentSource.INGESTION_API: # don't include INGESTION_API, as it's a system level # connector not manageable by the user continue if credential is not None: found = False for cc_pair in connector.credentials: if credential == cc_pair.credential_id: found = True break if not found: continue filtered_connectors.append(ConnectorSnapshot.from_connector_db_model(connector)) return filtered_connectors # Retrieves most recent failure cases for connectors that are currently failing @router.get("/admin/connector/failed-indexing-status", tags=PUBLIC_API_TAGS) def get_currently_failed_indexing_status( secondary_index: bool = False, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), get_editable: bool = Query( False, description="If true, return editable document sets" ), ) -> list[FailedConnectorIndexingStatus]: # Get the latest failed indexing attempts latest_failed_indexing_attempts = get_latest_index_attempts_by_status( secondary_index=secondary_index, db_session=db_session, status=IndexingStatus.FAILED, ) # Get the latest successful indexing attempts latest_successful_indexing_attempts = get_latest_index_attempts_by_status( secondary_index=secondary_index, db_session=db_session, status=IndexingStatus.SUCCESS, ) # Get all connector credential pairs cc_pairs = get_connector_credential_pairs_for_user( db_session=db_session, user=user, get_editable=get_editable, ) # Filter out failed attempts that have a more recent successful attempt filtered_failed_attempts = [ failed_attempt for failed_attempt in latest_failed_indexing_attempts if not any( success_attempt.connector_credential_pair_id == failed_attempt.connector_credential_pair_id and success_attempt.time_updated > failed_attempt.time_updated for success_attempt in latest_successful_indexing_attempts ) ] # Filter cc_pairs to include only those with failed attempts cc_pairs = [ cc_pair for cc_pair in cc_pairs if any( attempt.connector_credential_pair == cc_pair for attempt in filtered_failed_attempts ) ] # Create a mapping of cc_pair_id to its latest failed index attempt cc_pair_to_latest_index_attempt = { attempt.connector_credential_pair_id: attempt for attempt in filtered_failed_attempts } indexing_statuses = [] for cc_pair in cc_pairs: # Skip DefaultCCPair if cc_pair.name == "DefaultCCPair": continue latest_index_attempt = cc_pair_to_latest_index_attempt.get(cc_pair.id) indexing_statuses.append( FailedConnectorIndexingStatus( cc_pair_id=cc_pair.id, name=cc_pair.name, error_msg=( latest_index_attempt.error_msg if latest_index_attempt else None ), connector_id=cc_pair.connector_id, credential_id=cc_pair.credential_id, is_deletable=check_deletion_attempt_is_allowed( connector_credential_pair=cc_pair, db_session=db_session, allow_scheduled=True, ) is None, ) ) return indexing_statuses @router.get("/admin/connector/status", tags=PUBLIC_API_TAGS) def get_connector_status( user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> list[ConnectorStatus]: # This method is only used document set and group creation/editing # Therefore, it is okay to get non-editable, but public cc_pairs cc_pairs = get_connector_credential_pairs_for_user( db_session=db_session, user=user, eager_load_connector=True, eager_load_credential=True, eager_load_user=True, get_editable=False, ) group_cc_pair_relationships = get_cc_pair_groups_for_ids( db_session=db_session, cc_pair_ids=[cc_pair.id for cc_pair in cc_pairs], ) group_cc_pair_relationships_dict: dict[int, list[int]] = {} for relationship in group_cc_pair_relationships: group_cc_pair_relationships_dict.setdefault(relationship.cc_pair_id, []).append( relationship.user_group_id ) # Pre-compute credential_ids per connector to avoid N+1 lazy loads connector_to_credential_ids: dict[int, list[int]] = {} for cc_pair in cc_pairs: connector_to_credential_ids.setdefault(cc_pair.connector_id, []).append( cc_pair.credential_id ) return [ ConnectorStatus( cc_pair_id=cc_pair.id, name=cc_pair.name, connector=ConnectorSnapshot.from_connector_db_model( cc_pair.connector, credential_ids=connector_to_credential_ids.get( cc_pair.connector_id, [] ), ), credential=CredentialSnapshot.from_credential_db_model(cc_pair.credential), access_type=cc_pair.access_type, groups=group_cc_pair_relationships_dict.get(cc_pair.id, []), ) for cc_pair in cc_pairs if cc_pair.name != "DefaultCCPair" and cc_pair.connector and cc_pair.credential ] @router.post("/admin/connector/indexing-status", tags=PUBLIC_API_TAGS) def get_connector_indexing_status( request: IndexingStatusRequest, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> list[ConnectorIndexingStatusLiteResponse]: tenant_id = get_current_tenant_id() # NOTE: If the connector is deleting behind the scenes, # accessing cc_pairs can be inconsistent and members like # connector or credential may be None. # Additional checks are done to make sure the connector and credential still exist. # TODO: make this one query ... possibly eager load or wrap in a read transaction # to avoid the complexity of trying to error check throughout the function # see https://stackoverflow.com/questions/75758327/ # sqlalchemy-method-connection-for-bind-is-already-in-progress # for why we can't pass in the current db_session to these functions if MOCK_CONNECTOR_FILE_PATH: import json with open(MOCK_CONNECTOR_FILE_PATH, "r") as f: raw_data = json.load(f) connector_indexing_statuses = [ ConnectorIndexingStatusLite(**status) for status in raw_data ] return [ ConnectorIndexingStatusLiteResponse( source=DocumentSource.FILE, summary=SourceSummary( total_connectors=100, active_connectors=100, public_connectors=100, total_docs_indexed=100000, ), current_page=1, total_pages=1, indexing_statuses=connector_indexing_statuses, ) ] parallel_functions: list[tuple[CallableProtocol, tuple[Any, ...]]] = [ # Get editable connector/credential pairs ( lambda: get_connector_credential_pairs_for_user_parallel( user, True, None, True, True, False, True, request.source ), (), ), # Get federated connectors (fetch_all_federated_connectors_parallel, ()), # Get most recent index attempts ( lambda: get_latest_index_attempts_parallel( request.secondary_index, True, False ), (), ), # Get most recent finished index attempts ( lambda: get_latest_index_attempts_parallel( request.secondary_index, True, True ), (), ), # Get most recent successful index attempts ( lambda: get_latest_successful_index_attempts_parallel( request.secondary_index, ), (), ), ] if user and user.role == UserRole.ADMIN: ( editable_cc_pairs, federated_connectors, latest_index_attempts, latest_finished_index_attempts, latest_successful_index_attempts, ) = run_functions_tuples_in_parallel(parallel_functions) non_editable_cc_pairs = [] else: parallel_functions.append( ( lambda: get_connector_credential_pairs_for_user_parallel( user, False, None, True, True, False, True, request.source ), (), ), ) ( editable_cc_pairs, federated_connectors, latest_index_attempts, latest_finished_index_attempts, latest_successful_index_attempts, non_editable_cc_pairs, ) = run_functions_tuples_in_parallel(parallel_functions) # Cast results to proper types non_editable_cc_pairs = cast(list[ConnectorCredentialPair], non_editable_cc_pairs) editable_cc_pairs = cast(list[ConnectorCredentialPair], editable_cc_pairs) federated_connectors = cast(list[FederatedConnector], federated_connectors) latest_index_attempts = cast(list[IndexAttempt], latest_index_attempts) latest_finished_index_attempts = cast( list[IndexAttempt], latest_finished_index_attempts ) latest_successful_index_attempts = cast( list[IndexAttempt], latest_successful_index_attempts ) document_count_info = get_document_counts_for_all_cc_pairs(db_session) # Create lookup dictionaries for efficient access cc_pair_to_document_cnt: dict[tuple[int, int], int] = { (connector_id, credential_id): cnt for connector_id, credential_id, cnt in document_count_info } def _attempt_lookup( attempts: list[IndexAttempt], ) -> dict[int, IndexAttempt]: return {attempt.connector_credential_pair_id: attempt for attempt in attempts} cc_pair_to_latest_index_attempt = _attempt_lookup(latest_index_attempts) cc_pair_to_latest_finished_index_attempt = _attempt_lookup( latest_finished_index_attempts ) cc_pair_to_latest_successful_index_attempt = _attempt_lookup( latest_successful_index_attempts ) def build_connector_indexing_status( cc_pair: ConnectorCredentialPair, is_editable: bool, ) -> ConnectorIndexingStatusLite | None: if cc_pair.name == "DefaultCCPair": return None latest_attempt = cc_pair_to_latest_index_attempt.get(cc_pair.id) latest_finished_attempt = cc_pair_to_latest_finished_index_attempt.get( cc_pair.id ) latest_successful_attempt = cc_pair_to_latest_successful_index_attempt.get( cc_pair.id ) doc_count = cc_pair_to_document_cnt.get( (cc_pair.connector_id, cc_pair.credential_id), 0 ) return _get_connector_indexing_status_lite( cc_pair, latest_attempt, latest_finished_attempt, ( latest_successful_attempt.time_started if latest_successful_attempt else None ), is_editable, doc_count, ) # Process editable cc_pairs editable_statuses: list[ConnectorIndexingStatusLite] = [] for cc_pair in editable_cc_pairs: status = build_connector_indexing_status(cc_pair, True) if status: editable_statuses.append(status) # Process non-editable cc_pairs non_editable_statuses: list[ConnectorIndexingStatusLite] = [] for cc_pair in non_editable_cc_pairs: status = build_connector_indexing_status(cc_pair, False) if status: non_editable_statuses.append(status) # Process federated connectors federated_statuses: list[FederatedConnectorStatus] = [] for federated_connector in federated_connectors: federated_status = FederatedConnectorStatus( id=federated_connector.id, source=federated_connector.source, name=f"{federated_connector.source.replace('_', ' ').title()}", ) federated_statuses.append(federated_status) source_to_summary: dict[DocumentSource, SourceSummary] = {} # Apply filters only if any are provided has_filters = bool( request.access_type_filters or request.last_status_filters or ( request.docs_count_operator is not None and request.docs_count_value is not None ) or request.name_filter ) if has_filters: editable_statuses = _apply_connector_status_filters( editable_statuses, request.access_type_filters, request.last_status_filters, request.docs_count_operator, request.docs_count_value, request.name_filter, ) non_editable_statuses = _apply_connector_status_filters( non_editable_statuses, request.access_type_filters, request.last_status_filters, request.docs_count_operator, request.docs_count_value, request.name_filter, ) federated_statuses = _apply_federated_connector_status_filters( federated_statuses, request.name_filter, ) # Calculate source summary for connector_status in ( editable_statuses + non_editable_statuses + federated_statuses ): if isinstance(connector_status, FederatedConnectorStatus): source = connector_status.source.to_non_federated_source() else: source = connector_status.source # Skip if source is None (federated connectors without mapping) if source is None: continue if source not in source_to_summary: source_to_summary[source] = SourceSummary( total_connectors=0, active_connectors=0, public_connectors=0, total_docs_indexed=0, ) source_to_summary[source].total_connectors += 1 if isinstance(connector_status, ConnectorIndexingStatusLite): if connector_status.cc_pair_status == ConnectorCredentialPairStatus.ACTIVE: source_to_summary[source].active_connectors += 1 if connector_status.access_type == AccessType.PUBLIC: source_to_summary[source].public_connectors += 1 source_to_summary[ source ].total_docs_indexed += connector_status.docs_indexed # Track admin page visit for analytics mt_cloud_telemetry( tenant_id=tenant_id, distinct_id=str(user.id), event=MilestoneRecordType.VISITED_ADMIN_PAGE, ) # Group statuses by source for pagination source_to_all_statuses: dict[ DocumentSource, list[ConnectorIndexingStatusLite | FederatedConnectorStatus] ] = {} # Group by source for connector_status in ( editable_statuses + non_editable_statuses + federated_statuses ): if isinstance(connector_status, FederatedConnectorStatus): source = connector_status.source.to_non_federated_source() else: source = connector_status.source # Skip if source is None (federated connectors without mapping) if source is None: continue if source not in source_to_all_statuses: source_to_all_statuses[source] = [] source_to_all_statuses[source].append(connector_status) # Create paginated response objects by source response_list: list[ConnectorIndexingStatusLiteResponse] = [] source_list = list(source_to_all_statuses.keys()) source_list.sort() for source in source_list: statuses = source_to_all_statuses[source] # Get current page for this source (default to page 1, 1-indexed) current_page = request.source_to_page.get(source, 1) # Calculate start and end indices for pagination (convert to 0-indexed) start_idx = (current_page - 1) * _INDEXING_STATUS_PAGE_SIZE end_idx = start_idx + _INDEXING_STATUS_PAGE_SIZE if request.get_all_connectors: page_statuses = statuses else: # Get the page slice for this source page_statuses = statuses[start_idx:end_idx] # Create response object for this source if page_statuses: # Only include sources that have data on this page response_list.append( ConnectorIndexingStatusLiteResponse( source=source, summary=source_to_summary[source], current_page=current_page, total_pages=math.ceil(len(statuses) / _INDEXING_STATUS_PAGE_SIZE), indexing_statuses=page_statuses, ) ) return response_list def _get_connector_indexing_status_lite( cc_pair: ConnectorCredentialPair, latest_index_attempt: IndexAttempt | None, latest_finished_index_attempt: IndexAttempt | None, last_successful_index_time: datetime | None, is_editable: bool, document_cnt: int, ) -> ConnectorIndexingStatusLite | None: # TODO remove this to enable ingestion API if cc_pair.name == "DefaultCCPair": return None connector = cc_pair.connector credential = cc_pair.credential if not connector or not credential: # This may happen if background deletion is happening return None in_progress = bool( latest_index_attempt and latest_index_attempt.status == IndexingStatus.IN_PROGRESS ) return ConnectorIndexingStatusLite( cc_pair_id=cc_pair.id, name=cc_pair.name, source=cc_pair.connector.source, access_type=cc_pair.access_type, cc_pair_status=cc_pair.status, is_editable=is_editable, in_progress=in_progress, in_repeated_error_state=cc_pair.in_repeated_error_state, last_finished_status=( latest_finished_index_attempt.status if latest_finished_index_attempt else None ), last_status=latest_index_attempt.status if latest_index_attempt else None, last_success=last_successful_index_time, docs_indexed=document_cnt, latest_index_attempt_docs_indexed=( latest_index_attempt.total_docs_indexed if latest_index_attempt else None ), ) def _apply_connector_status_filters( statuses: list[ConnectorIndexingStatusLite], access_type_filters: list[AccessType], last_status_filters: list[IndexingStatus], docs_count_operator: DocsCountOperator | None, docs_count_value: int | None, name_filter: str | None, ) -> list[ConnectorIndexingStatusLite]: """Apply filters to a list of ConnectorIndexingStatusLite objects""" filtered_statuses: list[ConnectorIndexingStatusLite] = [] for status in statuses: # Filter by access type if access_type_filters and status.access_type not in access_type_filters: continue # Filter by last status if last_status_filters and status.last_status not in last_status_filters: continue # Filter by document count if docs_count_operator and docs_count_value is not None: if docs_count_operator == DocsCountOperator.GREATER_THAN and not ( status.docs_indexed > docs_count_value ): continue elif docs_count_operator == DocsCountOperator.LESS_THAN and not ( status.docs_indexed < docs_count_value ): continue elif ( docs_count_operator == DocsCountOperator.EQUAL_TO and status.docs_indexed != docs_count_value ): continue # Filter by name if status.name: if name_filter and name_filter.lower() not in status.name.lower(): continue else: if name_filter: continue filtered_statuses.append(status) return filtered_statuses def _apply_federated_connector_status_filters( statuses: list[FederatedConnectorStatus], name_filter: str | None, ) -> list[FederatedConnectorStatus]: filtered_statuses: list[FederatedConnectorStatus] = [] for status in statuses: if name_filter and name_filter.lower() not in status.name.lower(): continue filtered_statuses.append(status) return filtered_statuses def _validate_connector_allowed(source: DocumentSource) -> None: valid_connectors = [ x for x in ENABLED_CONNECTOR_TYPES.replace("_", "").split(",") if x ] if not valid_connectors: return for connector_type in valid_connectors: if source.value.lower().replace("_", "") == connector_type: return raise ValueError( "This connector type has been disabled by your system admin. Please contact them to get it enabled if you wish to use it." ) @router.post("/admin/connector", tags=PUBLIC_API_TAGS) def create_connector_from_model( connector_data: ConnectorUpdateRequest, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> ObjectCreationIdResponse: tenant_id = get_current_tenant_id() try: _validate_connector_allowed(connector_data.source) fetch_ee_implementation_or_noop( "onyx.db.user_group", "validate_object_creation_for_user", None )( db_session=db_session, user=user, target_group_ids=connector_data.groups, object_is_public=connector_data.access_type == AccessType.PUBLIC, object_is_perm_sync=connector_data.access_type == AccessType.SYNC, object_is_new=True, ) connector_base = connector_data.to_connector_base() connector_response = create_connector( db_session=db_session, connector_data=connector_base, ) mt_cloud_telemetry( tenant_id=tenant_id, distinct_id=str(user.id), event=MilestoneRecordType.CREATED_CONNECTOR, ) return connector_response except ValueError as e: logger.error(f"Error creating connector: {e}") raise HTTPException(status_code=400, detail=str(e)) @router.post("/admin/connector-with-mock-credential") def create_connector_with_mock_credential( connector_data: ConnectorUpdateRequest, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse: tenant_id = get_current_tenant_id() fetch_ee_implementation_or_noop( "onyx.db.user_group", "validate_object_creation_for_user", None )( db_session=db_session, user=user, target_group_ids=connector_data.groups, object_is_public=connector_data.access_type == AccessType.PUBLIC, object_is_perm_sync=connector_data.access_type == AccessType.SYNC, ) try: _validate_connector_allowed(connector_data.source) connector_response = create_connector( db_session=db_session, connector_data=connector_data, ) mock_credential = CredentialBase( credential_json={}, admin_public=True, source=connector_data.source, ) credential = create_credential( credential_data=mock_credential, user=user, db_session=db_session, ) # Store the created connector and credential IDs connector_id = cast(int, connector_response.id) credential_id = credential.id validate_ccpair_for_user( connector_id=connector_id, credential_id=credential_id, access_type=connector_data.access_type, db_session=db_session, ) response = add_credential_to_connector( db_session=db_session, user=user, connector_id=connector_id, credential_id=credential_id, access_type=connector_data.access_type, cc_pair_name=connector_data.name, groups=connector_data.groups, ) # trigger indexing immediately client_app.send_task( OnyxCeleryTask.CHECK_FOR_INDEXING, priority=OnyxCeleryPriority.HIGH, kwargs={"tenant_id": tenant_id}, ) logger.info( f"create_connector_with_mock_credential - running check_for_indexing: cc_pair={response.data}" ) mt_cloud_telemetry( tenant_id=tenant_id, distinct_id=str(user.id), event=MilestoneRecordType.CREATED_CONNECTOR, ) return response except ConnectorValidationError as e: raise HTTPException( status_code=400, detail="Connector validation error: " + str(e) ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @router.patch("/admin/connector/{connector_id}", tags=PUBLIC_API_TAGS) def update_connector_from_model( connector_id: int, connector_data: ConnectorUpdateRequest, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> ConnectorSnapshot | StatusResponse[int]: cc_pair = fetch_connector_credential_pair_for_connector(db_session, connector_id) try: _validate_connector_allowed(connector_data.source) fetch_ee_implementation_or_noop( "onyx.db.user_group", "validate_object_creation_for_user", None )( db_session=db_session, user=user, target_group_ids=connector_data.groups, object_is_public=connector_data.access_type == AccessType.PUBLIC, object_is_perm_sync=connector_data.access_type == AccessType.SYNC, object_is_owned_by_user=cc_pair and user and cc_pair.creator_id == user.id, ) connector_base = connector_data.to_connector_base() except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) updated_connector = update_connector(connector_id, connector_base, db_session) if updated_connector is None: raise HTTPException( status_code=404, detail=f"Connector {connector_id} does not exist" ) return ConnectorSnapshot( id=updated_connector.id, name=updated_connector.name, source=updated_connector.source, input_type=updated_connector.input_type, connector_specific_config=updated_connector.connector_specific_config, refresh_freq=updated_connector.refresh_freq, prune_freq=updated_connector.prune_freq, credential_ids=[ association.credential.id for association in updated_connector.credentials ], indexing_start=updated_connector.indexing_start, time_created=updated_connector.time_created, time_updated=updated_connector.time_updated, ) @router.delete( "/admin/connector/{connector_id}", response_model=StatusResponse[int], tags=PUBLIC_API_TAGS, ) def delete_connector_by_id( connector_id: int, _: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse[int]: try: with db_session.begin(): return delete_connector( db_session=db_session, connector_id=connector_id, ) except AssertionError: raise HTTPException(status_code=400, detail="Connector is not deletable") @router.post("/admin/connector/run-once", tags=PUBLIC_API_TAGS) def connector_run_once( run_info: RunConnectorRequest, _: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse[int]: """Used to trigger indexing on a set of cc_pairs associated with a single connector.""" tenant_id = get_current_tenant_id() connector_id = run_info.connector_id specified_credential_ids = run_info.credential_ids try: possible_credential_ids = get_connector_credential_ids( run_info.connector_id, db_session ) except ValueError: raise HTTPException( status_code=404, detail=f"Connector by id {connector_id} does not exist.", ) if not specified_credential_ids: credential_ids = possible_credential_ids else: if set(specified_credential_ids).issubset(set(possible_credential_ids)): credential_ids = specified_credential_ids else: raise HTTPException( status_code=400, detail="Not all specified credentials are associated with connector", ) if not credential_ids: raise HTTPException( status_code=400, detail="Connector has no valid credentials, cannot create index attempts.", ) try: num_triggers = trigger_indexing_for_cc_pair( credential_ids, connector_id, run_info.from_beginning, tenant_id, db_session, ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) logger.info("connector_run_once - running check_for_indexing") msg = f"Marked {num_triggers} index attempts with indexing triggers." return StatusResponse( success=True, message=msg, data=num_triggers, ) """Endpoints for basic users""" @router.get("/connector/gmail/authorize/{credential_id}") def gmail_auth( response: Response, credential_id: str, _: User = Depends(current_user) ) -> AuthUrl: # set a cookie that we can read in the callback (used for `verify_csrf`) response.set_cookie( key=_GMAIL_CREDENTIAL_ID_COOKIE_NAME, value=credential_id, httponly=True, max_age=600, ) return AuthUrl(auth_url=get_auth_url(int(credential_id), DocumentSource.GMAIL)) @router.get("/connector/google-drive/authorize/{credential_id}") def google_drive_auth( response: Response, credential_id: str, _: User = Depends(current_user) ) -> AuthUrl: # set a cookie that we can read in the callback (used for `verify_csrf`) response.set_cookie( key=_GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME, value=credential_id, httponly=True, max_age=600, ) return AuthUrl( auth_url=get_auth_url(int(credential_id), DocumentSource.GOOGLE_DRIVE) ) @router.get("/connector/gmail/callback") def gmail_callback( request: Request, callback: GmailCallback = Depends(), user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> StatusResponse: credential_id_cookie = request.cookies.get(_GMAIL_CREDENTIAL_ID_COOKIE_NAME) if credential_id_cookie is None or not credential_id_cookie.isdigit(): raise HTTPException( status_code=401, detail="Request did not pass CSRF verification." ) credential_id = int(credential_id_cookie) verify_csrf(credential_id, callback.state) credentials: Credentials | None = update_credential_access_tokens( callback.code, credential_id, user, db_session, DocumentSource.GMAIL, GoogleOAuthAuthenticationMethod.UPLOADED, ) if credentials is None: raise HTTPException( status_code=500, detail="Unable to fetch Gmail access tokens" ) return StatusResponse(success=True, message="Updated Gmail access tokens") @router.get("/connector/google-drive/callback") def google_drive_callback( request: Request, callback: GDriveCallback = Depends(), user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> StatusResponse: credential_id_cookie = request.cookies.get(_GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME) if credential_id_cookie is None or not credential_id_cookie.isdigit(): raise HTTPException( status_code=401, detail="Request did not pass CSRF verification." ) credential_id = int(credential_id_cookie) verify_csrf(credential_id, callback.state) credentials: Credentials | None = update_credential_access_tokens( callback.code, credential_id, user, db_session, DocumentSource.GOOGLE_DRIVE, GoogleOAuthAuthenticationMethod.UPLOADED, ) if credentials is None: raise HTTPException( status_code=500, detail="Unable to fetch Google Drive access tokens" ) return StatusResponse(success=True, message="Updated Google Drive access tokens") @router.get("/connector", tags=PUBLIC_API_TAGS) def get_connectors( _: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> list[ConnectorSnapshot]: connectors = fetch_connectors(db_session) return [ ConnectorSnapshot.from_connector_db_model(connector) for connector in connectors # don't include INGESTION_API, as it's not a "real" # connector like those created by the user if connector.source != DocumentSource.INGESTION_API ] @router.get("/indexed-sources", tags=PUBLIC_API_TAGS) def get_indexed_sources( _: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> IndexedSourcesResponse: sources = sorted( fetch_unique_document_sources(db_session), key=lambda source: source.value ) return IndexedSourcesResponse(sources=sources) @router.get("/connector/{connector_id}", tags=PUBLIC_API_TAGS) def get_connector_by_id( connector_id: int, _: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> ConnectorSnapshot | StatusResponse[int]: connector = fetch_connector_by_id(connector_id, db_session) if connector is None: raise HTTPException( status_code=404, detail=f"Connector {connector_id} does not exist" ) return ConnectorSnapshot( id=connector.id, name=connector.name, source=connector.source, indexing_start=connector.indexing_start, input_type=connector.input_type, connector_specific_config=connector.connector_specific_config, refresh_freq=connector.refresh_freq, prune_freq=connector.prune_freq, credential_ids=[ association.credential.id for association in connector.credentials ], time_created=connector.time_created, time_updated=connector.time_updated, ) @router.post("/connector-request") def submit_connector_request( request_data: ConnectorRequestSubmission, user: User = Depends(current_user), ) -> StatusResponse: """ Submit a connector request for Cloud deployments. Tracks via PostHog telemetry and sends email to hello@onyx.app. """ tenant_id = get_current_tenant_id() connector_name = request_data.connector_name.strip() if not connector_name: raise HTTPException(status_code=400, detail="Connector name cannot be empty") user_email = user.email # Track connector request via PostHog telemetry (Cloud only) from shared_configs.configs import MULTI_TENANT if MULTI_TENANT: mt_cloud_telemetry( tenant_id=tenant_id, distinct_id=str(user.id), event=MilestoneRecordType.REQUESTED_CONNECTOR, properties={ "connector_name": connector_name, "user_email": user.email, }, ) # Send email notification (if email is configured) if EMAIL_CONFIGURED: try: subject = "Onyx Craft Connector Request" email_body_text = f"""A new connector request has been submitted: Connector Name: {connector_name} User Email: {user_email or "Not provided (anonymous user)"} Tenant ID: {tenant_id} """ email_body_html = f"""

    A new connector request has been submitted:

    • Connector Name: {connector_name}
    • User Email: {user_email or "Not provided (anonymous user)"}
    • Tenant ID: {tenant_id}
    """ send_email( user_email="hello@onyx.app", subject=subject, html_body=email_body_html, text_body=email_body_text, ) logger.info( f"Connector request email sent to hello@onyx.app for connector: {connector_name}" ) except Exception as e: # Log error but don't fail the request if email fails logger.error( f"Failed to send connector request email for {connector_name}: {e}" ) logger.info( f"Connector request submitted: {connector_name} by user {user_email or 'anonymous'} (tenant: {tenant_id})" ) return StatusResponse( success=True, message="Connector request submitted successfully. We'll prioritize popular requests!", ) class BasicCCPairInfo(BaseModel): has_successful_run: bool source: DocumentSource status: ConnectorCredentialPairStatus @router.get("/connector-status", tags=PUBLIC_API_TAGS) def get_basic_connector_indexing_status( user: User = Depends(current_chat_accessible_user), db_session: Session = Depends(get_session), ) -> list[BasicCCPairInfo]: cc_pairs = get_connector_credential_pairs_for_user( db_session=db_session, eager_load_connector=True, get_editable=False, user=user, ) # NOTE: This endpoint excludes Craft connectors return [ BasicCCPairInfo( has_successful_run=cc_pair.last_successful_index_time is not None, source=cc_pair.connector.source, status=cc_pair.status, ) for cc_pair in cc_pairs if cc_pair.connector.source != DocumentSource.INGESTION_API and cc_pair.processing_mode == ProcessingMode.REGULAR ] def trigger_indexing_for_cc_pair( specified_credential_ids: list[int], connector_id: int, from_beginning: bool, tenant_id: str, db_session: Session, ) -> int: try: possible_credential_ids = get_connector_credential_ids(connector_id, db_session) except ValueError as e: raise ValueError(f"Connector by id {connector_id} does not exist: {str(e)}") if not specified_credential_ids: credential_ids = possible_credential_ids else: if set(specified_credential_ids).issubset(set(possible_credential_ids)): credential_ids = specified_credential_ids else: raise ValueError( "Not all specified credentials are associated with connector" ) if not credential_ids: raise ValueError( "Connector has no valid credentials, cannot create index attempts." ) # Prevents index attempts for cc pairs that already have an index attempt currently running skipped_credentials = [ credential_id for credential_id in credential_ids if get_index_attempts_for_cc_pair( cc_pair_identifier=ConnectorCredentialPairIdentifier( connector_id=connector_id, credential_id=credential_id, ), only_current=True, db_session=db_session, disinclude_finished=True, ) ] connector_credential_pairs = [ get_connector_credential_pair( db_session=db_session, connector_id=connector_id, credential_id=credential_id, ) for credential_id in credential_ids if credential_id not in skipped_credentials ] num_triggers = 0 for cc_pair in connector_credential_pairs: if cc_pair is not None: indexing_mode = IndexingMode.UPDATE if from_beginning: indexing_mode = IndexingMode.REINDEX mark_ccpair_with_indexing_trigger(cc_pair.id, indexing_mode, db_session) num_triggers += 1 logger.info( f"connector_run_once - marking cc_pair with indexing trigger: " f"connector={connector_id} " f"cc_pair={cc_pair.id} " f"indexing_trigger={indexing_mode}" ) priority = OnyxCeleryPriority.HIGH # run the beat task to pick up the triggers immediately logger.info(f"Sending indexing check task with priority {priority}") client_app.send_task( OnyxCeleryTask.CHECK_FOR_INDEXING, priority=priority, kwargs={"tenant_id": tenant_id}, ) return num_triggers ================================================ FILE: backend/onyx/server/documents/credential.py ================================================ import json from fastapi import APIRouter from fastapi import Depends from fastapi import File from fastapi import Form from fastapi import HTTPException from fastapi import Query from fastapi import UploadFile from sqlalchemy.orm import Session from onyx.auth.users import current_admin_user from onyx.auth.users import current_curator_or_admin_user from onyx.auth.users import current_user from onyx.configs.constants import PUBLIC_API_TAGS from onyx.connectors.factory import validate_ccpair_for_user from onyx.db.credentials import alter_credential from onyx.db.credentials import cleanup_gmail_credentials from onyx.db.credentials import create_credential from onyx.db.credentials import CREDENTIAL_PERMISSIONS_TO_IGNORE from onyx.db.credentials import delete_credential from onyx.db.credentials import delete_credential_for_user from onyx.db.credentials import fetch_credential_by_id_for_user from onyx.db.credentials import fetch_credentials_by_source_for_user from onyx.db.credentials import fetch_credentials_for_user from onyx.db.credentials import swap_credentials_connector from onyx.db.credentials import update_credential from onyx.db.engine.sql_engine import get_session from onyx.db.models import DocumentSource from onyx.db.models import User from onyx.server.documents.models import CredentialBase from onyx.server.documents.models import CredentialDataUpdateRequest from onyx.server.documents.models import CredentialSnapshot from onyx.server.documents.models import CredentialSwapRequest from onyx.server.documents.models import ObjectCreationIdResponse from onyx.server.documents.private_key_types import FILE_TYPE_TO_FILE_PROCESSOR from onyx.server.documents.private_key_types import PrivateKeyFileTypes from onyx.server.documents.private_key_types import ProcessPrivateKeyFileProtocol from onyx.server.models import StatusResponse from onyx.utils.logger import setup_logger from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop logger = setup_logger() router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS) def _ignore_credential_permissions(source: DocumentSource) -> bool: return source in CREDENTIAL_PERMISSIONS_TO_IGNORE """Admin-only endpoints""" @router.get("/admin/credential") def list_credentials_admin( user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> list[CredentialSnapshot]: """Lists all public credentials""" credentials = fetch_credentials_for_user( db_session=db_session, user=user, get_editable=False, ) return [ CredentialSnapshot.from_credential_db_model(credential) for credential in credentials ] @router.get("/admin/similar-credentials/{source_type}") def get_cc_source_full_info( source_type: DocumentSource, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), get_editable: bool = Query( False, description="If true, return editable credentials" ), ) -> list[CredentialSnapshot]: credentials = fetch_credentials_by_source_for_user( db_session=db_session, user=user, document_source=source_type, get_editable=get_editable, ) return [ CredentialSnapshot.from_credential_db_model(credential) for credential in credentials ] @router.delete("/admin/credential/{credential_id}") def delete_credential_by_id_admin( credential_id: int, _: User = Depends(current_admin_user), db_session: Session = Depends(get_session), ) -> StatusResponse: """Same as the user endpoint, but can delete any credential (not just the user's own)""" delete_credential(db_session=db_session, credential_id=credential_id) return StatusResponse( success=True, message="Credential deleted successfully", data=credential_id ) @router.put("/admin/credential/swap") def swap_credentials_for_connector( credential_swap_req: CredentialSwapRequest, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> StatusResponse: validate_ccpair_for_user( credential_swap_req.connector_id, credential_swap_req.new_credential_id, credential_swap_req.access_type, db_session, ) connector_credential_pair = swap_credentials_connector( new_credential_id=credential_swap_req.new_credential_id, connector_id=credential_swap_req.connector_id, db_session=db_session, user=user, ) return StatusResponse( success=True, message="Credential swapped successfully", data=connector_credential_pair.id, ) @router.post("/credential") def create_credential_from_model( credential_info: CredentialBase, user: User = Depends(current_curator_or_admin_user), db_session: Session = Depends(get_session), ) -> ObjectCreationIdResponse: if not _ignore_credential_permissions(credential_info.source): fetch_ee_implementation_or_noop( "onyx.db.user_group", "validate_object_creation_for_user", None )( db_session=db_session, user=user, target_group_ids=credential_info.groups, object_is_public=credential_info.curator_public, ) # Temporary fix for empty Google App credentials if credential_info.source == DocumentSource.GMAIL: cleanup_gmail_credentials(db_session=db_session) credential = create_credential(credential_info, user, db_session) return ObjectCreationIdResponse( id=credential.id, credential=CredentialSnapshot.from_credential_db_model(credential), ) @router.post("/credential/private-key") def create_credential_with_private_key( credential_json: str = Form(...), admin_public: bool = Form(False), curator_public: bool = Form(False), groups: list[int] = Form([]), name: str | None = Form(None), source: str = Form(...), user: User = Depends(current_curator_or_admin_user), uploaded_file: UploadFile = File(...), field_key: str = Form(...), type_definition_key: str = Form(...), db_session: Session = Depends(get_session), ) -> ObjectCreationIdResponse: try: credential_data = json.loads(credential_json) except json.JSONDecodeError as e: raise HTTPException( status_code=400, detail=f"Invalid JSON in credential_json: {str(e)}", ) private_key_processor: ProcessPrivateKeyFileProtocol | None = ( FILE_TYPE_TO_FILE_PROCESSOR.get(PrivateKeyFileTypes(type_definition_key)) ) if private_key_processor is None: raise HTTPException( status_code=400, detail="Invalid type definition key for private key file", ) private_key_content: str = private_key_processor(uploaded_file) credential_data[field_key] = private_key_content credential_info = CredentialBase( credential_json=credential_data, admin_public=admin_public, curator_public=curator_public, groups=groups, name=name, source=DocumentSource(source), ) if not _ignore_credential_permissions(DocumentSource(source)): fetch_ee_implementation_or_noop( "onyx.db.user_group", "validate_object_creation_for_user", None )( db_session=db_session, user=user, target_group_ids=groups, object_is_public=curator_public, ) # Temporary fix for empty Google App credentials if DocumentSource(source) == DocumentSource.GMAIL: cleanup_gmail_credentials(db_session=db_session) credential = create_credential(credential_info, user, db_session) return ObjectCreationIdResponse( id=credential.id, credential=CredentialSnapshot.from_credential_db_model(credential), ) """Endpoints for all""" @router.get("/credential") def list_credentials( user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> list[CredentialSnapshot]: credentials = fetch_credentials_for_user(db_session=db_session, user=user) return [ CredentialSnapshot.from_credential_db_model(credential) for credential in credentials ] @router.get("/credential/{credential_id}") def get_credential_by_id( credential_id: int, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> CredentialSnapshot | StatusResponse[int]: credential = fetch_credential_by_id_for_user( credential_id, user, db_session, get_editable=False, ) if credential is None: raise HTTPException( status_code=401, detail=f"Credential {credential_id} does not exist or does not belong to user", ) return CredentialSnapshot.from_credential_db_model(credential) @router.put("/admin/credential/{credential_id}") def update_credential_data( credential_id: int, credential_update: CredentialDataUpdateRequest, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> CredentialBase: credential = alter_credential( credential_id, credential_update.name, credential_update.credential_json, user, db_session, ) if credential is None: raise HTTPException( status_code=401, detail=f"Credential {credential_id} does not exist or does not belong to user", ) return CredentialSnapshot.from_credential_db_model(credential) @router.put("/admin/credential/private-key/{credential_id}") def update_credential_private_key( credential_id: int, name: str = Form(...), credential_json: str = Form(...), uploaded_file: UploadFile = File(...), field_key: str = Form(...), type_definition_key: str = Form(...), user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> CredentialBase: try: credential_data = json.loads(credential_json) except json.JSONDecodeError as e: raise HTTPException( status_code=400, detail=f"Invalid JSON in credential_json: {str(e)}", ) private_key_processor: ProcessPrivateKeyFileProtocol | None = ( FILE_TYPE_TO_FILE_PROCESSOR.get(PrivateKeyFileTypes(type_definition_key)) ) if private_key_processor is None: raise HTTPException( status_code=400, detail="Invalid type definition key for private key file", ) private_key_content: str = private_key_processor(uploaded_file) credential_data[field_key] = private_key_content credential = alter_credential( credential_id, name, credential_data, user, db_session, ) if credential is None: raise HTTPException( status_code=401, detail=f"Credential {credential_id} does not exist or does not belong to user", ) return CredentialSnapshot.from_credential_db_model(credential) @router.patch("/credential/{credential_id}") def update_credential_from_model( credential_id: int, credential_data: CredentialBase, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> CredentialSnapshot | StatusResponse[int]: updated_credential = update_credential( credential_id, credential_data, user, db_session ) if updated_credential is None: raise HTTPException( status_code=401, detail=f"Credential {credential_id} does not exist or does not belong to user", ) # Get credential_json value - use masking for API responses credential_json_value = ( updated_credential.credential_json.get_value(apply_mask=True) if updated_credential.credential_json else {} ) return CredentialSnapshot( source=updated_credential.source, id=updated_credential.id, credential_json=credential_json_value, user_id=updated_credential.user_id, name=updated_credential.name, admin_public=updated_credential.admin_public, time_created=updated_credential.time_created, time_updated=updated_credential.time_updated, curator_public=updated_credential.curator_public, ) @router.delete("/credential/{credential_id}") def delete_credential_by_id( credential_id: int, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> StatusResponse: delete_credential_for_user( credential_id, user, db_session, ) return StatusResponse( success=True, message="Credential deleted successfully", data=credential_id ) @router.delete("/credential/force/{credential_id}") def force_delete_credential_by_id( credential_id: int, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> StatusResponse: delete_credential_for_user(credential_id, user, db_session, True) return StatusResponse( success=True, message="Credential deleted successfully", data=credential_id ) ================================================ FILE: backend/onyx/server/documents/document.py ================================================ from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import Query from sqlalchemy.orm import Session from onyx.auth.users import current_user from onyx.context.search.models import IndexFilters from onyx.context.search.preprocessing.access_filters import ( build_access_filters_for_user, ) from onyx.db.engine.sql_engine import get_session from onyx.db.models import User from onyx.db.search_settings import get_current_search_settings from onyx.document_index.factory import get_default_document_index from onyx.document_index.interfaces import VespaChunkRequest from onyx.natural_language_processing.utils import get_tokenizer from onyx.prompts.prompt_utils import build_doc_context_str from onyx.server.documents.models import ChunkInfo from onyx.server.documents.models import DocumentInfo from onyx.server.utils_vector_db import require_vector_db router = APIRouter(prefix="/document") # Have to use a query parameter as FastAPI is interpreting the URL type document_ids # as a different path @router.get("/document-size-info", dependencies=[Depends(require_vector_db)]) def get_document_info( document_id: str = Query(...), user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> DocumentInfo: search_settings = get_current_search_settings(db_session) # This flow is for search so we do not get all indices. document_index = get_default_document_index(search_settings, None, db_session) user_acl_filters = build_access_filters_for_user(user, db_session) inference_chunks = document_index.id_based_retrieval( chunk_requests=[VespaChunkRequest(document_id=document_id)], filters=IndexFilters(access_control_list=user_acl_filters), ) if not inference_chunks: raise HTTPException(status_code=404, detail="Document not found") contents = [chunk.content for chunk in inference_chunks] combined_contents = "\n".join(contents) # get actual document context used for LLM first_chunk = inference_chunks[0] tokenizer_encode = get_tokenizer( provider_type=search_settings.provider_type, model_name=search_settings.model_name, ).encode full_context_str = build_doc_context_str( semantic_identifier=first_chunk.semantic_identifier, source_type=first_chunk.source_type, content=combined_contents, metadata_dict=first_chunk.metadata, updated_at=first_chunk.updated_at, ind=0, ) return DocumentInfo( num_chunks=len(inference_chunks), num_tokens=len(tokenizer_encode(full_context_str)), ) @router.get("/chunk-info", dependencies=[Depends(require_vector_db)]) def get_chunk_info( document_id: str = Query(...), chunk_id: int = Query(...), user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> ChunkInfo: search_settings = get_current_search_settings(db_session) # This flow is for search so we do not get all indices. document_index = get_default_document_index(search_settings, None, db_session) user_acl_filters = build_access_filters_for_user(user, db_session) chunk_request = VespaChunkRequest( document_id=document_id, min_chunk_ind=chunk_id, max_chunk_ind=chunk_id, ) inference_chunks = document_index.id_based_retrieval( chunk_requests=[chunk_request], filters=IndexFilters(access_control_list=user_acl_filters), batch_retrieval=True, ) if not inference_chunks: raise HTTPException(status_code=404, detail="Chunk not found") chunk_content = inference_chunks[0].content tokenizer_encode = get_tokenizer( provider_type=search_settings.provider_type, model_name=search_settings.model_name, ).encode return ChunkInfo( content=chunk_content, num_tokens=len(tokenizer_encode(chunk_content)) ) ================================================ FILE: backend/onyx/server/documents/document_utils.py ================================================ from cryptography.hazmat.primitives.serialization import pkcs12 from onyx.utils.logger import setup_logger logger = setup_logger() def _is_password_related_error(error: Exception) -> bool: """ Check if the exception indicates a password-related issue rather than a format issue. """ error_msg = str(error).lower() password_keywords = ["mac", "integrity", "password", "authentication", "verify"] return any(keyword in error_msg for keyword in password_keywords) def validate_pkcs12_content(file_bytes: bytes) -> bool: """ Validate that the file content is actually a PKCS#12 file. This performs basic format validation without requiring passwords. """ try: # Basic file size check if len(file_bytes) < 10: logger.debug("File too small to be a valid PKCS#12 file") return False # Check for PKCS#12 magic bytes/ASN.1 structure # PKCS#12 files start with ASN.1 SEQUENCE tag (0x30) if file_bytes[0] != 0x30: logger.debug("File does not start with ASN.1 SEQUENCE tag") return False # Try to parse the outer ASN.1 structure without password validation # This checks if the file has the basic PKCS#12 structure try: # Attempt to load just to validate the basic format # We expect this to fail due to password, but it should fail with a specific error pkcs12.load_key_and_certificates(file_bytes, password=None) return True except ValueError as e: # Check if the error is related to password (expected) vs format issues if _is_password_related_error(e): # These errors indicate the file format is correct but password is wrong/missing logger.debug( f"PKCS#12 format appears valid, password-related error: {e}" ) return True else: # Other ValueError likely indicates format issues logger.debug(f"PKCS#12 format validation failed: {e}") return False except Exception as e: # Try with empty password as fallback try: pkcs12.load_key_and_certificates(file_bytes, password=b"") return True except ValueError as e2: if _is_password_related_error(e2): logger.debug( f"PKCS#12 format appears valid with empty password attempt: {e2}" ) return True else: logger.debug( f"PKCS#12 validation failed on both attempts: {e}, {e2}" ) return False except Exception: logger.debug(f"PKCS#12 validation failed: {e}") return False except Exception as e: logger.debug(f"Unexpected error during PKCS#12 validation: {e}") return False ================================================ FILE: backend/onyx/server/documents/models.py ================================================ from collections.abc import Sequence from datetime import datetime from datetime import timezone from datetime import UTC from enum import Enum from typing import Any from typing import Generic from typing import TypeVar from uuid import UUID from pydantic import BaseModel from pydantic import ConfigDict from pydantic import Field from onyx.configs.app_configs import MASK_CREDENTIAL_PREFIX from onyx.configs.constants import DocumentSource from onyx.connectors.models import InputType from onyx.db.enums import AccessType from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import PermissionSyncStatus from onyx.db.enums import ProcessingMode from onyx.db.models import Connector from onyx.db.models import ConnectorCredentialPair from onyx.db.models import Credential from onyx.db.models import DocPermissionSyncAttempt from onyx.db.models import Document as DbDocument from onyx.db.models import IndexAttempt from onyx.db.models import IndexingStatus from onyx.db.models import TaskStatus from onyx.server.federated.models import FederatedConnectorStatus from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop class DocumentSyncStatus(BaseModel): doc_id: str last_synced: datetime | None last_modified: datetime | None @classmethod def from_model(cls, doc: DbDocument) -> "DocumentSyncStatus": return DocumentSyncStatus( doc_id=doc.id, last_synced=doc.last_synced, last_modified=doc.last_modified, ) class DocumentInfo(BaseModel): num_chunks: int num_tokens: int class ChunkInfo(BaseModel): content: str num_tokens: int class IndexedSourcesResponse(BaseModel): model_config = ConfigDict(use_enum_values=True) sources: list[DocumentSource] class DeletionAttemptSnapshot(BaseModel): connector_id: int credential_id: int status: TaskStatus class ConnectorBase(BaseModel): name: str source: DocumentSource input_type: InputType connector_specific_config: dict[str, Any] # In seconds, None for one time index with no refresh refresh_freq: int | None = None prune_freq: int | None = None indexing_start: datetime | None = None class ConnectorUpdateRequest(ConnectorBase): access_type: AccessType groups: list[int] = Field(default_factory=list) def to_connector_base(self) -> ConnectorBase: return ConnectorBase(**self.model_dump(exclude={"access_type", "groups"})) class ConnectorSnapshot(ConnectorBase): id: int credential_ids: list[int] time_created: datetime time_updated: datetime source: DocumentSource @classmethod def from_connector_db_model( cls, connector: Connector, credential_ids: list[int] | None = None ) -> "ConnectorSnapshot": return ConnectorSnapshot( id=connector.id, name=connector.name, source=connector.source, input_type=connector.input_type, connector_specific_config=connector.connector_specific_config, refresh_freq=connector.refresh_freq, prune_freq=connector.prune_freq, credential_ids=( credential_ids or [association.credential.id for association in connector.credentials] ), indexing_start=connector.indexing_start, time_created=connector.time_created, time_updated=connector.time_updated, ) class CredentialSwapRequest(BaseModel): new_credential_id: int connector_id: int access_type: AccessType class CredentialDataUpdateRequest(BaseModel): name: str credential_json: dict[str, Any] class CredentialBase(BaseModel): credential_json: dict[str, Any] # if `true`, then all Admins will have access to the credential admin_public: bool source: DocumentSource name: str | None = None curator_public: bool = False groups: list[int] = Field(default_factory=list) class CredentialSnapshot(CredentialBase): id: int user_id: UUID | None user_email: str | None = None time_created: datetime time_updated: datetime @classmethod def from_credential_db_model(cls, credential: Credential) -> "CredentialSnapshot": # Get the credential_json value with appropriate masking if credential.credential_json is None: credential_json_value: dict[str, Any] = {} elif MASK_CREDENTIAL_PREFIX: credential_json_value = credential.credential_json.get_value( apply_mask=True ) else: credential_json_value = credential.credential_json.get_value( apply_mask=False ) return CredentialSnapshot( id=credential.id, credential_json=credential_json_value, user_id=credential.user_id, user_email=credential.user.email if credential.user else None, admin_public=credential.admin_public, time_created=credential.time_created, time_updated=credential.time_updated, source=credential.source or DocumentSource.NOT_APPLICABLE, name=credential.name, curator_public=credential.curator_public, ) class IndexAttemptSnapshot(BaseModel): id: int status: IndexingStatus | None from_beginning: bool new_docs_indexed: int # only includes completely new docs total_docs_indexed: int # includes docs that are updated docs_removed_from_index: int error_msg: str | None error_count: int full_exception_trace: str | None time_started: str | None time_updated: str poll_range_start: datetime | None = None poll_range_end: datetime | None = None @classmethod def from_index_attempt_db_model( cls, index_attempt: IndexAttempt ) -> "IndexAttemptSnapshot": return IndexAttemptSnapshot( id=index_attempt.id, status=index_attempt.status, from_beginning=index_attempt.from_beginning, new_docs_indexed=index_attempt.new_docs_indexed or 0, total_docs_indexed=index_attempt.total_docs_indexed or 0, docs_removed_from_index=index_attempt.docs_removed_from_index or 0, error_msg=index_attempt.error_msg, error_count=len(index_attempt.error_rows), full_exception_trace=index_attempt.full_exception_trace, time_started=( index_attempt.time_started.isoformat() if index_attempt.time_started else None ), time_updated=index_attempt.time_updated.isoformat(), poll_range_start=index_attempt.poll_range_start, poll_range_end=index_attempt.poll_range_end, ) # These are the types currently supported by the pagination hook # More api endpoints can be refactored and be added here for use with the pagination hook PaginatedType = TypeVar("PaginatedType", bound=BaseModel) class PermissionSyncAttemptSnapshot(BaseModel): id: int status: PermissionSyncStatus error_message: str | None total_docs_synced: int docs_with_permission_errors: int time_created: str time_started: str | None time_finished: str | None @classmethod def from_permission_sync_attempt_db_model( cls, attempt: DocPermissionSyncAttempt ) -> "PermissionSyncAttemptSnapshot": return PermissionSyncAttemptSnapshot( id=attempt.id, status=attempt.status, error_message=attempt.error_message, total_docs_synced=attempt.total_docs_synced or 0, docs_with_permission_errors=attempt.docs_with_permission_errors or 0, time_created=attempt.time_created.isoformat(), time_started=( attempt.time_started.isoformat() if attempt.time_started else None ), time_finished=( attempt.time_finished.isoformat() if attempt.time_finished else None ), ) class PaginatedReturn(BaseModel, Generic[PaginatedType]): items: list[PaginatedType] total_items: int class CCPairFullInfo(BaseModel): id: int name: str status: ConnectorCredentialPairStatus in_repeated_error_state: bool num_docs_indexed: int connector: ConnectorSnapshot credential: CredentialSnapshot number_of_index_attempts: int last_index_attempt_status: IndexingStatus | None latest_deletion_attempt: DeletionAttemptSnapshot | None access_type: AccessType is_editable_for_current_user: bool deletion_failure_message: str | None indexing: bool creator: UUID | None creator_email: str | None # information on syncing/indexing last_indexed: datetime | None last_pruned: datetime | None # accounts for both doc sync and group sync last_full_permission_sync: datetime | None overall_indexing_speed: float | None latest_checkpoint_description: str | None # permission sync attempt status last_permission_sync_attempt_status: PermissionSyncStatus | None permission_syncing: bool last_permission_sync_attempt_finished: datetime | None last_permission_sync_attempt_error_message: str | None @classmethod def _get_last_full_permission_sync( cls, cc_pair_model: ConnectorCredentialPair ) -> datetime | None: check_if_source_requires_external_group_sync = fetch_ee_implementation_or_noop( "onyx.external_permissions.sync_params", "source_requires_external_group_sync", noop_return_value=False, ) check_if_source_requires_doc_sync = fetch_ee_implementation_or_noop( "onyx.external_permissions.sync_params", "source_requires_doc_sync", noop_return_value=False, ) needs_group_sync = check_if_source_requires_external_group_sync( cc_pair_model.connector.source ) needs_doc_sync = check_if_source_requires_doc_sync( cc_pair_model.connector.source ) last_group_sync = ( cc_pair_model.last_time_external_group_sync if needs_group_sync else datetime.now(UTC) ) last_doc_sync = ( cc_pair_model.last_time_perm_sync if needs_doc_sync else datetime.now(UTC) ) # if either is still None at this point, it means sync is necessary but # has never completed. if last_group_sync is None or last_doc_sync is None: return None return min(last_group_sync, last_doc_sync) @classmethod def from_models( cls, cc_pair_model: ConnectorCredentialPair, latest_deletion_attempt: DeletionAttemptSnapshot | None, number_of_index_attempts: int, last_index_attempt: IndexAttempt | None, num_docs_indexed: int, # not ideal, but this must be computed separately is_editable_for_current_user: bool, indexing: bool, last_successful_index_time: datetime | None = None, last_permission_sync_attempt_status: PermissionSyncStatus | None = None, permission_syncing: bool = False, last_permission_sync_attempt_finished: datetime | None = None, last_permission_sync_attempt_error_message: str | None = None, ) -> "CCPairFullInfo": # figure out if we need to artificially deflate the number of docs indexed. # This is required since the total number of docs indexed by a CC Pair is # updated before the new docs for an indexing attempt. If we don't do this, # there is a mismatch between these two numbers which may confuse users. last_indexing_status = last_index_attempt.status if last_index_attempt else None if ( # only need to do this if the last indexing attempt is still in progress last_indexing_status == IndexingStatus.IN_PROGRESS and number_of_index_attempts == 1 and last_index_attempt and last_index_attempt.new_docs_indexed ): num_docs_indexed = ( last_index_attempt.new_docs_indexed if last_index_attempt else 0 ) overall_indexing_speed = num_docs_indexed / ( ( datetime.now(tz=timezone.utc) - cc_pair_model.connector.time_created ).total_seconds() / 60 ) return cls( id=cc_pair_model.id, name=cc_pair_model.name, status=cc_pair_model.status, in_repeated_error_state=cc_pair_model.in_repeated_error_state, num_docs_indexed=num_docs_indexed, connector=ConnectorSnapshot.from_connector_db_model( cc_pair_model.connector, credential_ids=[cc_pair_model.credential_id], ), credential=CredentialSnapshot.from_credential_db_model( cc_pair_model.credential ), number_of_index_attempts=number_of_index_attempts, last_index_attempt_status=last_indexing_status, latest_deletion_attempt=latest_deletion_attempt, access_type=cc_pair_model.access_type, is_editable_for_current_user=is_editable_for_current_user, deletion_failure_message=cc_pair_model.deletion_failure_message, indexing=indexing, creator=cc_pair_model.creator_id, creator_email=( cc_pair_model.creator.email if cc_pair_model.creator else None ), last_indexed=last_successful_index_time, last_pruned=cc_pair_model.last_pruned, last_full_permission_sync=cls._get_last_full_permission_sync(cc_pair_model), overall_indexing_speed=overall_indexing_speed, latest_checkpoint_description=None, last_permission_sync_attempt_status=last_permission_sync_attempt_status, permission_syncing=permission_syncing, last_permission_sync_attempt_finished=last_permission_sync_attempt_finished, last_permission_sync_attempt_error_message=last_permission_sync_attempt_error_message, ) class CeleryTaskStatus(BaseModel): id: str name: str status: TaskStatus start_time: datetime | None register_time: datetime | None class FailedConnectorIndexingStatus(BaseModel): """Simplified version of ConnectorIndexingStatus for failed indexing attempts""" cc_pair_id: int name: str error_msg: str | None is_deletable: bool connector_id: int credential_id: int class ConnectorStatus(BaseModel): """ Represents the status of a connector, including indexing status elated information """ cc_pair_id: int name: str connector: ConnectorSnapshot credential: CredentialSnapshot access_type: AccessType groups: list[int] class ConnectorIndexingStatus(ConnectorStatus): """Represents the full indexing status of a connector""" cc_pair_status: ConnectorCredentialPairStatus # this is separate from the `status` above, since a connector can be `INITIAL_INDEXING`, `ACTIVE`, # or `PAUSED` and still be in a repeated error state. in_repeated_error_state: bool owner: str last_finished_status: IndexingStatus | None last_status: IndexingStatus | None last_success: datetime | None latest_index_attempt: IndexAttemptSnapshot | None docs_indexed: int in_progress: bool class DocsCountOperator(str, Enum): GREATER_THAN = ">" LESS_THAN = "<" EQUAL_TO = "=" class ConnectorIndexingStatusLite(BaseModel): cc_pair_id: int name: str source: DocumentSource access_type: AccessType cc_pair_status: ConnectorCredentialPairStatus in_progress: bool in_repeated_error_state: bool last_finished_status: IndexingStatus | None last_status: IndexingStatus | None last_success: datetime | None is_editable: bool docs_indexed: int latest_index_attempt_docs_indexed: int | None class SourceSummary(BaseModel): total_connectors: int active_connectors: int public_connectors: int total_docs_indexed: int class ConnectorIndexingStatusLiteResponse(BaseModel): source: DocumentSource summary: SourceSummary current_page: int total_pages: int indexing_statuses: Sequence[ConnectorIndexingStatusLite | FederatedConnectorStatus] class ConnectorCredentialPairIdentifier(BaseModel): connector_id: int credential_id: int class ConnectorCredentialPairMetadata(BaseModel): name: str access_type: AccessType auto_sync_options: dict[str, Any] | None = None groups: list[int] = Field(default_factory=list) processing_mode: ProcessingMode = ProcessingMode.REGULAR class CCStatusUpdateRequest(BaseModel): status: ConnectorCredentialPairStatus class ConnectorCredentialPairDescriptor(BaseModel): id: int name: str connector: ConnectorSnapshot credential: CredentialSnapshot access_type: AccessType class CCPairSummary(BaseModel): """Simplified connector-credential pair information with just essential data""" id: int name: str source: DocumentSource access_type: AccessType @classmethod def from_cc_pair_descriptor( cls, descriptor: ConnectorCredentialPairDescriptor ) -> "CCPairSummary": return cls( id=descriptor.id, name=descriptor.name, source=descriptor.connector.source, access_type=descriptor.access_type, ) class RunConnectorRequest(BaseModel): connector_id: int credential_ids: list[int] | None = None from_beginning: bool = False class ConnectorRequestSubmission(BaseModel): connector_name: str class CCPropertyUpdateRequest(BaseModel): name: str value: str """Connectors Models""" class GoogleAppWebCredentials(BaseModel): client_id: str project_id: str auth_uri: str token_uri: str auth_provider_x509_cert_url: str client_secret: str redirect_uris: list[str] javascript_origins: list[str] class GoogleAppCredentials(BaseModel): web: GoogleAppWebCredentials class GoogleServiceAccountKey(BaseModel): type: str project_id: str private_key_id: str private_key: str client_email: str client_id: str auth_uri: str token_uri: str auth_provider_x509_cert_url: str client_x509_cert_url: str universe_domain: str class GoogleServiceAccountCredentialRequest(BaseModel): google_primary_admin: str | None = None # email of user to impersonate class FileUploadResponse(BaseModel): file_paths: list[str] file_names: list[str] zip_metadata_file_id: str | None # File ID pointing to metadata in file store class ConnectorFileInfo(BaseModel): file_id: str file_name: str file_size: int | None = None upload_date: str | None = None class ConnectorFilesResponse(BaseModel): files: list[ConnectorFileInfo] class ObjectCreationIdResponse(BaseModel): id: int credential: CredentialSnapshot | None = None class AuthStatus(BaseModel): authenticated: bool class AuthUrl(BaseModel): auth_url: str class GmailCallback(BaseModel): state: str code: str class GDriveCallback(BaseModel): state: str code: str class IndexingStatusRequest(BaseModel): secondary_index: bool = False source: DocumentSource | None = None access_type_filters: list[AccessType] = Field(default_factory=list) last_status_filters: list[IndexingStatus] = Field(default_factory=list) docs_count_operator: DocsCountOperator | None = None docs_count_value: int | None = None name_filter: str | None = None source_to_page: dict[DocumentSource, int] = Field(default_factory=dict) get_all_connectors: bool = False ================================================ FILE: backend/onyx/server/documents/private_key_types.py ================================================ import base64 from enum import Enum from typing import Protocol from fastapi import HTTPException from fastapi import UploadFile from onyx.server.documents.document_utils import validate_pkcs12_content class ProcessPrivateKeyFileProtocol(Protocol): def __call__(self, file: UploadFile) -> str: """ Accepts a file-like object, validates the file (e.g., checks extension and content), and returns its contents as a base64-encoded string if valid. Raises an exception if validation fails. """ ... class PrivateKeyFileTypes(Enum): SHAREPOINT_PFX_FILE = "sharepoint_pfx_file" def process_sharepoint_private_key_file(file: UploadFile) -> str: """ Process and validate a private key file upload. Validates both the file extension and file content to ensure it's a valid PKCS#12 file. Content validation prevents attacks that rely on file extension spoofing. """ # First check file extension (basic filter) if not (file.filename and file.filename.lower().endswith(".pfx")): raise HTTPException( status_code=400, detail="Invalid file type. Only .pfx files are supported." ) # Read file content for validation and processing private_key_bytes = file.file.read() # Validate file content to prevent extension spoofing attacks if not validate_pkcs12_content(private_key_bytes): raise HTTPException( status_code=400, detail="Invalid file content. The uploaded file does not appear to be a valid PKCS#12 (.pfx) file.", ) # Convert to base64 if validation passes pfx_64 = base64.b64encode(private_key_bytes).decode("ascii") return pfx_64 FILE_TYPE_TO_FILE_PROCESSOR: dict[ PrivateKeyFileTypes, ProcessPrivateKeyFileProtocol ] = { PrivateKeyFileTypes.SHAREPOINT_PFX_FILE: process_sharepoint_private_key_file, } ================================================ FILE: backend/onyx/server/documents/standard_oauth.py ================================================ import json import uuid from typing import Annotated from typing import cast from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import Query from fastapi import Request from pydantic import BaseModel from pydantic import ValidationError from sqlalchemy.orm import Session from onyx.auth.users import current_user from onyx.configs.app_configs import WEB_DOMAIN from onyx.configs.constants import DocumentSource from onyx.connectors.interfaces import OAuthConnector from onyx.db.credentials import create_credential from onyx.db.engine.sql_engine import get_session from onyx.db.models import User from onyx.redis.redis_pool import get_redis_client from onyx.server.documents.models import CredentialBase from onyx.utils.logger import setup_logger from onyx.utils.subclasses import find_all_subclasses_in_package from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() router = APIRouter(prefix="/connector/oauth") _OAUTH_STATE_KEY_FMT = "oauth_state:{state}" _OAUTH_STATE_EXPIRATION_SECONDS = 10 * 60 # 10 minutes _DESIRED_RETURN_URL_KEY = "desired_return_url" _ADDITIONAL_KWARGS_KEY = "additional_kwargs" # Cache for OAuth connectors, populated at module load time _OAUTH_CONNECTORS: dict[DocumentSource, type[OAuthConnector]] = {} def _discover_oauth_connectors() -> dict[DocumentSource, type[OAuthConnector]]: """Walk through the connectors package to find all OAuthConnector implementations""" global _OAUTH_CONNECTORS if _OAUTH_CONNECTORS: # Return cached connectors if already discovered return _OAUTH_CONNECTORS # Import submodules using package-based discovery to avoid sys.path mutations oauth_connectors = find_all_subclasses_in_package( cast(type[OAuthConnector], OAuthConnector), "onyx.connectors" ) _OAUTH_CONNECTORS = {cls.oauth_id(): cls for cls in oauth_connectors} return _OAUTH_CONNECTORS # Discover OAuth connectors at module load time _discover_oauth_connectors() def _get_additional_kwargs( request: Request, connector_cls: type[OAuthConnector], args_to_ignore: list[str] ) -> dict[str, str]: # get additional kwargs from request # e.g. anything except for desired_return_url additional_kwargs_dict = { k: v for k, v in request.query_params.items() if k not in args_to_ignore } try: # validate connector_cls.AdditionalOauthKwargs(**additional_kwargs_dict) except ValidationError: raise HTTPException( status_code=400, detail=( f"Invalid additional kwargs. Got {additional_kwargs_dict}, expected " f"{connector_cls.AdditionalOauthKwargs.model_json_schema()}" ), ) return additional_kwargs_dict class AuthorizeResponse(BaseModel): redirect_url: str @router.get("/authorize/{source}") def oauth_authorize( request: Request, source: DocumentSource, desired_return_url: Annotated[str | None, Query()] = None, _: User = Depends(current_user), ) -> AuthorizeResponse: """Initiates the OAuth flow by redirecting to the provider's auth page""" tenant_id = get_current_tenant_id() oauth_connectors = _discover_oauth_connectors() if source not in oauth_connectors: raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}") connector_cls = oauth_connectors[source] base_url = WEB_DOMAIN # get additional kwargs from request # e.g. anything except for desired_return_url additional_kwargs = _get_additional_kwargs( request, connector_cls, ["desired_return_url"] ) # store state in redis if not desired_return_url: desired_return_url = f"{base_url}/admin/connectors/{source}?step=0" redis_client = get_redis_client(tenant_id=tenant_id) state = str(uuid.uuid4()) redis_client.set( _OAUTH_STATE_KEY_FMT.format(state=state), json.dumps( { _DESIRED_RETURN_URL_KEY: desired_return_url, _ADDITIONAL_KWARGS_KEY: additional_kwargs, } ), ex=_OAUTH_STATE_EXPIRATION_SECONDS, ) return AuthorizeResponse( redirect_url=connector_cls.oauth_authorization_url( base_url, state, additional_kwargs ) ) class CallbackResponse(BaseModel): redirect_url: str @router.get("/callback/{source}") def oauth_callback( source: DocumentSource, code: Annotated[str, Query()], state: Annotated[str, Query()], db_session: Session = Depends(get_session), user: User = Depends(current_user), ) -> CallbackResponse: """Handles the OAuth callback and exchanges the code for tokens""" oauth_connectors = _discover_oauth_connectors() if source not in oauth_connectors: raise HTTPException(status_code=400, detail=f"Unknown OAuth source: {source}") connector_cls = oauth_connectors[source] # get state from redis redis_client = get_redis_client() oauth_state_bytes = cast( bytes, redis_client.get(_OAUTH_STATE_KEY_FMT.format(state=state)) ) if not oauth_state_bytes: raise HTTPException(status_code=400, detail="Invalid OAuth state") oauth_state = json.loads(oauth_state_bytes.decode("utf-8")) desired_return_url = cast(str, oauth_state[_DESIRED_RETURN_URL_KEY]) additional_kwargs = cast(dict[str, str], oauth_state[_ADDITIONAL_KWARGS_KEY]) base_url = WEB_DOMAIN token_info = connector_cls.oauth_code_to_token(base_url, code, additional_kwargs) # Create a new credential with the token info credential_data = CredentialBase( credential_json=token_info, admin_public=True, # Or based on some logic/parameter source=source, name=f"{source.title()} OAuth Credential", ) credential = create_credential( credential_data=credential_data, user=user, db_session=db_session, ) # TODO: use a library for url handling sep = "&" if "?" in desired_return_url else "?" return CallbackResponse( redirect_url=f"{desired_return_url}{sep}credentialId={credential.id}" ) class OAuthAdditionalKwargDescription(BaseModel): name: str display_name: str description: str class OAuthDetails(BaseModel): oauth_enabled: bool additional_kwargs: list[OAuthAdditionalKwargDescription] @router.get("/details/{source}") def oauth_details( source: DocumentSource, _: User = Depends(current_user), ) -> OAuthDetails: oauth_connectors = _discover_oauth_connectors() if source not in oauth_connectors: return OAuthDetails( oauth_enabled=False, additional_kwargs=[], ) connector_cls = oauth_connectors[source] additional_kwarg_descriptions = [] for key, value in connector_cls.AdditionalOauthKwargs.model_json_schema()[ "properties" ].items(): additional_kwarg_descriptions.append( OAuthAdditionalKwargDescription( name=key, display_name=value.get("title", key), description=value.get("description", ""), ) ) return OAuthDetails( oauth_enabled=True, additional_kwargs=additional_kwarg_descriptions, ) ================================================ FILE: backend/onyx/server/evals/__init__.py ================================================ ================================================ FILE: backend/onyx/server/evals/models.py ================================================ from pydantic import BaseModel class EvalRunAck(BaseModel): """Response model for evaluation runs""" success: bool ================================================ FILE: backend/onyx/server/features/__init__.py ================================================ ================================================ FILE: backend/onyx/server/features/build/.gitignore ================================================ sandbox/kubernetes/docker/templates/venv/** sandbox/kubernetes/docker/demo_data/** ================================================ FILE: backend/onyx/server/features/build/AGENTS.template.md ================================================ # AGENTS.md You are an AI agent powering **Onyx Craft**. You create interactive web applications, dashboards, and documents from company knowledge. You run in a secure sandbox with access to the user's knowledge sources. The knowledge sources you have are organization context like meeting notes, emails, slack messages, and other organizational data that you must use to answer your question. {{USER_CONTEXT}} ## Configuration - **LLM**: {{LLM_PROVIDER_NAME}} / {{LLM_MODEL_NAME}} - **Next.js**: Running on port {{NEXTJS_PORT}} (already started — do NOT run `npm run dev`) {{DISABLED_TOOLS_SECTION}} ## Environment Ephemeral VM with Python 3.11 and Node v22. Virtual environment at `.venv/` includes numpy, pandas, matplotlib, scipy. Install packages: `pip install ` or `npm install ` (from `outputs/web`). {{ORG_INFO_SECTION}} ## Skills {{AVAILABLE_SKILLS_SECTION}} Read the relevant SKILL.md before starting work that the skill covers. ## Recommended Task Approach Methodology When presented with a task, you typically: 1. Analyze the request to understand what's being asked 2. Break down complex problems into manageable steps and sub-questions 3. Use appropriate tools and methods to address each step 4. Provide clear communication throughout the process 5. Deliver results in a helpful and organized manner Follow this two-step pattern for most tasks: ### Step 1: Information Retrieval 1. **Search** knowledge sources using `find`, `grep`, or direct file reads. Start your search at the root of the `files/` directory to get a general grasp of what subdirectories to further explore, especially when looking for a person. their name may be a proper noun or strictly lowercase. 2. **Extract** relevant data from JSON documents 3. **Summarize** key findings before proceeding **Tip**: Use `find`, `grep`, or `glob` to search files directly rather than navigating directories one at a time. ### Step 2: Output Generation 1. **Choose format**: Web app for interactive/visual, Markdown for reports, or direct response for quick answers 2. **Build** the output using retrieved information 3. **Verify** the output renders correctly and includes accurate data ## Behavior Guidelines - **Accuracy**: Do not make any assumptions about the user. Any conclusions you reach must be supported by the provided data. - **Completeness**: For any tasks requiring data from the knowledge sources, you should make sure to look at ALL sources that may be relevant to the user's questions and use that in your final response. Make sure you check Google Drive if applicable - **Explicitly state** which sources were checked and which had no relevant data - **Search ALL knowledge sources** for the person's name/email, not just the obvious ones when answering questions about a person's activites. - **Task Management**: For any non-trivial task involving multiple steps, you should organize your work and track progress. This helps users understand what you're doing and ensures nothing is missed. - **Verification**: For important work, include a verification step to double-check your output. This could involve testing functionality, reviewing for accuracy, or validating against requirements. - Critical execution rule: If you say you're about to do something, actually do it in the same turn (run the tool call right after). - Check off completed TODOs before reporting progress. - Your main goal is to follow the USER's instructions at each message - Don't mention tool names to the user; describe actions naturally. ## Knowledge Sources The `files/` directory contains JSON documents from various knowledge sources. Here's what's available: {{KNOWLEDGE_SOURCES_SECTION}} ### Document Format Files are JSON with: `title`, `source`, `metadata`, `sections[{text, link}]`. **Important**: The `files/` directory is read-only. Do NOT attempt to write to it. ## Outputs All outputs go in the `outputs/` directory. | Format | Use For | | ------------ | ---------------------------------------- | | **Web App** | Interactive dashboards, data exploration | | **Markdown** | Reports, analyses, documentation | | **Response** | Quick answers, lookups | You can also generate other output formats if you think they more directly answer the user's question ### Web Apps Use `outputs/web` with Next.js 16.1.1, React v19, Tailwind, Recharts, shadcn/ui. ### Markdown Save to `outputs/markdown/*.md`. Use clear headings and tables. ## Questions to Ask - Did you check all relevant sources that could be useful in addressing the user's question? - Did you generate the correct output format that the user requested? - Did you answer the user's question thoroughly? ================================================ FILE: backend/onyx/server/features/build/__init__.py ================================================ # Build feature module ================================================ FILE: backend/onyx/server/features/build/api/api.py ================================================ import re from collections.abc import Iterator from pathlib import Path from uuid import UUID import httpx from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi import Request from fastapi import Response from fastapi.responses import RedirectResponse from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from onyx.auth.users import current_user from onyx.auth.users import optional_user from onyx.configs.constants import DocumentSource from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user from onyx.db.engine.sql_engine import get_session from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.enums import IndexingStatus from onyx.db.enums import ProcessingMode from onyx.db.enums import SharingScope from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id from onyx.db.models import BuildSession from onyx.db.models import User from onyx.server.features.build.api.messages_api import router as messages_router from onyx.server.features.build.api.models import BuildConnectorInfo from onyx.server.features.build.api.models import BuildConnectorListResponse from onyx.server.features.build.api.models import BuildConnectorStatus from onyx.server.features.build.api.models import RateLimitResponse from onyx.server.features.build.api.rate_limit import get_user_rate_limit_status from onyx.server.features.build.api.sessions_api import router as sessions_router from onyx.server.features.build.api.user_library import router as user_library_router from onyx.server.features.build.db.sandbox import get_sandbox_by_user_id from onyx.server.features.build.sandbox import get_sandbox_manager from onyx.server.features.build.session.manager import SessionManager from onyx.server.features.build.utils import is_onyx_craft_enabled from onyx.utils.logger import setup_logger logger = setup_logger() _TEMPLATES_DIR = Path(__file__).parent / "templates" _WEBAPP_HMR_FIXER_TEMPLATE = (_TEMPLATES_DIR / "webapp_hmr_fixer.js").read_text() def require_onyx_craft_enabled(user: User = Depends(current_user)) -> User: """ Dependency that checks if Onyx Craft is enabled for the user. Raises HTTP 403 if Onyx Craft is disabled via feature flag. """ if not is_onyx_craft_enabled(user): raise HTTPException( status_code=403, detail="Onyx Craft is not available", ) return user router = APIRouter(prefix="/build", dependencies=[Depends(require_onyx_craft_enabled)]) # Include sub-routers for sessions, messages, and user library router.include_router(sessions_router, tags=["build"]) router.include_router(messages_router, tags=["build"]) router.include_router(user_library_router, tags=["build"]) # ----------------------------------------------------------------------------- # Rate Limiting # ----------------------------------------------------------------------------- @router.get("/limit", response_model=RateLimitResponse) def get_rate_limit( user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> RateLimitResponse: """Get rate limit information for the current user.""" return get_user_rate_limit_status(user, db_session) # ----------------------------------------------------------------------------- # Build Connectors # ----------------------------------------------------------------------------- @router.get("/connectors", response_model=BuildConnectorListResponse) def get_build_connectors( user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> BuildConnectorListResponse: """Get all connectors for the build admin panel. Returns connector-credential pairs with simplified status information. On the build configure page, all users (including admins) only see connectors they own/created. Users can create new connectors if they don't have one of a type. """ # Fetch both FILE_SYSTEM (standard connectors) and RAW_BINARY (User Library) connectors file_system_cc_pairs = get_connector_credential_pairs_for_user( db_session=db_session, user=user, get_editable=False, eager_load_connector=True, eager_load_credential=True, processing_mode=ProcessingMode.FILE_SYSTEM, ) raw_binary_cc_pairs = get_connector_credential_pairs_for_user( db_session=db_session, user=user, get_editable=False, eager_load_connector=True, eager_load_credential=True, processing_mode=ProcessingMode.RAW_BINARY, ) cc_pairs = file_system_cc_pairs + raw_binary_cc_pairs # Filter to only show connectors created by the current user # All users (including admins) must create their own connectors on the build configure page if user: cc_pairs = [cc_pair for cc_pair in cc_pairs if cc_pair.creator_id == user.id] connectors: list[BuildConnectorInfo] = [] for cc_pair in cc_pairs: # Skip ingestion API connectors and default pairs if cc_pair.connector.source == DocumentSource.INGESTION_API: continue if cc_pair.name == "DefaultCCPair": continue # Determine status error_message: str | None = None has_ever_succeeded = cc_pair.last_successful_index_time is not None if cc_pair.status == ConnectorCredentialPairStatus.DELETING: status = BuildConnectorStatus.DELETING elif cc_pair.status == ConnectorCredentialPairStatus.INVALID: # If connector has succeeded before but credentials are now invalid, # show as connected_with_errors so user can still disable demo data if has_ever_succeeded: status = BuildConnectorStatus.CONNECTED_WITH_ERRORS error_message = "Connector credentials are invalid" else: status = BuildConnectorStatus.ERROR error_message = "Connector credentials are invalid" else: # Check latest index attempt for errors latest_attempt = get_latest_index_attempt_for_cc_pair_id( db_session=db_session, connector_credential_pair_id=cc_pair.id, secondary_index=False, only_finished=True, ) if latest_attempt and latest_attempt.status == IndexingStatus.FAILED: # If connector has succeeded before but latest attempt failed, # show as connected_with_errors if has_ever_succeeded: status = BuildConnectorStatus.CONNECTED_WITH_ERRORS else: status = BuildConnectorStatus.ERROR error_message = latest_attempt.error_msg elif ( latest_attempt and latest_attempt.status == IndexingStatus.COMPLETED_WITH_ERRORS ): # Completed with errors - if it has succeeded before, show as connected_with_errors if has_ever_succeeded: status = BuildConnectorStatus.CONNECTED_WITH_ERRORS else: status = BuildConnectorStatus.ERROR error_message = "Indexing completed with errors" elif cc_pair.status == ConnectorCredentialPairStatus.PAUSED: status = BuildConnectorStatus.CONNECTED elif cc_pair.last_successful_index_time is None: # Never successfully indexed - check if currently indexing # First check cc_pair status for scheduled/initial indexing if cc_pair.status in ( ConnectorCredentialPairStatus.SCHEDULED, ConnectorCredentialPairStatus.INITIAL_INDEXING, ): status = BuildConnectorStatus.INDEXING else: in_progress_attempt = get_latest_index_attempt_for_cc_pair_id( db_session=db_session, connector_credential_pair_id=cc_pair.id, secondary_index=False, only_finished=False, ) if ( in_progress_attempt and in_progress_attempt.status == IndexingStatus.IN_PROGRESS ): status = BuildConnectorStatus.INDEXING elif ( in_progress_attempt and in_progress_attempt.status == IndexingStatus.NOT_STARTED ): status = BuildConnectorStatus.INDEXING else: # Has a finished attempt but never succeeded - likely error status = BuildConnectorStatus.ERROR error_message = ( latest_attempt.error_msg if latest_attempt else "Initial indexing failed" ) else: status = BuildConnectorStatus.CONNECTED connectors.append( BuildConnectorInfo( cc_pair_id=cc_pair.id, connector_id=cc_pair.connector.id, credential_id=cc_pair.credential.id, source=cc_pair.connector.source.value, name=cc_pair.name or cc_pair.connector.name or "Unnamed", status=status, docs_indexed=0, # Would need to query for this last_indexed=cc_pair.last_successful_index_time, error_message=error_message, ) ) return BuildConnectorListResponse(connectors=connectors) # Headers to skip when proxying. # Hop-by-hop headers must not be forwarded, and set-cookie is stripped to # prevent LLM-generated apps from setting cookies on the parent Onyx domain. EXCLUDED_HEADERS = { "content-encoding", "content-length", "transfer-encoding", "connection", "set-cookie", } def _stream_response(response: httpx.Response) -> Iterator[bytes]: """Stream the response content in chunks.""" for chunk in response.iter_bytes(chunk_size=8192): yield chunk def _inject_hmr_fixer(content: bytes, session_id: str) -> bytes: """Inject a script that stubs root-scoped Next HMR websocket connections.""" base = f"/api/build/sessions/{session_id}/webapp" script = f"" text = content.decode("utf-8") text = re.sub( r"(]*>)", lambda m: m.group(0) + script, text, count=1, flags=re.IGNORECASE, ) return text.encode("utf-8") def _rewrite_asset_paths(content: bytes, session_id: str) -> bytes: """Rewrite Next.js asset paths to go through the proxy.""" webapp_base_path = f"/api/build/sessions/{session_id}/webapp" escaped_webapp_base_path = webapp_base_path.replace("/", r"\/") hmr_paths = ("/_next/webpack-hmr", "/_next/hmr") text = content.decode("utf-8") # Anchor on delimiter so already-prefixed URLs (from assetPrefix) aren't double-rewritten. for delim in ('"', "'", "("): text = text.replace(f"{delim}/_next/", f"{delim}{webapp_base_path}/_next/") text = re.sub( rf"{re.escape(delim)}https?://[^/\"')]+/_next/", f"{delim}{webapp_base_path}/_next/", text, ) text = re.sub( rf"{re.escape(delim)}wss?://[^/\"')]+/_next/", f"{delim}{webapp_base_path}/_next/", text, ) text = text.replace(r"\/_next\/", rf"{escaped_webapp_base_path}\/_next\/") text = re.sub( r"https?:\\\/\\\/[^\"']+?\\\/_next\\\/", rf"{escaped_webapp_base_path}\/_next\/", text, ) text = re.sub( r"wss?:\\\/\\\/[^\"']+?\\\/_next\\\/", rf"{escaped_webapp_base_path}\/_next\/", text, ) for hmr_path in hmr_paths: escaped_hmr_path = hmr_path.replace("/", r"\/") text = text.replace( f"{webapp_base_path}{hmr_path}", hmr_path, ) text = text.replace( f"{escaped_webapp_base_path}{escaped_hmr_path}", escaped_hmr_path, ) text = re.sub( r'"(/(?:[a-zA-Z0-9_-]+/)*[a-zA-Z0-9_-]+\.json)"', f'"{webapp_base_path}\\1"', text, ) text = re.sub( r"'(/(?:[a-zA-Z0-9_-]+/)*[a-zA-Z0-9_-]+\.json)'", f"'{webapp_base_path}\\1'", text, ) text = text.replace('"/favicon.ico', f'"{webapp_base_path}/favicon.ico') return text.encode("utf-8") def _rewrite_proxy_response_headers( headers: dict[str, str], session_id: str ) -> dict[str, str]: """Rewrite response headers that can leak root-scoped asset URLs.""" link = headers.get("link") if link: webapp_base_path = f"/api/build/sessions/{session_id}/webapp" rewritten_link = re.sub( r"]+/_next/", f"<{webapp_base_path}/_next/", link, ) rewritten_link = rewritten_link.replace( " str: """Get the internal URL for a session's Next.js server. Uses the sandbox manager to get the correct URL for both local and Kubernetes environments. Args: session_id: The build session ID db_session: Database session Returns: Internal URL to proxy requests to Raises: HTTPException: If session not found, port not allocated, or sandbox not found """ session = db_session.get(BuildSession, session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") if session.nextjs_port is None: raise HTTPException(status_code=503, detail="Session port not allocated") if session.user_id is None: raise HTTPException(status_code=404, detail="User not found") sandbox = get_sandbox_by_user_id(db_session, session.user_id) if sandbox is None: raise HTTPException(status_code=404, detail="Sandbox not found") sandbox_manager = get_sandbox_manager() return sandbox_manager.get_webapp_url(sandbox.id, session.nextjs_port) def _proxy_request( path: str, request: Request, session_id: UUID, db_session: Session ) -> StreamingResponse | Response: """Proxy a request to the sandbox's Next.js server.""" base_url = _get_sandbox_url(session_id, db_session) # Build the target URL target_url = f"{base_url}/{path.lstrip('/')}" # Include query params if present if request.query_params: target_url = f"{target_url}?{request.query_params}" logger.debug(f"Proxying request to: {target_url}") try: # Make the request to the target URL with httpx.Client(timeout=30.0, follow_redirects=True) as client: response = client.get( target_url, headers={ key: value for key, value in request.headers.items() if key.lower() not in ("host", "content-length") }, ) # Build response headers, excluding hop-by-hop headers response_headers = { key: value for key, value in response.headers.items() if key.lower() not in EXCLUDED_HEADERS } response_headers = _rewrite_proxy_response_headers( response_headers, str(session_id) ) content_type = response.headers.get("content-type", "") # For HTML/CSS/JS responses, rewrite asset paths if any(ct in content_type for ct in REWRITABLE_CONTENT_TYPES): content = _rewrite_asset_paths(response.content, str(session_id)) if "text/html" in content_type: content = _inject_hmr_fixer(content, str(session_id)) return Response( content=content, status_code=response.status_code, headers=response_headers, media_type=content_type, ) return StreamingResponse( content=_stream_response(response), status_code=response.status_code, headers=response_headers, media_type=content_type or None, ) except httpx.TimeoutException: logger.error(f"Timeout while proxying request to {target_url}") raise HTTPException(status_code=504, detail="Gateway timeout") except httpx.RequestError as e: logger.error(f"Error proxying request to {target_url}: {e}") raise HTTPException(status_code=502, detail="Bad gateway") def _check_webapp_access( session_id: UUID, user: User | None, db_session: Session ) -> BuildSession: """Check if user can access a session's webapp. - public_global: accessible by anyone (no auth required) - public_org: accessible by any authenticated user - private: only accessible by the session owner """ session = db_session.get(BuildSession, session_id) if not session: raise HTTPException(status_code=404, detail="Session not found") if session.sharing_scope == SharingScope.PUBLIC_GLOBAL: return session if user is None: raise HTTPException(status_code=401, detail="Authentication required") if session.sharing_scope == SharingScope.PRIVATE and session.user_id != user.id: raise HTTPException(status_code=404, detail="Session not found") return session _OFFLINE_HTML_PATH = _TEMPLATES_DIR / "webapp_offline.html" def _offline_html_response() -> Response: """Return a branded Craft HTML page when the sandbox is not reachable. Design mirrors the default Craft web template (outputs/web/app/page.tsx): terminal window aesthetic with Minecraft-themed typing animation. """ html = _OFFLINE_HTML_PATH.read_text() return Response(content=html, status_code=503, media_type="text/html") # Public router for webapp proxy — no authentication required # (access controlled per-session via sharing_scope) public_build_router = APIRouter(prefix="/build") @public_build_router.get("/sessions/{session_id}/webapp", response_model=None) @public_build_router.get( "/sessions/{session_id}/webapp/{path:path}", response_model=None ) def get_webapp( session_id: UUID, request: Request, path: str = "", user: User | None = Depends(optional_user), db_session: Session = Depends(get_session), ) -> StreamingResponse | Response: """Proxy the webapp for a specific session (root and subpaths). Accessible without authentication when sharing_scope is public_global. Returns a friendly offline page when the sandbox is not running. """ try: _check_webapp_access(session_id, user, db_session) except HTTPException as e: if e.status_code == 401: return RedirectResponse(url="/auth/login", status_code=302) raise try: return _proxy_request(path, request, session_id, db_session) except HTTPException as e: if e.status_code in (502, 503, 504): return _offline_html_response() raise # ============================================================================= # Sandbox Management Endpoints # ============================================================================= @router.post("/sandbox/reset", response_model=None) def reset_sandbox( user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> Response: """Reset the user's sandbox by terminating it and cleaning up all sessions. This endpoint terminates the user's shared sandbox container/pod and cleans up all session workspaces. Useful for "start fresh" functionality. After calling this endpoint, the next session creation will provision a new sandbox. """ session_manager = SessionManager(db_session) try: success = session_manager.terminate_user_sandbox(user.id) if not success: raise HTTPException( status_code=404, detail="No sandbox found for user", ) db_session.commit() except HTTPException: raise except Exception as e: db_session.rollback() logger.error(f"Failed to reset sandbox for user {user.id}: {e}") raise HTTPException( status_code=500, detail=f"Failed to reset sandbox: {e}", ) return Response(status_code=204) ================================================ FILE: backend/onyx/server/features/build/api/messages_api.py ================================================ """API endpoints for Build Mode message management.""" from collections.abc import Generator from uuid import UUID from fastapi import APIRouter from fastapi import Depends from fastapi import HTTPException from fastapi.responses import StreamingResponse from sqlalchemy.orm import Session from onyx.auth.users import current_user from onyx.configs.constants import PUBLIC_API_TAGS from onyx.db.engine.sql_engine import get_session from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import User from onyx.server.features.build.api.models import MessageListResponse from onyx.server.features.build.api.models import MessageRequest from onyx.server.features.build.api.models import MessageResponse from onyx.server.features.build.db.sandbox import get_sandbox_by_user_id from onyx.server.features.build.db.sandbox import update_sandbox_heartbeat from onyx.server.features.build.session.manager import RateLimitError from onyx.server.features.build.session.manager import SessionManager from onyx.utils.logger import setup_logger logger = setup_logger() router = APIRouter() def check_build_rate_limits( user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> None: """ Dependency to check build mode rate limits before processing the request. Raises HTTPException(429) if rate limit is exceeded. Follows the same pattern as chat's check_token_rate_limits. """ session_manager = SessionManager(db_session) try: session_manager.check_rate_limit(user) except RateLimitError as e: raise HTTPException( status_code=429, detail=str(e), ) @router.get("/sessions/{session_id}/messages", tags=PUBLIC_API_TAGS) def list_messages( session_id: UUID, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> MessageListResponse: """Get all messages for a build session.""" session_manager = SessionManager(db_session) messages = session_manager.list_messages(session_id, user.id) if messages is None: raise HTTPException(status_code=404, detail="Session not found") return MessageListResponse( messages=[MessageResponse.from_model(msg) for msg in messages] ) @router.post("/sessions/{session_id}/send-message", tags=PUBLIC_API_TAGS) def send_message( session_id: UUID, request: MessageRequest, user: User = Depends(current_user), _rate_limit_check: None = Depends(check_build_rate_limits), ) -> StreamingResponse: """ Send a message to the CLI agent and stream the response. Enforces rate limiting before executing the agent (via dependency). Returns a Server-Sent Events (SSE) stream with the agent's response. Follows the same pattern as /chat/send-chat-message for consistency. """ def stream_generator() -> Generator[str, None, None]: """Stream generator that manages its own database session. This is necessary because StreamingResponse consumes the generator AFTER the endpoint returns, at which point FastAPI's dependency-injected db_session has already been closed. By creating a new session inside the generator, we ensure the session remains open for the entire streaming duration. """ # Capture user info needed for streaming (user object may not be available # after the endpoint returns due to dependency cleanup) user_id = user.id message_content = request.content with get_session_with_current_tenant() as db_session: # Update sandbox heartbeat - this is the only place we track activity # for determining when a sandbox should be put to sleep sandbox = get_sandbox_by_user_id(db_session, user.id) if sandbox and sandbox.status.is_active(): update_sandbox_heartbeat(db_session, sandbox.id) session_manager = SessionManager(db_session) yield from session_manager.send_message( session_id, user_id, message_content ) # Stream the CLI agent's response return StreamingResponse( stream_generator(), media_type="text/event-stream", headers={ "Cache-Control": "no-cache", "Connection": "keep-alive", "X-Accel-Buffering": "no", # Disable nginx buffering }, ) ================================================ FILE: backend/onyx/server/features/build/api/models.py ================================================ from datetime import datetime from enum import Enum from typing import Any from typing import TYPE_CHECKING from typing import Union from pydantic import BaseModel from onyx.configs.constants import MessageType from onyx.db.enums import ArtifactType from onyx.db.enums import BuildSessionStatus from onyx.db.enums import SandboxStatus from onyx.db.enums import SharingScope from onyx.server.features.build.sandbox.models import ( FilesystemEntry as FileSystemEntry, ) if TYPE_CHECKING: from onyx.db.models import Sandbox from onyx.db.models import BuildSession # ===== Session Models ===== class SessionCreateRequest(BaseModel): """Request to create a new build session.""" name: str | None = None # Optional session name demo_data_enabled: bool = True # Whether to enable demo org_info data in sandbox user_work_area: str | None = None # User's work area (e.g., "engineering") user_level: str | None = None # User's level (e.g., "ic", "manager") # LLM selection from user's cookie llm_provider_type: str | None = None # Provider type (e.g., "anthropic", "openai") llm_model_name: str | None = None # Model name (e.g., "claude-opus-4-5") class SessionUpdateRequest(BaseModel): """Request to update a build session. If name is None, the session name will be auto-generated using LLM. """ name: str | None = None class SessionNameGenerateResponse(BaseModel): """Response containing a generated session name.""" name: str class SandboxResponse(BaseModel): """Sandbox metadata in session response.""" id: str status: SandboxStatus container_id: str | None created_at: datetime last_heartbeat: datetime | None @classmethod def from_model(cls, sandbox: Any) -> "SandboxResponse": """Convert Sandbox ORM model to response.""" return cls( id=str(sandbox.id), status=sandbox.status, container_id=sandbox.container_id, created_at=sandbox.created_at, last_heartbeat=sandbox.last_heartbeat, ) class ArtifactResponse(BaseModel): """Artifact metadata in session response.""" id: str session_id: str type: ArtifactType name: str path: str preview_url: str | None created_at: datetime updated_at: datetime @classmethod def from_model(cls, artifact: Any) -> "ArtifactResponse": """Convert Artifact ORM model to response.""" return cls( id=str(artifact.id), session_id=str(artifact.session_id), type=artifact.type, name=artifact.name, path=artifact.path, preview_url=getattr(artifact, "preview_url", None), created_at=artifact.created_at, updated_at=artifact.updated_at, ) class SessionResponse(BaseModel): """Response containing session details.""" id: str user_id: str | None name: str | None status: BuildSessionStatus created_at: datetime last_activity_at: datetime nextjs_port: int | None sandbox: SandboxResponse | None artifacts: list[ArtifactResponse] sharing_scope: SharingScope @classmethod def from_model( cls, session: "BuildSession", sandbox: Union["Sandbox", None] = None ) -> "SessionResponse": """Convert BuildSession ORM model to response. Args: session: BuildSession ORM model sandbox: Optional Sandbox ORM model. Since sandboxes are now user-owned (not session-owned), the sandbox must be passed separately. """ return cls( id=str(session.id), user_id=str(session.user_id) if session.user_id else None, name=session.name, status=session.status, created_at=session.created_at, last_activity_at=session.last_activity_at, nextjs_port=session.nextjs_port, sandbox=(SandboxResponse.from_model(sandbox) if sandbox else None), artifacts=[ArtifactResponse.from_model(a) for a in session.artifacts], sharing_scope=session.sharing_scope, ) class DetailedSessionResponse(SessionResponse): """Extended session response with sandbox state details. Used for single-session endpoints where we compute expensive fields like session_loaded_in_sandbox. """ session_loaded_in_sandbox: bool @classmethod def from_session_response( cls, base: SessionResponse, session_loaded_in_sandbox: bool, ) -> "DetailedSessionResponse": return cls( **base.model_dump(), session_loaded_in_sandbox=session_loaded_in_sandbox, ) class SessionListResponse(BaseModel): """Response containing list of sessions.""" sessions: list[SessionResponse] class SetSessionSharingRequest(BaseModel): """Request to set the sharing scope of a session.""" sharing_scope: SharingScope class SetSessionSharingResponse(BaseModel): """Response after setting session sharing scope.""" session_id: str sharing_scope: SharingScope # ===== Message Models ===== class MessageRequest(BaseModel): """Request to send a message to the CLI agent.""" content: str class MessageResponse(BaseModel): """Response containing message details. All message data is stored in message_metadata as JSON (the raw ACP packet). The turn_index groups all assistant responses under the user prompt they respond to. Packet types in message_metadata: - user_message: {type: "user_message", content: {...}} - agent_message: {type: "agent_message", content: {...}} - agent_thought: {type: "agent_thought", content: {...}} - tool_call_progress: {type: "tool_call_progress", status: "completed", ...} - agent_plan_update: {type: "agent_plan_update", entries: [...]} """ id: str session_id: str turn_index: int type: MessageType message_metadata: dict[str, Any] created_at: datetime @classmethod def from_model(cls, message: Any) -> "MessageResponse": """Convert BuildMessage ORM model to response.""" return cls( id=str(message.id), session_id=str(message.session_id), turn_index=message.turn_index, type=message.type, message_metadata=message.message_metadata, created_at=message.created_at, ) class MessageListResponse(BaseModel): """Response containing list of messages.""" messages: list[MessageResponse] # ===== Legacy Models (for compatibility with other code) ===== class CreateSessionRequest(BaseModel): task: str available_sources: list[str] | None = None class CreateSessionResponse(BaseModel): session_id: str class ExecuteRequest(BaseModel): task: str context: str | None = None class ArtifactInfo(BaseModel): artifact_type: str # "webapp", "file", "markdown", "image" path: str filename: str mime_type: str | None = None class SessionStatus(BaseModel): session_id: str status: str # "idle", "running", "completed", "failed" webapp_url: str | None = None class DirectoryListing(BaseModel): path: str # Current directory path entries: list[FileSystemEntry] # Contents class WebappInfo(BaseModel): has_webapp: bool # Whether a webapp exists in outputs/web webapp_url: str | None # URL to access the webapp (e.g., http://localhost:3015) status: str # Sandbox status (running, terminated, etc.) ready: bool # Whether the NextJS dev server is actually responding sharing_scope: SharingScope # ===== File Upload Models ===== class UploadResponse(BaseModel): """Response after successful file upload.""" filename: str # Sanitized filename path: str # Relative path in sandbox (e.g., "attachments/doc.pdf") size_bytes: int # File size in bytes # ===== Rate Limit Models ===== class RateLimitResponse(BaseModel): """Rate limit information.""" is_limited: bool limit_type: str # "weekly" or "total" messages_used: int limit: int reset_timestamp: str | None = None # ===== Pre-Provisioned Session Check Models ===== class PreProvisionedCheckResponse(BaseModel): """Response for checking if a pre-provisioned session is still valid (empty).""" valid: bool # True if session exists and has no messages session_id: str | None = None # Session ID if valid, None otherwise # ===== Build Connector Models ===== class BuildConnectorStatus(str, Enum): """Status of a build connector.""" NOT_CONNECTED = "not_connected" CONNECTED = "connected" CONNECTED_WITH_ERRORS = "connected_with_errors" INDEXING = "indexing" ERROR = "error" DELETING = "deleting" class BuildConnectorInfo(BaseModel): """Simplified connector info for build admin panel.""" cc_pair_id: int connector_id: int credential_id: int source: str name: str status: BuildConnectorStatus docs_indexed: int last_indexed: datetime | None error_message: str | None = None class BuildConnectorListResponse(BaseModel): """List of build connectors.""" connectors: list[BuildConnectorInfo] # ===== Suggestion Bubble Models ===== class SuggestionTheme(str, Enum): """Theme/category of a follow-up suggestion.""" ADD = "add" QUESTION = "question" class SuggestionBubble(BaseModel): """A single follow-up suggestion bubble.""" theme: SuggestionTheme text: str class GenerateSuggestionsRequest(BaseModel): """Request to generate follow-up suggestions.""" user_message: str # First user message assistant_message: str # First assistant text response (accumulated) class GenerateSuggestionsResponse(BaseModel): """Response containing generated suggestions.""" suggestions: list[SuggestionBubble] class PptxPreviewResponse(BaseModel): """Response with PPTX slide preview metadata.""" slide_count: int slide_paths: list[str] # Relative paths to slide JPEGs within session workspace cached: bool # Whether result was served from cache ================================================ FILE: backend/onyx/server/features/build/api/packet_logger.py ================================================ """Comprehensive packet and ACP event logger for build mode debugging. Logs all packets, JSON-RPC messages, and ACP events during build mode streaming. Provides detailed tracing for the entire agent loop and communication flow. Log output locations (in priority order): 1. /var/log/onyx/packets.log (for Docker - mounted to host via docker-compose volumes) 2. backend/log/packets.log (for local dev without Docker) 3. backend/onyx/server/features/build/packets.log (fallback) Enable logging by setting LOG_LEVEL=DEBUG or BUILD_PACKET_LOGGING=true. Features: - Rotating log with max 5000 lines (configurable via BUILD_PACKET_LOG_MAX_LINES) - Automatically trims oldest entries when limit is exceeded - Visual separators between message streams for easy reading """ import json import logging import os import threading import time from pathlib import Path from typing import Any from uuid import UUID # Default max lines to keep in the log file (acts like a deque) DEFAULT_MAX_LOG_LINES = 5000 class PacketLogger: """Comprehensive logger for ACP/OpenCode communication and packet streaming. Logs: - All JSON-RPC requests sent to the agent - All JSON-RPC responses/notifications received from the agent - All ACP events emitted during streaming - Session and sandbox lifecycle events - Timing information for debugging performance The log file is kept to a maximum number of lines (default 5000) to prevent unbounded growth. When the limit is exceeded, the oldest lines are trimmed. """ _instance: "PacketLogger | None" = None _initialized: bool def __new__(cls) -> "PacketLogger": if cls._instance is None: cls._instance = super().__new__(cls) cls._instance._initialized = False return cls._instance def __init__(self) -> None: if self._initialized: return self._initialized = True # Enable via LOG_LEVEL=DEBUG or BUILD_PACKET_LOGGING=true log_level = os.getenv("LOG_LEVEL", "").upper() packet_logging = os.getenv("BUILD_PACKET_LOGGING", "").lower() self._enabled = log_level == "DEBUG" or packet_logging in ("true", "1", "yes") self._logger: logging.Logger | None = None self._log_file_path: Path | None = None self._session_start_times: dict[str, float] = {} # Max lines to keep in log file try: self._max_lines = int( os.getenv("BUILD_PACKET_LOG_MAX_LINES", str(DEFAULT_MAX_LOG_LINES)) ) except ValueError: self._max_lines = DEFAULT_MAX_LOG_LINES # Lock for thread-safe file operations self._file_lock = threading.Lock() # Track approximate line count to avoid reading file too often self._approx_line_count = 0 self._lines_since_last_trim = 0 # Trim every N lines written to avoid constant file reads self._trim_interval = 500 if self._enabled: self._setup_logger() def _get_log_file_path(self) -> Path: """Determine the best log file path based on environment. Priority: 1. /var/log/onyx/packets.log - Docker environment (mounted to host) 2. backend/log/packets.log - Local dev (same dir as other logs) 3. backend/onyx/server/features/build/packets.log - Fallback """ # Option 1: Docker environment - use /var/log/onyx which is mounted docker_log_dir = Path("/var/log/onyx") if docker_log_dir.exists() and docker_log_dir.is_dir(): return docker_log_dir / "packets.log" # Option 2: Local dev - use backend/log directory (same as other debug logs) # Navigate from this file to backend/log backend_dir = Path(__file__).parents[4] # up to backend/ local_log_dir = backend_dir / "log" if local_log_dir.exists() and local_log_dir.is_dir(): return local_log_dir / "packets.log" # Option 3: Fallback to build directory build_dir = Path(__file__).parents[1] return build_dir / "packets.log" def _setup_logger(self) -> None: """Set up the file handler for packet logging.""" self._log_file_path = self._get_log_file_path() # Ensure parent directory exists self._log_file_path.parent.mkdir(parents=True, exist_ok=True) self._logger = logging.getLogger("build.packets") self._logger.setLevel(logging.DEBUG) self._logger.propagate = False self._logger.handlers.clear() # Use append mode handler = logging.FileHandler(self._log_file_path, mode="a", encoding="utf-8") handler.setLevel(logging.DEBUG) # Include timestamp in each log entry handler.setFormatter( logging.Formatter( "%(asctime)s.%(msecs)03d | %(message)s", "%Y-%m-%d %H:%M:%S" ) ) self._logger.addHandler(handler) # Initialize line count from existing file self._init_line_count() def _init_line_count(self) -> None: """Initialize the approximate line count from the existing log file.""" if not self._log_file_path or not self._log_file_path.exists(): self._approx_line_count = 0 return try: with open(self._log_file_path, "r", encoding="utf-8", errors="ignore") as f: self._approx_line_count = sum(1 for _ in f) except Exception: self._approx_line_count = 0 def _maybe_trim_log(self) -> None: """Trim the log file if it exceeds the max line limit. This is called periodically (every _trim_interval lines) to avoid reading the file on every write. """ self._lines_since_last_trim += 1 if self._lines_since_last_trim < self._trim_interval: return self._lines_since_last_trim = 0 self._trim_log_file() def _trim_log_file(self) -> None: """Trim the log file to keep only the last max_lines.""" if not self._log_file_path or not self._log_file_path.exists(): return with self._file_lock: try: # Read all lines with open( self._log_file_path, "r", encoding="utf-8", errors="ignore" ) as f: lines = f.readlines() current_count = len(lines) self._approx_line_count = current_count # If under limit, nothing to do if current_count <= self._max_lines: return # Keep only the last max_lines lines_to_keep = lines[-self._max_lines :] # Close the logger's file handler temporarily if self._logger: for handler in self._logger.handlers: handler.close() # Rewrite the file with trimmed content with open(self._log_file_path, "w", encoding="utf-8") as f: f.writelines(lines_to_keep) # Reopen the handler if self._logger: self._logger.handlers.clear() handler = logging.FileHandler( self._log_file_path, mode="a", encoding="utf-8" ) handler.setLevel(logging.DEBUG) handler.setFormatter( logging.Formatter( "%(asctime)s.%(msecs)03d | %(message)s", "%Y-%m-%d %H:%M:%S" ) ) self._logger.addHandler(handler) self._approx_line_count = len(lines_to_keep) except Exception: pass # Silently ignore errors during trim def clear_log_file(self) -> None: """Clear the log file contents. Note: With the rotating log approach, this is optional. The log will automatically trim itself. But this can still be useful to start fresh. """ if not self._enabled or not self._log_file_path: return with self._file_lock: try: # Close the logger's file handler temporarily if self._logger: for handler in self._logger.handlers: handler.close() # Truncate the file with open(self._log_file_path, "w", encoding="utf-8") as f: f.write("") # Empty the file # Reopen the handler if self._logger: self._logger.handlers.clear() handler = logging.FileHandler( self._log_file_path, mode="a", encoding="utf-8" ) handler.setLevel(logging.DEBUG) handler.setFormatter( logging.Formatter( "%(asctime)s.%(msecs)03d | %(message)s", "%Y-%m-%d %H:%M:%S" ) ) self._logger.addHandler(handler) self._approx_line_count = 0 self._lines_since_last_trim = 0 except Exception: pass # Silently ignore errors @property def is_enabled(self) -> bool: """Check if logging is enabled.""" return self._enabled and self._logger is not None def _format_uuid(self, value: Any) -> str: """Format UUID for logging (shortened for readability).""" if isinstance(value, UUID): return str(value)[:8] if isinstance(value, str) and len(value) >= 8: return value[:8] return str(value) def _write_log(self, message: str) -> None: """Internal method to write a log message and trigger trim check. Args: message: The formatted log message """ if not self._logger: return self._logger.debug(message) self._maybe_trim_log() def log(self, packet_type: str, payload: dict[str, Any] | None = None) -> None: """Log a packet as JSON. Args: packet_type: The type of packet payload: The packet payload """ if not self._enabled or not self._logger: return try: output = json.dumps(payload, indent=2, default=str) if payload else "{}" self._write_log(f"[PACKET] {packet_type}\n{output}") except Exception: self._write_log(f"[PACKET] {packet_type}\n{payload}") def log_raw(self, label: str, data: Any) -> None: """Log raw data with a label. Args: label: A label for this log entry data: Any data to log """ if not self._enabled or not self._logger: return try: if isinstance(data, (dict, list)): output = json.dumps(data, indent=2, default=str) else: output = str(data) self._write_log(f"[RAW] {label}\n{output}") except Exception: self._write_log(f"[RAW] {label}\n{data}") # ========================================================================= # JSON-RPC Communication Logging # ========================================================================= def log_jsonrpc_request( self, method: str, request_id: int | None, params: dict[str, Any] | None = None, context: str = "", ) -> None: """Log a JSON-RPC request being sent to the agent. Args: method: The JSON-RPC method name request_id: The request ID (None for notifications) params: The request parameters context: Additional context (e.g., "local", "k8s") """ if not self._enabled or not self._logger: return try: req_type = "REQUEST" if request_id is not None else "NOTIFICATION" ctx_prefix = f"[{context}] " if context else "" params_str = json.dumps(params, indent=2, default=str) if params else "{}" id_str = f" id={request_id}" if request_id is not None else "" self._write_log( f"{ctx_prefix}[JSONRPC-OUT] {req_type} {method}{id_str}\n{params_str}" ) except Exception as e: self._write_log(f"[JSONRPC-OUT] {method} (logging error: {e})") def log_jsonrpc_response( self, request_id: int | None, result: dict[str, Any] | None = None, error: dict[str, Any] | None = None, context: str = "", ) -> None: """Log a JSON-RPC response received from the agent. Args: request_id: The request ID this is responding to result: The result payload (if success) error: The error payload (if error) context: Additional context (e.g., "local", "k8s") """ if not self._enabled or not self._logger: return try: ctx_prefix = f"[{context}] " if context else "" id_str = f" id={request_id}" if request_id is not None else "" if error: error_str = json.dumps(error, indent=2, default=str) self._write_log( f"{ctx_prefix}[JSONRPC-IN] RESPONSE{id_str} ERROR\n{error_str}" ) else: result_str = ( json.dumps(result, indent=2, default=str) if result else "{}" ) self._write_log( f"{ctx_prefix}[JSONRPC-IN] RESPONSE{id_str}\n{result_str}" ) except Exception as e: self._write_log(f"[JSONRPC-IN] RESPONSE (logging error: {e})") def log_jsonrpc_notification( self, method: str, params: dict[str, Any] | None = None, context: str = "", ) -> None: """Log a JSON-RPC notification received from the agent. Args: method: The notification method name params: The notification parameters context: Additional context (e.g., "local", "k8s") """ if not self._enabled or not self._logger: return try: ctx_prefix = f"[{context}] " if context else "" params_str = json.dumps(params, indent=2, default=str) if params else "{}" self._write_log( f"{ctx_prefix}[JSONRPC-IN] NOTIFICATION {method}\n{params_str}" ) except Exception as e: self._write_log(f"[JSONRPC-IN] NOTIFICATION {method} (logging error: {e})") def log_jsonrpc_raw_message( self, direction: str, message: dict[str, Any] | str, context: str = "", ) -> None: """Log a raw JSON-RPC message (for debugging parsing issues). Args: direction: "IN" or "OUT" message: The raw message (dict or string) context: Additional context """ if not self._enabled or not self._logger: return try: ctx_prefix = f"[{context}] " if context else "" if isinstance(message, dict): msg_str = json.dumps(message, indent=2, default=str) else: msg_str = str(message) self._write_log(f"{ctx_prefix}[JSONRPC-RAW-{direction}]\n{msg_str}") except Exception as e: self._write_log(f"[JSONRPC-RAW-{direction}] (logging error: {e})") # ========================================================================= # ACP Event Logging # ========================================================================= def log_acp_event( self, event_type: str, event_data: dict[str, Any], sandbox_id: UUID | str | None = None, session_id: UUID | str | None = None, ) -> None: """Log an ACP event being emitted. Args: event_type: The ACP event type (e.g., "agent_message_chunk") event_data: The full event data sandbox_id: The sandbox ID (optional, for context) session_id: The session ID (optional, for context) """ if not self._enabled or not self._logger: return try: ctx_parts = [] if sandbox_id: ctx_parts.append(f"sandbox={self._format_uuid(sandbox_id)}") if session_id: ctx_parts.append(f"session={self._format_uuid(session_id)}") ctx = f" ({', '.join(ctx_parts)})" if ctx_parts else "" # For message chunks, show truncated content for readability display_data = event_data.copy() if event_type in ("agent_message_chunk", "agent_thought_chunk"): content = display_data.get("content", {}) if isinstance(content, dict) and "text" in content: text = content.get("text", "") if len(text) > 200: display_data["content"] = { **content, "text": text[:200] + f"... ({len(text)} chars total)", } event_str = json.dumps(display_data, indent=2, default=str) self._write_log(f"[ACP-EVENT] {event_type}{ctx}\n{event_str}") except Exception as e: self._write_log(f"[ACP-EVENT] {event_type} (logging error: {e})") def log_acp_event_yielded( self, event_type: str, event_obj: Any, sandbox_id: UUID | str | None = None, session_id: UUID | str | None = None, ) -> None: """Log an ACP event object being yielded from the generator. Args: event_type: The ACP event type event_obj: The Pydantic event object sandbox_id: The sandbox ID (optional) session_id: The session ID (optional) """ if not self._enabled or not self._logger: return try: if hasattr(event_obj, "model_dump"): event_data = event_obj.model_dump(mode="json", by_alias=True) else: event_data = {"raw": str(event_obj)} self.log_acp_event(event_type, event_data, sandbox_id, session_id) except Exception as e: self._write_log(f"[ACP-EVENT] {event_type} (logging error: {e})") # ========================================================================= # Session and Sandbox Lifecycle Logging # ========================================================================= def log_session_start( self, session_id: UUID | str, sandbox_id: UUID | str, message_preview: str = "", ) -> None: """Log the start of a message streaming session. Args: session_id: The session ID sandbox_id: The sandbox ID message_preview: First 100 chars of the user message """ if not self._enabled or not self._logger: return session_key = str(session_id) self._session_start_times[session_key] = time.time() preview = ( message_preview[:100] + "..." if len(message_preview) > 100 else message_preview ) self._write_log( f"[SESSION-START] session={self._format_uuid(session_id)} " f"sandbox={self._format_uuid(sandbox_id)}\n" f" message: {preview}" ) def log_session_end( self, session_id: UUID | str, success: bool = True, error: str | None = None, events_count: int = 0, ) -> None: """Log the end of a message streaming session. Args: session_id: The session ID success: Whether the session completed successfully error: Error message if failed events_count: Number of events emitted """ if not self._enabled or not self._logger: return session_key = str(session_id) start_time = self._session_start_times.pop(session_key, None) duration_ms = (time.time() - start_time) * 1000 if start_time else 0 status = "SUCCESS" if success else "FAILED" error_str = f"\n error: {error}" if error else "" self._write_log( f"[SESSION-END] session={self._format_uuid(session_id)} " f"status={status} duration={duration_ms:.0f}ms events={events_count}" f"{error_str}" ) def log_acp_client_start( self, sandbox_id: UUID | str, session_id: UUID | str, cwd: str, context: str = "", ) -> None: """Log ACP client initialization. Args: sandbox_id: The sandbox ID session_id: The session ID cwd: Working directory context: "local" or "k8s" """ if not self._enabled or not self._logger: return ctx_prefix = f"[{context}] " if context else "" self._write_log( f"{ctx_prefix}[ACP-CLIENT-START] " f"sandbox={self._format_uuid(sandbox_id)} " f"session={self._format_uuid(session_id)}\n" f" cwd: {cwd}" ) def log_acp_client_stop( self, sandbox_id: UUID | str, session_id: UUID | str, context: str = "", ) -> None: """Log ACP client shutdown. Args: sandbox_id: The sandbox ID session_id: The session ID context: "local" or "k8s" """ if not self._enabled or not self._logger: return ctx_prefix = f"[{context}] " if context else "" self._write_log( f"{ctx_prefix}[ACP-CLIENT-STOP] sandbox={self._format_uuid(sandbox_id)} session={self._format_uuid(session_id)}" ) # ========================================================================= # Streaming State Logging # ========================================================================= def log_streaming_state_update( self, session_id: UUID | str, state_type: str, details: dict[str, Any] | None = None, ) -> None: """Log streaming state changes. Args: session_id: The session ID state_type: Type of state change (e.g., "chunk_accumulated", "saved_to_db") details: Additional details """ if not self._enabled or not self._logger: return try: details_str = "" if details: details_str = "\n" + json.dumps(details, indent=2, default=str) self._write_log( f"[STREAMING-STATE] session={self._format_uuid(session_id)} type={state_type}{details_str}" ) except Exception as e: self._write_log(f"[STREAMING-STATE] {state_type} (logging error: {e})") def log_sse_emit( self, event_type: str, session_id: UUID | str | None = None, ) -> None: """Log SSE event being emitted to frontend. Args: event_type: The event type being emitted session_id: The session ID """ if not self._enabled or not self._logger: return session_str = f" session={self._format_uuid(session_id)}" if session_id else "" self._write_log(f"[SSE-EMIT] {event_type}{session_str}") # Singleton instance _packet_logger: PacketLogger | None = None def get_packet_logger() -> PacketLogger: """Get the singleton packet logger instance.""" global _packet_logger if _packet_logger is None: _packet_logger = PacketLogger() return _packet_logger def log_separator(label: str = "") -> None: """Log a visual separator for readability in the log file. Args: label: Optional label for the separator """ logger = get_packet_logger() if not logger.is_enabled or not logger._logger: return separator = "=" * 80 if label: logger._write_log(f"\n{separator}\n{label}\n{separator}") else: logger._write_log(f"\n{separator}") ================================================ FILE: backend/onyx/server/features/build/api/packets.py ================================================ """Build Mode packet types for streaming agent responses. This module defines CUSTOM Onyx packet types that extend ACP (Agent Client Protocol). ACP events are passed through directly from the agent - this module only contains Onyx-specific extensions like artifacts and file operations. All packets use SSE (Server-Sent Events) format with `event: message` and include a `type` field to distinguish packet types. ACP events (passed through directly from acp.schema): - agent_message_chunk: Text/image content from agent - agent_thought_chunk: Agent's internal reasoning - tool_call_start: Tool invocation started - tool_call_progress: Tool execution progress/result - agent_plan_update: Agent's execution plan - current_mode_update: Agent mode change - prompt_response: Agent finished processing - error: An error occurred Custom Onyx packets (defined here): - error: Onyx-specific errors (e.g., session not found) Based on: - Agent Client Protocol (ACP): https://agentclientprotocol.com """ from datetime import datetime from datetime import timezone from typing import Any from typing import Literal from pydantic import BaseModel from pydantic import Field # ============================================================================= # Base Packet Type # ============================================================================= class BasePacket(BaseModel): """Base packet with common fields for all custom Onyx packet types.""" type: str timestamp: str = Field( default_factory=lambda: datetime.now(tz=timezone.utc).isoformat() ) # ============================================================================= # Custom Onyx Packets # ============================================================================= class ErrorPacket(BasePacket): """An Onyx-specific error occurred (e.g., session not found, sandbox not running).""" type: Literal["error"] = "error" message: str code: int | None = None details: dict[str, Any] | None = None # ============================================================================= # Union Type for Custom Onyx Packets # ============================================================================= BuildPacket = ErrorPacket ================================================ FILE: backend/onyx/server/features/build/api/rate_limit.py ================================================ """Rate limiting logic for Build Mode.""" from datetime import datetime from datetime import timedelta from datetime import timezone from typing import Literal from sqlalchemy.orm import Session from onyx.db.models import User from onyx.feature_flags.factory import get_default_feature_flag_provider from onyx.server.features.build.api.models import RateLimitResponse from onyx.server.features.build.api.subscription_check import is_user_subscribed from onyx.server.features.build.configs import CRAFT_PAID_USER_RATE_LIMIT from onyx.server.features.build.db.rate_limit import count_user_messages_in_window from onyx.server.features.build.db.rate_limit import count_user_messages_total from onyx.server.features.build.db.rate_limit import get_oldest_message_timestamp from onyx.server.features.build.utils import CRAFT_HAS_USAGE_LIMITS from shared_configs.configs import MULTI_TENANT # Default limit for free/non-subscribed users (not configurable) FREE_USER_RATE_LIMIT = 5 def _should_skip_rate_limiting(user: User) -> bool: """ Check if rate limiting should be skipped for this user. Currently grants unlimited usage to dev tenant users (tenant_dev). Controlled via PostHog feature flag. Returns: True to skip rate limiting (unlimited), False to apply normal limits """ # NOTE: We can modify the posthog flag to return more detail about a limit # i.e. can set variable limits per user and tenant via PostHog instead of env vars # to avoid re-deploying on every limit change feature_flag_provider = get_default_feature_flag_provider() # Flag returns True for users who SHOULD be rate limited # We negate to get: True = skip rate limiting has_rate_limit = feature_flag_provider.feature_enabled( CRAFT_HAS_USAGE_LIMITS, user.id, ) return not has_rate_limit def get_user_rate_limit_status( user: User, db_session: Session, ) -> RateLimitResponse: """ Get the rate limit status for a user. Rate limits: - Cloud (MULTI_TENANT=true): - Subscribed users: CRAFT_PAID_USER_RATE_LIMIT messages per week (configurable, default 25) - Non-subscribed users: 5 messages (lifetime total) - Per-user overrides via PostHog feature flag - Self-hosted (MULTI_TENANT=false): - Unlimited (no rate limiting) Args: user: The authenticated user db_session: Database session Returns: RateLimitResponse with current limit status """ # Self-hosted deployments have no rate limits if not MULTI_TENANT: return RateLimitResponse( is_limited=False, limit_type="weekly", messages_used=0, limit=0, # 0 indicates unlimited reset_timestamp=None, ) # Check if user should skip rate limiting (e.g., dev tenant users) if _should_skip_rate_limiting(user): return RateLimitResponse( is_limited=False, limit_type="weekly", messages_used=-1, limit=0, # 0 indicates unlimited reset_timestamp=None, ) # Determine subscription status is_subscribed = is_user_subscribed(user, db_session) # Get limit based on subscription status limit = CRAFT_PAID_USER_RATE_LIMIT if is_subscribed else FREE_USER_RATE_LIMIT # Limit type: weekly for subscribed users, total for free limit_type: Literal["weekly", "total"] = "weekly" if is_subscribed else "total" # Count messages if limit_type == "weekly": # Subscribed: rolling 7-day window cutoff_time = datetime.now(tz=timezone.utc) - timedelta(days=7) messages_used = count_user_messages_in_window(user.id, cutoff_time, db_session) # Calculate reset timestamp (when oldest message ages out) # Only show reset time if user is at or over the limit if messages_used >= limit: oldest_msg = get_oldest_message_timestamp(user.id, cutoff_time, db_session) if oldest_msg: reset_time = oldest_msg + timedelta(days=7) reset_timestamp = reset_time.isoformat() else: reset_timestamp = None else: reset_timestamp = None else: # Non-subscribed: lifetime total messages_used = count_user_messages_total(user.id, db_session) reset_timestamp = None return RateLimitResponse( is_limited=messages_used >= limit, limit_type=limit_type, messages_used=messages_used, limit=limit, reset_timestamp=reset_timestamp, ) ================================================ FILE: backend/onyx/server/features/build/api/sessions_api.py ================================================ """API endpoints for Build Mode session management.""" from uuid import UUID from fastapi import APIRouter from fastapi import Depends from fastapi import File from fastapi import HTTPException from fastapi import Response from fastapi import UploadFile from sqlalchemy import exists from sqlalchemy.orm import Session from onyx.auth.users import current_user from onyx.db.engine.sql_engine import get_session from onyx.db.enums import BuildSessionStatus from onyx.db.enums import SandboxStatus from onyx.db.models import BuildMessage from onyx.db.models import User from onyx.redis.redis_pool import get_redis_client from onyx.server.features.build.api.models import ArtifactResponse from onyx.server.features.build.api.models import DetailedSessionResponse from onyx.server.features.build.api.models import DirectoryListing from onyx.server.features.build.api.models import GenerateSuggestionsRequest from onyx.server.features.build.api.models import GenerateSuggestionsResponse from onyx.server.features.build.api.models import PptxPreviewResponse from onyx.server.features.build.api.models import PreProvisionedCheckResponse from onyx.server.features.build.api.models import SessionCreateRequest from onyx.server.features.build.api.models import SessionListResponse from onyx.server.features.build.api.models import SessionNameGenerateResponse from onyx.server.features.build.api.models import SessionResponse from onyx.server.features.build.api.models import SessionUpdateRequest from onyx.server.features.build.api.models import SetSessionSharingRequest from onyx.server.features.build.api.models import SetSessionSharingResponse from onyx.server.features.build.api.models import SuggestionBubble from onyx.server.features.build.api.models import SuggestionTheme from onyx.server.features.build.api.models import UploadResponse from onyx.server.features.build.api.models import WebappInfo from onyx.server.features.build.configs import SANDBOX_BACKEND from onyx.server.features.build.configs import SandboxBackend from onyx.server.features.build.db.build_session import allocate_nextjs_port from onyx.server.features.build.db.build_session import get_build_session from onyx.server.features.build.db.build_session import set_build_session_sharing_scope from onyx.server.features.build.db.sandbox import get_latest_snapshot_for_session from onyx.server.features.build.db.sandbox import get_sandbox_by_user_id from onyx.server.features.build.db.sandbox import update_sandbox_heartbeat from onyx.server.features.build.db.sandbox import update_sandbox_status__no_commit from onyx.server.features.build.sandbox import get_sandbox_manager from onyx.server.features.build.session.manager import SessionManager from onyx.server.features.build.session.manager import UploadLimitExceededError from onyx.server.features.build.utils import sanitize_filename from onyx.server.features.build.utils import validate_file from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() router = APIRouter(prefix="/sessions") # ============================================================================= # Session Management Endpoints # ============================================================================= @router.get("", response_model=SessionListResponse) def list_sessions( user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> SessionListResponse: """List all build sessions for the current user.""" session_manager = SessionManager(db_session) sessions = session_manager.list_sessions(user.id) # Get the user's sandbox (shared across all sessions) sandbox = get_sandbox_by_user_id(db_session, user.id) return SessionListResponse( sessions=[SessionResponse.from_model(session, sandbox) for session in sessions] ) # Lock timeout for session creation (should be longer than max provision time) SESSION_CREATE_LOCK_TIMEOUT_SECONDS = 300 @router.post("", response_model=DetailedSessionResponse) def create_session( request: SessionCreateRequest, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> DetailedSessionResponse: """ Create or get an existing empty build session. Creates a sandbox with the necessary file structure and returns a session ID. Uses SessionManager for session and sandbox provisioning. This endpoint is atomic - if sandbox provisioning fails, no database records are created (transaction is rolled back). Uses Redis lock to prevent race conditions when multiple requests try to create/provision a session for the same user concurrently. """ tenant_id = get_current_tenant_id() redis_client = get_redis_client(tenant_id=tenant_id) # Lock on user_id to prevent concurrent session creation for the same user # This prevents race conditions where two requests both see sandbox as SLEEPING # and both try to provision, with one deleting the other's work lock_key = f"session_create:{user.id}" lock = redis_client.lock(lock_key, timeout=SESSION_CREATE_LOCK_TIMEOUT_SECONDS) # blocking=True means wait if another create is in progress acquired = lock.acquire( blocking=True, blocking_timeout=SESSION_CREATE_LOCK_TIMEOUT_SECONDS ) if not acquired: raise HTTPException( status_code=503, detail="Session creation timed out waiting for lock", ) try: session_manager = SessionManager(db_session) build_session = session_manager.get_or_create_empty_session( user.id, user_work_area=( request.user_work_area if request.demo_data_enabled else None ), user_level=request.user_level if request.demo_data_enabled else None, llm_provider_type=request.llm_provider_type, llm_model_name=request.llm_model_name, demo_data_enabled=request.demo_data_enabled, ) db_session.commit() sandbox = get_sandbox_by_user_id(db_session, user.id) base_response = SessionResponse.from_model(build_session, sandbox) return DetailedSessionResponse.from_session_response( base_response, session_loaded_in_sandbox=True ) except ValueError as e: logger.exception("Session creation failed") db_session.rollback() raise HTTPException(status_code=429, detail=str(e)) except Exception as e: db_session.rollback() logger.error(f"Session creation failed: {e}") raise HTTPException(status_code=500, detail=f"Session creation failed: {e}") finally: if lock.owned(): lock.release() @router.get("/{session_id}", response_model=DetailedSessionResponse) def get_session_details( session_id: UUID, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> DetailedSessionResponse: """ Get details of a specific build session. Returns session_loaded_in_sandbox to indicate if the session workspace exists in the running sandbox. """ session_manager = SessionManager(db_session) session = session_manager.get_session(session_id, user.id) if session is None: raise HTTPException(status_code=404, detail="Session not found") # Get the user's sandbox to include in response sandbox = get_sandbox_by_user_id(db_session, user.id) # Check if session workspace exists in the sandbox session_loaded = False if sandbox and sandbox.status == SandboxStatus.RUNNING: sandbox_manager = get_sandbox_manager() session_loaded = sandbox_manager.session_workspace_exists( sandbox.id, session_id ) base_response = SessionResponse.from_model(session, sandbox) return DetailedSessionResponse.from_session_response( base_response, session_loaded_in_sandbox=session_loaded ) @router.get( "/{session_id}/pre-provisioned-check", response_model=PreProvisionedCheckResponse ) def check_pre_provisioned_session( session_id: UUID, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> PreProvisionedCheckResponse: """ Check if a pre-provisioned session is still valid (empty). Used by the frontend to poll and detect when another tab has used the session. A session is considered valid if it has no messages yet. Returns: - valid=True, session_id= if the session is still empty - valid=False, session_id=None if the session has messages or doesn't exist """ session = get_build_session(session_id, user.id, db_session) if session is None: return PreProvisionedCheckResponse(valid=False, session_id=None) # Check if session is still empty (no messages = pre-provisioned) has_messages = db_session.query( exists().where(BuildMessage.session_id == session_id) ).scalar() if not has_messages: return PreProvisionedCheckResponse(valid=True, session_id=str(session_id)) # Session has messages - it's no longer a valid pre-provisioned session return PreProvisionedCheckResponse(valid=False, session_id=None) @router.post("/{session_id}/generate-name", response_model=SessionNameGenerateResponse) def generate_session_name( session_id: UUID, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> SessionNameGenerateResponse: """Generate a session name using LLM based on the first user message.""" session_manager = SessionManager(db_session) generated_name = session_manager.generate_session_name(session_id, user.id) if generated_name is None: raise HTTPException(status_code=404, detail="Session not found") return SessionNameGenerateResponse(name=generated_name) @router.post( "/{session_id}/generate-suggestions", response_model=GenerateSuggestionsResponse ) def generate_suggestions( session_id: UUID, request: GenerateSuggestionsRequest, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> GenerateSuggestionsResponse: """Generate follow-up suggestions based on the first exchange in a session.""" session_manager = SessionManager(db_session) # Verify session exists and belongs to user session = session_manager.get_session(session_id, user.id) if session is None: raise HTTPException(status_code=404, detail="Session not found") # Generate suggestions suggestions_data = session_manager.generate_followup_suggestions( user_message=request.user_message, assistant_message=request.assistant_message, ) # Convert to response model suggestions = [ SuggestionBubble( theme=SuggestionTheme(item["theme"]), text=item["text"], ) for item in suggestions_data ] return GenerateSuggestionsResponse(suggestions=suggestions) @router.put("/{session_id}/name", response_model=SessionResponse) def update_session_name( session_id: UUID, request: SessionUpdateRequest, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> SessionResponse: """Update the name of a build session.""" session_manager = SessionManager(db_session) session = session_manager.update_session_name(session_id, user.id, request.name) if session is None: raise HTTPException(status_code=404, detail="Session not found") # Get the user's sandbox to include in response sandbox = get_sandbox_by_user_id(db_session, user.id) return SessionResponse.from_model(session, sandbox) @router.patch("/{session_id}/public") def set_session_public( session_id: UUID, request: SetSessionSharingRequest, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> SetSessionSharingResponse: """Set the sharing scope of a build session's webapp.""" updated = set_build_session_sharing_scope( session_id, user.id, request.sharing_scope, db_session ) if not updated: raise HTTPException(status_code=404, detail="Session not found") return SetSessionSharingResponse( session_id=str(session_id), sharing_scope=updated.sharing_scope, ) @router.delete("/{session_id}", response_model=None) def delete_session( session_id: UUID, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> Response: """Delete a build session and all associated data. This endpoint is atomic - if sandbox termination fails, the session is NOT deleted (transaction is rolled back). """ session_manager = SessionManager(db_session) try: success = session_manager.delete_session(session_id, user.id) if not success: raise HTTPException(status_code=404, detail="Session not found") db_session.commit() except HTTPException: # Re-raise HTTP exceptions (like 404) without rollback raise except Exception as e: # Sandbox termination failed - rollback to preserve session db_session.rollback() logger.error(f"Failed to delete session {session_id}: {e}") raise HTTPException( status_code=500, detail=f"Failed to delete session: {e}", ) return Response(status_code=204) # Lock timeout should be longer than max restore time (5 minutes) RESTORE_LOCK_TIMEOUT_SECONDS = 300 @router.post("/{session_id}/restore", response_model=DetailedSessionResponse) def restore_session( session_id: UUID, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> DetailedSessionResponse: """Restore sandbox and load session snapshot. Blocks until complete. Uses Redis lock to ensure only one restore runs per sandbox at a time. If another restore is in progress, waits for it to complete. Handles two cases: 1. Sandbox is SLEEPING: Re-provision pod, then load session snapshot 2. Sandbox is RUNNING but session not loaded: Just load session snapshot Returns immediately if session workspace already exists in pod. Always returns session_loaded_in_sandbox=True on success. """ session = get_build_session(session_id, user.id, db_session) if not session: raise HTTPException(status_code=404, detail="Session not found") sandbox = get_sandbox_by_user_id(db_session, user.id) if not sandbox: raise HTTPException(status_code=404, detail="Sandbox not found") # If sandbox is already running, check if session workspace exists sandbox_manager = get_sandbox_manager() tenant_id = get_current_tenant_id() # Need to do some work - acquire Redis lock redis_client = get_redis_client(tenant_id=tenant_id) lock_key = f"sandbox_restore:{sandbox.id}" lock = redis_client.lock(lock_key, timeout=RESTORE_LOCK_TIMEOUT_SECONDS) # Non-blocking: if another restore is already running, return 409 immediately # instead of making the user wait. The frontend will retry. acquired = lock.acquire(blocking=False) if not acquired: raise HTTPException( status_code=409, detail="Restore already in progress", ) try: # Re-fetch sandbox status (may have changed while waiting for lock) db_session.refresh(sandbox) # Also re-check if session workspace exists (another request may have # restored it while we were waiting) if sandbox.status == SandboxStatus.RUNNING: is_healthy = sandbox_manager.health_check(sandbox.id, timeout=10.0) if is_healthy and sandbox_manager.session_workspace_exists( sandbox.id, session_id ): session.status = BuildSessionStatus.ACTIVE update_sandbox_heartbeat(db_session, sandbox.id) base_response = SessionResponse.from_model(session, sandbox) return DetailedSessionResponse.from_session_response( base_response, session_loaded_in_sandbox=True ) if not is_healthy: logger.warning( f"Sandbox {sandbox.id} marked as RUNNING but pod is unhealthy/missing. Entering recovery mode." ) # Terminate to clean up any lingering K8s resources sandbox_manager.terminate(sandbox.id) update_sandbox_status__no_commit( db_session, sandbox.id, SandboxStatus.TERMINATED ) db_session.commit() db_session.refresh(sandbox) # Fall through to TERMINATED handling below session_manager = SessionManager(db_session) llm_config = session_manager._get_llm_config(None, None) if sandbox.status in (SandboxStatus.SLEEPING, SandboxStatus.TERMINATED): # Mark as PROVISIONING before the long-running provision() call # so other requests know work is in progress update_sandbox_status__no_commit( db_session, sandbox.id, SandboxStatus.PROVISIONING ) db_session.commit() sandbox_manager.provision( sandbox_id=sandbox.id, user_id=user.id, tenant_id=tenant_id, llm_config=llm_config, ) # Mark as RUNNING after successful provision update_sandbox_status__no_commit( db_session, sandbox.id, SandboxStatus.RUNNING ) db_session.commit() # 2. Check if session workspace needs to be loaded if sandbox.status == SandboxStatus.RUNNING: workspace_exists = sandbox_manager.session_workspace_exists( sandbox.id, session_id ) if not workspace_exists: # Allocate port if not already set (needed for both snapshot restore and fresh setup) if not session.nextjs_port: session.nextjs_port = allocate_nextjs_port(db_session) # Commit port allocation before long-running operations db_session.commit() # Only Kubernetes backend supports snapshot restoration snapshot = None if SANDBOX_BACKEND == SandboxBackend.KUBERNETES: snapshot = get_latest_snapshot_for_session(db_session, session_id) if snapshot: try: sandbox_manager.restore_snapshot( sandbox_id=sandbox.id, session_id=session_id, snapshot_storage_path=snapshot.storage_path, tenant_id=tenant_id, nextjs_port=session.nextjs_port, llm_config=llm_config, use_demo_data=session.demo_data_enabled, ) session.status = BuildSessionStatus.ACTIVE db_session.commit() except Exception as e: logger.error( f"Snapshot restore failed for session {session_id}: {e}" ) session.nextjs_port = None db_session.commit() raise else: # No snapshot - set up fresh workspace sandbox_manager.setup_session_workspace( sandbox_id=sandbox.id, session_id=session_id, llm_config=llm_config, nextjs_port=session.nextjs_port, ) session.status = BuildSessionStatus.ACTIVE db_session.commit() else: logger.warning( f"Sandbox {sandbox.id} status is {sandbox.status} after re-provision, expected RUNNING" ) except Exception as e: logger.error(f"Failed to restore session {session_id}: {e}", exc_info=True) raise HTTPException( status_code=500, detail=f"Failed to restore session: {e}", ) finally: if lock.owned(): lock.release() # Update heartbeat to mark sandbox as active after successful restore update_sandbox_heartbeat(db_session, sandbox.id) base_response = SessionResponse.from_model(session, sandbox) return DetailedSessionResponse.from_session_response( base_response, session_loaded_in_sandbox=True ) # ============================================================================= # Artifact Endpoints # ============================================================================= @router.get( "/{session_id}/artifacts", response_model=list[ArtifactResponse], ) def list_artifacts( session_id: UUID, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> list[dict]: """List artifacts generated in the session.""" user_id: UUID = user.id session_manager = SessionManager(db_session) artifacts = session_manager.list_artifacts(session_id, user_id) if artifacts is None: raise HTTPException(status_code=404, detail="Session not found") return artifacts @router.get("/{session_id}/files", response_model=DirectoryListing) def list_directory( session_id: UUID, path: str = "", user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> DirectoryListing: """ List files and directories in the sandbox. Args: session_id: The session ID path: Relative path from sandbox root (empty string for root) Returns: DirectoryListing with sorted entries (directories first, then files) """ user_id: UUID = user.id session_manager = SessionManager(db_session) try: listing = session_manager.list_directory(session_id, user_id, path) except ValueError as e: error_message = str(e) if "path traversal" in error_message.lower(): raise HTTPException(status_code=403, detail="Access denied") elif "not found" in error_message.lower(): raise HTTPException(status_code=404, detail="Directory not found") elif "not a directory" in error_message.lower(): raise HTTPException(status_code=400, detail="Path is not a directory") raise HTTPException(status_code=400, detail=error_message) if listing is None: raise HTTPException(status_code=404, detail="Session not found") return listing @router.get("/{session_id}/artifacts/{path:path}") def download_artifact( session_id: UUID, path: str, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> Response: """Download a specific artifact file.""" user_id: UUID = user.id session_manager = SessionManager(db_session) try: result = session_manager.download_artifact(session_id, user_id, path) except ValueError as e: error_message = str(e) if ( "path traversal" in error_message.lower() or "access denied" in error_message.lower() ): raise HTTPException(status_code=403, detail="Access denied") elif "directory" in error_message.lower(): raise HTTPException(status_code=400, detail="Cannot download directory") raise HTTPException(status_code=400, detail=error_message) if result is None: raise HTTPException(status_code=404, detail="Artifact not found") content, mime_type, filename = result # Handle Unicode filenames in Content-Disposition header # HTTP headers require Latin-1 encoding, so we use RFC 5987 for Unicode try: # Try Latin-1 encoding first (ASCII-compatible filenames) filename.encode("latin-1") content_disposition = f'attachment; filename="{filename}"' except UnicodeEncodeError: # Use RFC 5987 encoding for Unicode filenames from urllib.parse import quote encoded_filename = quote(filename, safe="") content_disposition = f"attachment; filename*=UTF-8''{encoded_filename}" return Response( content=content, media_type=mime_type, headers={ "Content-Disposition": content_disposition, }, ) @router.get("/{session_id}/export-docx/{path:path}") def export_docx( session_id: UUID, path: str, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> Response: """Export a markdown file as DOCX.""" session_manager = SessionManager(db_session) try: result = session_manager.export_docx(session_id, user.id, path) except ValueError as e: error_message = str(e) if ( "path traversal" in error_message.lower() or "access denied" in error_message.lower() ): raise HTTPException(status_code=403, detail="Access denied") raise HTTPException(status_code=400, detail=error_message) if result is None: raise HTTPException(status_code=404, detail="File not found") docx_bytes, filename = result try: filename.encode("latin-1") content_disposition = f'attachment; filename="{filename}"' except UnicodeEncodeError: from urllib.parse import quote encoded_filename = quote(filename, safe="") content_disposition = f"attachment; filename*=UTF-8''{encoded_filename}" return Response( content=docx_bytes, media_type="application/vnd.openxmlformats-officedocument.wordprocessingml.document", headers={"Content-Disposition": content_disposition}, ) @router.get("/{session_id}/pptx-preview/{path:path}") def get_pptx_preview( session_id: UUID, path: str, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> PptxPreviewResponse: """Generate slide image previews for a PPTX file.""" session_manager = SessionManager(db_session) try: result = session_manager.get_pptx_preview(session_id, user.id, path) except ValueError as e: error_message = str(e) if ( "path traversal" in error_message.lower() or "access denied" in error_message.lower() ): raise HTTPException(status_code=403, detail="Access denied") raise HTTPException(status_code=400, detail=error_message) if result is None: raise HTTPException(status_code=404, detail="Session not found") return PptxPreviewResponse(**result) @router.get("/{session_id}/webapp-info", response_model=WebappInfo) def get_webapp_info( session_id: UUID, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> WebappInfo: """ Get webapp information for a session. Returns whether a webapp exists, its URL, and the sandbox status. """ user_id: UUID = user.id session_manager = SessionManager(db_session) webapp_info = session_manager.get_webapp_info(session_id, user_id) if webapp_info is None: raise HTTPException(status_code=404, detail="Session not found") return WebappInfo(**webapp_info) @router.get("/{session_id}/webapp-download") def download_webapp( session_id: UUID, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> Response: """ Download the webapp directory as a zip file. Returns the entire outputs/web directory as a zip archive. """ user_id: UUID = user.id session_manager = SessionManager(db_session) result = session_manager.download_webapp_zip(session_id, user_id) if result is None: raise HTTPException(status_code=404, detail="Webapp not found") zip_bytes, filename = result return Response( content=zip_bytes, media_type="application/zip", headers={ "Content-Disposition": f'attachment; filename="{filename}"', }, ) @router.get("/{session_id}/download-directory/{path:path}") def download_directory( session_id: UUID, path: str, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> Response: """ Download a directory as a zip file. Returns the specified directory as a zip archive. """ user_id: UUID = user.id session_manager = SessionManager(db_session) try: result = session_manager.download_directory(session_id, user_id, path) except ValueError as e: error_message = str(e) if "path traversal" in error_message.lower(): raise HTTPException(status_code=403, detail="Access denied") raise HTTPException(status_code=400, detail=error_message) if result is None: raise HTTPException(status_code=404, detail="Directory not found") zip_bytes, filename = result return Response( content=zip_bytes, media_type="application/zip", headers={ "Content-Disposition": f'attachment; filename="{filename}"', }, ) @router.post("/{session_id}/upload", response_model=UploadResponse) def upload_file_endpoint( session_id: UUID, file: UploadFile = File(...), user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> UploadResponse: """Upload a file to the session's sandbox. The file will be placed in the sandbox's attachments directory. """ user_id: UUID = user.id session_manager = SessionManager(db_session) if not file.filename: raise HTTPException(status_code=400, detail="File has no filename") # Read file content (use sync file interface) content = file.file.read() # Validate file (extension, mime type, size) is_valid, error = validate_file(file.filename, file.content_type, len(content)) if not is_valid: raise HTTPException(status_code=400, detail=error) # Sanitize filename safe_filename = sanitize_filename(file.filename) try: relative_path, _ = session_manager.upload_file( session_id=session_id, user_id=user_id, filename=safe_filename, content=content, ) except UploadLimitExceededError as e: # Return 429 for limit exceeded errors raise HTTPException(status_code=429, detail=str(e)) except ValueError as e: error_message = str(e) if "not found" in error_message.lower(): raise HTTPException(status_code=404, detail=error_message) raise HTTPException(status_code=400, detail=error_message) return UploadResponse( filename=safe_filename, path=relative_path, size_bytes=len(content), ) @router.delete("/{session_id}/files/{path:path}", response_model=None) def delete_file_endpoint( session_id: UUID, path: str, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> Response: """Delete a file from the session's sandbox. Args: session_id: The session ID path: Relative path to the file (e.g., "attachments/doc.pdf") """ user_id: UUID = user.id session_manager = SessionManager(db_session) try: deleted = session_manager.delete_file(session_id, user_id, path) except ValueError as e: error_message = str(e) if "path traversal" in error_message.lower(): raise HTTPException(status_code=403, detail="Access denied") elif "not found" in error_message.lower(): raise HTTPException(status_code=404, detail=error_message) elif "directory" in error_message.lower(): raise HTTPException(status_code=400, detail="Cannot delete directory") raise HTTPException(status_code=400, detail=error_message) if not deleted: raise HTTPException(status_code=404, detail="File not found") return Response(status_code=204) ================================================ FILE: backend/onyx/server/features/build/api/subscription_check.py ================================================ """Subscription detection for Build Mode rate limiting.""" from sqlalchemy.orm import Session from onyx.configs.app_configs import DEV_MODE from onyx.db.models import User from onyx.server.usage_limits import is_tenant_on_trial_fn from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() def is_user_subscribed(user: User, db_session: Session) -> bool: # noqa: ARG001 """ Check if a user has an active subscription. For cloud (MULTI_TENANT=true): - Checks Stripe billing via control plane - Returns True if tenant is NOT on trial (subscribed = NOT on trial) For self-hosted (MULTI_TENANT=false): - Checks license metadata - Returns True if license status is ACTIVE Args: user: The user object (None for unauthenticated users) db_session: Database session Returns: True if user has active subscription, False otherwise """ if DEV_MODE: return True if user is None: return False if MULTI_TENANT: # Cloud: check Stripe billing via control plane tenant_id = get_current_tenant_id() try: on_trial = is_tenant_on_trial_fn(tenant_id) # Subscribed = NOT on trial return not on_trial except Exception as e: logger.warning(f"Subscription check failed for tenant {tenant_id}: {e}") # Default to non-subscribed (safer/more restrictive) return False return True ================================================ FILE: backend/onyx/server/features/build/api/templates/webapp_hmr_fixer.js ================================================ (function () { var WEBAPP_BASE = "__WEBAPP_BASE__"; var PROXIED_NEXT_PREFIX = WEBAPP_BASE + "/_next/"; var PROXIED_HMR_PREFIX = WEBAPP_BASE + "/_next/webpack-hmr"; var PROXIED_ALT_HMR_PREFIX = WEBAPP_BASE + "/_next/hmr"; function isHmrWebSocketUrl(url) { if (!url) return false; try { var parsedUrl = new URL(String(url), window.location.href); return ( parsedUrl.pathname.indexOf("/_next/webpack-hmr") === 0 || parsedUrl.pathname.indexOf("/_next/hmr") === 0 || parsedUrl.pathname.indexOf(PROXIED_HMR_PREFIX) === 0 || parsedUrl.pathname.indexOf(PROXIED_ALT_HMR_PREFIX) === 0 ); } catch (e) {} if (typeof url === "string") { return ( url.indexOf("/_next/webpack-hmr") === 0 || url.indexOf("/_next/hmr") === 0 || url.indexOf(PROXIED_HMR_PREFIX) === 0 || url.indexOf(PROXIED_ALT_HMR_PREFIX) === 0 ); } return false; } function rewriteNextAssetUrl(url) { if (!url) return url; try { var parsedUrl = new URL(String(url), window.location.href); if (parsedUrl.pathname.indexOf(PROXIED_NEXT_PREFIX) === 0) { return parsedUrl.pathname + parsedUrl.search + parsedUrl.hash; } if (parsedUrl.pathname.indexOf("/_next/") === 0) { return ( WEBAPP_BASE + parsedUrl.pathname + parsedUrl.search + parsedUrl.hash ); } } catch (e) {} if (typeof url === "string") { if (url.indexOf(PROXIED_NEXT_PREFIX) === 0) { return url; } if (url.indexOf("/_next/") === 0) { return WEBAPP_BASE + url; } } return url; } function createEvent(eventType) { return typeof Event === "function" ? new Event(eventType) : { type: eventType }; } function MockHmrWebSocket(url) { this.url = String(url); this.readyState = 1; this.bufferedAmount = 0; this.extensions = ""; this.protocol = ""; this.binaryType = "blob"; this.onopen = null; this.onmessage = null; this.onerror = null; this.onclose = null; this._l = {}; var socket = this; setTimeout(function () { socket._d("open", createEvent("open")); }, 0); } MockHmrWebSocket.CONNECTING = 0; MockHmrWebSocket.OPEN = 1; MockHmrWebSocket.CLOSING = 2; MockHmrWebSocket.CLOSED = 3; MockHmrWebSocket.prototype.addEventListener = function (eventType, callback) { (this._l[eventType] || (this._l[eventType] = [])).push(callback); }; MockHmrWebSocket.prototype.removeEventListener = function ( eventType, callback, ) { var listeners = this._l[eventType] || []; this._l[eventType] = listeners.filter(function (listener) { return listener !== callback; }); }; MockHmrWebSocket.prototype._d = function (eventType, eventValue) { var listeners = this._l[eventType] || []; for (var i = 0; i < listeners.length; i++) { listeners[i].call(this, eventValue); } var handler = this["on" + eventType]; if (typeof handler === "function") { handler.call(this, eventValue); } }; MockHmrWebSocket.prototype.send = function () {}; MockHmrWebSocket.prototype.close = function (code, reason) { if (this.readyState >= 2) return; this.readyState = 3; var closeEvent = createEvent("close"); closeEvent.code = code === undefined ? 1000 : code; closeEvent.reason = reason || ""; closeEvent.wasClean = true; this._d("close", closeEvent); }; if (window.WebSocket) { var OriginalWebSocket = window.WebSocket; window.WebSocket = function (url, protocols) { if (isHmrWebSocketUrl(url)) { return new MockHmrWebSocket(rewriteNextAssetUrl(url)); } return protocols === undefined ? new OriginalWebSocket(url) : new OriginalWebSocket(url, protocols); }; window.WebSocket.prototype = OriginalWebSocket.prototype; Object.setPrototypeOf(window.WebSocket, OriginalWebSocket); ["CONNECTING", "OPEN", "CLOSING", "CLOSED"].forEach(function (stateKey) { window.WebSocket[stateKey] = OriginalWebSocket[stateKey]; }); } })(); ================================================ FILE: backend/onyx/server/features/build/api/templates/webapp_offline.html ================================================ Craft — Starting up
    crafting_table
    /> Sandbox is asleep...

    Ask the owner to open their Craft session to wake it up.

    ================================================ FILE: backend/onyx/server/features/build/api/user_library.py ================================================ """API endpoints for User Library file management in Craft. This module provides endpoints for uploading and managing raw binary files (xlsx, pptx, docx, csv, etc.) that are stored directly in S3 for sandbox access. Files are stored at: s3://{bucket}/{tenant_id}/knowledge/{user_id}/user_library/{path} And synced to sandbox at: /workspace/files/user_library/{path} Known Issues / TODOs: - Memory: Upload endpoints read entire file content into memory (up to 500MB). Should be refactored to stream uploads directly to S3 via multipart upload for better memory efficiency under concurrent load. - Transaction safety: Multi-file uploads are not atomic. If the endpoint fails mid-batch (e.g., file 3 of 5 exceeds storage quota), files 1-2 are already persisted to S3 and DB. A partial upload is not catastrophic but the response implies atomicity that doesn't exist. """ import hashlib import mimetypes import re import zipfile from datetime import datetime from datetime import timezone from io import BytesIO from typing import Any from fastapi import APIRouter from fastapi import Depends from fastapi import File from fastapi import Form from fastapi import HTTPException from fastapi import Query from fastapi import UploadFile from pydantic import BaseModel from sqlalchemy.orm import Session from onyx.auth.users import current_user from onyx.background.celery.versioned_apps.client import app as celery_app from onyx.configs.constants import DocumentSource from onyx.configs.constants import OnyxCeleryQueues from onyx.configs.constants import OnyxCeleryTask from onyx.db.connector_credential_pair import update_connector_credential_pair from onyx.db.document import upsert_document_by_connector_credential_pair from onyx.db.document import upsert_documents from onyx.db.engine.sql_engine import get_session from onyx.db.enums import ConnectorCredentialPairStatus from onyx.db.models import User from onyx.document_index.interfaces import DocumentMetadata from onyx.server.features.build.configs import USER_LIBRARY_MAX_FILE_SIZE_BYTES from onyx.server.features.build.configs import USER_LIBRARY_MAX_FILES_PER_UPLOAD from onyx.server.features.build.configs import USER_LIBRARY_MAX_TOTAL_SIZE_BYTES from onyx.server.features.build.configs import USER_LIBRARY_SOURCE_DIR from onyx.server.features.build.db.user_library import get_or_create_craft_connector from onyx.server.features.build.db.user_library import get_user_storage_bytes from onyx.server.features.build.indexing.persistent_document_writer import ( get_persistent_document_writer, ) from onyx.server.features.build.indexing.persistent_document_writer import ( PersistentDocumentWriter, ) from onyx.server.features.build.indexing.persistent_document_writer import ( S3PersistentDocumentWriter, ) from onyx.server.features.build.utils import sanitize_filename as api_sanitize_filename from onyx.utils.logger import setup_logger from shared_configs.contextvars import get_current_tenant_id logger = setup_logger() router = APIRouter(prefix="/user-library") # ============================================================================= # Pydantic Models # ============================================================================= class LibraryEntryResponse(BaseModel): """Response for a single library entry (file or directory).""" id: str # document_id name: str path: str is_directory: bool file_size: int | None mime_type: str | None sync_enabled: bool created_at: datetime children: list["LibraryEntryResponse"] | None = None class CreateDirectoryRequest(BaseModel): """Request to create a virtual directory.""" name: str parent_path: str = "/" class UploadResponse(BaseModel): """Response after successful file upload.""" entries: list[LibraryEntryResponse] total_uploaded: int total_size_bytes: int class ToggleSyncResponse(BaseModel): """Response after toggling file sync.""" success: bool sync_enabled: bool class DeleteFileResponse(BaseModel): """Response after deleting a file.""" success: bool deleted: str # ============================================================================= # Helper Functions # ============================================================================= def _sanitize_path(path: str) -> str: """Sanitize a file path, removing traversal attempts and normalizing. Removes '..' and '.' segments and ensures the path starts with '/'. Only allows alphanumeric characters, hyphens, underscores, dots, spaces, and forward slashes. All other characters are stripped. """ parts = path.split("/") sanitized_parts: list[str] = [] for p in parts: if not p or p == ".." or p == ".": continue # Strip any character not in the whitelist cleaned = re.sub(r"[^a-zA-Z0-9\-_. ]", "", p) if cleaned: sanitized_parts.append(cleaned) return "/" + "/".join(sanitized_parts) def _build_document_id(user_id: str, path: str) -> str: """Build a document ID for a craft file. Deterministic: re-uploading the same file to the same path will produce the same document ID, allowing upsert to overwrite the previous record. Uses a hash of the path to avoid collisions from separator replacement (e.g., "/a/b_c" vs "/a_b/c" would collide with naive slash-to-underscore). """ path_hash = hashlib.sha256(path.encode()).hexdigest()[:16] return f"CRAFT_FILE__{user_id}__{path_hash}" def _trigger_sandbox_sync( user_id: str, tenant_id: str, source: str | None = None ) -> None: """Trigger sandbox file sync task. Args: user_id: The user ID whose sandbox should be synced tenant_id: The tenant ID for S3 path construction source: Optional source type (e.g., "user_library"). If specified, only syncs that source's directory with --delete flag. """ celery_app.send_task( OnyxCeleryTask.SANDBOX_FILE_SYNC, kwargs={"user_id": user_id, "tenant_id": tenant_id, "source": source}, queue=OnyxCeleryQueues.SANDBOX, ) def _validate_zip_contents( zip_file: zipfile.ZipFile, existing_usage: int, ) -> None: """Validate zip file contents before extraction. Checks file count limit and total decompressed size against storage quota. Raises HTTPException on validation failure. """ if len(zip_file.namelist()) > USER_LIBRARY_MAX_FILES_PER_UPLOAD: raise HTTPException( status_code=400, detail=f"Zip contains too many files. Maximum is {USER_LIBRARY_MAX_FILES_PER_UPLOAD}.", ) # Zip bomb protection: check total decompressed size before extracting declared_total = sum( info.file_size for info in zip_file.infolist() if not info.is_dir() ) if existing_usage + declared_total > USER_LIBRARY_MAX_TOTAL_SIZE_BYTES: raise HTTPException( status_code=400, detail=( f"Zip decompressed size ({declared_total // (1024 * 1024)}MB) would exceed storage limit." ), ) def _verify_ownership_and_get_document( document_id: str, user: User, db_session: Session, ) -> Any: """Verify the user owns the document and return it. Raises HTTPException on authorization failure or if document not found. """ from onyx.db.document import get_document user_prefix = f"CRAFT_FILE__{user.id}__" if not document_id.startswith(user_prefix): raise HTTPException( status_code=403, detail="Not authorized to modify this file" ) doc = get_document(document_id, db_session) if doc is None: raise HTTPException(status_code=404, detail="File not found") return doc def _store_and_track_file( *, writer: "PersistentDocumentWriter | S3PersistentDocumentWriter", file_path: str, content: bytes, content_type: str | None, user_id: str, connector_id: int, credential_id: int, db_session: Session, ) -> tuple[str, str]: """Write a file to storage and upsert its document record. Returns: Tuple of (document_id, storage_key) """ storage_key = writer.write_raw_file( path=file_path, content=content, content_type=content_type, ) doc_id = _build_document_id(user_id, file_path) doc_metadata = DocumentMetadata( connector_id=connector_id, credential_id=credential_id, document_id=doc_id, semantic_identifier=f"{USER_LIBRARY_SOURCE_DIR}{file_path}", first_link=storage_key, doc_metadata={ "storage_key": storage_key, "file_path": file_path, "file_size": len(content), "mime_type": content_type, "is_directory": False, }, ) upsert_documents(db_session, [doc_metadata]) upsert_document_by_connector_credential_pair( db_session, connector_id, credential_id, [doc_id] ) return doc_id, storage_key # ============================================================================= # API Endpoints # ============================================================================= @router.get("/tree") def get_library_tree( user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> list[LibraryEntryResponse]: """Get user's uploaded files as a tree structure. Returns all CRAFT_FILE documents for the user, organized hierarchically. """ from onyx.db.document import get_documents_by_source # Get CRAFT_FILE documents for this user (filtered at SQL level) user_docs = get_documents_by_source( db_session=db_session, source=DocumentSource.CRAFT_FILE, creator_id=user.id, ) # Build tree structure entries: list[LibraryEntryResponse] = [] now = datetime.now(timezone.utc) for doc in user_docs: doc_metadata = doc.doc_metadata or {} entries.append( LibraryEntryResponse( id=doc.id, name=doc.semantic_id.split("/")[-1] if doc.semantic_id else "unknown", path=doc.semantic_id or "", is_directory=doc_metadata.get("is_directory", False), file_size=doc_metadata.get("file_size"), mime_type=doc_metadata.get("mime_type"), sync_enabled=not doc_metadata.get("sync_disabled", False), created_at=doc.last_modified or now, ) ) return entries @router.post("/upload") async def upload_files( files: list[UploadFile] = File(...), path: str = Form("/"), user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> UploadResponse: """Upload files directly to S3 and track in PostgreSQL. Files are stored as raw binary (no text extraction) for access by the sandbox agent using Python libraries like openpyxl, python-pptx, etc. """ tenant_id = get_current_tenant_id() if tenant_id is None: raise HTTPException(status_code=500, detail="Tenant ID not found") # Validate file count if len(files) > USER_LIBRARY_MAX_FILES_PER_UPLOAD: raise HTTPException( status_code=400, detail=f"Too many files. Maximum is {USER_LIBRARY_MAX_FILES_PER_UPLOAD} per upload.", ) # Check cumulative storage usage existing_usage = get_user_storage_bytes(db_session, user.id) # Get or create connector connector_id, credential_id = get_or_create_craft_connector(db_session, user) # Get the persistent document writer writer = get_persistent_document_writer( user_id=str(user.id), tenant_id=tenant_id, ) uploaded_entries: list[LibraryEntryResponse] = [] total_size = 0 now = datetime.now(timezone.utc) # Sanitize the base path base_path = _sanitize_path(path) for file in files: # TODO: Stream directly to S3 via multipart upload instead of reading # entire file into memory. With 500MB max file size, this can OOM under # concurrent uploads. content = await file.read() file_size = len(content) # Validate individual file size if file_size > USER_LIBRARY_MAX_FILE_SIZE_BYTES: raise HTTPException( status_code=400, detail=f"File '{file.filename}' exceeds maximum size of {USER_LIBRARY_MAX_FILE_SIZE_BYTES // (1024 * 1024)}MB", ) # Validate cumulative storage (existing + this upload batch) total_size += file_size if existing_usage + total_size > USER_LIBRARY_MAX_TOTAL_SIZE_BYTES: raise HTTPException( status_code=400, detail=f"Total storage would exceed maximum of {USER_LIBRARY_MAX_TOTAL_SIZE_BYTES // (1024 * 1024 * 1024)}GB", ) # Sanitize filename safe_filename = api_sanitize_filename(file.filename or "unnamed") file_path = f"{base_path}/{safe_filename}".replace("//", "/") doc_id, _ = _store_and_track_file( writer=writer, file_path=file_path, content=content, content_type=file.content_type, user_id=str(user.id), connector_id=connector_id, credential_id=credential_id, db_session=db_session, ) uploaded_entries.append( LibraryEntryResponse( id=doc_id, name=safe_filename, path=file_path, is_directory=False, file_size=file_size, mime_type=file.content_type, sync_enabled=True, created_at=now, ) ) # Mark connector as having succeeded (sets last_successful_index_time) # This allows the demo data toggle to be disabled update_connector_credential_pair( db_session=db_session, connector_id=connector_id, credential_id=credential_id, status=ConnectorCredentialPairStatus.ACTIVE, net_docs=len(uploaded_entries), run_dt=now, ) # Trigger sandbox sync for user_library source only _trigger_sandbox_sync(str(user.id), tenant_id, source=USER_LIBRARY_SOURCE_DIR) logger.info( f"Uploaded {len(uploaded_entries)} files ({total_size} bytes) for user {user.id}" ) return UploadResponse( entries=uploaded_entries, total_uploaded=len(uploaded_entries), total_size_bytes=total_size, ) @router.post("/upload-zip") async def upload_zip( file: UploadFile = File(...), path: str = Form("/"), user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> UploadResponse: """Upload and extract a zip file, storing each extracted file to S3. Preserves the directory structure from the zip file. """ tenant_id = get_current_tenant_id() if tenant_id is None: raise HTTPException(status_code=500, detail="Tenant ID not found") # Read zip content content = await file.read() if len(content) > USER_LIBRARY_MAX_TOTAL_SIZE_BYTES: raise HTTPException( status_code=400, detail=f"Zip file exceeds maximum size of {USER_LIBRARY_MAX_TOTAL_SIZE_BYTES // (1024 * 1024 * 1024)}GB", ) # Check cumulative storage usage existing_usage = get_user_storage_bytes(db_session, user.id) # Get or create connector connector_id, credential_id = get_or_create_craft_connector(db_session, user) # Get the persistent document writer writer = get_persistent_document_writer( user_id=str(user.id), tenant_id=tenant_id, ) uploaded_entries: list[LibraryEntryResponse] = [] total_size = 0 # Extract zip contents into a subfolder named after the zip file zip_name = api_sanitize_filename(file.filename or "upload") if zip_name.lower().endswith(".zip"): zip_name = zip_name[:-4] folder_path = f"{_sanitize_path(path)}/{zip_name}".replace("//", "/") base_path = folder_path now = datetime.now(timezone.utc) # Track all directory paths we need to create records for directory_paths: set[str] = set() try: with zipfile.ZipFile(BytesIO(content), "r") as zip_file: _validate_zip_contents(zip_file, existing_usage) for zip_info in zip_file.infolist(): # Skip hidden files and __MACOSX if ( zip_info.filename.startswith("__MACOSX") or "/." in zip_info.filename ): continue # Skip directories - we'll create records from file paths below if zip_info.is_dir(): continue # Read file content file_content = zip_file.read(zip_info.filename) file_size = len(file_content) # Validate individual file size if file_size > USER_LIBRARY_MAX_FILE_SIZE_BYTES: logger.warning(f"Skipping '{zip_info.filename}' - exceeds max size") continue total_size += file_size # Validate cumulative storage if existing_usage + total_size > USER_LIBRARY_MAX_TOTAL_SIZE_BYTES: raise HTTPException( status_code=400, detail=f"Total storage would exceed maximum of {USER_LIBRARY_MAX_TOTAL_SIZE_BYTES // (1024 * 1024 * 1024)}GB", ) # Build path preserving zip structure sanitized_zip_path = _sanitize_path(zip_info.filename) file_path = f"{base_path}{sanitized_zip_path}".replace("//", "/") file_name = file_path.split("/")[-1] # Collect all intermediate directories for this file parts = file_path.split("/") for i in range( 2, len(parts) ): # start at 2 to skip empty + first segment directory_paths.add("/".join(parts[:i])) # Guess content type content_type, _ = mimetypes.guess_type(file_name) doc_id, _ = _store_and_track_file( writer=writer, file_path=file_path, content=file_content, content_type=content_type, user_id=str(user.id), connector_id=connector_id, credential_id=credential_id, db_session=db_session, ) uploaded_entries.append( LibraryEntryResponse( id=doc_id, name=file_name, path=file_path, is_directory=False, file_size=file_size, mime_type=content_type, sync_enabled=True, created_at=now, ) ) except zipfile.BadZipFile: raise HTTPException(status_code=400, detail="Invalid zip file") # Create directory document records so they appear in the tree view if directory_paths: dir_doc_ids: list[str] = [] for dir_path in sorted(directory_paths): dir_doc_id = _build_document_id(str(user.id), dir_path) dir_doc_ids.append(dir_doc_id) dir_metadata = DocumentMetadata( connector_id=connector_id, credential_id=credential_id, document_id=dir_doc_id, semantic_identifier=f"{USER_LIBRARY_SOURCE_DIR}{dir_path}", first_link="", doc_metadata={"is_directory": True}, ) upsert_documents(db_session, [dir_metadata]) upsert_document_by_connector_credential_pair( db_session, connector_id, credential_id, dir_doc_ids ) # Mark connector as having succeeded (sets last_successful_index_time) # This allows the demo data toggle to be disabled update_connector_credential_pair( db_session=db_session, connector_id=connector_id, credential_id=credential_id, status=ConnectorCredentialPairStatus.ACTIVE, net_docs=len(uploaded_entries), run_dt=now, ) # Trigger sandbox sync for user_library source only _trigger_sandbox_sync(str(user.id), tenant_id, source=USER_LIBRARY_SOURCE_DIR) logger.info( f"Extracted {len(uploaded_entries)} files ({total_size} bytes) from zip for user {user.id}" ) return UploadResponse( entries=uploaded_entries, total_uploaded=len(uploaded_entries), total_size_bytes=total_size, ) @router.post("/directories") def create_directory( request: CreateDirectoryRequest, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> LibraryEntryResponse: """Create a virtual directory. Directories are tracked as documents with is_directory=True. No S3 object is created (S3 doesn't have real directories). """ # Get or create connector connector_id, credential_id = get_or_create_craft_connector(db_session, user) # Build path parent_path = _sanitize_path(request.parent_path) safe_name = api_sanitize_filename(request.name) dir_path = f"{parent_path}/{safe_name}".replace("//", "/") # Track in document table doc_id = _build_document_id(str(user.id), dir_path) doc_metadata = DocumentMetadata( connector_id=connector_id, credential_id=credential_id, document_id=doc_id, semantic_identifier=f"{USER_LIBRARY_SOURCE_DIR}{dir_path}", first_link="", doc_metadata={ "is_directory": True, }, ) upsert_documents(db_session, [doc_metadata]) upsert_document_by_connector_credential_pair( db_session, connector_id, credential_id, [doc_id] ) db_session.commit() return LibraryEntryResponse( id=doc_id, name=safe_name, path=dir_path, is_directory=True, file_size=None, mime_type=None, sync_enabled=True, created_at=datetime.now(timezone.utc), ) @router.patch("/files/{document_id}/toggle") def toggle_file_sync( document_id: str, enabled: bool = Query(...), user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> ToggleSyncResponse: """Enable/disable syncing a file to sandboxes. When sync is disabled, the file's metadata is updated with sync_disabled=True. The sandbox sync task will exclude these files when syncing to the sandbox. If the item is a directory, all children are also toggled. """ from onyx.db.document import get_documents_by_source from onyx.db.document import update_document_metadata__no_commit tenant_id = get_current_tenant_id() if tenant_id is None: raise HTTPException(status_code=500, detail="Tenant ID not found") doc = _verify_ownership_and_get_document(document_id, user, db_session) # Update metadata for this document new_metadata = dict(doc.doc_metadata or {}) new_metadata["sync_disabled"] = not enabled update_document_metadata__no_commit(db_session, document_id, new_metadata) # If this is a directory, also toggle all children doc_metadata = doc.doc_metadata or {} if doc_metadata.get("is_directory"): folder_path = doc.semantic_id if folder_path: all_docs = get_documents_by_source( db_session=db_session, source=DocumentSource.CRAFT_FILE, creator_id=user.id, ) for child_doc in all_docs: if child_doc.semantic_id and child_doc.semantic_id.startswith( folder_path + "/" ): child_metadata = dict(child_doc.doc_metadata or {}) child_metadata["sync_disabled"] = not enabled update_document_metadata__no_commit( db_session, child_doc.id, child_metadata ) db_session.commit() return ToggleSyncResponse(success=True, sync_enabled=enabled) @router.delete("/files/{document_id}") def delete_file( document_id: str, user: User = Depends(current_user), db_session: Session = Depends(get_session), ) -> DeleteFileResponse: """Delete a file from both S3 and the document table.""" from onyx.db.document import delete_document_by_id__no_commit tenant_id = get_current_tenant_id() if tenant_id is None: raise HTTPException(status_code=500, detail="Tenant ID not found") doc = _verify_ownership_and_get_document(document_id, user, db_session) # Delete from storage if it's a file (not directory) doc_metadata = doc.doc_metadata or {} if not doc_metadata.get("is_directory"): file_path = doc_metadata.get("file_path") if file_path: writer = get_persistent_document_writer( user_id=str(user.id), tenant_id=tenant_id, ) try: if isinstance(writer, S3PersistentDocumentWriter): writer.delete_raw_file_by_path(file_path) else: writer.delete_raw_file(file_path) except Exception as e: logger.warning(f"Failed to delete file at path {file_path}: {e}") else: # Fallback for documents created before file_path was stored storage_key = doc_metadata.get("storage_key") or doc_metadata.get("s3_key") if storage_key: writer = get_persistent_document_writer( user_id=str(user.id), tenant_id=tenant_id, ) try: if isinstance(writer, S3PersistentDocumentWriter): writer.delete_raw_file(storage_key) else: logger.warning( f"Cannot delete file in local mode without file_path: {document_id}" ) except Exception as e: logger.warning( f"Failed to delete storage object {storage_key}: {e}" ) # Delete from document table delete_document_by_id__no_commit(db_session, document_id) db_session.commit() # Trigger sync to apply changes _trigger_sandbox_sync(str(user.id), tenant_id, source=USER_LIBRARY_SOURCE_DIR) return DeleteFileResponse(success=True, deleted=document_id) ================================================ FILE: backend/onyx/server/features/build/configs.py ================================================ import os from enum import Enum from pathlib import Path class SandboxBackend(str, Enum): """Backend mode for sandbox operations. LOCAL: Development mode - no snapshots, no automatic cleanup KUBERNETES: Production mode - full snapshots and cleanup """ LOCAL = "local" KUBERNETES = "kubernetes" # Sandbox backend mode (controls snapshot and cleanup behavior) # "local" = no snapshots, no cleanup (for development) # "kubernetes" = full snapshots and cleanup (for production) SANDBOX_BACKEND = SandboxBackend(os.environ.get("SANDBOX_BACKEND", "local")) # Base directory path for persistent document storage (local filesystem) # Example: /var/onyx/file-system or /app/file-system PERSISTENT_DOCUMENT_STORAGE_PATH = os.environ.get( "PERSISTENT_DOCUMENT_STORAGE_PATH", "/app/file-system" ) # Demo Data Path # Local: Source tree path (relative to this file) # Kubernetes: Baked into container image at /workspace/demo_data _THIS_FILE = Path(__file__) DEMO_DATA_PATH = str( _THIS_FILE.parent / "sandbox" / "kubernetes" / "docker" / "demo_data" ) # Sandbox filesystem paths SANDBOX_BASE_PATH = os.environ.get("SANDBOX_BASE_PATH", "/tmp/onyx-sandboxes") OUTPUTS_TEMPLATE_PATH = os.environ.get("OUTPUTS_TEMPLATE_PATH", "/templates/outputs") VENV_TEMPLATE_PATH = os.environ.get("VENV_TEMPLATE_PATH", "/templates/venv") # Sandbox agent configuration SANDBOX_AGENT_COMMAND = os.environ.get("SANDBOX_AGENT_COMMAND", "opencode").split() # OpenCode disabled tools (comma-separated list) # Available tools: bash, edit, write, read, grep, glob, list, lsp, patch, # skill, todowrite, todoread, webfetch, question # Example: "question,webfetch" to disable user questions and web fetching _disabled_tools_str = os.environ.get("OPENCODE_DISABLED_TOOLS", "question") OPENCODE_DISABLED_TOOLS: list[str] = [ t.strip() for t in _disabled_tools_str.split(",") if t.strip() ] # Sandbox lifecycle configuration SANDBOX_IDLE_TIMEOUT_SECONDS = int( os.environ.get("SANDBOX_IDLE_TIMEOUT_SECONDS", "3600") ) SANDBOX_MAX_CONCURRENT_PER_ORG = int( os.environ.get("SANDBOX_MAX_CONCURRENT_PER_ORG", "10") ) # Sandbox snapshot storage SANDBOX_SNAPSHOTS_BUCKET = os.environ.get( "SANDBOX_SNAPSHOTS_BUCKET", "sandbox-snapshots" ) # Next.js preview server port range SANDBOX_NEXTJS_PORT_START = int(os.environ.get("SANDBOX_NEXTJS_PORT_START", "3010")) SANDBOX_NEXTJS_PORT_END = int(os.environ.get("SANDBOX_NEXTJS_PORT_END", "3100")) # File upload configuration MAX_UPLOAD_FILE_SIZE_MB = int(os.environ.get("BUILD_MAX_UPLOAD_FILE_SIZE_MB", "50")) MAX_UPLOAD_FILE_SIZE_BYTES = MAX_UPLOAD_FILE_SIZE_MB * 1024 * 1024 MAX_UPLOAD_FILES_PER_SESSION = int( os.environ.get("BUILD_MAX_UPLOAD_FILES_PER_SESSION", "20") ) MAX_TOTAL_UPLOAD_SIZE_MB = int(os.environ.get("BUILD_MAX_TOTAL_UPLOAD_SIZE_MB", "200")) MAX_TOTAL_UPLOAD_SIZE_BYTES = MAX_TOTAL_UPLOAD_SIZE_MB * 1024 * 1024 ATTACHMENTS_DIRECTORY = "attachments" # ============================================================================ # Kubernetes Sandbox Configuration # Only used when SANDBOX_BACKEND = "kubernetes" # ============================================================================ # Namespace where sandbox pods are created SANDBOX_NAMESPACE = os.environ.get("SANDBOX_NAMESPACE", "onyx-sandboxes") # Container image for sandbox pods # Should include Next.js template, opencode CLI, and demo_data zip SANDBOX_CONTAINER_IMAGE = os.environ.get( "SANDBOX_CONTAINER_IMAGE", "onyxdotapp/sandbox:v0.1.5" ) # S3 bucket for sandbox file storage (snapshots, knowledge files, uploads) # Path structure: s3://{bucket}/{tenant_id}/snapshots/{session_id}/{snapshot_id}.tar.gz # s3://{bucket}/{tenant_id}/knowledge/{user_id}/ # s3://{bucket}/{tenant_id}/uploads/{session_id}/ SANDBOX_S3_BUCKET = os.environ.get("SANDBOX_S3_BUCKET", "onyx-sandbox-files") # Service account for sandbox pods (NO IRSA - no AWS API access) SANDBOX_SERVICE_ACCOUNT_NAME = os.environ.get( "SANDBOX_SERVICE_ACCOUNT_NAME", "sandbox-runner" ) # Service account for init container (has IRSA for S3 access) SANDBOX_FILE_SYNC_SERVICE_ACCOUNT = os.environ.get( "SANDBOX_FILE_SYNC_SERVICE_ACCOUNT", "sandbox-file-sync" ) ENABLE_CRAFT = os.environ.get("ENABLE_CRAFT", "false").lower() == "true" # ============================================================================ # SSE Streaming Configuration # ============================================================================ # SSE keepalive interval in seconds - send keepalive comment if no events SSE_KEEPALIVE_INTERVAL = float(os.environ.get("SSE_KEEPALIVE_INTERVAL", "15.0")) # ============================================================================ # ACP (Agent Communication Protocol) Configuration # ============================================================================ # Timeout for ACP message processing in seconds # This is the maximum time to wait for a complete response from the agent ACP_MESSAGE_TIMEOUT = float(os.environ.get("ACP_MESSAGE_TIMEOUT", "900.0")) # ============================================================================ # Rate Limiting Configuration # ============================================================================ # Base rate limit for paid/subscribed users (messages per week) # Free users always get 5 messages total (not configurable) # Per-user overrides are managed via PostHog feature flag "craft-has-usage-limits" CRAFT_PAID_USER_RATE_LIMIT = int(os.environ.get("CRAFT_PAID_USER_RATE_LIMIT", "25")) # ============================================================================ # User Library Configuration # For user-uploaded raw files (xlsx, pptx, docx, etc.) in Craft # ============================================================================ # Maximum size per file in MB (default 500MB) USER_LIBRARY_MAX_FILE_SIZE_MB = int( os.environ.get("USER_LIBRARY_MAX_FILE_SIZE_MB", "500") ) USER_LIBRARY_MAX_FILE_SIZE_BYTES = USER_LIBRARY_MAX_FILE_SIZE_MB * 1024 * 1024 # Maximum total storage per user in GB (default 10GB) USER_LIBRARY_MAX_TOTAL_SIZE_GB = int( os.environ.get("USER_LIBRARY_MAX_TOTAL_SIZE_GB", "10") ) USER_LIBRARY_MAX_TOTAL_SIZE_BYTES = USER_LIBRARY_MAX_TOTAL_SIZE_GB * 1024 * 1024 * 1024 # Maximum files per single upload request (default 100) USER_LIBRARY_MAX_FILES_PER_UPLOAD = int( os.environ.get("USER_LIBRARY_MAX_FILES_PER_UPLOAD", "100") ) # String constants for User Library entities USER_LIBRARY_CONNECTOR_NAME = "User Library" USER_LIBRARY_CREDENTIAL_NAME = "User Library Credential" USER_LIBRARY_SOURCE_DIR = "user_library" ================================================ FILE: backend/onyx/server/features/build/db/__init__.py ================================================ # Database operations for the build feature ================================================ FILE: backend/onyx/server/features/build/db/build_session.py ================================================ """Database operations for Build Mode sessions.""" from datetime import datetime from typing import Any from uuid import UUID from sqlalchemy import desc from sqlalchemy import exists from sqlalchemy import select from sqlalchemy.orm import selectinload from sqlalchemy.orm import Session from onyx.configs.constants import MessageType from onyx.db.enums import BuildSessionStatus from onyx.db.enums import SandboxStatus from onyx.db.enums import SharingScope from onyx.db.models import Artifact from onyx.db.models import BuildMessage from onyx.db.models import BuildSession from onyx.db.models import LLMProvider as LLMProviderModel from onyx.db.models import Sandbox from onyx.server.features.build.configs import SANDBOX_NEXTJS_PORT_END from onyx.server.features.build.configs import SANDBOX_NEXTJS_PORT_START from onyx.server.manage.llm.models import LLMProviderView from onyx.utils.logger import setup_logger logger = setup_logger() def create_build_session__no_commit( user_id: UUID, db_session: Session, name: str | None = None, demo_data_enabled: bool = True, ) -> BuildSession: """Create a new build session for the given user. NOTE: This function uses flush() instead of commit(). The caller is responsible for committing the transaction when ready. Args: user_id: The user ID db_session: Database session name: Optional session name demo_data_enabled: Whether this session uses demo data (default True) """ session = BuildSession( user_id=user_id, name=name, status=BuildSessionStatus.ACTIVE, demo_data_enabled=demo_data_enabled, ) db_session.add(session) db_session.flush() logger.info( f"Created build session {session.id} for user {user_id} (demo_data={demo_data_enabled})" ) return session def get_build_session( session_id: UUID, user_id: UUID, db_session: Session, ) -> BuildSession | None: """Get a build session by ID, ensuring it belongs to the user.""" return ( db_session.query(BuildSession) .filter( BuildSession.id == session_id, BuildSession.user_id == user_id, ) .one_or_none() ) def get_user_build_sessions( user_id: UUID, db_session: Session, limit: int = 100, ) -> list[BuildSession]: """Get all build sessions for a user that have at least one message. Excludes empty (pre-provisioned) sessions from the listing. """ # Subquery to check if session has any messages has_messages = exists().where(BuildMessage.session_id == BuildSession.id) return ( db_session.query(BuildSession) .filter( BuildSession.user_id == user_id, has_messages, # Only sessions with messages ) .order_by(desc(BuildSession.created_at)) .limit(limit) .all() ) def get_empty_session_for_user( user_id: UUID, db_session: Session, demo_data_enabled: bool | None = None, ) -> BuildSession | None: """Get an empty (pre-provisioned) session for the user if one exists. Returns a session with no messages, or None if all sessions have messages. Args: user_id: The user ID db_session: Database session demo_data_enabled: Match sessions with this demo_data setting. If None, matches any session regardless of setting. """ # Subquery to check if session has any messages has_messages = exists().where(BuildMessage.session_id == BuildSession.id) query = db_session.query(BuildSession).filter( BuildSession.user_id == user_id, ~has_messages, # Sessions with no messages only ) if demo_data_enabled is not None: query = query.filter(BuildSession.demo_data_enabled == demo_data_enabled) return query.first() def update_session_activity( session_id: UUID, db_session: Session, ) -> None: """Update the last activity timestamp for a session.""" session = ( db_session.query(BuildSession) .filter(BuildSession.id == session_id) .one_or_none() ) if session: session.last_activity_at = datetime.utcnow() db_session.commit() def update_session_status( session_id: UUID, status: BuildSessionStatus, db_session: Session, ) -> None: """Update the status of a build session.""" session = ( db_session.query(BuildSession) .filter(BuildSession.id == session_id) .one_or_none() ) if session: session.status = status db_session.commit() logger.info(f"Updated build session {session_id} status to {status}") def set_build_session_sharing_scope( session_id: UUID, user_id: UUID, sharing_scope: SharingScope, db_session: Session, ) -> BuildSession | None: """Set the sharing scope of a build session. Only the session owner can change this setting. Returns the updated session, or None if not found/unauthorized. """ session = get_build_session(session_id, user_id, db_session) if not session: return None session.sharing_scope = sharing_scope db_session.commit() logger.info(f"Set build session {session_id} sharing_scope={sharing_scope}") return session def delete_build_session__no_commit( session_id: UUID, user_id: UUID, db_session: Session, ) -> bool: """Delete a build session and all related data. NOTE: This function uses flush() instead of commit(). The caller is responsible for committing the transaction when ready. """ session = get_build_session(session_id, user_id, db_session) if not session: return False db_session.delete(session) db_session.flush() logger.info(f"Deleted build session {session_id}") return True # Sandbox operations # NOTE: Most sandbox operations have moved to sandbox.py # These remain here for convenience in session-related workflows def update_sandbox_status( sandbox_id: UUID, status: SandboxStatus, db_session: Session, container_id: str | None = None, ) -> None: """Update the status of a sandbox.""" sandbox = db_session.query(Sandbox).filter(Sandbox.id == sandbox_id).one_or_none() if sandbox: sandbox.status = status if container_id is not None: sandbox.container_id = container_id sandbox.last_heartbeat = datetime.utcnow() db_session.commit() logger.info(f"Updated sandbox {sandbox_id} status to {status}") def update_sandbox_heartbeat( sandbox_id: UUID, db_session: Session, ) -> None: """Update the heartbeat timestamp for a sandbox.""" sandbox = db_session.query(Sandbox).filter(Sandbox.id == sandbox_id).one_or_none() if sandbox: sandbox.last_heartbeat = datetime.utcnow() db_session.commit() # Artifact operations def create_artifact( session_id: UUID, artifact_type: str, path: str, name: str, db_session: Session, ) -> Artifact: """Create a new artifact record.""" artifact = Artifact( session_id=session_id, type=artifact_type, path=path, name=name, ) db_session.add(artifact) db_session.commit() db_session.refresh(artifact) logger.info(f"Created artifact {artifact.id} for session {session_id}") return artifact def get_session_artifacts( session_id: UUID, db_session: Session, ) -> list[Artifact]: """Get all artifacts for a session.""" return ( db_session.query(Artifact) .filter(Artifact.session_id == session_id) .order_by(desc(Artifact.created_at)) .all() ) def update_artifact( artifact_id: UUID, db_session: Session, path: str | None = None, name: str | None = None, ) -> None: """Update artifact metadata.""" artifact = ( db_session.query(Artifact).filter(Artifact.id == artifact_id).one_or_none() ) if artifact: if path is not None: artifact.path = path if name is not None: artifact.name = name artifact.updated_at = datetime.utcnow() db_session.commit() logger.info(f"Updated artifact {artifact_id}") # Message operations def create_message( session_id: UUID, message_type: MessageType, turn_index: int, message_metadata: dict[str, Any], db_session: Session, ) -> BuildMessage: """Create a new message in a build session. All message data is stored in message_metadata as JSON. Args: session_id: Session UUID message_type: Type of message (USER, ASSISTANT, SYSTEM) turn_index: 0-indexed user message number this message belongs to message_metadata: Required structured data (the raw ACP packet JSON) db_session: Database session """ message = BuildMessage( session_id=session_id, turn_index=turn_index, type=message_type, message_metadata=message_metadata, ) db_session.add(message) db_session.commit() db_session.refresh(message) logger.info( f"Created {message_type.value} message {message.id} for session {session_id} " f"turn={turn_index} type={message_metadata.get('type')}" ) return message def update_message( message_id: UUID, message_metadata: dict[str, Any], db_session: Session, ) -> BuildMessage | None: """Update an existing message's metadata. Used for upserting agent_plan_update messages. Args: message_id: The message UUID to update message_metadata: New metadata to set db_session: Database session Returns: Updated BuildMessage or None if not found """ message = ( db_session.query(BuildMessage).filter(BuildMessage.id == message_id).first() ) if message is None: return None message.message_metadata = message_metadata db_session.commit() db_session.refresh(message) logger.info( f"Updated message {message_id} metadata type={message_metadata.get('type')}" ) return message def upsert_agent_plan( session_id: UUID, turn_index: int, plan_metadata: dict[str, Any], db_session: Session, existing_plan_id: UUID | None = None, ) -> BuildMessage: """Upsert an agent plan - update if exists, create if not. Each session/turn should only have one agent_plan_update message. This function updates the existing plan message or creates a new one. Args: session_id: Session UUID turn_index: Current turn index plan_metadata: The agent_plan_update packet data db_session: Database session existing_plan_id: ID of existing plan message to update (if known) Returns: The created or updated BuildMessage """ if existing_plan_id: # Fast path: we know the plan ID updated = update_message(existing_plan_id, plan_metadata, db_session) if updated: return updated # Check if a plan already exists for this session/turn existing_plan = ( db_session.query(BuildMessage) .filter( BuildMessage.session_id == session_id, BuildMessage.turn_index == turn_index, BuildMessage.message_metadata["type"].astext == "agent_plan_update", ) .first() ) if existing_plan: existing_plan.message_metadata = plan_metadata db_session.commit() db_session.refresh(existing_plan) logger.info( f"Updated agent_plan_update message {existing_plan.id} for session {session_id}" ) return existing_plan # Create new plan message return create_message( session_id=session_id, message_type=MessageType.ASSISTANT, turn_index=turn_index, message_metadata=plan_metadata, db_session=db_session, ) def get_session_messages( session_id: UUID, db_session: Session, ) -> list[BuildMessage]: """Get all messages for a session, ordered by turn index and creation time.""" return ( db_session.query(BuildMessage) .filter(BuildMessage.session_id == session_id) .order_by(BuildMessage.turn_index, BuildMessage.created_at) .all() ) def _is_port_available(port: int) -> bool: """Check if a port is available by attempting to bind to it. Checks both IPv4 and IPv6 wildcard addresses to properly detect if anything is listening on the port, regardless of address family. """ import socket logger.debug(f"Checking if port {port} is available") # Check IPv4 wildcard (0.0.0.0) - this will detect any IPv4 listener try: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) sock.bind(("0.0.0.0", port)) logger.debug(f"Port {port} IPv4 wildcard bind successful") except OSError as e: logger.debug(f"Port {port} IPv4 wildcard not available: {e}") return False # Check IPv6 wildcard (::) - this will detect any IPv6 listener try: with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as sock: sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) # IPV6_V6ONLY must be False to allow dual-stack behavior sock.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0) sock.bind(("::", port)) logger.debug(f"Port {port} IPv6 wildcard bind successful") except OSError as e: logger.debug(f"Port {port} IPv6 wildcard not available: {e}") return False logger.debug(f"Port {port} is available") return True def allocate_nextjs_port(db_session: Session) -> int: """Allocate an available port for a new session. Finds the first available port in the configured range by checking both database allocations and system-level port availability. Args: db_session: Database session for querying allocated ports Returns: An available port number Raises: RuntimeError: If no ports are available in the configured range """ from onyx.db.models import BuildSession # Get all currently allocated ports from active sessions allocated_ports = set( db_session.query(BuildSession.nextjs_port) .filter(BuildSession.nextjs_port.isnot(None)) .all() ) allocated_ports = {port[0] for port in allocated_ports if port[0] is not None} # Find first port that's not in DB and not currently bound for port in range(SANDBOX_NEXTJS_PORT_START, SANDBOX_NEXTJS_PORT_END): if port not in allocated_ports and _is_port_available(port): return port raise RuntimeError( f"No available ports in range [{SANDBOX_NEXTJS_PORT_START}, {SANDBOX_NEXTJS_PORT_END})" ) def mark_user_sessions_idle__no_commit(db_session: Session, user_id: UUID) -> int: """Mark all ACTIVE sessions for a user as IDLE. Called when a sandbox goes to sleep so the frontend knows these sessions need restoration before they can be used again. Args: db_session: Database session user_id: The user whose sessions should be marked idle Returns: Number of sessions updated """ result = ( db_session.query(BuildSession) .filter( BuildSession.user_id == user_id, BuildSession.status == BuildSessionStatus.ACTIVE, ) .update({BuildSession.status: BuildSessionStatus.IDLE}) ) db_session.flush() logger.info(f"Marked {result} sessions as IDLE for user {user_id}") return result def clear_nextjs_ports_for_user(db_session: Session, user_id: UUID) -> int: """Clear nextjs_port for all sessions belonging to a user. Called when sandbox goes to sleep to release port allocations. Args: db_session: Database session user_id: The user whose sessions should have ports cleared Returns: Number of sessions updated """ result = ( db_session.query(BuildSession) .filter( BuildSession.user_id == user_id, BuildSession.nextjs_port.isnot(None), ) .update({BuildSession.nextjs_port: None}) ) db_session.flush() logger.info(f"Cleared {result} nextjs_port allocations for user {user_id}") return result def fetch_llm_provider_by_type_for_build_mode( db_session: Session, provider_type: str ) -> LLMProviderView | None: """Fetch an LLM provider by its provider type (e.g., "anthropic", "openai"). Resolution priority: 1. First try to find a provider named "build-mode-{type}" (e.g., "build-mode-anthropic") 2. If not found, fall back to any provider that matches the type Args: db_session: Database session provider_type: The provider type (e.g., "anthropic", "openai", "openrouter") Returns: LLMProviderView if found, None otherwise """ from onyx.db.llm import fetch_existing_llm_provider # First try to find a "build-mode-{type}" provider build_mode_name = f"build-mode-{provider_type}" provider_model = fetch_existing_llm_provider( name=build_mode_name, db_session=db_session ) # If not found, fall back to any provider that matches the type if not provider_model: provider_model = db_session.scalar( select(LLMProviderModel) .where(LLMProviderModel.provider == provider_type) .options( selectinload(LLMProviderModel.model_configurations), selectinload(LLMProviderModel.groups), selectinload(LLMProviderModel.personas), ) ) if not provider_model: return None return LLMProviderView.from_model(provider_model) ================================================ FILE: backend/onyx/server/features/build/db/rate_limit.py ================================================ """Database queries for Build Mode rate limiting.""" from datetime import datetime from uuid import UUID from sqlalchemy import func from sqlalchemy.orm import Session from onyx.configs.constants import MessageType from onyx.db.models import BuildMessage from onyx.db.models import BuildSession def count_user_messages_in_window( user_id: UUID, cutoff_time: datetime, db_session: Session, ) -> int: """ Count USER messages for a user since cutoff_time. Args: user_id: The user's UUID cutoff_time: Only count messages created at or after this time db_session: Database session Returns: Number of USER messages in the time window """ return ( db_session.query(func.count(BuildMessage.id)) .join(BuildSession, BuildMessage.session_id == BuildSession.id) .filter( BuildSession.user_id == user_id, BuildMessage.type == MessageType.USER, BuildMessage.created_at >= cutoff_time, ) .scalar() or 0 ) def count_user_messages_total(user_id: UUID, db_session: Session) -> int: """ Count all USER messages for a user (lifetime total). Args: user_id: The user's UUID db_session: Database session Returns: Total number of USER messages """ return ( db_session.query(func.count(BuildMessage.id)) .join(BuildSession, BuildMessage.session_id == BuildSession.id) .filter( BuildSession.user_id == user_id, BuildMessage.type == MessageType.USER, ) .scalar() or 0 ) def get_oldest_message_timestamp( user_id: UUID, cutoff_time: datetime, db_session: Session, ) -> datetime | None: """ Get the timestamp of the oldest USER message in the time window. Used to calculate when the rate limit will reset (when the oldest message ages out of the rolling window). Args: user_id: The user's UUID cutoff_time: Only consider messages created at or after this time db_session: Database session Returns: Timestamp of oldest message in window, or None if no messages """ return ( db_session.query(BuildMessage.created_at) .join(BuildSession, BuildMessage.session_id == BuildSession.id) .filter( BuildSession.user_id == user_id, BuildMessage.type == MessageType.USER, BuildMessage.created_at >= cutoff_time, ) .order_by(BuildMessage.created_at.asc()) .limit(1) .scalar() ) ================================================ FILE: backend/onyx/server/features/build/db/sandbox.py ================================================ """Database operations for CLI agent sandbox management.""" import datetime from uuid import UUID from sqlalchemy import and_ from sqlalchemy import func from sqlalchemy import or_ from sqlalchemy import select from sqlalchemy.orm import Session from onyx.db.enums import SandboxStatus from onyx.db.models import Sandbox from onyx.db.models import Snapshot from onyx.utils.logger import setup_logger logger = setup_logger() def create_sandbox__no_commit( db_session: Session, user_id: UUID, ) -> Sandbox: """Create a new sandbox record for a user. Sets last_heartbeat to now so that: 1. The sandbox has a proper idle timeout baseline from creation 2. Long-running provisioning doesn't cause the sandbox to appear "old" when it transitions to RUNNING NOTE: This function uses flush() instead of commit(). The caller is responsible for committing the transaction when ready. """ sandbox = Sandbox( user_id=user_id, status=SandboxStatus.PROVISIONING, last_heartbeat=datetime.datetime.now(datetime.timezone.utc), ) db_session.add(sandbox) db_session.flush() return sandbox def get_sandbox_by_user_id(db_session: Session, user_id: UUID) -> Sandbox | None: """Get sandbox by user ID (primary lookup method).""" stmt = select(Sandbox).where(Sandbox.user_id == user_id) return db_session.execute(stmt).scalar_one_or_none() def get_sandbox_by_session_id(db_session: Session, session_id: UUID) -> Sandbox | None: """Get sandbox by session ID (compatibility function). This function provides backwards compatibility during the transition to user-owned sandboxes. It looks up the session's user_id, then finds the user's sandbox. NOTE: This will be removed in a future phase when all callers are updated to use get_sandbox_by_user_id() directly. """ from onyx.db.models import BuildSession stmt = select(BuildSession.user_id).where(BuildSession.id == session_id) result = db_session.execute(stmt).scalar_one_or_none() if result is None: return None return get_sandbox_by_user_id(db_session, result) def get_sandbox_by_id(db_session: Session, sandbox_id: UUID) -> Sandbox | None: """Get sandbox by its ID.""" stmt = select(Sandbox).where(Sandbox.id == sandbox_id) return db_session.execute(stmt).scalar_one_or_none() def update_sandbox_status__no_commit( db_session: Session, sandbox_id: UUID, status: SandboxStatus, ) -> Sandbox: """Update sandbox status. When transitioning to RUNNING, also sets last_heartbeat to now. This ensures newly provisioned sandboxes have a proper idle timeout baseline (rather than being immediately considered idle due to NULL heartbeat). NOTE: This function uses flush() instead of commit(). The caller is responsible for committing the transaction when ready. """ sandbox = get_sandbox_by_id(db_session, sandbox_id) if not sandbox: raise ValueError(f"Sandbox {sandbox_id} not found") sandbox.status = status # Set heartbeat when sandbox becomes active to establish idle timeout baseline if status == SandboxStatus.RUNNING: sandbox.last_heartbeat = datetime.datetime.now(datetime.timezone.utc) db_session.flush() return sandbox def update_sandbox_heartbeat(db_session: Session, sandbox_id: UUID) -> Sandbox: """Update sandbox last_heartbeat to now.""" sandbox = get_sandbox_by_id(db_session, sandbox_id) if not sandbox: raise ValueError(f"Sandbox {sandbox_id} not found") sandbox.last_heartbeat = datetime.datetime.now(datetime.timezone.utc) db_session.commit() return sandbox def get_idle_sandboxes( db_session: Session, idle_threshold_seconds: int ) -> list[Sandbox]: """Get sandboxes that have been idle longer than threshold. Also includes sandboxes with NULL heartbeat, but only if they were created before the threshold (to avoid sweeping up brand-new sandboxes that may have NULL heartbeat due to edge cases like older rows or manual inserts). """ threshold_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta( seconds=idle_threshold_seconds ) stmt = select(Sandbox).where( Sandbox.status == SandboxStatus.RUNNING, or_( Sandbox.last_heartbeat < threshold_time, and_( Sandbox.last_heartbeat.is_(None), Sandbox.created_at < threshold_time, ), ), ) return list(db_session.execute(stmt).scalars().all()) def get_running_sandbox_count_by_tenant( db_session: Session, tenant_id: str, # noqa: ARG001 ) -> int: """Get count of running sandboxes for a tenant (for limit enforcement). Note: tenant_id parameter is kept for API compatibility but is not used since Sandbox model no longer has tenant_id. This function returns the count of all running sandboxes. """ stmt = select(func.count(Sandbox.id)).where(Sandbox.status == SandboxStatus.RUNNING) result = db_session.execute(stmt).scalar() return result or 0 def create_snapshot__no_commit( db_session: Session, session_id: UUID, storage_path: str, size_bytes: int, ) -> Snapshot: """Create a snapshot record for a session. NOTE: Uses flush() instead of commit(). The caller (cleanup task) is responsible for committing after all snapshots + status updates are done, so the entire operation is atomic. """ snapshot = Snapshot( session_id=session_id, storage_path=storage_path, size_bytes=size_bytes, ) db_session.add(snapshot) db_session.flush() return snapshot def get_latest_snapshot_for_session( db_session: Session, session_id: UUID ) -> Snapshot | None: """Get most recent snapshot for a session.""" stmt = ( select(Snapshot) .where(Snapshot.session_id == session_id) .order_by(Snapshot.created_at.desc()) .limit(1) ) return db_session.execute(stmt).scalar_one_or_none() def get_snapshots_for_session(db_session: Session, session_id: UUID) -> list[Snapshot]: """Get all snapshots for a session, ordered by creation time descending.""" stmt = ( select(Snapshot) .where(Snapshot.session_id == session_id) .order_by(Snapshot.created_at.desc()) ) return list(db_session.execute(stmt).scalars().all()) def delete_old_snapshots( db_session: Session, tenant_id: str, # noqa: ARG001 retention_days: int, ) -> int: """Delete snapshots older than retention period, return count deleted. Note: tenant_id parameter is kept for API compatibility but is not used since Snapshot model no longer has tenant_id. This function deletes all snapshots older than the retention period. """ cutoff_time = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta( days=retention_days ) stmt = select(Snapshot).where( Snapshot.created_at < cutoff_time, ) old_snapshots = db_session.execute(stmt).scalars().all() count = 0 for snapshot in old_snapshots: db_session.delete(snapshot) count += 1 if count > 0: db_session.commit() return count def delete_snapshot(db_session: Session, snapshot_id: UUID) -> bool: """Delete a specific snapshot by ID. Returns True if deleted, False if not found.""" stmt = select(Snapshot).where(Snapshot.id == snapshot_id) snapshot = db_session.execute(stmt).scalar_one_or_none() if not snapshot: return False db_session.delete(snapshot) db_session.commit() return True ================================================ FILE: backend/onyx/server/features/build/db/user_library.py ================================================ """Database operations for User Library (CRAFT_FILE connector). Handles storage quota queries and connector/credential setup for the User Library feature in Craft. """ from uuid import UUID from sqlalchemy import and_ from sqlalchemy import cast from sqlalchemy import func from sqlalchemy import Integer from sqlalchemy import select from sqlalchemy.orm import Session from onyx.configs.constants import DocumentSource from onyx.connectors.models import InputType from onyx.db.connector import create_connector from onyx.db.connector import fetch_connectors from onyx.db.connector_credential_pair import add_credential_to_connector from onyx.db.connector_credential_pair import ( get_connector_credential_pairs_for_user, ) from onyx.db.credentials import create_credential from onyx.db.credentials import fetch_credentials_for_user from onyx.db.enums import AccessType from onyx.db.enums import ProcessingMode from onyx.db.models import Connector from onyx.db.models import ConnectorCredentialPair from onyx.db.models import Document as DbDocument from onyx.db.models import DocumentByConnectorCredentialPair from onyx.db.models import User from onyx.server.documents.models import ConnectorBase from onyx.server.documents.models import CredentialBase from onyx.server.features.build.configs import USER_LIBRARY_CONNECTOR_NAME from onyx.server.features.build.configs import USER_LIBRARY_CREDENTIAL_NAME from onyx.utils.logger import setup_logger logger = setup_logger() def get_user_storage_bytes(db_session: Session, user_id: UUID) -> int: """Get total storage usage for a user's library files. Uses SQL aggregation to sum file_size from doc_metadata JSONB for all CRAFT_FILE documents owned by this user, avoiding loading all documents into Python memory. """ stmt = ( select( func.coalesce( func.sum( cast( DbDocument.doc_metadata["file_size"].as_string(), Integer, ) ), 0, ) ) .join( DocumentByConnectorCredentialPair, DbDocument.id == DocumentByConnectorCredentialPair.id, ) .join( ConnectorCredentialPair, and_( DocumentByConnectorCredentialPair.connector_id == ConnectorCredentialPair.connector_id, DocumentByConnectorCredentialPair.credential_id == ConnectorCredentialPair.credential_id, ), ) .join( Connector, ConnectorCredentialPair.connector_id == Connector.id, ) .where(Connector.source == DocumentSource.CRAFT_FILE) .where(ConnectorCredentialPair.creator_id == user_id) .where(DbDocument.doc_metadata["is_directory"].as_boolean().is_not(True)) ) result = db_session.execute(stmt).scalar() return int(result or 0) def get_or_create_craft_connector(db_session: Session, user: User) -> tuple[int, int]: """Get or create the CRAFT_FILE connector for a user. Returns: Tuple of (connector_id, credential_id) Note: We need to create a credential even though CRAFT_FILE doesn't require authentication. This is because Onyx's connector-credential pair system requires a credential for all connectors. The credential is empty ({}). This function handles recovery from partial creation failures by detecting orphaned connectors (connectors without cc_pairs) and completing their setup. """ # Check if user already has a complete CRAFT_FILE cc_pair cc_pairs = get_connector_credential_pairs_for_user( db_session=db_session, user=user, get_editable=False, eager_load_connector=True, eager_load_credential=True, processing_mode=ProcessingMode.RAW_BINARY, ) for cc_pair in cc_pairs: if ( cc_pair.connector.source == DocumentSource.CRAFT_FILE and cc_pair.creator_id == user.id ): return cc_pair.connector.id, cc_pair.credential.id # No cc_pair for this user — find or create the shared CRAFT_FILE connector existing_connectors = fetch_connectors( db_session, sources=[DocumentSource.CRAFT_FILE] ) connector_id: int | None = None for conn in existing_connectors: if conn.name == USER_LIBRARY_CONNECTOR_NAME: connector_id = conn.id break if connector_id is None: connector_data = ConnectorBase( name=USER_LIBRARY_CONNECTOR_NAME, source=DocumentSource.CRAFT_FILE, input_type=InputType.LOAD_STATE, connector_specific_config={"disabled_paths": []}, refresh_freq=None, prune_freq=None, ) connector_response = create_connector( db_session=db_session, connector_data=connector_data, ) connector_id = connector_response.id # Try to reuse an existing User Library credential for this user existing_credentials = fetch_credentials_for_user( db_session=db_session, user=user, ) credential = None for cred in existing_credentials: if ( cred.source == DocumentSource.CRAFT_FILE and cred.name == USER_LIBRARY_CREDENTIAL_NAME ): credential = cred break if credential is None: credential_data = CredentialBase( credential_json={}, admin_public=False, source=DocumentSource.CRAFT_FILE, name=USER_LIBRARY_CREDENTIAL_NAME, ) credential = create_credential( credential_data=credential_data, user=user, db_session=db_session, ) # Link them with RAW_BINARY processing mode add_credential_to_connector( db_session=db_session, connector_id=connector_id, credential_id=credential.id, user=user, cc_pair_name=USER_LIBRARY_CONNECTOR_NAME, access_type=AccessType.PRIVATE, groups=None, processing_mode=ProcessingMode.RAW_BINARY, ) db_session.commit() return connector_id, credential.id ================================================ FILE: backend/onyx/server/features/build/indexing/persistent_document_writer.py ================================================ """ Persistent Document Writer for writing indexed documents to local filesystem or S3 with hierarchical directory structure that mirrors the source organization. Local mode (SandboxBackend.LOCAL): Writes to local filesystem at {PERSISTENT_DOCUMENT_STORAGE_PATH}/{tenant_id}/knowledge/{user_id}/... Kubernetes mode (SandboxBackend.KUBERNETES): Writes to S3 at s3://{SANDBOX_S3_BUCKET}/{tenant_id}/knowledge/{user_id}/... This is the same location that kubernetes_sandbox_manager.py reads from when provisioning sandboxes. Both modes use consistent tenant/user-segregated paths for multi-tenant isolation. """ import hashlib import json import unicodedata from pathlib import Path from typing import Any from botocore.exceptions import ClientError from mypy_boto3_s3.client import S3Client from onyx.connectors.models import Document from onyx.server.features.build.configs import PERSISTENT_DOCUMENT_STORAGE_PATH from onyx.server.features.build.configs import SANDBOX_BACKEND from onyx.server.features.build.configs import SANDBOX_S3_BUCKET from onyx.server.features.build.configs import SandboxBackend from onyx.server.features.build.s3.s3_client import build_s3_client from onyx.utils.logger import setup_logger logger = setup_logger() # ============================================================================= # Shared Utilities for Path Building # ============================================================================= def sanitize_path_component(component: str, replace_slash: bool = True) -> str: """Sanitize a path component for file system / S3 key safety. Args: component: The path component to sanitize replace_slash: If True, replaces forward slashes (needed for local filesystem). Set to False for S3 where `/` is a valid delimiter. Returns: Sanitized path component safe for use in file paths or S3 keys """ # First, normalize Unicode to decomposed form and remove combining characters # This handles cases like accented characters, while also filtering format chars normalized = unicodedata.normalize("NFKD", component) # Filter out Unicode format/control characters (categories Cf, Cc) # This removes invisible chars like U+2060 (WORD JOINER), zero-width spaces, etc. sanitized = "".join( c for c in normalized if unicodedata.category(c) not in ("Cf", "Cc") ) # Replace spaces with underscores sanitized = sanitized.replace(" ", "_") # Replace problematic characters if replace_slash: sanitized = sanitized.replace("/", "_") sanitized = sanitized.replace("\\", "_").replace(":", "_") sanitized = sanitized.replace("<", "_").replace(">", "_").replace("|", "_") sanitized = sanitized.replace('"', "_").replace("?", "_").replace("*", "_") return sanitized.strip() or "unnamed" def sanitize_filename(name: str, replace_slash: bool = True) -> str: """Sanitize name for use as filename. Args: name: The filename to sanitize replace_slash: Passed through to sanitize_path_component Returns: Sanitized filename, truncated with hash suffix if too long """ sanitized = sanitize_path_component(name, replace_slash=replace_slash) if len(sanitized) > 200: # Keep first 150 chars + hash suffix for uniqueness hash_suffix = hashlib.sha256(name.encode()).hexdigest()[:16] return f"{sanitized[:150]}_{hash_suffix}" return sanitized def normalize_leading_slash(path: str) -> str: """Ensure a path starts with exactly one leading slash.""" return "/" + path.lstrip("/") def get_base_filename(doc: Document, replace_slash: bool = True) -> str: """Get base filename from document, preferring semantic identifier. Args: doc: The document to get filename for replace_slash: Passed through to sanitize_filename Returns: Sanitized base filename (without extension) """ name = doc.semantic_identifier or doc.title or doc.id return sanitize_filename(name, replace_slash=replace_slash) def build_document_subpath(doc: Document, replace_slash: bool = True) -> list[str]: """Build the source/hierarchy path components from a document. Returns path components like: [source, hierarchy_part1, hierarchy_part2, ...] This is the common part of the path that comes after user/tenant segregation. Args: doc: The document to build path for replace_slash: Passed through to sanitize_path_component Returns: List of sanitized path components """ parts: list[str] = [] # Source type (e.g., "google_drive", "confluence") parts.append(doc.source.value) # Get hierarchy from doc_metadata hierarchy: dict[str, Any] = ( doc.doc_metadata.get("hierarchy", {}) if doc.doc_metadata else {} ) source_path: list[str] = hierarchy.get("source_path", []) if source_path: parts.extend( [ sanitize_path_component(p, replace_slash=replace_slash) for p in source_path ] ) return parts def resolve_duplicate_filename( doc: Document, base_filename: str, has_duplicates: bool, replace_slash: bool = True, ) -> str: """Resolve filename, appending ID suffix if there are duplicates. Args: doc: The document (for ID extraction) base_filename: The base filename without extension has_duplicates: Whether there are other docs with the same base filename replace_slash: Passed through to sanitize_path_component Returns: Final filename with .json extension """ if has_duplicates: id_suffix = sanitize_path_component(doc.id, replace_slash=replace_slash) if len(id_suffix) > 50: id_suffix = hashlib.sha256(doc.id.encode()).hexdigest()[:16] return f"{base_filename}_{id_suffix}.json" return f"{base_filename}.json" def serialize_document(doc: Document) -> dict[str, Any]: """Serialize a document to a dictionary for JSON storage. Args: doc: The document to serialize Returns: Dictionary representation of the document """ return { "id": doc.id, "semantic_identifier": doc.semantic_identifier, "title": doc.title, "source": doc.source.value, "doc_updated_at": ( doc.doc_updated_at.isoformat() if doc.doc_updated_at else None ), "metadata": doc.metadata, "doc_metadata": doc.doc_metadata, "sections": [ {"text": s.text if hasattr(s, "text") else None, "link": s.link} for s in doc.sections ], "primary_owners": [o.model_dump() for o in (doc.primary_owners or [])], "secondary_owners": [o.model_dump() for o in (doc.secondary_owners or [])], } # ============================================================================= # Classes # ============================================================================= class PersistentDocumentWriter: """Writes indexed documents to local filesystem with hierarchical structure. Documents are stored in tenant/user-segregated paths: {base_path}/{tenant_id}/knowledge/{user_id}/{source}/{hierarchy}/document.json This enables per-tenant and per-user isolation for sandbox access control. """ def __init__( self, base_path: str, tenant_id: str, user_id: str, ): self.base_path = Path(base_path) self.tenant_id = tenant_id self.user_id = user_id def write_documents(self, documents: list[Document]) -> list[str]: """Write documents to local filesystem, returns written file paths.""" written_paths: list[str] = [] # Build a map of base filenames to detect duplicates # Key: (directory_path, base_filename) -> list of docs with that name filename_map: dict[tuple[Path, str], list[Document]] = {} for doc in documents: dir_path = self._build_directory_path(doc) base_filename = get_base_filename(doc, replace_slash=True) key = (dir_path, base_filename) if key not in filename_map: filename_map[key] = [] filename_map[key].append(doc) # Now write documents, appending ID if there are duplicates for (dir_path, base_filename), docs in filename_map.items(): has_duplicates = len(docs) > 1 for doc in docs: filename = resolve_duplicate_filename( doc, base_filename, has_duplicates, replace_slash=True ) path = dir_path / filename self._write_document(doc, path) written_paths.append(str(path)) return written_paths def _build_directory_path(self, doc: Document) -> Path: """Build directory path from document metadata. Documents are stored under tenant/user-segregated paths: {base_path}/{tenant_id}/knowledge/{user_id}/{source}/{hierarchy}/ This enables per-tenant and per-user isolation for sandbox access control. """ # Tenant and user segregation prefix (matches S3 path structure) parts = [self.tenant_id, "knowledge", self.user_id] # Add source and hierarchy from document parts.extend(build_document_subpath(doc, replace_slash=True)) return self.base_path / "/".join(parts) def _write_document(self, doc: Document, path: Path) -> None: """Serialize and write document to filesystem.""" content = serialize_document(doc) # Create parent directories if they don't exist path.parent.mkdir(parents=True, exist_ok=True) # Write the JSON file with open(path, "w", encoding="utf-8") as f: json.dump(content, f, indent=2, default=str) logger.debug(f"Wrote document to {path}") def write_raw_file( self, path: str, content: bytes, content_type: str | None = None, # noqa: ARG002 ) -> str: """Write a raw binary file to local filesystem (for User Library). Unlike write_documents which serializes Document objects to JSON, this method writes raw binary content directly. Used for user-uploaded files like xlsx, pptx. Args: path: Relative path within user's library (e.g., "/project-data/financials.xlsx") content: Raw binary content to write content_type: MIME type of the file (stored as metadata, unused locally) Returns: Full filesystem path where file was written """ # Build full path: {base_path}/{tenant}/knowledge/{user}/user_library/{path} normalized_path = normalize_leading_slash(path) full_path = ( self.base_path / self.tenant_id / "knowledge" / self.user_id / "user_library" / normalized_path.lstrip("/") ) # Create parent directories if they don't exist full_path.parent.mkdir(parents=True, exist_ok=True) # Write the raw binary content with open(full_path, "wb") as f: f.write(content) logger.debug(f"Wrote raw file to {full_path}") return str(full_path) def delete_raw_file(self, path: str) -> None: """Delete a raw file from local filesystem. Args: path: Relative path within user's library (e.g., "/project-data/financials.xlsx") """ # Build full path normalized_path = normalize_leading_slash(path) full_path = ( self.base_path / self.tenant_id / "knowledge" / self.user_id / "user_library" / normalized_path.lstrip("/") ) if full_path.exists(): full_path.unlink() logger.debug(f"Deleted raw file at {full_path}") else: logger.warning(f"File not found for deletion: {full_path}") class S3PersistentDocumentWriter: """Writes indexed documents to S3 with hierarchical structure. Documents are stored in tenant/user-segregated paths: s3://{bucket}/{tenant_id}/knowledge/{user_id}/{source}/{hierarchy}/document.json This matches the location that KubernetesSandboxManager reads from when provisioning sandboxes (via the sidecar container's s5cmd sync command). """ def __init__(self, tenant_id: str, user_id: str): """Initialize S3PersistentDocumentWriter. Args: tenant_id: Tenant identifier for multi-tenant isolation user_id: User ID for user-segregated storage paths """ self.tenant_id = tenant_id self.user_id = user_id self.bucket = SANDBOX_S3_BUCKET self._s3_client: S3Client | None = None def _get_s3_client(self) -> S3Client: """Lazily initialize S3 client. Uses the craft-specific boto3 client which only supports IAM roles (IRSA). """ if self._s3_client is None: self._s3_client = build_s3_client() return self._s3_client def write_documents(self, documents: list[Document]) -> list[str]: """Write documents to S3, returns written S3 keys. Args: documents: List of documents to write Returns: List of S3 keys that were written """ written_keys: list[str] = [] # Build a map of base keys to detect duplicates # Key: (directory_prefix, base_filename) -> list of docs with that name key_map: dict[tuple[str, str], list[Document]] = {} for doc in documents: dir_prefix = self._build_directory_path(doc) base_filename = get_base_filename(doc, replace_slash=False) key = (dir_prefix, base_filename) if key not in key_map: key_map[key] = [] key_map[key].append(doc) # Now write documents, appending ID if there are duplicates s3_client = self._get_s3_client() for (dir_prefix, base_filename), docs in key_map.items(): has_duplicates = len(docs) > 1 for doc in docs: filename = resolve_duplicate_filename( doc, base_filename, has_duplicates, replace_slash=False ) s3_key = f"{dir_prefix}/{filename}" self._write_document(s3_client, doc, s3_key) written_keys.append(s3_key) return written_keys def _build_directory_path(self, doc: Document) -> str: """Build S3 key prefix from document metadata. Documents are stored under tenant/user-segregated paths: {tenant_id}/knowledge/{user_id}/{source}/{hierarchy}/ This matches the path that KubernetesSandboxManager syncs from: s5cmd sync "s3://{bucket}/{tenant_id}/knowledge/{user_id}/*" /workspace/files/ """ # Tenant and user segregation (matches K8s sandbox init container path) parts = [self.tenant_id, "knowledge", self.user_id] # Add source and hierarchy from document parts.extend(build_document_subpath(doc, replace_slash=False)) return "/".join(parts) def _write_document(self, s3_client: S3Client, doc: Document, s3_key: str) -> None: """Serialize and write document to S3.""" content = serialize_document(doc) json_content = json.dumps(content, indent=2, default=str) try: s3_client.put_object( Bucket=self.bucket, Key=s3_key, Body=json_content.encode("utf-8"), ContentType="application/json", ) logger.debug(f"Wrote document to s3://{self.bucket}/{s3_key}") except ClientError as e: logger.error(f"Failed to write to S3: {e}") raise def write_raw_file( self, path: str, content: bytes, content_type: str | None = None, ) -> str: """Write a raw binary file to S3 (for User Library). Unlike write_documents which serializes Document objects to JSON, this method writes raw binary content directly. Used for user-uploaded files like xlsx, pptx. Args: path: Relative path within user's library (e.g., "/project-data/financials.xlsx") content: Raw binary content to write content_type: MIME type of the file Returns: S3 key where file was written """ # Build S3 key: {tenant}/knowledge/{user}/user_library/{path} normalized_path = path.lstrip("/") s3_key = ( f"{self.tenant_id}/knowledge/{self.user_id}/user_library/{normalized_path}" ) s3_client = self._get_s3_client() try: s3_client.put_object( Bucket=self.bucket, Key=s3_key, Body=content, ContentType=content_type or "application/octet-stream", ) logger.debug(f"Wrote raw file to s3://{self.bucket}/{s3_key}") return s3_key except ClientError as e: logger.error(f"Failed to write raw file to S3: {e}") raise def delete_raw_file(self, s3_key: str) -> None: """Delete a raw file from S3. Args: s3_key: Full S3 key of the file to delete """ s3_client = self._get_s3_client() try: s3_client.delete_object(Bucket=self.bucket, Key=s3_key) logger.debug(f"Deleted raw file at s3://{self.bucket}/{s3_key}") except ClientError as e: logger.error(f"Failed to delete raw file from S3: {e}") raise def delete_raw_file_by_path(self, path: str) -> None: """Delete a raw file from S3 by its relative path. Args: path: Relative path within user's library (e.g., "/project-data/financials.xlsx") """ normalized_path = path.lstrip("/") s3_key = ( f"{self.tenant_id}/knowledge/{self.user_id}/user_library/{normalized_path}" ) self.delete_raw_file(s3_key) def get_persistent_document_writer( user_id: str, tenant_id: str, ) -> PersistentDocumentWriter | S3PersistentDocumentWriter: """Factory function to create a PersistentDocumentWriter with default configuration. Args: user_id: User ID for user-segregated storage paths. tenant_id: Tenant ID for multi-tenant isolation. Both local and S3 modes use consistent tenant/user-segregated paths: - Local: {base_path}/{tenant_id}/knowledge/{user_id}/... - S3: s3://{bucket}/{tenant_id}/knowledge/{user_id}/... Returns: PersistentDocumentWriter for local mode, S3PersistentDocumentWriter for K8s mode """ if SANDBOX_BACKEND == SandboxBackend.LOCAL: return PersistentDocumentWriter( base_path=PERSISTENT_DOCUMENT_STORAGE_PATH, tenant_id=tenant_id, user_id=user_id, ) elif SANDBOX_BACKEND == SandboxBackend.KUBERNETES: return S3PersistentDocumentWriter( tenant_id=tenant_id, user_id=user_id, ) else: raise ValueError(f"Unknown sandbox backend: {SANDBOX_BACKEND}") ================================================ FILE: backend/onyx/server/features/build/s3/s3_client.py ================================================ import boto3 from mypy_boto3_s3.client import S3Client from onyx.configs.app_configs import AWS_REGION_NAME def build_s3_client() -> S3Client: """Build an S3 client using IAM roles (IRSA)""" return boto3.client("s3", region_name=AWS_REGION_NAME) ================================================ FILE: backend/onyx/server/features/build/sandbox/README.md ================================================ # Onyx Sandbox System This directory contains the implementation of Onyx's sandbox system for running OpenCode agents in isolated environments. ## Overview The sandbox system provides isolated execution environments where OpenCode agents can build web applications, run code, and interact with knowledge files. Each sandbox includes: - **Next.js development environment** - Lightweight Next.js scaffold with shadcn/ui and Recharts for building UIs - **Python virtual environment** - Pre-installed packages for data processing - **OpenCode agent** - AI coding agent with access to tools and MCP servers - **Knowledge files** - Access to indexed documents and user uploads ## Architecture ### Deployment Modes 1. **Local Mode** (`SANDBOX_BACKEND=local`) - Sandboxes run as directories on the local filesystem - No automatic cleanup or snapshots - Suitable for development and testing 2. **Kubernetes Mode** (`SANDBOX_BACKEND=kubernetes`) - Sandboxes run as Kubernetes pods - Automatic snapshots to S3 - Auto-cleanup of idle sandboxes - Production-ready with resource isolation ### Directory Structure ``` /workspace/ # Sandbox root (in container) ├── outputs/ # Working directory │ ├── web/ # Lightweight Next.js app (shadcn/ui, Recharts) │ ├── slides/ # Generated presentations │ ├── markdown/ # Generated documents │ └── graphs/ # Generated visualizations ├── .venv/ # Python virtual environment ├── files/ # Symlink to knowledge files ├── attachments/ # User uploads ├── AGENTS.md # Agent instructions └── .opencode/ └── skills/ # Agent skills ``` ## Setup ### Running via Docker/Kubernetes (Zero Setup!) 🎉 **No setup required!** Just build and deploy: ```bash # Build backend image (includes both templates) cd backend docker build -f Dockerfile.sandbox-templates -t onyxdotapp/backend:latest . # Build sandbox container (lightweight runner) cd onyx/server/features/build/sandbox/kubernetes/docker docker build -t onyxdotapp/sandbox:latest . # Deploy with docker-compose or kubectl - sandboxes work immediately! ``` **How it works:** - **Backend image**: Contains both templates at build time: - Web template at `/templates/outputs/web` (lightweight Next.js scaffold, ~2MB) - Python venv template at `/templates/venv` (pre-installed packages, ~50MB) - **Init container** (Kubernetes only): Syncs knowledge files from S3 - **Sandbox startup**: Runs `npm install` (for fresh dependency locks) + `next dev` ### Running Backend Directly (Without Docker) **Only needed if you're running the Onyx backend outside of Docker.** Most developers use Docker and can skip this section. If you're running the backend Python process directly on your machine, you need templates at `/templates/`: #### Web Template The web template is a lightweight Next.js app (Next.js 16, React 19, shadcn/ui, Recharts) checked into the codebase at `backend/onyx/server/features/build/templates/outputs/web/`. For local development, create a symlink to this template: ```bash sudo mkdir -p /templates/outputs sudo ln -s $(pwd)/backend/onyx/server/features/build/templates/outputs/web /templates/outputs/web ``` #### Python Venv Template If you don't have a venv template, create it: ```bash # Use the utility script cd backend python -m onyx.server.features.build.sandbox.util.build_venv_template # Or manually python3 -m venv /templates/venv /templates/venv/bin/pip install -r backend/onyx/server/features/build/sandbox/kubernetes/docker/initial-requirements.txt ``` #### System Dependencies (for PPTX skill) The PPTX skill requires LibreOffice and Poppler for PDF conversion and thumbnail generation: **macOS:** ```bash brew install poppler brew install --cask libreoffice ``` Ensure `soffice` is on your PATH: ```bash export PATH="/Applications/LibreOffice.app/Contents/MacOS:$PATH" ``` **Linux (Debian/Ubuntu):** ```bash sudo apt-get install libreoffice-impress poppler-utils ``` **That's it!** When sandboxes are created: 1. Web template is copied from `/templates/outputs/web` 2. Python venv is copied from `/templates/venv` 3. `npm install` runs automatically to install fresh Next.js dependencies ## OpenCode Configuration Each sandbox includes an OpenCode agent configured with: - **LLM Provider**: Anthropic, OpenAI, Google, Bedrock, or Azure - **Extended thinking**: High reasoning effort / thinking budgets for complex tasks - **Tool permissions**: File operations, bash commands, web access - **Disabled tools**: Configurable via `OPENCODE_DISABLED_TOOLS` env var Configuration is generated dynamically in `templates/opencode_config.py`. ## Key Components ### Managers - **`base.py`** - Abstract base class defining the sandbox interface - **`local/manager.py`** - Filesystem-based sandbox manager for local development - **`kubernetes/manager.py`** - Kubernetes-based sandbox manager for production ### Managers (Shared) - **`manager/directory_manager.py`** - Creates sandbox directory structure and copies templates - **`manager/snapshot_manager.py`** - Handles snapshot creation and restoration ### Utilities - **`util/opencode_config.py`** - Generates OpenCode configuration with MCP support - **`util/agent_instructions.py`** - Generates agent instructions (AGENTS.md) - **`util/build_venv_template.py`** - Utility to build Python venv template for local development ### Templates - **`../templates/outputs/web/`** - Lightweight Next.js scaffold (shadcn/ui, Recharts) versioned with the backend code ### Kubernetes Specific - **`kubernetes/docker/Dockerfile`** - Sandbox container image (runs Next.js + OpenCode) - **`kubernetes/docker/entrypoint.sh`** - Container startup script ## Environment Variables ### Core Settings ```bash # Sandbox backend mode SANDBOX_BACKEND=local|kubernetes # Default: local # Template paths (local mode) OUTPUTS_TEMPLATE_PATH=/templates/outputs # Default: /templates/outputs VENV_TEMPLATE_PATH=/templates/venv # Default: /templates/venv # Sandbox base path (local mode) SANDBOX_BASE_PATH=/tmp/onyx-sandboxes # Default: /tmp/onyx-sandboxes # OpenCode configuration OPENCODE_DISABLED_TOOLS=question # Comma-separated list, default: question ``` ### Kubernetes Settings ```bash # Kubernetes namespace SANDBOX_NAMESPACE=onyx-sandboxes # Default: onyx-sandboxes # Container image SANDBOX_CONTAINER_IMAGE=onyxdotapp/sandbox:latest # S3 bucket for snapshots and files SANDBOX_S3_BUCKET=onyx-sandbox-files # Default: onyx-sandbox-files # Service accounts SANDBOX_SERVICE_ACCOUNT_NAME=sandbox-runner # No AWS access SANDBOX_FILE_SYNC_SERVICE_ACCOUNT=sandbox-file-sync # Has S3 access via IRSA ``` ### Lifecycle Settings ```bash # Idle timeout before cleanup (seconds) SANDBOX_IDLE_TIMEOUT_SECONDS=900 # Default: 900 (15 minutes) # Max concurrent sandboxes per organization SANDBOX_MAX_CONCURRENT_PER_ORG=10 # Default: 10 # Next.js port range (local mode) SANDBOX_NEXTJS_PORT_START=3010 # Default: 3010 SANDBOX_NEXTJS_PORT_END=3100 # Default: 3100 ``` ## Testing ### Integration Tests ```bash # Test local sandbox provisioning uv run pytest backend/tests/integration/sandbox/test_local_sandbox.py # Test Kubernetes sandbox provisioning (requires k8s cluster) uv run pytest backend/tests/integration/sandbox/test_kubernetes_sandbox.py ``` ### Manual Testing ```bash # Start a local sandbox session curl -X POST http://localhost:3000/api/build/session \ -H "Content-Type: application/json" \ -d '{ "user_id": "user-123", "file_system_path": "/path/to/files" }' # Send a message to the agent curl -X POST http://localhost:3000/api/build/session/{session_id}/message \ -H "Content-Type: application/json" \ -d '{ "message": "Create a simple web page" }' ``` ## Troubleshooting ### Sandbox Stuck in PROVISIONING (Kubernetes) **Symptoms**: Sandbox status never changes from `PROVISIONING` **Solutions**: - Check pod logs: `kubectl logs -n onyx-sandboxes sandbox-{sandbox-id}` - Check init container: `kubectl logs -n onyx-sandboxes sandbox-{sandbox-id} -c file-sync` - Verify init container completed: `kubectl describe pod -n onyx-sandboxes sandbox-{sandbox-id}` - Check S3 bucket access: Ensure init container service account has IRSA configured ### Next.js Server Won't Start **Symptoms**: Sandbox provisioned but web preview doesn't load **Solutions**: - **Local mode**: Check if port is already in use - **Docker/K8s**: Check container logs: `kubectl logs -n onyx-sandboxes sandbox-{sandbox-id}` - Verify npm install succeeded (check entrypoint.sh logs) - Check that web template was copied: `kubectl exec -n onyx-sandboxes sandbox-{sandbox-id} -- ls /workspace/outputs/web` ### Templates Not Found (Local Mode) **Symptoms**: `RuntimeError: Sandbox templates are missing` **Solution**: Set up templates as described in the "Local Development" section above: ```bash # Symlink web template sudo ln -s $(pwd)/backend/onyx/server/features/build/templates/outputs/web /templates/outputs/web # Create Python venv python3 -m venv /templates/venv /templates/venv/bin/pip install -r backend/onyx/server/features/build/sandbox/kubernetes/docker/initial-requirements.txt ``` ### Permission Denied **Symptoms**: `Permission denied` error accessing `/templates/` **Solution**: Either use sudo when creating symlinks, or use custom paths: ```bash export OUTPUTS_TEMPLATE_PATH=$HOME/.onyx/templates/outputs export VENV_TEMPLATE_PATH=$HOME/.onyx/templates/venv # Then symlink to your home directory mkdir -p $HOME/.onyx/templates/outputs ln -s $(pwd)/backend/onyx/server/features/build/templates/outputs/web $HOME/.onyx/templates/outputs/web ``` ## Security Considerations ### Sandbox Isolation - **Kubernetes pods** run with restricted security context (non-root, no privilege escalation) - **Init containers** have S3 access for file sync, but main sandbox container does NOT - **Network policies** can restrict sandbox egress traffic - **Resource limits** prevent resource exhaustion ### Credentials Management - LLM API keys are passed as environment variables (not stored in sandbox) - User file access is read-only via symlinks - Snapshots are isolated per tenant in S3 ## Development ### Adding New MCP Servers 1. Add MCP configuration to `templates/opencode_config.py`: ```python config["mcp"] = { "my-mcp": { "type": "local", "command": ["npx", "@my/mcp@latest"], "enabled": True, } } ``` 2. Install required npm packages in web template (if needed) 3. Rebuild Docker image and templates ### Modifying Agent Instructions Edit `AGENTS.template.md` in the build directory. This is populated with dynamic content by `templates/agent_instructions.py`. ### Adding New Tools/Permissions Update `templates/opencode_config.py` to add/remove tool permissions in the `permission` section. ## Template Details ### Web Template The lightweight Next.js template (`backend/onyx/server/features/build/templates/outputs/web/`) includes: - **Framework**: Next.js 16.1.4 with React 19.2.3 - **UI Library**: shadcn/ui components with Radix UI primitives - **Styling**: Tailwind CSS v4 with custom theming support - **Charts**: Recharts for data visualization - **Size**: ~2MB (excluding node_modules, which are installed fresh per sandbox) This template provides a modern development environment without the complexity of the full Onyx application, allowing agents to build custom UIs quickly. ### Python Venv Template The Python venv (`/templates/venv/`) includes packages from `initial-requirements.txt`: - Data processing: pandas, numpy, polars - HTTP clients: requests, httpx - Utilities: python-dotenv, pydantic ## References - [OpenCode Documentation](https://docs.opencode.ai) - [Next.js Documentation](https://nextjs.org/docs) - [shadcn/ui Components](https://ui.shadcn.com) ================================================ FILE: backend/onyx/server/features/build/sandbox/__init__.py ================================================ """ Sandbox module for CLI agent filesystem-based isolation. This module provides lightweight sandbox management for CLI-based AI agent sessions. Each sandbox is a directory on the local filesystem or a Kubernetes pod. Usage: from onyx.server.features.build.sandbox import get_sandbox_manager # Get the appropriate sandbox manager based on SANDBOX_BACKEND config sandbox_manager = get_sandbox_manager() # Use the sandbox manager sandbox_info = sandbox_manager.provision(...) Module structure: - base.py: SandboxManager ABC and get_sandbox_manager() factory - models.py: Shared Pydantic models - local/: Local filesystem-based implementation for development - kubernetes/: Kubernetes pod-based implementation for production - internal/: Shared internal utilities (snapshot manager) """ from onyx.server.features.build.sandbox.base import get_sandbox_manager from onyx.server.features.build.sandbox.base import SandboxManager from onyx.server.features.build.sandbox.local.local_sandbox_manager import ( LocalSandboxManager, ) from onyx.server.features.build.sandbox.models import FilesystemEntry from onyx.server.features.build.sandbox.models import SandboxInfo from onyx.server.features.build.sandbox.models import SnapshotInfo __all__ = [ # Factory function (preferred) "get_sandbox_manager", # Interface "SandboxManager", # Implementations "LocalSandboxManager", # Models "SandboxInfo", "SnapshotInfo", "FilesystemEntry", ] ================================================ FILE: backend/onyx/server/features/build/sandbox/base.py ================================================ """Abstract base class and factory for sandbox operations. SandboxManager is the abstract interface for sandbox lifecycle management. Use get_sandbox_manager() to get the appropriate implementation based on SANDBOX_BACKEND. IMPORTANT: SandboxManager implementations must NOT interface with the database directly. All database operations should be handled by the caller (SessionManager, Celery tasks, etc.). Architecture Note (User-Shared Sandbox Model): - One sandbox (container/pod) is shared across all of a user's sessions - provision() creates the user's sandbox with shared files/ directory - setup_session_workspace() creates per-session workspace within the sandbox - cleanup_session_workspace() removes session workspace on session delete - terminate() destroys the entire sandbox (all sessions) """ import threading from abc import ABC from abc import abstractmethod from collections.abc import Generator from typing import Any from uuid import UUID from onyx.server.features.build.configs import SANDBOX_BACKEND from onyx.server.features.build.configs import SandboxBackend from onyx.server.features.build.sandbox.models import FilesystemEntry from onyx.server.features.build.sandbox.models import LLMProviderConfig from onyx.server.features.build.sandbox.models import SandboxInfo from onyx.server.features.build.sandbox.models import SnapshotResult from onyx.utils.logger import setup_logger logger = setup_logger() # ACPEvent is a union type defined in both local and kubernetes modules # Using Any here to avoid circular imports - the actual type checking # happens in the implementation modules ACPEvent = Any class SandboxManager(ABC): """Abstract interface for sandbox operations. Defines the contract for sandbox lifecycle management including: - Provisioning and termination (user-level) - Session workspace setup and cleanup (session-level) - Snapshot creation (session-level) - Health checks - Agent communication (session-level) - Filesystem operations (session-level) Directory Structure: $SANDBOX_ROOT/ ├── files/ # SHARED - symlink to user's persistent documents └── sessions/ ├── $session_id_1/ # Per-session workspace │ ├── outputs/ # Agent output for this session │ │ └── web/ # Next.js app │ ├── venv/ # Python virtual environment │ ├── skills/ # Opencode skills │ ├── AGENTS.md # Agent instructions │ ├── opencode.json # LLM config │ └── attachments/ └── $session_id_2/ └── ... IMPORTANT: Implementations must NOT interface with the database directly. All database operations should be handled by the caller. Use get_sandbox_manager() to get the appropriate implementation. """ @abstractmethod def provision( self, sandbox_id: UUID, user_id: UUID, tenant_id: str, llm_config: LLMProviderConfig, ) -> SandboxInfo: """Provision a new sandbox for a user. Creates the sandbox container/directory with: - sessions/ directory for per-session workspaces NOTE: This does NOT set up session-specific workspaces. Call setup_session_workspace() after provisioning to create a session workspace. Args: sandbox_id: Unique identifier for the sandbox user_id: User identifier who owns this sandbox tenant_id: Tenant identifier for multi-tenant isolation llm_config: LLM provider configuration (for default config) Returns: SandboxInfo with the provisioned sandbox details Raises: RuntimeError: If provisioning fails """ ... @abstractmethod def terminate(self, sandbox_id: UUID) -> None: """Terminate a sandbox and clean up all resources. Destroys the entire sandbox including all session workspaces. Use cleanup_session_workspace() to remove individual sessions. Args: sandbox_id: The sandbox ID to terminate """ ... @abstractmethod def setup_session_workspace( self, sandbox_id: UUID, session_id: UUID, llm_config: LLMProviderConfig, nextjs_port: int, file_system_path: str | None = None, snapshot_path: str | None = None, user_name: str | None = None, user_role: str | None = None, user_work_area: str | None = None, user_level: str | None = None, use_demo_data: bool = False, excluded_user_library_paths: list[str] | None = None, ) -> None: """Set up a session workspace within an existing sandbox. Creates the per-session directory structure: - sessions/$session_id/outputs/ (from snapshot or template) - sessions/$session_id/venv/ - sessions/$session_id/skills/ - sessions/$session_id/files/ (symlink to demo data or user files) - sessions/$session_id/AGENTS.md - sessions/$session_id/opencode.json - sessions/$session_id/attachments/ - sessions/$session_id/org_info/ (if demo data enabled) Args: sandbox_id: The sandbox ID (must be provisioned) session_id: The session ID for this workspace llm_config: LLM provider configuration for opencode.json file_system_path: Path to user's knowledge/source files snapshot_path: Optional storage path to restore outputs from user_name: User's name for personalization in AGENTS.md user_role: User's role/title for personalization in AGENTS.md user_work_area: User's work area for demo persona (e.g., "engineering") user_level: User's level for demo persona (e.g., "ic", "manager") use_demo_data: If True, symlink files/ to demo data; else to user files excluded_user_library_paths: List of paths within user_library to exclude from the sandbox (e.g., ["/data/file.xlsx"]). Only applies when use_demo_data=False. Files at these paths won't be accessible. Raises: RuntimeError: If workspace setup fails """ ... @abstractmethod def cleanup_session_workspace( self, sandbox_id: UUID, session_id: UUID, nextjs_port: int | None = None, ) -> None: """Clean up a session workspace (on session delete). 1. Stop the Next.js dev server if running on nextjs_port 2. Remove the session directory: sessions/$session_id/ Does NOT terminate the sandbox - other sessions may still be using it. Args: sandbox_id: The sandbox ID session_id: The session ID to clean up nextjs_port: Optional port where Next.js server is running """ ... @abstractmethod def create_snapshot( self, sandbox_id: UUID, session_id: UUID, tenant_id: str, ) -> SnapshotResult | None: """Create a snapshot of a session's outputs and attachments directories. Captures session-specific user data: - sessions/$session_id/outputs/ (generated artifacts, web apps) - sessions/$session_id/attachments/ (user uploaded files) Does NOT include: venv, skills, AGENTS.md, opencode.json, files symlink (these are regenerated during restore) Args: sandbox_id: The sandbox ID session_id: The session ID to snapshot tenant_id: Tenant identifier for storage path Returns: SnapshotResult with storage path and size, or None if: - Snapshots are disabled for this backend - No outputs directory exists (nothing to snapshot) Raises: RuntimeError: If snapshot creation fails """ ... @abstractmethod def restore_snapshot( self, sandbox_id: UUID, session_id: UUID, snapshot_storage_path: str, tenant_id: str, nextjs_port: int, llm_config: LLMProviderConfig, use_demo_data: bool = False, ) -> None: """Restore a session workspace from a snapshot. For Kubernetes: Downloads and extracts the snapshot, regenerates config files. For Local: No-op since workspaces persist on disk (no snapshots). Args: sandbox_id: The sandbox ID session_id: The session ID to restore snapshot_storage_path: Path to the snapshot in storage tenant_id: Tenant identifier for storage access nextjs_port: Port number for the NextJS dev server llm_config: LLM provider configuration for opencode.json use_demo_data: If True, symlink files/ to demo data Raises: RuntimeError: If snapshot restoration fails """ ... @abstractmethod def session_workspace_exists( self, sandbox_id: UUID, session_id: UUID, ) -> bool: """Check if a session's workspace directory exists in the sandbox. Used to determine if we need to restore from snapshot. Checks for sessions/$session_id/outputs/ directory. Args: sandbox_id: The sandbox ID session_id: The session ID to check Returns: True if the session workspace exists, False otherwise """ ... @abstractmethod def health_check(self, sandbox_id: UUID, timeout: float = 60.0) -> bool: """Check if the sandbox is healthy. Args: sandbox_id: The sandbox ID to check Returns: True if sandbox is healthy, False otherwise """ ... @abstractmethod def send_message( self, sandbox_id: UUID, session_id: UUID, message: str, ) -> Generator[ACPEvent, None, None]: """Send a message to the CLI agent and stream typed ACP events. The agent runs in the session-specific workspace: sessions/$session_id/ Args: sandbox_id: The sandbox ID session_id: The session ID (determines workspace directory) message: The message content to send Yields: Typed ACP schema event objects Raises: RuntimeError: If agent communication fails """ ... @abstractmethod def list_directory( self, sandbox_id: UUID, session_id: UUID, path: str ) -> list[FilesystemEntry]: """List contents of a directory in the session's outputs directory. Args: sandbox_id: The sandbox ID session_id: The session ID path: Relative path within sessions/$session_id/outputs/ Returns: List of FilesystemEntry objects sorted by directory first, then name Raises: ValueError: If path traversal attempted or path is not a directory """ ... @abstractmethod def read_file(self, sandbox_id: UUID, session_id: UUID, path: str) -> bytes: """Read a file from the session's workspace. Args: sandbox_id: The sandbox ID session_id: The session ID path: Relative path within sessions/$session_id/ Returns: File contents as bytes Raises: ValueError: If path traversal attempted or path is not a file """ ... @abstractmethod def upload_file( self, sandbox_id: UUID, session_id: UUID, filename: str, content: bytes, ) -> str: """Upload a file to the session's attachments directory. Args: sandbox_id: The sandbox ID session_id: The session ID filename: Sanitized filename content: File content as bytes Returns: Relative path where file was saved (e.g., "attachments/doc.pdf") Raises: RuntimeError: If upload fails """ ... @abstractmethod def delete_file( self, sandbox_id: UUID, session_id: UUID, path: str, ) -> bool: """Delete a file from the session's workspace. Args: sandbox_id: The sandbox ID session_id: The session ID path: Relative path to the file (e.g., "attachments/doc.pdf") Returns: True if file was deleted, False if not found Raises: ValueError: If path traversal attempted """ ... @abstractmethod def get_upload_stats( self, sandbox_id: UUID, session_id: UUID, ) -> tuple[int, int]: """Get current file count and total size for a session's attachments. Args: sandbox_id: The sandbox ID session_id: The session ID Returns: Tuple of (file_count, total_size_bytes) """ ... @abstractmethod def get_webapp_url(self, sandbox_id: UUID, port: int) -> str: """Get the webapp URL for a session's Next.js server. Returns the appropriate URL based on the backend: - Local: Returns localhost URL with port - Kubernetes: Returns internal cluster service URL Args: sandbox_id: The sandbox ID port: The session's allocated Next.js port Returns: URL to access the webapp """ ... @abstractmethod def generate_pptx_preview( self, sandbox_id: UUID, session_id: UUID, pptx_path: str, cache_dir: str, ) -> tuple[list[str], bool]: """Convert PPTX to slide JPEG images for preview, with caching. Checks if cache_dir already has slides. If the PPTX is newer than the cached images (or no cache exists), runs soffice -> pdftoppm pipeline. Args: sandbox_id: The sandbox ID session_id: The session ID pptx_path: Relative path to the PPTX file within the session workspace cache_dir: Relative path for the cache directory (e.g., "outputs/.pptx-preview/abc123") Returns: Tuple of (slide_paths, cached) where slide_paths is a list of relative paths to slide JPEG images (within session workspace) and cached indicates whether the result was served from cache. Raises: ValueError: If file not found or conversion fails """ ... @abstractmethod def sync_files( self, sandbox_id: UUID, user_id: UUID, tenant_id: str, source: str | None = None, ) -> bool: """Sync files from S3 to the sandbox's /workspace/files directory. For Kubernetes backend: Executes `s5cmd sync` in the file-sync sidecar container. For Local backend: No-op since files are directly accessible via symlink. This is idempotent - only downloads changed files. File visibility in sessions is controlled via filtered symlinks in setup_session_workspace(), not at the sync level. Args: sandbox_id: The sandbox UUID user_id: The user ID (for S3 path construction) tenant_id: The tenant ID (for S3 path construction) source: Optional source type (e.g., "gmail", "google_drive"). If None, syncs all sources. If specified, only syncs that source's directory. Returns: True if sync was successful, False otherwise. """ ... def ensure_nextjs_running( self, sandbox_id: UUID, session_id: UUID, nextjs_port: int, ) -> None: """Ensure the Next.js server is running for a session. Default is a no-op — only meaningful for local backends that manage process lifecycles directly (e.g., LocalSandboxManager). Args: sandbox_id: The sandbox ID session_id: The session ID nextjs_port: The port the Next.js server should be listening on """ # Singleton instance cache for the factory _sandbox_manager_instance: SandboxManager | None = None _sandbox_manager_lock = threading.Lock() def get_sandbox_manager() -> SandboxManager: """Get the appropriate SandboxManager implementation based on SANDBOX_BACKEND. Returns: SandboxManager instance: - LocalSandboxManager for local backend (development) - KubernetesSandboxManager for kubernetes backend (production) """ global _sandbox_manager_instance if _sandbox_manager_instance is None: with _sandbox_manager_lock: if _sandbox_manager_instance is None: if SANDBOX_BACKEND == SandboxBackend.LOCAL: from onyx.server.features.build.sandbox.local.local_sandbox_manager import ( LocalSandboxManager, ) _sandbox_manager_instance = LocalSandboxManager() elif SANDBOX_BACKEND == SandboxBackend.KUBERNETES: from onyx.server.features.build.sandbox.kubernetes.kubernetes_sandbox_manager import ( KubernetesSandboxManager, ) _sandbox_manager_instance = KubernetesSandboxManager() logger.info("Using KubernetesSandboxManager for sandbox operations") else: raise ValueError(f"Unknown sandbox backend: {SANDBOX_BACKEND}") return _sandbox_manager_instance ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/__init__.py ================================================ """Kubernetes-based sandbox implementation. This module provides the KubernetesSandboxManager for production deployments that run sandboxes as isolated Kubernetes pods. Internal implementation details (acp_http_client) are in the internal/ subdirectory and should not be used directly. """ from onyx.server.features.build.sandbox.kubernetes.kubernetes_sandbox_manager import ( KubernetesSandboxManager, ) __all__ = [ "KubernetesSandboxManager", ] ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/Dockerfile ================================================ # Sandbox Container Image # # User-shared sandbox model: # - One pod per user, shared across all user's sessions # - Session workspaces created via kubectl exec (setup_session_workspace) # - OpenCode agent runs via kubectl exec when needed # # Directory structure (created by init container + session setup): # /workspace/ # ├── demo_data/ # Demo data (baked into image, for demo sessions) # ├── files/ # User's knowledge files (synced from S3) # ├── skills/ # Agent skills (baked into image, copied per-session) # ├── templates/ # Output templates (baked into image) # └── sessions/ # Per-session workspaces (created via exec) # └── $session_id/ # ├── files/ # Symlink to /workspace/demo_data or /workspace/files # ├── outputs/ # ├── AGENTS.md # └── opencode.json FROM node:20-slim # Install system dependencies RUN apt-get update && apt-get install -y --no-install-recommends \ python3 \ python3-pip \ python3-venv \ curl \ git \ procps \ unzip \ \ libreoffice-core \ libreoffice-common \ libreoffice-impress \ libreoffice-draw \ poppler-utils \ gcc \ libc6-dev \ fontconfig \ fonts-dejavu-core \ fonts-liberation \ && rm -rf /var/lib/apt/lists/* # Create non-root user (matches pod securityContext) # Handle existing user/group with UID/GID 1000 in base image RUN EXISTING_USER=$(id -nu 1000 2>/dev/null || echo ""); \ EXISTING_GROUP=$(getent group 1000 | cut -d: -f1 2>/dev/null || echo ""); \ if [ -n "$EXISTING_GROUP" ] && [ "$EXISTING_GROUP" != "sandbox" ]; then \ groupmod -n sandbox $EXISTING_GROUP; \ elif [ -z "$EXISTING_GROUP" ]; then \ groupadd -g 1000 sandbox; \ fi; \ if [ -n "$EXISTING_USER" ] && [ "$EXISTING_USER" != "sandbox" ]; then \ usermod -l sandbox -g sandbox $EXISTING_USER; \ usermod -d /home/sandbox -m sandbox; \ usermod -s /bin/bash sandbox; \ elif [ -z "$EXISTING_USER" ]; then \ useradd -u 1000 -g sandbox -m -s /bin/bash sandbox; \ fi # Create workspace directories RUN mkdir -p workspace/sessions /workspace/files /workspace/templates /workspace/demo_data && \ chown -R sandbox:sandbox /workspace # Copy outputs template (web app scaffold, without node_modules) COPY --exclude=.next --exclude=node_modules templates/outputs /workspace/templates/outputs RUN chown -R sandbox:sandbox /workspace/templates # Copy and extract demo data from zip file # Zip contains demo_data/ as root folder COPY demo_data.zip /tmp/demo_data.zip RUN unzip -q /tmp/demo_data.zip -d /workspace && \ rm /tmp/demo_data.zip && \ chown -R sandbox:sandbox /workspace/demo_data # Copy and install Python requirements into a venv COPY initial-requirements.txt /tmp/initial-requirements.txt RUN python3 -m venv /workspace/.venv && \ /workspace/.venv/bin/pip install --upgrade pip && \ /workspace/.venv/bin/pip install -r /tmp/initial-requirements.txt && \ rm /tmp/initial-requirements.txt && \ chown -R sandbox:sandbox /workspace/.venv # Add venv to PATH so python/pip use it by default ENV PATH="/workspace/.venv/bin:${PATH}" # Install pptxgenjs globally for creating presentations from scratch RUN npm install -g pptxgenjs # Install opencode CLI as sandbox user so it goes to their home directory USER sandbox RUN curl -fsSL https://opencode.ai/install | bash USER root # Add opencode to PATH (installs to ~/.opencode/bin) ENV PATH="/home/sandbox/.opencode/bin:${PATH}" # Copy agent skills (symlinked into each session's .opencode/skills/ at setup time) COPY --exclude=__pycache__ skills/ /workspace/skills/ # Set ownership RUN chown -R sandbox:sandbox /workspace # Copy scripts COPY generate_agents_md.py /usr/local/bin/generate_agents_md.py RUN chmod +x /usr/local/bin/generate_agents_md.py # Switch to non-root user USER sandbox WORKDIR /workspace # Expose ports # - 3000: Next.js dev server (started per-session if needed) # - 8081: OpenCode ACP HTTP server (started via exec) EXPOSE 3000 8081 # Keep container alive - all work done via kubectl exec CMD ["sleep", "infinity"] ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/README.md ================================================ # Sandbox Container Image This directory contains the Dockerfile and resources for building the Onyx Craft sandbox container image. ## Directory Structure ``` docker/ ├── Dockerfile # Main container image definition ├── demo_data.zip # Demo data (extracted to /workspace/demo_data) ├── skills/ # Agent skills (image-generation, pptx, etc.) ├── templates/ │ └── outputs/ # Web app scaffold template (Next.js) ├── initial-requirements.txt # Python packages pre-installed in sandbox ├── generate_agents_md.py # Script to generate AGENTS.md for sessions └── README.md # This file ``` ## Building the Image The sandbox image must be built for **amd64** architecture since our Kubernetes cluster runs on x86_64 nodes. ### Build for amd64 only (fastest) ```bash cd backend/onyx/server/features/build/sandbox/kubernetes/docker docker build --platform linux/amd64 -t onyxdotapp/sandbox:v0.1.x . docker push onyxdotapp/sandbox:v0.1.x ``` ### Build multi-arch (recommended for flexibility) ```bash docker buildx build --platform linux/amd64,linux/arm64 \ -t onyxdotapp/sandbox:v0.1.x \ --push . ``` ### Update the `latest` tag After pushing a versioned tag, update `latest`: ```bash docker tag onyxdotapp/sandbox:v0.1.x onyxdotapp/sandbox:latest docker push onyxdotapp/sandbox:latest ``` Or with buildx: ```bash docker buildx build --platform linux/amd64,linux/arm64 \ -t onyxdotapp/sandbox:v0.1.x \ -t onyxdotapp/sandbox:latest \ --push . ``` ## Deploying a New Version 1. **Build and push** the new image (see above) 2. **Update the ConfigMap** in `cloud-deployment-yamls/danswer/configmap/env-configmap.yaml`: ```yaml SANDBOX_CONTAINER_IMAGE: "onyxdotapp/sandbox:v0.1.x" ``` 3. **Apply the ConfigMap**: ```bash kubectl apply -f configmap/env-configmap.yaml ``` 4. **Restart the API server** to pick up the new config: ```bash kubectl rollout restart deployment/api-server -n danswer ``` 5. **Delete existing sandbox pods** (they will be recreated with the new image): ```bash kubectl delete pods -n onyx-sandboxes -l app.kubernetes.io/component=sandbox ``` ## What's Baked Into the Image - **Base**: `node:20-slim` (Debian-based) - **Demo data**: `/workspace/demo_data/` - sample files for demo sessions - **Skills**: `/workspace/skills/` - agent skills (image-generation, pptx, etc.) - **Templates**: `/workspace/templates/outputs/` - Next.js web app scaffold - **Python venv**: `/workspace/.venv/` with packages from `initial-requirements.txt` - **OpenCode CLI**: Installed in `/home/sandbox/.opencode/bin/` ## Runtime Directory Structure When a session is created, the following structure is set up in the pod: ``` /workspace/ ├── demo_data/ # Baked into image ├── files/ # Mounted volume, synced from S3 ├── skills/ # Baked into image (agent skills) ├── templates/ # Baked into image └── sessions/ └── $session_id/ ├── .opencode/ │ └── skills/ # Symlink to /workspace/skills ├── files/ # Symlink to /workspace/demo_data or /workspace/files ├── outputs/ # Copied from templates, contains web app ├── attachments/ # User-uploaded files ├── org_info/ # Demo persona info (if demo mode) ├── AGENTS.md # Instructions for the AI agent └── opencode.json # OpenCode configuration ``` ## Troubleshooting ### Verify image exists on Docker Hub ```bash curl -s "https://hub.docker.com/v2/repositories/onyxdotapp/sandbox/tags" | jq '.results[].name' ``` ### Check what image a pod is using ```bash kubectl get pod -n onyx-sandboxes -o jsonpath='{.spec.containers[?(@.name=="sandbox")].image}' ``` ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/generate_agents_md.py ================================================ #!/usr/bin/env python3 """Generate AGENTS.md by scanning the files directory and populating the template. This script runs during session setup, AFTER files have been synced from S3 and the files symlink has been created. It reads an existing AGENTS.md (which contains the {{KNOWLEDGE_SOURCES_SECTION}} placeholder), replaces the placeholder by scanning the knowledge source directory, and writes it back. Usage: python3 generate_agents_md.py Arguments: agents_md_path: Path to the AGENTS.md file to update in place files_path: Path to the files directory to scan for knowledge sources """ import sys from pathlib import Path # Type alias for connector info entries ConnectorInfoEntry = dict[str, str | int] # Connector information for generating knowledge sources section # Keys are normalized (lowercase, underscores) directory names # Each entry has: summary (with optional {subdirs}), file_pattern, scan_depth # NOTE: This is duplicated from agent_instructions.py to avoid circular imports CONNECTOR_INFO: dict[str, ConnectorInfoEntry] = { "google_drive": { "summary": "Documents and files from Google Drive. This may contain information about a user and work they have done.", "file_pattern": "`FILE_NAME.json`", "scan_depth": 0, }, "gmail": { "summary": "Email conversations and threads", "file_pattern": "`FILE_NAME.json`", "scan_depth": 0, }, "linear": { "summary": "Engineering tickets from teams: {subdirs}", "file_pattern": "`[TEAM]/[TICKET_ID]_TICKET_TITLE.json`", "scan_depth": 2, }, "slack": { "summary": "Team messages from channels: {subdirs}", "file_pattern": "`[CHANNEL]/[AUTHOR]_in_[CHANNEL]__[MSG].json`", "scan_depth": 1, }, "github": { "summary": "Pull requests and code from: {subdirs}", "file_pattern": "`[ORG]/[REPO]/pull_requests/[PR_NUMBER]__[PR_TITLE].json`", "scan_depth": 2, }, "fireflies": { "summary": "Meeting transcripts from: {subdirs}", "file_pattern": "`[YYYY-MM]/CALL_TITLE.json`", "scan_depth": 1, }, "hubspot": { "summary": "CRM data including: {subdirs}", "file_pattern": "`[TYPE]/[RECORD_NAME].json`", "scan_depth": 1, }, "notion": { "summary": "Documentation and notes: {subdirs}", "file_pattern": "`PAGE_TITLE.json`", "scan_depth": 1, }, "user_library": { "summary": "User-uploaded files (spreadsheets, documents, presentations, etc.)", "file_pattern": "Any file format", "scan_depth": 1, }, } DEFAULT_SCAN_DEPTH = 1 def _normalize_connector_name(name: str) -> str: """Normalize a connector directory name for lookup.""" return name.lower().replace(" ", "_").replace("-", "_") def _scan_directory_to_depth( directory: Path, current_depth: int, max_depth: int, indent: str = " " ) -> list[str]: """Recursively scan directory up to max_depth levels.""" if current_depth >= max_depth: return [] lines: list[str] = [] try: subdirs = sorted( d for d in directory.iterdir() if d.is_dir() and not d.name.startswith(".") ) for subdir in subdirs[:10]: # Limit to 10 per level lines.append(f"{indent}- {subdir.name}/") # Recurse if we haven't hit max depth if current_depth + 1 < max_depth: nested = _scan_directory_to_depth( subdir, current_depth + 1, max_depth, indent + " " ) lines.extend(nested) if len(subdirs) > 10: lines.append(f"{indent}- ... and {len(subdirs) - 10} more") except Exception: pass return lines def build_knowledge_sources_section(files_path: Path) -> str: """Build combined knowledge sources section with summary, structure, and file patterns. This creates a single section per connector that includes: - What kind of data it contains (with actual subdirectory names) - The directory structure - The file naming pattern Args: files_path: Path to the files directory Returns: Formatted knowledge sources section """ if not files_path.exists(): return "No knowledge sources available." sections: list[str] = [] try: for item in sorted(files_path.iterdir()): if not item.is_dir() or item.name.startswith("."): continue normalized = _normalize_connector_name(item.name) info = CONNECTOR_INFO.get(normalized, {}) # Get subdirectory names subdirs: list[str] = [] try: subdirs = sorted( d.name for d in item.iterdir() if d.is_dir() and not d.name.startswith(".") )[:5] except Exception: pass # Build summary with subdirs summary_template = str(info.get("summary", f"Data from {item.name}")) if "{subdirs}" in summary_template and subdirs: subdir_str = ", ".join(subdirs) if len(subdirs) == 5: subdir_str += ", ..." summary = summary_template.format(subdirs=subdir_str) elif "{subdirs}" in summary_template: summary = summary_template.replace(": {subdirs}", "").replace( " {subdirs}", "" ) else: summary = summary_template # Build connector section file_pattern = str(info.get("file_pattern", "")) scan_depth = int(info.get("scan_depth", DEFAULT_SCAN_DEPTH)) lines = [f"### {item.name}/"] lines.append(f"{summary}.\n") # Add directory structure if depth > 0 if scan_depth > 0: lines.append("Directory structure:\n") nested = _scan_directory_to_depth(item, 0, scan_depth, "") if nested: lines.append("") lines.extend(nested) lines.append(f"\nFile format: {file_pattern}") sections.append("\n".join(lines)) except Exception as e: print( f"Warning: Error building knowledge sources section: {e}", file=sys.stderr ) return "Error scanning knowledge sources." if not sections: return "No knowledge sources available." return "\n\n".join(sections) def main() -> None: """Main entry point for container startup script. Reads an existing AGENTS.md, replaces the {{KNOWLEDGE_SOURCES_SECTION}} placeholder by scanning the files directory, and writes it back. Usage: python3 generate_agents_md.py """ if len(sys.argv) != 3: print( f"Usage: {sys.argv[0]} ", file=sys.stderr, ) sys.exit(1) agents_md_path = Path(sys.argv[1]) files_path = Path(sys.argv[2]) if not agents_md_path.exists(): print(f"Error: {agents_md_path} not found", file=sys.stderr) sys.exit(1) template = agents_md_path.read_text() # Resolve symlinks (handles both direct symlinks and dirs containing symlinks) resolved_files_path = files_path.resolve() knowledge_sources_section = build_knowledge_sources_section(resolved_files_path) # Replace placeholder and write back content = template.replace( "{{KNOWLEDGE_SOURCES_SECTION}}", knowledge_sources_section ) agents_md_path.write_text(content) print(f"Populated knowledge sources in {agents_md_path}") if __name__ == "__main__": main() ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/initial-requirements.txt ================================================ defusedxml>=0.7.1 google-genai>=1.0.0 lxml>=5.0.0 markitdown>=0.1.2 matplotlib==3.9.1 matplotlib-inline>=0.1.7 matplotlib-venn>=1.1.2 numpy==1.26.4 opencv-python>=4.11.0.86 openpyxl>=3.1.5 pandas==2.2.2 pdfplumber>=0.11.7 Pillow>=10.0.0 pydantic>=2.11.9 python-pptx>=1.0.2 scikit-image>=0.25.2 scikit-learn>=1.7.2 scipy>=1.16.2 seaborn>=0.13.2 xgboost>=3.0.5 ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/run-test.sh ================================================ #!/bin/bash # Run Kubernetes sandbox integration tests # # This script: # 1. Builds the onyx-backend Docker image # 2. Loads it into the kind cluster # 3. Deletes/recreates the test pod # 4. Waits for the pod to be ready # 5. Runs the pytest command inside the pod # # Usage: # ./run-test.sh [test_name] # # Examples: # ./run-test.sh # Run all tests # ./run-test.sh test_kubernetes_sandbox_provision # Run specific test set -e SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" PROJECT_ROOT="$(cd "$SCRIPT_DIR/../../../../../../../.." && pwd)" NAMESPACE="onyx-sandboxes" POD_NAME="sandbox-test" IMAGE_NAME="onyxdotapp/onyx-backend:latest" TEST_FILE="onyx/server/features/build/sandbox/kubernetes/test_kubernetes_sandbox.py" ENV_FILE="$PROJECT_ROOT/.vscode/.env" ORIGINAL_TEST_FILE="$PROJECT_ROOT/backend/tests/external_dependency_unit/craft/test_kubernetes_sandbox.py" cp "$ORIGINAL_TEST_FILE" "$PROJECT_ROOT/backend/$TEST_FILE" # Optional: specific test to run TEST_NAME="${1:-}" # Build env var arguments from .vscode/.env file for passing to the container ENV_VARS=() if [ -f "$ENV_FILE" ]; then echo "=== Loading environment variables from .vscode/.env ===" while IFS= read -r line || [ -n "$line" ]; do # Skip empty lines and comments [[ -z "$line" || "$line" =~ ^[[:space:]]*# ]] && continue # Skip lines without = [[ "$line" != *"="* ]] && continue # Add to env vars array ENV_VARS+=("$line") done < "$ENV_FILE" echo "Loaded ${#ENV_VARS[@]} environment variables" else echo "Warning: .vscode/.env not found, running without additional env vars" fi echo "=== Building onyx-backend Docker image ===" cd "$PROJECT_ROOT/backend" docker build -t "$IMAGE_NAME" -f Dockerfile . rm "$PROJECT_ROOT/backend/$TEST_FILE" echo "=== Loading image into kind cluster ===" kind load docker-image "$IMAGE_NAME" --name onyx 2>/dev/null || \ kind load docker-image "$IMAGE_NAME" 2>/dev/null || \ echo "Warning: Could not load into kind. If using minikube, run: minikube image load $IMAGE_NAME" echo "=== Deleting existing test pod (if any) ===" kubectl delete pod "$POD_NAME" -n "$NAMESPACE" --ignore-not-found=true echo "=== Creating test pod ===" kubectl apply -f "$SCRIPT_DIR/test-job.yaml" echo "=== Waiting for pod to be ready ===" kubectl wait --for=condition=Ready pod/"$POD_NAME" -n "$NAMESPACE" --timeout=120s echo "=== Running tests ===" if [ -n "$TEST_NAME" ]; then kubectl exec -it "$POD_NAME" -n "$NAMESPACE" -- \ env "${ENV_VARS[@]}" pytest "$TEST_FILE::$TEST_NAME" -v -s else kubectl exec -it "$POD_NAME" -n "$NAMESPACE" -- \ env "${ENV_VARS[@]}" pytest "$TEST_FILE" -v -s fi echo "=== Tests complete ===" ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/image-generation/SKILL.md ================================================ --- name: image-generation description: Generate images using nano banana. --- # Image Generation Skill Generate images using Nano Banana (Google Gemini Image API). Supports text-to-image and image-to-image generation with configurable options. ## Setup ### Dependencies ```bash pip install google-genai Pillow ``` ### Environment Variable Set your API key: ```bash export GEMINI_API_KEY="your_api_key_here" ``` ## Usage ### Basic Text-to-Image ```bash python scripts/generate.py --prompt "A futuristic city at sunset with neon lights" --output city.png ``` ### With Aspect Ratio ```bash python scripts/generate.py \ --prompt "Mountain landscape with a lake" \ --output landscape.png \ --aspect-ratio 16:9 ``` ### Image-to-Image Mode Use a reference image to guide generation: ```bash python scripts/generate.py \ --prompt "Make it look like a watercolor painting" \ --input-image original.png \ --output watercolor.png ``` ### Generate Multiple Images ```bash python scripts/generate.py \ --prompt "Abstract colorful art" \ --output art.png \ --num-images 3 ``` ## Arguments | Argument | Short | Required | Default | Description | |----------|-------|----------|---------|-------------| | `--prompt` | `-p` | Yes | — | Text prompt describing the desired image | | `--output` | `-o` | No | `output.png` | Output path for the generated image | | `--model` | `-m` | No | `gemini-2.0-flash-preview-image-generation` | Model to use for generation | | `--input-image` | `-i` | No | — | Reference image for image-to-image mode | | `--aspect-ratio` | `-a` | No | — | Aspect ratio: `1:1`, `16:9`, `9:16`, `4:3`, `3:4` | | `--num-images` | `-n` | No | `1` | Number of images to generate | ## Available Models - `gemini-2.0-flash-preview-image-generation` - Fast, optimized for speed and lower latency - `imagen-3.0-generate-002` - High quality image generation ## Programmatic Usage Import the function directly in Python: ```python from scripts.generate import generate_image paths = generate_image( prompt="A serene mountain lake under moonlight", output_path="./outputs/lake.png", aspect_ratio="16:9", num_images=2, ) ``` ## Tips - **Detailed prompts work better**: Instead of "a cat", try "a fluffy orange tabby cat sitting on a windowsill, soft morning light, photorealistic" - **Specify style**: Include style keywords like "digital art", "oil painting", "photorealistic", "anime style" - **Use aspect ratios**: Match the aspect ratio to your intended use (16:9 for landscapes, 9:16 for portraits/mobile) - **Image-to-image**: Great for style transfer, variations, or guided modifications of existing images ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/image-generation/scripts/generate.py ================================================ #!/usr/bin/env python3 """ Image generation script using Nano Banana (Google Gemini Image API). Supports text-to-image and image-to-image generation with configurable options. """ import argparse import base64 import os import sys from io import BytesIO from pathlib import Path from PIL import Image def load_image_as_base64(image_path: str) -> tuple[str, str]: """Load an image file and return base64 data and mime type.""" path = Path(image_path) if not path.exists(): raise FileNotFoundError(f"Image not found: {image_path}") # Determine mime type from extension ext = path.suffix.lower() mime_types = { ".png": "image/png", ".jpg": "image/jpeg", ".jpeg": "image/jpeg", ".gif": "image/gif", ".webp": "image/webp", } mime_type = mime_types.get(ext, "image/png") with open(image_path, "rb") as f: data = base64.b64encode(f.read()).decode("utf-8") return data, mime_type def generate_image( prompt: str, output_path: str, model: str = "gemini-3-pro-image-preview", input_image: str | None = None, aspect_ratio: str | None = None, # noqa: ARG001 num_images: int = 1, ) -> list[str]: """ Generate image(s) using Google Gemini / Nano Banana API. Args: prompt: Text description for image generation. output_path: Path to save the generated image(s). model: Model ID to use for generation. input_image: Optional path to reference image for image-to-image mode. aspect_ratio: Aspect ratio (e.g., "1:1", "16:9", "9:16", "4:3", "3:4"). num_images: Number of images to generate. Returns: List of paths to saved images. """ api_key = os.environ.get("GEMINI_API_KEY") or os.environ.get("GENAI_API_KEY") if not api_key: raise ValueError( "API key not found. Set GEMINI_API_KEY or GENAI_API_KEY environment variable." ) # lazy importing since very heavy libs from google import genai from google.genai import types client = genai.Client(api_key=api_key) # Build content parts parts: list[types.Part] = [] # Add reference image if provided (image-to-image mode) if input_image: img_data, mime_type = load_image_as_base64(input_image) parts.append( types.Part.from_bytes( data=base64.b64decode(img_data), mime_type=mime_type, ) ) # Add text prompt parts.append(types.Part.from_text(text=prompt)) # Build generation config generate_config = types.GenerateContentConfig( response_modalities=["TEXT", "IMAGE"], ) saved_paths: list[str] = [] output_dir = Path(output_path).parent output_dir.mkdir(parents=True, exist_ok=True) base_name = Path(output_path).stem extension = Path(output_path).suffix or ".png" for i in range(num_images): response = client.models.generate_content( model=model, contents=types.Content(parts=parts), config=generate_config, ) # Validate response if not response.candidates: raise ValueError("No candidates returned from the API") candidate = response.candidates[0] if not candidate.content or not candidate.content.parts: raise ValueError("No content parts returned from the API") # Process response parts image_count = 0 for part in candidate.content.parts: if part.inline_data is not None and part.inline_data.data is not None: # Extract and save the image image_data = part.inline_data.data image = Image.open(BytesIO(image_data)) # Generate output filename if num_images == 1 and image_count == 0: save_path = output_path else: save_path = str( output_dir / f"{base_name}_{i + 1}_{image_count + 1}{extension}" ) image.save(save_path) saved_paths.append(save_path) print(f"Saved: {save_path}") image_count += 1 elif part.text: # Print any text response from the model print(f"Model response: {part.text}") return saved_paths def main() -> None: """Main entry point for CLI usage.""" parser = argparse.ArgumentParser( description="Generate images using Nano Banana (Google Gemini Image API).", formatter_class=argparse.RawDescriptionHelpFormatter, epilog=""" Examples: # Basic text-to-image generation python generate.py --prompt "A futuristic city at sunset" --output city.png # Generate with specific aspect ratio python generate.py --prompt "Mountain landscape" --output landscape.png --aspect-ratio 16:9 # Image-to-image mode (use reference image) python generate.py --prompt "Make it more colorful" --input-image ref.png --output colorful.png # Generate multiple images python generate.py --prompt "Abstract art" --output art.png --num-images 3 """, ) parser.add_argument( "--prompt", "-p", type=str, required=True, help="Text prompt describing the desired image.", ) parser.add_argument( "--output", "-o", type=str, default="output.png", help="Output path for the generated image (default: output.png).", ) parser.add_argument( "--model", "-m", type=str, default="gemini-3-pro-image-preview", help="Model to use (default: gemini-3-pro-image-preview).", ) parser.add_argument( "--input-image", "-i", type=str, help="Path to reference image for image-to-image generation.", ) parser.add_argument( "--aspect-ratio", "-a", type=str, choices=["1:1", "16:9", "9:16", "4:3", "3:4"], help="Aspect ratio for the generated image.", ) parser.add_argument( "--num-images", "-n", type=int, default=1, help="Number of images to generate (default: 1).", ) args = parser.parse_args() try: saved_paths = generate_image( prompt=args.prompt, output_path=args.output, model=args.model, input_image=args.input_image, aspect_ratio=args.aspect_ratio, num_images=args.num_images, ) print(f"\nSuccessfully generated {len(saved_paths)} image(s):") for path in saved_paths: print(f" - {path}") except Exception as e: print(f"Error: {e}", file=sys.stderr) sys.exit(1) if __name__ == "__main__": main() ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/SKILL.md ================================================ --- name: pptx description: "Use this skill any time a .pptx file is involved in any way — as input, output, or both. This includes: creating slide decks, pitch decks, or presentations; reading, parsing, or extracting text from any .pptx file (even if the extracted content will be used elsewhere, like in an email or summary); editing, modifying, or updating existing presentations; combining or splitting slide files; working with templates, layouts, speaker notes, or comments. Trigger whenever the user mentions \"deck,\" \"slides,\" \"presentation,\" or references a .pptx filename, regardless of what they plan to do with the content afterward. If a .pptx file needs to be opened, created, or touched, use this skill." license: Proprietary. LICENSE.txt has complete terms --- # PPTX Skill > **Path convention**: All commands run from the **session workspace** (your working directory). Never `cd` into the skill directory. Prefix all skill scripts with `.opencode/skills/pptx/`. All generated files (unpacked dirs, output presentations, thumbnails, PDFs, images) go in `outputs/`. ## Quick Reference | Task | Guide | |------|-------| | Read/analyze content | `python -m markitdown presentation.pptx` | | Edit or create from template | Read [editing.md](editing.md) | | Create from scratch | Read [pptxgenjs.md](pptxgenjs.md) | --- ## Reading Content ```bash # Text extraction python -m markitdown presentation.pptx # Visual overview python .opencode/skills/pptx/scripts/thumbnail.py presentation.pptx # Raw XML python .opencode/skills/pptx/scripts/office/unpack.py presentation.pptx outputs/unpacked/ ``` --- ## Editing Workflow **Read [editing.md](editing.md) for full details.** 1. Analyze template with `thumbnail.py` 2. Unpack → manipulate slides → edit content → clean → pack --- ## Creating from Scratch **Read [pptxgenjs.md](pptxgenjs.md) for full details.** Use when no template or reference presentation is available. --- ## Design Ideas **Don't create boring slides.** Plain bullets on a white background won't impress anyone. Consider ideas from this list for each slide. ### Before Starting - **Pick a bold, content-informed color palette**: The palette should feel designed for THIS topic. If swapping your colors into a completely different presentation would still "work," you haven't made specific enough choices. - **Dominance over equality**: One color should dominate (60-70% visual weight), with 1-2 supporting tones and one sharp accent. Never give all colors equal weight. - **Dark/light contrast**: Dark backgrounds for title + conclusion slides, light for content ("sandwich" structure). Or commit to dark throughout for a premium feel. - **Commit to a visual motif**: Pick ONE distinctive element and repeat it — rounded image frames, icons in colored circles, thick single-side borders. Carry it across every slide. ### Color Palettes Choose colors that match your topic — don't default to generic blue. Use these palettes as inspiration: | Theme | Primary | Secondary | Accent | |-------|---------|-----------|--------| | **Midnight Executive** | `1E2761` (navy) | `CADCFC` (ice blue) | `FFFFFF` (white) | | **Forest & Moss** | `2C5F2D` (forest) | `97BC62` (moss) | `F5F5F5` (cream) | | **Coral Energy** | `F96167` (coral) | `F9E795` (gold) | `2F3C7E` (navy) | | **Warm Terracotta** | `B85042` (terracotta) | `E7E8D1` (sand) | `A7BEAE` (sage) | | **Ocean Gradient** | `065A82` (deep blue) | `1C7293` (teal) | `21295C` (midnight) | | **Charcoal Minimal** | `36454F` (charcoal) | `F2F2F2` (off-white) | `212121` (black) | | **Teal Trust** | `028090` (teal) | `00A896` (seafoam) | `02C39A` (mint) | | **Berry & Cream** | `6D2E46` (berry) | `A26769` (dusty rose) | `ECE2D0` (cream) | | **Sage Calm** | `84B59F` (sage) | `69A297` (eucalyptus) | `50808E` (slate) | | **Cherry Bold** | `990011` (cherry) | `FCF6F5` (off-white) | `2F3C7E` (navy) | ### For Each Slide **Every slide needs a visual element** — image, chart, icon, or shape. Text-only slides are forgettable. **Layout options:** - Two-column (text left, illustration on right) - Icon + text rows (icon in colored circle, bold header, description below) - 2x2 or 2x3 grid (image on one side, grid of content blocks on other) - Half-bleed image (full left or right side) with content overlay **Data display:** - Large stat callouts (big numbers 60-72pt with small labels below) - Comparison columns (before/after, pros/cons, side-by-side options) - Timeline or process flow (numbered steps, arrows) **Visual polish:** - Icons in small colored circles next to section headers - Italic accent text for key stats or taglines ### Typography **Choose an interesting font pairing** — don't default to Arial. Pick a header font with personality and pair it with a clean body font. | Header Font | Body Font | |-------------|-----------| | Georgia | Calibri | | Arial Black | Arial | | Calibri | Calibri Light | | Cambria | Calibri | | Trebuchet MS | Calibri | | Impact | Arial | | Palatino | Garamond | | Consolas | Calibri | | Element | Size | |---------|------| | Slide title | 36-44pt bold | | Section header | 20-24pt bold | | Body text | 14-16pt | | Captions | 10-12pt muted | ### Spacing - 0.5" minimum margins - 0.3-0.5" between content blocks - Leave breathing room—don't fill every inch ### Avoid (Common Mistakes) - **Don't repeat the same layout** — vary columns, cards, and callouts across slides - **Don't center body text** — left-align paragraphs and lists; center only titles - **Don't skimp on size contrast** — titles need 36pt+ to stand out from 14-16pt body - **Don't default to blue** — pick colors that reflect the specific topic - **Don't mix spacing randomly** — choose 0.3" or 0.5" gaps and use consistently - **Don't style one slide and leave the rest plain** — commit fully or keep it simple throughout - **Don't create text-only slides** — add images, icons, charts, or visual elements; avoid plain title + bullets - **Don't forget text box padding** — when aligning lines or shapes with text edges, set `margin: 0` on the text box or offset the shape to account for padding - **Don't use low-contrast elements** — icons AND text need strong contrast against the background; avoid light text on light backgrounds or dark text on dark backgrounds - **NEVER use accent lines under titles** — these are a hallmark of AI-generated slides; use whitespace or background color instead --- ## QA (Required) **Assume there are problems. Your job is to find them.** Your first render is almost never correct. Approach QA as a bug hunt, not a confirmation step. If you found zero issues on first inspection, you weren't looking hard enough. ### Content QA ```bash python -m markitdown output.pptx ``` Check for missing content, typos, wrong order. **When using templates, check for leftover placeholder text:** ```bash python -m markitdown output.pptx | grep -iE "xxxx|lorem|ipsum|this.*(page|slide).*layout" ``` If grep returns results, fix them before declaring success. ### Visual QA **⚠️ USE SUBAGENTS** — even for 2-3 slides. You've been staring at the code and will see what you expect, not what's there. Subagents have fresh eyes. Convert slides to images (see [Converting to Images](#converting-to-images)), then use this prompt: ``` Visually inspect these slides. Assume there are issues — find them. Look for: - Overlapping elements (text through shapes, lines through words, stacked elements) - Text overflow or cut off at edges/box boundaries - Decorative lines positioned for single-line text but title wrapped to two lines - Source citations or footers colliding with content above - Elements too close (< 0.3" gaps) or cards/sections nearly touching - Uneven gaps (large empty area in one place, cramped in another) - Insufficient margin from slide edges (< 0.5") - Columns or similar elements not aligned consistently - Low-contrast text (e.g., light gray text on cream-colored background) - Low-contrast icons (e.g., dark icons on dark backgrounds without a contrasting circle) - Text boxes too narrow causing excessive wrapping - Leftover placeholder content For each slide, list issues or areas of concern, even if minor. Read and analyze these images: 1. /path/to/slide-01.jpg (Expected: [brief description]) 2. /path/to/slide-02.jpg (Expected: [brief description]) Report ALL issues found, including minor ones. ``` ### Verification Loop 1. Generate slides → Convert to images → Inspect 2. **List issues found** (if none found, look again more critically) 3. Fix issues 4. **Re-verify affected slides** — one fix often creates another problem 5. Repeat until a full pass reveals no new issues **Do not declare success until you've completed at least one fix-and-verify cycle.** --- ## Converting to Images Convert presentations to individual slide images for visual inspection: ```bash python .opencode/skills/pptx/scripts/office/soffice.py --headless --convert-to pdf outputs/output.pptx pdftoppm -jpeg -r 150 outputs/output.pdf outputs/slide ``` This creates `slide-01.jpg`, `slide-02.jpg`, etc. To re-render specific slides after fixes: ```bash pdftoppm -jpeg -r 150 -f N -l N outputs/output.pdf outputs/slide-fixed ``` --- ## Dependencies - `pip install "markitdown[pptx]"` - text extraction - `pip install Pillow` - thumbnail grids - `npm install -g pptxgenjs` - creating from scratch - LibreOffice (`soffice`) - PDF conversion (auto-configured for sandboxed environments via `.opencode/skills/pptx/scripts/office/soffice.py`) - Poppler (`pdftoppm`) - PDF to images ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/editing.md ================================================ # Editing Presentations > **Path convention**: All commands run from the **session workspace**. Never `cd` into the skill directory. Prefix all skill scripts with `.opencode/skills/pptx/`. All generated files go in `outputs/`. ## Template-Based Workflow When using an existing presentation as a template: 1. **Analyze existing slides**: ```bash python .opencode/skills/pptx/scripts/thumbnail.py template.pptx outputs/thumbnails python -m markitdown template.pptx ``` Review `outputs/thumbnails.jpg` to see layouts, and markitdown output to see placeholder text. 2. **Plan slide mapping**: For each content section, choose a template slide. ⚠️ **USE VARIED LAYOUTS** — monotonous presentations are a common failure mode. Don't default to basic title + bullet slides. Actively seek out: - Multi-column layouts (2-column, 3-column) - Image + text combinations - Full-bleed images with text overlay - Quote or callout slides - Section dividers - Stat/number callouts - Icon grids or icon + text rows **Avoid:** Repeating the same text-heavy layout for every slide. Match content type to layout style (e.g., key points → bullet slide, team info → multi-column, testimonials → quote slide). 3. **Unpack**: `python .opencode/skills/pptx/scripts/office/unpack.py template.pptx outputs/unpacked/` 4. **Build presentation** (do this yourself, not with subagents): - Delete unwanted slides (remove from ``) - Duplicate slides you want to reuse (`add_slide.py`) - Reorder slides in `` - **Complete all structural changes before step 5** 5. **Edit content**: Update text in each `slide{N}.xml`. **Use subagents here if available** — slides are separate XML files, so subagents can edit in parallel. 6. **Clean**: `python .opencode/skills/pptx/scripts/clean.py outputs/unpacked/` 7. **Pack**: `python .opencode/skills/pptx/scripts/office/pack.py outputs/unpacked/ outputs/output.pptx --original template.pptx` --- ## Scripts | Script | Purpose | |--------|---------| | `unpack.py` | Extract and pretty-print PPTX | | `add_slide.py` | Duplicate slide or create from layout | | `clean.py` | Remove orphaned files | | `pack.py` | Repack with validation | | `thumbnail.py` | Create visual grid of slides | ### unpack.py ```bash python .opencode/skills/pptx/scripts/office/unpack.py input.pptx outputs/unpacked/ ``` Extracts PPTX, pretty-prints XML, escapes smart quotes. ### add_slide.py ```bash python .opencode/skills/pptx/scripts/add_slide.py outputs/unpacked/ slide2.xml # Duplicate slide python .opencode/skills/pptx/scripts/add_slide.py outputs/unpacked/ slideLayout2.xml # From layout ``` Prints `` to add to `` at desired position. ### clean.py ```bash python .opencode/skills/pptx/scripts/clean.py outputs/unpacked/ ``` Removes slides not in ``, unreferenced media, orphaned rels. ### pack.py ```bash python .opencode/skills/pptx/scripts/office/pack.py outputs/unpacked/ outputs/output.pptx --original input.pptx ``` Validates, repairs, condenses XML, re-encodes smart quotes. ### thumbnail.py ```bash python .opencode/skills/pptx/scripts/thumbnail.py input.pptx outputs/thumbnails [--cols N] ``` Creates `outputs/thumbnails.jpg` with slide filenames as labels. Default 3 columns, max 12 per grid. **Use for template analysis only** (choosing layouts). For visual QA, use `soffice` + `pdftoppm` to create full-resolution individual slide images—see SKILL.md. --- ## Slide Operations Slide order is in `outputs/unpacked/ppt/presentation.xml` → ``. **Reorder**: Rearrange `` elements. **Delete**: Remove ``, then run `clean.py`. **See available layouts**: `ls outputs/unpacked/ppt/slideLayouts/` **Add**: Use `add_slide.py`. Never manually copy slide files—the script handles notes references, Content_Types.xml, and relationship IDs that manual copying misses. --- ## Editing Content **Subagents:** If available, use them here (after completing step 4). Each slide is a separate XML file, so subagents can edit in parallel. In your prompt to subagents, include: - The slide file path(s) to edit - **"Use the Edit tool for all changes"** - The formatting rules and common pitfalls below For each slide: 1. Read the slide's XML 2. Identify ALL placeholder content—text, images, charts, icons, captions 3. Replace each placeholder with final content **Use the Edit tool, not sed or Python scripts.** The Edit tool forces specificity about what to replace and where, yielding better reliability. ### Formatting Rules - **Bold all headers, subheadings, and inline labels**: Use `b="1"` on ``. This includes: - Slide titles - Section headers within a slide - Inline labels like (e.g.: "Status:", "Description:") at the start of a line - **Never use unicode bullets (•)**: Use proper list formatting with `` or `` - **Bullet consistency**: Let bullets inherit from the layout. Only specify `` or ``. --- ## Common Pitfalls ### Template Adaptation When source content has fewer items than the template: - **Remove excess elements entirely** (images, shapes, text boxes), don't just clear text - Check for orphaned visuals after clearing text content - Run visual QA to catch mismatched counts When replacing text with different length content: - **Shorter replacements**: Usually safe - **Longer replacements**: May overflow or wrap unexpectedly - Test with visual QA after text changes - Consider truncating or splitting content to fit the template's design constraints **Template slots ≠ Source items**: If template has 4 team members but source has 3 users, delete the 4th member's entire group (image + text boxes), not just the text. ### Multi-Item Content If source has multiple items (numbered lists, multiple sections), create separate `` elements for each — **never concatenate into one string**. **❌ WRONG** — all items in one paragraph: ```xml Step 1: Do the first thing. Step 2: Do the second thing. ``` **✅ CORRECT** — separate paragraphs with bold headers: ```xml Step 1 Do the first thing. Step 2 ``` Copy `` from the original paragraph to preserve line spacing. Use `b="1"` on headers. ### Smart Quotes Handled automatically by unpack/pack. But the Edit tool converts smart quotes to ASCII. **When adding new text with quotes, use XML entities:** ```xml the “Agreement” ``` | Character | Name | Unicode | XML Entity | |-----------|------|---------|------------| | `“` | Left double quote | U+201C | `“` | | `”` | Right double quote | U+201D | `”` | | `‘` | Left single quote | U+2018 | `‘` | | `’` | Right single quote | U+2019 | `’` | ### Other - **Whitespace**: Use `xml:space="preserve"` on `` with leading/trailing spaces - **XML parsing**: Use `defusedxml.minidom`, not `xml.etree.ElementTree` (corrupts namespaces) ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/pptxgenjs.md ================================================ # PptxGenJS Tutorial ## Setup & Basic Structure ```javascript const pptxgen = require("pptxgenjs"); let pres = new pptxgen(); pres.layout = 'LAYOUT_16x9'; // or 'LAYOUT_16x10', 'LAYOUT_4x3', 'LAYOUT_WIDE' pres.author = 'Your Name'; pres.title = 'Presentation Title'; let slide = pres.addSlide(); slide.addText("Hello World!", { x: 0.5, y: 0.5, fontSize: 36, color: "363636" }); pres.writeFile({ fileName: "Presentation.pptx" }); ``` ## Layout Dimensions Slide dimensions (coordinates in inches): - `LAYOUT_16x9`: 10" × 5.625" (default) - `LAYOUT_16x10`: 10" × 6.25" - `LAYOUT_4x3`: 10" × 7.5" - `LAYOUT_WIDE`: 13.3" × 7.5" --- ## Text & Formatting ```javascript // Basic text slide.addText("Simple Text", { x: 1, y: 1, w: 8, h: 2, fontSize: 24, fontFace: "Arial", color: "363636", bold: true, align: "center", valign: "middle" }); // Character spacing (use charSpacing, not letterSpacing which is silently ignored) slide.addText("SPACED TEXT", { x: 1, y: 1, w: 8, h: 1, charSpacing: 6 }); // Rich text arrays slide.addText([ { text: "Bold ", options: { bold: true } }, { text: "Italic ", options: { italic: true } } ], { x: 1, y: 3, w: 8, h: 1 }); // Multi-line text (requires breakLine: true) slide.addText([ { text: "Line 1", options: { breakLine: true } }, { text: "Line 2", options: { breakLine: true } }, { text: "Line 3" } // Last item doesn't need breakLine ], { x: 0.5, y: 0.5, w: 8, h: 2 }); // Text box margin (internal padding) slide.addText("Title", { x: 0.5, y: 0.3, w: 9, h: 0.6, margin: 0 // Use 0 when aligning text with other elements like shapes or icons }); ``` **Tip:** Text boxes have internal margin by default. Set `margin: 0` when you need text to align precisely with shapes, lines, or icons at the same x-position. --- ## Lists & Bullets ```javascript // ✅ CORRECT: Multiple bullets slide.addText([ { text: "First item", options: { bullet: true, breakLine: true } }, { text: "Second item", options: { bullet: true, breakLine: true } }, { text: "Third item", options: { bullet: true } } ], { x: 0.5, y: 0.5, w: 8, h: 3 }); // ❌ WRONG: Never use unicode bullets slide.addText("• First item", { ... }); // Creates double bullets // Sub-items and numbered lists { text: "Sub-item", options: { bullet: true, indentLevel: 1 } } { text: "First", options: { bullet: { type: "number" }, breakLine: true } } ``` --- ## Shapes ```javascript slide.addShape(pres.shapes.RECTANGLE, { x: 0.5, y: 0.8, w: 1.5, h: 3.0, fill: { color: "FF0000" }, line: { color: "000000", width: 2 } }); slide.addShape(pres.shapes.OVAL, { x: 4, y: 1, w: 2, h: 2, fill: { color: "0000FF" } }); slide.addShape(pres.shapes.LINE, { x: 1, y: 3, w: 5, h: 0, line: { color: "FF0000", width: 3, dashType: "dash" } }); // With transparency slide.addShape(pres.shapes.RECTANGLE, { x: 1, y: 1, w: 3, h: 2, fill: { color: "0088CC", transparency: 50 } }); // Rounded rectangle (rectRadius only works with ROUNDED_RECTANGLE, not RECTANGLE) // ⚠️ Don't pair with rectangular accent overlays — they won't cover rounded corners. Use RECTANGLE instead. slide.addShape(pres.shapes.ROUNDED_RECTANGLE, { x: 1, y: 1, w: 3, h: 2, fill: { color: "FFFFFF" }, rectRadius: 0.1 }); // With shadow slide.addShape(pres.shapes.RECTANGLE, { x: 1, y: 1, w: 3, h: 2, fill: { color: "FFFFFF" }, shadow: { type: "outer", color: "000000", blur: 6, offset: 2, angle: 135, opacity: 0.15 } }); ``` Shadow options: | Property | Type | Range | Notes | |----------|------|-------|-------| | `type` | string | `"outer"`, `"inner"` | | | `color` | string | 6-char hex (e.g. `"000000"`) | No `#` prefix, no 8-char hex — see Common Pitfalls | | `blur` | number | 0-100 pt | | | `offset` | number | 0-200 pt | **Must be non-negative** — negative values corrupt the file | | `angle` | number | 0-359 degrees | Direction the shadow falls (135 = bottom-right, 270 = upward) | | `opacity` | number | 0.0-1.0 | Use this for transparency, never encode in color string | To cast a shadow upward (e.g. on a footer bar), use `angle: 270` with a positive offset — do **not** use a negative offset. **Note**: Gradient fills are not natively supported. Use a gradient image as a background instead. --- ## Images ### Image Sources ```javascript // From file path slide.addImage({ path: "images/chart.png", x: 1, y: 1, w: 5, h: 3 }); // From URL slide.addImage({ path: "https://example.com/image.jpg", x: 1, y: 1, w: 5, h: 3 }); // From base64 (faster, no file I/O) slide.addImage({ data: "image/png;base64,iVBORw0KGgo...", x: 1, y: 1, w: 5, h: 3 }); ``` ### Image Options ```javascript slide.addImage({ path: "image.png", x: 1, y: 1, w: 5, h: 3, rotate: 45, // 0-359 degrees rounding: true, // Circular crop transparency: 50, // 0-100 flipH: true, // Horizontal flip flipV: false, // Vertical flip altText: "Description", // Accessibility hyperlink: { url: "https://example.com" } }); ``` ### Image Sizing Modes ```javascript // Contain - fit inside, preserve ratio { sizing: { type: 'contain', w: 4, h: 3 } } // Cover - fill area, preserve ratio (may crop) { sizing: { type: 'cover', w: 4, h: 3 } } // Crop - cut specific portion { sizing: { type: 'crop', x: 0.5, y: 0.5, w: 2, h: 2 } } ``` ### Calculate Dimensions (preserve aspect ratio) ```javascript const origWidth = 1978, origHeight = 923, maxHeight = 3.0; const calcWidth = maxHeight * (origWidth / origHeight); const centerX = (10 - calcWidth) / 2; slide.addImage({ path: "image.png", x: centerX, y: 1.2, w: calcWidth, h: maxHeight }); ``` ### Supported Formats - **Standard**: PNG, JPG, GIF (animated GIFs work in Microsoft 365) - **SVG**: Works in modern PowerPoint/Microsoft 365 --- ## Icons Use react-icons to generate SVG icons, then rasterize to PNG for universal compatibility. ### Setup ```javascript const React = require("react"); const ReactDOMServer = require("react-dom/server"); const sharp = require("sharp"); const { FaCheckCircle, FaChartLine } = require("react-icons/fa"); function renderIconSvg(IconComponent, color = "#000000", size = 256) { return ReactDOMServer.renderToStaticMarkup( React.createElement(IconComponent, { color, size: String(size) }) ); } async function iconToBase64Png(IconComponent, color, size = 256) { const svg = renderIconSvg(IconComponent, color, size); const pngBuffer = await sharp(Buffer.from(svg)).png().toBuffer(); return "image/png;base64," + pngBuffer.toString("base64"); } ``` ### Add Icon to Slide ```javascript const iconData = await iconToBase64Png(FaCheckCircle, "#4472C4", 256); slide.addImage({ data: iconData, x: 1, y: 1, w: 0.5, h: 0.5 // Size in inches }); ``` **Note**: Use size 256 or higher for crisp icons. The size parameter controls the rasterization resolution, not the display size on the slide (which is set by `w` and `h` in inches). ### Icon Libraries Install: `npm install -g react-icons react react-dom sharp` Popular icon sets in react-icons: - `react-icons/fa` - Font Awesome - `react-icons/md` - Material Design - `react-icons/hi` - Heroicons - `react-icons/bi` - Bootstrap Icons --- ## Slide Backgrounds ```javascript // Solid color slide.background = { color: "F1F1F1" }; // Color with transparency slide.background = { color: "FF3399", transparency: 50 }; // Image from URL slide.background = { path: "https://example.com/bg.jpg" }; // Image from base64 slide.background = { data: "image/png;base64,iVBORw0KGgo..." }; ``` --- ## Tables ```javascript slide.addTable([ ["Header 1", "Header 2"], ["Cell 1", "Cell 2"] ], { x: 1, y: 1, w: 8, h: 2, border: { pt: 1, color: "999999" }, fill: { color: "F1F1F1" } }); // Advanced with merged cells let tableData = [ [{ text: "Header", options: { fill: { color: "6699CC" }, color: "FFFFFF", bold: true } }, "Cell"], [{ text: "Merged", options: { colspan: 2 } }] ]; slide.addTable(tableData, { x: 1, y: 3.5, w: 8, colW: [4, 4] }); ``` --- ## Charts ```javascript // Bar chart slide.addChart(pres.charts.BAR, [{ name: "Sales", labels: ["Q1", "Q2", "Q3", "Q4"], values: [4500, 5500, 6200, 7100] }], { x: 0.5, y: 0.6, w: 6, h: 3, barDir: 'col', showTitle: true, title: 'Quarterly Sales' }); // Line chart slide.addChart(pres.charts.LINE, [{ name: "Temp", labels: ["Jan", "Feb", "Mar"], values: [32, 35, 42] }], { x: 0.5, y: 4, w: 6, h: 3, lineSize: 3, lineSmooth: true }); // Pie chart slide.addChart(pres.charts.PIE, [{ name: "Share", labels: ["A", "B", "Other"], values: [35, 45, 20] }], { x: 7, y: 1, w: 5, h: 4, showPercent: true }); ``` ### Better-Looking Charts Default charts look dated. Apply these options for a modern, clean appearance: ```javascript slide.addChart(pres.charts.BAR, chartData, { x: 0.5, y: 1, w: 9, h: 4, barDir: "col", // Custom colors (match your presentation palette) chartColors: ["0D9488", "14B8A6", "5EEAD4"], // Clean background chartArea: { fill: { color: "FFFFFF" }, roundedCorners: true }, // Muted axis labels catAxisLabelColor: "64748B", valAxisLabelColor: "64748B", // Subtle grid (value axis only) valGridLine: { color: "E2E8F0", size: 0.5 }, catGridLine: { style: "none" }, // Data labels on bars showValue: true, dataLabelPosition: "outEnd", dataLabelColor: "1E293B", // Hide legend for single series showLegend: false, }); ``` **Key styling options:** - `chartColors: [...]` - hex colors for series/segments - `chartArea: { fill, border, roundedCorners }` - chart background - `catGridLine/valGridLine: { color, style, size }` - grid lines (`style: "none"` to hide) - `lineSmooth: true` - curved lines (line charts) - `legendPos: "r"` - legend position: "b", "t", "l", "r", "tr" --- ## Slide Masters ```javascript pres.defineSlideMaster({ title: 'TITLE_SLIDE', background: { color: '283A5E' }, objects: [{ placeholder: { options: { name: 'title', type: 'title', x: 1, y: 2, w: 8, h: 2 } } }] }); let titleSlide = pres.addSlide({ masterName: "TITLE_SLIDE" }); titleSlide.addText("My Title", { placeholder: "title" }); ``` --- ## Common Pitfalls ⚠️ These issues cause file corruption, visual bugs, or broken output. Avoid them. 1. **NEVER use "#" with hex colors** - causes file corruption ```javascript color: "FF0000" // ✅ CORRECT color: "#FF0000" // ❌ WRONG ``` 2. **NEVER encode opacity in hex color strings** - 8-char colors (e.g., `"00000020"`) corrupt the file. Use the `opacity` property instead. ```javascript shadow: { type: "outer", blur: 6, offset: 2, color: "00000020" } // ❌ CORRUPTS FILE shadow: { type: "outer", blur: 6, offset: 2, color: "000000", opacity: 0.12 } // ✅ CORRECT ``` 3. **Use `bullet: true`** - NEVER unicode symbols like "•" (creates double bullets) 4. **Use `breakLine: true`** between array items or text runs together 5. **Avoid `lineSpacing` with bullets** - causes excessive gaps; use `paraSpaceAfter` instead 6. **Each presentation needs fresh instance** - don't reuse `pptxgen()` objects 7. **NEVER reuse option objects across calls** - PptxGenJS mutates objects in-place (e.g. converting shadow values to EMU). Sharing one object between multiple calls corrupts the second shape. ```javascript const shadow = { type: "outer", blur: 6, offset: 2, color: "000000", opacity: 0.15 }; slide.addShape(pres.shapes.RECTANGLE, { shadow, ... }); // ❌ second call gets already-converted values slide.addShape(pres.shapes.RECTANGLE, { shadow, ... }); const makeShadow = () => ({ type: "outer", blur: 6, offset: 2, color: "000000", opacity: 0.15 }); slide.addShape(pres.shapes.RECTANGLE, { shadow: makeShadow(), ... }); // ✅ fresh object each time slide.addShape(pres.shapes.RECTANGLE, { shadow: makeShadow(), ... }); ``` 8. **Don't use `ROUNDED_RECTANGLE` with accent borders** - rectangular overlay bars won't cover rounded corners. Use `RECTANGLE` instead. ```javascript // ❌ WRONG: Accent bar doesn't cover rounded corners slide.addShape(pres.shapes.ROUNDED_RECTANGLE, { x: 1, y: 1, w: 3, h: 1.5, fill: { color: "FFFFFF" } }); slide.addShape(pres.shapes.RECTANGLE, { x: 1, y: 1, w: 0.08, h: 1.5, fill: { color: "0891B2" } }); // ✅ CORRECT: Use RECTANGLE for clean alignment slide.addShape(pres.shapes.RECTANGLE, { x: 1, y: 1, w: 3, h: 1.5, fill: { color: "FFFFFF" } }); slide.addShape(pres.shapes.RECTANGLE, { x: 1, y: 1, w: 0.08, h: 1.5, fill: { color: "0891B2" } }); ``` --- ## Quick Reference - **Shapes**: RECTANGLE, OVAL, LINE, ROUNDED_RECTANGLE - **Charts**: BAR, LINE, PIE, DOUGHNUT, SCATTER, BUBBLE, RADAR - **Layouts**: LAYOUT_16x9 (10"×5.625"), LAYOUT_16x10, LAYOUT_4x3, LAYOUT_WIDE - **Alignment**: "left", "center", "right" - **Chart data labels**: "outEnd", "inEnd", "center" ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/__init__.py ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/add_slide.py ================================================ """Add a new slide to an unpacked PPTX directory. Usage: python add_slide.py The source can be: - A slide file (e.g., slide2.xml) - duplicates the slide - A layout file (e.g., slideLayout2.xml) - creates from layout Examples: python add_slide.py unpacked/ slide2.xml # Duplicates slide2, creates slide5.xml python add_slide.py unpacked/ slideLayout2.xml # Creates slide5.xml from slideLayout2.xml To see available layouts: ls unpacked/ppt/slideLayouts/ Prints the element to add to presentation.xml. """ import re import shutil import sys from pathlib import Path def get_next_slide_number(slides_dir: Path) -> int: existing = [ int(m.group(1)) for f in slides_dir.glob("slide*.xml") if (m := re.match(r"slide(\d+)\.xml", f.name)) ] return max(existing) + 1 if existing else 1 def create_slide_from_layout(unpacked_dir: Path, layout_file: str) -> None: slides_dir = unpacked_dir / "ppt" / "slides" rels_dir = slides_dir / "_rels" layouts_dir = unpacked_dir / "ppt" / "slideLayouts" layout_path = layouts_dir / layout_file if not layout_path.exists(): print(f"Error: {layout_path} not found", file=sys.stderr) sys.exit(1) next_num = get_next_slide_number(slides_dir) dest = f"slide{next_num}.xml" dest_slide = slides_dir / dest dest_rels = rels_dir / f"{dest}.rels" slide_xml = """ """ dest_slide.write_text(slide_xml, encoding="utf-8") rels_dir.mkdir(exist_ok=True) rels_xml = f""" """ dest_rels.write_text(rels_xml, encoding="utf-8") _add_to_content_types(unpacked_dir, dest) rid = _add_to_presentation_rels(unpacked_dir, dest) next_slide_id = _get_next_slide_id(unpacked_dir) print(f"Created {dest} from {layout_file}") print( f'Add to presentation.xml : ' ) def duplicate_slide(unpacked_dir: Path, source: str) -> None: slides_dir = unpacked_dir / "ppt" / "slides" rels_dir = slides_dir / "_rels" source_slide = slides_dir / source if not source_slide.exists(): print(f"Error: {source_slide} not found", file=sys.stderr) sys.exit(1) next_num = get_next_slide_number(slides_dir) dest = f"slide{next_num}.xml" dest_slide = slides_dir / dest source_rels = rels_dir / f"{source}.rels" dest_rels = rels_dir / f"{dest}.rels" shutil.copy2(source_slide, dest_slide) if source_rels.exists(): shutil.copy2(source_rels, dest_rels) rels_content = dest_rels.read_text(encoding="utf-8") rels_content = re.sub( r'\s*]*Type="[^"]*notesSlide"[^>]*/>\s*', "\n", rels_content, ) dest_rels.write_text(rels_content, encoding="utf-8") _add_to_content_types(unpacked_dir, dest) rid = _add_to_presentation_rels(unpacked_dir, dest) next_slide_id = _get_next_slide_id(unpacked_dir) print(f"Created {dest} from {source}") print( f'Add to presentation.xml : ' ) def _add_to_content_types(unpacked_dir: Path, dest: str) -> None: content_types_path = unpacked_dir / "[Content_Types].xml" content_types = content_types_path.read_text(encoding="utf-8") content_type = ( "application/vnd.openxmlformats-officedocument.presentationml.slide+xml" ) new_override = ( f'' ) if f"/ppt/slides/{dest}" not in content_types: content_types = content_types.replace("", f" {new_override}\n") content_types_path.write_text(content_types, encoding="utf-8") def _add_to_presentation_rels(unpacked_dir: Path, dest: str) -> str: pres_rels_path = unpacked_dir / "ppt" / "_rels" / "presentation.xml.rels" pres_rels = pres_rels_path.read_text(encoding="utf-8") rids = [int(m) for m in re.findall(r'Id="rId(\d+)"', pres_rels)] next_rid = max(rids) + 1 if rids else 1 rid = f"rId{next_rid}" slide_type = ( "http://schemas.openxmlformats.org/officeDocument/2006/relationships/slide" ) new_rel = f'' if f"slides/{dest}" not in pres_rels: pres_rels = pres_rels.replace( "", f" {new_rel}\n" ) pres_rels_path.write_text(pres_rels, encoding="utf-8") return rid def _get_next_slide_id(unpacked_dir: Path) -> int: pres_path = unpacked_dir / "ppt" / "presentation.xml" pres_content = pres_path.read_text(encoding="utf-8") slide_ids = [int(m) for m in re.findall(r']*id="(\d+)"', pres_content)] return max(slide_ids) + 1 if slide_ids else 256 def parse_source(source: str) -> tuple[str, str | None]: if source.startswith("slideLayout") and source.endswith(".xml"): return ("layout", source) return ("slide", None) if __name__ == "__main__": if len(sys.argv) != 3: print("Usage: python add_slide.py ", file=sys.stderr) print("", file=sys.stderr) print("Source can be:", file=sys.stderr) print(" slide2.xml - duplicate an existing slide", file=sys.stderr) print(" slideLayout2.xml - create from a layout template", file=sys.stderr) print("", file=sys.stderr) print( "To see available layouts: ls /ppt/slideLayouts/", file=sys.stderr, ) sys.exit(1) unpacked_dir = Path(sys.argv[1]) source = sys.argv[2] if not unpacked_dir.exists(): print(f"Error: {unpacked_dir} not found", file=sys.stderr) sys.exit(1) source_type, layout_file = parse_source(source) if source_type == "layout" and layout_file is not None: create_slide_from_layout(unpacked_dir, layout_file) else: duplicate_slide(unpacked_dir, source) ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/clean.py ================================================ """Remove unreferenced files from an unpacked PPTX directory. Usage: python clean.py Example: python clean.py unpacked/ This script removes: - Orphaned slides (not in sldIdLst) and their relationships - [trash] directory (unreferenced files) - Orphaned .rels files for deleted resources - Unreferenced media, embeddings, charts, diagrams, drawings, ink files - Unreferenced theme files - Unreferenced notes slides - Content-Type overrides for deleted files """ import re import sys from pathlib import Path import defusedxml.minidom def get_slides_in_sldidlst(unpacked_dir: Path) -> set[str]: pres_path = unpacked_dir / "ppt" / "presentation.xml" pres_rels_path = unpacked_dir / "ppt" / "_rels" / "presentation.xml.rels" if not pres_path.exists() or not pres_rels_path.exists(): return set() rels_dom = defusedxml.minidom.parse(str(pres_rels_path)) rid_to_slide = {} for rel in rels_dom.getElementsByTagName("Relationship"): rid = rel.getAttribute("Id") target = rel.getAttribute("Target") rel_type = rel.getAttribute("Type") if "slide" in rel_type and target.startswith("slides/"): rid_to_slide[rid] = target.replace("slides/", "") pres_content = pres_path.read_text(encoding="utf-8") referenced_rids = set(re.findall(r']*r:id="([^"]+)"', pres_content)) return {rid_to_slide[rid] for rid in referenced_rids if rid in rid_to_slide} def remove_orphaned_slides(unpacked_dir: Path) -> list[str]: slides_dir = unpacked_dir / "ppt" / "slides" slides_rels_dir = slides_dir / "_rels" pres_rels_path = unpacked_dir / "ppt" / "_rels" / "presentation.xml.rels" if not slides_dir.exists(): return [] referenced_slides = get_slides_in_sldidlst(unpacked_dir) removed = [] for slide_file in slides_dir.glob("slide*.xml"): if slide_file.name not in referenced_slides: rel_path = slide_file.relative_to(unpacked_dir) slide_file.unlink() removed.append(str(rel_path)) rels_file = slides_rels_dir / f"{slide_file.name}.rels" if rels_file.exists(): rels_file.unlink() removed.append(str(rels_file.relative_to(unpacked_dir))) if removed and pres_rels_path.exists(): rels_dom = defusedxml.minidom.parse(str(pres_rels_path)) changed = False for rel in list(rels_dom.getElementsByTagName("Relationship")): target = rel.getAttribute("Target") if target.startswith("slides/"): slide_name = target.replace("slides/", "") if slide_name not in referenced_slides: if rel.parentNode: rel.parentNode.removeChild(rel) changed = True if changed: with open(pres_rels_path, "wb") as f: f.write(rels_dom.toxml(encoding="utf-8")) return removed def remove_trash_directory(unpacked_dir: Path) -> list[str]: trash_dir = unpacked_dir / "[trash]" removed = [] if trash_dir.exists() and trash_dir.is_dir(): for file_path in trash_dir.iterdir(): if file_path.is_file(): rel_path = file_path.relative_to(unpacked_dir) removed.append(str(rel_path)) file_path.unlink() trash_dir.rmdir() return removed def get_slide_referenced_files(unpacked_dir: Path) -> set: referenced = set() slides_rels_dir = unpacked_dir / "ppt" / "slides" / "_rels" if not slides_rels_dir.exists(): return referenced for rels_file in slides_rels_dir.glob("*.rels"): dom = defusedxml.minidom.parse(str(rels_file)) for rel in dom.getElementsByTagName("Relationship"): target = rel.getAttribute("Target") if not target: continue target_path = (rels_file.parent.parent / target).resolve() try: referenced.add(target_path.relative_to(unpacked_dir.resolve())) except ValueError: pass return referenced def remove_orphaned_rels_files(unpacked_dir: Path) -> list[str]: resource_dirs = ["charts", "diagrams", "drawings"] removed = [] slide_referenced = get_slide_referenced_files(unpacked_dir) for dir_name in resource_dirs: rels_dir = unpacked_dir / "ppt" / dir_name / "_rels" if not rels_dir.exists(): continue for rels_file in rels_dir.glob("*.rels"): resource_file = rels_dir.parent / rels_file.name.replace(".rels", "") try: resource_rel_path = resource_file.resolve().relative_to( unpacked_dir.resolve() ) except ValueError: continue if not resource_file.exists() or resource_rel_path not in slide_referenced: rels_file.unlink() rel_path = rels_file.relative_to(unpacked_dir) removed.append(str(rel_path)) return removed def get_referenced_files(unpacked_dir: Path) -> set: referenced = set() for rels_file in unpacked_dir.rglob("*.rels"): dom = defusedxml.minidom.parse(str(rels_file)) for rel in dom.getElementsByTagName("Relationship"): target = rel.getAttribute("Target") if not target: continue target_path = (rels_file.parent.parent / target).resolve() try: referenced.add(target_path.relative_to(unpacked_dir.resolve())) except ValueError: pass return referenced def remove_orphaned_files(unpacked_dir: Path, referenced: set) -> list[str]: resource_dirs = [ "media", "embeddings", "charts", "diagrams", "tags", "drawings", "ink", ] removed = [] for dir_name in resource_dirs: dir_path = unpacked_dir / "ppt" / dir_name if not dir_path.exists(): continue for file_path in dir_path.glob("*"): if not file_path.is_file(): continue rel_path = file_path.relative_to(unpacked_dir) if rel_path not in referenced: file_path.unlink() removed.append(str(rel_path)) theme_dir = unpacked_dir / "ppt" / "theme" if theme_dir.exists(): for file_path in theme_dir.glob("theme*.xml"): rel_path = file_path.relative_to(unpacked_dir) if rel_path not in referenced: file_path.unlink() removed.append(str(rel_path)) theme_rels = theme_dir / "_rels" / f"{file_path.name}.rels" if theme_rels.exists(): theme_rels.unlink() removed.append(str(theme_rels.relative_to(unpacked_dir))) notes_dir = unpacked_dir / "ppt" / "notesSlides" if notes_dir.exists(): for file_path in notes_dir.glob("*.xml"): if not file_path.is_file(): continue rel_path = file_path.relative_to(unpacked_dir) if rel_path not in referenced: file_path.unlink() removed.append(str(rel_path)) notes_rels_dir = notes_dir / "_rels" if notes_rels_dir.exists(): for file_path in notes_rels_dir.glob("*.rels"): notes_file = notes_dir / file_path.name.replace(".rels", "") if not notes_file.exists(): file_path.unlink() removed.append(str(file_path.relative_to(unpacked_dir))) return removed def update_content_types(unpacked_dir: Path, removed_files: list[str]) -> None: ct_path = unpacked_dir / "[Content_Types].xml" if not ct_path.exists(): return dom = defusedxml.minidom.parse(str(ct_path)) changed = False for override in list(dom.getElementsByTagName("Override")): part_name = override.getAttribute("PartName").lstrip("/") if part_name in removed_files: if override.parentNode: override.parentNode.removeChild(override) changed = True if changed: with open(ct_path, "wb") as f: f.write(dom.toxml(encoding="utf-8")) def clean_unused_files(unpacked_dir: Path) -> list[str]: all_removed = [] slides_removed = remove_orphaned_slides(unpacked_dir) all_removed.extend(slides_removed) trash_removed = remove_trash_directory(unpacked_dir) all_removed.extend(trash_removed) while True: removed_rels = remove_orphaned_rels_files(unpacked_dir) referenced = get_referenced_files(unpacked_dir) removed_files = remove_orphaned_files(unpacked_dir, referenced) total_removed = removed_rels + removed_files if not total_removed: break all_removed.extend(total_removed) if all_removed: update_content_types(unpacked_dir, all_removed) return all_removed if __name__ == "__main__": if len(sys.argv) != 2: print("Usage: python clean.py ", file=sys.stderr) print("Example: python clean.py unpacked/", file=sys.stderr) sys.exit(1) unpacked_dir = Path(sys.argv[1]) if not unpacked_dir.exists(): print(f"Error: {unpacked_dir} not found", file=sys.stderr) sys.exit(1) removed = clean_unused_files(unpacked_dir) if removed: print(f"Removed {len(removed)} unreferenced files:") for f in removed: print(f" {f}") else: print("No unreferenced files found") ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/helpers/__init__.py ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/helpers/merge_runs.py ================================================ """Merge adjacent runs with identical formatting in DOCX. Merges adjacent elements that have identical properties. Works on runs in paragraphs and inside tracked changes (, ). Also: - Removes rsid attributes from runs (revision metadata that doesn't affect rendering) - Removes proofErr elements (spell/grammar markers that block merging) """ from pathlib import Path import defusedxml.minidom def merge_runs(input_dir: str) -> tuple[int, str]: doc_xml = Path(input_dir) / "word" / "document.xml" if not doc_xml.exists(): return 0, f"Error: {doc_xml} not found" try: dom = defusedxml.minidom.parseString(doc_xml.read_text(encoding="utf-8")) root = dom.documentElement _remove_elements(root, "proofErr") _strip_run_rsid_attrs(root) containers = {run.parentNode for run in _find_elements(root, "r")} merge_count = 0 for container in containers: merge_count += _merge_runs_in(container) doc_xml.write_bytes(dom.toxml(encoding="UTF-8")) return merge_count, f"Merged {merge_count} runs" except Exception as e: return 0, f"Error: {e}" def _find_elements(root, tag: str) -> list: results = [] def traverse(node): if node.nodeType == node.ELEMENT_NODE: name = node.localName or node.tagName if name == tag or name.endswith(f":{tag}"): results.append(node) for child in node.childNodes: traverse(child) traverse(root) return results def _get_child(parent, tag: str): for child in parent.childNodes: if child.nodeType == child.ELEMENT_NODE: name = child.localName or child.tagName if name == tag or name.endswith(f":{tag}"): return child return None def _get_children(parent, tag: str) -> list: results = [] for child in parent.childNodes: if child.nodeType == child.ELEMENT_NODE: name = child.localName or child.tagName if name == tag or name.endswith(f":{tag}"): results.append(child) return results def _is_adjacent(elem1, elem2) -> bool: node = elem1.nextSibling while node: if node == elem2: return True if node.nodeType == node.ELEMENT_NODE: return False if node.nodeType == node.TEXT_NODE and node.data.strip(): return False node = node.nextSibling return False def _remove_elements(root, tag: str): for elem in _find_elements(root, tag): if elem.parentNode: elem.parentNode.removeChild(elem) def _strip_run_rsid_attrs(root): for run in _find_elements(root, "r"): for attr in list(run.attributes.values()): if "rsid" in attr.name.lower(): run.removeAttribute(attr.name) def _merge_runs_in(container) -> int: merge_count = 0 run = _first_child_run(container) while run: while True: next_elem = _next_element_sibling(run) if next_elem and _is_run(next_elem) and _can_merge(run, next_elem): _merge_run_content(run, next_elem) container.removeChild(next_elem) merge_count += 1 else: break _consolidate_text(run) run = _next_sibling_run(run) return merge_count def _first_child_run(container): for child in container.childNodes: if child.nodeType == child.ELEMENT_NODE and _is_run(child): return child return None def _next_element_sibling(node): sibling = node.nextSibling while sibling: if sibling.nodeType == sibling.ELEMENT_NODE: return sibling sibling = sibling.nextSibling return None def _next_sibling_run(node): sibling = node.nextSibling while sibling: if sibling.nodeType == sibling.ELEMENT_NODE: if _is_run(sibling): return sibling sibling = sibling.nextSibling return None def _is_run(node) -> bool: name = node.localName or node.tagName return name == "r" or name.endswith(":r") def _can_merge(run1, run2) -> bool: rpr1 = _get_child(run1, "rPr") rpr2 = _get_child(run2, "rPr") if (rpr1 is None) != (rpr2 is None): return False if rpr1 is None: return True return rpr1.toxml() == rpr2.toxml() def _merge_run_content(target, source): for child in list(source.childNodes): if child.nodeType == child.ELEMENT_NODE: name = child.localName or child.tagName if name != "rPr" and not name.endswith(":rPr"): target.appendChild(child) def _consolidate_text(run): t_elements = _get_children(run, "t") for i in range(len(t_elements) - 1, 0, -1): curr, prev = t_elements[i], t_elements[i - 1] if _is_adjacent(prev, curr): prev_text = prev.firstChild.data if prev.firstChild else "" curr_text = curr.firstChild.data if curr.firstChild else "" merged = prev_text + curr_text if prev.firstChild: prev.firstChild.data = merged else: prev.appendChild(run.ownerDocument.createTextNode(merged)) if merged.startswith(" ") or merged.endswith(" "): prev.setAttribute("xml:space", "preserve") elif prev.hasAttribute("xml:space"): prev.removeAttribute("xml:space") run.removeChild(curr) ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/helpers/simplify_redlines.py ================================================ """Simplify tracked changes by merging adjacent w:ins or w:del elements. Merges adjacent elements from the same author into a single element. Same for elements. This makes heavily-redlined documents easier to work with by reducing the number of tracked change wrappers. Rules: - Only merges w:ins with w:ins, w:del with w:del (same element type) - Only merges if same author (ignores timestamp differences) - Only merges if truly adjacent (only whitespace between them) """ import xml.etree.ElementTree as ET import zipfile from pathlib import Path import defusedxml.minidom WORD_NS = "http://schemas.openxmlformats.org/wordprocessingml/2006/main" def simplify_redlines(input_dir: str) -> tuple[int, str]: doc_xml = Path(input_dir) / "word" / "document.xml" if not doc_xml.exists(): return 0, f"Error: {doc_xml} not found" try: dom = defusedxml.minidom.parseString(doc_xml.read_text(encoding="utf-8")) root = dom.documentElement merge_count = 0 containers = _find_elements(root, "p") + _find_elements(root, "tc") for container in containers: merge_count += _merge_tracked_changes_in(container, "ins") merge_count += _merge_tracked_changes_in(container, "del") doc_xml.write_bytes(dom.toxml(encoding="UTF-8")) return merge_count, f"Simplified {merge_count} tracked changes" except Exception as e: return 0, f"Error: {e}" def _merge_tracked_changes_in(container, tag: str) -> int: merge_count = 0 tracked = [ child for child in container.childNodes if child.nodeType == child.ELEMENT_NODE and _is_element(child, tag) ] if len(tracked) < 2: return 0 i = 0 while i < len(tracked) - 1: curr = tracked[i] next_elem = tracked[i + 1] if _can_merge_tracked(curr, next_elem): _merge_tracked_content(curr, next_elem) container.removeChild(next_elem) tracked.pop(i + 1) merge_count += 1 else: i += 1 return merge_count def _is_element(node, tag: str) -> bool: name = node.localName or node.tagName return name == tag or name.endswith(f":{tag}") def _get_author(elem) -> str: author = elem.getAttribute("w:author") if not author: for attr in elem.attributes.values(): if attr.localName == "author" or attr.name.endswith(":author"): return attr.value return author def _can_merge_tracked(elem1, elem2) -> bool: if _get_author(elem1) != _get_author(elem2): return False node = elem1.nextSibling while node and node != elem2: if node.nodeType == node.ELEMENT_NODE: return False if node.nodeType == node.TEXT_NODE and node.data.strip(): return False node = node.nextSibling return True def _merge_tracked_content(target, source): while source.firstChild: child = source.firstChild source.removeChild(child) target.appendChild(child) def _find_elements(root, tag: str) -> list: results = [] def traverse(node): if node.nodeType == node.ELEMENT_NODE: name = node.localName or node.tagName if name == tag or name.endswith(f":{tag}"): results.append(node) for child in node.childNodes: traverse(child) traverse(root) return results def get_tracked_change_authors(doc_xml_path: Path) -> dict[str, int]: if not doc_xml_path.exists(): return {} try: tree = ET.parse(doc_xml_path) root = tree.getroot() except ET.ParseError: return {} namespaces = {"w": WORD_NS} author_attr = f"{{{WORD_NS}}}author" authors: dict[str, int] = {} for tag in ["ins", "del"]: for elem in root.findall(f".//w:{tag}", namespaces): author = elem.get(author_attr) if author: authors[author] = authors.get(author, 0) + 1 return authors def _get_authors_from_docx(docx_path: Path) -> dict[str, int]: try: with zipfile.ZipFile(docx_path, "r") as zf: if "word/document.xml" not in zf.namelist(): return {} with zf.open("word/document.xml") as f: tree = ET.parse(f) root = tree.getroot() namespaces = {"w": WORD_NS} author_attr = f"{{{WORD_NS}}}author" authors: dict[str, int] = {} for tag in ["ins", "del"]: for elem in root.findall(f".//w:{tag}", namespaces): author = elem.get(author_attr) if author: authors[author] = authors.get(author, 0) + 1 return authors except (zipfile.BadZipFile, ET.ParseError): return {} def infer_author( modified_dir: Path, original_docx: Path, default: str = "Claude" ) -> str: modified_xml = modified_dir / "word" / "document.xml" modified_authors = get_tracked_change_authors(modified_xml) if not modified_authors: return default original_authors = _get_authors_from_docx(original_docx) new_changes: dict[str, int] = {} for author, count in modified_authors.items(): original_count = original_authors.get(author, 0) diff = count - original_count if diff > 0: new_changes[author] = diff if not new_changes: return default if len(new_changes) == 1: return next(iter(new_changes)) raise ValueError( f"Multiple authors added new changes: {new_changes}. Cannot infer which author to validate." ) ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/pack.py ================================================ """Pack a directory into a DOCX, PPTX, or XLSX file. Validates with auto-repair, condenses XML formatting, and creates the Office file. Usage: python pack.py [--original ] [--validate true|false] Examples: python pack.py unpacked/ output.docx --original input.docx python pack.py unpacked/ output.pptx --validate false """ import argparse import shutil import sys import tempfile import zipfile from pathlib import Path import defusedxml.minidom from validators import DOCXSchemaValidator from validators import PPTXSchemaValidator from validators import RedliningValidator def pack( input_directory: str, output_file: str, original_file: str | None = None, validate: bool = True, infer_author_func=None, ) -> tuple[None, str]: input_dir = Path(input_directory) output_path = Path(output_file) suffix = output_path.suffix.lower() if not input_dir.is_dir(): return None, f"Error: {input_dir} is not a directory" if suffix not in {".docx", ".pptx", ".xlsx"}: return None, f"Error: {output_file} must be a .docx, .pptx, or .xlsx file" if validate and original_file: original_path = Path(original_file) if original_path.exists(): success, output = _run_validation( input_dir, original_path, suffix, infer_author_func ) if output: print(output) if not success: return None, f"Error: Validation failed for {input_dir}" with tempfile.TemporaryDirectory() as temp_dir: temp_content_dir = Path(temp_dir) / "content" shutil.copytree(input_dir, temp_content_dir) for pattern in ["*.xml", "*.rels"]: for xml_file in temp_content_dir.rglob(pattern): _condense_xml(xml_file) output_path.parent.mkdir(parents=True, exist_ok=True) with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zf: for f in temp_content_dir.rglob("*"): if f.is_file(): zf.write(f, f.relative_to(temp_content_dir)) return None, f"Successfully packed {input_dir} to {output_file}" def _run_validation( unpacked_dir: Path, original_file: Path, suffix: str, infer_author_func=None, ) -> tuple[bool, str | None]: output_lines = [] validators = [] if suffix == ".docx": author = "Claude" if infer_author_func: try: author = infer_author_func(unpacked_dir, original_file) except ValueError as e: print(f"Warning: {e} Using default author 'Claude'.", file=sys.stderr) validators = [ DOCXSchemaValidator(unpacked_dir, original_file), RedliningValidator(unpacked_dir, original_file, author=author), ] elif suffix == ".pptx": validators = [PPTXSchemaValidator(unpacked_dir, original_file)] if not validators: return True, None total_repairs = sum(v.repair() for v in validators) if total_repairs: output_lines.append(f"Auto-repaired {total_repairs} issue(s)") success = all(v.validate() for v in validators) if success: output_lines.append("All validations PASSED!") return success, "\n".join(output_lines) if output_lines else None def _condense_xml(xml_file: Path) -> None: try: with open(xml_file, encoding="utf-8") as f: dom = defusedxml.minidom.parse(f) for element in dom.getElementsByTagName("*"): if element.tagName.endswith(":t"): continue for child in list(element.childNodes): if ( child.nodeType == child.TEXT_NODE and child.nodeValue and child.nodeValue.strip() == "" ) or child.nodeType == child.COMMENT_NODE: element.removeChild(child) xml_file.write_bytes(dom.toxml(encoding="UTF-8")) except Exception as e: print(f"ERROR: Failed to parse {xml_file.name}: {e}", file=sys.stderr) raise if __name__ == "__main__": parser = argparse.ArgumentParser( description="Pack a directory into a DOCX, PPTX, or XLSX file" ) parser.add_argument("input_directory", help="Unpacked Office document directory") parser.add_argument("output_file", help="Output Office file (.docx/.pptx/.xlsx)") parser.add_argument( "--original", help="Original file for validation comparison", ) parser.add_argument( "--validate", type=lambda x: x.lower() == "true", default=True, metavar="true|false", help="Run validation with auto-repair (default: true)", ) args = parser.parse_args() _, message = pack( args.input_directory, args.output_file, original_file=args.original, validate=args.validate, ) print(message) if "Error" in message: sys.exit(1) ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/dml-chart.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/dml-chartDrawing.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/dml-diagram.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/dml-lockedCanvas.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/dml-main.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/dml-picture.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/dml-spreadsheetDrawing.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/dml-wordprocessingDrawing.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/pml.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/shared-additionalCharacteristics.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/shared-bibliography.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/shared-commonSimpleTypes.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/shared-customXmlDataProperties.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/shared-customXmlSchemaProperties.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/shared-documentPropertiesCustom.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/shared-documentPropertiesExtended.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/shared-documentPropertiesVariantTypes.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/shared-math.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/shared-relationshipReference.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/sml.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/vml-main.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/vml-officeDrawing.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/vml-presentationDrawing.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/vml-spreadsheetDrawing.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/vml-wordprocessingDrawing.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/wml.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ISO-IEC29500-4_2016/xml.xsd ================================================ See http://www.w3.org/XML/1998/namespace.html and http://www.w3.org/TR/REC-xml for information about this namespace. This schema document describes the XML namespace, in a form suitable for import by other schema documents. Note that local names in this namespace are intended to be defined only by the World Wide Web Consortium or its subgroups. The following names are currently defined in this namespace and should not be used with conflicting semantics by any Working Group, specification, or document instance: base (as an attribute name): denotes an attribute whose value provides a URI to be used as the base for interpreting any relative URIs in the scope of the element on which it appears; its value is inherited. This name is reserved by virtue of its definition in the XML Base specification. lang (as an attribute name): denotes an attribute whose value is a language code for the natural language of the content of any element; its value is inherited. This name is reserved by virtue of its definition in the XML specification. space (as an attribute name): denotes an attribute whose value is a keyword indicating what whitespace processing discipline is intended for the content of the element; its value is inherited. This name is reserved by virtue of its definition in the XML specification. Father (in any context at all): denotes Jon Bosak, the chair of the original XML Working Group. This name is reserved by the following decision of the W3C XML Plenary and XML Coordination groups: In appreciation for his vision, leadership and dedication the W3C XML Plenary on this 10th day of February, 2000 reserves for Jon Bosak in perpetuity the XML name xml:Father This schema defines attributes and an attribute group suitable for use by schemas wishing to allow xml:base, xml:lang or xml:space attributes on elements they define. To enable this, such a schema must import this schema for the XML namespace, e.g. as follows: <schema . . .> . . . <import namespace="http://www.w3.org/XML/1998/namespace" schemaLocation="http://www.w3.org/2001/03/xml.xsd"/> Subsequently, qualified reference to any of the attributes or the group defined below will have the desired effect, e.g. <type . . .> . . . <attributeGroup ref="xml:specialAttrs"/> will define a type which will schema-validate an instance element with any of those attributes In keeping with the XML Schema WG's standard versioning policy, this schema document will persist at http://www.w3.org/2001/03/xml.xsd. At the date of issue it can also be found at http://www.w3.org/2001/xml.xsd. The schema document at that URI may however change in the future, in order to remain compatible with the latest version of XML Schema itself. In other words, if the XML Schema namespace changes, the version of this document at http://www.w3.org/2001/xml.xsd will change accordingly; the version at http://www.w3.org/2001/03/xml.xsd will not change. In due course, we should install the relevant ISO 2- and 3-letter codes as the enumerated possible values . . . See http://www.w3.org/TR/xmlbase/ for information about this attribute. ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ecma/fouth-edition/opc-contentTypes.xsd ================================================  ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ecma/fouth-edition/opc-coreProperties.xsd ================================================  ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ecma/fouth-edition/opc-digSig.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/ecma/fouth-edition/opc-relationships.xsd ================================================  ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/mce/mc.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/microsoft/wml-2010.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/microsoft/wml-2012.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/microsoft/wml-2018.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/microsoft/wml-cex-2018.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/microsoft/wml-cid-2016.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/microsoft/wml-sdtdatahash-2020.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/schemas/microsoft/wml-symex-2015.xsd ================================================ ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/soffice.py ================================================ """ Helper for running LibreOffice (soffice) in environments where AF_UNIX sockets may be blocked (e.g., sandboxed VMs). Detects the restriction at runtime and applies an LD_PRELOAD shim if needed. Usage: from office.soffice import run_soffice, get_soffice_env # Option 1 – run soffice directly result = run_soffice(["--headless", "--convert-to", "pdf", "input.docx"]) # Option 2 – get env dict for your own subprocess calls env = get_soffice_env() subprocess.run(["soffice", ...], env=env) """ import os import socket import subprocess import tempfile from pathlib import Path def get_soffice_env() -> dict: env = os.environ.copy() env["SAL_USE_VCLPLUGIN"] = "svp" if _needs_shim(): shim = _ensure_shim() env["LD_PRELOAD"] = str(shim) return env def run_soffice(args: list[str], **kwargs) -> subprocess.CompletedProcess: env = get_soffice_env() return subprocess.run(["soffice"] + args, env=env, **kwargs) _SHIM_SO = Path(tempfile.gettempdir()) / "lo_socket_shim.so" def _needs_shim() -> bool: try: s = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) s.close() return False except OSError: return True def _ensure_shim() -> Path: if _SHIM_SO.exists(): return _SHIM_SO src = Path(tempfile.gettempdir()) / "lo_socket_shim.c" src.write_text(_SHIM_SOURCE) subprocess.run( ["gcc", "-shared", "-fPIC", "-o", str(_SHIM_SO), str(src), "-ldl"], check=True, capture_output=True, ) src.unlink() return _SHIM_SO _SHIM_SOURCE = r""" #define _GNU_SOURCE #include #include #include #include #include #include #include static int (*real_socket)(int, int, int); static int (*real_socketpair)(int, int, int, int[2]); static int (*real_listen)(int, int); static int (*real_accept)(int, struct sockaddr *, socklen_t *); static int (*real_close)(int); static int (*real_read)(int, void *, size_t); /* Per-FD bookkeeping (FDs >= 1024 are passed through unshimmed). */ static int is_shimmed[1024]; static int peer_of[1024]; static int wake_r[1024]; /* accept() blocks reading this */ static int wake_w[1024]; /* close() writes to this */ static int listener_fd = -1; /* FD that received listen() */ __attribute__((constructor)) static void init(void) { real_socket = dlsym(RTLD_NEXT, "socket"); real_socketpair = dlsym(RTLD_NEXT, "socketpair"); real_listen = dlsym(RTLD_NEXT, "listen"); real_accept = dlsym(RTLD_NEXT, "accept"); real_close = dlsym(RTLD_NEXT, "close"); real_read = dlsym(RTLD_NEXT, "read"); for (int i = 0; i < 1024; i++) { peer_of[i] = -1; wake_r[i] = -1; wake_w[i] = -1; } } /* ---- socket ---------------------------------------------------------- */ int socket(int domain, int type, int protocol) { if (domain == AF_UNIX) { int fd = real_socket(domain, type, protocol); if (fd >= 0) return fd; /* socket(AF_UNIX) blocked – fall back to socketpair(). */ int sv[2]; if (real_socketpair(domain, type, protocol, sv) == 0) { if (sv[0] >= 0 && sv[0] < 1024) { is_shimmed[sv[0]] = 1; peer_of[sv[0]] = sv[1]; int wp[2]; if (pipe(wp) == 0) { wake_r[sv[0]] = wp[0]; wake_w[sv[0]] = wp[1]; } } return sv[0]; } errno = EPERM; return -1; } return real_socket(domain, type, protocol); } /* ---- listen ---------------------------------------------------------- */ int listen(int sockfd, int backlog) { if (sockfd >= 0 && sockfd < 1024 && is_shimmed[sockfd]) { listener_fd = sockfd; return 0; } return real_listen(sockfd, backlog); } /* ---- accept ---------------------------------------------------------- */ int accept(int sockfd, struct sockaddr *addr, socklen_t *addrlen) { if (sockfd >= 0 && sockfd < 1024 && is_shimmed[sockfd]) { /* Block until close() writes to the wake pipe. */ if (wake_r[sockfd] >= 0) { char buf; real_read(wake_r[sockfd], &buf, 1); } errno = ECONNABORTED; return -1; } return real_accept(sockfd, addr, addrlen); } /* ---- close ----------------------------------------------------------- */ int close(int fd) { if (fd >= 0 && fd < 1024 && is_shimmed[fd]) { int was_listener = (fd == listener_fd); is_shimmed[fd] = 0; if (wake_w[fd] >= 0) { /* unblock accept() */ char c = 0; write(wake_w[fd], &c, 1); real_close(wake_w[fd]); wake_w[fd] = -1; } if (wake_r[fd] >= 0) { real_close(wake_r[fd]); wake_r[fd] = -1; } if (peer_of[fd] >= 0) { real_close(peer_of[fd]); peer_of[fd] = -1; } if (was_listener) _exit(0); /* conversion done – exit */ } return real_close(fd); } """ if __name__ == "__main__": import sys result = run_soffice(sys.argv[1:]) sys.exit(result.returncode) ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/unpack.py ================================================ """Unpack Office files (DOCX, PPTX, XLSX) for editing. Extracts the ZIP archive, pretty-prints XML files, and optionally: - Merges adjacent runs with identical formatting (DOCX only) - Simplifies adjacent tracked changes from same author (DOCX only) Usage: python unpack.py [options] Examples: python unpack.py document.docx unpacked/ python unpack.py presentation.pptx unpacked/ python unpack.py document.docx unpacked/ --merge-runs false """ import argparse import sys import zipfile from pathlib import Path import defusedxml.minidom from helpers.merge_runs import merge_runs as do_merge_runs from helpers.simplify_redlines import simplify_redlines as do_simplify_redlines SMART_QUOTE_REPLACEMENTS = { "\u201c": "“", "\u201d": "”", "\u2018": "‘", "\u2019": "’", } def unpack( input_file: str, output_directory: str, merge_runs: bool = True, simplify_redlines: bool = True, ) -> tuple[None, str]: input_path = Path(input_file) output_path = Path(output_directory) suffix = input_path.suffix.lower() if not input_path.exists(): return None, f"Error: {input_file} does not exist" if suffix not in {".docx", ".pptx", ".xlsx"}: return None, f"Error: {input_file} must be a .docx, .pptx, or .xlsx file" try: output_path.mkdir(parents=True, exist_ok=True) with zipfile.ZipFile(input_path, "r") as zf: zf.extractall(output_path) xml_files = list(output_path.rglob("*.xml")) + list(output_path.rglob("*.rels")) for xml_file in xml_files: _pretty_print_xml(xml_file) message = f"Unpacked {input_file} ({len(xml_files)} XML files)" if suffix == ".docx": if simplify_redlines: simplify_count, _ = do_simplify_redlines(str(output_path)) message += f", simplified {simplify_count} tracked changes" if merge_runs: merge_count, _ = do_merge_runs(str(output_path)) message += f", merged {merge_count} runs" for xml_file in xml_files: _escape_smart_quotes(xml_file) return None, message except zipfile.BadZipFile: return None, f"Error: {input_file} is not a valid Office file" except Exception as e: return None, f"Error unpacking: {e}" def _pretty_print_xml(xml_file: Path) -> None: try: content = xml_file.read_text(encoding="utf-8") dom = defusedxml.minidom.parseString(content) xml_file.write_bytes(dom.toprettyxml(indent=" ", encoding="utf-8")) except Exception: pass def _escape_smart_quotes(xml_file: Path) -> None: try: content = xml_file.read_text(encoding="utf-8") for char, entity in SMART_QUOTE_REPLACEMENTS.items(): content = content.replace(char, entity) xml_file.write_text(content, encoding="utf-8") except Exception: pass if __name__ == "__main__": parser = argparse.ArgumentParser( description="Unpack an Office file (DOCX, PPTX, XLSX) for editing" ) parser.add_argument("input_file", help="Office file to unpack") parser.add_argument("output_directory", help="Output directory") parser.add_argument( "--merge-runs", type=lambda x: x.lower() == "true", default=True, metavar="true|false", help="Merge adjacent runs with identical formatting (DOCX only, default: true)", ) parser.add_argument( "--simplify-redlines", type=lambda x: x.lower() == "true", default=True, metavar="true|false", help="Merge adjacent tracked changes from same author (DOCX only, default: true)", ) args = parser.parse_args() _, message = unpack( args.input_file, args.output_directory, merge_runs=args.merge_runs, simplify_redlines=args.simplify_redlines, ) print(message) if "Error" in message: sys.exit(1) ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/validate.py ================================================ """ Command line tool to validate Office document XML files against XSD schemas and tracked changes. Usage: python validate.py [--original ] [--auto-repair] [--author NAME] The first argument can be either: - An unpacked directory containing the Office document XML files - A packed Office file (.docx/.pptx/.xlsx) which will be unpacked to a temp directory Auto-repair fixes: - paraId/durableId values that exceed OOXML limits - Missing xml:space="preserve" on w:t elements with whitespace """ import argparse import sys import tempfile import zipfile from pathlib import Path from validators import DOCXSchemaValidator from validators import PPTXSchemaValidator from validators import RedliningValidator def main(): parser = argparse.ArgumentParser(description="Validate Office document XML files") parser.add_argument( "path", help="Path to unpacked directory or packed Office file (.docx/.pptx/.xlsx)", ) parser.add_argument( "--original", required=False, default=None, help=( "Path to original file (.docx/.pptx/.xlsx). If omitted, all XSD errors " "are reported and redlining validation is skipped." ), ) parser.add_argument( "-v", "--verbose", action="store_true", help="Enable verbose output", ) parser.add_argument( "--auto-repair", action="store_true", help="Automatically repair common issues (hex IDs, whitespace preservation)", ) parser.add_argument( "--author", default="Claude", help="Author name for redlining validation (default: Claude)", ) args = parser.parse_args() path = Path(args.path) assert path.exists(), f"Error: {path} does not exist" original_file = None if args.original: original_file = Path(args.original) assert original_file.is_file(), f"Error: {original_file} is not a file" assert original_file.suffix.lower() in [ ".docx", ".pptx", ".xlsx", ], f"Error: {original_file} must be a .docx, .pptx, or .xlsx file" file_extension = (original_file or path).suffix.lower() assert file_extension in [ ".docx", ".pptx", ".xlsx", ], f"Error: Cannot determine file type from {path}. Use --original or provide a .docx/.pptx/.xlsx file." if path.is_file() and path.suffix.lower() in [".docx", ".pptx", ".xlsx"]: temp_dir = tempfile.mkdtemp() with zipfile.ZipFile(path, "r") as zf: zf.extractall(temp_dir) unpacked_dir = Path(temp_dir) else: assert path.is_dir(), f"Error: {path} is not a directory or Office file" unpacked_dir = path match file_extension: case ".docx": validators = [ DOCXSchemaValidator(unpacked_dir, original_file, verbose=args.verbose), ] if original_file: validators.append( RedliningValidator( unpacked_dir, original_file, verbose=args.verbose, author=args.author, ) ) case ".pptx": validators = [ PPTXSchemaValidator(unpacked_dir, original_file, verbose=args.verbose), ] case _: print(f"Error: Validation not supported for file type {file_extension}") sys.exit(1) if args.auto_repair: total_repairs = sum(v.repair() for v in validators) if total_repairs: print(f"Auto-repaired {total_repairs} issue(s)") success = all(v.validate() for v in validators) if success: print("All validations PASSED!") sys.exit(0 if success else 1) if __name__ == "__main__": main() ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/validators/__init__.py ================================================ """ Validation modules for Word document processing. """ from .base import BaseSchemaValidator from .docx import DOCXSchemaValidator from .pptx import PPTXSchemaValidator from .redlining import RedliningValidator __all__ = [ "BaseSchemaValidator", "DOCXSchemaValidator", "PPTXSchemaValidator", "RedliningValidator", ] ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/validators/base.py ================================================ """ Base validator with common validation logic for document files. """ import re from pathlib import Path import defusedxml.minidom import lxml.etree class BaseSchemaValidator: IGNORED_VALIDATION_ERRORS = [ "hyphenationZone", "purl.org/dc/terms", ] UNIQUE_ID_REQUIREMENTS = { "comment": ("id", "file"), "commentrangestart": ("id", "file"), "commentrangeend": ("id", "file"), "bookmarkstart": ("id", "file"), "bookmarkend": ("id", "file"), "sldid": ("id", "file"), "sldmasterid": ("id", "global"), "sldlayoutid": ("id", "global"), "cm": ("authorid", "file"), "sheet": ("sheetid", "file"), "definedname": ("id", "file"), "cxnsp": ("id", "file"), "sp": ("id", "file"), "pic": ("id", "file"), "grpsp": ("id", "file"), } EXCLUDED_ID_CONTAINERS = { "sectionlst", } ELEMENT_RELATIONSHIP_TYPES = {} SCHEMA_MAPPINGS = { "word": "ISO-IEC29500-4_2016/wml.xsd", "ppt": "ISO-IEC29500-4_2016/pml.xsd", "xl": "ISO-IEC29500-4_2016/sml.xsd", "[Content_Types].xml": "ecma/fouth-edition/opc-contentTypes.xsd", "app.xml": "ISO-IEC29500-4_2016/shared-documentPropertiesExtended.xsd", "core.xml": "ecma/fouth-edition/opc-coreProperties.xsd", "custom.xml": "ISO-IEC29500-4_2016/shared-documentPropertiesCustom.xsd", ".rels": "ecma/fouth-edition/opc-relationships.xsd", "people.xml": "microsoft/wml-2012.xsd", "commentsIds.xml": "microsoft/wml-cid-2016.xsd", "commentsExtensible.xml": "microsoft/wml-cex-2018.xsd", "commentsExtended.xml": "microsoft/wml-2012.xsd", "chart": "ISO-IEC29500-4_2016/dml-chart.xsd", "theme": "ISO-IEC29500-4_2016/dml-main.xsd", "drawing": "ISO-IEC29500-4_2016/dml-main.xsd", } MC_NAMESPACE = "http://schemas.openxmlformats.org/markup-compatibility/2006" XML_NAMESPACE = "http://www.w3.org/XML/1998/namespace" PACKAGE_RELATIONSHIPS_NAMESPACE = ( "http://schemas.openxmlformats.org/package/2006/relationships" ) OFFICE_RELATIONSHIPS_NAMESPACE = ( "http://schemas.openxmlformats.org/officeDocument/2006/relationships" ) CONTENT_TYPES_NAMESPACE = ( "http://schemas.openxmlformats.org/package/2006/content-types" ) MAIN_CONTENT_FOLDERS = {"word", "ppt", "xl"} OOXML_NAMESPACES = { "http://schemas.openxmlformats.org/officeDocument/2006/math", "http://schemas.openxmlformats.org/officeDocument/2006/relationships", "http://schemas.openxmlformats.org/schemaLibrary/2006/main", "http://schemas.openxmlformats.org/drawingml/2006/main", "http://schemas.openxmlformats.org/drawingml/2006/chart", "http://schemas.openxmlformats.org/drawingml/2006/chartDrawing", "http://schemas.openxmlformats.org/drawingml/2006/diagram", "http://schemas.openxmlformats.org/drawingml/2006/picture", "http://schemas.openxmlformats.org/drawingml/2006/spreadsheetDrawing", "http://schemas.openxmlformats.org/drawingml/2006/wordprocessingDrawing", "http://schemas.openxmlformats.org/wordprocessingml/2006/main", "http://schemas.openxmlformats.org/presentationml/2006/main", "http://schemas.openxmlformats.org/spreadsheetml/2006/main", "http://schemas.openxmlformats.org/officeDocument/2006/sharedTypes", "http://www.w3.org/XML/1998/namespace", } def __init__(self, unpacked_dir, original_file=None, verbose=False): self.unpacked_dir = Path(unpacked_dir).resolve() self.original_file = Path(original_file) if original_file else None self.verbose = verbose self.schemas_dir = Path(__file__).parent.parent / "schemas" patterns = ["*.xml", "*.rels"] self.xml_files = [ f for pattern in patterns for f in self.unpacked_dir.rglob(pattern) ] if not self.xml_files: print(f"Warning: No XML files found in {self.unpacked_dir}") def validate(self): raise NotImplementedError("Subclasses must implement the validate method") def repair(self) -> int: return self.repair_whitespace_preservation() def repair_whitespace_preservation(self) -> int: repairs = 0 for xml_file in self.xml_files: try: content = xml_file.read_text(encoding="utf-8") dom = defusedxml.minidom.parseString(content) modified = False for elem in dom.getElementsByTagName("*"): if elem.tagName.endswith(":t") and elem.firstChild: text = elem.firstChild.nodeValue if text and ( text.startswith((" ", "\t")) or text.endswith((" ", "\t")) ): if elem.getAttribute("xml:space") != "preserve": elem.setAttribute("xml:space", "preserve") text_preview = ( repr(text[:30]) + "..." if len(text) > 30 else repr(text) ) print( f" Repaired: {xml_file.name}: Added xml:space='preserve' to {elem.tagName}: {text_preview}" ) repairs += 1 modified = True if modified: xml_file.write_bytes(dom.toxml(encoding="UTF-8")) except Exception: pass return repairs def validate_xml(self): errors = [] for xml_file in self.xml_files: try: lxml.etree.parse(str(xml_file)) except lxml.etree.XMLSyntaxError as e: errors.append( f" {xml_file.relative_to(self.unpacked_dir)}: Line {e.lineno}: {e.msg}" ) except Exception as e: errors.append( f" {xml_file.relative_to(self.unpacked_dir)}: Unexpected error: {str(e)}" ) if errors: print(f"FAILED - Found {len(errors)} XML violations:") for error in errors: print(error) return False else: if self.verbose: print("PASSED - All XML files are well-formed") return True def validate_namespaces(self): errors = [] for xml_file in self.xml_files: try: root = lxml.etree.parse(str(xml_file)).getroot() declared = set(root.nsmap.keys()) - {None} for attr_val in [ v for k, v in root.attrib.items() if k.endswith("Ignorable") ]: undeclared = set(attr_val.split()) - declared errors.extend( f" {xml_file.relative_to(self.unpacked_dir)}: Namespace '{ns}' in Ignorable but not declared" for ns in undeclared ) except lxml.etree.XMLSyntaxError: continue if errors: print(f"FAILED - {len(errors)} namespace issues:") for error in errors: print(error) return False if self.verbose: print("PASSED - All namespace prefixes properly declared") return True def validate_unique_ids(self): errors = [] global_ids = {} for xml_file in self.xml_files: try: root = lxml.etree.parse(str(xml_file)).getroot() file_ids = {} mc_elements = root.xpath( ".//mc:AlternateContent", namespaces={"mc": self.MC_NAMESPACE} ) for elem in mc_elements: elem.getparent().remove(elem) for elem in root.iter(): tag = ( elem.tag.split("}")[-1].lower() if "}" in elem.tag else elem.tag.lower() ) if tag in self.UNIQUE_ID_REQUIREMENTS: in_excluded_container = any( ancestor.tag.split("}")[-1].lower() in self.EXCLUDED_ID_CONTAINERS for ancestor in elem.iterancestors() ) if in_excluded_container: continue attr_name, scope = self.UNIQUE_ID_REQUIREMENTS[tag] id_value = None for attr, value in elem.attrib.items(): attr_local = ( attr.split("}")[-1].lower() if "}" in attr else attr.lower() ) if attr_local == attr_name: id_value = value break if id_value is not None: if scope == "global": if id_value in global_ids: prev_file, prev_line, prev_tag = global_ids[ id_value ] errors.append( f" {xml_file.relative_to(self.unpacked_dir)}: " f"Line {elem.sourceline}: Global ID '{id_value}' in <{tag}> " f"already used in {prev_file} at line {prev_line} in <{prev_tag}>" ) else: global_ids[id_value] = ( xml_file.relative_to(self.unpacked_dir), elem.sourceline, tag, ) elif scope == "file": key = (tag, attr_name) if key not in file_ids: file_ids[key] = {} if id_value in file_ids[key]: prev_line = file_ids[key][id_value] errors.append( f" {xml_file.relative_to(self.unpacked_dir)}: " f"Line {elem.sourceline}: Duplicate {attr_name}='{id_value}' in <{tag}> " f"(first occurrence at line {prev_line})" ) else: file_ids[key][id_value] = elem.sourceline except (lxml.etree.XMLSyntaxError, Exception) as e: errors.append( f" {xml_file.relative_to(self.unpacked_dir)}: Error: {e}" ) if errors: print(f"FAILED - Found {len(errors)} ID uniqueness violations:") for error in errors: print(error) return False else: if self.verbose: print("PASSED - All required IDs are unique") return True def validate_file_references(self): errors = [] rels_files = list(self.unpacked_dir.rglob("*.rels")) if not rels_files: if self.verbose: print("PASSED - No .rels files found") return True all_files = [] for file_path in self.unpacked_dir.rglob("*"): if ( file_path.is_file() and file_path.name != "[Content_Types].xml" and not file_path.name.endswith(".rels") ): all_files.append(file_path.resolve()) all_referenced_files = set() if self.verbose: print( f"Found {len(rels_files)} .rels files and {len(all_files)} target files" ) for rels_file in rels_files: try: rels_root = lxml.etree.parse(str(rels_file)).getroot() rels_dir = rels_file.parent referenced_files = set() broken_refs = [] for rel in rels_root.findall( ".//ns:Relationship", namespaces={"ns": self.PACKAGE_RELATIONSHIPS_NAMESPACE}, ): target = rel.get("Target") if target and not target.startswith(("http", "mailto:")): if target.startswith("/"): target_path = self.unpacked_dir / target.lstrip("/") elif rels_file.name == ".rels": target_path = self.unpacked_dir / target else: base_dir = rels_dir.parent target_path = base_dir / target try: target_path = target_path.resolve() if target_path.exists() and target_path.is_file(): referenced_files.add(target_path) all_referenced_files.add(target_path) else: broken_refs.append((target, rel.sourceline)) except (OSError, ValueError): broken_refs.append((target, rel.sourceline)) if broken_refs: rel_path = rels_file.relative_to(self.unpacked_dir) for broken_ref, line_num in broken_refs: errors.append( f" {rel_path}: Line {line_num}: Broken reference to {broken_ref}" ) except Exception as e: rel_path = rels_file.relative_to(self.unpacked_dir) errors.append(f" Error parsing {rel_path}: {e}") unreferenced_files = set(all_files) - all_referenced_files if unreferenced_files: for unref_file in sorted(unreferenced_files): unref_rel_path = unref_file.relative_to(self.unpacked_dir) errors.append(f" Unreferenced file: {unref_rel_path}") if errors: print(f"FAILED - Found {len(errors)} relationship validation errors:") for error in errors: print(error) print( "CRITICAL: These errors will cause the document to appear corrupt. " + "Broken references MUST be fixed, " + "and unreferenced files MUST be referenced or removed." ) return False else: if self.verbose: print( "PASSED - All references are valid and all files are properly referenced" ) return True def validate_all_relationship_ids(self): import lxml.etree errors = [] for xml_file in self.xml_files: if xml_file.suffix == ".rels": continue rels_dir = xml_file.parent / "_rels" rels_file = rels_dir / f"{xml_file.name}.rels" if not rels_file.exists(): continue try: rels_root = lxml.etree.parse(str(rels_file)).getroot() rid_to_type = {} for rel in rels_root.findall( f".//{{{self.PACKAGE_RELATIONSHIPS_NAMESPACE}}}Relationship" ): rid = rel.get("Id") rel_type = rel.get("Type", "") if rid: if rid in rid_to_type: rels_rel_path = rels_file.relative_to(self.unpacked_dir) errors.append( f" {rels_rel_path}: Line {rel.sourceline}: " f"Duplicate relationship ID '{rid}' (IDs must be unique)" ) type_name = ( rel_type.split("/")[-1] if "/" in rel_type else rel_type ) rid_to_type[rid] = type_name xml_root = lxml.etree.parse(str(xml_file)).getroot() r_ns = self.OFFICE_RELATIONSHIPS_NAMESPACE rid_attrs_to_check = ["id", "embed", "link"] for elem in xml_root.iter(): for attr_name in rid_attrs_to_check: rid_attr = elem.get(f"{{{r_ns}}}{attr_name}") if not rid_attr: continue xml_rel_path = xml_file.relative_to(self.unpacked_dir) elem_name = ( elem.tag.split("}")[-1] if "}" in elem.tag else elem.tag ) if rid_attr not in rid_to_type: errors.append( f" {xml_rel_path}: Line {elem.sourceline}: " f"<{elem_name}> r:{attr_name} references non-existent relationship '{rid_attr}' " f"(valid IDs: {', '.join(sorted(rid_to_type.keys())[:5])}{'...' if len(rid_to_type) > 5 else ''})" ) elif attr_name == "id" and self.ELEMENT_RELATIONSHIP_TYPES: expected_type = self._get_expected_relationship_type( elem_name ) if expected_type: actual_type = rid_to_type[rid_attr] if expected_type not in actual_type.lower(): errors.append( f" {xml_rel_path}: Line {elem.sourceline}: " f"<{elem_name}> references '{rid_attr}' which points to '{actual_type}' " f"but should point to a '{expected_type}' relationship" ) except Exception as e: xml_rel_path = xml_file.relative_to(self.unpacked_dir) errors.append(f" Error processing {xml_rel_path}: {e}") if errors: print(f"FAILED - Found {len(errors)} relationship ID reference errors:") for error in errors: print(error) print("\nThese ID mismatches will cause the document to appear corrupt!") return False else: if self.verbose: print("PASSED - All relationship ID references are valid") return True def _get_expected_relationship_type(self, element_name): elem_lower = element_name.lower() if elem_lower in self.ELEMENT_RELATIONSHIP_TYPES: return self.ELEMENT_RELATIONSHIP_TYPES[elem_lower] if elem_lower.endswith("id") and len(elem_lower) > 2: prefix = elem_lower[:-2] if prefix.endswith("master"): return prefix.lower() elif prefix.endswith("layout"): return prefix.lower() else: if prefix == "sld": return "slide" return prefix.lower() if elem_lower.endswith("reference") and len(elem_lower) > 9: prefix = elem_lower[:-9] return prefix.lower() return None def validate_content_types(self): errors = [] content_types_file = self.unpacked_dir / "[Content_Types].xml" if not content_types_file.exists(): print("FAILED - [Content_Types].xml file not found") return False try: root = lxml.etree.parse(str(content_types_file)).getroot() declared_parts = set() declared_extensions = set() for override in root.findall( f".//{{{self.CONTENT_TYPES_NAMESPACE}}}Override" ): part_name = override.get("PartName") if part_name is not None: declared_parts.add(part_name.lstrip("/")) for default in root.findall( f".//{{{self.CONTENT_TYPES_NAMESPACE}}}Default" ): extension = default.get("Extension") if extension is not None: declared_extensions.add(extension.lower()) declarable_roots = { "sld", "sldLayout", "sldMaster", "presentation", "document", "workbook", "worksheet", "theme", } media_extensions = { "png": "image/png", "jpg": "image/jpeg", "jpeg": "image/jpeg", "gif": "image/gif", "bmp": "image/bmp", "tiff": "image/tiff", "wmf": "image/x-wmf", "emf": "image/x-emf", } all_files = list(self.unpacked_dir.rglob("*")) all_files = [f for f in all_files if f.is_file()] for xml_file in self.xml_files: path_str = str(xml_file.relative_to(self.unpacked_dir)).replace( "\\", "/" ) if any( skip in path_str for skip in [".rels", "[Content_Types]", "docProps/", "_rels/"] ): continue try: root_tag = lxml.etree.parse(str(xml_file)).getroot().tag root_name = root_tag.split("}")[-1] if "}" in root_tag else root_tag if root_name in declarable_roots and path_str not in declared_parts: errors.append( f" {path_str}: File with <{root_name}> root not declared in [Content_Types].xml" ) except Exception: continue for file_path in all_files: if file_path.suffix.lower() in {".xml", ".rels"}: continue if file_path.name == "[Content_Types].xml": continue if "_rels" in file_path.parts or "docProps" in file_path.parts: continue extension = file_path.suffix.lstrip(".").lower() if extension and extension not in declared_extensions: if extension in media_extensions: relative_path = file_path.relative_to(self.unpacked_dir) msg = ( f" {relative_path}: File with extension '{extension}' " f"not declared in [Content_Types].xml - should add: " f'' ) errors.append(msg) except Exception as e: errors.append(f" Error parsing [Content_Types].xml: {e}") if errors: print(f"FAILED - Found {len(errors)} content type declaration errors:") for error in errors: print(error) return False else: if self.verbose: print( "PASSED - All content files are properly declared in [Content_Types].xml" ) return True def validate_file_against_xsd(self, xml_file, verbose=False): xml_file = Path(xml_file).resolve() unpacked_dir = self.unpacked_dir.resolve() is_valid, current_errors = self._validate_single_file_xsd( xml_file, unpacked_dir ) if is_valid is None: return None, set() elif is_valid: return True, set() original_errors = self._get_original_file_errors(xml_file) assert current_errors is not None new_errors = current_errors - original_errors new_errors = { e for e in new_errors if not any(pattern in e for pattern in self.IGNORED_VALIDATION_ERRORS) } if new_errors: if verbose: relative_path = xml_file.relative_to(unpacked_dir) print(f"FAILED - {relative_path}: {len(new_errors)} new error(s)") for error in list(new_errors)[:3]: truncated = error[:250] + "..." if len(error) > 250 else error print(f" - {truncated}") return False, new_errors else: if verbose: print( f"PASSED - No new errors (original had {len(current_errors)} errors)" ) return True, set() def validate_against_xsd(self): new_errors = [] original_error_count = 0 valid_count = 0 skipped_count = 0 for xml_file in self.xml_files: relative_path = str(xml_file.relative_to(self.unpacked_dir)) is_valid, new_file_errors = self.validate_file_against_xsd( xml_file, verbose=False ) if is_valid is None: skipped_count += 1 continue elif is_valid and not new_file_errors: valid_count += 1 continue elif is_valid: original_error_count += 1 valid_count += 1 continue new_errors.append(f" {relative_path}: {len(new_file_errors)} new error(s)") for error in list(new_file_errors)[:3]: new_errors.append( f" - {error[:250]}..." if len(error) > 250 else f" - {error}" ) if self.verbose: print(f"Validated {len(self.xml_files)} files:") print(f" - Valid: {valid_count}") print(f" - Skipped (no schema): {skipped_count}") if original_error_count: print(f" - With original errors (ignored): {original_error_count}") print( f" - With NEW errors: {len(new_errors) > 0 and len([e for e in new_errors if not e.startswith(' ')]) or 0}" ) if new_errors: print("\nFAILED - Found NEW validation errors:") for error in new_errors: print(error) return False else: if self.verbose: print("\nPASSED - No new XSD validation errors introduced") return True def _get_schema_path(self, xml_file): if xml_file.name in self.SCHEMA_MAPPINGS: return self.schemas_dir / self.SCHEMA_MAPPINGS[xml_file.name] if xml_file.suffix == ".rels": return self.schemas_dir / self.SCHEMA_MAPPINGS[".rels"] if "charts/" in str(xml_file) and xml_file.name.startswith("chart"): return self.schemas_dir / self.SCHEMA_MAPPINGS["chart"] if "theme/" in str(xml_file) and xml_file.name.startswith("theme"): return self.schemas_dir / self.SCHEMA_MAPPINGS["theme"] if xml_file.parent.name in self.MAIN_CONTENT_FOLDERS: return self.schemas_dir / self.SCHEMA_MAPPINGS[xml_file.parent.name] return None def _clean_ignorable_namespaces(self, xml_doc): xml_string = lxml.etree.tostring(xml_doc, encoding="unicode") xml_copy = lxml.etree.fromstring(xml_string) for elem in xml_copy.iter(): attrs_to_remove = [] for attr in elem.attrib: if "{" in attr: ns = attr.split("}")[0][1:] if ns not in self.OOXML_NAMESPACES: attrs_to_remove.append(attr) for attr in attrs_to_remove: del elem.attrib[attr] self._remove_ignorable_elements(xml_copy) return lxml.etree.ElementTree(xml_copy) def _remove_ignorable_elements(self, root): elements_to_remove = [] for elem in list(root): if not hasattr(elem, "tag") or callable(elem.tag): continue tag_str = str(elem.tag) if tag_str.startswith("{"): ns = tag_str.split("}")[0][1:] if ns not in self.OOXML_NAMESPACES: elements_to_remove.append(elem) continue self._remove_ignorable_elements(elem) for elem in elements_to_remove: root.remove(elem) def _preprocess_for_mc_ignorable(self, xml_doc): root = xml_doc.getroot() if f"{{{self.MC_NAMESPACE}}}Ignorable" in root.attrib: del root.attrib[f"{{{self.MC_NAMESPACE}}}Ignorable"] return xml_doc def _validate_single_file_xsd(self, xml_file, base_path): schema_path = self._get_schema_path(xml_file) if not schema_path: return None, None try: with open(schema_path, "rb") as xsd_file: parser = lxml.etree.XMLParser() xsd_doc = lxml.etree.parse( xsd_file, parser=parser, base_url=str(schema_path) ) schema = lxml.etree.XMLSchema(xsd_doc) with open(xml_file, "r") as f: xml_doc = lxml.etree.parse(f) xml_doc, _ = self._remove_template_tags_from_text_nodes(xml_doc) xml_doc = self._preprocess_for_mc_ignorable(xml_doc) relative_path = xml_file.relative_to(base_path) if ( relative_path.parts and relative_path.parts[0] in self.MAIN_CONTENT_FOLDERS ): xml_doc = self._clean_ignorable_namespaces(xml_doc) if schema.validate(xml_doc): return True, set() else: errors = set() for error in schema.error_log: errors.add(error.message) return False, errors except Exception as e: return False, {str(e)} def _get_original_file_errors(self, xml_file): if self.original_file is None: return set() import tempfile import zipfile xml_file = Path(xml_file).resolve() unpacked_dir = self.unpacked_dir.resolve() relative_path = xml_file.relative_to(unpacked_dir) with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) with zipfile.ZipFile(self.original_file, "r") as zip_ref: zip_ref.extractall(temp_path) original_xml_file = temp_path / relative_path if not original_xml_file.exists(): return set() is_valid, errors = self._validate_single_file_xsd( original_xml_file, temp_path ) return errors if errors else set() def _remove_template_tags_from_text_nodes(self, xml_doc): warnings = [] template_pattern = re.compile(r"\{\{[^}]*\}\}") xml_string = lxml.etree.tostring(xml_doc, encoding="unicode") xml_copy = lxml.etree.fromstring(xml_string) def process_text_content(text, content_type): if not text: return text matches = list(template_pattern.finditer(text)) if matches: for match in matches: warnings.append( f"Found template tag in {content_type}: {match.group()}" ) return template_pattern.sub("", text) return text for elem in xml_copy.iter(): if not hasattr(elem, "tag") or callable(elem.tag): continue tag_str = str(elem.tag) if tag_str.endswith("}t") or tag_str == "t": continue elem.text = process_text_content(elem.text, "text content") elem.tail = process_text_content(elem.tail, "tail content") return lxml.etree.ElementTree(xml_copy), warnings if __name__ == "__main__": raise RuntimeError("This module should not be run directly.") ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/validators/docx.py ================================================ """ Validator for Word document XML files against XSD schemas. """ import random import re import tempfile import zipfile import defusedxml.minidom import lxml.etree from .base import BaseSchemaValidator class DOCXSchemaValidator(BaseSchemaValidator): WORD_2006_NAMESPACE = "http://schemas.openxmlformats.org/wordprocessingml/2006/main" W14_NAMESPACE = "http://schemas.microsoft.com/office/word/2010/wordml" W16CID_NAMESPACE = "http://schemas.microsoft.com/office/word/2016/wordml/cid" ELEMENT_RELATIONSHIP_TYPES = {} def validate(self): if not self.validate_xml(): return False all_valid = True if not self.validate_namespaces(): all_valid = False if not self.validate_unique_ids(): all_valid = False if not self.validate_file_references(): all_valid = False if not self.validate_content_types(): all_valid = False if not self.validate_against_xsd(): all_valid = False if not self.validate_whitespace_preservation(): all_valid = False if not self.validate_deletions(): all_valid = False if not self.validate_insertions(): all_valid = False if not self.validate_all_relationship_ids(): all_valid = False if not self.validate_id_constraints(): all_valid = False if not self.validate_comment_markers(): all_valid = False self.compare_paragraph_counts() return all_valid def validate_whitespace_preservation(self): errors = [] for xml_file in self.xml_files: if xml_file.name != "document.xml": continue try: root = lxml.etree.parse(str(xml_file)).getroot() for elem in root.iter(f"{{{self.WORD_2006_NAMESPACE}}}t"): if elem.text: text = elem.text if re.search(r"^[ \t\n\r]", text) or re.search( r"[ \t\n\r]$", text ): xml_space_attr = f"{{{self.XML_NAMESPACE}}}space" if ( xml_space_attr not in elem.attrib or elem.attrib[xml_space_attr] != "preserve" ): text_preview = ( repr(text)[:50] + "..." if len(repr(text)) > 50 else repr(text) ) errors.append( f" {xml_file.relative_to(self.unpacked_dir)}: " f"Line {elem.sourceline}: w:t element with whitespace " f"missing xml:space='preserve': {text_preview}" ) except (lxml.etree.XMLSyntaxError, Exception) as e: errors.append( f" {xml_file.relative_to(self.unpacked_dir)}: Error: {e}" ) if errors: print(f"FAILED - Found {len(errors)} whitespace preservation violations:") for error in errors: print(error) return False else: if self.verbose: print("PASSED - All whitespace is properly preserved") return True def validate_deletions(self): errors = [] for xml_file in self.xml_files: if xml_file.name != "document.xml": continue try: root = lxml.etree.parse(str(xml_file)).getroot() namespaces = {"w": self.WORD_2006_NAMESPACE} for t_elem in root.xpath(".//w:del//w:t", namespaces=namespaces): if t_elem.text: text_preview = ( repr(t_elem.text)[:50] + "..." if len(repr(t_elem.text)) > 50 else repr(t_elem.text) ) errors.append( f" {xml_file.relative_to(self.unpacked_dir)}: " f"Line {t_elem.sourceline}: found within : {text_preview}" ) for instr_elem in root.xpath( ".//w:del//w:instrText", namespaces=namespaces ): text_preview = ( repr(instr_elem.text or "")[:50] + "..." if len(repr(instr_elem.text or "")) > 50 else repr(instr_elem.text or "") ) errors.append( f" {xml_file.relative_to(self.unpacked_dir)}: " f"Line {instr_elem.sourceline}: found within (use ): {text_preview}" ) except (lxml.etree.XMLSyntaxError, Exception) as e: errors.append( f" {xml_file.relative_to(self.unpacked_dir)}: Error: {e}" ) if errors: print(f"FAILED - Found {len(errors)} deletion validation violations:") for error in errors: print(error) return False else: if self.verbose: print("PASSED - No w:t elements found within w:del elements") return True def count_paragraphs_in_unpacked(self): count = 0 for xml_file in self.xml_files: if xml_file.name != "document.xml": continue try: root = lxml.etree.parse(str(xml_file)).getroot() paragraphs = root.findall(f".//{{{self.WORD_2006_NAMESPACE}}}p") count = len(paragraphs) except Exception as e: print(f"Error counting paragraphs in unpacked document: {e}") return count def count_paragraphs_in_original(self): original = self.original_file if original is None: return 0 count = 0 try: with tempfile.TemporaryDirectory() as temp_dir: with zipfile.ZipFile(original, "r") as zip_ref: zip_ref.extractall(temp_dir) doc_xml_path = temp_dir + "/word/document.xml" root = lxml.etree.parse(doc_xml_path).getroot() paragraphs = root.findall(f".//{{{self.WORD_2006_NAMESPACE}}}p") count = len(paragraphs) except Exception as e: print(f"Error counting paragraphs in original document: {e}") return count def validate_insertions(self): errors = [] for xml_file in self.xml_files: if xml_file.name != "document.xml": continue try: root = lxml.etree.parse(str(xml_file)).getroot() namespaces = {"w": self.WORD_2006_NAMESPACE} invalid_elements = root.xpath( ".//w:ins//w:delText[not(ancestor::w:del)]", namespaces=namespaces ) for elem in invalid_elements: text_preview = ( repr(elem.text or "")[:50] + "..." if len(repr(elem.text or "")) > 50 else repr(elem.text or "") ) errors.append( f" {xml_file.relative_to(self.unpacked_dir)}: " f"Line {elem.sourceline}: within : {text_preview}" ) except (lxml.etree.XMLSyntaxError, Exception) as e: errors.append( f" {xml_file.relative_to(self.unpacked_dir)}: Error: {e}" ) if errors: print(f"FAILED - Found {len(errors)} insertion validation violations:") for error in errors: print(error) return False else: if self.verbose: print("PASSED - No w:delText elements within w:ins elements") return True def compare_paragraph_counts(self): original_count = self.count_paragraphs_in_original() new_count = self.count_paragraphs_in_unpacked() diff = new_count - original_count diff_str = f"+{diff}" if diff > 0 else str(diff) print(f"\nParagraphs: {original_count} → {new_count} ({diff_str})") def _parse_id_value(self, val: str, base: int = 16) -> int: return int(val, base) def validate_id_constraints(self): errors = [] para_id_attr = f"{{{self.W14_NAMESPACE}}}paraId" durable_id_attr = f"{{{self.W16CID_NAMESPACE}}}durableId" for xml_file in self.xml_files: try: for elem in lxml.etree.parse(str(xml_file)).iter(): if val := elem.get(para_id_attr): if self._parse_id_value(val, base=16) >= 0x80000000: errors.append( f" {xml_file.name}:{elem.sourceline}: paraId={val} >= 0x80000000" ) if val := elem.get(durable_id_attr): if xml_file.name == "numbering.xml": try: if self._parse_id_value(val, base=10) >= 0x7FFFFFFF: errors.append( f" {xml_file.name}:{elem.sourceline}: durableId={val} >= 0x7FFFFFFF" ) except ValueError: errors.append( f" {xml_file.name}:{elem.sourceline}: durableId={val} must be decimal in numbering.xml" ) else: if self._parse_id_value(val, base=16) >= 0x7FFFFFFF: errors.append( f" {xml_file.name}:{elem.sourceline}: durableId={val} >= 0x7FFFFFFF" ) except Exception: pass if errors: print(f"FAILED - {len(errors)} ID constraint violations:") for e in errors: print(e) elif self.verbose: print("PASSED - All paraId/durableId values within constraints") return not errors def validate_comment_markers(self): errors = [] document_xml = None comments_xml = None for xml_file in self.xml_files: if xml_file.name == "document.xml" and "word" in str(xml_file): document_xml = xml_file elif xml_file.name == "comments.xml": comments_xml = xml_file if not document_xml: if self.verbose: print("PASSED - No document.xml found (skipping comment validation)") return True try: doc_root = lxml.etree.parse(str(document_xml)).getroot() namespaces = {"w": self.WORD_2006_NAMESPACE} range_starts = { elem.get(f"{{{self.WORD_2006_NAMESPACE}}}id") for elem in doc_root.xpath( ".//w:commentRangeStart", namespaces=namespaces ) } range_ends = { elem.get(f"{{{self.WORD_2006_NAMESPACE}}}id") for elem in doc_root.xpath( ".//w:commentRangeEnd", namespaces=namespaces ) } references = { elem.get(f"{{{self.WORD_2006_NAMESPACE}}}id") for elem in doc_root.xpath( ".//w:commentReference", namespaces=namespaces ) } orphaned_ends = range_ends - range_starts for comment_id in sorted( orphaned_ends, key=lambda x: int(x) if x and x.isdigit() else 0 ): errors.append( f' document.xml: commentRangeEnd id="{comment_id}" has no matching commentRangeStart' ) orphaned_starts = range_starts - range_ends for comment_id in sorted( orphaned_starts, key=lambda x: int(x) if x and x.isdigit() else 0 ): errors.append( f' document.xml: commentRangeStart id="{comment_id}" has no matching commentRangeEnd' ) comment_ids = set() if comments_xml and comments_xml.exists(): comments_root = lxml.etree.parse(str(comments_xml)).getroot() comment_ids = { elem.get(f"{{{self.WORD_2006_NAMESPACE}}}id") for elem in comments_root.xpath( ".//w:comment", namespaces=namespaces ) } marker_ids = range_starts | range_ends | references invalid_refs = marker_ids - comment_ids for comment_id in sorted( invalid_refs, key=lambda x: int(x) if x and x.isdigit() else 0 ): if comment_id: errors.append( f' document.xml: marker id="{comment_id}" references non-existent comment' ) except (lxml.etree.XMLSyntaxError, Exception) as e: errors.append(f" Error parsing XML: {e}") if errors: print(f"FAILED - {len(errors)} comment marker violations:") for error in errors: print(error) return False else: if self.verbose: print("PASSED - All comment markers properly paired") return True def repair(self) -> int: repairs = super().repair() repairs += self.repair_durableId() return repairs def repair_durableId(self) -> int: repairs = 0 for xml_file in self.xml_files: try: content = xml_file.read_text(encoding="utf-8") dom = defusedxml.minidom.parseString(content) modified = False for elem in dom.getElementsByTagName("*"): if not elem.hasAttribute("w16cid:durableId"): continue durable_id = elem.getAttribute("w16cid:durableId") needs_repair = False if xml_file.name == "numbering.xml": try: needs_repair = ( self._parse_id_value(durable_id, base=10) >= 0x7FFFFFFF ) except ValueError: needs_repair = True else: try: needs_repair = ( self._parse_id_value(durable_id, base=16) >= 0x7FFFFFFF ) except ValueError: needs_repair = True if needs_repair: value = random.randint(1, 0x7FFFFFFE) if xml_file.name == "numbering.xml": new_id = str(value) else: new_id = f"{value:08X}" elem.setAttribute("w16cid:durableId", new_id) print( f" Repaired: {xml_file.name}: durableId {durable_id} → {new_id}" ) repairs += 1 modified = True if modified: xml_file.write_bytes(dom.toxml(encoding="UTF-8")) except Exception: pass return repairs if __name__ == "__main__": raise RuntimeError("This module should not be run directly.") ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/validators/pptx.py ================================================ """ Validator for PowerPoint presentation XML files against XSD schemas. """ import re from .base import BaseSchemaValidator class PPTXSchemaValidator(BaseSchemaValidator): PRESENTATIONML_NAMESPACE = ( "http://schemas.openxmlformats.org/presentationml/2006/main" ) ELEMENT_RELATIONSHIP_TYPES = { "sldid": "slide", "sldmasterid": "slidemaster", "notesmasterid": "notesmaster", "sldlayoutid": "slidelayout", "themeid": "theme", "tablestyleid": "tablestyles", } def validate(self): if not self.validate_xml(): return False all_valid = True if not self.validate_namespaces(): all_valid = False if not self.validate_unique_ids(): all_valid = False if not self.validate_uuid_ids(): all_valid = False if not self.validate_file_references(): all_valid = False if not self.validate_slide_layout_ids(): all_valid = False if not self.validate_content_types(): all_valid = False if not self.validate_against_xsd(): all_valid = False if not self.validate_notes_slide_references(): all_valid = False if not self.validate_all_relationship_ids(): all_valid = False if not self.validate_no_duplicate_slide_layouts(): all_valid = False return all_valid def validate_uuid_ids(self): import lxml.etree errors = [] uuid_pattern = re.compile( r"^[\{\(]?[0-9A-Fa-f]{8}-?[0-9A-Fa-f]{4}-?[0-9A-Fa-f]{4}-?[0-9A-Fa-f]{4}-?[0-9A-Fa-f]{12}[\}\)]?$" ) for xml_file in self.xml_files: try: root = lxml.etree.parse(str(xml_file)).getroot() for elem in root.iter(): for attr, value in elem.attrib.items(): attr_name = attr.split("}")[-1].lower() if attr_name == "id" or attr_name.endswith("id"): if self._looks_like_uuid(value): if not uuid_pattern.match(value): errors.append( f" {xml_file.relative_to(self.unpacked_dir)}: " f"Line {elem.sourceline}: ID '{value}' appears to be " "a UUID but contains invalid hex characters" ) except (lxml.etree.XMLSyntaxError, Exception) as e: errors.append( f" {xml_file.relative_to(self.unpacked_dir)}: Error: {e}" ) if errors: print(f"FAILED - Found {len(errors)} UUID ID validation errors:") for error in errors: print(error) return False else: if self.verbose: print("PASSED - All UUID-like IDs contain valid hex values") return True def _looks_like_uuid(self, value): clean_value = value.strip("{}()").replace("-", "") return len(clean_value) == 32 and all(c.isalnum() for c in clean_value) def validate_slide_layout_ids(self): import lxml.etree errors = [] slide_masters = list(self.unpacked_dir.glob("ppt/slideMasters/*.xml")) if not slide_masters: if self.verbose: print("PASSED - No slide masters found") return True for slide_master in slide_masters: try: root = lxml.etree.parse(str(slide_master)).getroot() rels_file = slide_master.parent / "_rels" / f"{slide_master.name}.rels" if not rels_file.exists(): errors.append( f" {slide_master.relative_to(self.unpacked_dir)}: " f"Missing relationships file: {rels_file.relative_to(self.unpacked_dir)}" ) continue rels_root = lxml.etree.parse(str(rels_file)).getroot() valid_layout_rids = set() for rel in rels_root.findall( f".//{{{self.PACKAGE_RELATIONSHIPS_NAMESPACE}}}Relationship" ): rel_type = rel.get("Type", "") if "slideLayout" in rel_type: valid_layout_rids.add(rel.get("Id")) for sld_layout_id in root.findall( f".//{{{self.PRESENTATIONML_NAMESPACE}}}sldLayoutId" ): r_id = sld_layout_id.get( f"{{{self.OFFICE_RELATIONSHIPS_NAMESPACE}}}id" ) layout_id = sld_layout_id.get("id") if r_id and r_id not in valid_layout_rids: errors.append( f" {slide_master.relative_to(self.unpacked_dir)}: " f"Line {sld_layout_id.sourceline}: sldLayoutId with id='{layout_id}' " f"references r:id='{r_id}' which is not found in slide layout relationships" ) except (lxml.etree.XMLSyntaxError, Exception) as e: errors.append( f" {slide_master.relative_to(self.unpacked_dir)}: Error: {e}" ) if errors: print(f"FAILED - Found {len(errors)} slide layout ID validation errors:") for error in errors: print(error) print( "Remove invalid references or add missing slide layouts to the relationships file." ) return False else: if self.verbose: print("PASSED - All slide layout IDs reference valid slide layouts") return True def validate_no_duplicate_slide_layouts(self): import lxml.etree errors = [] slide_rels_files = list(self.unpacked_dir.glob("ppt/slides/_rels/*.xml.rels")) for rels_file in slide_rels_files: try: root = lxml.etree.parse(str(rels_file)).getroot() layout_rels = [ rel for rel in root.findall( f".//{{{self.PACKAGE_RELATIONSHIPS_NAMESPACE}}}Relationship" ) if "slideLayout" in rel.get("Type", "") ] if len(layout_rels) > 1: errors.append( f" {rels_file.relative_to(self.unpacked_dir)}: has {len(layout_rels)} slideLayout references" ) except Exception as e: errors.append( f" {rels_file.relative_to(self.unpacked_dir)}: Error: {e}" ) if errors: print("FAILED - Found slides with duplicate slideLayout references:") for error in errors: print(error) return False else: if self.verbose: print("PASSED - All slides have exactly one slideLayout reference") return True def validate_notes_slide_references(self): import lxml.etree errors = [] notes_slide_references = {} slide_rels_files = list(self.unpacked_dir.glob("ppt/slides/_rels/*.xml.rels")) if not slide_rels_files: if self.verbose: print("PASSED - No slide relationship files found") return True for rels_file in slide_rels_files: try: root = lxml.etree.parse(str(rels_file)).getroot() for rel in root.findall( f".//{{{self.PACKAGE_RELATIONSHIPS_NAMESPACE}}}Relationship" ): rel_type = rel.get("Type", "") if "notesSlide" in rel_type: target = rel.get("Target", "") if target: normalized_target = target.replace("../", "") slide_name = rels_file.stem.replace(".xml", "") if normalized_target not in notes_slide_references: notes_slide_references[normalized_target] = [] notes_slide_references[normalized_target].append( (slide_name, rels_file) ) except (lxml.etree.XMLSyntaxError, Exception) as e: errors.append( f" {rels_file.relative_to(self.unpacked_dir)}: Error: {e}" ) for target, references in notes_slide_references.items(): if len(references) > 1: slide_names = [ref[0] for ref in references] errors.append( f" Notes slide '{target}' is referenced by multiple slides: {', '.join(slide_names)}" ) for slide_name, rels_file in references: errors.append(f" - {rels_file.relative_to(self.unpacked_dir)}") if errors: print( f"FAILED - Found {len([e for e in errors if not e.startswith(' ')])} notes slide reference validation errors:" ) for error in errors: print(error) print("Each slide may optionally have its own slide file.") return False else: if self.verbose: print("PASSED - All notes slide references are unique") return True if __name__ == "__main__": raise RuntimeError("This module should not be run directly.") ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/office/validators/redlining.py ================================================ """ Validator for tracked changes in Word documents. """ import subprocess import tempfile import zipfile from pathlib import Path class RedliningValidator: def __init__(self, unpacked_dir, original_docx, verbose=False, author="Claude"): self.unpacked_dir = Path(unpacked_dir) self.original_docx = Path(original_docx) self.verbose = verbose self.author = author self.namespaces = { "w": "http://schemas.openxmlformats.org/wordprocessingml/2006/main" } def repair(self) -> int: return 0 def validate(self): modified_file = self.unpacked_dir / "word" / "document.xml" if not modified_file.exists(): print(f"FAILED - Modified document.xml not found at {modified_file}") return False try: import xml.etree.ElementTree as ET tree = ET.parse(modified_file) root = tree.getroot() del_elements = root.findall(".//w:del", self.namespaces) ins_elements = root.findall(".//w:ins", self.namespaces) author_del_elements = [ elem for elem in del_elements if elem.get(f"{{{self.namespaces['w']}}}author") == self.author ] author_ins_elements = [ elem for elem in ins_elements if elem.get(f"{{{self.namespaces['w']}}}author") == self.author ] if not author_del_elements and not author_ins_elements: if self.verbose: print(f"PASSED - No tracked changes by {self.author} found.") return True except Exception: pass with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) try: with zipfile.ZipFile(self.original_docx, "r") as zip_ref: zip_ref.extractall(temp_path) except Exception as e: print(f"FAILED - Error unpacking original docx: {e}") return False original_file = temp_path / "word" / "document.xml" if not original_file.exists(): print( f"FAILED - Original document.xml not found in {self.original_docx}" ) return False try: import xml.etree.ElementTree as ET modified_tree = ET.parse(modified_file) modified_root = modified_tree.getroot() original_tree = ET.parse(original_file) original_root = original_tree.getroot() except ET.ParseError as e: print(f"FAILED - Error parsing XML files: {e}") return False self._remove_author_tracked_changes(original_root) self._remove_author_tracked_changes(modified_root) modified_text = self._extract_text_content(modified_root) original_text = self._extract_text_content(original_root) if modified_text != original_text: error_message = self._generate_detailed_diff( original_text, modified_text ) print(error_message) return False if self.verbose: print(f"PASSED - All changes by {self.author} are properly tracked") return True def _generate_detailed_diff(self, original_text, modified_text): error_parts = [ f"FAILED - Document text doesn't match after removing {self.author}'s tracked changes", "", "Likely causes:", " 1. Modified text inside another author's or tags", " 2. Made edits without proper tracked changes", " 3. Didn't nest inside when deleting another's insertion", "", "For pre-redlined documents, use correct patterns:", " - To reject another's INSERTION: Nest inside their ", " - To restore another's DELETION: Add new AFTER their ", "", ] git_diff = self._get_git_word_diff(original_text, modified_text) if git_diff: error_parts.extend(["Differences:", "============", git_diff]) else: error_parts.append("Unable to generate word diff (git not available)") return "\n".join(error_parts) def _get_git_word_diff(self, original_text, modified_text): try: with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) original_file = temp_path / "original.txt" modified_file = temp_path / "modified.txt" original_file.write_text(original_text, encoding="utf-8") modified_file.write_text(modified_text, encoding="utf-8") result = subprocess.run( [ "git", "diff", "--word-diff=plain", "--word-diff-regex=.", "-U0", "--no-index", str(original_file), str(modified_file), ], capture_output=True, text=True, ) if result.stdout.strip(): lines = result.stdout.split("\n") content_lines = [] in_content = False for line in lines: if line.startswith("@@"): in_content = True continue if in_content and line.strip(): content_lines.append(line) if content_lines: return "\n".join(content_lines) result = subprocess.run( [ "git", "diff", "--word-diff=plain", "-U0", "--no-index", str(original_file), str(modified_file), ], capture_output=True, text=True, ) if result.stdout.strip(): lines = result.stdout.split("\n") content_lines = [] in_content = False for line in lines: if line.startswith("@@"): in_content = True continue if in_content and line.strip(): content_lines.append(line) return "\n".join(content_lines) except (subprocess.CalledProcessError, FileNotFoundError, Exception): pass return None def _remove_author_tracked_changes(self, root): ins_tag = f"{{{self.namespaces['w']}}}ins" del_tag = f"{{{self.namespaces['w']}}}del" author_attr = f"{{{self.namespaces['w']}}}author" for parent in root.iter(): to_remove = [] for child in parent: if child.tag == ins_tag and child.get(author_attr) == self.author: to_remove.append(child) for elem in to_remove: parent.remove(elem) deltext_tag = f"{{{self.namespaces['w']}}}delText" t_tag = f"{{{self.namespaces['w']}}}t" for parent in root.iter(): to_process = [] for child in parent: if child.tag == del_tag and child.get(author_attr) == self.author: to_process.append((child, list(parent).index(child))) for del_elem, del_index in reversed(to_process): for elem in del_elem.iter(): if elem.tag == deltext_tag: elem.tag = t_tag for child in reversed(list(del_elem)): parent.insert(del_index, child) parent.remove(del_elem) def _extract_text_content(self, root): p_tag = f"{{{self.namespaces['w']}}}p" t_tag = f"{{{self.namespaces['w']}}}t" paragraphs = [] for p_elem in root.findall(f".//{p_tag}"): text_parts = [] for t_elem in p_elem.findall(f".//{t_tag}"): if t_elem.text: text_parts.append(t_elem.text) paragraph_text = "".join(text_parts) if paragraph_text: paragraphs.append(paragraph_text) return "\n".join(paragraphs) if __name__ == "__main__": raise RuntimeError("This module should not be run directly.") ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/preview.py ================================================ """Generate slide preview images from a PowerPoint file. Converts PPTX -> PDF -> JPEG slides with caching. If cached slides already exist and are up-to-date, returns them without reconverting. Output protocol (stdout): Line 1: status — one of CACHED, GENERATED, ERROR_NOT_FOUND, ERROR_NO_PDF Lines 2+: sorted absolute paths to slide-*.jpg files Usage: python preview.py /path/to/file.pptx /path/to/cache_dir """ import os import subprocess import sys from pathlib import Path # Allow importing office.soffice from the scripts directory sys.path.insert(0, str(Path(__file__).resolve().parent)) from office.soffice import run_soffice CONVERSION_DPI = 150 def _find_slides(directory: Path) -> list[str]: """Find slide-*.jpg files in directory, sorted by page number.""" slides = list(directory.glob("slide-*.jpg")) slides.sort(key=lambda p: int(p.stem.split("-")[-1])) return [str(s) for s in slides] def main() -> None: if len(sys.argv) != 3: print(f"Usage: {sys.argv[0]} ", file=sys.stderr) sys.exit(1) pptx_path = Path(sys.argv[1]) cache_dir = Path(sys.argv[2]) if not pptx_path.is_file(): print("ERROR_NOT_FOUND") return # Check cache: if slides exist and are at least as new as the PPTX, reuse them cached_slides = _find_slides(cache_dir) if cached_slides: pptx_mtime = os.path.getmtime(pptx_path) oldest_slide_mtime = min(os.path.getmtime(s) for s in cached_slides) if oldest_slide_mtime >= pptx_mtime: print("CACHED") for slide in cached_slides: print(slide) return # Stale cache — remove old slides for slide in cached_slides: os.remove(slide) cache_dir.mkdir(parents=True, exist_ok=True) # Convert PPTX -> PDF via LibreOffice result = run_soffice( [ "--headless", "--convert-to", "pdf", "--outdir", str(cache_dir), str(pptx_path), ], capture_output=True, text=True, ) if result.returncode != 0: print("CONVERSION_ERROR", file=sys.stderr) sys.exit(1) # Find the generated PDF pdfs = sorted(cache_dir.glob("*.pdf")) if not pdfs: print("ERROR_NO_PDF") return pdf_file = pdfs[0] # Convert PDF -> JPEG slides result = subprocess.run( [ "pdftoppm", "-jpeg", "-r", str(CONVERSION_DPI), str(pdf_file), str(cache_dir / "slide"), ], capture_output=True, text=True, ) if result.returncode != 0: print("CONVERSION_ERROR", file=sys.stderr) sys.exit(1) # Clean up PDF pdf_file.unlink(missing_ok=True) slides = _find_slides(cache_dir) print("GENERATED") for slide in slides: print(slide) if __name__ == "__main__": main() ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx/scripts/thumbnail.py ================================================ """Create thumbnail grids from PowerPoint presentation slides. Creates a grid layout of slide thumbnails for quick visual analysis. Labels each thumbnail with its XML filename (e.g., slide1.xml). Hidden slides are shown with a placeholder pattern. Usage: python thumbnail.py input.pptx [output_prefix] [--cols N] Examples: python thumbnail.py presentation.pptx # Creates: thumbnails.jpg python thumbnail.py template.pptx grid --cols 4 # Creates: grid.jpg (or grid-1.jpg, grid-2.jpg for large decks) """ import argparse import subprocess import sys import tempfile import zipfile from pathlib import Path import defusedxml.minidom from office.soffice import get_soffice_env from PIL import Image from PIL import ImageDraw from PIL import ImageFont THUMBNAIL_WIDTH = 300 CONVERSION_DPI = 100 MAX_COLS = 6 DEFAULT_COLS = 3 JPEG_QUALITY = 95 GRID_PADDING = 20 BORDER_WIDTH = 2 FONT_SIZE_RATIO = 0.10 LABEL_PADDING_RATIO = 0.4 def main(): parser = argparse.ArgumentParser( description="Create thumbnail grids from PowerPoint slides." ) parser.add_argument("input", help="Input PowerPoint file (.pptx)") parser.add_argument( "output_prefix", nargs="?", default="thumbnails", help="Output prefix for image files (default: thumbnails)", ) parser.add_argument( "--cols", type=int, default=DEFAULT_COLS, help=f"Number of columns (default: {DEFAULT_COLS}, max: {MAX_COLS})", ) args = parser.parse_args() cols = min(args.cols, MAX_COLS) if args.cols > MAX_COLS: print(f"Warning: Columns limited to {MAX_COLS}") input_path = Path(args.input) if not input_path.exists() or input_path.suffix.lower() != ".pptx": print(f"Error: Invalid PowerPoint file: {args.input}", file=sys.stderr) sys.exit(1) output_path = Path(f"{args.output_prefix}.jpg") try: slide_info = get_slide_info(input_path) with tempfile.TemporaryDirectory() as temp_dir: temp_path = Path(temp_dir) visible_images = convert_to_images(input_path, temp_path) if not visible_images and not any(s["hidden"] for s in slide_info): print("Error: No slides found", file=sys.stderr) sys.exit(1) slides = build_slide_list(slide_info, visible_images, temp_path) grid_files = create_grids(slides, cols, THUMBNAIL_WIDTH, output_path) print(f"Created {len(grid_files)} grid(s):") for grid_file in grid_files: print(f" {grid_file}") except Exception as e: print(f"Error: {e}", file=sys.stderr) sys.exit(1) def get_slide_info(pptx_path: Path) -> list[dict]: with zipfile.ZipFile(pptx_path, "r") as zf: rels_content = zf.read("ppt/_rels/presentation.xml.rels").decode("utf-8") rels_dom = defusedxml.minidom.parseString(rels_content) rid_to_slide = {} for rel in rels_dom.getElementsByTagName("Relationship"): rid = rel.getAttribute("Id") target = rel.getAttribute("Target") rel_type = rel.getAttribute("Type") if "slide" in rel_type and target.startswith("slides/"): rid_to_slide[rid] = target.replace("slides/", "") pres_content = zf.read("ppt/presentation.xml").decode("utf-8") pres_dom = defusedxml.minidom.parseString(pres_content) slides = [] for sld_id in pres_dom.getElementsByTagName("p:sldId"): rid = sld_id.getAttribute("r:id") if rid in rid_to_slide: hidden = sld_id.getAttribute("show") == "0" slides.append({"name": rid_to_slide[rid], "hidden": hidden}) return slides def build_slide_list( slide_info: list[dict], visible_images: list[Path], temp_dir: Path, ) -> list[tuple[Path, str]]: if visible_images: with Image.open(visible_images[0]) as img: placeholder_size = img.size else: placeholder_size = (1920, 1080) slides = [] visible_idx = 0 for info in slide_info: if info["hidden"]: placeholder_path = temp_dir / f"hidden-{info['name']}.jpg" placeholder_img = create_hidden_placeholder(placeholder_size) placeholder_img.save(placeholder_path, "JPEG") slides.append((placeholder_path, f"{info['name']} (hidden)")) else: if visible_idx < len(visible_images): slides.append((visible_images[visible_idx], info["name"])) visible_idx += 1 return slides def create_hidden_placeholder(size: tuple[int, int]) -> Image.Image: img = Image.new("RGB", size, color="#F0F0F0") draw = ImageDraw.Draw(img) line_width = max(5, min(size) // 100) draw.line([(0, 0), size], fill="#CCCCCC", width=line_width) draw.line([(size[0], 0), (0, size[1])], fill="#CCCCCC", width=line_width) return img def convert_to_images(pptx_path: Path, temp_dir: Path) -> list[Path]: pdf_path = temp_dir / f"{pptx_path.stem}.pdf" result = subprocess.run( [ "soffice", "--headless", "--convert-to", "pdf", "--outdir", str(temp_dir), str(pptx_path), ], capture_output=True, text=True, env=get_soffice_env(), ) if result.returncode != 0 or not pdf_path.exists(): raise RuntimeError("PDF conversion failed") result = subprocess.run( [ "pdftoppm", "-jpeg", "-r", str(CONVERSION_DPI), str(pdf_path), str(temp_dir / "slide"), ], capture_output=True, text=True, ) if result.returncode != 0: raise RuntimeError("Image conversion failed") return sorted(temp_dir.glob("slide-*.jpg")) def create_grids( slides: list[tuple[Path, str]], cols: int, width: int, output_path: Path, ) -> list[str]: max_per_grid = cols * (cols + 1) grid_files = [] for chunk_idx, start_idx in enumerate(range(0, len(slides), max_per_grid)): end_idx = min(start_idx + max_per_grid, len(slides)) chunk_slides = slides[start_idx:end_idx] grid = create_grid(chunk_slides, cols, width) if len(slides) <= max_per_grid: grid_filename = output_path else: stem = output_path.stem suffix = output_path.suffix grid_filename = output_path.parent / f"{stem}-{chunk_idx + 1}{suffix}" grid_filename.parent.mkdir(parents=True, exist_ok=True) grid.save(str(grid_filename), quality=JPEG_QUALITY) grid_files.append(str(grid_filename)) return grid_files def create_grid( slides: list[tuple[Path, str]], cols: int, width: int, ) -> Image.Image: font_size = int(width * FONT_SIZE_RATIO) label_padding = int(font_size * LABEL_PADDING_RATIO) with Image.open(slides[0][0]) as img: aspect = img.height / img.width height = int(width * aspect) rows = (len(slides) + cols - 1) // cols grid_w = cols * width + (cols + 1) * GRID_PADDING grid_h = rows * (height + font_size + label_padding * 2) + (rows + 1) * GRID_PADDING grid = Image.new("RGB", (grid_w, grid_h), "white") draw = ImageDraw.Draw(grid) try: font = ImageFont.load_default(size=font_size) except Exception: font = ImageFont.load_default() for i, (img_path, slide_name) in enumerate(slides): row, col = i // cols, i % cols x = col * width + (col + 1) * GRID_PADDING y_base = ( row * (height + font_size + label_padding * 2) + (row + 1) * GRID_PADDING ) label = slide_name bbox = draw.textbbox((0, 0), label, font=font) text_w = bbox[2] - bbox[0] draw.text( (x + (width - text_w) // 2, y_base + label_padding), label, fill="black", font=font, ) y_thumbnail = y_base + label_padding + font_size + label_padding with Image.open(img_path) as img: img.thumbnail((width, height), Image.Resampling.LANCZOS) w, h = img.size tx = x + (width - w) // 2 ty = y_thumbnail + (height - h) // 2 grid.paste(img, (tx, ty)) if BORDER_WIDTH > 0: draw.rectangle( [ (tx - BORDER_WIDTH, ty - BORDER_WIDTH), (tx + w + BORDER_WIDTH - 1, ty + h + BORDER_WIDTH - 1), ], outline="gray", width=BORDER_WIDTH, ) return grid if __name__ == "__main__": main() ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web/.gitignore ================================================ # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. # dependencies /node_modules /.pnp .pnp.* .yarn/* !.yarn/patches !.yarn/plugins !.yarn/releases !.yarn/versions # testing /coverage # next.js /.next/ /out/ # production /build # misc .DS_Store *.pem # debug npm-debug.log* yarn-debug.log* yarn-error.log* .pnpm-debug.log* # env files (can opt-in for committing if needed) .env* # vercel .vercel # typescript *.tsbuildinfo next-env.d.ts ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web/AGENTS.md ================================================ # AGENTS.md This file provides guidance to AI agents when working on the web application within this directory. ## Important Notes - **The development server is already running** at a dynamically allocated port. Do NOT run `npm run dev` yourself. - **We do NOT use a `src` directory** - all code lives directly in the root folders (`app/`, `components/`, `lib/`, etc.) - If the app needs pre-computation (data processing, API calls, etc.), create a bash or python script called `prepare.sh`/`prepare.py` at the root of this directory - **CRITICAL: Create small, modular components** - Do NOT write everything in `page.tsx`. Break your UI into small, reusable components in the `components/` directory. Each component should have a single responsibility and be in its own file. ## Data Preparation Scripts **CRITICAL: Always re-run data scripts after modifying them.** If a `prepare.sh` or `prepare.py` script exists at the root of this directory, it is responsible for generating/loading data that the frontend consumes. ### When to Run the Script You MUST run the data preparation script: 1. **After creating** the script for the first time 2. **After modifying** the script logic (new data sources, changed processing, etc.) 3. **After updating** any data files the script reads from 4. **Before testing** the frontend if you're unsure if data is fresh ### How to Run ```bash # For bash scripts bash prepare.sh # For python scripts python prepare.py ``` ### Common Mistake ❌ **Updating the script but forgetting to run it** - This leaves stale data in place and the frontend won't reflect your changes. Always run the script immediately after modifying it. ## Commands ```bash npm run dev # Start development server (DO NOT RUN - already running) npm run lint # Run ESLint ``` ## Architecture This is a **Next.js 16.1.1** application using the **App Router** with **React 19** and **TypeScript**. It serves as a component showcase/template built on shadcn/ui. ### File Organization Philosophy **Prioritize small, incremental file writes.** Break your application into many small components rather than monolithic page files. #### Component Organization ``` components/ ├── dashboard/ # Feature-specific components │ ├── stats-card.tsx │ ├── activity-feed.tsx │ └── recent-items.tsx ├── charts/ # Chart components │ ├── line-chart.tsx │ ├── bar-chart.tsx │ └── pie-chart.tsx ├── data/ # Data display components │ ├── data-table.tsx │ ├── filter-bar.tsx │ └── sort-controls.tsx └── layout/ # Layout components ├── header.tsx ├── sidebar.tsx └── footer.tsx ``` #### Page Structure Pages (`app/page.tsx`) should be **thin orchestration layers** that compose components: ```typescript // ✅ GOOD - page.tsx is just composition import { StatsCard } from "@/components/dashboard/stats-card"; import { ActivityFeed } from "@/components/dashboard/activity-feed"; import { RecentItems } from "@/components/dashboard/recent-items"; export default function DashboardPage() { return (

    Dashboard

    ); } // ❌ BAD - Everything in page.tsx (500+ lines of mixed logic) export default function DashboardPage() { // ... 500 lines of component logic, state, handlers, JSX ... } ``` #### Component Granularity Create a new component file when: - A UI section has distinct functionality (e.g., `user-profile-card.tsx`) - Logic exceeds ~50-100 lines - A pattern is reused 2+ times - Testing/maintenance would benefit from isolation **Example: Dashboard Feature** Instead of writing everything in `app/page.tsx`: ```typescript // components/dashboard/stats-card.tsx export function StatsCard({ title, value, trend }: StatsCardProps) { return ( {title}
    {value}
    {trend &&

    {trend}

    }
    ); } // components/dashboard/activity-feed.tsx export function ActivityFeed() { // Activity feed logic here } // components/dashboard/recent-items.tsx export function RecentItems() { // Recent items logic here } ``` #### Benefits of Small Components 1. **Incremental Development**: Write one component at a time, test, iterate 2. **Better Diffs**: Smaller files = clearer git diffs and easier reviews 3. **Reusability**: Components can be imported across pages 4. **Maintainability**: Easier to locate and fix issues 5. **Hot Reload Efficiency**: Changes to small files reload faster 6. **Parallel Development**: Multiple features can be worked on independently ### Tech Stack - **Framework**: Next.js 16.1.1 with App Router - **React**: React 19 - **Language**: TypeScript - **Styling**: Tailwind CSS v4 with CSS variables in OKLCH color space - **Charts**: recharts for data visualization - **UI Components**: shadcn/ui (53 components) built on Radix UI primitives - **Variants**: class-variance-authority (CVA) for component variants - **Class Merging**: `cn()` utility in `lib/utils.ts` (clsx + tailwind-merge) - **Theme**: Dark mode enforced (via `dark` class on ``) ### Key Directories - `app/` - Next.js App Router pages and layouts - `components/ui/` - shadcn/ui component library (Button, Card, Dialog, etc.) - `components/` - App-specific components - `hooks/` - Custom React hooks (e.g., `use-mobile.ts`) - `lib/` - Utilities (`cn()` function) ### Component Patterns - **Compound Components**: Components like `DropdownMenu`, `Dialog`, `Select` export multiple sub-components (Trigger, Content, Item) - **Variants via CVA**: Use `variants` prop for size/style variations (e.g., `buttonVariants`) - **Radix UI Primitives**: UI components wrap Radix for accessibility ### Path Aliases All imports use `@/` alias (e.g., `@/components/ui/button`, `@/lib/utils`) ### shadcn/ui Configuration Located in `components.json`: - Style: `radix-nova` - RSC enabled - Icons: lucide-react ### Theme Variables Global CSS variables defined in `app/globals.css` control colors, radius, and spacing. **Dark mode is enforced site-wide** via the `dark` class on the `` element in `app/layout.tsx`. All styling should assume dark mode is active. ### Dark Mode Priority - **Dark mode is the default and only theme** - do not design for light mode - The `dark` class is permanently set on `` in `layout.tsx` - Use dark-appropriate colors: `bg-background`, `text-foreground`, etc. - Ensure sufficient contrast for dark backgrounds - Test all components in dark mode only ## Styling Guidelines ### CRITICAL: Use Only shadcn/ui Components **MINIMIZE freestyling and creating custom components.** This application uses a complete, professionally designed component library (shadcn/ui). You MUST use the existing components from `components/ui/` for most UI needs. #### Available shadcn/ui Components All components are in `components/ui/`. Import using `@/components/ui/component-name`. **Layout & Structure:** - `Card` (`card.tsx`) - Content containers with CardHeader, CardTitle, CardDescription, CardContent, CardFooter - `Separator` (`separator.tsx`) - Horizontal/vertical dividers - `Tabs` (`tabs.tsx`) - Tabbed interfaces with Tabs, TabsList, TabsTrigger, TabsContent - `ScrollArea` (`scroll-area.tsx`) - Styled scrollable regions - `Resizable` (`resizable.tsx`) - Resizable panel layouts - `Drawer` (`drawer.tsx`) - Bottom/side drawer overlays - `Sidebar` (`sidebar.tsx`) - Application sidebar layout - `AspectRatio` (`aspect-ratio.tsx`) - Maintain aspect ratios **Forms & Inputs:** - `Button` (`button.tsx`) - Primary, secondary, destructive, outline, ghost, link variants - `ButtonGroup` (`button-group.tsx`) - Group of related buttons - `Input` (`input.tsx`) - Text inputs with various states - `InputGroup` (`input-group.tsx`) - Input with addons/icons - `Textarea` (`textarea.tsx`) - Multi-line text input - `Checkbox` (`checkbox.tsx`) - Checkboxes with indeterminate state - `RadioGroup` (`radio-group.tsx`) - Radio button groups - `Switch` (`switch.tsx`) - Toggle switches - `Select` (`select.tsx`) - Dropdown select menus - `NativeSelect` (`native-select.tsx`) - Native HTML select - `Combobox` (`combobox.tsx`) - Autocomplete select with search - `Command` (`command.tsx`) - Command palette/search interface - `Field` (`field.tsx`) - Form field wrapper with label and error - `Label` (`label.tsx`) - Form labels with proper accessibility - `Slider` (`slider.tsx`) - Range sliders - `Calendar` (`calendar.tsx`) - Date picker calendar - `Toggle` (`toggle.tsx`) - Toggle button - `ToggleGroup` (`toggle-group.tsx`) - Group of toggle buttons **Navigation:** - `NavigationMenu` (`navigation-menu.tsx`) - Complex navigation menus - `Menubar` (`menubar.tsx`) - Application menu bar - `Breadcrumb` (`breadcrumb.tsx`) - Breadcrumb navigation - `Pagination` (`pagination.tsx`) - Page navigation controls **Feedback & Overlays:** - `Dialog` (`dialog.tsx`) - Modal dialogs - `AlertDialog` (`alert-dialog.tsx`) - Confirmation dialogs - `Sheet` (`sheet.tsx`) - Side sheets/panels - `Popover` (`popover.tsx`) - Floating popovers - `HoverCard` (`hover-card.tsx`) - Hover-triggered cards - `Tooltip` (`tooltip.tsx`) - Tooltips on hover - `Sonner` (`sonner.tsx`) - Toast notifications - `Alert` (`alert.tsx`) - Static alert messages - `Progress` (`progress.tsx`) - Progress bars - `Skeleton` (`skeleton.tsx`) - Loading skeletons - `Spinner` (`spinner.tsx`) - Loading spinners - `Empty` (`empty.tsx`) - Empty state placeholder **Menus & Dropdowns:** - `DropdownMenu` (`dropdown-menu.tsx`) - Dropdown menus with submenus - `ContextMenu` (`context-menu.tsx`) - Right-click context menus **Data Display:** - `Table` (`table.tsx`) - Data tables with Table, TableHeader, TableBody, TableRow, TableCell, etc. - `Badge` (`badge.tsx`) - Status badges and tags - `Avatar` (`avatar.tsx`) - User avatars with fallbacks - `Accordion` (`accordion.tsx`) - Collapsible content sections - `Collapsible` (`collapsible.tsx`) - Simple collapse/expand - `Carousel` (`carousel.tsx`) - Image/content carousels - `Item` (`item.tsx`) - List item component - `Kbd` (`kbd.tsx`) - Keyboard shortcut display **Data Visualization:** - `Chart` (`chart.tsx`) - Chart wrapper with ChartContainer, ChartTooltip, ChartTooltipContent, ChartLegend, ChartLegendContent ### Component Usage Principles #### 1. **Never Create Custom Components** ```typescript // ❌ WRONG - Do not create freestyle components function CustomCard({ title, children }) { return (

    {title}

    {children}
    ); } // ✅ CORRECT - Use shadcn Card import { Card, CardHeader, CardTitle, CardContent } from "@/components/ui/card"; function MyComponent() { return ( Title Content here ); } ``` #### 2. **Use Component Variants, Don't Style Directly** ```typescript // ❌ WRONG - Applying custom Tailwind classes // ✅ CORRECT - Use Button variants import { Button } from "@/components/ui/button"; ``` #### 3. **Compose Compound Components** Many shadcn components export multiple sub-components. Use them as designed: ```typescript // ✅ Dropdown Menu Composition import { DropdownMenu, DropdownMenuTrigger, DropdownMenuContent, DropdownMenuItem, DropdownMenuSeparator, DropdownMenuLabel, } from "@/components/ui/dropdown-menu"; Actions Edit Delete ``` #### 4. **Use Layout Components for Structure** ```typescript // ✅ Use Card for content sections import { Card, CardHeader, CardTitle, CardDescription, CardContent, CardFooter } from "@/components/ui/card"; Dashboard Overview of your data {/* Your content */} ``` ### Styling Rules #### 1. **Spacing & Layout** Use Tailwind's utility classes for spacing, but stick to the design system: - Gap: `gap-2`, `gap-4`, `gap-6`, `gap-8` - Padding: `p-2`, `p-4`, `p-6`, `p-8` - Margins: Prefer `gap` and `space-y-*` over margins #### 2. **Colors** All colors come from CSS variables in `app/globals.css`. Use semantic color classes: - `bg-background`, `bg-foreground` - `bg-card`, `text-card-foreground` - `bg-primary`, `text-primary-foreground` - `bg-secondary`, `text-secondary-foreground` - `bg-muted`, `text-muted-foreground` - `bg-accent`, `text-accent-foreground` - `bg-destructive`, `text-destructive-foreground` - `border-border`, `border-input` - `ring-ring` **DO NOT use arbitrary color values** like `bg-blue-500` or `text-red-600`. #### **CRITICAL: Color Contrast Pairing Rules** **Always pair background colors with their matching foreground colors.** The color system uses paired variables where each background has a corresponding text color designed for proper contrast. | Background Class | Text Class to Use | Description | |-----------------|-------------------|-------------| | `bg-background` | `text-foreground` | Main page background | | `bg-card` | `text-card-foreground` | Card containers | | `bg-primary` | `text-primary-foreground` | Primary buttons/accents | | `bg-secondary` | `text-secondary-foreground` | Secondary elements | | `bg-muted` | `text-muted-foreground` | Muted/subtle areas | | `bg-accent` | `text-accent-foreground` | Accent highlights | | `bg-destructive` | `text-destructive-foreground` | Error/delete actions | **Examples:** ```typescript // ✅ CORRECT - Matching background and foreground pairs
    Content
    Subtle text
    // ❌ WRONG - Mismatched colors causing contrast issues
    Invisible text!
    May have poor contrast
    ``` **Key Rules:** 1. **Never use the same color for background and text** (e.g., `bg-foreground text-foreground`) 2. **Always use the `-foreground` variant for text** when using a colored background 3. **For text on `bg-background`**, use `text-foreground` (primary) or `text-muted-foreground` (secondary) 4. **Test visually** - if text is hard to read, you have a contrast problem #### 3. **Typography** Use Tailwind text utilities (no separate Typography component): - Headings: `text-xl font-semibold`, `text-2xl font-bold`, etc. - Body: `text-sm`, `text-base` - Secondary text: `text-muted-foreground` - Use semantic HTML: `

    `, `

    `, `

    `, etc. - **Always wrap text** - Use `max-w-prose` or `max-w-xl` for readable line lengths - **Prevent overflow** - Use `break-words` or `truncate` for long text that might overflow containers #### 4. **Responsive Design** Use Tailwind's responsive prefixes: ```typescript

    {/* Responsive grid */}
    ``` #### 5. **Icons** Use Lucide React icons (already configured): ```typescript import { Check, X, ChevronDown, User } from "lucide-react"; ``` ### Data Visualization For charts and data visualization, use the **shadcn/ui Chart components** (`@/components/ui/chart`) which wrap recharts with consistent theming. Charts should be **elegant, informative, and digestible at a glance**. #### Chart Design Principles 1. **Clarity over complexity** - A chart should communicate ONE key insight immediately 2. **Minimal visual noise** - Remove anything that doesn't add information 3. **Consistent styling** - Use `ChartConfig` for colors, not arbitrary values 4. **Responsive** - Always use `ChartContainer` (includes ResponsiveContainer) 5. **Accessible** - Use `ChartTooltip` with `ChartTooltipContent` for proper styling #### Chart Type Selection | Data Type | Recommended Chart | Use Case | |-----------|-------------------|----------| | Trend over time | `LineChart` or `AreaChart` | Stock prices, user growth, metrics over days/months | | Comparing categories | `BarChart` | Revenue by product, users by region | | Part of whole | `PieChart` or `RadialBarChart` | Market share, budget allocation | | Distribution | `BarChart` (horizontal) | Survey responses, rating distribution | | Correlation | `ScatterChart` | Price vs. quality, age vs. income | #### shadcn/ui Chart Components Always import from the shadcn chart component: ```typescript import { ChartContainer, ChartTooltip, ChartTooltipContent, ChartLegend, ChartLegendContent, type ChartConfig, } from "@/components/ui/chart"; import { LineChart, Line, XAxis, YAxis, CartesianGrid } from "recharts"; ``` #### ChartConfig - Define Colors and Labels The `ChartConfig` object defines colors and labels for your data series. This ensures consistent theming: ```typescript const chartConfig = { revenue: { label: "Revenue", color: "var(--chart-1)", }, expenses: { label: "Expenses", color: "var(--chart-2)", }, } satisfies ChartConfig; ``` #### Basic Line Chart Template ```typescript import { ChartContainer, ChartTooltip, ChartTooltipContent, type ChartConfig, } from "@/components/ui/chart"; import { LineChart, Line, XAxis, YAxis, CartesianGrid } from "recharts"; const chartConfig = { value: { label: "Value", color: "var(--chart-1)", }, } satisfies ChartConfig; } /> ``` #### Bar Chart with Multiple Series ```typescript const chartConfig = { revenue: { label: "Revenue", color: "var(--chart-1)", }, expenses: { label: "Expenses", color: "var(--chart-2)", }, } satisfies ChartConfig; } /> } /> ``` #### Pie/Donut Chart ```typescript const chartConfig = { desktop: { label: "Desktop", color: "var(--chart-1)" }, mobile: { label: "Mobile", color: "var(--chart-2)" }, tablet: { label: "Tablet", color: "var(--chart-3)" }, } satisfies ChartConfig; } /> } /> ``` #### Chart Styling Rules **Colors (use CSS variables from globals.css):** - `var(--chart-1)` through `var(--chart-5)` - Primary chart colors - `var(--primary)` - For single-series emphasis - `var(--muted)` - For de-emphasized data **Color References in Charts:** - In `ChartConfig`: Use `color: "var(--chart-1)"` - In chart elements: Use `fill="var(--color-keyname)"` or `stroke="var(--color-keyname)"` - The `keyname` matches the key in your `ChartConfig` **Visual Cleanup:** - Set `tickLine={false}` and `axisLine={false}` on axes for cleaner look - Use `vertical={false}` on `CartesianGrid` for horizontal-only grid lines - Use `dot={false}` on line charts unless individual points matter - Add `radius={4}` to bars for rounded corners - Limit to 3-5 data series maximum per chart **Avoid:** - ❌ 3D effects - ❌ More than 5-6 colors in one chart - ❌ Legends with more than 5 items (simplify the data instead) - ❌ Dual Y-axes (confusing - use two separate charts) - ❌ Pie charts with more than 5-6 slices - ❌ Custom tooltip styling - use `ChartTooltipContent` #### Fallback to Raw Recharts If shadcn/ui Chart components don't support a specific chart type (e.g., ScatterChart, ComposedChart, RadarChart), you can use recharts directly: ```typescript import { ScatterChart, Scatter, XAxis, YAxis, CartesianGrid, Tooltip, ResponsiveContainer } from "recharts"; ``` **When using raw recharts:** - Still use CSS variables for colors (`var(--chart-1)`, etc.) - Match styling to shadcn conventions (tickLine={false}, axisLine={false}) - Style tooltips to match the design system #### Data Accuracy Checklist Before displaying a chart, verify: - [ ] `ChartConfig` keys match your data's `dataKey` values - [ ] Data values are correctly mapped to the right axes - [ ] Axis labels match the data units (%, $, count, etc.) - [ ] Time series data is sorted chronologically - [ ] No missing data points that would break the visualization - [ ] `ChartTooltip` with `ChartTooltipContent` is included - [ ] Chart title/context makes the insight clear ### Common Patterns #### Loading States ```typescript import { Skeleton } from "@/components/ui/skeleton"; {isLoading ? ( ) : ( )} ``` #### Empty States ```typescript import { Empty, EmptyHeader, EmptyTitle, EmptyDescription, EmptyMedia } from "@/components/ui/empty"; import { Inbox } from "lucide-react"; No data available There's nothing to display yet. Add some items to get started. ``` #### Interactive Lists ```typescript import { ScrollArea } from "@/components/ui/scroll-area"; import { ItemGroup, Item, ItemContent, ItemTitle, ItemDescription, ItemMedia } from "@/components/ui/item"; import { FileText } from "lucide-react"; {items.map((item) => ( {item.name} {item.description} ))} ``` #### Form Fields ```typescript import { Field, FieldLabel, FieldDescription, FieldError, FieldGroup } from "@/components/ui/field"; import { Input } from "@/components/ui/input"; import { Button } from "@/components/ui/button"; Email We'll never share your email. Password Password must be at least 8 characters. ``` ### What NOT To Do ❌ **Don't create custom styled divs when a component exists** ❌ **Don't use arbitrary Tailwind colors** (use CSS variables) ❌ **Don't import UI libraries** like Material-UI, Ant Design, etc. ❌ **Don't use inline styles** except for dynamic values ❌ **Don't create custom form inputs** (use Field, Input, Select, etc. from components/ui) ❌ **Don't add new dependencies** without checking if shadcn covers it ❌ **Don't write everything in page.tsx** - break into separate component files ❌ **Don't design for light mode** - this site is dark mode only ❌ **Don't use `dark:` variants** - dark mode is always active, use base classes ### Development Workflow 1. **Plan the component structure** - Identify logical UI sections before writing code 2. **Create components incrementally** - Write one small component file at a time 3. **Test each component** - Verify it works before moving to the next 4. **Compose in page.tsx** - Import and arrange your components in the page 5. **Iterate** - Refine individual components without touching others ### Summary This application has a **complete, production-ready component library**. Your job is to: 1. **Compose** shadcn/ui components (from `components/ui/`) 2. **Create small, focused component files** (in `components/`) 3. **Keep pages thin** - pages should orchestrate components, not contain implementation Think of yourself as assembling LEGO blocks—all the UI pieces you need already exist in `components/ui/`, and you create small, organized structures by composing them into feature-specific components. ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web/app/globals.css ================================================ @import "tailwindcss"; @import "tw-animate-css"; @import "shadcn/tailwind.css"; @custom-variant dark (&:is(.dark *)); @theme inline { --color-background: var(--background); --color-foreground: var(--foreground); --font-sans: var(--font-sans); --font-mono: var(--font-geist-mono); --color-sidebar-ring: var(--sidebar-ring); --color-sidebar-border: var(--sidebar-border); --color-sidebar-accent-foreground: var(--sidebar-accent-foreground); --color-sidebar-accent: var(--sidebar-accent); --color-sidebar-primary-foreground: var(--sidebar-primary-foreground); --color-sidebar-primary: var(--sidebar-primary); --color-sidebar-foreground: var(--sidebar-foreground); --color-sidebar: var(--sidebar); --color-chart-5: var(--chart-5); --color-chart-4: var(--chart-4); --color-chart-3: var(--chart-3); --color-chart-2: var(--chart-2); --color-chart-1: var(--chart-1); --color-ring: var(--ring); --color-input: var(--input); --color-border: var(--border); --color-destructive: var(--destructive); --color-accent-foreground: var(--accent-foreground); --color-accent: var(--accent); --color-muted-foreground: var(--muted-foreground); --color-muted: var(--muted); --color-secondary-foreground: var(--secondary-foreground); --color-secondary: var(--secondary); --color-primary-foreground: var(--primary-foreground); --color-primary: var(--primary); --color-popover-foreground: var(--popover-foreground); --color-popover: var(--popover); --color-card-foreground: var(--card-foreground); --color-card: var(--card); --radius-sm: calc(var(--radius) - 4px); --radius-md: calc(var(--radius) - 2px); --radius-lg: var(--radius); --radius-xl: calc(var(--radius) + 4px); --radius-2xl: calc(var(--radius) + 8px); --radius-3xl: calc(var(--radius) + 12px); --radius-4xl: calc(var(--radius) + 16px); } :root { --background: oklch(1 0 0); --foreground: oklch(0.145 0 0); --card: oklch(1 0 0); --card-foreground: oklch(0.145 0 0); --popover: oklch(1 0 0); --popover-foreground: oklch(0.145 0 0); --primary: oklch(0.67 0.16 58); --primary-foreground: oklch(0.99 0.02 95); --secondary: oklch(0.967 0.001 286.375); --secondary-foreground: oklch(0.21 0.006 285.885); --muted: oklch(0.97 0 0); --muted-foreground: oklch(0.556 0 0); --accent: oklch(0.97 0 0); --accent-foreground: oklch(0.205 0 0); --destructive: oklch(0.58 0.22 27); --border: oklch(0.922 0 0); --input: oklch(0.922 0 0); --ring: oklch(0.708 0 0); --chart-1: oklch(0.88 0.15 92); --chart-2: oklch(0.77 0.16 70); --chart-3: oklch(0.67 0.16 58); --chart-4: oklch(0.56 0.15 49); --chart-5: oklch(0.47 0.12 46); --radius: 0.625rem; --sidebar: oklch(0.985 0 0); --sidebar-foreground: oklch(0.145 0 0); --sidebar-primary: oklch(0.67 0.16 58); --sidebar-primary-foreground: oklch(0.99 0.02 95); --sidebar-accent: oklch(0.97 0 0); --sidebar-accent-foreground: oklch(0.205 0 0); --sidebar-border: oklch(0.922 0 0); --sidebar-ring: oklch(0.708 0 0); } .dark { --background: oklch(0.145 0 0); --foreground: oklch(0.985 0 0); --card: oklch(0.205 0 0); --card-foreground: oklch(0.985 0 0); --popover: oklch(0.205 0 0); --popover-foreground: oklch(0.985 0 0); --primary: oklch(0.77 0.16 70); --primary-foreground: oklch(0.28 0.07 46); --secondary: oklch(0.274 0.006 286.033); --secondary-foreground: oklch(0.985 0 0); --muted: oklch(0.269 0 0); --muted-foreground: oklch(0.708 0 0); --accent: oklch(0.371 0 0); --accent-foreground: oklch(0.985 0 0); --destructive: oklch(0.704 0.191 22.216); --border: oklch(1 0 0 / 10%); --input: oklch(1 0 0 / 15%); --ring: oklch(0.556 0 0); /* Chart colors optimized for dark backgrounds - brighter and more vibrant */ --chart-1: oklch(0.82 0.18 140); --chart-2: oklch(0.75 0.2 200); --chart-3: oklch(0.7 0.22 280); --chart-4: oklch(0.78 0.18 50); --chart-5: oklch(0.72 0.2 330); --sidebar: oklch(0.205 0 0); --sidebar-foreground: oklch(0.985 0 0); --sidebar-primary: oklch(0.77 0.16 70); --sidebar-primary-foreground: oklch(0.28 0.07 46); --sidebar-accent: oklch(0.269 0 0); --sidebar-accent-foreground: oklch(0.985 0 0); --sidebar-border: oklch(1 0 0 / 10%); --sidebar-ring: oklch(0.556 0 0); } @layer base { * { @apply border-border outline-ring/50; } body { @apply bg-background text-foreground; } } ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web/app/layout.tsx ================================================ import type { Metadata } from "next"; import { Geist, Geist_Mono, Inter } from "next/font/google"; import "./globals.css"; const inter = Inter({ subsets: ["latin"], variable: "--font-sans" }); const geistSans = Geist({ variable: "--font-geist-sans", subsets: ["latin"], }); const geistMono = Geist_Mono({ variable: "--font-geist-mono", subsets: ["latin"], }); export const metadata: Metadata = { title: "Onyx Craft", description: "Crafting your next great idea.", }; export default function RootLayout({ children, }: Readonly<{ children: React.ReactNode; }>) { return ( {children} ); } ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web/app/page.tsx ================================================ "use client"; import { useState, useEffect, useRef } from "react"; const messages = [ "Punching wood...", "Gathering resources...", "Placing blocks...", "Crafting your workspace...", "Mining for dependencies...", "Smelting the code...", "Enchanting with magic...", "World generation complete...", "/gamemode 1", ]; const MESSAGE_COUNT = messages.length; const TYPE_DELAY = 40; const LINE_PAUSE = 800; const RESET_DELAY = 2000; export default function CraftingLoader() { const [display, setDisplay] = useState({ lines: [] as string[], currentText: "", }); const lineIndexRef = useRef(0); const charIndexRef = useRef(0); const lastUpdateRef = useRef(0); const timeoutRef = useRef(undefined); const rafRef = useRef(undefined); useEffect(() => { let isActive = true; const update = (now: number) => { if (!isActive) return; const lineIdx = lineIndexRef.current; const charIdx = charIndexRef.current; if (lineIdx >= MESSAGE_COUNT) { timeoutRef.current = setTimeout(() => { if (!isActive) return; lineIndexRef.current = 0; charIndexRef.current = 0; setDisplay({ lines: [], currentText: "" }); lastUpdateRef.current = performance.now(); rafRef.current = requestAnimationFrame(update); }, RESET_DELAY); return; } const msg = messages[lineIdx]; if (!msg) return; const elapsed = now - lastUpdateRef.current; if (charIdx < msg.length) { if (elapsed >= TYPE_DELAY) { charIndexRef.current = charIdx + 1; setDisplay((prev) => ({ lines: prev.lines, currentText: msg.substring(0, charIdx + 1), })); lastUpdateRef.current = now; } } else if (elapsed >= LINE_PAUSE) { setDisplay((prev) => ({ lines: [...prev.lines, msg], currentText: "", })); lineIndexRef.current = lineIdx + 1; charIndexRef.current = 0; lastUpdateRef.current = now; } rafRef.current = requestAnimationFrame(update); }; lastUpdateRef.current = performance.now(); rafRef.current = requestAnimationFrame(update); return () => { isActive = false; if (rafRef.current !== undefined) cancelAnimationFrame(rafRef.current); if (timeoutRef.current !== undefined) clearTimeout(timeoutRef.current); }; }, []); const { lines, currentText } = display; const hasCurrentText = currentText.length > 0; return (
    crafting_table
    {lines.map((line, i) => (
    /> {line}
    ))} {hasCurrentText && (
    /> {currentText}
    )} {!hasCurrentText && (
    />
    )}

    Crafting your next great idea...

    ); } ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web/app/site.webmanifest ================================================ {"name":"","short_name":"","icons":[{"src":"/android-chrome-192x192.png","sizes":"192x192","type":"image/png"},{"src":"/android-chrome-512x512.png","sizes":"512x512","type":"image/png"}],"theme_color":"#ffffff","background_color":"#ffffff","display":"standalone"} ================================================ FILE: backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/outputs/web/components/component-example.tsx ================================================ "use client"; import * as React from "react"; import { Example, ExampleWrapper } from "@/components/example"; import { AlertDialog, AlertDialogAction, AlertDialogCancel, AlertDialogContent, AlertDialogDescription, AlertDialogFooter, AlertDialogHeader, AlertDialogMedia, AlertDialogTitle, AlertDialogTrigger, } from "@/components/ui/alert-dialog"; import { Badge } from "@/components/ui/badge"; import { Button } from "@/components/ui/button"; import { Card, CardAction, CardContent, CardDescription, CardFooter, CardHeader, CardTitle, } from "@/components/ui/card"; import { Combobox, ComboboxContent, ComboboxEmpty, ComboboxInput, ComboboxItem, ComboboxList, } from "@/components/ui/combobox"; import { DropdownMenu, DropdownMenuCheckboxItem, DropdownMenuContent, DropdownMenuGroup, DropdownMenuItem, DropdownMenuLabel, DropdownMenuPortal, DropdownMenuRadioGroup, DropdownMenuRadioItem, DropdownMenuSeparator, DropdownMenuShortcut, DropdownMenuSub, DropdownMenuSubContent, DropdownMenuSubTrigger, DropdownMenuTrigger, } from "@/components/ui/dropdown-menu"; import { Field, FieldGroup, FieldLabel } from "@/components/ui/field"; import { Input } from "@/components/ui/input"; import { Select, SelectContent, SelectGroup, SelectItem, SelectTrigger, SelectValue, } from "@/components/ui/select"; import { Textarea } from "@/components/ui/textarea"; import { PlusIcon, BluetoothIcon, MoreVerticalIcon, FileIcon, FolderIcon, FolderOpenIcon, FileCodeIcon, MoreHorizontalIcon, FolderSearchIcon, SaveIcon, DownloadIcon, EyeIcon, LayoutIcon, PaletteIcon, SunIcon, MoonIcon, MonitorIcon, UserIcon, CreditCardIcon, SettingsIcon, KeyboardIcon, LanguagesIcon, BellIcon, MailIcon, ShieldIcon, HelpCircleIcon, FileTextIcon, LogOutIcon, } from "lucide-react"; export function ComponentExample() { return ( ); } function CardExample() { return (
    Photo by mymind on Unsplash Observability Plus is replacing Monitoring Switch to the improved way to explore your data, with natural language. Monitoring will no longer be available on the Pro plan in November, 2025 Allow accessory to connect? Do you want to allow the USB accessory to connect to this device? Don't allow Allow Warning ); } const frameworks = [ "Next.js", "SvelteKit", "Nuxt.js", "Remix", "Astro", ] as const; function FormExample() { const [notifications, setNotifications] = React.useState({ email: true, sms: false, push: true, }); const [theme, setTheme] = React.useState("light"); return ( User Information Please fill in your details below File New File ⌘N New Folder ⇧⌘N Open Recent Recent Projects Project Alpha Project Beta More Projects Project Gamma Project Delta Browse... Save ⌘S Export ⇧⌘E View setNotifications({ ...notifications, email: checked === true, }) } > Show Sidebar setNotifications({ ...notifications, sms: checked === true, }) } > Show Status Bar Theme Appearance Light Dark System Account Profile ⇧⌘P Billing Settings Preferences Keyboard Shortcuts Language Notifications Notification Types setNotifications({ ...notifications, push: checked === true, }) } > Push Notifications setNotifications({ ...notifications, email: checked === true, }) } > Email Notifications Privacy & Security Help & Support Documentation Sign Out ⇧⌘Q
    Name Role
    Framework No frameworks found. {(item) => ( {item} )} Comments
    ================================================ FILE: backend/tests/integration/tests/pruning/website/courses.html ================================================ Above Multi-purpose Free Bootstrap Responsive Template

    Courses

    Courses We Offer

    Lorem ipsum dolor sit amet, consectetur adipisicing elit. Dolores quae porro consequatur aliquam, incidunt eius magni provident, doloribus omnis minus temporibus perferendis nesciunt quam repellendus nulla nemo ipsum odit corrupti consequuntur possimus, vero mollitia velit ad consectetur. Alias, laborum excepturi nihil autem nemo numquam, ipsa architecto non, magni consequuntur quam.

    Heading Course

    Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Praesent vest sit amet, consec ibulum molestie lacus. Aenean nonummy hendrerit mauris. Phasellus porta.

    Heading Course

    Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Praesent vest sit amet, consec ibulum molestie lacus. Aenean nonummy hendrerit mauris. Phasellus porta.

    Heading Course

    Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Praesent vest sit amet, consec ibulum molestie lacus. Aenean nonummy hendrerit mauris. Phasellus porta.

    Heading Course

    Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Praesent vest sit amet, consec ibulum molestie lacus. Aenean nonummy hendrerit mauris. Phasellus porta.

    Heading Course

    Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Praesent vest sit amet, consec ibulum molestie lacus. Aenean nonummy hendrerit mauris. Phasellus porta.

    Heading Course

    Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Praesent vest sit amet, consec ibulum molestie lacus. Aenean nonummy hendrerit mauris. Phasellus porta.

    Lorem ipsum dolor sit amet, consectetur adipisicing elit. Dolores quae porro consequatur aliquam, incidunt eius magni provident, doloribus omnis minus temporibus perferendis nesciunt quam repellendus nulla nemo ipsum odit corrupti consequuntur possimus


    Web Development

    Lorem ipsum dolor sit amet, consectetur adipisicing elit. Dolores quae porro consequatur aliquam, incidunt eius magni provident, doloribus omnis minus temporibus perferendis nesciunt quam repellendus nulla nemo ipsum odit corrupti consequuntur possimus

    Mobile Development

    Lorem ipsum dolor sit amet, consectetur adipisicing elit. Dolores quae porro consequatur aliquam, incidunt eius magni provident, doloribus omnis minus temporibus perferendis nesciunt quam repellendus nulla nemo ipsum odit corrupti consequuntur possimus

    Responsive Design

    Lorem ipsum dolor sit amet, consectetur adipisicing elit. Dolores quae porro consequatur aliquam, incidunt eius magni provident, doloribus omnis minus temporibus perferendis nesciunt quam repellendus nulla nemo ipsum odit corrupti consequuntur possimus

    ================================================ FILE: backend/tests/integration/tests/pruning/website/css/animate.css ================================================ @charset "UTF-8"; /* Animate.css - http://daneden.me/animate Licensed under the MIT license Copyright (c) 2013 Daniel Eden Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ body { /* Addresses a small issue in webkit: http://bit.ly/NEdoDq */ -webkit-backface-visibility: hidden; } .animated { -webkit-animation-duration: 1s; -moz-animation-duration: 1s; -o-animation-duration: 1s; animation-duration: 1s; -webkit-animation-fill-mode: both; -moz-animation-fill-mode: both; -o-animation-fill-mode: both; animation-fill-mode: both; } .animated.hinge { -webkit-animation-duration: 2s; -moz-animation-duration: 2s; -o-animation-duration: 2s; animation-duration: 2s; } @-webkit-keyframes flash { 0%, 50%, 100% { opacity: 1; } 25%, 75% { opacity: 0; } } @-moz-keyframes flash { 0%, 50%, 100% { opacity: 1; } 25%, 75% { opacity: 0; } } @-o-keyframes flash { 0%, 50%, 100% { opacity: 1; } 25%, 75% { opacity: 0; } } @keyframes flash { 0%, 50%, 100% { opacity: 1; } 25%, 75% { opacity: 0; } } .flash { -webkit-animation-name: flash; -moz-animation-name: flash; -o-animation-name: flash; animation-name: flash; } @-webkit-keyframes shake { 0%, 100% { -webkit-transform: translateX(0); } 10%, 30%, 50%, 70%, 90% { -webkit-transform: translateX(-10px); } 20%, 40%, 60%, 80% { -webkit-transform: translateX(10px); } } @-moz-keyframes shake { 0%, 100% { -moz-transform: translateX(0); } 10%, 30%, 50%, 70%, 90% { -moz-transform: translateX(-10px); } 20%, 40%, 60%, 80% { -moz-transform: translateX(10px); } } @-o-keyframes shake { 0%, 100% { -o-transform: translateX(0); } 10%, 30%, 50%, 70%, 90% { -o-transform: translateX(-10px); } 20%, 40%, 60%, 80% { -o-transform: translateX(10px); } } @keyframes shake { 0%, 100% { transform: translateX(0); } 10%, 30%, 50%, 70%, 90% { transform: translateX(-10px); } 20%, 40%, 60%, 80% { transform: translateX(10px); } } .shake { -webkit-animation-name: shake; -moz-animation-name: shake; -o-animation-name: shake; animation-name: shake; } @-webkit-keyframes bounce { 0%, 20%, 50%, 80%, 100% { -webkit-transform: translateY(0); } 40% { -webkit-transform: translateY(-30px); } 60% { -webkit-transform: translateY(-15px); } } @-moz-keyframes bounce { 0%, 20%, 50%, 80%, 100% { -moz-transform: translateY(0); } 40% { -moz-transform: translateY(-30px); } 60% { -moz-transform: translateY(-15px); } } @-o-keyframes bounce { 0%, 20%, 50%, 80%, 100% { -o-transform: translateY(0); } 40% { -o-transform: translateY(-30px); } 60% { -o-transform: translateY(-15px); } } @keyframes bounce { 0%, 20%, 50%, 80%, 100% { transform: translateY(0); } 40% { transform: translateY(-30px); } 60% { transform: translateY(-15px); } } .bounce { -webkit-animation-name: bounce; -moz-animation-name: bounce; -o-animation-name: bounce; animation-name: bounce; } @-webkit-keyframes tada { 0% { -webkit-transform: scale(1); } 10%, 20% { -webkit-transform: scale(0.9) rotate(-3deg); } 30%, 50%, 70%, 90% { -webkit-transform: scale(1.1) rotate(3deg); } 40%, 60%, 80% { -webkit-transform: scale(1.1) rotate(-3deg); } 100% { -webkit-transform: scale(1) rotate(0); } } @-moz-keyframes tada { 0% { -moz-transform: scale(1); } 10%, 20% { -moz-transform: scale(0.9) rotate(-3deg); } 30%, 50%, 70%, 90% { -moz-transform: scale(1.1) rotate(3deg); } 40%, 60%, 80% { -moz-transform: scale(1.1) rotate(-3deg); } 100% { -moz-transform: scale(1) rotate(0); } } @-o-keyframes tada { 0% { -o-transform: scale(1); } 10%, 20% { -o-transform: scale(0.9) rotate(-3deg); } 30%, 50%, 70%, 90% { -o-transform: scale(1.1) rotate(3deg); } 40%, 60%, 80% { -o-transform: scale(1.1) rotate(-3deg); } 100% { -o-transform: scale(1) rotate(0); } } @keyframes tada { 0% { transform: scale(1); } 10%, 20% { transform: scale(0.9) rotate(-3deg); } 30%, 50%, 70%, 90% { transform: scale(1.1) rotate(3deg); } 40%, 60%, 80% { transform: scale(1.1) rotate(-3deg); } 100% { transform: scale(1) rotate(0); } } .tada { -webkit-animation-name: tada; -moz-animation-name: tada; -o-animation-name: tada; animation-name: tada; } @-webkit-keyframes swing { 20%, 40%, 60%, 80%, 100% { -webkit-transform-origin: top center; } 20% { -webkit-transform: rotate(15deg); } 40% { -webkit-transform: rotate(-10deg); } 60% { -webkit-transform: rotate(5deg); } 80% { -webkit-transform: rotate(-5deg); } 100% { -webkit-transform: rotate(0deg); } } @-moz-keyframes swing { 20% { -moz-transform: rotate(15deg); } 40% { -moz-transform: rotate(-10deg); } 60% { -moz-transform: rotate(5deg); } 80% { -moz-transform: rotate(-5deg); } 100% { -moz-transform: rotate(0deg); } } @-o-keyframes swing { 20% { -o-transform: rotate(15deg); } 40% { -o-transform: rotate(-10deg); } 60% { -o-transform: rotate(5deg); } 80% { -o-transform: rotate(-5deg); } 100% { -o-transform: rotate(0deg); } } @keyframes swing { 20% { transform: rotate(15deg); } 40% { transform: rotate(-10deg); } 60% { transform: rotate(5deg); } 80% { transform: rotate(-5deg); } 100% { transform: rotate(0deg); } } .swing { -webkit-transform-origin: top center; -moz-transform-origin: top center; -o-transform-origin: top center; transform-origin: top center; -webkit-animation-name: swing; -moz-animation-name: swing; -o-animation-name: swing; animation-name: swing; } /* originally authored by Nick Pettit - https://github.com/nickpettit/glide */ @-webkit-keyframes wobble { 0% { -webkit-transform: translateX(0%); } 15% { -webkit-transform: translateX(-25%) rotate(-5deg); } 30% { -webkit-transform: translateX(20%) rotate(3deg); } 45% { -webkit-transform: translateX(-15%) rotate(-3deg); } 60% { -webkit-transform: translateX(10%) rotate(2deg); } 75% { -webkit-transform: translateX(-5%) rotate(-1deg); } 100% { -webkit-transform: translateX(0%); } } @-moz-keyframes wobble { 0% { -moz-transform: translateX(0%); } 15% { -moz-transform: translateX(-25%) rotate(-5deg); } 30% { -moz-transform: translateX(20%) rotate(3deg); } 45% { -moz-transform: translateX(-15%) rotate(-3deg); } 60% { -moz-transform: translateX(10%) rotate(2deg); } 75% { -moz-transform: translateX(-5%) rotate(-1deg); } 100% { -moz-transform: translateX(0%); } } @-o-keyframes wobble { 0% { -o-transform: translateX(0%); } 15% { -o-transform: translateX(-25%) rotate(-5deg); } 30% { -o-transform: translateX(20%) rotate(3deg); } 45% { -o-transform: translateX(-15%) rotate(-3deg); } 60% { -o-transform: translateX(10%) rotate(2deg); } 75% { -o-transform: translateX(-5%) rotate(-1deg); } 100% { -o-transform: translateX(0%); } } @keyframes wobble { 0% { transform: translateX(0%); } 15% { transform: translateX(-25%) rotate(-5deg); } 30% { transform: translateX(20%) rotate(3deg); } 45% { transform: translateX(-15%) rotate(-3deg); } 60% { transform: translateX(10%) rotate(2deg); } 75% { transform: translateX(-5%) rotate(-1deg); } 100% { transform: translateX(0%); } } .wobble { -webkit-animation-name: wobble; -moz-animation-name: wobble; -o-animation-name: wobble; animation-name: wobble; } /* originally authored by Nick Pettit - https://github.com/nickpettit/glide */ @-webkit-keyframes pulse { 0% { -webkit-transform: scale(1); } 50% { -webkit-transform: scale(1.1); } 100% { -webkit-transform: scale(1); } } @-moz-keyframes pulse { 0% { -moz-transform: scale(1); } 50% { -moz-transform: scale(1.1); } 100% { -moz-transform: scale(1); } } @-o-keyframes pulse { 0% { -o-transform: scale(1); } 50% { -o-transform: scale(1.1); } 100% { -o-transform: scale(1); } } @keyframes pulse { 0% { transform: scale(1); } 50% { transform: scale(1.1); } 100% { transform: scale(1); } } .pulse { -webkit-animation-name: pulse; -moz-animation-name: pulse; -o-animation-name: pulse; animation-name: pulse; } @-webkit-keyframes flip { 0% { -webkit-transform: perspective(400px) rotateY(0); -webkit-animation-timing-function: ease-out; } 40% { -webkit-transform: perspective(400px) translateZ(150px) rotateY(170deg); -webkit-animation-timing-function: ease-out; } 50% { -webkit-transform: perspective(400px) translateZ(150px) rotateY(190deg) scale(1); -webkit-animation-timing-function: ease-in; } 80% { -webkit-transform: perspective(400px) rotateY(360deg) scale(0.95); -webkit-animation-timing-function: ease-in; } 100% { -webkit-transform: perspective(400px) scale(1); -webkit-animation-timing-function: ease-in; } } @-moz-keyframes flip { 0% { -moz-transform: perspective(400px) rotateY(0); -moz-animation-timing-function: ease-out; } 40% { -moz-transform: perspective(400px) translateZ(150px) rotateY(170deg); -moz-animation-timing-function: ease-out; } 50% { -moz-transform: perspective(400px) translateZ(150px) rotateY(190deg) scale(1); -moz-animation-timing-function: ease-in; } 80% { -moz-transform: perspective(400px) rotateY(360deg) scale(0.95); -moz-animation-timing-function: ease-in; } 100% { -moz-transform: perspective(400px) scale(1); -moz-animation-timing-function: ease-in; } } @-o-keyframes flip { 0% { -o-transform: perspective(400px) rotateY(0); -o-animation-timing-function: ease-out; } 40% { -o-transform: perspective(400px) translateZ(150px) rotateY(170deg); -o-animation-timing-function: ease-out; } 50% { -o-transform: perspective(400px) translateZ(150px) rotateY(190deg) scale(1); -o-animation-timing-function: ease-in; } 80% { -o-transform: perspective(400px) rotateY(360deg) scale(0.95); -o-animation-timing-function: ease-in; } 100% { -o-transform: perspective(400px) scale(1); -o-animation-timing-function: ease-in; } } @keyframes flip { 0% { transform: perspective(400px) rotateY(0); animation-timing-function: ease-out; } 40% { transform: perspective(400px) translateZ(150px) rotateY(170deg); animation-timing-function: ease-out; } 50% { transform: perspective(400px) translateZ(150px) rotateY(190deg) scale(1); animation-timing-function: ease-in; } 80% { transform: perspective(400px) rotateY(360deg) scale(0.95); animation-timing-function: ease-in; } 100% { transform: perspective(400px) scale(1); animation-timing-function: ease-in; } } .flip { -webkit-backface-visibility: visible !important; -webkit-animation-name: flip; -moz-backface-visibility: visible !important; -moz-animation-name: flip; -o-backface-visibility: visible !important; -o-animation-name: flip; backface-visibility: visible !important; animation-name: flip; } @-webkit-keyframes flipInX { 0% { -webkit-transform: perspective(400px) rotateX(90deg); opacity: 0; } 40% { -webkit-transform: perspective(400px) rotateX(-10deg); } 70% { -webkit-transform: perspective(400px) rotateX(10deg); } 100% { -webkit-transform: perspective(400px) rotateX(0deg); opacity: 1; } } @-moz-keyframes flipInX { 0% { -moz-transform: perspective(400px) rotateX(90deg); opacity: 0; } 40% { -moz-transform: perspective(400px) rotateX(-10deg); } 70% { -moz-transform: perspective(400px) rotateX(10deg); } 100% { -moz-transform: perspective(400px) rotateX(0deg); opacity: 1; } } @-o-keyframes flipInX { 0% { -o-transform: perspective(400px) rotateX(90deg); opacity: 0; } 40% { -o-transform: perspective(400px) rotateX(-10deg); } 70% { -o-transform: perspective(400px) rotateX(10deg); } 100% { -o-transform: perspective(400px) rotateX(0deg); opacity: 1; } } @keyframes flipInX { 0% { transform: perspective(400px) rotateX(90deg); opacity: 0; } 40% { transform: perspective(400px) rotateX(-10deg); } 70% { transform: perspective(400px) rotateX(10deg); } 100% { transform: perspective(400px) rotateX(0deg); opacity: 1; } } .flipInX { -webkit-backface-visibility: visible !important; -webkit-animation-name: flipInX; -moz-backface-visibility: visible !important; -moz-animation-name: flipInX; -o-backface-visibility: visible !important; -o-animation-name: flipInX; backface-visibility: visible !important; animation-name: flipInX; } @-webkit-keyframes flipOutX { 0% { -webkit-transform: perspective(400px) rotateX(0deg); opacity: 1; } 100% { -webkit-transform: perspective(400px) rotateX(90deg); opacity: 0; } } @-moz-keyframes flipOutX { 0% { -moz-transform: perspective(400px) rotateX(0deg); opacity: 1; } 100% { -moz-transform: perspective(400px) rotateX(90deg); opacity: 0; } } @-o-keyframes flipOutX { 0% { -o-transform: perspective(400px) rotateX(0deg); opacity: 1; } 100% { -o-transform: perspective(400px) rotateX(90deg); opacity: 0; } } @keyframes flipOutX { 0% { transform: perspective(400px) rotateX(0deg); opacity: 1; } 100% { transform: perspective(400px) rotateX(90deg); opacity: 0; } } .flipOutX { -webkit-animation-name: flipOutX; -webkit-backface-visibility: visible !important; -moz-animation-name: flipOutX; -moz-backface-visibility: visible !important; -o-animation-name: flipOutX; -o-backface-visibility: visible !important; animation-name: flipOutX; backface-visibility: visible !important; } @-webkit-keyframes flipInY { 0% { -webkit-transform: perspective(400px) rotateY(90deg); opacity: 0; } 40% { -webkit-transform: perspective(400px) rotateY(-10deg); } 70% { -webkit-transform: perspective(400px) rotateY(10deg); } 100% { -webkit-transform: perspective(400px) rotateY(0deg); opacity: 1; } } @-moz-keyframes flipInY { 0% { -moz-transform: perspective(400px) rotateY(90deg); opacity: 0; } 40% { -moz-transform: perspective(400px) rotateY(-10deg); } 70% { -moz-transform: perspective(400px) rotateY(10deg); } 100% { -moz-transform: perspective(400px) rotateY(0deg); opacity: 1; } } @-o-keyframes flipInY { 0% { -o-transform: perspective(400px) rotateY(90deg); opacity: 0; } 40% { -o-transform: perspective(400px) rotateY(-10deg); } 70% { -o-transform: perspective(400px) rotateY(10deg); } 100% { -o-transform: perspective(400px) rotateY(0deg); opacity: 1; } } @keyframes flipInY { 0% { transform: perspective(400px) rotateY(90deg); opacity: 0; } 40% { transform: perspective(400px) rotateY(-10deg); } 70% { transform: perspective(400px) rotateY(10deg); } 100% { transform: perspective(400px) rotateY(0deg); opacity: 1; } } .flipInY { -webkit-backface-visibility: visible !important; -webkit-animation-name: flipInY; -moz-backface-visibility: visible !important; -moz-animation-name: flipInY; -o-backface-visibility: visible !important; -o-animation-name: flipInY; backface-visibility: visible !important; animation-name: flipInY; } @-webkit-keyframes flipOutY { 0% { -webkit-transform: perspective(400px) rotateY(0deg); opacity: 1; } 100% { -webkit-transform: perspective(400px) rotateY(90deg); opacity: 0; } } @-moz-keyframes flipOutY { 0% { -moz-transform: perspective(400px) rotateY(0deg); opacity: 1; } 100% { -moz-transform: perspective(400px) rotateY(90deg); opacity: 0; } } @-o-keyframes flipOutY { 0% { -o-transform: perspective(400px) rotateY(0deg); opacity: 1; } 100% { -o-transform: perspective(400px) rotateY(90deg); opacity: 0; } } @keyframes flipOutY { 0% { transform: perspective(400px) rotateY(0deg); opacity: 1; } 100% { transform: perspective(400px) rotateY(90deg); opacity: 0; } } .flipOutY { -webkit-backface-visibility: visible !important; -webkit-animation-name: flipOutY; -moz-backface-visibility: visible !important; -moz-animation-name: flipOutY; -o-backface-visibility: visible !important; -o-animation-name: flipOutY; backface-visibility: visible !important; animation-name: flipOutY; } @-webkit-keyframes fadeIn { 0% { opacity: 0; } 100% { opacity: 1; } } @-moz-keyframes fadeIn { 0% { opacity: 0; } 100% { opacity: 1; } } @-o-keyframes fadeIn { 0% { opacity: 0; } 100% { opacity: 1; } } @keyframes fadeIn { 0% { opacity: 0; } 100% { opacity: 1; } } .fadeIn { -webkit-animation-name: fadeIn; -moz-animation-name: fadeIn; -o-animation-name: fadeIn; animation-name: fadeIn; } @-webkit-keyframes fadeInUp { 0% { opacity: 0; -webkit-transform: translateY(20px); } 100% { opacity: 1; -webkit-transform: translateY(0); } } @-moz-keyframes fadeInUp { 0% { opacity: 0; -moz-transform: translateY(20px); } 100% { opacity: 1; -moz-transform: translateY(0); } } @-o-keyframes fadeInUp { 0% { opacity: 0; -o-transform: translateY(20px); } 100% { opacity: 1; -o-transform: translateY(0); } } @keyframes fadeInUp { 0% { opacity: 0; transform: translateY(20px); } 100% { opacity: 1; transform: translateY(0); } } .fadeInUp { -webkit-animation-name: fadeInUp; -moz-animation-name: fadeInUp; -o-animation-name: fadeInUp; animation-name: fadeInUp; } @-webkit-keyframes fadeInDown { 0% { opacity: 0; -webkit-transform: translateY(-20px); } 100% { opacity: 1; -webkit-transform: translateY(0); } } @-moz-keyframes fadeInDown { 0% { opacity: 0; -moz-transform: translateY(-20px); } 100% { opacity: 1; -moz-transform: translateY(0); } } @-o-keyframes fadeInDown { 0% { opacity: 0; -o-transform: translateY(-20px); } 100% { opacity: 1; -o-transform: translateY(0); } } @keyframes fadeInDown { 0% { opacity: 0; transform: translateY(-20px); } 100% { opacity: 1; transform: translateY(0); } } .fadeInDown { -webkit-animation-name: fadeInDown; -moz-animation-name: fadeInDown; -o-animation-name: fadeInDown; animation-name: fadeInDown; } @-webkit-keyframes fadeInLeft { 0% { opacity: 0; -webkit-transform: translateX(-20px); } 100% { opacity: 1; -webkit-transform: translateX(0); } } @-moz-keyframes fadeInLeft { 0% { opacity: 0; -moz-transform: translateX(-20px); } 100% { opacity: 1; -moz-transform: translateX(0); } } @-o-keyframes fadeInLeft { 0% { opacity: 0; -o-transform: translateX(-20px); } 100% { opacity: 1; -o-transform: translateX(0); } } @keyframes fadeInLeft { 0% { opacity: 0; transform: translateX(-20px); } 100% { opacity: 1; transform: translateX(0); } } .fadeInLeft { -webkit-animation-name: fadeInLeft; -moz-animation-name: fadeInLeft; -o-animation-name: fadeInLeft; animation-name: fadeInLeft; } @-webkit-keyframes fadeInRight { 0% { opacity: 0; -webkit-transform: translateX(20px); } 100% { opacity: 1; -webkit-transform: translateX(0); } } @-moz-keyframes fadeInRight { 0% { opacity: 0; -moz-transform: translateX(20px); } 100% { opacity: 1; -moz-transform: translateX(0); } } @-o-keyframes fadeInRight { 0% { opacity: 0; -o-transform: translateX(20px); } 100% { opacity: 1; -o-transform: translateX(0); } } @keyframes fadeInRight { 0% { opacity: 0; transform: translateX(20px); } 100% { opacity: 1; transform: translateX(0); } } .fadeInRight { -webkit-animation-name: fadeInRight; -moz-animation-name: fadeInRight; -o-animation-name: fadeInRight; animation-name: fadeInRight; } @-webkit-keyframes fadeInUpBig { 0% { opacity: 0; -webkit-transform: translateY(2000px); } 100% { opacity: 1; -webkit-transform: translateY(0); } } @-moz-keyframes fadeInUpBig { 0% { opacity: 0; -moz-transform: translateY(2000px); } 100% { opacity: 1; -moz-transform: translateY(0); } } @-o-keyframes fadeInUpBig { 0% { opacity: 0; -o-transform: translateY(2000px); } 100% { opacity: 1; -o-transform: translateY(0); } } @keyframes fadeInUpBig { 0% { opacity: 0; transform: translateY(2000px); } 100% { opacity: 1; transform: translateY(0); } } .fadeInUpBig { -webkit-animation-name: fadeInUpBig; -moz-animation-name: fadeInUpBig; -o-animation-name: fadeInUpBig; animation-name: fadeInUpBig; } @-webkit-keyframes fadeInDownBig { 0% { opacity: 0; -webkit-transform: translateY(-2000px); } 100% { opacity: 1; -webkit-transform: translateY(0); } } @-moz-keyframes fadeInDownBig { 0% { opacity: 0; -moz-transform: translateY(-2000px); } 100% { opacity: 1; -moz-transform: translateY(0); } } @-o-keyframes fadeInDownBig { 0% { opacity: 0; -o-transform: translateY(-2000px); } 100% { opacity: 1; -o-transform: translateY(0); } } @keyframes fadeInDownBig { 0% { opacity: 0; transform: translateY(-2000px); } 100% { opacity: 1; transform: translateY(0); } } .fadeInDownBig { -webkit-animation-name: fadeInDownBig; -moz-animation-name: fadeInDownBig; -o-animation-name: fadeInDownBig; animation-name: fadeInDownBig; } @-webkit-keyframes fadeInLeftBig { 0% { opacity: 0; -webkit-transform: translateX(-2000px); } 100% { opacity: 1; -webkit-transform: translateX(0); } } @-moz-keyframes fadeInLeftBig { 0% { opacity: 0; -moz-transform: translateX(-2000px); } 100% { opacity: 1; -moz-transform: translateX(0); } } @-o-keyframes fadeInLeftBig { 0% { opacity: 0; -o-transform: translateX(-2000px); } 100% { opacity: 1; -o-transform: translateX(0); } } @keyframes fadeInLeftBig { 0% { opacity: 0; transform: translateX(-2000px); } 100% { opacity: 1; transform: translateX(0); } } .fadeInLeftBig { -webkit-animation-name: fadeInLeftBig; -moz-animation-name: fadeInLeftBig; -o-animation-name: fadeInLeftBig; animation-name: fadeInLeftBig; } @-webkit-keyframes fadeInRightBig { 0% { opacity: 0; -webkit-transform: translateX(2000px); } 100% { opacity: 1; -webkit-transform: translateX(0); } } @-moz-keyframes fadeInRightBig { 0% { opacity: 0; -moz-transform: translateX(2000px); } 100% { opacity: 1; -moz-transform: translateX(0); } } @-o-keyframes fadeInRightBig { 0% { opacity: 0; -o-transform: translateX(2000px); } 100% { opacity: 1; -o-transform: translateX(0); } } @keyframes fadeInRightBig { 0% { opacity: 0; transform: translateX(2000px); } 100% { opacity: 1; transform: translateX(0); } } .fadeInRightBig { -webkit-animation-name: fadeInRightBig; -moz-animation-name: fadeInRightBig; -o-animation-name: fadeInRightBig; animation-name: fadeInRightBig; } @-webkit-keyframes fadeOut { 0% { opacity: 1; } 100% { opacity: 0; } } @-moz-keyframes fadeOut { 0% { opacity: 1; } 100% { opacity: 0; } } @-o-keyframes fadeOut { 0% { opacity: 1; } 100% { opacity: 0; } } @keyframes fadeOut { 0% { opacity: 1; } 100% { opacity: 0; } } .fadeOut { -webkit-animation-name: fadeOut; -moz-animation-name: fadeOut; -o-animation-name: fadeOut; animation-name: fadeOut; } @-webkit-keyframes fadeOutUp { 0% { opacity: 1; -webkit-transform: translateY(0); } 100% { opacity: 0; -webkit-transform: translateY(-20px); } } @-moz-keyframes fadeOutUp { 0% { opacity: 1; -moz-transform: translateY(0); } 100% { opacity: 0; -moz-transform: translateY(-20px); } } @-o-keyframes fadeOutUp { 0% { opacity: 1; -o-transform: translateY(0); } 100% { opacity: 0; -o-transform: translateY(-20px); } } @keyframes fadeOutUp { 0% { opacity: 1; transform: translateY(0); } 100% { opacity: 0; transform: translateY(-20px); } } .fadeOutUp { -webkit-animation-name: fadeOutUp; -moz-animation-name: fadeOutUp; -o-animation-name: fadeOutUp; animation-name: fadeOutUp; } @-webkit-keyframes fadeOutDown { 0% { opacity: 1; -webkit-transform: translateY(0); } 100% { opacity: 0; -webkit-transform: translateY(20px); } } @-moz-keyframes fadeOutDown { 0% { opacity: 1; -moz-transform: translateY(0); } 100% { opacity: 0; -moz-transform: translateY(20px); } } @-o-keyframes fadeOutDown { 0% { opacity: 1; -o-transform: translateY(0); } 100% { opacity: 0; -o-transform: translateY(20px); } } @keyframes fadeOutDown { 0% { opacity: 1; transform: translateY(0); } 100% { opacity: 0; transform: translateY(20px); } } .fadeOutDown { -webkit-animation-name: fadeOutDown; -moz-animation-name: fadeOutDown; -o-animation-name: fadeOutDown; animation-name: fadeOutDown; } @-webkit-keyframes fadeOutLeft { 0% { opacity: 1; -webkit-transform: translateX(0); } 100% { opacity: 0; -webkit-transform: translateX(-20px); } } @-moz-keyframes fadeOutLeft { 0% { opacity: 1; -moz-transform: translateX(0); } 100% { opacity: 0; -moz-transform: translateX(-20px); } } @-o-keyframes fadeOutLeft { 0% { opacity: 1; -o-transform: translateX(0); } 100% { opacity: 0; -o-transform: translateX(-20px); } } @keyframes fadeOutLeft { 0% { opacity: 1; transform: translateX(0); } 100% { opacity: 0; transform: translateX(-20px); } } .fadeOutLeft { -webkit-animation-name: fadeOutLeft; -moz-animation-name: fadeOutLeft; -o-animation-name: fadeOutLeft; animation-name: fadeOutLeft; } @-webkit-keyframes fadeOutRight { 0% { opacity: 1; -webkit-transform: translateX(0); } 100% { opacity: 0; -webkit-transform: translateX(20px); } } @-moz-keyframes fadeOutRight { 0% { opacity: 1; -moz-transform: translateX(0); } 100% { opacity: 0; -moz-transform: translateX(20px); } } @-o-keyframes fadeOutRight { 0% { opacity: 1; -o-transform: translateX(0); } 100% { opacity: 0; -o-transform: translateX(20px); } } @keyframes fadeOutRight { 0% { opacity: 1; transform: translateX(0); } 100% { opacity: 0; transform: translateX(20px); } } .fadeOutRight { -webkit-animation-name: fadeOutRight; -moz-animation-name: fadeOutRight; -o-animation-name: fadeOutRight; animation-name: fadeOutRight; } @-webkit-keyframes fadeOutUpBig { 0% { opacity: 1; -webkit-transform: translateY(0); } 100% { opacity: 0; -webkit-transform: translateY(-2000px); } } @-moz-keyframes fadeOutUpBig { 0% { opacity: 1; -moz-transform: translateY(0); } 100% { opacity: 0; -moz-transform: translateY(-2000px); } } @-o-keyframes fadeOutUpBig { 0% { opacity: 1; -o-transform: translateY(0); } 100% { opacity: 0; -o-transform: translateY(-2000px); } } @keyframes fadeOutUpBig { 0% { opacity: 1; transform: translateY(0); } 100% { opacity: 0; transform: translateY(-2000px); } } .fadeOutUpBig { -webkit-animation-name: fadeOutUpBig; -moz-animation-name: fadeOutUpBig; -o-animation-name: fadeOutUpBig; animation-name: fadeOutUpBig; } @-webkit-keyframes fadeOutDownBig { 0% { opacity: 1; -webkit-transform: translateY(0); } 100% { opacity: 0; -webkit-transform: translateY(2000px); } } @-moz-keyframes fadeOutDownBig { 0% { opacity: 1; -moz-transform: translateY(0); } 100% { opacity: 0; -moz-transform: translateY(2000px); } } @-o-keyframes fadeOutDownBig { 0% { opacity: 1; -o-transform: translateY(0); } 100% { opacity: 0; -o-transform: translateY(2000px); } } @keyframes fadeOutDownBig { 0% { opacity: 1; transform: translateY(0); } 100% { opacity: 0; transform: translateY(2000px); } } .fadeOutDownBig { -webkit-animation-name: fadeOutDownBig; -moz-animation-name: fadeOutDownBig; -o-animation-name: fadeOutDownBig; animation-name: fadeOutDownBig; } @-webkit-keyframes fadeOutLeftBig { 0% { opacity: 1; -webkit-transform: translateX(0); } 100% { opacity: 0; -webkit-transform: translateX(-2000px); } } @-moz-keyframes fadeOutLeftBig { 0% { opacity: 1; -moz-transform: translateX(0); } 100% { opacity: 0; -moz-transform: translateX(-2000px); } } @-o-keyframes fadeOutLeftBig { 0% { opacity: 1; -o-transform: translateX(0); } 100% { opacity: 0; -o-transform: translateX(-2000px); } } @keyframes fadeOutLeftBig { 0% { opacity: 1; transform: translateX(0); } 100% { opacity: 0; transform: translateX(-2000px); } } .fadeOutLeftBig { -webkit-animation-name: fadeOutLeftBig; -moz-animation-name: fadeOutLeftBig; -o-animation-name: fadeOutLeftBig; animation-name: fadeOutLeftBig; } @-webkit-keyframes fadeOutRightBig { 0% { opacity: 1; -webkit-transform: translateX(0); } 100% { opacity: 0; -webkit-transform: translateX(2000px); } } @-moz-keyframes fadeOutRightBig { 0% { opacity: 1; -moz-transform: translateX(0); } 100% { opacity: 0; -moz-transform: translateX(2000px); } } @-o-keyframes fadeOutRightBig { 0% { opacity: 1; -o-transform: translateX(0); } 100% { opacity: 0; -o-transform: translateX(2000px); } } @keyframes fadeOutRightBig { 0% { opacity: 1; transform: translateX(0); } 100% { opacity: 0; transform: translateX(2000px); } } .fadeOutRightBig { -webkit-animation-name: fadeOutRightBig; -moz-animation-name: fadeOutRightBig; -o-animation-name: fadeOutRightBig; animation-name: fadeOutRightBig; } @-webkit-keyframes bounceIn { 0% { opacity: 0; -webkit-transform: scale(0.3); } 50% { opacity: 1; -webkit-transform: scale(1.05); } 70% { -webkit-transform: scale(0.9); } 100% { -webkit-transform: scale(1); } } @-moz-keyframes bounceIn { 0% { opacity: 0; -moz-transform: scale(0.3); } 50% { opacity: 1; -moz-transform: scale(1.05); } 70% { -moz-transform: scale(0.9); } 100% { -moz-transform: scale(1); } } @-o-keyframes bounceIn { 0% { opacity: 0; -o-transform: scale(0.3); } 50% { opacity: 1; -o-transform: scale(1.05); } 70% { -o-transform: scale(0.9); } 100% { -o-transform: scale(1); } } @keyframes bounceIn { 0% { opacity: 0; transform: scale(0.3); } 50% { opacity: 1; transform: scale(1.05); } 70% { transform: scale(0.9); } 100% { transform: scale(1); } } .bounceIn { -webkit-animation-name: bounceIn; -moz-animation-name: bounceIn; -o-animation-name: bounceIn; animation-name: bounceIn; } @-webkit-keyframes bounceInUp { 0% { opacity: 0; -webkit-transform: translateY(2000px); } 60% { opacity: 1; -webkit-transform: translateY(-30px); } 80% { -webkit-transform: translateY(10px); } 100% { -webkit-transform: translateY(0); } } @-moz-keyframes bounceInUp { 0% { opacity: 0; -moz-transform: translateY(2000px); } 60% { opacity: 1; -moz-transform: translateY(-30px); } 80% { -moz-transform: translateY(10px); } 100% { -moz-transform: translateY(0); } } @-o-keyframes bounceInUp { 0% { opacity: 0; -o-transform: translateY(2000px); } 60% { opacity: 1; -o-transform: translateY(-30px); } 80% { -o-transform: translateY(10px); } 100% { -o-transform: translateY(0); } } @keyframes bounceInUp { 0% { opacity: 0; transform: translateY(2000px); } 60% { opacity: 1; transform: translateY(-30px); } 80% { transform: translateY(10px); } 100% { transform: translateY(0); } } .bounceInUp { -webkit-animation-name: bounceInUp; -moz-animation-name: bounceInUp; -o-animation-name: bounceInUp; animation-name: bounceInUp; } @-webkit-keyframes bounceInDown { 0% { opacity: 0; -webkit-transform: translateY(-2000px); } 60% { opacity: 1; -webkit-transform: translateY(30px); } 80% { -webkit-transform: translateY(-10px); } 100% { -webkit-transform: translateY(0); } } @-moz-keyframes bounceInDown { 0% { opacity: 0; -moz-transform: translateY(-2000px); } 60% { opacity: 1; -moz-transform: translateY(30px); } 80% { -moz-transform: translateY(-10px); } 100% { -moz-transform: translateY(0); } } @-o-keyframes bounceInDown { 0% { opacity: 0; -o-transform: translateY(-2000px); } 60% { opacity: 1; -o-transform: translateY(30px); } 80% { -o-transform: translateY(-10px); } 100% { -o-transform: translateY(0); } } @keyframes bounceInDown { 0% { opacity: 0; transform: translateY(-2000px); } 60% { opacity: 1; transform: translateY(30px); } 80% { transform: translateY(-10px); } 100% { transform: translateY(0); } } .bounceInDown { -webkit-animation-name: bounceInDown; -moz-animation-name: bounceInDown; -o-animation-name: bounceInDown; animation-name: bounceInDown; } @-webkit-keyframes bounceInLeft { 0% { opacity: 0; -webkit-transform: translateX(-2000px); } 60% { opacity: 1; -webkit-transform: translateX(30px); } 80% { -webkit-transform: translateX(-10px); } 100% { -webkit-transform: translateX(0); } } @-moz-keyframes bounceInLeft { 0% { opacity: 0; -moz-transform: translateX(-2000px); } 60% { opacity: 1; -moz-transform: translateX(30px); } 80% { -moz-transform: translateX(-10px); } 100% { -moz-transform: translateX(0); } } @-o-keyframes bounceInLeft { 0% { opacity: 0; -o-transform: translateX(-2000px); } 60% { opacity: 1; -o-transform: translateX(30px); } 80% { -o-transform: translateX(-10px); } 100% { -o-transform: translateX(0); } } @keyframes bounceInLeft { 0% { opacity: 0; transform: translateX(-2000px); } 60% { opacity: 1; transform: translateX(30px); } 80% { transform: translateX(-10px); } 100% { transform: translateX(0); } } .bounceInLeft { -webkit-animation-name: bounceInLeft; -moz-animation-name: bounceInLeft; -o-animation-name: bounceInLeft; animation-name: bounceInLeft; } @-webkit-keyframes bounceInRight { 0% { opacity: 0; -webkit-transform: translateX(2000px); } 60% { opacity: 1; -webkit-transform: translateX(-30px); } 80% { -webkit-transform: translateX(10px); } 100% { -webkit-transform: translateX(0); } } @-moz-keyframes bounceInRight { 0% { opacity: 0; -moz-transform: translateX(2000px); } 60% { opacity: 1; -moz-transform: translateX(-30px); } 80% { -moz-transform: translateX(10px); } 100% { -moz-transform: translateX(0); } } @-o-keyframes bounceInRight { 0% { opacity: 0; -o-transform: translateX(2000px); } 60% { opacity: 1; -o-transform: translateX(-30px); } 80% { -o-transform: translateX(10px); } 100% { -o-transform: translateX(0); } } @keyframes bounceInRight { 0% { opacity: 0; transform: translateX(2000px); } 60% { opacity: 1; transform: translateX(-30px); } 80% { transform: translateX(10px); } 100% { transform: translateX(0); } } .bounceInRight { -webkit-animation-name: bounceInRight; -moz-animation-name: bounceInRight; -o-animation-name: bounceInRight; animation-name: bounceInRight; } @-webkit-keyframes bounceOut { 0% { -webkit-transform: scale(1); } 25% { -webkit-transform: scale(0.95); } 50% { opacity: 1; -webkit-transform: scale(1.1); } 100% { opacity: 0; -webkit-transform: scale(0.3); } } @-moz-keyframes bounceOut { 0% { -moz-transform: scale(1); } 25% { -moz-transform: scale(0.95); } 50% { opacity: 1; -moz-transform: scale(1.1); } 100% { opacity: 0; -moz-transform: scale(0.3); } } @-o-keyframes bounceOut { 0% { -o-transform: scale(1); } 25% { -o-transform: scale(0.95); } 50% { opacity: 1; -o-transform: scale(1.1); } 100% { opacity: 0; -o-transform: scale(0.3); } } @keyframes bounceOut { 0% { transform: scale(1); } 25% { transform: scale(0.95); } 50% { opacity: 1; transform: scale(1.1); } 100% { opacity: 0; transform: scale(0.3); } } .bounceOut { -webkit-animation-name: bounceOut; -moz-animation-name: bounceOut; -o-animation-name: bounceOut; animation-name: bounceOut; } @-webkit-keyframes bounceOutUp { 0% { -webkit-transform: translateY(0); } 20% { opacity: 1; -webkit-transform: translateY(20px); } 100% { opacity: 0; -webkit-transform: translateY(-2000px); } } @-moz-keyframes bounceOutUp { 0% { -moz-transform: translateY(0); } 20% { opacity: 1; -moz-transform: translateY(20px); } 100% { opacity: 0; -moz-transform: translateY(-2000px); } } @-o-keyframes bounceOutUp { 0% { -o-transform: translateY(0); } 20% { opacity: 1; -o-transform: translateY(20px); } 100% { opacity: 0; -o-transform: translateY(-2000px); } } @keyframes bounceOutUp { 0% { transform: translateY(0); } 20% { opacity: 1; transform: translateY(20px); } 100% { opacity: 0; transform: translateY(-2000px); } } .bounceOutUp { -webkit-animation-name: bounceOutUp; -moz-animation-name: bounceOutUp; -o-animation-name: bounceOutUp; animation-name: bounceOutUp; } @-webkit-keyframes bounceOutDown { 0% { -webkit-transform: translateY(0); } 20% { opacity: 1; -webkit-transform: translateY(-20px); } 100% { opacity: 0; -webkit-transform: translateY(2000px); } } @-moz-keyframes bounceOutDown { 0% { -moz-transform: translateY(0); } 20% { opacity: 1; -moz-transform: translateY(-20px); } 100% { opacity: 0; -moz-transform: translateY(2000px); } } @-o-keyframes bounceOutDown { 0% { -o-transform: translateY(0); } 20% { opacity: 1; -o-transform: translateY(-20px); } 100% { opacity: 0; -o-transform: translateY(2000px); } } @keyframes bounceOutDown { 0% { transform: translateY(0); } 20% { opacity: 1; transform: translateY(-20px); } 100% { opacity: 0; transform: translateY(2000px); } } .bounceOutDown { -webkit-animation-name: bounceOutDown; -moz-animation-name: bounceOutDown; -o-animation-name: bounceOutDown; animation-name: bounceOutDown; } @-webkit-keyframes bounceOutLeft { 0% { -webkit-transform: translateX(0); } 20% { opacity: 1; -webkit-transform: translateX(20px); } 100% { opacity: 0; -webkit-transform: translateX(-2000px); } } @-moz-keyframes bounceOutLeft { 0% { -moz-transform: translateX(0); } 20% { opacity: 1; -moz-transform: translateX(20px); } 100% { opacity: 0; -moz-transform: translateX(-2000px); } } @-o-keyframes bounceOutLeft { 0% { -o-transform: translateX(0); } 20% { opacity: 1; -o-transform: translateX(20px); } 100% { opacity: 0; -o-transform: translateX(-2000px); } } @keyframes bounceOutLeft { 0% { transform: translateX(0); } 20% { opacity: 1; transform: translateX(20px); } 100% { opacity: 0; transform: translateX(-2000px); } } .bounceOutLeft { -webkit-animation-name: bounceOutLeft; -moz-animation-name: bounceOutLeft; -o-animation-name: bounceOutLeft; animation-name: bounceOutLeft; } @-webkit-keyframes bounceOutRight { 0% { -webkit-transform: translateX(0); } 20% { opacity: 1; -webkit-transform: translateX(-20px); } 100% { opacity: 0; -webkit-transform: translateX(2000px); } } @-moz-keyframes bounceOutRight { 0% { -moz-transform: translateX(0); } 20% { opacity: 1; -moz-transform: translateX(-20px); } 100% { opacity: 0; -moz-transform: translateX(2000px); } } @-o-keyframes bounceOutRight { 0% { -o-transform: translateX(0); } 20% { opacity: 1; -o-transform: translateX(-20px); } 100% { opacity: 0; -o-transform: translateX(2000px); } } @keyframes bounceOutRight { 0% { transform: translateX(0); } 20% { opacity: 1; transform: translateX(-20px); } 100% { opacity: 0; transform: translateX(2000px); } } .bounceOutRight { -webkit-animation-name: bounceOutRight; -moz-animation-name: bounceOutRight; -o-animation-name: bounceOutRight; animation-name: bounceOutRight; } @-webkit-keyframes rotateIn { 0% { -webkit-transform-origin: center center; -webkit-transform: rotate(-200deg); opacity: 0; } 100% { -webkit-transform-origin: center center; -webkit-transform: rotate(0); opacity: 1; } } @-moz-keyframes rotateIn { 0% { -moz-transform-origin: center center; -moz-transform: rotate(-200deg); opacity: 0; } 100% { -moz-transform-origin: center center; -moz-transform: rotate(0); opacity: 1; } } @-o-keyframes rotateIn { 0% { -o-transform-origin: center center; -o-transform: rotate(-200deg); opacity: 0; } 100% { -o-transform-origin: center center; -o-transform: rotate(0); opacity: 1; } } @keyframes rotateIn { 0% { transform-origin: center center; transform: rotate(-200deg); opacity: 0; } 100% { transform-origin: center center; transform: rotate(0); opacity: 1; } } .rotateIn { -webkit-animation-name: rotateIn; -moz-animation-name: rotateIn; -o-animation-name: rotateIn; animation-name: rotateIn; } @-webkit-keyframes rotateInUpLeft { 0% { -webkit-transform-origin: left bottom; -webkit-transform: rotate(90deg); opacity: 0; } 100% { -webkit-transform-origin: left bottom; -webkit-transform: rotate(0); opacity: 1; } } @-moz-keyframes rotateInUpLeft { 0% { -moz-transform-origin: left bottom; -moz-transform: rotate(90deg); opacity: 0; } 100% { -moz-transform-origin: left bottom; -moz-transform: rotate(0); opacity: 1; } } @-o-keyframes rotateInUpLeft { 0% { -o-transform-origin: left bottom; -o-transform: rotate(90deg); opacity: 0; } 100% { -o-transform-origin: left bottom; -o-transform: rotate(0); opacity: 1; } } @keyframes rotateInUpLeft { 0% { transform-origin: left bottom; transform: rotate(90deg); opacity: 0; } 100% { transform-origin: left bottom; transform: rotate(0); opacity: 1; } } .rotateInUpLeft { -webkit-animation-name: rotateInUpLeft; -moz-animation-name: rotateInUpLeft; -o-animation-name: rotateInUpLeft; animation-name: rotateInUpLeft; } @-webkit-keyframes rotateInDownLeft { 0% { -webkit-transform-origin: left bottom; -webkit-transform: rotate(-90deg); opacity: 0; } 100% { -webkit-transform-origin: left bottom; -webkit-transform: rotate(0); opacity: 1; } } @-moz-keyframes rotateInDownLeft { 0% { -moz-transform-origin: left bottom; -moz-transform: rotate(-90deg); opacity: 0; } 100% { -moz-transform-origin: left bottom; -moz-transform: rotate(0); opacity: 1; } } @-o-keyframes rotateInDownLeft { 0% { -o-transform-origin: left bottom; -o-transform: rotate(-90deg); opacity: 0; } 100% { -o-transform-origin: left bottom; -o-transform: rotate(0); opacity: 1; } } @keyframes rotateInDownLeft { 0% { transform-origin: left bottom; transform: rotate(-90deg); opacity: 0; } 100% { transform-origin: left bottom; transform: rotate(0); opacity: 1; } } .rotateInDownLeft { -webkit-animation-name: rotateInDownLeft; -moz-animation-name: rotateInDownLeft; -o-animation-name: rotateInDownLeft; animation-name: rotateInDownLeft; } @-webkit-keyframes rotateInUpRight { 0% { -webkit-transform-origin: right bottom; -webkit-transform: rotate(-90deg); opacity: 0; } 100% { -webkit-transform-origin: right bottom; -webkit-transform: rotate(0); opacity: 1; } } @-moz-keyframes rotateInUpRight { 0% { -moz-transform-origin: right bottom; -moz-transform: rotate(-90deg); opacity: 0; } 100% { -moz-transform-origin: right bottom; -moz-transform: rotate(0); opacity: 1; } } @-o-keyframes rotateInUpRight { 0% { -o-transform-origin: right bottom; -o-transform: rotate(-90deg); opacity: 0; } 100% { -o-transform-origin: right bottom; -o-transform: rotate(0); opacity: 1; } } @keyframes rotateInUpRight { 0% { transform-origin: right bottom; transform: rotate(-90deg); opacity: 0; } 100% { transform-origin: right bottom; transform: rotate(0); opacity: 1; } } .rotateInUpRight { -webkit-animation-name: rotateInUpRight; -moz-animation-name: rotateInUpRight; -o-animation-name: rotateInUpRight; animation-name: rotateInUpRight; } @-webkit-keyframes rotateInDownRight { 0% { -webkit-transform-origin: right bottom; -webkit-transform: rotate(90deg); opacity: 0; } 100% { -webkit-transform-origin: right bottom; -webkit-transform: rotate(0); opacity: 1; } } @-moz-keyframes rotateInDownRight { 0% { -moz-transform-origin: right bottom; -moz-transform: rotate(90deg); opacity: 0; } 100% { -moz-transform-origin: right bottom; -moz-transform: rotate(0); opacity: 1; } } @-o-keyframes rotateInDownRight { 0% { -o-transform-origin: right bottom; -o-transform: rotate(90deg); opacity: 0; } 100% { -o-transform-origin: right bottom; -o-transform: rotate(0); opacity: 1; } } @keyframes rotateInDownRight { 0% { transform-origin: right bottom; transform: rotate(90deg); opacity: 0; } 100% { transform-origin: right bottom; transform: rotate(0); opacity: 1; } } .rotateInDownRight { -webkit-animation-name: rotateInDownRight; -moz-animation-name: rotateInDownRight; -o-animation-name: rotateInDownRight; animation-name: rotateInDownRight; } @-webkit-keyframes rotateOut { 0% { -webkit-transform-origin: center center; -webkit-transform: rotate(0); opacity: 1; } 100% { -webkit-transform-origin: center center; -webkit-transform: rotate(200deg); opacity: 0; } } @-moz-keyframes rotateOut { 0% { -moz-transform-origin: center center; -moz-transform: rotate(0); opacity: 1; } 100% { -moz-transform-origin: center center; -moz-transform: rotate(200deg); opacity: 0; } } @-o-keyframes rotateOut { 0% { -o-transform-origin: center center; -o-transform: rotate(0); opacity: 1; } 100% { -o-transform-origin: center center; -o-transform: rotate(200deg); opacity: 0; } } @keyframes rotateOut { 0% { transform-origin: center center; transform: rotate(0); opacity: 1; } 100% { transform-origin: center center; transform: rotate(200deg); opacity: 0; } } .rotateOut { -webkit-animation-name: rotateOut; -moz-animation-name: rotateOut; -o-animation-name: rotateOut; animation-name: rotateOut; } @-webkit-keyframes rotateOutUpLeft { 0% { -webkit-transform-origin: left bottom; -webkit-transform: rotate(0); opacity: 1; } 100% { -webkit-transform-origin: left bottom; -webkit-transform: rotate(-90deg); opacity: 0; } } @-moz-keyframes rotateOutUpLeft { 0% { -moz-transform-origin: left bottom; -moz-transform: rotate(0); opacity: 1; } 100% { -moz-transform-origin: left bottom; -moz-transform: rotate(-90deg); opacity: 0; } } @-o-keyframes rotateOutUpLeft { 0% { -o-transform-origin: left bottom; -o-transform: rotate(0); opacity: 1; } 100% { -o-transform-origin: left bottom; -o-transform: rotate(-90deg); opacity: 0; } } @keyframes rotateOutUpLeft { 0% { transform-origin: left bottom; transform: rotate(0); opacity: 1; } 100% { transform-origin: left bottom; transform: rotate(-90deg); opacity: 0; } } .rotateOutUpLeft { -webkit-animation-name: rotateOutUpLeft; -moz-animation-name: rotateOutUpLeft; -o-animation-name: rotateOutUpLeft; animation-name: rotateOutUpLeft; } @-webkit-keyframes rotateOutDownLeft { 0% { -webkit-transform-origin: left bottom; -webkit-transform: rotate(0); opacity: 1; } 100% { -webkit-transform-origin: left bottom; -webkit-transform: rotate(90deg); opacity: 0; } } @-moz-keyframes rotateOutDownLeft { 0% { -moz-transform-origin: left bottom; -moz-transform: rotate(0); opacity: 1; } 100% { -moz-transform-origin: left bottom; -moz-transform: rotate(90deg); opacity: 0; } } @-o-keyframes rotateOutDownLeft { 0% { -o-transform-origin: left bottom; -o-transform: rotate(0); opacity: 1; } 100% { -o-transform-origin: left bottom; -o-transform: rotate(90deg); opacity: 0; } } @keyframes rotateOutDownLeft { 0% { transform-origin: left bottom; transform: rotate(0); opacity: 1; } 100% { transform-origin: left bottom; transform: rotate(90deg); opacity: 0; } } .rotateOutDownLeft { -webkit-animation-name: rotateOutDownLeft; -moz-animation-name: rotateOutDownLeft; -o-animation-name: rotateOutDownLeft; animation-name: rotateOutDownLeft; } @-webkit-keyframes rotateOutUpRight { 0% { -webkit-transform-origin: right bottom; -webkit-transform: rotate(0); opacity: 1; } 100% { -webkit-transform-origin: right bottom; -webkit-transform: rotate(90deg); opacity: 0; } } @-moz-keyframes rotateOutUpRight { 0% { -moz-transform-origin: right bottom; -moz-transform: rotate(0); opacity: 1; } 100% { -moz-transform-origin: right bottom; -moz-transform: rotate(90deg); opacity: 0; } } @-o-keyframes rotateOutUpRight { 0% { -o-transform-origin: right bottom; -o-transform: rotate(0); opacity: 1; } 100% { -o-transform-origin: right bottom; -o-transform: rotate(90deg); opacity: 0; } } @keyframes rotateOutUpRight { 0% { transform-origin: right bottom; transform: rotate(0); opacity: 1; } 100% { transform-origin: right bottom; transform: rotate(90deg); opacity: 0; } } .rotateOutUpRight { -webkit-animation-name: rotateOutUpRight; -moz-animation-name: rotateOutUpRight; -o-animation-name: rotateOutUpRight; animation-name: rotateOutUpRight; } @-webkit-keyframes rotateOutDownRight { 0% { -webkit-transform-origin: right bottom; -webkit-transform: rotate(0); opacity: 1; } 100% { -webkit-transform-origin: right bottom; -webkit-transform: rotate(-90deg); opacity: 0; } } @-moz-keyframes rotateOutDownRight { 0% { -moz-transform-origin: right bottom; -moz-transform: rotate(0); opacity: 1; } 100% { -moz-transform-origin: right bottom; -moz-transform: rotate(-90deg); opacity: 0; } } @-o-keyframes rotateOutDownRight { 0% { -o-transform-origin: right bottom; -o-transform: rotate(0); opacity: 1; } 100% { -o-transform-origin: right bottom; -o-transform: rotate(-90deg); opacity: 0; } } @keyframes rotateOutDownRight { 0% { transform-origin: right bottom; transform: rotate(0); opacity: 1; } 100% { transform-origin: right bottom; transform: rotate(-90deg); opacity: 0; } } .rotateOutDownRight { -webkit-animation-name: rotateOutDownRight; -moz-animation-name: rotateOutDownRight; -o-animation-name: rotateOutDownRight; animation-name: rotateOutDownRight; } @-webkit-keyframes hinge { 0% { -webkit-transform: rotate(0); -webkit-transform-origin: top left; -webkit-animation-timing-function: ease-in-out; } 20%, 60% { -webkit-transform: rotate(80deg); -webkit-transform-origin: top left; -webkit-animation-timing-function: ease-in-out; } 40% { -webkit-transform: rotate(60deg); -webkit-transform-origin: top left; -webkit-animation-timing-function: ease-in-out; } 80% { -webkit-transform: rotate(60deg) translateY(0); opacity: 1; -webkit-transform-origin: top left; -webkit-animation-timing-function: ease-in-out; } 100% { -webkit-transform: translateY(700px); opacity: 0; } } @-moz-keyframes hinge { 0% { -moz-transform: rotate(0); -moz-transform-origin: top left; -moz-animation-timing-function: ease-in-out; } 20%, 60% { -moz-transform: rotate(80deg); -moz-transform-origin: top left; -moz-animation-timing-function: ease-in-out; } 40% { -moz-transform: rotate(60deg); -moz-transform-origin: top left; -moz-animation-timing-function: ease-in-out; } 80% { -moz-transform: rotate(60deg) translateY(0); opacity: 1; -moz-transform-origin: top left; -moz-animation-timing-function: ease-in-out; } 100% { -moz-transform: translateY(700px); opacity: 0; } } @-o-keyframes hinge { 0% { -o-transform: rotate(0); -o-transform-origin: top left; -o-animation-timing-function: ease-in-out; } 20%, 60% { -o-transform: rotate(80deg); -o-transform-origin: top left; -o-animation-timing-function: ease-in-out; } 40% { -o-transform: rotate(60deg); -o-transform-origin: top left; -o-animation-timing-function: ease-in-out; } 80% { -o-transform: rotate(60deg) translateY(0); opacity: 1; -o-transform-origin: top left; -o-animation-timing-function: ease-in-out; } 100% { -o-transform: translateY(700px); opacity: 0; } } @keyframes hinge { 0% { transform: rotate(0); transform-origin: top left; animation-timing-function: ease-in-out; } 20%, 60% { transform: rotate(80deg); transform-origin: top left; animation-timing-function: ease-in-out; } 40% { transform: rotate(60deg); transform-origin: top left; animation-timing-function: ease-in-out; } 80% { transform: rotate(60deg) translateY(0); opacity: 1; transform-origin: top left; animation-timing-function: ease-in-out; } 100% { transform: translateY(700px); opacity: 0; } } .hinge { -webkit-animation-name: hinge; -moz-animation-name: hinge; -o-animation-name: hinge; animation-name: hinge; } /* originally authored by Nick Pettit - https://github.com/nickpettit/glide */ @-webkit-keyframes rollIn { 0% { opacity: 0; -webkit-transform: translateX(-100%) rotate(-120deg); } 100% { opacity: 1; -webkit-transform: translateX(0px) rotate(0deg); } } @-moz-keyframes rollIn { 0% { opacity: 0; -moz-transform: translateX(-100%) rotate(-120deg); } 100% { opacity: 1; -moz-transform: translateX(0px) rotate(0deg); } } @-o-keyframes rollIn { 0% { opacity: 0; -o-transform: translateX(-100%) rotate(-120deg); } 100% { opacity: 1; -o-transform: translateX(0px) rotate(0deg); } } @keyframes rollIn { 0% { opacity: 0; transform: translateX(-100%) rotate(-120deg); } 100% { opacity: 1; transform: translateX(0px) rotate(0deg); } } .rollIn { -webkit-animation-name: rollIn; -moz-animation-name: rollIn; -o-animation-name: rollIn; animation-name: rollIn; } /* originally authored by Nick Pettit - https://github.com/nickpettit/glide */ @-webkit-keyframes rollOut { 0% { opacity: 1; -webkit-transform: translateX(0px) rotate(0deg); } 100% { opacity: 0; -webkit-transform: translateX(100%) rotate(120deg); } } @-moz-keyframes rollOut { 0% { opacity: 1; -moz-transform: translateX(0px) rotate(0deg); } 100% { opacity: 0; -moz-transform: translateX(100%) rotate(120deg); } } @-o-keyframes rollOut { 0% { opacity: 1; -o-transform: translateX(0px) rotate(0deg); } 100% { opacity: 0; -o-transform: translateX(100%) rotate(120deg); } } @keyframes rollOut { 0% { opacity: 1; transform: translateX(0px) rotate(0deg); } 100% { opacity: 0; transform: translateX(100%) rotate(120deg); } } .rollOut { -webkit-animation-name: rollOut; -moz-animation-name: rollOut; -o-animation-name: rollOut; animation-name: rollOut; } /* originally authored by Angelo Rohit - https://github.com/angelorohit */ @-webkit-keyframes lightSpeedIn { 0% { -webkit-transform: translateX(100%) skewX(-30deg); opacity: 0; } 60% { -webkit-transform: translateX(-20%) skewX(30deg); opacity: 1; } 80% { -webkit-transform: translateX(0%) skewX(-15deg); opacity: 1; } 100% { -webkit-transform: translateX(0%) skewX(0deg); opacity: 1; } } @-moz-keyframes lightSpeedIn { 0% { -moz-transform: translateX(100%) skewX(-30deg); opacity: 0; } 60% { -moz-transform: translateX(-20%) skewX(30deg); opacity: 1; } 80% { -moz-transform: translateX(0%) skewX(-15deg); opacity: 1; } 100% { -moz-transform: translateX(0%) skewX(0deg); opacity: 1; } } @-o-keyframes lightSpeedIn { 0% { -o-transform: translateX(100%) skewX(-30deg); opacity: 0; } 60% { -o-transform: translateX(-20%) skewX(30deg); opacity: 1; } 80% { -o-transform: translateX(0%) skewX(-15deg); opacity: 1; } 100% { -o-transform: translateX(0%) skewX(0deg); opacity: 1; } } @keyframes lightSpeedIn { 0% { transform: translateX(100%) skewX(-30deg); opacity: 0; } 60% { transform: translateX(-20%) skewX(30deg); opacity: 1; } 80% { transform: translateX(0%) skewX(-15deg); opacity: 1; } 100% { transform: translateX(0%) skewX(0deg); opacity: 1; } } .lightSpeedIn { -webkit-animation-name: lightSpeedIn; -moz-animation-name: lightSpeedIn; -o-animation-name: lightSpeedIn; animation-name: lightSpeedIn; -webkit-animation-timing-function: ease-out; -moz-animation-timing-function: ease-out; -o-animation-timing-function: ease-out; animation-timing-function: ease-out; } .animated.lightSpeedIn { -webkit-animation-duration: 0.5s; -moz-animation-duration: 0.5s; -o-animation-duration: 0.5s; animation-duration: 0.5s; } /* originally authored by Angelo Rohit - https://github.com/angelorohit */ @-webkit-keyframes lightSpeedOut { 0% { -webkit-transform: translateX(0%) skewX(0deg); opacity: 1; } 100% { -webkit-transform: translateX(100%) skewX(-30deg); opacity: 0; } } @-moz-keyframes lightSpeedOut { 0% { -moz-transform: translateX(0%) skewX(0deg); opacity: 1; } 100% { -moz-transform: translateX(100%) skewX(-30deg); opacity: 0; } } @-o-keyframes lightSpeedOut { 0% { -o-transform: translateX(0%) skewX(0deg); opacity: 1; } 100% { -o-transform: translateX(100%) skewX(-30deg); opacity: 0; } } @keyframes lightSpeedOut { 0% { transform: translateX(0%) skewX(0deg); opacity: 1; } 100% { transform: translateX(100%) skewX(-30deg); opacity: 0; } } .lightSpeedOut { -webkit-animation-name: lightSpeedOut; -moz-animation-name: lightSpeedOut; -o-animation-name: lightSpeedOut; animation-name: lightSpeedOut; -webkit-animation-timing-function: ease-in; -moz-animation-timing-function: ease-in; -o-animation-timing-function: ease-in; animation-timing-function: ease-in; } .animated.lightSpeedOut { -webkit-animation-duration: 0.25s; -moz-animation-duration: 0.25s; -o-animation-duration: 0.25s; animation-duration: 0.25s; } /* originally authored by Angelo Rohit - https://github.com/angelorohit */ @-webkit-keyframes wiggle { 0% { -webkit-transform: skewX(9deg); } 10% { -webkit-transform: skewX(-8deg); } 20% { -webkit-transform: skewX(7deg); } 30% { -webkit-transform: skewX(-6deg); } 40% { -webkit-transform: skewX(5deg); } 50% { -webkit-transform: skewX(-4deg); } 60% { -webkit-transform: skewX(3deg); } 70% { -webkit-transform: skewX(-2deg); } 80% { -webkit-transform: skewX(1deg); } 90% { -webkit-transform: skewX(0deg); } 100% { -webkit-transform: skewX(0deg); } } @-moz-keyframes wiggle { 0% { -moz-transform: skewX(9deg); } 10% { -moz-transform: skewX(-8deg); } 20% { -moz-transform: skewX(7deg); } 30% { -moz-transform: skewX(-6deg); } 40% { -moz-transform: skewX(5deg); } 50% { -moz-transform: skewX(-4deg); } 60% { -moz-transform: skewX(3deg); } 70% { -moz-transform: skewX(-2deg); } 80% { -moz-transform: skewX(1deg); } 90% { -moz-transform: skewX(0deg); } 100% { -moz-transform: skewX(0deg); } } @-o-keyframes wiggle { 0% { -o-transform: skewX(9deg); } 10% { -o-transform: skewX(-8deg); } 20% { -o-transform: skewX(7deg); } 30% { -o-transform: skewX(-6deg); } 40% { -o-transform: skewX(5deg); } 50% { -o-transform: skewX(-4deg); } 60% { -o-transform: skewX(3deg); } 70% { -o-transform: skewX(-2deg); } 80% { -o-transform: skewX(1deg); } 90% { -o-transform: skewX(0deg); } 100% { -o-transform: skewX(0deg); } } @keyframes wiggle { 0% { transform: skewX(9deg); } 10% { transform: skewX(-8deg); } 20% { transform: skewX(7deg); } 30% { transform: skewX(-6deg); } 40% { transform: skewX(5deg); } 50% { transform: skewX(-4deg); } 60% { transform: skewX(3deg); } 70% { transform: skewX(-2deg); } 80% { transform: skewX(1deg); } 90% { transform: skewX(0deg); } 100% { transform: skewX(0deg); } } .wiggle { -webkit-animation-name: wiggle; -moz-animation-name: wiggle; -o-animation-name: wiggle; animation-name: wiggle; -webkit-animation-timing-function: ease-in; -moz-animation-timing-function: ease-in; -o-animation-timing-function: ease-in; animation-timing-function: ease-in; } .animated.wiggle { -webkit-animation-duration: 0.75s; -moz-animation-duration: 0.75s; -o-animation-duration: 0.75s; animation-duration: 0.75s; } ================================================ FILE: backend/tests/integration/tests/pruning/website/css/custom-fonts.css ================================================ /* ================================================== Font-Face Icons ================================================== */ @font-face { font-family: "Icons"; src: url("../fonts/customicon/Icons.eot"); src: url("../fonts/customicon/Icons.eot?#iefix") format("embedded-opentype"), url("../fonts/customicon/Icons.woff") format("woff"), url("../fonts/customicon/Icons.ttf") format("truetype"), url("../fonts/customicon/Icons.svg#Icons") format("svg"); font-weight: normal; font-style: normal; } /* Use the following CSS code if you want to use data attributes for inserting your icons */ [data-icon]:before { font-family: "Icons"; content: attr(data-icon); speak: none; font-weight: normal; font-variant: normal; text-transform: none; line-height: 1; -webkit-font-smoothing: antialiased; } [class^="font-"]:before, [class*=" font-"]:before { font-family: "Icons"; speak: none; font-style: normal; font-weight: normal; font-variant: normal; text-transform: none; -webkit-font-smoothing: antialiased; } [class^="font-"], [class*=" font-"] { display: inline-block; line-height: 1em; } /* Use the following CSS code if you want to have a class per icon */ /* Instead of a list of all class selectors, you can use the generic selector below, but it's slower: [class*="font-icon-"] { */ .font-icon-zoom-out, .font-icon-zoom-in, .font-icon-wrench, .font-icon-waves, .font-icon-warning, .font-icon-volume-up, .font-icon-volume-off, .font-icon-volume-down, .font-icon-viewport, .font-icon-user, .font-icon-user-border, .font-icon-upload, .font-icon-upload-2, .font-icon-unlock, .font-icon-underline, .font-icon-tint, .font-icon-time, .font-icon-text, .font-icon-text-width, .font-icon-text-height, .font-icon-tags, .font-icon-tag, .font-icon-table, .font-icon-strikethrough, .font-icon-stop, .font-icon-step-forward, .font-icon-step-backward, .font-icon-stars, .font-icon-star, .font-icon-star-line, .font-icon-star-half, .font-icon-sort, .font-icon-sort-up, .font-icon-sort-down, .font-icon-social-zerply, .font-icon-social-youtube, .font-icon-social-yelp, .font-icon-social-yahoo, .font-icon-social-wordpress, .font-icon-social-virb, .font-icon-social-vimeo, .font-icon-social-viddler, .font-icon-social-twitter, .font-icon-social-tumblr, .font-icon-social-stumbleupon, .font-icon-social-soundcloud, .font-icon-social-skype, .font-icon-social-share-this, .font-icon-social-quora, .font-icon-social-pinterest, .font-icon-social-photobucket, .font-icon-social-paypal, .font-icon-social-myspace, .font-icon-social-linkedin, .font-icon-social-last-fm, .font-icon-social-grooveshark, .font-icon-social-google-plus, .font-icon-social-github, .font-icon-social-forrst, .font-icon-social-flickr, .font-icon-social-facebook, .font-icon-social-evernote, .font-icon-social-envato, .font-icon-social-email, .font-icon-social-dribbble, .font-icon-social-digg, .font-icon-social-deviant-art, .font-icon-social-blogger, .font-icon-social-behance, .font-icon-social-bebo, .font-icon-social-addthis, .font-icon-social-500px, .font-icon-sitemap, .font-icon-signout, .font-icon-signin, .font-icon-signal, .font-icon-shopping-cart, .font-icon-search, .font-icon-rss, .font-icon-road, .font-icon-retweet, .font-icon-resize-vertical, .font-icon-resize-vertical-2, .font-icon-resize-small, .font-icon-resize-horizontal, .font-icon-resize-horizontal-2, .font-icon-resize-fullscreen, .font-icon-resize-full, .font-icon-repeat, .font-icon-reorder, .font-icon-remove, .font-icon-remove-sign, .font-icon-remove-circle, .font-icon-read-more, .font-icon-random, .font-icon-question-sign, .font-icon-pushpin, .font-icon-pushpin-2, .font-icon-print, .font-icon-plus, .font-icon-plus-sign, .font-icon-play, .font-icon-picture, .font-icon-phone, .font-icon-phone-sign, .font-icon-phone-boxed, .font-icon-pause, .font-icon-paste, .font-icon-paper-clip, .font-icon-ok, .font-icon-ok-sign, .font-icon-ok-circle, .font-icon-music, .font-icon-move, .font-icon-money, .font-icon-minus, .font-icon-minus-sign, .font-icon-map, .font-icon-map-marker, .font-icon-map-marker-2, .font-icon-magnet, .font-icon-magic, .font-icon-lock, .font-icon-list, .font-icon-list-3, .font-icon-list-2, .font-icon-link, .font-icon-layer, .font-icon-key, .font-icon-italic, .font-icon-info, .font-icon-indent-right, .font-icon-indent-left, .font-icon-inbox, .font-icon-inbox-empty, .font-icon-home, .font-icon-heart, .font-icon-heart-line, .font-icon-headphones, .font-icon-headphones-line, .font-icon-headphones-line-2, .font-icon-headphones-2, .font-icon-hdd, .font-icon-group, .font-icon-grid, .font-icon-grid-large, .font-icon-globe_line, .font-icon-glass, .font-icon-glass_2, .font-icon-gift, .font-icon-forward, .font-icon-font, .font-icon-folder-open, .font-icon-folder-close, .font-icon-flag, .font-icon-fire, .font-icon-film, .font-icon-file, .font-icon-file-empty, .font-icon-fast-forward, .font-icon-fast-backward, .font-icon-facetime, .font-icon-eye, .font-icon-eye_disable, .font-icon-expand-view, .font-icon-expand-view-3, .font-icon-expand-view-2, .font-icon-expand-vertical, .font-icon-expand-horizontal, .font-icon-exclamation, .font-icon-email, .font-icon-email_2, .font-icon-eject, .font-icon-edit, .font-icon-edit-check, .font-icon-download, .font-icon-download_2, .font-icon-dashboard, .font-icon-credit-card, .font-icon-copy, .font-icon-comments, .font-icon-comments-line, .font-icon-comment, .font-icon-comment-line, .font-icon-columns, .font-icon-columns-2, .font-icon-cogs, .font-icon-cog, .font-icon-cloud, .font-icon-check, .font-icon-check-empty, .font-icon-certificate, .font-icon-camera, .font-icon-calendar, .font-icon-bullhorn, .font-icon-briefcase, .font-icon-bookmark, .font-icon-book, .font-icon-bolt, .font-icon-bold, .font-icon-blockquote, .font-icon-bell, .font-icon-beaker, .font-icon-barcode, .font-icon-ban-circle, .font-icon-ban-chart, .font-icon-ban-chart-2, .font-icon-backward, .font-icon-asterisk, .font-icon-arrow-simple-up, .font-icon-arrow-simple-up-circle, .font-icon-arrow-simple-right, .font-icon-arrow-simple-right-circle, .font-icon-arrow-simple-left, .font-icon-arrow-simple-left-circle, .font-icon-arrow-simple-down, .font-icon-arrow-simple-down-circle, .font-icon-arrow-round-up, .font-icon-arrow-round-up-circle, .font-icon-arrow-round-right, .font-icon-arrow-round-right-circle, .font-icon-arrow-round-left, .font-icon-arrow-round-left-circle, .font-icon-arrow-round-down, .font-icon-arrow-round-down-circle, .font-icon-arrow-light-up, .font-icon-arrow-light-round-up, .font-icon-arrow-light-round-up-circle, .font-icon-arrow-light-round-right, .font-icon-arrow-light-round-right-circle, .font-icon-arrow-light-round-left, .font-icon-arrow-light-round-left-circle, .font-icon-arrow-light-round-down, .font-icon-arrow-light-round-down-circle, .font-icon-arrow-light-right, .font-icon-arrow-light-left, .font-icon-arrow-light-down, .font-icon-align-right, .font-icon-align-left, .font-icon-align-justify, .font-icon-align-center, .font-icon-adjust { font-family: "Icons"; speak: none; font-style: normal; font-weight: normal; font-variant: normal; text-transform: none; line-height: 1; -webkit-font-smoothing: antialiased; } .font-icon-zoom-out:before { content: "\e000"; } .font-icon-zoom-in:before { content: "\e001"; } .font-icon-wrench:before { content: "\e002"; } .font-icon-waves:before { content: "\e003"; } .font-icon-warning:before { content: "\e004"; } .font-icon-volume-up:before { content: "\e005"; } .font-icon-volume-off:before { content: "\e006"; } .font-icon-volume-down:before { content: "\e007"; } .font-icon-viewport:before { content: "\e008"; } .font-icon-user:before { content: "\e009"; } .font-icon-user-border:before { content: "\e00a"; } .font-icon-upload:before { content: "\e00b"; } .font-icon-upload-2:before { content: "\e00c"; } .font-icon-unlock:before { content: "\e00d"; } .font-icon-underline:before { content: "\e00e"; } .font-icon-tint:before { content: "\e00f"; } .font-icon-time:before { content: "\e010"; } .font-icon-text:before { content: "\e011"; } .font-icon-text-width:before { content: "\e012"; } .font-icon-text-height:before { content: "\e013"; } .font-icon-tags:before { content: "\e014"; } .font-icon-tag:before { content: "\e015"; } .font-icon-table:before { content: "\e016"; } .font-icon-strikethrough:before { content: "\e017"; } .font-icon-stop:before { content: "\e018"; } .font-icon-step-forward:before { content: "\e019"; } .font-icon-step-backward:before { content: "\e01a"; } .font-icon-stars:before { content: "\e01b"; } .font-icon-star:before { content: "\e01c"; } .font-icon-star-line:before { content: "\e01d"; } .font-icon-star-half:before { content: "\e01e"; } .font-icon-sort:before { content: "\e01f"; } .font-icon-sort-up:before { content: "\e020"; } .font-icon-sort-down:before { content: "\e021"; } .font-icon-social-zerply:before { content: "\e022"; } .font-icon-social-youtube:before { content: "\e023"; } .font-icon-social-yelp:before { content: "\e024"; } .font-icon-social-yahoo:before { content: "\e025"; } .font-icon-social-wordpress:before { content: "\e026"; } .font-icon-social-virb:before { content: "\e027"; } .font-icon-social-vimeo:before { content: "\e028"; } .font-icon-social-viddler:before { content: "\e029"; } .font-icon-social-twitter:before { content: "\e02a"; } .font-icon-social-tumblr:before { content: "\e02b"; } .font-icon-social-stumbleupon:before { content: "\e02c"; } .font-icon-social-soundcloud:before { content: "\e02d"; } .font-icon-social-skype:before { content: "\e02e"; } .font-icon-social-share-this:before { content: "\e02f"; } .font-icon-social-quora:before { content: "\e030"; } .font-icon-social-pinterest:before { content: "\e031"; } .font-icon-social-photobucket:before { content: "\e032"; } .font-icon-social-paypal:before { content: "\e033"; } .font-icon-social-myspace:before { content: "\e034"; } .font-icon-social-linkedin:before { content: "\e035"; } .font-icon-social-last-fm:before { content: "\e036"; } .font-icon-social-grooveshark:before { content: "\e037"; } .font-icon-social-google-plus:before { content: "\e038"; } .font-icon-social-github:before { content: "\e039"; } .font-icon-social-forrst:before { content: "\e03a"; } .font-icon-social-flickr:before { content: "\e03b"; } .font-icon-social-facebook:before { content: "\e03c"; } .font-icon-social-evernote:before { content: "\e03d"; } .font-icon-social-envato:before { content: "\e03e"; } .font-icon-social-email:before { content: "\e03f"; } .font-icon-social-dribbble:before { content: "\e040"; } .font-icon-social-digg:before { content: "\e041"; } .font-icon-social-deviant-art:before { content: "\e042"; } .font-icon-social-blogger:before { content: "\e043"; } .font-icon-social-behance:before { content: "\e044"; } .font-icon-social-bebo:before { content: "\e045"; } .font-icon-social-addthis:before { content: "\e046"; } .font-icon-social-500px:before { content: "\e047"; } .font-icon-sitemap:before { content: "\e048"; } .font-icon-signout:before { content: "\e049"; } .font-icon-signin:before { content: "\e04a"; } .font-icon-signal:before { content: "\e04b"; } .font-icon-shopping-cart:before { content: "\e04c"; } .font-icon-search:before { content: "\e04d"; } .font-icon-rss:before { content: "\e04e"; } .font-icon-road:before { content: "\e04f"; } .font-icon-retweet:before { content: "\e050"; } .font-icon-resize-vertical:before { content: "\e051"; } .font-icon-resize-vertical-2:before { content: "\e052"; } .font-icon-resize-small:before { content: "\e053"; } .font-icon-resize-horizontal:before { content: "\e054"; } .font-icon-resize-horizontal-2:before { content: "\e055"; } .font-icon-resize-fullscreen:before { content: "\e056"; } .font-icon-resize-full:before { content: "\e057"; } .font-icon-repeat:before { content: "\e058"; } .font-icon-reorder:before { content: "\e059"; } .font-icon-remove:before { content: "\e05a"; } .font-icon-remove-sign:before { content: "\e05b"; } .font-icon-remove-circle:before { content: "\e05c"; } .font-icon-read-more:before { content: "\e05d"; } .font-icon-random:before { content: "\e05e"; } .font-icon-question-sign:before { content: "\e05f"; } .font-icon-pushpin:before { content: "\e060"; } .font-icon-pushpin-2:before { content: "\e061"; } .font-icon-print:before { content: "\e062"; } .font-icon-plus:before { content: "\e063"; } .font-icon-plus-sign:before { content: "\e064"; } .font-icon-play:before { content: "\e065"; } .font-icon-picture:before { content: "\e066"; } .font-icon-phone:before { content: "\e067"; } .font-icon-phone-sign:before { content: "\e068"; } .font-icon-phone-boxed:before { content: "\e069"; } .font-icon-pause:before { content: "\e06a"; } .font-icon-paste:before { content: "\e06b"; } .font-icon-paper-clip:before { content: "\e06c"; } .font-icon-ok:before { content: "\e06d"; } .font-icon-ok-sign:before { content: "\e06e"; } .font-icon-ok-circle:before { content: "\e06f"; } .font-icon-music:before { content: "\e070"; } .font-icon-move:before { content: "\e071"; } .font-icon-money:before { content: "\e072"; } .font-icon-minus:before { content: "\e073"; } .font-icon-minus-sign:before { content: "\e074"; } .font-icon-map:before { content: "\e075"; } .font-icon-map-marker:before { content: "\e076"; } .font-icon-map-marker-2:before { content: "\e077"; } .font-icon-magnet:before { content: "\e078"; } .font-icon-magic:before { content: "\e079"; } .font-icon-lock:before { content: "\e07a"; } .font-icon-list:before { content: "\e07b"; } .font-icon-list-3:before { content: "\e07c"; } .font-icon-list-2:before { content: "\e07d"; } .font-icon-link:before { content: "\e07e"; } .font-icon-layer:before { content: "\e07f"; } .font-icon-key:before { content: "\e080"; } .font-icon-italic:before { content: "\e081"; } .font-icon-info:before { content: "\e082"; } .font-icon-indent-right:before { content: "\e083"; } .font-icon-indent-left:before { content: "\e084"; } .font-icon-inbox:before { content: "\e085"; } .font-icon-inbox-empty:before { content: "\e086"; } .font-icon-home:before { content: "\e087"; } .font-icon-heart:before { content: "\e088"; } .font-icon-heart-line:before { content: "\e089"; } .font-icon-headphones:before { content: "\e08a"; } .font-icon-headphones-line:before { content: "\e08b"; } .font-icon-headphones-line-2:before { content: "\e08c"; } .font-icon-headphones-2:before { content: "\e08d"; } .font-icon-hdd:before { content: "\e08e"; } .font-icon-group:before { content: "\e08f"; } .font-icon-grid:before { content: "\e090"; } .font-icon-grid-large:before { content: "\e091"; } .font-icon-globe_line:before { content: "\e092"; } .font-icon-glass:before { content: "\e093"; } .font-icon-glass_2:before { content: "\e094"; } .font-icon-gift:before { content: "\e095"; } .font-icon-forward:before { content: "\e096"; } .font-icon-font:before { content: "\e097"; } .font-icon-folder-open:before { content: "\e098"; } .font-icon-folder-close:before { content: "\e099"; } .font-icon-flag:before { content: "\e09a"; } .font-icon-fire:before { content: "\e09b"; } .font-icon-film:before { content: "\e09c"; } .font-icon-file:before { content: "\e09d"; } .font-icon-file-empty:before { content: "\e09e"; } .font-icon-fast-forward:before { content: "\e09f"; } .font-icon-fast-backward:before { content: "\e0a0"; } .font-icon-facetime:before { content: "\e0a1"; } .font-icon-eye:before { content: "\e0a2"; } .font-icon-eye_disable:before { content: "\e0a3"; } .font-icon-expand-view:before { content: "\e0a4"; } .font-icon-expand-view-3:before { content: "\e0a5"; } .font-icon-expand-view-2:before { content: "\e0a6"; } .font-icon-expand-vertical:before { content: "\e0a7"; } .font-icon-expand-horizontal:before { content: "\e0a8"; } .font-icon-exclamation:before { content: "\e0a9"; } .font-icon-email:before { content: "\e0aa"; } .font-icon-email_2:before { content: "\e0ab"; } .font-icon-eject:before { content: "\e0ac"; } .font-icon-edit:before { content: "\e0ad"; } .font-icon-edit-check:before { content: "\e0ae"; } .font-icon-download:before { content: "\e0af"; } .font-icon-download_2:before { content: "\e0b0"; } .font-icon-dashboard:before { content: "\e0b1"; } .font-icon-credit-card:before { content: "\e0b2"; } .font-icon-copy:before { content: "\e0b3"; } .font-icon-comments:before { content: "\e0b4"; } .font-icon-comments-line:before { content: "\e0b5"; } .font-icon-comment:before { content: "\e0b6"; } .font-icon-comment-line:before { content: "\e0b7"; } .font-icon-columns:before { content: "\e0b8"; } .font-icon-columns-2:before { content: "\e0b9"; } .font-icon-cogs:before { content: "\e0ba"; } .font-icon-cog:before { content: "\e0bb"; } .font-icon-cloud:before { content: "\e0bc"; } .font-icon-check:before { content: "\e0bd"; } .font-icon-check-empty:before { content: "\e0be"; } .font-icon-certificate:before { content: "\e0bf"; } .font-icon-camera:before { content: "\e0c0"; } .font-icon-calendar:before { content: "\e0c1"; } .font-icon-bullhorn:before { content: "\e0c2"; } .font-icon-briefcase:before { content: "\e0c3"; } .font-icon-bookmark:before { content: "\e0c4"; } .font-icon-book:before { content: "\e0c5"; } .font-icon-bolt:before { content: "\e0c6"; } .font-icon-bold:before { content: "\e0c7"; } .font-icon-blockquote:before { content: "\e0c8"; } .font-icon-bell:before { content: "\e0c9"; } .font-icon-beaker:before { content: "\e0ca"; } .font-icon-barcode:before { content: "\e0cb"; } .font-icon-ban-circle:before { content: "\e0cc"; } .font-icon-ban-chart:before { content: "\e0cd"; } .font-icon-ban-chart-2:before { content: "\e0ce"; } .font-icon-backward:before { content: "\e0cf"; } .font-icon-asterisk:before { content: "\e0d0"; } .font-icon-arrow-simple-up:before { content: "\e0d1"; } .font-icon-arrow-simple-up-circle:before { content: "\e0d2"; } .font-icon-arrow-simple-right:before { content: "\e0d3"; } .font-icon-arrow-simple-right-circle:before { content: "\e0d4"; } .font-icon-arrow-simple-left:before { content: "\e0d5"; } .font-icon-arrow-simple-left-circle:before { content: "\e0d6"; } .font-icon-arrow-simple-down:before { content: "\e0d7"; } .font-icon-arrow-simple-down-circle:before { content: "\e0d8"; } .font-icon-arrow-round-up:before { content: "\e0d9"; } .font-icon-arrow-round-up-circle:before { content: "\e0da"; } .font-icon-arrow-round-right:before { content: "\e0db"; } .font-icon-arrow-round-right-circle:before { content: "\e0dc"; } .font-icon-arrow-round-left:before { content: "\e0dd"; } .font-icon-arrow-round-left-circle:before { content: "\e0de"; } .font-icon-arrow-round-down:before { content: "\e0df"; } .font-icon-arrow-round-down-circle:before { content: "\e0e0"; } .font-icon-arrow-light-up:before { content: "\e0e1"; } .font-icon-arrow-light-round-up:before { content: "\e0e2"; } .font-icon-arrow-light-round-up-circle:before { content: "\e0e3"; } .font-icon-arrow-light-round-right:before { content: "\e0e4"; } .font-icon-arrow-light-round-right-circle:before { content: "\e0e5"; } .font-icon-arrow-light-round-left:before { content: "\e0e6"; } .font-icon-arrow-light-round-left-circle:before { content: "\e0e7"; } .font-icon-arrow-light-round-down:before { content: "\e0e8"; } .font-icon-arrow-light-round-down-circle:before { content: "\e0e9"; } .font-icon-arrow-light-right:before { content: "\e0ea"; } .font-icon-arrow-light-left:before { content: "\e0eb"; } .font-icon-arrow-light-down:before { content: "\e0ec"; } .font-icon-align-right:before { content: "\e0ed"; } .font-icon-align-left:before { content: "\e0ee"; } .font-icon-align-justify:before { content: "\e0ef"; } .font-icon-align-center:before { content: "\e0f0"; } .font-icon-adjust:before { content: "\e0f1"; } ================================================ FILE: backend/tests/integration/tests/pruning/website/css/fancybox/jquery.fancybox.css ================================================ /*! fancyBox v2.1.4 fancyapps.com | fancyapps.com/fancybox/#license */ .fancybox-wrap, .fancybox-skin, .fancybox-outer, .fancybox-inner, .fancybox-image, .fancybox-wrap iframe, .fancybox-wrap object, .fancybox-nav, .fancybox-nav span, .fancybox-tmp { padding: 0; margin: 0; border: 0; outline: none; vertical-align: top; } .fancybox-wrap { position: absolute; top: 0; left: 0; z-index: 8020; } .fancybox-skin { position: relative; background: #2f3238; color: #565656; text-shadow: none; -webkit-border-radius: 0; -moz-border-radius: 0; border-radius: 0; } .fancybox-opened { z-index: 8030; } .fancybox-opened .fancybox-skin { -webkit-box-shadow: none; -moz-box-shadow: none; box-shadow: none; } .fancybox-outer, .fancybox-inner { position: relative; } .fancybox-inner { overflow: hidden; } .fancybox-type-iframe .fancybox-inner { -webkit-overflow-scrolling: touch; } .fancybox-error { color: #444; font-size: 14px; line-height: 20px; margin: 0; padding: 15px; white-space: nowrap; } .fancybox-image, .fancybox-iframe { display: block; width: 100%; height: 100%; } .fancybox-image { max-width: 100%; max-height: 100%; } #fancybox-loading, .fancybox-close, .fancybox-prev span, .fancybox-next span { background-image: url("fancybox_sprite.png") !important; } #fancybox-loading { position: fixed; top: 50%; left: 50%; margin-top: -22px; margin-left: -22px; background-position: 0 -108px; opacity: 0.8; cursor: pointer; z-index: 8060; } #fancybox-loading div { width: 44px; height: 44px; background: url("fancybox_loading.gif") center center no-repeat; } .fancybox-close { position: absolute; right: 0; top: 0; width: 40px; height: 38px; cursor: pointer; z-index: 9000; background-image: none; opacity: 0.5; -webkit-transition: background 0.1s linear 0s, opacity 0.1s linear 0s; -moz-transition: background 0.1s linear 0s, opacity 0.1s linear 0s; -o-transition: background 0.1s linear 0s, opacity 0.1s linear 0s; transition: background 0.1s linear 0s, opacity 0.1s linear 0s; } .fancybox-close i { left: 50%; top: 50%; margin: -11px 0 0 -11px; font-size: 22px; line-height: 1em; position: absolute; color: #ffffff; } .fancybox-close:hover { opacity: 1; } .fancybox-nav { position: absolute; top: 0; height: 100%; cursor: pointer; text-decoration: none; background: transparent url("blank.gif"); /* helps IE */ -webkit-tap-highlight-color: rgba(0, 0, 0, 0); z-index: 8040; } .fancybox-prev, .fancybox-prev span { left: 0; } .fancybox-next, .fancybox-next span { right: 0; } .fancybox-nav span { position: absolute; top: 50%; width: 44px; height: 32px; margin-top: -25px; cursor: pointer; z-index: 8040; background-image: none; background-color: #26292e; background-position-y: -38px; opacity: 0.5; -webkit-transition: background 0.1s linear 0s, opacity 0.1s linear 0s; -moz-transition: background 0.1s linear 0s, opacity 0.1s linear 0s; -o-transition: background 0.1s linear 0s, opacity 0.1s linear 0s; transition: background 0.1s linear 0s, opacity 0.1s linear 0s; } .fancybox-next span { background-position-y: -72px; } .fancybox-prev span i { left: 50%; top: 50%; margin: -15px 0 0 -17px; font-size: 30px; line-height: 1em; position: absolute; color: #ffffff; } .fancybox-next span i { left: 50%; top: 50%; margin: -15px 0 0 -15px; font-size: 30px; line-height: 1em; position: absolute; color: #ffffff; } .fancybox-nav:hover span { opacity: 1; } .fancybox-tmp { position: absolute; top: -99999px; left: -99999px; visibility: hidden; max-width: 99999px; max-height: 99999px; overflow: visible !important; } /* Overlay helper */ .fancybox-lock { margin: 0 !important; } .fancybox-overlay { position: absolute; top: 0; left: 0; overflow: hidden !important; display: none; z-index: 8010; background: url("fancybox_overlay.png"); } .fancybox-overlay-fixed { position: fixed; bottom: 0; right: 0; } .fancybox-lock .fancybox-overlay { overflow: auto; overflow-y: scroll; } /* Title helper */ .fancybox-title { visibility: hidden; position: relative; text-shadow: none; z-index: 8050; } .fancybox-opened .fancybox-title { visibility: visible; } .fancybox-opened .fancybox-title h4 { font-size: 24px; color: #fff; font-weight: 300; margin-bottom: 10px; } .fancybox-opened .fancybox-title p { font-size: 16px; font-weight: 300; color: #bbb; line-height: 1.6em; margin-bottom: 0; } .fancybox-title-float-wrap { position: absolute; bottom: 0; right: 50%; margin-bottom: -35px; z-index: 8050; text-align: center; } .fancybox-title-float-wrap .child { display: inline-block; margin-right: -100%; padding: 2px 20px; background: transparent; /* Fallback for web browsers that doesn't support RGBa */ background: rgba(0, 0, 0, 0.8); -webkit-border-radius: 15px; -moz-border-radius: 15px; border-radius: 15px; text-shadow: 0 1px 2px #222; color: #fff; font-weight: bold; line-height: 24px; white-space: nowrap; } .fancybox-title-outside-wrap { position: relative; margin-top: 10px; color: #fff; } .fancybox-title-inside-wrap { padding: 3px 30px 6px; background: #61b331; } .fancybox-title-over-wrap { position: absolute; bottom: 0; left: 0; color: #fff; padding: 10px; background: #000; background: rgba(0, 0, 0, 0.8); } @media (max-width: 480px) { .fancybox-nav span, .fancybox-nav:hover span, .fancybox-close, .fancybox-close:hover { background: transparent; } .fancybox-close i { left: 70px; top: 10px; } } @media (max-width: 320px) { .fancybox-close i { left: 30px; top: 20px; } } ================================================ FILE: backend/tests/integration/tests/pruning/website/css/font-awesome.css ================================================ /*! * Font Awesome 4.0.3 by @davegandy - http://fontawesome.io - @fontawesome * License - http://fontawesome.io/license (Font: SIL OFL 1.1, CSS: MIT License) */ /* FONT PATH * -------------------------- */ @font-face { font-family: "FontAwesome"; src: url("../fonts/fontawesome-webfont.eot?v=4.0.3"); src: url("../fonts/fontawesome-webfont.eot?#iefix&v=4.0.3") format("embedded-opentype"), url("../fonts/fontawesome-webfont.woff?v=4.0.3") format("woff"), url("../fonts/fontawesome-webfont.ttf?v=4.0.3") format("truetype"), url("../fonts/fontawesome-webfont.svg?v=4.0.3#fontawesomeregular") format("svg"); font-weight: normal; font-style: normal; } .fa { display: inline-block; font-family: FontAwesome; font-style: normal; font-weight: normal; line-height: 1; -webkit-font-smoothing: antialiased; -moz-osx-font-smoothing: grayscale; } /* makes the font 33% larger relative to the icon container */ .fa-lg { font-size: 1.3333333333333333em; line-height: 0.75em; vertical-align: -15%; } .fa-2x { font-size: 2em; } .fa-3x { font-size: 3em; } .fa-4x { font-size: 4em; } .fa-5x { font-size: 5em; } .fa-fw { width: 1.2857142857142858em; text-align: center; } .fa-ul { padding-left: 0; margin-left: 2.142857142857143em; list-style-type: none; } .fa-ul > li { position: relative; } .fa-li { position: absolute; left: -2.142857142857143em; width: 2.142857142857143em; top: 0.14285714285714285em; text-align: center; } .fa-li.fa-lg { left: -1.8571428571428572em; } .fa-border { padding: 0.2em 0.25em 0.15em; border: solid 0.08em #eeeeee; border-radius: 0.1em; } .pull-right { float: right; } .pull-left { float: left; } .fa.pull-left { margin-right: 0.3em; } .fa.pull-right { margin-left: 0.3em; } .fa-spin { -webkit-animation: spin 2s infinite linear; -moz-animation: spin 2s infinite linear; -o-animation: spin 2s infinite linear; animation: spin 2s infinite linear; } @-moz-keyframes spin { 0% { -moz-transform: rotate(0deg); } 100% { -moz-transform: rotate(359deg); } } @-webkit-keyframes spin { 0% { -webkit-transform: rotate(0deg); } 100% { -webkit-transform: rotate(359deg); } } @-o-keyframes spin { 0% { -o-transform: rotate(0deg); } 100% { -o-transform: rotate(359deg); } } @-ms-keyframes spin { 0% { -ms-transform: rotate(0deg); } 100% { -ms-transform: rotate(359deg); } } @keyframes spin { 0% { transform: rotate(0deg); } 100% { transform: rotate(359deg); } } .fa-rotate-90 { filter: progid:DXImageTransform.Microsoft.BasicImage(rotation=1); -webkit-transform: rotate(90deg); -moz-transform: rotate(90deg); -ms-transform: rotate(90deg); -o-transform: rotate(90deg); transform: rotate(90deg); } .fa-rotate-180 { filter: progid:DXImageTransform.Microsoft.BasicImage(rotation=2); -webkit-transform: rotate(180deg); -moz-transform: rotate(180deg); -ms-transform: rotate(180deg); -o-transform: rotate(180deg); transform: rotate(180deg); } .fa-rotate-270 { filter: progid:DXImageTransform.Microsoft.BasicImage(rotation=3); -webkit-transform: rotate(270deg); -moz-transform: rotate(270deg); -ms-transform: rotate(270deg); -o-transform: rotate(270deg); transform: rotate(270deg); } .fa-flip-horizontal { filter: progid:DXImageTransform.Microsoft.BasicImage(rotation=0, mirror=1); -webkit-transform: scale(-1, 1); -moz-transform: scale(-1, 1); -ms-transform: scale(-1, 1); -o-transform: scale(-1, 1); transform: scale(-1, 1); } .fa-flip-vertical { filter: progid:DXImageTransform.Microsoft.BasicImage(rotation=2, mirror=1); -webkit-transform: scale(1, -1); -moz-transform: scale(1, -1); -ms-transform: scale(1, -1); -o-transform: scale(1, -1); transform: scale(1, -1); } .fa-stack { position: relative; display: inline-block; width: 2em; height: 2em; line-height: 2em; vertical-align: middle; } .fa-stack-1x, .fa-stack-2x { position: absolute; left: 0; width: 100%; text-align: center; } .fa-stack-1x { line-height: inherit; } .fa-stack-2x { font-size: 2em; } .fa-inverse { color: #ffffff; } /* Font Awesome uses the Unicode Private Use Area (PUA) to ensure screen readers do not read off random characters that represent icons */ .fa-glass:before { content: "\f000"; } .fa-music:before { content: "\f001"; } .fa-search:before { content: "\f002"; } .fa-envelope-o:before { content: "\f003"; } .fa-heart:before { content: "\f004"; } .fa-star:before { content: "\f005"; } .fa-star-o:before { content: "\f006"; } .fa-user:before { content: "\f007"; } .fa-film:before { content: "\f008"; } .fa-th-large:before { content: "\f009"; } .fa-th:before { content: "\f00a"; } .fa-th-list:before { content: "\f00b"; } .fa-check:before { content: "\f00c"; } .fa-times:before { content: "\f00d"; } .fa-search-plus:before { content: "\f00e"; } .fa-search-minus:before { content: "\f010"; } .fa-power-off:before { content: "\f011"; } .fa-signal:before { content: "\f012"; } .fa-gear:before, .fa-cog:before { content: "\f013"; } .fa-trash-o:before { content: "\f014"; } .fa-home:before { content: "\f015"; } .fa-file-o:before { content: "\f016"; } .fa-clock-o:before { content: "\f017"; } .fa-road:before { content: "\f018"; } .fa-download:before { content: "\f019"; } .fa-arrow-circle-o-down:before { content: "\f01a"; } .fa-arrow-circle-o-up:before { content: "\f01b"; } .fa-inbox:before { content: "\f01c"; } .fa-play-circle-o:before { content: "\f01d"; } .fa-rotate-right:before, .fa-repeat:before { content: "\f01e"; } .fa-refresh:before { content: "\f021"; } .fa-list-alt:before { content: "\f022"; } .fa-lock:before { content: "\f023"; } .fa-flag:before { content: "\f024"; } .fa-headphones:before { content: "\f025"; } .fa-volume-off:before { content: "\f026"; } .fa-volume-down:before { content: "\f027"; } .fa-volume-up:before { content: "\f028"; } .fa-qrcode:before { content: "\f029"; } .fa-barcode:before { content: "\f02a"; } .fa-tag:before { content: "\f02b"; } .fa-tags:before { content: "\f02c"; } .fa-book:before { content: "\f02d"; } .fa-bookmark:before { content: "\f02e"; } .fa-print:before { content: "\f02f"; } .fa-camera:before { content: "\f030"; } .fa-font:before { content: "\f031"; } .fa-bold:before { content: "\f032"; } .fa-italic:before { content: "\f033"; } .fa-text-height:before { content: "\f034"; } .fa-text-width:before { content: "\f035"; } .fa-align-left:before { content: "\f036"; } .fa-align-center:before { content: "\f037"; } .fa-align-right:before { content: "\f038"; } .fa-align-justify:before { content: "\f039"; } .fa-list:before { content: "\f03a"; } .fa-dedent:before, .fa-outdent:before { content: "\f03b"; } .fa-indent:before { content: "\f03c"; } .fa-video-camera:before { content: "\f03d"; } .fa-picture-o:before { content: "\f03e"; } .fa-pencil:before { content: "\f040"; } .fa-map-marker:before { content: "\f041"; } .fa-adjust:before { content: "\f042"; } .fa-tint:before { content: "\f043"; } .fa-edit:before, .fa-pencil-square-o:before { content: "\f044"; } .fa-share-square-o:before { content: "\f045"; } .fa-check-square-o:before { content: "\f046"; } .fa-arrows:before { content: "\f047"; } .fa-step-backward:before { content: "\f048"; } .fa-fast-backward:before { content: "\f049"; } .fa-backward:before { content: "\f04a"; } .fa-play:before { content: "\f04b"; } .fa-pause:before { content: "\f04c"; } .fa-stop:before { content: "\f04d"; } .fa-forward:before { content: "\f04e"; } .fa-fast-forward:before { content: "\f050"; } .fa-step-forward:before { content: "\f051"; } .fa-eject:before { content: "\f052"; } .fa-chevron-left:before { content: "\f053"; } .fa-chevron-right:before { content: "\f054"; } .fa-plus-circle:before { content: "\f055"; } .fa-minus-circle:before { content: "\f056"; } .fa-times-circle:before { content: "\f057"; } .fa-check-circle:before { content: "\f058"; } .fa-question-circle:before { content: "\f059"; } .fa-info-circle:before { content: "\f05a"; } .fa-crosshairs:before { content: "\f05b"; } .fa-times-circle-o:before { content: "\f05c"; } .fa-check-circle-o:before { content: "\f05d"; } .fa-ban:before { content: "\f05e"; } .fa-arrow-left:before { content: "\f060"; } .fa-arrow-right:before { content: "\f061"; } .fa-arrow-up:before { content: "\f062"; } .fa-arrow-down:before { content: "\f063"; } .fa-mail-forward:before, .fa-share:before { content: "\f064"; } .fa-expand:before { content: "\f065"; } .fa-compress:before { content: "\f066"; } .fa-plus:before { content: "\f067"; } .fa-minus:before { content: "\f068"; } .fa-asterisk:before { content: "\f069"; } .fa-exclamation-circle:before { content: "\f06a"; } .fa-gift:before { content: "\f06b"; } .fa-leaf:before { content: "\f06c"; } .fa-fire:before { content: "\f06d"; } .fa-eye:before { content: "\f06e"; } .fa-eye-slash:before { content: "\f070"; } .fa-warning:before, .fa-exclamation-triangle:before { content: "\f071"; } .fa-plane:before { content: "\f072"; } .fa-calendar:before { content: "\f073"; } .fa-random:before { content: "\f074"; } .fa-comment:before { content: "\f075"; } .fa-magnet:before { content: "\f076"; } .fa-chevron-up:before { content: "\f077"; } .fa-chevron-down:before { content: "\f078"; } .fa-retweet:before { content: "\f079"; } .fa-shopping-cart:before { content: "\f07a"; } .fa-folder:before { content: "\f07b"; } .fa-folder-open:before { content: "\f07c"; } .fa-arrows-v:before { content: "\f07d"; } .fa-arrows-h:before { content: "\f07e"; } .fa-bar-chart-o:before { content: "\f080"; } .fa-twitter-square:before { content: "\f081"; } .fa-facebook-square:before { content: "\f082"; } .fa-camera-retro:before { content: "\f083"; } .fa-key:before { content: "\f084"; } .fa-gears:before, .fa-cogs:before { content: "\f085"; } .fa-comments:before { content: "\f086"; } .fa-thumbs-o-up:before { content: "\f087"; } .fa-thumbs-o-down:before { content: "\f088"; } .fa-star-half:before { content: "\f089"; } .fa-heart-o:before { content: "\f08a"; } .fa-sign-out:before { content: "\f08b"; } .fa-linkedin-square:before { content: "\f08c"; } .fa-thumb-tack:before { content: "\f08d"; } .fa-external-link:before { content: "\f08e"; } .fa-sign-in:before { content: "\f090"; } .fa-trophy:before { content: "\f091"; } .fa-github-square:before { content: "\f092"; } .fa-upload:before { content: "\f093"; } .fa-lemon-o:before { content: "\f094"; } .fa-phone:before { content: "\f095"; } .fa-square-o:before { content: "\f096"; } .fa-bookmark-o:before { content: "\f097"; } .fa-phone-square:before { content: "\f098"; } .fa-twitter:before { content: "\f099"; } .fa-facebook:before { content: "\f09a"; } .fa-github:before { content: "\f09b"; } .fa-unlock:before { content: "\f09c"; } .fa-credit-card:before { content: "\f09d"; } .fa-rss:before { content: "\f09e"; } .fa-hdd-o:before { content: "\f0a0"; } .fa-bullhorn:before { content: "\f0a1"; } .fa-bell:before { content: "\f0f3"; } .fa-certificate:before { content: "\f0a3"; } .fa-hand-o-right:before { content: "\f0a4"; } .fa-hand-o-left:before { content: "\f0a5"; } .fa-hand-o-up:before { content: "\f0a6"; } .fa-hand-o-down:before { content: "\f0a7"; } .fa-arrow-circle-left:before { content: "\f0a8"; } .fa-arrow-circle-right:before { content: "\f0a9"; } .fa-arrow-circle-up:before { content: "\f0aa"; } .fa-arrow-circle-down:before { content: "\f0ab"; } .fa-globe:before { content: "\f0ac"; } .fa-wrench:before { content: "\f0ad"; } .fa-tasks:before { content: "\f0ae"; } .fa-filter:before { content: "\f0b0"; } .fa-briefcase:before { content: "\f0b1"; } .fa-arrows-alt:before { content: "\f0b2"; } .fa-group:before, .fa-users:before { content: "\f0c0"; } .fa-chain:before, .fa-link:before { content: "\f0c1"; } .fa-cloud:before { content: "\f0c2"; } .fa-flask:before { content: "\f0c3"; } .fa-cut:before, .fa-scissors:before { content: "\f0c4"; } .fa-copy:before, .fa-files-o:before { content: "\f0c5"; } .fa-paperclip:before { content: "\f0c6"; } .fa-save:before, .fa-floppy-o:before { content: "\f0c7"; } .fa-square:before { content: "\f0c8"; } .fa-bars:before { content: "\f0c9"; } .fa-list-ul:before { content: "\f0ca"; } .fa-list-ol:before { content: "\f0cb"; } .fa-strikethrough:before { content: "\f0cc"; } .fa-underline:before { content: "\f0cd"; } .fa-table:before { content: "\f0ce"; } .fa-magic:before { content: "\f0d0"; } .fa-truck:before { content: "\f0d1"; } .fa-pinterest:before { content: "\f0d2"; } .fa-pinterest-square:before { content: "\f0d3"; } .fa-google-plus-square:before { content: "\f0d4"; } .fa-google-plus:before { content: "\f0d5"; } .fa-money:before { content: "\f0d6"; } .fa-caret-down:before { content: "\f0d7"; } .fa-caret-up:before { content: "\f0d8"; } .fa-caret-left:before { content: "\f0d9"; } .fa-caret-right:before { content: "\f0da"; } .fa-columns:before { content: "\f0db"; } .fa-unsorted:before, .fa-sort:before { content: "\f0dc"; } .fa-sort-down:before, .fa-sort-asc:before { content: "\f0dd"; } .fa-sort-up:before, .fa-sort-desc:before { content: "\f0de"; } .fa-envelope:before { content: "\f0e0"; } .fa-linkedin:before { content: "\f0e1"; } .fa-rotate-left:before, .fa-undo:before { content: "\f0e2"; } .fa-legal:before, .fa-gavel:before { content: "\f0e3"; } .fa-dashboard:before, .fa-tachometer:before { content: "\f0e4"; } .fa-comment-o:before { content: "\f0e5"; } .fa-comments-o:before { content: "\f0e6"; } .fa-flash:before, .fa-bolt:before { content: "\f0e7"; } .fa-sitemap:before { content: "\f0e8"; } .fa-umbrella:before { content: "\f0e9"; } .fa-paste:before, .fa-clipboard:before { content: "\f0ea"; } .fa-lightbulb-o:before { content: "\f0eb"; } .fa-exchange:before { content: "\f0ec"; } .fa-cloud-download:before { content: "\f0ed"; } .fa-cloud-upload:before { content: "\f0ee"; } .fa-user-md:before { content: "\f0f0"; } .fa-stethoscope:before { content: "\f0f1"; } .fa-suitcase:before { content: "\f0f2"; } .fa-bell-o:before { content: "\f0a2"; } .fa-coffee:before { content: "\f0f4"; } .fa-cutlery:before { content: "\f0f5"; } .fa-file-text-o:before { content: "\f0f6"; } .fa-building-o:before { content: "\f0f7"; } .fa-hospital-o:before { content: "\f0f8"; } .fa-ambulance:before { content: "\f0f9"; } .fa-medkit:before { content: "\f0fa"; } .fa-fighter-jet:before { content: "\f0fb"; } .fa-beer:before { content: "\f0fc"; } .fa-h-square:before { content: "\f0fd"; } .fa-plus-square:before { content: "\f0fe"; } .fa-angle-double-left:before { content: "\f100"; } .fa-angle-double-right:before { content: "\f101"; } .fa-angle-double-up:before { content: "\f102"; } .fa-angle-double-down:before { content: "\f103"; } .fa-angle-left:before { content: "\f104"; } .fa-angle-right:before { content: "\f105"; } .fa-angle-up:before { content: "\f106"; } .fa-angle-down:before { content: "\f107"; } .fa-desktop:before { content: "\f108"; } .fa-laptop:before { content: "\f109"; } .fa-tablet:before { content: "\f10a"; } .fa-mobile-phone:before, .fa-mobile:before { content: "\f10b"; } .fa-circle-o:before { content: "\f10c"; } .fa-quote-left:before { content: "\f10d"; } .fa-quote-right:before { content: "\f10e"; } .fa-spinner:before { content: "\f110"; } .fa-circle:before { content: "\f111"; } .fa-mail-reply:before, .fa-reply:before { content: "\f112"; } .fa-github-alt:before { content: "\f113"; } .fa-folder-o:before { content: "\f114"; } .fa-folder-open-o:before { content: "\f115"; } .fa-smile-o:before { content: "\f118"; } .fa-frown-o:before { content: "\f119"; } .fa-meh-o:before { content: "\f11a"; } .fa-gamepad:before { content: "\f11b"; } .fa-keyboard-o:before { content: "\f11c"; } .fa-flag-o:before { content: "\f11d"; } .fa-flag-checkered:before { content: "\f11e"; } .fa-terminal:before { content: "\f120"; } .fa-code:before { content: "\f121"; } .fa-reply-all:before { content: "\f122"; } .fa-mail-reply-all:before { content: "\f122"; } .fa-star-half-empty:before, .fa-star-half-full:before, .fa-star-half-o:before { content: "\f123"; } .fa-location-arrow:before { content: "\f124"; } .fa-crop:before { content: "\f125"; } .fa-code-fork:before { content: "\f126"; } .fa-unlink:before, .fa-chain-broken:before { content: "\f127"; } .fa-question:before { content: "\f128"; } .fa-info:before { content: "\f129"; } .fa-exclamation:before { content: "\f12a"; } .fa-superscript:before { content: "\f12b"; } .fa-subscript:before { content: "\f12c"; } .fa-eraser:before { content: "\f12d"; } .fa-puzzle-piece:before { content: "\f12e"; } .fa-microphone:before { content: "\f130"; } .fa-microphone-slash:before { content: "\f131"; } .fa-shield:before { content: "\f132"; } .fa-calendar-o:before { content: "\f133"; } .fa-fire-extinguisher:before { content: "\f134"; } .fa-rocket:before { content: "\f135"; } .fa-maxcdn:before { content: "\f136"; } .fa-chevron-circle-left:before { content: "\f137"; } .fa-chevron-circle-right:before { content: "\f138"; } .fa-chevron-circle-up:before { content: "\f139"; } .fa-chevron-circle-down:before { content: "\f13a"; } .fa-html5:before { content: "\f13b"; } .fa-css3:before { content: "\f13c"; } .fa-anchor:before { content: "\f13d"; } .fa-unlock-alt:before { content: "\f13e"; } .fa-bullseye:before { content: "\f140"; } .fa-ellipsis-h:before { content: "\f141"; } .fa-ellipsis-v:before { content: "\f142"; } .fa-rss-square:before { content: "\f143"; } .fa-play-circle:before { content: "\f144"; } .fa-ticket:before { content: "\f145"; } .fa-minus-square:before { content: "\f146"; } .fa-minus-square-o:before { content: "\f147"; } .fa-level-up:before { content: "\f148"; } .fa-level-down:before { content: "\f149"; } .fa-check-square:before { content: "\f14a"; } .fa-pencil-square:before { content: "\f14b"; } .fa-external-link-square:before { content: "\f14c"; } .fa-share-square:before { content: "\f14d"; } .fa-compass:before { content: "\f14e"; } .fa-toggle-down:before, .fa-caret-square-o-down:before { content: "\f150"; } .fa-toggle-up:before, .fa-caret-square-o-up:before { content: "\f151"; } .fa-toggle-right:before, .fa-caret-square-o-right:before { content: "\f152"; } .fa-euro:before, .fa-eur:before { content: "\f153"; } .fa-gbp:before { content: "\f154"; } .fa-dollar:before, .fa-usd:before { content: "\f155"; } .fa-rupee:before, .fa-inr:before { content: "\f156"; } .fa-cny:before, .fa-rmb:before, .fa-yen:before, .fa-jpy:before { content: "\f157"; } .fa-ruble:before, .fa-rouble:before, .fa-rub:before { content: "\f158"; } .fa-won:before, .fa-krw:before { content: "\f159"; } .fa-bitcoin:before, .fa-btc:before { content: "\f15a"; } .fa-file:before { content: "\f15b"; } .fa-file-text:before { content: "\f15c"; } .fa-sort-alpha-asc:before { content: "\f15d"; } .fa-sort-alpha-desc:before { content: "\f15e"; } .fa-sort-amount-asc:before { content: "\f160"; } .fa-sort-amount-desc:before { content: "\f161"; } .fa-sort-numeric-asc:before { content: "\f162"; } .fa-sort-numeric-desc:before { content: "\f163"; } .fa-thumbs-up:before { content: "\f164"; } .fa-thumbs-down:before { content: "\f165"; } .fa-youtube-square:before { content: "\f166"; } .fa-youtube:before { content: "\f167"; } .fa-xing:before { content: "\f168"; } .fa-xing-square:before { content: "\f169"; } .fa-youtube-play:before { content: "\f16a"; } .fa-dropbox:before { content: "\f16b"; } .fa-stack-overflow:before { content: "\f16c"; } .fa-instagram:before { content: "\f16d"; } .fa-flickr:before { content: "\f16e"; } .fa-adn:before { content: "\f170"; } .fa-bitbucket:before { content: "\f171"; } .fa-bitbucket-square:before { content: "\f172"; } .fa-tumblr:before { content: "\f173"; } .fa-tumblr-square:before { content: "\f174"; } .fa-long-arrow-down:before { content: "\f175"; } .fa-long-arrow-up:before { content: "\f176"; } .fa-long-arrow-left:before { content: "\f177"; } .fa-long-arrow-right:before { content: "\f178"; } .fa-apple:before { content: "\f179"; } .fa-windows:before { content: "\f17a"; } .fa-android:before { content: "\f17b"; } .fa-linux:before { content: "\f17c"; } .fa-dribbble:before { content: "\f17d"; } .fa-skype:before { content: "\f17e"; } .fa-foursquare:before { content: "\f180"; } .fa-trello:before { content: "\f181"; } .fa-female:before { content: "\f182"; } .fa-male:before { content: "\f183"; } .fa-gittip:before { content: "\f184"; } .fa-sun-o:before { content: "\f185"; } .fa-moon-o:before { content: "\f186"; } .fa-archive:before { content: "\f187"; } .fa-bug:before { content: "\f188"; } .fa-vk:before { content: "\f189"; } .fa-weibo:before { content: "\f18a"; } .fa-renren:before { content: "\f18b"; } .fa-pagelines:before { content: "\f18c"; } .fa-stack-exchange:before { content: "\f18d"; } .fa-arrow-circle-o-right:before { content: "\f18e"; } .fa-arrow-circle-o-left:before { content: "\f190"; } .fa-toggle-left:before, .fa-caret-square-o-left:before { content: "\f191"; } .fa-dot-circle-o:before { content: "\f192"; } .fa-wheelchair:before { content: "\f193"; } .fa-vimeo-square:before { content: "\f194"; } .fa-turkish-lira:before, .fa-try:before { content: "\f195"; } .fa-plus-square-o:before { content: "\f196"; } ================================================ FILE: backend/tests/integration/tests/pruning/website/css/style.css ================================================ /* Author URI: http://webthemez.com/ Note: Licence under Creative Commons Attribution 3.0 Do not remove the back-link in this web template -------------------------------------------------------*/ @import url("http://fonts.googleapis.com/css?family=Noto+Serif:400,400italic,700|Open+Sans:400,600,700"); @import url("font-awesome.css"); @import url("animate.css"); body { font-family: "Open Sans", Arial, sans-serif; font-size: 14px; font-weight: 300; line-height: 1.6em; color: #656565; } a:active { outline: 0; } .clear { clear: both; } h1, h2, h3, h4, h5, h6 { font-family: "Open Sans", Arial, sans-serif; font-weight: 700; line-height: 1.1em; color: #333; margin-bottom: 20px; } .container { padding: 0 20px 0 20px; position: relative; } #wrapper { width: 100%; margin: 0; padding: 0; } .row, .row-fluid { margin-bottom: 30px; } .row .row, .row-fluid .row-fluid { margin-bottom: 30px; } .row.nomargin, .row-fluid.nomargin { margin-bottom: 0; } img.img-polaroid { margin: 0 0 20px 0; } .img-box { max-width: 100%; } /* Header ==================================== */ header .navbar { margin-bottom: 0; } .navbar-default { border: none; } .navbar-brand { color: #222; text-transform: uppercase; font-size: 24px; font-weight: 700; line-height: 1em; letter-spacing: -1px; margin-top: 13px; padding: 0 0 0 15px; } .navbar-default .navbar-brand { color: #61b331; } header .navbar-collapse ul.navbar-nav { float: right; margin-right: 0; } header .navbar-default { background-color: #ffffff; } header .nav li a:hover, header .nav li a:focus, header .nav li.active a, header .nav li.active a:hover, header .nav li a.dropdown-toggle:hover, header .nav li a.dropdown-toggle:focus, header .nav li.active ul.dropdown-menu li a:hover, header .nav li.active ul.dropdown-menu li.active a { -webkit-transition: all 0.3s ease; -moz-transition: all 0.3s ease; -ms-transition: all 0.3s ease; -o-transition: all 0.3s ease; transition: all 0.3s ease; } header .navbar-default .navbar-nav > .open > a, header .navbar-default .navbar-nav > .open > a:hover, header .navbar-default .navbar-nav > .open > a:focus { -webkit-transition: all 0.3s ease; -moz-transition: all 0.3s ease; -ms-transition: all 0.3s ease; -o-transition: all 0.3s ease; transition: all 0.3s ease; } header .navbar { min-height: 70px; padding: 18px 0; } header .navbar-nav > li { padding-bottom: 12px; padding-top: 12px; } header .navbar-nav > li > a { padding-bottom: 6px; padding-top: 5px; margin-left: 2px; line-height: 30px; font-weight: 700; -webkit-transition: all 0.3s ease; -moz-transition: all 0.3s ease; -ms-transition: all 0.3s ease; -o-transition: all 0.3s ease; transition: all 0.3s ease; } .dropdown-menu li a:hover { color: #fff !important; } header .nav .caret { border-bottom-color: #f5f5f5; border-top-color: #f5f5f5; } .navbar-default .navbar-nav > .active > a, .navbar-default .navbar-nav > .active > a:hover, .navbar-default .navbar-nav > .active > a:focus { background-color: #fff; } .navbar-default .navbar-nav > .open > a, .navbar-default .navbar-nav > .open > a:hover, .navbar-default .navbar-nav > .open > a:focus { background-color: #fff; } .dropdown-menu { box-shadow: none; border-radius: 0; border: none; } .dropdown-menu li:last-child { padding-bottom: 0 !important; margin-bottom: 0; } header .nav li .dropdown-menu { padding: 0; } header .nav li .dropdown-menu li a { line-height: 28px; padding: 3px 12px; } .item-thumbs img { margin-bottom: 15px; } .flex-control-paging li a.flex-active { background: #000; background: rgb(255, 255, 255); cursor: default; } .flex-control-paging li a { width: 30px; height: 11px; display: block; background: #666; background: rgba(0, 0, 0, 0.5); cursor: pointer; text-indent: -9999px; -webkit-border-radius: 20px; -moz-border-radius: 20px; -o-border-radius: 20px; border-radius: 20px; box-shadow: inset 0 0 3px rgba(0, 0, 0, 0.3); } .panel-title > a { color: inherit; color: #fff; } .panel-group .panel-heading + .panel-collapse .panel-body { border-top: 1px solid #ddd; color: #fff; background-color: #9c9c9c; } /* --- menu --- */ header .navigation { float: right; } header ul.nav li { border: none; margin: 0; } header ul.nav li a { font-size: 12px; border: none; font-weight: 700; text-transform: uppercase; } header ul.nav li ul li a { font-size: 12px; border: none; font-weight: 300; text-transform: uppercase; } .navbar .nav > li > a { color: #848484; text-shadow: none; border: 1px solid rgba(255, 255, 255, 0) !important; } .navbar .nav a:hover { background: none; color: #14a085 !important; } .navbar .nav > .active > a, .navbar .nav > .active > a:hover { background: none; font-weight: 700; } .navbar .nav > .active > a:active, .navbar .nav > .active > a:focus { background: none; outline: 0; font-weight: 700; } .navbar .nav li .dropdown-menu { z-index: 2000; } header ul.nav li ul { margin-top: 1px; } header ul.nav li ul li ul { margin: 1px 0 0 1px; } .dropdown-menu .dropdown i { position: absolute; right: 0; margin-top: 3px; padding-left: 20px; } .navbar .nav > li > .dropdown-menu:before { display: inline-block; border-right: none; border-bottom: none; border-left: none; border-bottom-color: none; content: none; } .navbar-default .navbar-nav > .active > a, .navbar-default .navbar-nav > .active > a:hover, .navbar-default .navbar-nav > .active > a:focus { color: #14a085; } ul.nav li.dropdown a { z-index: 1000; display: block; } select.selectmenu { display: none; } .pageTitle { color: #fff; margin: 30px 0 3px; display: inline-block; } #featured { width: 100%; background: #000; position: relative; margin: 0; padding: 0; } /* Sliders ==================================== */ /* --- flexslider --- */ #featured .flexslider { padding: 0; background: #fff; position: relative; zoom: 1; } .flex-direction-nav .flex-prev { left: 0px; } .flex-direction-nav .flex-next { right: 0px; } .flex-caption { zoom: 0; color: #1c1d21; margin: 0 auto; padding: 1px; position: absolute; vertical-align: bottom; text-align: center; background-color: rgba(255, 255, 255, 0.26); bottom: 5%; display: block; left: 0; right: 0; } .flex-caption h3 { color: #fff; letter-spacing: 1px; margin-bottom: 8px; text-transform: uppercase; } .flex-caption p { margin: 0 0 15px; } .skill-home { margin-bottom: 50px; } .c1 { border: #ed5441 1px solid; background: #ed5441; } .c2 { border: #d867b2 1px solid; background: #d867b2; } .c3 { border: #61b331 1px solid; background: #4bc567; } .c4 { border: #609cec 1px solid; background: #26aff0; } .skill-home .icons { padding: 33px 0 0 0; width: 100%; height: 178px; color: rgb(255, 255, 255); font-size: 42px; font-size: 76px; text-align: center; -ms-border-radius: 50%; -moz-border-radius: 50%; -webkit-border-radius: 50%; border-radius: 0; display: inline-table; } .skill-home h2 { padding-top: 20px; font-size: 36px; font-weight: 700; } .testimonial-solid { padding: 50px 0 60px 0; margin: 0 0 0 0; background: #efefef; text-align: center; } .testi-icon-area { text-align: center; position: absolute; top: -84px; margin: 0 auto; width: 100%; color: #000; } .testi-icon-area .quote { padding: 15px 0 0 0; margin: 0 0 0 0; background: #ffffff; text-align: center; color: #26aff0; display: inline-table; width: 70px; height: 70px; -ms-border-radius: 50%; -moz-border-radius: 50%; -webkit-border-radius: 50%; border-radius: 0; font-size: 42px; border: 1px solid #26aff0; display: none; } .testi-icon-area .carousel-inner { margin: 20px 0; } .carousel-indicators { bottom: -30px; } .team-member { text-align: center; background-color: #f9f9f9; padding-bottom: 15px; } .fancybox-title-inside-wrap { padding: 3px 30px 6px; background: #292929; } .item_introtext { background-color: rgba(254, 254, 255, 0.66); margin: 0 auto; display: inline-block; padding: 25px; } .item_introtext span { font-size: 20px; display: block; font-weight: bold; } .item_introtext strong { font-size: 50px; display: block; padding: 14px 0 30px; } .item_introtext p { font-size: 20px !important; color: #1c1d21; font-weight: bold; } .form-control { border-radius: 0; } /* Testimonial ----------------------------------*/ .testimonial-area { padding: 0 0 0 0; margin: 0; background: url(../img/low-poly01.jpg) fixed center center; background-size: cover; -webkit-background-size: cover; -moz-background-size: cover; -ms-background-size: cover; color: red; } .testimonial-solid p { color: #1f1f1f; font-size: 16px; line-height: 30px; font-style: italic; } section.callaction { background: #fff; padding: 50px 0 0 0; } /* Content ==================================== */ #content { position: relative; background: #fff; padding: 50px 0 0px 0; } #content img { max-width: 100%; height: auto; } .cta-text { text-align: center; margin-top: 10px; } .big-cta .cta { margin-top: 10px; } .box { width: 100%; } .box-gray { background: #f8f8f8; padding: 20px 20px 30px; } .box-gray h4, .box-gray i { margin-bottom: 20px; } .box-bottom { padding: 20px 0; text-align: center; } .box-bottom a { color: #fff; font-weight: 700; } .box-bottom a:hover { color: #eee; text-decoration: none; } /* Bottom ==================================== */ #bottom { background: #fcfcfc; padding: 50px 0 0; } /* twitter */ #twitter-wrapper { text-align: center; width: 70%; margin: 0 auto; } #twitter em { font-style: normal; font-size: 13px; } #twitter em.twitterTime a { font-weight: 600; } #twitter ul { padding: 0; list-style: none; } #twitter ul li { font-size: 20px; line-height: 1.6em; font-weight: 300; margin-bottom: 20px; position: relative; word-break: break-word; } /* page headline ==================================== */ #inner-headline { background: #14a085; position: relative; margin: 0; padding: 0; color: #fefefe; /* margin: 15px; */ border-top: 10px solid #11967c; } #inner-headline .inner-heading h2 { color: #fff; margin: 20px 0 0 0; } /* --- breadcrumbs --- */ #inner-headline ul.breadcrumb { margin: 30px 0 0; float: left; } #inner-headline ul.breadcrumb li { margin-bottom: 0; padding-bottom: 0; } #inner-headline ul.breadcrumb li { font-size: 13px; color: #fff; } #inner-headline ul.breadcrumb li i { color: #dedede; } #inner-headline ul.breadcrumb li a { color: #fff; } ul.breadcrumb li a:hover { text-decoration: none; } /* Forms ============================= */ /* --- contact form ---- */ form#contactform input[type="text"] { width: 100%; border: 1px solid #f5f5f5; min-height: 40px; padding-left: 20px; font-size: 13px; padding-right: 20px; -webkit-box-sizing: border-box; -moz-box-sizing: border-box; box-sizing: border-box; } form#contactform textarea { border: 1px solid #f5f5f5; width: 100%; padding-left: 20px; padding-top: 10px; font-size: 13px; padding-right: 20px; -webkit-box-sizing: border-box; -moz-box-sizing: border-box; box-sizing: border-box; } form#contactform .validation { font-size: 11px; } #sendmessage { border: 1px solid #e6e6e6; background: #f6f6f6; display: none; text-align: center; padding: 15px 12px 15px 65px; margin: 10px 0; font-weight: 600; margin-bottom: 30px; } #sendmessage.show, .show { display: block; } form#commentform input[type="text"] { width: 100%; min-height: 40px; padding-left: 20px; font-size: 13px; padding-right: 20px; -webkit-box-sizing: border-box; -moz-box-sizing: border-box; box-sizing: border-box; -webkit-border-radius: 2px 2px 2px 2px; -moz-border-radius: 2px 2px 2px 2px; border-radius: 2px 2px 2px 2px; } form#commentform textarea { width: 100%; padding-left: 20px; padding-top: 10px; font-size: 13px; padding-right: 20px; -webkit-box-sizing: border-box; -moz-box-sizing: border-box; box-sizing: border-box; -webkit-border-radius: 2px 2px 2px 2px; -moz-border-radius: 2px 2px 2px 2px; border-radius: 2px 2px 2px 2px; } /* --- search form --- */ .search { float: right; margin: 35px 0 0; padding-bottom: 0; } #inner-headline form.input-append { margin: 0; padding: 0; } /* Portfolio ================================ */ .work-nav #filters { margin: 0; padding: 0; list-style: none; } .work-nav #filters li { margin: 0 10px 30px 0; padding: 0; float: left; } .work-nav #filters li a { color: #7f8289; font-size: 16px; display: block; } .work-nav #filters li a:hover { } .work-nav #filters li a.selected { color: #de5e60; } #thumbs { margin: 0; padding: 0; } #thumbs li { list-style-type: none; } .item-thumbs { position: relative; overflow: hidden; margin-bottom: 30px; cursor: pointer; } .item-thumbs a + img { width: 100%; } .item-thumbs .hover-wrap { position: absolute; display: block; width: 100%; height: 100%; opacity: 0; filter: alpha(opacity=0); -webkit-transition: all 450ms ease-out 0s; -moz-transition: all 450ms ease-out 0s; -o-transition: all 450ms ease-out 0s; transition: all 450ms ease-out 0s; -webkit-transform: rotateY(180deg) scale(0.5, 0.5); -moz-transform: rotateY(180deg) scale(0.5, 0.5); -ms-transform: rotateY(180deg) scale(0.5, 0.5); -o-transform: rotateY(180deg) scale(0.5, 0.5); transform: rotateY(180deg) scale(0.5, 0.5); } .item-thumbs:hover .hover-wrap, .item-thumbs.active .hover-wrap { opacity: 1; filter: alpha(opacity=100); -webkit-transform: rotateY(0deg) scale(1, 1); -moz-transform: rotateY(0deg) scale(1, 1); -ms-transform: rotateY(0deg) scale(1, 1); -o-transform: rotateY(0deg) scale(1, 1); transform: rotateY(0deg) scale(1, 1); } .item-thumbs .hover-wrap .overlay-img { position: absolute; width: 90%; height: 91%; opacity: 0.5; filter: alpha(opacity=80); background: #14a085; } .item-thumbs .hover-wrap .overlay-img-thumb { position: absolute; border-radius: 60px; top: 50%; left: 45%; margin: -16px 0 0 -16px; color: #fff; font-size: 32px; line-height: 1em; opacity: 1; filter: alpha(opacity=100); } ul.portfolio-categ { margin: 10px 0 30px 0; padding: 0; float: left; list-style: none; } ul.portfolio-categ li { margin: 0; float: left; list-style: none; font-size: 13px; font-weight: 600; border: 1px solid #d5d5d5; margin-right: 15px; } ul.portfolio-categ li a { display: block; padding: 8px 20px; color: #14a085; } ul.portfolio-categ li.active { border: 1px solid #d7d8d6; background-color: #eaeaea; } ul.portfolio-categ li.active a:hover, ul.portfolio-categ li a:hover, ul.portfolio-categ li a:focus, ul.portfolio-categ li a:active { text-decoration: none; outline: 0; } #accordion-alt3 .panel-heading h4 { font-size: 13px; line-height: 28px; color: #6b6b6b; } .panel .panel-heading h4 { font-weight: 400; } .panel-title { margin-top: 0; margin-bottom: 0; font-size: 15px; color: inherit; } .panel-group .panel { margin-bottom: 0; border-radius: 2px; } .panel { margin-bottom: 18px; background-color: #b9b9b9; border: 1px solid transparent; border-radius: 2px; -webkit-box-shadow: 0 1px 1px rgba(0, 0, 0, 0.05); box-shadow: 0 1px 1px rgba(0, 0, 0, 0.05); } #accordion-alt3 .panel-heading h4 a i { font-size: 13px; line-height: 18px; width: 18px; height: 18px; margin-right: 5px; color: #fff; text-align: center; border-radius: 50%; margin-left: 6px; } .progress.pb-sm { height: 6px !important; } .progress { box-shadow: inset 0 0 2px rgba(0, 0, 0, 0.1); } .progress { overflow: hidden; height: 18px; margin-bottom: 18px; background-color: #f5f5f5; border-radius: 2px; -webkit-box-shadow: inset 0 1px 2px rgba(0, 0, 0, 0.1); box-shadow: inset 0 1px 2px rgba(0, 0, 0, 0.1); } .progress .progress-bar.progress-bar-red { background: #ed5441; } .progress .progress-bar.progress-bar-green { background: #51d466; } .progress .progress-bar.progress-bar-lblue { background: #32c8de; } /* --- portfolio detail --- */ .top-wrapper { margin-bottom: 20px; } .info-blocks { margin-bottom: 15px; } .info-blocks i.icon-info-blocks { float: left; color: #318fcf; font-size: 30px; min-width: 50px; margin-top: 6px; text-align: center; background-color: #efefef; padding: 15px; } .info-blocks .info-blocks-in { padding: 0 10px; overflow: hidden; } .info-blocks .info-blocks-in h3 { color: #555; font-size: 20px; line-height: 28px; margin: 0px; } .info-blocks .info-blocks-in p { font-size: 12px; } blockquote { font-size: 16px; font-weight: 400; font-family: "Noto Serif", serif; font-style: italic; padding-left: 0; color: #a2a2a2; line-height: 1.6em; border: none; } blockquote cite { display: block; font-size: 12px; color: #666; margin-top: 10px; } blockquote cite:before { content: "\2014 \0020"; } blockquote cite a, blockquote cite a:visited, blockquote cite a:visited { color: #555; } /* --- pullquotes --- */ .pullquote-left { display: block; color: #a2a2a2; font-family: "Noto Serif", serif; font-size: 14px; line-height: 1.6em; padding-left: 20px; } .pullquote-right { display: block; color: #a2a2a2; font-family: "Noto Serif", serif; font-size: 14px; line-height: 1.6em; padding-right: 20px; } /* --- button --- */ .btn { text-align: center; background: #318cca; color: #fff; border-radius: 0; padding: 10px 30px; } .btn-theme { color: #fff; } .btn-theme:hover { color: #eee; } /* --- list style --- */ ul.general { list-style: none; margin-left: 0; } ul.link-list { margin: 0; padding: 0; list-style: none; } ul.link-list li { margin: 0; padding: 2px 0 2px 0; list-style: none; } footer { background: #14a085; } footer ul.link-list li a { color: #ffffff; } footer ul.link-list li a:hover { color: #e2e2e2; } /* --- Heading style --- */ h4.heading { font-weight: 700; } .heading { margin-bottom: 30px; } .heading { position: relative; } .widgetheading { width: 100%; padding: 0; } #bottom .widgetheading { position: relative; border-bottom: #e6e6e6 1px solid; padding-bottom: 9px; } aside .widgetheading { position: relative; border-bottom: #e9e9e9 1px solid; padding-bottom: 9px; } footer .widgetheading { position: relative; } footer .widget .social-network { position: relative; } #bottom .widget .widgetheading span, aside .widget .widgetheading span, footer .widget .widgetheading span { position: absolute; width: 60px; height: 1px; bottom: -1px; right: 0; } .box-area { border: 1px solid #f3f3f3; padding: 0 15px 12px; padding-top: 41px; margin-top: -42px; text-align: left; background-color: #f9f9f9; position: relative; } /* --- Map --- */ .map { position: relative; margin-top: -50px; margin-bottom: 40px; } .map iframe { width: 100%; height: 450px; border: none; } .map-grid iframe { width: 100%; height: 350px; border: none; margin: 0 0 -5px 0; padding: 0; } ul.team-detail { margin: -10px 0 0 0; padding: 0; list-style: none; } ul.team-detail li { border-bottom: 1px dotted #e9e9e9; margin: 0 0 15px 0; padding: 0 0 15px 0; list-style: none; } ul.team-detail li label { font-size: 13px; } ul.team-detail li h4, ul.team-detail li label { margin-bottom: 0; } ul.team-detail li ul.social-network { border: none; margin: 0; padding: 0; } ul.team-detail li ul.social-network li { border: none; margin: 0; } ul.team-detail li ul.social-network li i { margin: 0; } .pricing-title { background: #fff; text-align: center; padding: 10px 0 10px 0; } .pricing-title h3 { font-weight: 600; margin-bottom: 0; } .pricing-offer { background: #fcfcfc; text-align: center; padding: 40px 0 40px 0; font-size: 18px; border-top: 1px solid #e6e6e6; border-bottom: 1px solid #e6e6e6; } .pricing-box.activeItem .pricing-offer { color: #fff; } .pricing-offer strong { font-size: 78px; line-height: 89px; } .pricing-offer sup { font-size: 28px; } .pricing-container { background: #fff; text-align: center; font-size: 14px; } .pricing-container strong { color: #353535; } .pricing-container ul { list-style: none; padding: 0; margin: 0; } .pricing-container ul li { border-bottom: 1px solid #f5f5f5; list-style: none; padding: 15px 0 15px 0; margin: 0 0 0 0; color: #222; } .pricing-action { margin: 0; background: #fcfcfc; text-align: center; padding: 20px 0 30px 0; } .pricing-wrapp { margin: 0 auto; width: 100%; background: #fd0000; } .pricing-box-item { border: 1px solid #f5f5f5; background: #f9f9f9; position: relative; margin: 0 0 20px 0; padding: 0; -webkit-box-shadow: 0 2px 0 rgba(0, 0, 0, 0.03); -moz-box-shadow: 0 2px 0 rgba(0, 0, 0, 0.03); box-shadow: 0 2px 0 rgba(0, 0, 0, 0.03); -webkit-box-sizing: border-box; -moz-box-sizing: border-box; box-sizing: border-box; } .pricing-box-item .pricing-heading { text-align: center; padding: 0px 0 0px 0; display: block; } .pricing-box-item.activeItem .pricing-heading { text-align: center; padding: 0px 0 1px 0; border-bottom: none; display: block; color: #fff; } .pricing-box-item.activeItem .pricing-heading h3 { } .pricing-box-item .pricing-heading h3 strong { font-size: 20px; font-weight: 700; letter-spacing: -1px; } .pricing-box-item .pricing-heading h3 { font-size: 35px; font-weight: 300; letter-spacing: -1px; } .pricing-box-item .pricing-terms { text-align: center; display: block; overflow: hidden; padding: 11px 0 5px; } .pricing-box-item .pricing-terms h6 { font-style: italic; margin-top: 10px; color: #14a085; font-size: 22px; font-family: "Noto Serif", serif; } .pricing-box-item .icon .price-circled { margin: 10px 10px 10px 0; display: inline-block !important; text-align: center !important; color: #fff; width: 68px; height: 68px; padding: 12px; font-size: 16px; font-weight: 700; line-height: 68px; text-shadow: none; cursor: pointer; background-color: #888; border-radius: 64px; -moz-border-radius: 64px; -webkit-border-radius: 64px; } .pricing-box-item .pricing-action { margin: 0; text-align: center; padding: 30px 0 30px 0; } /* ===== Widgets ===== */ /* --- flickr --- */ .widget .flickr_badge { width: 100%; } .widget .flickr_badge img { margin: 0 9px 20px 0; } footer .widget .flickr_badge { width: 100%; } footer .widget .flickr_badge img { margin: 0 9px 20px 0; } .flickr_badge img { width: 50px; height: 50px; float: left; margin: 0 9px 20px 0; } /* --- Recent post widget --- */ .recent-post { margin: 20px 0 0 0; padding: 0; line-height: 18px; } .recent-post h5 a:hover { text-decoration: none; } .recent-post .text h5 a { color: #353535; } footer { padding: 50px 0 0 0; color: #f8f8f8; } footer a { color: #fff; } footer a:hover { color: #eee; } footer h1, footer h2, footer h3, footer h4, footer h5, footer h6 { color: #fff; } footer address { line-height: 1.6em; color: #ffffff; } footer h5 a:hover, footer a:hover { text-decoration: none; } ul.social-network { list-style: none; margin: 0; } ul.social-network li { display: inline; margin: 0 5px; } #sub-footer { text-shadow: none; color: #f5f5f5; padding: 0; padding-top: 30px; margin: 20px 0 0 0; background: #14a085; } #sub-footer p { margin: 0; padding: 0; } #sub-footer span { color: #f5f5f5; } .copyright { text-align: left; font-size: 12px; } #sub-footer ul.social-network { float: right; } /* scroll to top */ .scrollup { position: fixed; width: 32px; height: 32px; bottom: 0px; right: 20px; background: #222; } a.scrollup { outline: 0; text-align: center; } a.scrollup:hover, a.scrollup:active, a.scrollup:focus { opacity: 1; text-decoration: none; } a.scrollup i { margin-top: 10px; color: #fff; } a.scrollup i:hover { text-decoration: none; } .absolute { position: absolute; } .relative { position: relative; } .aligncenter { text-align: center; } .aligncenter span { margin-left: 0; } .floatright { float: right; } .floatleft { float: left; } .floatnone { float: none; } .aligncenter { text-align: center; } img.pull-left, .align-left { float: left; margin: 0 15px 15px 0; } .widget img.pull-left { float: left; margin: 0 15px 15px 0; } img.pull-right, .align-right { float: right; margin: 0 0 15px 15px; } article img.pull-left, article .align-left { float: left; margin: 5px 15px 15px 0; } article img.pull-right, article .align-right { float: right; margin: 5px 0 15px 15px; } ============================= */ .clear-marginbot { margin-bottom: 0; } .marginbot10 { margin-bottom: 10px; } .marginbot20 { margin-bottom: 20px; } .marginbot30 { margin-bottom: 30px; } .marginbot40 { margin-bottom: 40px; } .clear-margintop { margin-top: 0; } .margintop10 { margin-top: 10px; } .margintop20 { margin-top: 20px; } .margintop30 { margin-top: 30px; } .margintop40 { margin-top: 40px; } /* Media queries ============================= */ @media (min-width: 768px) and (max-width: 979px) { a.detail { background: none; width: 100%; } footer .widget form input#appendedInputButton { display: block; width: 91%; -webkit-border-radius: 4px 4px 4px 4px; -moz-border-radius: 4px 4px 4px 4px; border-radius: 4px 4px 4px 4px; } footer .widget form .input-append .btn { display: block; width: 100%; padding-right: 0; padding-left: 0; -webkit-box-sizing: border-box; -moz-box-sizing: border-box; box-sizing: border-box; margin-top: 10px; } ul.related-folio li { width: 156px; margin: 0 20px 0 0; } } @media (max-width: 767px) { body { padding-right: 0; padding-left: 0; } .navbar-brand { margin-top: 10px; border-bottom: none; } .navbar-header { margin-top: 20px; border-bottom: none; } .navbar-nav { border-top: none; float: none; width: 100%; } .navbar .nav > .active > a, .navbar .nav > .active > a:hover { background: none; font-weight: 700; color: #26aff0; } header .navbar-nav > li { padding-bottom: 0px; padding-top: 2px; } header .nav li .dropdown-menu { margin-top: 0; } .dropdown-menu { position: absolute; top: 0; left: 40px; z-index: 1000; display: none; float: left; min-width: 160px; padding: 5px 0; margin: 2px 0 0; font-size: 13px; list-style: none; background-color: #fff; background-clip: padding-box; border: 1px solid #f5f5f5; border: 1px solid rgba(0, 0, 0, 0.15); border-radius: 0; -webkit-box-shadow: 0 6px 12px rgba(0, 0, 0, 0.175); box-shadow: 0 6px 12px rgba(0, 0, 0, 0.175); } .navbar-collapse.collapse { border: none; overflow: hidden; } .box { border-bottom: 1px solid #e9e9e9; padding-bottom: 20px; } #featured .flexslider .slide-caption { width: 90%; padding: 2%; position: absolute; left: 0; bottom: -40px; } #inner-headline .breadcrumb { float: left; clear: both; width: 100%; } .breadcrumb > li { font-size: 13px; } ul.portfolio li article a i.icon-48 { width: 20px; height: 20px; font-size: 16px; line-height: 20px; } .left-sidebar { border-right: none; padding: 0 0 0 0; border-bottom: 1px dotted #e6e6e6; padding-bottom: 10px; margin-bottom: 40px; } .right-sidebar { margin-top: 30px; border-left: none; padding: 0 0 0 0; } footer .col-lg-1, footer .col-lg-2, footer .col-lg-3, footer .col-lg-4, footer .col-lg-5, footer .col-lg-6, footer .col-lg-7, footer .col-lg-8, footer .col-lg-9, footer .col-lg-10, footer .col-lg-11, footer .col-lg-12 { margin-bottom: 20px; } #sub-footer ul.social-network { float: left; } [class*="span"] { margin-bottom: 20px; } } @media (max-width: 480px) { .bottom-article a.pull-right { float: left; margin-top: 20px; } .search { float: left; } .flexslider .flex-caption { display: none; } .cta-text { margin: 0 auto; text-align: center; } ul.portfolio li article a i { width: 20px; height: 20px; font-size: 14px; } } .box-area:before { position: absolute; width: 100%; height: 100%; z-index: 0; background-color: red; content: ""; position: absolute; top: 7px; left: -1px; width: 100%; height: 23px; background: #f9f9f9; -moz-transform: skewY(-3deg); -o-transform: skewY(-3deg); -ms-transform: skewY(-3deg); -webkit-transform: skewY(-3deg); transform: skewY(11deg); background-size: cover; } .box-area:after { position: absolute; width: 100%; height: 100%; z-index: 0; background-color: red; content: ""; position: absolute; top: 7px; left: 1px; width: 100%; height: 22px; background: #f9f9f9; -moz-transform: skewY(-3deg); -o-transform: skewY(-3deg); -ms-transform: skewY(-3deg); -webkit-transform: skewY(-3deg); transform: skewY(-11deg); background-size: cover; } .box-area h3 { margin-top: -16px; z-index: 12; position: relative; } .courses { padding: 50px 0; } .carousel-indicators li { display: inline-block; border: 1px solid #929292; } .textbox { background-color: #efefef; padding: 4px 25px; } .textbox h3 { margin: 0; padding: 22px 0 14px; font-size: 18px; } ================================================ FILE: backend/tests/integration/tests/pruning/website/index.html ================================================ Above Multi-purpose Free Bootstrap Responsive Template

    Our Featured Courses

    Lorem ipsum dolor sit amet, consectetur adipisicing elit. Dolores quae porro consequatur aliquam, incidunt eius magni provident, doloribus omnis minus temporibus perferendis nesciunt quam repellendus nulla nemo ipsum odit corrupti consequuntur possimus, vero mollitia velit ad consectetur. Alias, laborum excepturi nihil autem nemo numquam, ipsa architecto non, magni consequuntur quam.

    Web Development

    Lorem ipsum dolor sit amet, consectetur adipisicing elit. Dolores quae porro consequatur aliquam, incidunt eius magni provident

    UI Design

    Lorem ipsum dolor sit amet, consectetur adipisicing elit. Dolores quae porro consequatur aliquam, incidunt eius magni provident

    Interaction

    Lorem ipsum dolor sit amet, consectetur adipisicing elit. Dolores quae porro consequatur aliquam, incidunt eius magni provident

    User Experiance

    Lorem ipsum dolor sit amet, consectetur adipisicing elit. Dolores quae porro consequatur aliquam, incidunt eius magni provident

    Courses We Offer

    Lorem ipsum dolor sit amet, consectetur adipisicing elit. Dolores quae porro consequatur aliquam, incidunt eius magni provident, doloribus omnis minus temporibus perferendis nesciunt quam repellendus nulla nemo ipsum odit corrupti consequuntur possimus, vero mollitia velit ad consectetur. Alias, laborum excepturi nihil autem nemo numquam, ipsa architecto non, magni consequuntur quam.

    Heading Course

    Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Praesent vest sit amet, consec ibulum molestie lacus. Aenean nonummy hendrerit mauris. Phasellus porta.

    Heading Course

    Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Praesent vest sit amet, consec ibulum molestie lacus. Aenean nonummy hendrerit mauris. Phasellus porta.

    Heading Course

    Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Praesent vest sit amet, consec ibulum molestie lacus. Aenean nonummy hendrerit mauris. Phasellus porta.

    Heading Course

    Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Praesent vest sit amet, consec ibulum molestie lacus. Aenean nonummy hendrerit mauris. Phasellus porta.

    Heading Course

    Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Praesent vest sit amet, consec ibulum molestie lacus. Aenean nonummy hendrerit mauris. Phasellus porta.

    Heading Course

    Lorem ipsum dolor sit amet, consectetuer adipiscing elit. Praesent vest sit amet, consec ibulum molestie lacus. Aenean nonummy hendrerit mauris. Phasellus porta.

    ================================================ FILE: backend/tests/integration/tests/pruning/website/js/animate.js ================================================ jQuery(document).ready(function ($) { //animate effect $(".e_flash").hover( function () { $(this).addClass("animated flash"); }, function () { $(this).removeClass("animated flash"); }, ); $(".e_bounce").hover( function () { $(this).addClass("animated bounce"); }, function () { $(this).removeClass("animated bounce"); }, ); $(".e_shake").hover( function () { $(this).addClass("animated shake"); }, function () { $(this).removeClass("animated shake"); }, ); $(".e_tada").hover( function () { $(this).addClass("animated tada"); }, function () { $(this).removeClass("animated tada"); }, ); $(".e_swing").hover( function () { $(this).addClass("animated swing"); }, function () { $(this).removeClass("animated swing"); }, ); $(".e_wobble").hover( function () { $(this).addClass("animated wobble"); }, function () { $(this).removeClass("animated wobble"); }, ); $(".e_wiggle").hover( function () { $(this).addClass("animated wiggle"); }, function () { $(this).removeClass("animated wiggle"); }, ); $(".e_pulse").hover( function () { $(this).addClass("animated pulse"); }, function () { $(this).removeClass("animated pulse"); }, ); $(".e_flip").hover( function () { $(this).addClass("animated flip"); }, function () { $(this).removeClass("animated flip"); }, ); $(".e_flipInX").hover( function () { $(this).addClass("animated flipInX"); }, function () { $(this).removeClass("animated flipInX"); }, ); $(".e_flipOutX").hover( function () { $(this).addClass("animated flipOutX"); }, function () { $(this).removeClass("animated flipOutX"); }, ); $(".e_flipInY").hover( function () { $(this).addClass("animated flipInY"); }, function () { $(this).removeClass("animated flipInY"); }, ); $(".e_flipOutY").hover( function () { $(this).addClass("animated flipOutY"); }, function () { $(this).removeClass("animated flipOutY"); }, ); //Fading entrances $(".e_fadeIn").hover( function () { $(this).addClass("animated fadeIn"); }, function () { $(this).removeClass("animated fadeIn"); }, ); $(".e_fadeInUp").hover( function () { $(this).addClass("animated fadeInUp"); }, function () { $(this).removeClass("animated fadeInUp"); }, ); $(".e_fadeInDown").hover( function () { $(this).addClass("animated fadeInDown"); }, function () { $(this).removeClass("animated fadeInDown"); }, ); $(".e_fadeInLeft").hover( function () { $(this).addClass("animated fadeInLeft"); }, function () { $(this).removeClass("animated fadeInLeft"); }, ); $(".e_fadeInRight").hover( function () { $(this).addClass("animated fadeInRight"); }, function () { $(this).removeClass("animated fadeInRight"); }, ); $(".e_fadeInUpBig").hover( function () { $(this).addClass("animated fadeInUpBig"); }, function () { $(this).removeClass("animated fadeInUpBig"); }, ); $(".e_fadeInUpBig").hover( function () { $(this).addClass("animated fadeInUpBig"); }, function () { $(this).removeClass("animated fadeInUpBig"); }, ); $(".e_fadeInDownBig").hover( function () { $(this).addClass("animated fadeInDownBig"); }, function () { $(this).removeClass("animated fadeInDownBig"); }, ); $(".e_fadeInLeftBig").hover( function () { $(this).addClass("animated fadeInLeftBig"); }, function () { $(this).removeClass("animated fadeInLeftBig"); }, ); $(".e_fadeInRightBig").hover( function () { $(this).addClass("animated fadeInRightBig"); }, function () { $(this).removeClass("animated fadeInRightBig"); }, ); //Fading exits $(".e_fadeOut").hover( function () { $(this).addClass("animated fadeOut"); }, function () { $(this).removeClass("animated fadeOut"); }, ); $(".e_fadeOutUp").hover( function () { $(this).addClass("animated fadeOutUp"); }, function () { $(this).removeClass("animated fadeOutUp"); }, ); $(".e_fadeOutDown").hover( function () { $(this).addClass("animated fadeOutDown"); }, function () { $(this).removeClass("animated fadeOutDown"); }, ); $(".e_fadeOutLeft").hover( function () { $(this).addClass("animated fadeOutLeft"); }, function () { $(this).removeClass("animated fadeOutLeft"); }, ); $(".e_fadeOutRight").hover( function () { $(this).addClass("animated fadeOutRight"); }, function () { $(this).removeClass("animated fadeOutRight"); }, ); $(".e_fadeOutUpBig").hover( function () { $(this).addClass("animated fadeOutUpBig"); }, function () { $(this).removeClass("animated fadeOutUpBig"); }, ); $(".e_fadeOutDownBig").hover( function () { $(this).addClass("animated fadeOutDownBig"); }, function () { $(this).removeClass("animated fadeOutDownBig"); }, ); $(".e_fadeOutLeftBig").hover( function () { $(this).addClass("animated fadeOutLeftBig"); }, function () { $(this).removeClass("animated fadeOutLeftBig"); }, ); $(".e_fadeOutRightBig").hover( function () { $(this).addClass("animated fadeOutRightBig"); }, function () { $(this).removeClass("animated fadeOutRightBig"); }, ); //Bouncing entrances $(".e_bounceIn").hover( function () { $(this).addClass("animated bounceIn"); }, function () { $(this).removeClass("animated bounceIn"); }, ); $(".e_bounceInDown").hover( function () { $(this).addClass("animated bounceInDown"); }, function () { $(this).removeClass("animated bounceInDown"); }, ); $(".e_bounceInUp").hover( function () { $(this).addClass("animated bounceInUp"); }, function () { $(this).removeClass("animated bounceInUp"); }, ); $(".e_bounceInLeft").hover( function () { $(this).addClass("animated bounceInLeft"); }, function () { $(this).removeClass("animated bounceInLeft"); }, ); $(".e_bounceInRight").hover( function () { $(this).addClass("animated bounceInRight"); }, function () { $(this).removeClass("animated bounceInRight"); }, ); //Bouncing exits $(".e_bounceOut").hover( function () { $(this).addClass("animated bounceOut"); }, function () { $(this).removeClass("animated bounceOut"); }, ); $(".e_bounceOutDown").hover( function () { $(this).addClass("animated bounceOutDown"); }, function () { $(this).removeClass("animated bounceOutDown"); }, ); $(".e_bounceOutUp").hover( function () { $(this).addClass("animated bounceOutUp"); }, function () { $(this).removeClass("animated bounceOutUp"); }, ); $(".e_bounceOutLeft").hover( function () { $(this).addClass("animated bounceOutLeft"); }, function () { $(this).removeClass("animated bounceOutLeft"); }, ); $(".e_bounceOutRight").hover( function () { $(this).addClass("animated bounceOutRight"); }, function () { $(this).removeClass("animated bounceOutRight"); }, ); //Rotating entrances $(".e_rotateIn").hover( function () { $(this).addClass("animated rotateIn"); }, function () { $(this).removeClass("animated rotateIn"); }, ); $(".e_rotateInDownLeft").hover( function () { $(this).addClass("animated rotateInDownLeft"); }, function () { $(this).removeClass("animated rotateInDownLeft"); }, ); $(".e_rotateInDownRight").hover( function () { $(this).addClass("animated rotateInDownRight"); }, function () { $(this).removeClass("animated rotateInDownRight"); }, ); $(".e_rotateInUpRight").hover( function () { $(this).addClass("animated rotateInUpRight"); }, function () { $(this).removeClass("animated rotateInUpRight"); }, ); $(".e_rotateInUpLeft").hover( function () { $(this).addClass("animated rotateInUpLeft"); }, function () { $(this).removeClass("animated rotateInUpLeft"); }, ); //Rotating exits $(".e_rotateOut").hover( function () { $(this).addClass("animated rotateOut"); }, function () { $(this).removeClass("animated rotateOut"); }, ); $(".e_rotateOutDownLeft").hover( function () { $(this).addClass("animated rotateOutDownLeft"); }, function () { $(this).removeClass("animated rotateOutDownLeft"); }, ); $(".e_rotateOutDownRight").hover( function () { $(this).addClass("animated rotateOutDownRight"); }, function () { $(this).removeClass("animated rotateOutDownRight"); }, ); $(".e_rotateOutUpLeft").hover( function () { $(this).addClass("animated rotateOutUpLeft"); }, function () { $(this).removeClass("animated rotateOutUpLeft"); }, ); $(".e_rotateOutUpRight").hover( function () { $(this).addClass("animated rotateOutUpRight"); }, function () { $(this).removeClass("animated rotateOutUpRight"); }, ); //Lightspeed $(".e_lightSpeedIn").hover( function () { $(this).addClass("animated lightSpeedIn"); }, function () { $(this).removeClass("animated lightSpeedIn"); }, ); $(".e_lightSpeedOut").hover( function () { $(this).addClass("animated lightSpeedOut"); }, function () { $(this).removeClass("animated lightSpeedOut"); }, ); //specials $(".e_hinge").hover( function () { $(this).addClass("animated hinge"); }, function () { $(this).removeClass("animated hinge"); }, ); $(".e_rollIn").hover( function () { $(this).addClass("animated rollIn"); }, function () { $(this).removeClass("animated rollIn"); }, ); $(".e_rollOut").hover( function () { $(this).addClass("animated rollOut"); }, function () { $(this).removeClass("animated rollOut"); }, ); }); ================================================ FILE: backend/tests/integration/tests/pruning/website/js/custom.js ================================================ /*global jQuery:false */ jQuery(document).ready(function ($) { "use strict"; //add some elements with animate effect $(".big-cta").hover( function () { $(".cta a").addClass("animated shake"); }, function () { $(".cta a").removeClass("animated shake"); }, ); $(".box").hover( function () { $(this).find(".icon").addClass("animated fadeInDown"); $(this).find("p").addClass("animated fadeInUp"); }, function () { $(this).find(".icon").removeClass("animated fadeInDown"); $(this).find("p").removeClass("animated fadeInUp"); }, ); $(".accordion").on("show", function (e) { $(e.target) .prev(".accordion-heading") .find(".accordion-toggle") .addClass("active"); $(e.target) .prev(".accordion-heading") .find(".accordion-toggle i") .removeClass("icon-plus"); $(e.target) .prev(".accordion-heading") .find(".accordion-toggle i") .addClass("icon-minus"); }); $(".accordion").on("hide", function (e) { $(this).find(".accordion-toggle").not($(e.target)).removeClass("active"); $(this) .find(".accordion-toggle i") .not($(e.target)) .removeClass("icon-minus"); $(this).find(".accordion-toggle i").not($(e.target)).addClass("icon-plus"); }); // tooltip $(".social-network li a, .options_box .color a").tooltip(); // fancybox $(".fancybox").fancybox({ padding: 0, autoResize: true, beforeShow: function () { this.title = $(this.element).attr("title"); this.title = "

    " + this.title + "

    " + "

    " + $(this.element).parent().find("img").attr("alt") + "

    "; }, helpers: { title: { type: "inside" }, }, }); //scroll to top $(window).scroll(function () { if ($(this).scrollTop() > 100) { $(".scrollup").fadeIn(); } else { $(".scrollup").fadeOut(); } }); $(".scrollup").click(function () { $("html, body").animate({ scrollTop: 0 }, 1000); return false; }); $("#post-slider").flexslider({ // Primary Controls controlNav: false, //Boolean: Create navigation for paging control of each clide? Note: Leave true for manualControls usage directionNav: true, //Boolean: Create navigation for previous/next navigation? (true/false) prevText: "Previous", //String: Set the text for the "previous" directionNav item nextText: "Next", //String: Set the text for the "next" directionNav item // Secondary Navigation keyboard: true, //Boolean: Allow slider navigating via keyboard left/right keys multipleKeyboard: false, //{NEW} Boolean: Allow keyboard navigation to affect multiple sliders. Default behavior cuts out keyboard navigation with more than one slider present. mousewheel: false, //{UPDATED} Boolean: Requires jquery.mousewheel.js (https://github.com/brandonaaron/jquery-mousewheel) - Allows slider navigating via mousewheel pausePlay: false, //Boolean: Create pause/play dynamic element pauseText: "Pause", //String: Set the text for the "pause" pausePlay item playText: "Play", //String: Set the text for the "play" pausePlay item // Special properties controlsContainer: "", //{UPDATED} Selector: USE CLASS SELECTOR. Declare which container the navigation elements should be appended too. Default container is the FlexSlider element. Example use would be ".flexslider-container". Property is ignored if given element is not found. manualControls: "", //Selector: Declare custom control navigation. Examples would be ".flex-control-nav li" or "#tabs-nav li img", etc. The number of elements in your controlNav should match the number of slides/tabs. sync: "", //{NEW} Selector: Mirror the actions performed on this slider with another slider. Use with care. asNavFor: "", //{NEW} Selector: Internal property exposed for turning the slider into a thumbnail navigation for another slider }); $("#main-slider").flexslider({ namespace: "flex-", //{NEW} String: Prefix string attached to the class of every element generated by the plugin selector: ".slides > li", //{NEW} Selector: Must match a simple pattern. '{container} > {slide}' -- Ignore pattern at your own peril animation: "fade", //String: Select your animation type, "fade" or "slide" easing: "swing", //{NEW} String: Determines the easing method used in jQuery transitions. jQuery easing plugin is supported! direction: "horizontal", //String: Select the sliding direction, "horizontal" or "vertical" reverse: false, //{NEW} Boolean: Reverse the animation direction animationLoop: true, //Boolean: Should the animation loop? If false, directionNav will received "disable" classes at either end smoothHeight: false, //{NEW} Boolean: Allow height of the slider to animate smoothly in horizontal mode startAt: 0, //Integer: The slide that the slider should start on. Array notation (0 = first slide) slideshow: true, //Boolean: Animate slider automatically slideshowSpeed: 7000, //Integer: Set the speed of the slideshow cycling, in milliseconds animationSpeed: 600, //Integer: Set the speed of animations, in milliseconds initDelay: 0, //{NEW} Integer: Set an initialization delay, in milliseconds randomize: false, //Boolean: Randomize slide order // Usability features pauseOnAction: true, //Boolean: Pause the slideshow when interacting with control elements, highly recommended. pauseOnHover: false, //Boolean: Pause the slideshow when hovering over slider, then resume when no longer hovering useCSS: true, //{NEW} Boolean: Slider will use CSS3 transitions if available touch: true, //{NEW} Boolean: Allow touch swipe navigation of the slider on touch-enabled devices video: false, //{NEW} Boolean: If using video in the slider, will prevent CSS3 3D Transforms to avoid graphical glitches // Primary Controls controlNav: true, //Boolean: Create navigation for paging control of each clide? Note: Leave true for manualControls usage directionNav: true, //Boolean: Create navigation for previous/next navigation? (true/false) prevText: "Previous", //String: Set the text for the "previous" directionNav item nextText: "Next", //String: Set the text for the "next" directionNav item // Secondary Navigation keyboard: true, //Boolean: Allow slider navigating via keyboard left/right keys multipleKeyboard: false, //{NEW} Boolean: Allow keyboard navigation to affect multiple sliders. Default behavior cuts out keyboard navigation with more than one slider present. mousewheel: false, //{UPDATED} Boolean: Requires jquery.mousewheel.js (https://github.com/brandonaaron/jquery-mousewheel) - Allows slider navigating via mousewheel pausePlay: false, //Boolean: Create pause/play dynamic element pauseText: "Pause", //String: Set the text for the "pause" pausePlay item playText: "Play", //String: Set the text for the "play" pausePlay item // Special properties controlsContainer: "", //{UPDATED} Selector: USE CLASS SELECTOR. Declare which container the navigation elements should be appended too. Default container is the FlexSlider element. Example use would be ".flexslider-container". Property is ignored if given element is not found. manualControls: "", //Selector: Declare custom control navigation. Examples would be ".flex-control-nav li" or "#tabs-nav li img", etc. The number of elements in your controlNav should match the number of slides/tabs. sync: "", //{NEW} Selector: Mirror the actions performed on this slider with another slider. Use with care. asNavFor: "", //{NEW} Selector: Internal property exposed for turning the slider into a thumbnail navigation for another slider }); }); ================================================ FILE: backend/tests/integration/tests/pruning/website/js/flexslider/jquery.flexslider.js ================================================ /* * jQuery FlexSlider v1.8 * http://www.woothemes.com/flexslider/ * * Copyright 2012 WooThemes * Free to use under the MIT license. * http://www.opensource.org/licenses/mit-license.php * * Contributing Author: Tyler Smith */ (function ($) { //FlexSlider: Object Instance $.flexslider = function (el, options) { var slider = $(el); // slider DOM reference for use outside of the plugin $.data(el, "flexslider", slider); slider.init = function () { slider.vars = $.extend({}, $.flexslider.defaults, options); $.data(el, "flexsliderInit", true); slider.container = $(".slides", slider).eq(0); slider.slides = $(".slides:first > li", slider); slider.count = slider.slides.length; slider.animating = false; slider.currentSlide = slider.vars.slideToStart; slider.animatingTo = slider.currentSlide; slider.atEnd = slider.currentSlide == 0 ? true : false; slider.eventType = "ontouchstart" in document.documentElement ? "touchstart" : "click"; slider.cloneCount = 0; slider.cloneOffset = 0; slider.manualPause = false; slider.vertical = slider.vars.slideDirection == "vertical"; slider.prop = slider.vertical ? "top" : "marginLeft"; slider.args = {}; //Test for webbkit CSS3 Animations slider.transitions = "webkitTransition" in document.body.style && slider.vars.useCSS; if (slider.transitions) slider.prop = "-webkit-transform"; //Test for controlsContainer if (slider.vars.controlsContainer != "") { slider.controlsContainer = $(slider.vars.controlsContainer).eq( $(".slides").index(slider.container), ); slider.containerExists = slider.controlsContainer.length > 0; } //Test for manualControls if (slider.vars.manualControls != "") { slider.manualControls = $( slider.vars.manualControls, slider.containerExists ? slider.controlsContainer : slider, ); slider.manualExists = slider.manualControls.length > 0; } /////////////////////////////////////////////////////////////////// // FlexSlider: Randomize Slides if (slider.vars.randomize) { slider.slides.sort(function () { return Math.round(Math.random()) - 0.5; }); slider.container.empty().append(slider.slides); } /////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////// // FlexSlider: Slider Animation Initialize if (slider.vars.animation.toLowerCase() == "slide") { if (slider.transitions) { slider.setTransition(0); } slider.css({ overflow: "hidden" }); if (slider.vars.animationLoop) { slider.cloneCount = 2; slider.cloneOffset = 1; slider.container .append(slider.slides.filter(":first").clone().addClass("clone")) .prepend(slider.slides.filter(":last").clone().addClass("clone")); } //create newSlides to capture possible clones slider.newSlides = $(".slides:first > li", slider); var sliderOffset = -1 * (slider.currentSlide + slider.cloneOffset); if (slider.vertical) { slider.newSlides.css({ display: "block", width: "100%", float: "left", }); slider.container .height((slider.count + slider.cloneCount) * 200 + "%") .css("position", "absolute") .width("100%"); //Timeout function to give browser enough time to get proper height initially setTimeout(function () { slider .css({ position: "relative" }) .height(slider.slides.filter(":first").height()); slider.args[slider.prop] = slider.transitions ? "translate3d(0," + sliderOffset * slider.height() + "px,0)" : sliderOffset * slider.height() + "px"; slider.container.css(slider.args); }, 100); } else { slider.args[slider.prop] = slider.transitions ? "translate3d(" + sliderOffset * slider.width() + "px,0,0)" : sliderOffset * slider.width() + "px"; slider.container .width((slider.count + slider.cloneCount) * 200 + "%") .css(slider.args); //Timeout function to give browser enough time to get proper width initially setTimeout(function () { slider.newSlides .width(slider.width()) .css({ float: "left", display: "block" }); }, 100); } } else { //Default to fade //Not supporting fade CSS3 transitions right now slider.transitions = false; slider.slides .css({ width: "100%", float: "left", marginRight: "-100%" }) .eq(slider.currentSlide) .fadeIn(slider.vars.animationDuration); } /////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////// // FlexSlider: Control Nav if (slider.vars.controlNav) { if (slider.manualExists) { slider.controlNav = slider.manualControls; } else { var controlNavScaffold = $('
      '); var j = 1; for (var i = 0; i < slider.count; i++) { controlNavScaffold.append("
    1. " + j + "
    2. "); j++; } if (slider.containerExists) { $(slider.controlsContainer).append(controlNavScaffold); slider.controlNav = $( ".flex-control-nav li a", slider.controlsContainer, ); } else { slider.append(controlNavScaffold); slider.controlNav = $(".flex-control-nav li a", slider); } } slider.controlNav.eq(slider.currentSlide).addClass("active"); slider.controlNav.bind(slider.eventType, function (event) { event.preventDefault(); if (!$(this).hasClass("active")) { slider.controlNav.index($(this)) > slider.currentSlide ? (slider.direction = "next") : (slider.direction = "prev"); slider.flexAnimate( slider.controlNav.index($(this)), slider.vars.pauseOnAction, ); } }); } /////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////// //FlexSlider: Direction Nav if (slider.vars.directionNav) { var directionNavScaffold = $( '", ); if (slider.containerExists) { $(slider.controlsContainer).append(directionNavScaffold); slider.directionNav = $( ".flex-direction-nav li a", slider.controlsContainer, ); } else { slider.append(directionNavScaffold); slider.directionNav = $(".flex-direction-nav li a", slider); } //Set initial disable styles if necessary if (!slider.vars.animationLoop) { if (slider.currentSlide == 0) { slider.directionNav.filter(".prev").addClass("disabled"); } else if (slider.currentSlide == slider.count - 1) { slider.directionNav.filter(".next").addClass("disabled"); } } slider.directionNav.bind(slider.eventType, function (event) { event.preventDefault(); var target = $(this).hasClass("next") ? slider.getTarget("next") : slider.getTarget("prev"); if (slider.canAdvance(target)) { slider.flexAnimate(target, slider.vars.pauseOnAction); } }); } ////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////// //FlexSlider: Keyboard Nav if (slider.vars.keyboardNav && $("ul.slides").length == 1) { function keyboardMove(event) { if (slider.animating) { return; } else if (event.keyCode != 39 && event.keyCode != 37) { return; } else { if (event.keyCode == 39) { var target = slider.getTarget("next"); } else if (event.keyCode == 37) { var target = slider.getTarget("prev"); } if (slider.canAdvance(target)) { slider.flexAnimate(target, slider.vars.pauseOnAction); } } } $(document).bind("keyup", keyboardMove); } ////////////////////////////////////////////////////////////////// /////////////////////////////////////////////////////////////////// // FlexSlider: Mousewheel interaction if (slider.vars.mousewheel) { slider.mousewheelEvent = /Firefox/i.test(navigator.userAgent) ? "DOMMouseScroll" : "mousewheel"; slider.bind(slider.mousewheelEvent, function (e) { e.preventDefault(); e = e ? e : window.event; var wheelData = e.detail ? e.detail * -1 : e.originalEvent.wheelDelta / 40, target = wheelData < 0 ? slider.getTarget("next") : slider.getTarget("prev"); if (slider.canAdvance(target)) { slider.flexAnimate(target, slider.vars.pauseOnAction); } }); } /////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////// //FlexSlider: Slideshow Setup if (slider.vars.slideshow) { //pauseOnHover if (slider.vars.pauseOnHover && slider.vars.slideshow) { slider.hover( function () { slider.pause(); }, function () { if (!slider.manualPause) { slider.resume(); } }, ); } //Initialize animation slider.animatedSlides = setInterval( slider.animateSlides, slider.vars.slideshowSpeed, ); } ////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////// //FlexSlider: Pause/Play if (slider.vars.pausePlay) { var pausePlayScaffold = $( '
      ', ); if (slider.containerExists) { slider.controlsContainer.append(pausePlayScaffold); slider.pausePlay = $( ".flex-pauseplay span", slider.controlsContainer, ); } else { slider.append(pausePlayScaffold); slider.pausePlay = $(".flex-pauseplay span", slider); } var pausePlayState = slider.vars.slideshow ? "pause" : "play"; slider.pausePlay .addClass(pausePlayState) .text( pausePlayState == "pause" ? slider.vars.pauseText : slider.vars.playText, ); slider.pausePlay.bind(slider.eventType, function (event) { event.preventDefault(); if ($(this).hasClass("pause")) { slider.pause(); slider.manualPause = true; } else { slider.resume(); slider.manualPause = false; } }); } ////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////// //FlexSlider:Touch Swip Gestures //Some brilliant concepts adapted from the following sources //Source: TouchSwipe - http://www.netcu.de/jquery-touchwipe-iphone-ipad-library //Source: SwipeJS - http://swipejs.com if ("ontouchstart" in document.documentElement && slider.vars.touch) { //For brevity, variables are named for x-axis scrolling //The variables are then swapped if vertical sliding is applied //This reduces redundant code...I think :) //If debugging, recognize variables are named for horizontal scrolling var startX, startY, offset, cwidth, dx, startT, scrolling = false; slider.each(function () { if ("ontouchstart" in document.documentElement) { this.addEventListener("touchstart", onTouchStart, false); } }); function onTouchStart(e) { if (slider.animating) { e.preventDefault(); } else if (e.touches.length == 1) { slider.pause(); cwidth = slider.vertical ? slider.height() : slider.width(); startT = Number(new Date()); offset = slider.vertical ? (slider.currentSlide + slider.cloneOffset) * slider.height() : (slider.currentSlide + slider.cloneOffset) * slider.width(); startX = slider.vertical ? e.touches[0].pageY : e.touches[0].pageX; startY = slider.vertical ? e.touches[0].pageX : e.touches[0].pageY; slider.setTransition(0); this.addEventListener("touchmove", onTouchMove, false); this.addEventListener("touchend", onTouchEnd, false); } } function onTouchMove(e) { dx = slider.vertical ? startX - e.touches[0].pageY : startX - e.touches[0].pageX; scrolling = slider.vertical ? Math.abs(dx) < Math.abs(e.touches[0].pageX - startY) : Math.abs(dx) < Math.abs(e.touches[0].pageY - startY); if (!scrolling) { e.preventDefault(); if (slider.vars.animation == "slide" && slider.transitions) { if (!slider.vars.animationLoop) { dx = dx / ((slider.currentSlide == 0 && dx < 0) || (slider.currentSlide == slider.count - 1 && dx > 0) ? Math.abs(dx) / cwidth + 2 : 1); } slider.args[slider.prop] = slider.vertical ? "translate3d(0," + (-offset - dx) + "px,0)" : "translate3d(" + (-offset - dx) + "px,0,0)"; slider.container.css(slider.args); } } } function onTouchEnd(e) { slider.animating = false; if ( slider.animatingTo == slider.currentSlide && !scrolling && !(dx == null) ) { var target = dx > 0 ? slider.getTarget("next") : slider.getTarget("prev"); if ( (slider.canAdvance(target) && Number(new Date()) - startT < 550 && Math.abs(dx) > 20) || Math.abs(dx) > cwidth / 2 ) { slider.flexAnimate(target, slider.vars.pauseOnAction); } else if (slider.vars.animation !== "fade") { slider.flexAnimate( slider.currentSlide, slider.vars.pauseOnAction, ); } } //Finish the touch by undoing the touch session this.removeEventListener("touchmove", onTouchMove, false); this.removeEventListener("touchend", onTouchEnd, false); startX = null; startY = null; dx = null; offset = null; } } ////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////// //FlexSlider: Resize Functions (If necessary) if (slider.vars.animation.toLowerCase() == "slide") { $(window).resize(function () { if (!slider.animating && slider.is(":visible")) { if (slider.vertical) { slider.height(slider.slides.filter(":first").height()); slider.args[slider.prop] = -1 * (slider.currentSlide + slider.cloneOffset) * slider.slides.filter(":first").height() + "px"; if (slider.transitions) { slider.setTransition(0); slider.args[slider.prop] = slider.vertical ? "translate3d(0," + slider.args[slider.prop] + ",0)" : "translate3d(" + slider.args[slider.prop] + ",0,0)"; } slider.container.css(slider.args); } else { slider.newSlides.width(slider.width()); slider.args[slider.prop] = -1 * (slider.currentSlide + slider.cloneOffset) * slider.width() + "px"; if (slider.transitions) { slider.setTransition(0); slider.args[slider.prop] = slider.vertical ? "translate3d(0," + slider.args[slider.prop] + ",0)" : "translate3d(" + slider.args[slider.prop] + ",0,0)"; } slider.container.css(slider.args); } } }); } ////////////////////////////////////////////////////////////////// //FlexSlider: start() Callback slider.vars.start(slider); }; //FlexSlider: Animation Actions slider.flexAnimate = function (target, pause) { if (!slider.animating && slider.is(":visible")) { //Animating flag slider.animating = true; //FlexSlider: before() animation Callback slider.animatingTo = target; slider.vars.before(slider); //Optional paramter to pause slider when making an anmiation call if (pause) { slider.pause(); } //Update controlNav if (slider.vars.controlNav) { slider.controlNav.removeClass("active").eq(target).addClass("active"); } //Is the slider at either end slider.atEnd = target == 0 || target == slider.count - 1 ? true : false; if (!slider.vars.animationLoop && slider.vars.directionNav) { if (target == 0) { slider.directionNav .removeClass("disabled") .filter(".prev") .addClass("disabled"); } else if (target == slider.count - 1) { slider.directionNav .removeClass("disabled") .filter(".next") .addClass("disabled"); } else { slider.directionNav.removeClass("disabled"); } } if (!slider.vars.animationLoop && target == slider.count - 1) { slider.pause(); //FlexSlider: end() of cycle Callback slider.vars.end(slider); } if (slider.vars.animation.toLowerCase() == "slide") { var dimension = slider.vertical ? slider.slides.filter(":first").height() : slider.slides.filter(":first").width(); if ( slider.currentSlide == 0 && target == slider.count - 1 && slider.vars.animationLoop && slider.direction != "next" ) { slider.slideString = "0px"; } else if ( slider.currentSlide == slider.count - 1 && target == 0 && slider.vars.animationLoop && slider.direction != "prev" ) { slider.slideString = -1 * (slider.count + 1) * dimension + "px"; } else { slider.slideString = -1 * (target + slider.cloneOffset) * dimension + "px"; } slider.args[slider.prop] = slider.slideString; if (slider.transitions) { slider.setTransition(slider.vars.animationDuration); slider.args[slider.prop] = slider.vertical ? "translate3d(0," + slider.slideString + ",0)" : "translate3d(" + slider.slideString + ",0,0)"; slider.container .css(slider.args) .one("webkitTransitionEnd transitionend", function () { slider.wrapup(dimension); }); } else { slider.container.animate( slider.args, slider.vars.animationDuration, function () { slider.wrapup(dimension); }, ); } } else { //Default to Fade slider.slides .eq(slider.currentSlide) .fadeOut(slider.vars.animationDuration); slider.slides .eq(target) .fadeIn(slider.vars.animationDuration, function () { slider.wrapup(); }); } } }; //FlexSlider: Function to minify redundant animation actions slider.wrapup = function (dimension) { if (slider.vars.animation == "slide") { //Jump the slider if necessary if ( slider.currentSlide == 0 && slider.animatingTo == slider.count - 1 && slider.vars.animationLoop ) { slider.args[slider.prop] = -1 * slider.count * dimension + "px"; if (slider.transitions) { slider.setTransition(0); slider.args[slider.prop] = slider.vertical ? "translate3d(0," + slider.args[slider.prop] + ",0)" : "translate3d(" + slider.args[slider.prop] + ",0,0)"; } slider.container.css(slider.args); } else if ( slider.currentSlide == slider.count - 1 && slider.animatingTo == 0 && slider.vars.animationLoop ) { slider.args[slider.prop] = -1 * dimension + "px"; if (slider.transitions) { slider.setTransition(0); slider.args[slider.prop] = slider.vertical ? "translate3d(0," + slider.args[slider.prop] + ",0)" : "translate3d(" + slider.args[slider.prop] + ",0,0)"; } slider.container.css(slider.args); } } slider.animating = false; slider.currentSlide = slider.animatingTo; //FlexSlider: after() animation Callback slider.vars.after(slider); }; //FlexSlider: Automatic Slideshow slider.animateSlides = function () { if (!slider.animating) { slider.flexAnimate(slider.getTarget("next")); } }; //FlexSlider: Automatic Slideshow Pause slider.pause = function () { clearInterval(slider.animatedSlides); if (slider.vars.pausePlay) { slider.pausePlay .removeClass("pause") .addClass("play") .text(slider.vars.playText); } }; //FlexSlider: Automatic Slideshow Start/Resume slider.resume = function () { slider.animatedSlides = setInterval( slider.animateSlides, slider.vars.slideshowSpeed, ); if (slider.vars.pausePlay) { slider.pausePlay .removeClass("play") .addClass("pause") .text(slider.vars.pauseText); } }; //FlexSlider: Helper function for non-looping sliders slider.canAdvance = function (target) { if (!slider.vars.animationLoop && slider.atEnd) { if ( slider.currentSlide == 0 && target == slider.count - 1 && slider.direction != "next" ) { return false; } else if ( slider.currentSlide == slider.count - 1 && target == 0 && slider.direction == "next" ) { return false; } else { return true; } } else { return true; } }; //FlexSlider: Helper function to determine animation target slider.getTarget = function (dir) { slider.direction = dir; if (dir == "next") { return slider.currentSlide == slider.count - 1 ? 0 : slider.currentSlide + 1; } else { return slider.currentSlide == 0 ? slider.count - 1 : slider.currentSlide - 1; } }; //FlexSlider: Helper function to set CSS3 transitions slider.setTransition = function (dur) { slider.container.css({ "-webkit-transition-duration": dur / 1000 + "s" }); }; //FlexSlider: Initialize slider.init(); }; //FlexSlider: Default Settings $.flexslider.defaults = { animation: "slide", //String: Select your animation type, "fade" or "slide" slideDirection: "horizontal", //String: Select the sliding direction, "horizontal" or "vertical" slideshow: true, //Boolean: Animate slider automatically slideshowSpeed: 7000, //Integer: Set the speed of the slideshow cycling, in milliseconds animationDuration: 600, //Integer: Set the speed of animations, in milliseconds directionNav: false, //Boolean: Create navigation for previous/next navigation? (true/false) controlNav: true, //Boolean: Create navigation for paging control of each clide? Note: Leave true for manualControls usage keyboardNav: true, //Boolean: Allow slider navigating via keyboard left/right keys mousewheel: false, //Boolean: Allow slider navigating via mousewheel prevText: "Previous", //String: Set the text for the "previous" directionNav item nextText: "Next", //String: Set the text for the "next" directionNav item pausePlay: false, //Boolean: Create pause/play dynamic element pauseText: "Pause", //String: Set the text for the "pause" pausePlay item playText: "Play", //String: Set the text for the "play" pausePlay item randomize: false, //Boolean: Randomize slide order slideToStart: 0, //Integer: The slide that the slider should start on. Array notation (0 = first slide) animationLoop: true, //Boolean: Should the animation loop? If false, directionNav will received "disable" classes at either end pauseOnAction: true, //Boolean: Pause the slideshow when interacting with control elements, highly recommended. pauseOnHover: false, //Boolean: Pause the slideshow when hovering over slider, then resume when no longer hovering useCSS: true, //Boolean: Override the use of CSS3 Translate3d animations touch: true, //Boolean: Disable touchswipe events controlsContainer: "", //Selector: Declare which container the navigation elements should be appended too. Default container is the flexSlider element. Example use would be ".flexslider-container", "#container", etc. If the given element is not found, the default action will be taken. manualControls: "", //Selector: Declare custom control navigation. Example would be ".flex-control-nav li" or "#tabs-nav li img", etc. The number of elements in your controlNav should match the number of slides/tabs. start: function () {}, //Callback: function(slider) - Fires when the slider loads the first slide before: function () {}, //Callback: function(slider) - Fires asynchronously with each slider animation after: function () {}, //Callback: function(slider) - Fires after each slider animation completes end: function () {}, //Callback: function(slider) - Fires when the slider reaches the last slide (asynchronous) }; //FlexSlider: Plugin Function $.fn.flexslider = function (options) { return this.each(function () { var $slides = $(this).find(".slides > li"); if ($slides.length === 1) { $slides.fadeIn(400); if (options && options.start) options.start($(this)); } else if ($(this).data("flexsliderInit") != true) { new $.flexslider(this, options); } }); }; })(jQuery); ================================================ FILE: backend/tests/integration/tests/pruning/website/js/flexslider/setting.js ================================================ $(window).load(function () { $(".flexslider").flexslider(); }); ================================================ FILE: backend/tests/integration/tests/pruning/website/js/google-code-prettify/prettify.css ================================================ .com { color: #93a1a1; } .lit { color: #195f91; } .pun, .opn, .clo { color: #93a1a1; } .fun { color: #dc322f; } .str, .atv { color: #d14; } .kwd, .prettyprint .tag { color: #1e347b; } .typ, .atn, .dec, .var { color: teal; } .pln { color: #48484c; } .prettyprint { padding: 8px; background-color: #f7f7f9; border: 1px solid #e1e1e8; } .prettyprint.linenums { -webkit-box-shadow: inset 40px 0 0 #fbfbfc, inset 41px 0 0 #ececf0; -moz-box-shadow: inset 40px 0 0 #fbfbfc, inset 41px 0 0 #ececf0; box-shadow: inset 40px 0 0 #fbfbfc, inset 41px 0 0 #ececf0; } /* Specify class=linenums on a pre to get line numbering */ ol.linenums { margin: 0 0 0 33px; /* IE indents via margin-left */ } ol.linenums li { padding-left: 12px; color: #bebec5; line-height: 20px; text-shadow: 0 1px 0 #fff; } ================================================ FILE: backend/tests/integration/tests/pruning/website/js/google-code-prettify/prettify.js ================================================ var q = null; window.PR_SHOULD_USE_CONTINUATION = !0; (function () { function L(a) { function m(a) { var f = a.charCodeAt(0); if (f !== 92) return f; var b = a.charAt(1); return (f = r[b]) ? f : "0" <= b && b <= "7" ? parseInt(a.substring(1), 8) : b === "u" || b === "x" ? parseInt(a.substring(2), 16) : a.charCodeAt(1); } function e(a) { if (a < 32) return (a < 16 ? "\\x0" : "\\x") + a.toString(16); a = String.fromCharCode(a); if (a === "\\" || a === "-" || a === "[" || a === "]") a = "\\" + a; return a; } function h(a) { for ( var f = a .substring(1, a.length - 1) .match( /\\u[\dA-Fa-f]{4}|\\x[\dA-Fa-f]{2}|\\[0-3][0-7]{0,2}|\\[0-7]{1,2}|\\[\S\s]|[^\\]/g, ), a = [], b = [], o = f[0] === "^", c = o ? 1 : 0, i = f.length; c < i; ++c ) { var j = f[c]; if (/\\[bdsw]/i.test(j)) a.push(j); else { var j = m(j), d; c + 2 < i && "-" === f[c + 1] ? ((d = m(f[c + 2])), (c += 2)) : (d = j); b.push([j, d]); d < 65 || j > 122 || (d < 65 || j > 90 || b.push([Math.max(65, j) | 32, Math.min(d, 90) | 32]), d < 97 || j > 122 || b.push([Math.max(97, j) & -33, Math.min(d, 122) & -33])); } } b.sort(function (a, f) { return a[0] - f[0] || f[1] - a[1]; }); f = []; j = [NaN, NaN]; for (c = 0; c < b.length; ++c) (i = b[c]), i[0] <= j[1] + 1 ? (j[1] = Math.max(j[1], i[1])) : f.push((j = i)); b = ["["]; o && b.push("^"); b.push.apply(b, a); for (c = 0; c < f.length; ++c) (i = f[c]), b.push(e(i[0])), i[1] > i[0] && (i[1] + 1 > i[0] && b.push("-"), b.push(e(i[1]))); b.push("]"); return b.join(""); } function y(a) { for ( var f = a.source.match( /\[(?:[^\\\]]|\\[\S\s])*]|\\u[\dA-Fa-f]{4}|\\x[\dA-Fa-f]{2}|\\\d+|\\[^\dux]|\(\?[!:=]|[()^]|[^()[\\^]+/g, ), b = f.length, d = [], c = 0, i = 0; c < b; ++c ) { var j = f[c]; j === "(" ? ++i : "\\" === j.charAt(0) && (j = +j.substring(1)) && j <= i && (d[j] = -1); } for (c = 1; c < d.length; ++c) -1 === d[c] && (d[c] = ++t); for (i = c = 0; c < b; ++c) (j = f[c]), j === "(" ? (++i, d[i] === void 0 && (f[c] = "(?:")) : "\\" === j.charAt(0) && (j = +j.substring(1)) && j <= i && (f[c] = "\\" + d[i]); for (i = c = 0; c < b; ++c) "^" === f[c] && "^" !== f[c + 1] && (f[c] = ""); if (a.ignoreCase && s) for (c = 0; c < b; ++c) (j = f[c]), (a = j.charAt(0)), j.length >= 2 && a === "[" ? (f[c] = h(j)) : a !== "\\" && (f[c] = j.replace(/[A-Za-z]/g, function (a) { a = a.charCodeAt(0); return "[" + String.fromCharCode(a & -33, a | 32) + "]"; })); return f.join(""); } for (var t = 0, s = !1, l = !1, p = 0, d = a.length; p < d; ++p) { var g = a[p]; if (g.ignoreCase) l = !0; else if ( /[a-z]/i.test( g.source.replace(/\\u[\da-f]{4}|\\x[\da-f]{2}|\\[^UXux]/gi, ""), ) ) { s = !0; l = !1; break; } } for ( var r = { b: 8, t: 9, n: 10, v: 11, f: 12, r: 13 }, n = [], p = 0, d = a.length; p < d; ++p ) { g = a[p]; if (g.global || g.multiline) throw Error("" + g); n.push("(?:" + y(g) + ")"); } return RegExp(n.join("|"), l ? "gi" : "g"); } function M(a) { function m(a) { switch (a.nodeType) { case 1: if (e.test(a.className)) break; for (var g = a.firstChild; g; g = g.nextSibling) m(g); g = a.nodeName; if ("BR" === g || "LI" === g) (h[s] = "\n"), (t[s << 1] = y++), (t[(s++ << 1) | 1] = a); break; case 3: case 4: (g = a.nodeValue), g.length && ((g = p ? g.replace(/\r\n?/g, "\n") : g.replace(/[\t\n\r ]+/g, " ")), (h[s] = g), (t[s << 1] = y), (y += g.length), (t[(s++ << 1) | 1] = a)); } } var e = /(?:^|\s)nocode(?:\s|$)/, h = [], y = 0, t = [], s = 0, l; a.currentStyle ? (l = a.currentStyle.whiteSpace) : window.getComputedStyle && (l = document.defaultView .getComputedStyle(a, q) .getPropertyValue("white-space")); var p = l && "pre" === l.substring(0, 3); m(a); return { a: h.join("").replace(/\n$/, ""), c: t }; } function B(a, m, e, h) { m && ((a = { a: m, d: a }), e(a), h.push.apply(h, a.e)); } function x(a, m) { function e(a) { for ( var l = a.d, p = [l, "pln"], d = 0, g = a.a.match(y) || [], r = {}, n = 0, z = g.length; n < z; ++n ) { var f = g[n], b = r[f], o = void 0, c; if (typeof b === "string") c = !1; else { var i = h[f.charAt(0)]; if (i) (o = f.match(i[1])), (b = i[0]); else { for (c = 0; c < t; ++c) if (((i = m[c]), (o = f.match(i[1])))) { b = i[0]; break; } o || (b = "pln"); } if ( (c = b.length >= 5 && "lang-" === b.substring(0, 5)) && !(o && typeof o[1] === "string") ) (c = !1), (b = "src"); c || (r[f] = b); } i = d; d += f.length; if (c) { c = o[1]; var j = f.indexOf(c), k = j + c.length; o[2] && ((k = f.length - o[2].length), (j = k - c.length)); b = b.substring(5); B(l + i, f.substring(0, j), e, p); B(l + i + j, c, C(b, c), p); B(l + i + k, f.substring(k), e, p); } else p.push(l + i, b); } a.e = p; } var h = {}, y; (function () { for ( var e = a.concat(m), l = [], p = {}, d = 0, g = e.length; d < g; ++d ) { var r = e[d], n = r[3]; if (n) for (var k = n.length; --k >= 0; ) h[n.charAt(k)] = r; r = r[1]; n = "" + r; p.hasOwnProperty(n) || (l.push(r), (p[n] = q)); } l.push(/[\S\s]/); y = L(l); })(); var t = m.length; return e; } function u(a) { var m = [], e = []; a.tripleQuotedStrings ? m.push([ "str", /^(?:'''(?:[^'\\]|\\[\S\s]|''?(?=[^']))*(?:'''|$)|"""(?:[^"\\]|\\[\S\s]|""?(?=[^"]))*(?:"""|$)|'(?:[^'\\]|\\[\S\s])*(?:'|$)|"(?:[^"\\]|\\[\S\s])*(?:"|$))/, q, "'\"", ]) : a.multiLineStrings ? m.push([ "str", /^(?:'(?:[^'\\]|\\[\S\s])*(?:'|$)|"(?:[^"\\]|\\[\S\s])*(?:"|$)|`(?:[^\\`]|\\[\S\s])*(?:`|$))/, q, "'\"`", ]) : m.push([ "str", /^(?:'(?:[^\n\r'\\]|\\.)*(?:'|$)|"(?:[^\n\r"\\]|\\.)*(?:"|$))/, q, "\"'", ]); a.verbatimStrings && e.push(["str", /^@"(?:[^"]|"")*(?:"|$)/, q]); var h = a.hashComments; h && (a.cStyleComments ? (h > 1 ? m.push(["com", /^#(?:##(?:[^#]|#(?!##))*(?:###|$)|.*)/, q, "#"]) : m.push([ "com", /^#(?:(?:define|elif|else|endif|error|ifdef|include|ifndef|line|pragma|undef|warning)\b|[^\n\r]*)/, q, "#", ]), e.push([ "str", /^<(?:(?:(?:\.\.\/)*|\/?)(?:[\w-]+(?:\/[\w-]+)+)?[\w-]+\.h|[a-z]\w*)>/, q, ])) : m.push(["com", /^#[^\n\r]*/, q, "#"])); a.cStyleComments && (e.push(["com", /^\/\/[^\n\r]*/, q]), e.push(["com", /^\/\*[\S\s]*?(?:\*\/|$)/, q])); a.regexLiterals && e.push([ "lang-regex", /^(?:^^\.?|[!+-]|!=|!==|#|%|%=|&|&&|&&=|&=|\(|\*|\*=|\+=|,|-=|->|\/|\/=|:|::|;|<|<<|<<=|<=|=|==|===|>|>=|>>|>>=|>>>|>>>=|[?@[^]|\^=|\^\^|\^\^=|{|\||\|=|\|\||\|\|=|~|break|case|continue|delete|do|else|finally|instanceof|return|throw|try|typeof)\s*(\/(?=[^*/])(?:[^/[\\]|\\[\S\s]|\[(?:[^\\\]]|\\[\S\s])*(?:]|$))+\/)/, ]); (h = a.types) && e.push(["typ", h]); a = ("" + a.keywords).replace(/^ | $/g, ""); a.length && e.push(["kwd", RegExp("^(?:" + a.replace(/[\s,]+/g, "|") + ")\\b"), q]); m.push(["pln", /^\s+/, q, " \r\n\t\xa0"]); e.push( ["lit", /^@[$_a-z][\w$@]*/i, q], ["typ", /^(?:[@_]?[A-Z]+[a-z][\w$@]*|\w+_t\b)/, q], ["pln", /^[$_a-z][\w$@]*/i, q], [ "lit", /^(?:0x[\da-f]+|(?:\d(?:_\d+)*\d*(?:\.\d*)?|\.\d\+)(?:e[+-]?\d+)?)[a-z]*/i, q, "0123456789", ], ["pln", /^\\[\S\s]?/, q], ["pun", /^.[^\s\w"-$'./@\\`]*/, q], ); return x(m, e); } function D(a, m) { function e(a) { switch (a.nodeType) { case 1: if (k.test(a.className)) break; if ("BR" === a.nodeName) h(a), a.parentNode && a.parentNode.removeChild(a); else for (a = a.firstChild; a; a = a.nextSibling) e(a); break; case 3: case 4: if (p) { var b = a.nodeValue, d = b.match(t); if (d) { var c = b.substring(0, d.index); a.nodeValue = c; (b = b.substring(d.index + d[0].length)) && a.parentNode.insertBefore(s.createTextNode(b), a.nextSibling); h(a); c || a.parentNode.removeChild(a); } } } } function h(a) { function b(a, d) { var e = d ? a.cloneNode(!1) : a, f = a.parentNode; if (f) { var f = b(f, 1), g = a.nextSibling; f.appendChild(e); for (var h = g; h; h = g) (g = h.nextSibling), f.appendChild(h); } return e; } for (; !a.nextSibling; ) if (((a = a.parentNode), !a)) return; for ( var a = b(a.nextSibling, 0), e; (e = a.parentNode) && e.nodeType === 1; ) a = e; d.push(a); } var k = /(?:^|\s)nocode(?:\s|$)/, t = /\r\n?|\n/, s = a.ownerDocument, l; a.currentStyle ? (l = a.currentStyle.whiteSpace) : window.getComputedStyle && (l = s.defaultView .getComputedStyle(a, q) .getPropertyValue("white-space")); var p = l && "pre" === l.substring(0, 3); for (l = s.createElement("LI"); a.firstChild; ) l.appendChild(a.firstChild); for (var d = [l], g = 0; g < d.length; ++g) e(d[g]); m === (m | 0) && d[0].setAttribute("value", m); var r = s.createElement("OL"); r.className = "linenums"; for (var n = Math.max(0, (m - 1) | 0) || 0, g = 0, z = d.length; g < z; ++g) (l = d[g]), (l.className = "L" + ((g + n) % 10)), l.firstChild || l.appendChild(s.createTextNode("\xa0")), r.appendChild(l); a.appendChild(r); } function k(a, m) { for (var e = m.length; --e >= 0; ) { var h = m[e]; A.hasOwnProperty(h) ? window.console && console.warn("cannot override language handler %s", h) : (A[h] = a); } } function C(a, m) { if (!a || !A.hasOwnProperty(a)) a = /^\s*= o && (h += 2); e >= c && (a += 2); } } catch (w) { "console" in window && console.log(w && w.stack ? w.stack : w); } } var v = ["break,continue,do,else,for,if,return,while"], w = [ [ v, "auto,case,char,const,default,double,enum,extern,float,goto,int,long,register,short,signed,sizeof,static,struct,switch,typedef,union,unsigned,void,volatile", ], "catch,class,delete,false,import,new,operator,private,protected,public,this,throw,true,try,typeof", ], F = [ w, "alignof,align_union,asm,axiom,bool,concept,concept_map,const_cast,constexpr,decltype,dynamic_cast,explicit,export,friend,inline,late_check,mutable,namespace,nullptr,reinterpret_cast,static_assert,static_cast,template,typeid,typename,using,virtual,where", ], G = [ w, "abstract,boolean,byte,extends,final,finally,implements,import,instanceof,null,native,package,strictfp,super,synchronized,throws,transient", ], H = [ G, "as,base,by,checked,decimal,delegate,descending,dynamic,event,fixed,foreach,from,group,implicit,in,interface,internal,into,is,lock,object,out,override,orderby,params,partial,readonly,ref,sbyte,sealed,stackalloc,string,select,uint,ulong,unchecked,unsafe,ushort,var", ], w = [ w, "debugger,eval,export,function,get,null,set,undefined,var,with,Infinity,NaN", ], I = [ v, "and,as,assert,class,def,del,elif,except,exec,finally,from,global,import,in,is,lambda,nonlocal,not,or,pass,print,raise,try,with,yield,False,True,None", ], J = [ v, "alias,and,begin,case,class,def,defined,elsif,end,ensure,false,in,module,next,nil,not,or,redo,rescue,retry,self,super,then,true,undef,unless,until,when,yield,BEGIN,END", ], v = [v, "case,done,elif,esac,eval,fi,function,in,local,set,then,until"], K = /^(DIR|FILE|vector|(de|priority_)?queue|list|stack|(const_)?iterator|(multi)?(set|map)|bitset|u?(int|float)\d*)/, N = /\S/, O = u({ keywords: [ F, H, w, "caller,delete,die,do,dump,elsif,eval,exit,foreach,for,goto,if,import,last,local,my,next,no,our,print,package,redo,require,sub,undef,unless,until,use,wantarray,while,BEGIN,END" + I, J, v, ], hashComments: !0, cStyleComments: !0, multiLineStrings: !0, regexLiterals: !0, }), A = {}; k(O, ["default-code"]); k( x( [], [ ["pln", /^[^]*(?:>|$)/], ["com", /^<\!--[\S\s]*?(?:--\>|$)/], ["lang-", /^<\?([\S\s]+?)(?:\?>|$)/], ["lang-", /^<%([\S\s]+?)(?:%>|$)/], ["pun", /^(?:<[%?]|[%?]>)/], ["lang-", /^]*>([\S\s]+?)<\/xmp\b[^>]*>/i], ["lang-js", /^]*>([\S\s]*?)(<\/script\b[^>]*>)/i], ["lang-css", /^]*>([\S\s]*?)(<\/style\b[^>]*>)/i], ["lang-in.tag", /^(<\/?[a-z][^<>]*>)/i], ], ), ["default-markup", "htm", "html", "mxml", "xhtml", "xml", "xsl"], ); k( x( [ ["pln", /^\s+/, q, " \t\r\n"], ["atv", /^(?:"[^"]*"?|'[^']*'?)/, q, "\"'"], ], [ ["tag", /^^<\/?[a-z](?:[\w-.:]*\w)?|\/?>$/i], ["atn", /^(?!style[\s=]|on)[a-z](?:[\w:-]*\w)?/i], ["lang-uq.val", /^=\s*([^\s"'>]*(?:[^\s"'/>]|\/(?=\s)))/], ["pun", /^[/<->]+/], ["lang-js", /^on\w+\s*=\s*"([^"]+)"/i], ["lang-js", /^on\w+\s*=\s*'([^']+)'/i], ["lang-js", /^on\w+\s*=\s*([^\s"'>]+)/i], ["lang-css", /^style\s*=\s*"([^"]+)"/i], ["lang-css", /^style\s*=\s*'([^']+)'/i], ["lang-css", /^style\s*=\s*([^\s"'>]+)/i], ], ), ["in.tag"], ); k(x([], [["atv", /^[\S\s]+/]]), ["uq.val"]); k(u({ keywords: F, hashComments: !0, cStyleComments: !0, types: K }), [ "c", "cc", "cpp", "cxx", "cyc", "m", ]); k(u({ keywords: "null,true,false" }), ["json"]); k( u({ keywords: H, hashComments: !0, cStyleComments: !0, verbatimStrings: !0, types: K, }), ["cs"], ); k(u({ keywords: G, cStyleComments: !0 }), ["java"]); k(u({ keywords: v, hashComments: !0, multiLineStrings: !0 }), [ "bsh", "csh", "sh", ]); k( u({ keywords: I, hashComments: !0, multiLineStrings: !0, tripleQuotedStrings: !0, }), ["cv", "py"], ); k( u({ keywords: "caller,delete,die,do,dump,elsif,eval,exit,foreach,for,goto,if,import,last,local,my,next,no,our,print,package,redo,require,sub,undef,unless,until,use,wantarray,while,BEGIN,END", hashComments: !0, multiLineStrings: !0, regexLiterals: !0, }), ["perl", "pl", "pm"], ); k( u({ keywords: J, hashComments: !0, multiLineStrings: !0, regexLiterals: !0, }), ["rb"], ); k(u({ keywords: w, cStyleComments: !0, regexLiterals: !0 }), ["js"]); k( u({ keywords: "all,and,by,catch,class,else,extends,false,finally,for,if,in,is,isnt,loop,new,no,not,null,of,off,on,or,return,super,then,true,try,unless,until,when,while,yes", hashComments: 3, cStyleComments: !0, multilineStrings: !0, tripleQuotedStrings: !0, regexLiterals: !0, }), ["coffee"], ); k(x([], [["str", /^[\S\s]+/]]), ["regex"]); window.prettyPrintOne = function (a, m, e) { var h = document.createElement("PRE"); h.innerHTML = a; e && D(h, e); E({ g: m, i: e, h: h }); return h.innerHTML; }; window.prettyPrint = function (a) { function m() { for ( var e = window.PR_SHOULD_USE_CONTINUATION ? l.now() + 250 : Infinity; p < h.length && l.now() < e; p++ ) { var n = h[p], k = n.className; if (k.indexOf("prettyprint") >= 0) { var k = k.match(g), f, b; if ((b = !k)) { b = n; for (var o = void 0, c = b.firstChild; c; c = c.nextSibling) var i = c.nodeType, o = i === 1 ? o ? b : c : i === 3 ? N.test(c.nodeValue) ? b : o : o; b = (f = o === b ? void 0 : o) && "CODE" === f.tagName; } b && (k = f.className.match(g)); k && (k = k[1]); b = !1; for (o = n.parentNode; o; o = o.parentNode) if ( (o.tagName === "pre" || o.tagName === "code" || o.tagName === "xmp") && o.className && o.className.indexOf("prettyprint") >= 0 ) { b = !0; break; } b || ((b = (b = n.className.match(/\blinenums\b(?::(\d+))?/)) ? b[1] && b[1].length ? +b[1] : !0 : !1) && D(n, b), (d = { g: k, h: n, i: b }), E(d)); } } p < h.length ? setTimeout(m, 250) : a && a(); } for ( var e = [ document.getElementsByTagName("pre"), document.getElementsByTagName("code"), document.getElementsByTagName("xmp"), ], h = [], k = 0; k < e.length; ++k ) for (var t = 0, s = e[k].length; t < s; ++t) h.push(e[k][t]); var e = q, l = Date; l.now || (l = { now: function () { return +new Date(); }, }); var p = 0, d, g = /\blang(?:uage)?-([\w.]+)(?!\S)/; m(); }; window.PR = { createSimpleLexer: x, registerLangHandler: k, sourceDecorator: u, PR_ATTRIB_NAME: "atn", PR_ATTRIB_VALUE: "atv", PR_COMMENT: "com", PR_DECLARATION: "dec", PR_KEYWORD: "kwd", PR_LITERAL: "lit", PR_NOCODE: "nocode", PR_PLAIN: "pln", PR_PUNCTUATION: "pun", PR_SOURCE: "src", PR_STRING: "str", PR_TAG: "tag", PR_TYPE: "typ", }; })(); // make code pretty window.prettyPrint && prettyPrint(); ================================================ FILE: backend/tests/integration/tests/pruning/website/js/jquery.easing.1.3.js ================================================ /* * jQuery Easing v1.3 - http://gsgd.co.uk/sandbox/jquery/easing/ * * Uses the built in easing capabilities added In jQuery 1.1 * to offer multiple easing options * * TERMS OF USE - jQuery Easing * * Open source under the BSD License. * * Copyright © 2008 George McGinley Smith * All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, * are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, this list of * conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list * of conditions and the following disclaimer in the documentation and/or other materials * provided with the distribution. * * Neither the name of the author nor the names of contributors may be used to endorse * or promote products derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE * GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED * OF THE POSSIBILITY OF SUCH DAMAGE. * */ // t: current time, b: begInnIng value, c: change In value, d: duration jQuery.easing["jswing"] = jQuery.easing["swing"]; jQuery.extend(jQuery.easing, { def: "easeOutQuad", swing: function (x, t, b, c, d) { //alert(jQuery.easing.default); return jQuery.easing[jQuery.easing.def](x, t, b, c, d); }, easeInQuad: function (x, t, b, c, d) { return c * (t /= d) * t + b; }, easeOutQuad: function (x, t, b, c, d) { return -c * (t /= d) * (t - 2) + b; }, easeInOutQuad: function (x, t, b, c, d) { if ((t /= d / 2) < 1) return (c / 2) * t * t + b; return (-c / 2) * (--t * (t - 2) - 1) + b; }, easeInCubic: function (x, t, b, c, d) { return c * (t /= d) * t * t + b; }, easeOutCubic: function (x, t, b, c, d) { return c * ((t = t / d - 1) * t * t + 1) + b; }, easeInOutCubic: function (x, t, b, c, d) { if ((t /= d / 2) < 1) return (c / 2) * t * t * t + b; return (c / 2) * ((t -= 2) * t * t + 2) + b; }, easeInQuart: function (x, t, b, c, d) { return c * (t /= d) * t * t * t + b; }, easeOutQuart: function (x, t, b, c, d) { return -c * ((t = t / d - 1) * t * t * t - 1) + b; }, easeInOutQuart: function (x, t, b, c, d) { if ((t /= d / 2) < 1) return (c / 2) * t * t * t * t + b; return (-c / 2) * ((t -= 2) * t * t * t - 2) + b; }, easeInQuint: function (x, t, b, c, d) { return c * (t /= d) * t * t * t * t + b; }, easeOutQuint: function (x, t, b, c, d) { return c * ((t = t / d - 1) * t * t * t * t + 1) + b; }, easeInOutQuint: function (x, t, b, c, d) { if ((t /= d / 2) < 1) return (c / 2) * t * t * t * t * t + b; return (c / 2) * ((t -= 2) * t * t * t * t + 2) + b; }, easeInSine: function (x, t, b, c, d) { return -c * Math.cos((t / d) * (Math.PI / 2)) + c + b; }, easeOutSine: function (x, t, b, c, d) { return c * Math.sin((t / d) * (Math.PI / 2)) + b; }, easeInOutSine: function (x, t, b, c, d) { return (-c / 2) * (Math.cos((Math.PI * t) / d) - 1) + b; }, easeInExpo: function (x, t, b, c, d) { return t == 0 ? b : c * Math.pow(2, 10 * (t / d - 1)) + b; }, easeOutExpo: function (x, t, b, c, d) { return t == d ? b + c : c * (-Math.pow(2, (-10 * t) / d) + 1) + b; }, easeInOutExpo: function (x, t, b, c, d) { if (t == 0) return b; if (t == d) return b + c; if ((t /= d / 2) < 1) return (c / 2) * Math.pow(2, 10 * (t - 1)) + b; return (c / 2) * (-Math.pow(2, -10 * --t) + 2) + b; }, easeInCirc: function (x, t, b, c, d) { return -c * (Math.sqrt(1 - (t /= d) * t) - 1) + b; }, easeOutCirc: function (x, t, b, c, d) { return c * Math.sqrt(1 - (t = t / d - 1) * t) + b; }, easeInOutCirc: function (x, t, b, c, d) { if ((t /= d / 2) < 1) return (-c / 2) * (Math.sqrt(1 - t * t) - 1) + b; return (c / 2) * (Math.sqrt(1 - (t -= 2) * t) + 1) + b; }, easeInElastic: function (x, t, b, c, d) { var s = 1.70158; var p = 0; var a = c; if (t == 0) return b; if ((t /= d) == 1) return b + c; if (!p) p = d * 0.3; if (a < Math.abs(c)) { a = c; var s = p / 4; } else var s = (p / (2 * Math.PI)) * Math.asin(c / a); return ( -( a * Math.pow(2, 10 * (t -= 1)) * Math.sin(((t * d - s) * (2 * Math.PI)) / p) ) + b ); }, easeOutElastic: function (x, t, b, c, d) { var s = 1.70158; var p = 0; var a = c; if (t == 0) return b; if ((t /= d) == 1) return b + c; if (!p) p = d * 0.3; if (a < Math.abs(c)) { a = c; var s = p / 4; } else var s = (p / (2 * Math.PI)) * Math.asin(c / a); return ( a * Math.pow(2, -10 * t) * Math.sin(((t * d - s) * (2 * Math.PI)) / p) + c + b ); }, easeInOutElastic: function (x, t, b, c, d) { var s = 1.70158; var p = 0; var a = c; if (t == 0) return b; if ((t /= d / 2) == 2) return b + c; if (!p) p = d * (0.3 * 1.5); if (a < Math.abs(c)) { a = c; var s = p / 4; } else var s = (p / (2 * Math.PI)) * Math.asin(c / a); if (t < 1) return ( -0.5 * (a * Math.pow(2, 10 * (t -= 1)) * Math.sin(((t * d - s) * (2 * Math.PI)) / p)) + b ); return ( a * Math.pow(2, -10 * (t -= 1)) * Math.sin(((t * d - s) * (2 * Math.PI)) / p) * 0.5 + c + b ); }, easeInBack: function (x, t, b, c, d, s) { if (s == undefined) s = 1.70158; return c * (t /= d) * t * ((s + 1) * t - s) + b; }, easeOutBack: function (x, t, b, c, d, s) { if (s == undefined) s = 1.70158; return c * ((t = t / d - 1) * t * ((s + 1) * t + s) + 1) + b; }, easeInOutBack: function (x, t, b, c, d, s) { if (s == undefined) s = 1.70158; if ((t /= d / 2) < 1) return (c / 2) * (t * t * (((s *= 1.525) + 1) * t - s)) + b; return (c / 2) * ((t -= 2) * t * (((s *= 1.525) + 1) * t + s) + 2) + b; }, easeInBounce: function (x, t, b, c, d) { return c - jQuery.easing.easeOutBounce(x, d - t, 0, c, d) + b; }, easeOutBounce: function (x, t, b, c, d) { if ((t /= d) < 1 / 2.75) { return c * (7.5625 * t * t) + b; } else if (t < 2 / 2.75) { return c * (7.5625 * (t -= 1.5 / 2.75) * t + 0.75) + b; } else if (t < 2.5 / 2.75) { return c * (7.5625 * (t -= 2.25 / 2.75) * t + 0.9375) + b; } else { return c * (7.5625 * (t -= 2.625 / 2.75) * t + 0.984375) + b; } }, easeInOutBounce: function (x, t, b, c, d) { if (t < d / 2) return jQuery.easing.easeInBounce(x, t * 2, 0, c, d) * 0.5 + b; return ( jQuery.easing.easeOutBounce(x, t * 2 - d, 0, c, d) * 0.5 + c * 0.5 + b ); }, }); /* * * TERMS OF USE - EASING EQUATIONS * * Open source under the BSD License. * * Copyright © 2001 Robert Penner * All rights reserved. * * Redistribution and use in source and binary forms, with or without modification, * are permitted provided that the following conditions are met: * * Redistributions of source code must retain the above copyright notice, this list of * conditions and the following disclaimer. * Redistributions in binary form must reproduce the above copyright notice, this list * of conditions and the following disclaimer in the documentation and/or other materials * provided with the distribution. * * Neither the name of the author nor the names of contributors may be used to endorse * or promote products derived from this software without specific prior written permission. * * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF * MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE * COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, * EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE * GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED * AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED * OF THE POSSIBILITY OF SUCH DAMAGE. * */ ================================================ FILE: backend/tests/integration/tests/pruning/website/js/jquery.fancybox-media.js ================================================ /*! * Media helper for fancyBox * version: 1.0.5 (Tue, 23 Oct 2012) * @requires fancyBox v2.0 or later * * Usage: * $(".fancybox").fancybox({ * helpers : { * media: true * } * }); * * Set custom URL parameters: * $(".fancybox").fancybox({ * helpers : { * media: { * youtube : { * params : { * autoplay : 0 * } * } * } * } * }); * * Or: * $(".fancybox").fancybox({, * helpers : { * media: true * }, * youtube : { * autoplay: 0 * } * }); * * Supports: * * Youtube * http://www.youtube.com/watch?v=opj24KnzrWo * http://www.youtube.com/embed/opj24KnzrWo * http://youtu.be/opj24KnzrWo * Vimeo * http://vimeo.com/40648169 * http://vimeo.com/channels/staffpicks/38843628 * http://vimeo.com/groups/surrealism/videos/36516384 * http://player.vimeo.com/video/45074303 * Metacafe * http://www.metacafe.com/watch/7635964/dr_seuss_the_lorax_movie_trailer/ * http://www.metacafe.com/watch/7635964/ * Dailymotion * http://www.dailymotion.com/video/xoytqh_dr-seuss-the-lorax-premiere_people * Twitvid * http://twitvid.com/QY7MD * Twitpic * http://twitpic.com/7p93st * Instagram * http://instagr.am/p/IejkuUGxQn/ * http://instagram.com/p/IejkuUGxQn/ * Google maps * http://maps.google.com/maps?q=Eiffel+Tower,+Avenue+Gustave+Eiffel,+Paris,+France&t=h&z=17 * http://maps.google.com/?ll=48.857995,2.294297&spn=0.007666,0.021136&t=m&z=16 * http://maps.google.com/?ll=48.859463,2.292626&spn=0.000965,0.002642&t=m&z=19&layer=c&cbll=48.859524,2.292532&panoid=YJ0lq28OOy3VT2IqIuVY0g&cbp=12,151.58,,0,-15.56 */ (function ($) { "use strict"; //Shortcut for fancyBox object var F = $.fancybox, format = function (url, rez, params) { params = params || ""; if ($.type(params) === "object") { params = $.param(params, true); } $.each(rez, function (key, value) { url = url.replace("$" + key, value || ""); }); if (params.length) { url += (url.indexOf("?") > 0 ? "&" : "?") + params; } return url; }; //Add helper object F.helpers.media = { defaults: { youtube: { matcher: /(youtube\.com|youtu\.be)\/(watch\?v=|v\/|u\/|embed\/?)?(videoseries\?list=(.*)|[\w-]{11}|\?listType=(.*)&list=(.*)).*/i, params: { autoplay: 1, autohide: 1, fs: 1, rel: 0, hd: 1, wmode: "opaque", enablejsapi: 1, }, type: "iframe", url: "//www.youtube.com/embed/$3", }, vimeo: { matcher: /(?:vimeo(?:pro)?.com)\/(?:[^\d]+)?(\d+)(?:.*)/, params: { autoplay: 1, hd: 1, show_title: 1, show_byline: 1, show_portrait: 0, fullscreen: 1, }, type: "iframe", url: "//player.vimeo.com/video/$1", }, metacafe: { matcher: /metacafe.com\/(?:watch|fplayer)\/([\w\-]{1,10})/, params: { autoPlay: "yes", }, type: "swf", url: function (rez, params, obj) { obj.swf.flashVars = "playerVars=" + $.param(params, true); return "//www.metacafe.com/fplayer/" + rez[1] + "/.swf"; }, }, dailymotion: { matcher: /dailymotion.com\/video\/(.*)\/?(.*)/, params: { additionalInfos: 0, autoStart: 1, }, type: "swf", url: "//www.dailymotion.com/swf/video/$1", }, twitvid: { matcher: /twitvid\.com\/([a-zA-Z0-9_\-\?\=]+)/i, params: { autoplay: 0, }, type: "iframe", url: "//www.twitvid.com/embed.php?guid=$1", }, twitpic: { matcher: /twitpic\.com\/(?!(?:place|photos|events)\/)([a-zA-Z0-9\?\=\-]+)/i, type: "image", url: "//twitpic.com/show/full/$1/", }, instagram: { matcher: /(instagr\.am|instagram\.com)\/p\/([a-zA-Z0-9_\-]+)\/?/i, type: "image", url: "//$1/p/$2/media/", }, google_maps: { matcher: /maps\.google\.([a-z]{2,3}(\.[a-z]{2})?)\/(\?ll=|maps\?)(.*)/i, type: "iframe", url: function (rez) { return ( "//maps.google." + rez[1] + "/" + rez[3] + "" + rez[4] + "&output=" + (rez[4].indexOf("layer=c") > 0 ? "svembed" : "embed") ); }, }, }, beforeLoad: function (opts, obj) { var url = obj.href || "", type = false, what, item, rez, params; for (what in opts) { item = opts[what]; rez = url.match(item.matcher); if (rez) { type = item.type; params = $.extend( true, {}, item.params, obj[what] || ($.isPlainObject(opts[what]) ? opts[what].params : null), ); url = $.type(item.url) === "function" ? item.url.call(this, rez, params, obj) : format(item.url, rez, params); break; } } if (type) { obj.href = url; obj.type = type; obj.autoHeight = false; } }, }; })(jQuery); ================================================ FILE: backend/tests/integration/tests/pruning/website/js/jquery.fancybox.pack.js ================================================ /*! fancyBox v2.1.4 fancyapps.com | fancyapps.com/fancybox/#license */ (function (C, z, f, r) { var q = f(C), n = f(z), b = (f.fancybox = function () { b.open.apply(this, arguments); }), H = navigator.userAgent.match(/msie/), w = null, s = z.createTouch !== r, t = function (a) { return a && a.hasOwnProperty && a instanceof f; }, p = function (a) { return a && "string" === f.type(a); }, F = function (a) { return p(a) && 0 < a.indexOf("%"); }, l = function (a, d) { var e = parseInt(a, 10) || 0; d && F(a) && (e *= b.getViewport()[d] / 100); return Math.ceil(e); }, x = function (a, b) { return l(a, b) + "px"; }; f.extend(b, { version: "2.1.4", defaults: { padding: 15, margin: 20, width: 800, height: 600, minWidth: 100, minHeight: 100, maxWidth: 9999, maxHeight: 9999, autoSize: !0, autoHeight: !1, autoWidth: !1, autoResize: !0, autoCenter: !s, fitToView: !0, aspectRatio: !1, topRatio: 0.5, leftRatio: 0.5, scrolling: "auto", wrapCSS: "", arrows: !0, closeBtn: !0, closeClick: !1, nextClick: !1, mouseWheel: !0, autoPlay: !1, playSpeed: 3e3, preload: 3, modal: !1, loop: !0, ajax: { dataType: "html", headers: { "X-fancyBox": !0 } }, iframe: { scrolling: "auto", preload: !0 }, swf: { wmode: "transparent", allowfullscreen: "true", allowscriptaccess: "always", }, keys: { next: { 13: "left", 34: "up", 39: "left", 40: "up" }, prev: { 8: "right", 33: "down", 37: "right", 38: "down" }, close: [27], play: [32], toggle: [70], }, direction: { next: "left", prev: "right" }, scrollOutside: !0, index: 0, type: null, href: null, content: null, title: null, tpl: { wrap: '
      ', image: '', iframe: '", error: '

      The requested content cannot be loaded.
      Please try again later.

      ', closeBtn: '', next: '', prev: '', }, openEffect: "fade", openSpeed: 250, openEasing: "swing", openOpacity: !0, openMethod: "zoomIn", closeEffect: "fade", closeSpeed: 250, closeEasing: "swing", closeOpacity: !0, closeMethod: "zoomOut", nextEffect: "elastic", nextSpeed: 250, nextEasing: "swing", nextMethod: "changeIn", prevEffect: "elastic", prevSpeed: 250, prevEasing: "swing", prevMethod: "changeOut", helpers: { overlay: !0, title: !0 }, onCancel: f.noop, beforeLoad: f.noop, afterLoad: f.noop, beforeShow: f.noop, afterShow: f.noop, beforeChange: f.noop, beforeClose: f.noop, afterClose: f.noop, }, group: {}, opts: {}, previous: null, coming: null, current: null, isActive: !1, isOpen: !1, isOpened: !1, wrap: null, skin: null, outer: null, inner: null, player: { timer: null, isActive: !1 }, ajaxLoad: null, imgPreload: null, transitions: {}, helpers: {}, open: function (a, d) { if (a && (f.isPlainObject(d) || (d = {}), !1 !== b.close(!0))) return ( f.isArray(a) || (a = t(a) ? f(a).get() : [a]), f.each(a, function (e, c) { var k = {}, g, h, j, m, l; "object" === f.type(c) && (c.nodeType && (c = f(c)), t(c) ? ((k = { href: c.data("fancybox-href") || c.attr("href"), title: c.data("fancybox-title") || c.attr("title"), isDom: !0, element: c, }), f.metadata && f.extend(!0, k, c.metadata())) : (k = c)); g = d.href || k.href || (p(c) ? c : null); h = d.title !== r ? d.title : k.title || ""; m = (j = d.content || k.content) ? "html" : d.type || k.type; !m && k.isDom && ((m = c.data("fancybox-type")), m || (m = (m = c.prop("class").match(/fancybox\.(\w+)/)) ? m[1] : null)); p(g) && (m || (b.isImage(g) ? (m = "image") : b.isSWF(g) ? (m = "swf") : "#" === g.charAt(0) ? (m = "inline") : p(c) && ((m = "html"), (j = c))), "ajax" === m && ((l = g.split(/\s+/, 2)), (g = l.shift()), (l = l.shift()))); j || ("inline" === m ? g ? (j = f(p(g) ? g.replace(/.*(?=#[^\s]+$)/, "") : g)) : k.isDom && (j = c) : "html" === m ? (j = g) : !m && !g && k.isDom && ((m = "inline"), (j = c))); f.extend(k, { href: g, type: m, content: j, title: h, selector: l, }); a[e] = k; }), (b.opts = f.extend(!0, {}, b.defaults, d)), d.keys !== r && (b.opts.keys = d.keys ? f.extend({}, b.defaults.keys, d.keys) : !1), (b.group = a), b._start(b.opts.index) ); }, cancel: function () { var a = b.coming; a && !1 !== b.trigger("onCancel") && (b.hideLoading(), b.ajaxLoad && b.ajaxLoad.abort(), (b.ajaxLoad = null), b.imgPreload && (b.imgPreload.onload = b.imgPreload.onerror = null), a.wrap && a.wrap.stop(!0, !0).trigger("onReset").remove(), (b.coming = null), b.current || b._afterZoomOut(a)); }, close: function (a) { b.cancel(); !1 !== b.trigger("beforeClose") && (b.unbindEvents(), b.isActive && (!b.isOpen || !0 === a ? (f(".fancybox-wrap").stop(!0).trigger("onReset").remove(), b._afterZoomOut()) : ((b.isOpen = b.isOpened = !1), (b.isClosing = !0), f(".fancybox-item, .fancybox-nav").remove(), b.wrap.stop(!0, !0).removeClass("fancybox-opened"), b.transitions[b.current.closeMethod]()))); }, play: function (a) { var d = function () { clearTimeout(b.player.timer); }, e = function () { d(); b.current && b.player.isActive && (b.player.timer = setTimeout(b.next, b.current.playSpeed)); }, c = function () { d(); f("body").unbind(".player"); b.player.isActive = !1; b.trigger("onPlayEnd"); }; if (!0 === a || (!b.player.isActive && !1 !== a)) { if ( b.current && (b.current.loop || b.current.index < b.group.length - 1) ) (b.player.isActive = !0), f("body").bind({ "afterShow.player onUpdate.player": e, "onCancel.player beforeClose.player": c, "beforeLoad.player": d, }), e(), b.trigger("onPlayStart"); } else c(); }, next: function (a) { var d = b.current; d && (p(a) || (a = d.direction.next), b.jumpto(d.index + 1, a, "next")); }, prev: function (a) { var d = b.current; d && (p(a) || (a = d.direction.prev), b.jumpto(d.index - 1, a, "prev")); }, jumpto: function (a, d, e) { var c = b.current; c && ((a = l(a)), (b.direction = d || c.direction[a >= c.index ? "next" : "prev"]), (b.router = e || "jumpto"), c.loop && (0 > a && (a = c.group.length + (a % c.group.length)), (a %= c.group.length)), c.group[a] !== r && (b.cancel(), b._start(a))); }, reposition: function (a, d) { var e = b.current, c = e ? e.wrap : null, k; c && ((k = b._getPosition(d)), a && "scroll" === a.type ? (delete k.position, c.stop(!0, !0).animate(k, 200)) : (c.css(k), (e.pos = f.extend({}, e.dim, k)))); }, update: function (a) { var d = a && a.type, e = !d || "orientationchange" === d; e && (clearTimeout(w), (w = null)); b.isOpen && !w && (w = setTimeout( function () { var c = b.current; c && !b.isClosing && (b.wrap.removeClass("fancybox-tmp"), (e || "load" === d || ("resize" === d && c.autoResize)) && b._setDimension(), ("scroll" === d && c.canShrink) || b.reposition(a), b.trigger("onUpdate"), (w = null)); }, e && !s ? 0 : 300, )); }, toggle: function (a) { b.isOpen && ((b.current.fitToView = "boolean" === f.type(a) ? a : !b.current.fitToView), s && (b.wrap.removeAttr("style").addClass("fancybox-tmp"), b.trigger("onUpdate")), b.update()); }, hideLoading: function () { n.unbind(".loading"); f("#fancybox-loading").remove(); }, showLoading: function () { var a, d; b.hideLoading(); a = f('
      ') .click(b.cancel) .appendTo("body"); n.bind("keydown.loading", function (a) { if (27 === (a.which || a.keyCode)) a.preventDefault(), b.cancel(); }); b.defaults.fixed || ((d = b.getViewport()), a.css({ position: "absolute", top: 0.5 * d.h + d.y, left: 0.5 * d.w + d.x, })); }, getViewport: function () { var a = (b.current && b.current.locked) || !1, d = { x: q.scrollLeft(), y: q.scrollTop() }; a ? ((d.w = a[0].clientWidth), (d.h = a[0].clientHeight)) : ((d.w = s && C.innerWidth ? C.innerWidth : q.width()), (d.h = s && C.innerHeight ? C.innerHeight : q.height())); return d; }, unbindEvents: function () { b.wrap && t(b.wrap) && b.wrap.unbind(".fb"); n.unbind(".fb"); q.unbind(".fb"); }, bindEvents: function () { var a = b.current, d; a && (q.bind( "orientationchange.fb" + (s ? "" : " resize.fb") + (a.autoCenter && !a.locked ? " scroll.fb" : ""), b.update, ), (d = a.keys) && n.bind("keydown.fb", function (e) { var c = e.which || e.keyCode, k = e.target || e.srcElement; if (27 === c && b.coming) return !1; !e.ctrlKey && !e.altKey && !e.shiftKey && !e.metaKey && (!k || (!k.type && !f(k).is("[contenteditable]"))) && f.each(d, function (d, k) { if (1 < a.group.length && k[c] !== r) return b[d](k[c]), e.preventDefault(), !1; if (-1 < f.inArray(c, k)) return b[d](), e.preventDefault(), !1; }); }), f.fn.mousewheel && a.mouseWheel && b.wrap.bind("mousewheel.fb", function (d, c, k, g) { for ( var h = f(d.target || null), j = !1; h.length && !j && !h.is(".fancybox-skin") && !h.is(".fancybox-wrap"); ) (j = h[0] && !(h[0].style.overflow && "hidden" === h[0].style.overflow) && ((h[0].clientWidth && h[0].scrollWidth > h[0].clientWidth) || (h[0].clientHeight && h[0].scrollHeight > h[0].clientHeight))), (h = f(h).parent()); if (0 !== c && !j && 1 < b.group.length && !a.canShrink) { if (0 < g || 0 < k) b.prev(0 < g ? "down" : "left"); else if (0 > g || 0 > k) b.next(0 > g ? "up" : "right"); d.preventDefault(); } })); }, trigger: function (a, d) { var e, c = d || b.coming || b.current; if (c) { f.isFunction(c[a]) && (e = c[a].apply(c, Array.prototype.slice.call(arguments, 1))); if (!1 === e) return !1; c.helpers && f.each(c.helpers, function (d, e) { e && b.helpers[d] && f.isFunction(b.helpers[d][a]) && ((e = f.extend(!0, {}, b.helpers[d].defaults, e)), b.helpers[d][a](e, c)); }); f.event.trigger(a + ".fb"); } }, isImage: function (a) { return ( p(a) && a.match( /(^data:image\/.*,)|(\.(jp(e|g|eg)|gif|png|bmp|webp)((\?|#).*)?$)/i, ) ); }, isSWF: function (a) { return p(a) && a.match(/\.(swf)((\?|#).*)?$/i); }, _start: function (a) { var d = {}, e, c; a = l(a); e = b.group[a] || null; if (!e) return !1; d = f.extend(!0, {}, b.opts, e); e = d.margin; c = d.padding; "number" === f.type(e) && (d.margin = [e, e, e, e]); "number" === f.type(c) && (d.padding = [c, c, c, c]); d.modal && f.extend(!0, d, { closeBtn: !1, closeClick: !1, nextClick: !1, arrows: !1, mouseWheel: !1, keys: null, helpers: { overlay: { closeClick: !1 } }, }); d.autoSize && (d.autoWidth = d.autoHeight = !0); "auto" === d.width && (d.autoWidth = !0); "auto" === d.height && (d.autoHeight = !0); d.group = b.group; d.index = a; b.coming = d; if (!1 === b.trigger("beforeLoad")) b.coming = null; else { c = d.type; e = d.href; if (!c) return ( (b.coming = null), b.current && b.router && "jumpto" !== b.router ? ((b.current.index = a), b[b.router](b.direction)) : !1 ); b.isActive = !0; if ("image" === c || "swf" === c) (d.autoHeight = d.autoWidth = !1), (d.scrolling = "visible"); "image" === c && (d.aspectRatio = !0); "iframe" === c && s && (d.scrolling = "scroll"); d.wrap = f(d.tpl.wrap) .addClass( "fancybox-" + (s ? "mobile" : "desktop") + " fancybox-type-" + c + " fancybox-tmp " + d.wrapCSS, ) .appendTo(d.parent || "body"); f.extend(d, { skin: f(".fancybox-skin", d.wrap), outer: f(".fancybox-outer", d.wrap), inner: f(".fancybox-inner", d.wrap), }); f.each(["Top", "Right", "Bottom", "Left"], function (a, b) { d.skin.css("padding" + b, x(d.padding[a])); }); b.trigger("onReady"); if ("inline" === c || "html" === c) { if (!d.content || !d.content.length) return b._error("content"); } else if (!e) return b._error("href"); "image" === c ? b._loadImage() : "ajax" === c ? b._loadAjax() : "iframe" === c ? b._loadIframe() : b._afterLoad(); } }, _error: function (a) { f.extend(b.coming, { type: "html", autoWidth: !0, autoHeight: !0, minWidth: 0, minHeight: 0, scrolling: "no", hasError: a, content: b.coming.tpl.error, }); b._afterLoad(); }, _loadImage: function () { var a = (b.imgPreload = new Image()); a.onload = function () { this.onload = this.onerror = null; b.coming.width = this.width; b.coming.height = this.height; b._afterLoad(); }; a.onerror = function () { this.onload = this.onerror = null; b._error("image"); }; a.src = b.coming.href; !0 !== a.complete && b.showLoading(); }, _loadAjax: function () { var a = b.coming; b.showLoading(); b.ajaxLoad = f.ajax( f.extend({}, a.ajax, { url: a.href, error: function (a, e) { b.coming && "abort" !== e ? b._error("ajax", a) : b.hideLoading(); }, success: function (d, e) { "success" === e && ((a.content = d), b._afterLoad()); }, }), ); }, _loadIframe: function () { var a = b.coming, d = f(a.tpl.iframe.replace(/\{rnd\}/g, new Date().getTime())) .attr("scrolling", s ? "auto" : a.iframe.scrolling) .attr("src", a.href); f(a.wrap).bind("onReset", function () { try { f(this) .find("iframe") .hide() .attr("src", "//about:blank") .end() .empty(); } catch (a) {} }); a.iframe.preload && (b.showLoading(), d.one("load", function () { f(this).data("ready", 1); s || f(this).bind("load.fb", b.update); f(this) .parents(".fancybox-wrap") .width("100%") .removeClass("fancybox-tmp") .show(); b._afterLoad(); })); a.content = d.appendTo(a.inner); a.iframe.preload || b._afterLoad(); }, _preloadImages: function () { var a = b.group, d = b.current, e = a.length, c = d.preload ? Math.min(d.preload, e - 1) : 0, f, g; for (g = 1; g <= c; g += 1) (f = a[(d.index + g) % e]), "image" === f.type && f.href && (new Image().src = f.href); }, _afterLoad: function () { var a = b.coming, d = b.current, e, c, k, g, h; b.hideLoading(); if (a && !1 !== b.isActive) if (!1 === b.trigger("afterLoad", a, d)) a.wrap.stop(!0).trigger("onReset").remove(), (b.coming = null); else { d && (b.trigger("beforeChange", d), d.wrap .stop(!0) .removeClass("fancybox-opened") .find(".fancybox-item, .fancybox-nav") .remove()); b.unbindEvents(); e = a.content; c = a.type; k = a.scrolling; f.extend(b, { wrap: a.wrap, skin: a.skin, outer: a.outer, inner: a.inner, current: a, previous: d, }); g = a.href; switch (c) { case "inline": case "ajax": case "html": a.selector ? (e = f("
      ").html(e).find(a.selector)) : t(e) && (e.data("fancybox-placeholder") || e.data( "fancybox-placeholder", f('
      ') .insertAfter(e) .hide(), ), (e = e.show().detach()), a.wrap.bind("onReset", function () { f(this).find(e).length && e .hide() .replaceAll(e.data("fancybox-placeholder")) .data("fancybox-placeholder", !1); })); break; case "image": e = a.tpl.image.replace("{href}", g); break; case "swf": (e = ''), (h = ""), f.each(a.swf, function (a, b) { e += ''; h += " " + a + '="' + b + '"'; }), (e += '"); } (!t(e) || !e.parent().is(a.inner)) && a.inner.append(e); b.trigger("beforeShow"); a.inner.css( "overflow", "yes" === k ? "scroll" : "no" === k ? "hidden" : k, ); b._setDimension(); b.reposition(); b.isOpen = !1; b.coming = null; b.bindEvents(); if (b.isOpened) { if (d.prevMethod) b.transitions[d.prevMethod](); } else f(".fancybox-wrap") .not(a.wrap) .stop(!0) .trigger("onReset") .remove(); b.transitions[b.isOpened ? a.nextMethod : a.openMethod](); b._preloadImages(); } }, _setDimension: function () { var a = b.getViewport(), d = 0, e = !1, c = !1, e = b.wrap, k = b.skin, g = b.inner, h = b.current, c = h.width, j = h.height, m = h.minWidth, u = h.minHeight, n = h.maxWidth, v = h.maxHeight, s = h.scrolling, q = h.scrollOutside ? h.scrollbarWidth : 0, y = h.margin, p = l(y[1] + y[3]), r = l(y[0] + y[2]), z, A, t, D, B, G, C, E, w; e.add(k).add(g).width("auto").height("auto").removeClass("fancybox-tmp"); y = l(k.outerWidth(!0) - k.width()); z = l(k.outerHeight(!0) - k.height()); A = p + y; t = r + z; D = F(c) ? ((a.w - A) * l(c)) / 100 : c; B = F(j) ? ((a.h - t) * l(j)) / 100 : j; if ("iframe" === h.type) { if (((w = h.content), h.autoHeight && 1 === w.data("ready"))) try { w[0].contentWindow.document.location && (g.width(D).height(9999), (G = w.contents().find("body")), q && G.css("overflow-x", "hidden"), (B = G.height())); } catch (H) {} } else if (h.autoWidth || h.autoHeight) g.addClass("fancybox-tmp"), h.autoWidth || g.width(D), h.autoHeight || g.height(B), h.autoWidth && (D = g.width()), h.autoHeight && (B = g.height()), g.removeClass("fancybox-tmp"); c = l(D); j = l(B); E = D / B; m = l(F(m) ? l(m, "w") - A : m); n = l(F(n) ? l(n, "w") - A : n); u = l(F(u) ? l(u, "h") - t : u); v = l(F(v) ? l(v, "h") - t : v); G = n; C = v; h.fitToView && ((n = Math.min(a.w - A, n)), (v = Math.min(a.h - t, v))); A = a.w - p; r = a.h - r; h.aspectRatio ? (c > n && ((c = n), (j = l(c / E))), j > v && ((j = v), (c = l(j * E))), c < m && ((c = m), (j = l(c / E))), j < u && ((j = u), (c = l(j * E)))) : ((c = Math.max(m, Math.min(c, n))), h.autoHeight && "iframe" !== h.type && (g.width(c), (j = g.height())), (j = Math.max(u, Math.min(j, v)))); if (h.fitToView) if ( (g.width(c).height(j), e.width(c + y), (a = e.width()), (p = e.height()), h.aspectRatio) ) for (; (a > A || p > r) && c > m && j > u && !(19 < d++); ) (j = Math.max(u, Math.min(v, j - 10))), (c = l(j * E)), c < m && ((c = m), (j = l(c / E))), c > n && ((c = n), (j = l(c / E))), g.width(c).height(j), e.width(c + y), (a = e.width()), (p = e.height()); else (c = Math.max(m, Math.min(c, c - (a - A)))), (j = Math.max(u, Math.min(j, j - (p - r)))); q && "auto" === s && j < B && c + y + q < A && (c += q); g.width(c).height(j); e.width(c + y); a = e.width(); p = e.height(); e = (a > A || p > r) && c > m && j > u; c = h.aspectRatio ? c < G && j < C && c < D && j < B : (c < G || j < C) && (c < D || j < B); f.extend(h, { dim: { width: x(a), height: x(p) }, origWidth: D, origHeight: B, canShrink: e, canExpand: c, wPadding: y, hPadding: z, wrapSpace: p - k.outerHeight(!0), skinSpace: k.height() - j, }); !w && h.autoHeight && j > u && j < v && !c && g.height("auto"); }, _getPosition: function (a) { var d = b.current, e = b.getViewport(), c = d.margin, f = b.wrap.width() + c[1] + c[3], g = b.wrap.height() + c[0] + c[2], c = { position: "absolute", top: c[0], left: c[3] }; d.autoCenter && d.fixed && !a && g <= e.h && f <= e.w ? (c.position = "fixed") : d.locked || ((c.top += e.y), (c.left += e.x)); c.top = x(Math.max(c.top, c.top + (e.h - g) * d.topRatio)); c.left = x(Math.max(c.left, c.left + (e.w - f) * d.leftRatio)); return c; }, _afterZoomIn: function () { var a = b.current; a && ((b.isOpen = b.isOpened = !0), b.wrap.css("overflow", "visible").addClass("fancybox-opened"), b.update(), (a.closeClick || (a.nextClick && 1 < b.group.length)) && b.inner.css("cursor", "pointer").bind("click.fb", function (d) { !f(d.target).is("a") && !f(d.target).parent().is("a") && (d.preventDefault(), b[a.closeClick ? "close" : "next"]()); }), a.closeBtn && f(a.tpl.closeBtn) .appendTo(b.skin) .bind("click.fb", function (a) { a.preventDefault(); b.close(); }), a.arrows && 1 < b.group.length && ((a.loop || 0 < a.index) && f(a.tpl.prev).appendTo(b.outer).bind("click.fb", b.prev), (a.loop || a.index < b.group.length - 1) && f(a.tpl.next).appendTo(b.outer).bind("click.fb", b.next)), b.trigger("afterShow"), !a.loop && a.index === a.group.length - 1 ? b.play(!1) : b.opts.autoPlay && !b.player.isActive && ((b.opts.autoPlay = !1), b.play())); }, _afterZoomOut: function (a) { a = a || b.current; f(".fancybox-wrap").trigger("onReset").remove(); f.extend(b, { group: {}, opts: {}, router: !1, current: null, isActive: !1, isOpened: !1, isOpen: !1, isClosing: !1, wrap: null, skin: null, outer: null, inner: null, }); b.trigger("afterClose", a); }, }); b.transitions = { getOrigPosition: function () { var a = b.current, d = a.element, e = a.orig, c = {}, f = 50, g = 50, h = a.hPadding, j = a.wPadding, m = b.getViewport(); !e && a.isDom && d.is(":visible") && ((e = d.find("img:first")), e.length || (e = d)); t(e) ? ((c = e.offset()), e.is("img") && ((f = e.outerWidth()), (g = e.outerHeight()))) : ((c.top = m.y + (m.h - g) * a.topRatio), (c.left = m.x + (m.w - f) * a.leftRatio)); if ("fixed" === b.wrap.css("position") || a.locked) (c.top -= m.y), (c.left -= m.x); return (c = { top: x(c.top - h * a.topRatio), left: x(c.left - j * a.leftRatio), width: x(f + j), height: x(g + h), }); }, step: function (a, d) { var e, c, f = d.prop; c = b.current; var g = c.wrapSpace, h = c.skinSpace; if ("width" === f || "height" === f) (e = d.end === d.start ? 1 : (a - d.start) / (d.end - d.start)), b.isClosing && (e = 1 - e), (c = "width" === f ? c.wPadding : c.hPadding), (c = a - c), b.skin[f](l("width" === f ? c : c - g * e)), b.inner[f](l("width" === f ? c : c - g * e - h * e)); }, zoomIn: function () { var a = b.current, d = a.pos, e = a.openEffect, c = "elastic" === e, k = f.extend({ opacity: 1 }, d); delete k.position; c ? ((d = this.getOrigPosition()), a.openOpacity && (d.opacity = 0.1)) : "fade" === e && (d.opacity = 0.1); b.wrap.css(d).animate(k, { duration: "none" === e ? 0 : a.openSpeed, easing: a.openEasing, step: c ? this.step : null, complete: b._afterZoomIn, }); }, zoomOut: function () { var a = b.current, d = a.closeEffect, e = "elastic" === d, c = { opacity: 0.1 }; e && ((c = this.getOrigPosition()), a.closeOpacity && (c.opacity = 0.1)); b.wrap.animate(c, { duration: "none" === d ? 0 : a.closeSpeed, easing: a.closeEasing, step: e ? this.step : null, complete: b._afterZoomOut, }); }, changeIn: function () { var a = b.current, d = a.nextEffect, e = a.pos, c = { opacity: 1 }, f = b.direction, g; e.opacity = 0.1; "elastic" === d && ((g = "down" === f || "up" === f ? "top" : "left"), "down" === f || "right" === f ? ((e[g] = x(l(e[g]) - 200)), (c[g] = "+=200px")) : ((e[g] = x(l(e[g]) + 200)), (c[g] = "-=200px"))); "none" === d ? b._afterZoomIn() : b.wrap.css(e).animate(c, { duration: a.nextSpeed, easing: a.nextEasing, complete: b._afterZoomIn, }); }, changeOut: function () { var a = b.previous, d = a.prevEffect, e = { opacity: 0.1 }, c = b.direction; "elastic" === d && (e["down" === c || "up" === c ? "top" : "left"] = ("up" === c || "left" === c ? "-" : "+") + "=200px"); a.wrap.animate(e, { duration: "none" === d ? 0 : a.prevSpeed, easing: a.prevEasing, complete: function () { f(this).trigger("onReset").remove(); }, }); }, }; b.helpers.overlay = { defaults: { closeClick: !0, speedOut: 200, showEarly: !0, css: {}, locked: !s, fixed: !0, }, overlay: null, fixed: !1, create: function (a) { a = f.extend({}, this.defaults, a); this.overlay && this.close(); this.overlay = f('
      ').appendTo("body"); this.fixed = !1; a.fixed && b.defaults.fixed && (this.overlay.addClass("fancybox-overlay-fixed"), (this.fixed = !0)); }, open: function (a) { var d = this; a = f.extend({}, this.defaults, a); this.overlay ? this.overlay.unbind(".overlay").width("auto").height("auto") : this.create(a); this.fixed || (q.bind("resize.overlay", f.proxy(this.update, this)), this.update()); a.closeClick && this.overlay.bind("click.overlay", function (a) { f(a.target).hasClass("fancybox-overlay") && (b.isActive ? b.close() : d.close()); }); this.overlay.css(a.css).show(); }, close: function () { f(".fancybox-overlay").remove(); q.unbind("resize.overlay"); this.overlay = null; !1 !== this.margin && (f("body").css("margin-right", this.margin), (this.margin = !1)); this.el && this.el.removeClass("fancybox-lock"); }, update: function () { var a = "100%", b; this.overlay.width(a).height("100%"); H ? ((b = Math.max(z.documentElement.offsetWidth, z.body.offsetWidth)), n.width() > b && (a = n.width())) : n.width() > q.width() && (a = n.width()); this.overlay.width(a).height(n.height()); }, onReady: function (a, b) { f(".fancybox-overlay").stop(!0, !0); this.overlay || ((this.margin = n.height() > q.height() || "scroll" === f("body").css("overflow-y") ? f("body").css("margin-right") : !1), (this.el = z.all && !z.querySelector ? f("html") : f("body")), this.create(a)); a.locked && this.fixed && ((b.locked = this.overlay.append(b.wrap)), (b.fixed = !1)); !0 === a.showEarly && this.beforeShow.apply(this, arguments); }, beforeShow: function (a, b) { b.locked && (this.el.addClass("fancybox-lock"), !1 !== this.margin && f("body").css("margin-right", l(this.margin) + b.scrollbarWidth)); this.open(a); }, onUpdate: function () { this.fixed || this.update(); }, afterClose: function (a) { this.overlay && !b.isActive && this.overlay.fadeOut(a.speedOut, f.proxy(this.close, this)); }, }; b.helpers.title = { defaults: { type: "float", position: "bottom" }, beforeShow: function (a) { var d = b.current, e = d.title, c = a.type; f.isFunction(e) && (e = e.call(d.element, d)); if (p(e) && "" !== f.trim(e)) { d = f( '
      ' + e + "
      ", ); switch (c) { case "inside": c = b.skin; break; case "outside": c = b.wrap; break; case "over": c = b.inner; break; default: (c = b.skin), d.appendTo("body"), H && d.width(d.width()), d.wrapInner(''), (b.current.margin[2] += Math.abs(l(d.css("margin-bottom")))); } d["top" === a.position ? "prependTo" : "appendTo"](c); } }, }; f.fn.fancybox = function (a) { var d, e = f(this), c = this.selector || "", k = function (g) { var h = f(this).blur(), j = d, k, l; !g.ctrlKey && !g.altKey && !g.shiftKey && !g.metaKey && !h.is(".fancybox-wrap") && ((k = a.groupAttr || "data-fancybox-group"), (l = h.attr(k)), l || ((k = "rel"), (l = h.get(0)[k])), l && "" !== l && "nofollow" !== l && ((h = c.length ? f(c) : e), (h = h.filter("[" + k + '="' + l + '"]')), (j = h.index(this))), (a.index = j), !1 !== b.open(h, a) && g.preventDefault()); }; a = a || {}; d = a.index || 0; !c || !1 === a.live ? e.unbind("click.fb-start").bind("click.fb-start", k) : n .undelegate(c, "click.fb-start") .delegate( c + ":not('.fancybox-item, .fancybox-nav')", "click.fb-start", k, ); this.filter("[data-fancybox-start=1]").trigger("click"); return this; }; n.ready(function () { f.scrollbarWidth === r && (f.scrollbarWidth = function () { var a = f( '
      ', ).appendTo("body"), b = a.children(), b = b.innerWidth() - b.height(99).innerWidth(); a.remove(); return b; }); if (f.support.fixedPosition === r) { var a = f.support, d = f('
      ').appendTo("body"), e = 20 === d[0].offsetTop || 15 === d[0].offsetTop; d.remove(); a.fixedPosition = e; } f.extend(b.defaults, { scrollbarWidth: f.scrollbarWidth(), fixed: f.support.fixedPosition, parent: f("body"), }); }); })(window, document, jQuery); ================================================ FILE: backend/tests/integration/tests/pruning/website/js/jquery.flexslider.js ================================================ /* * jQuery FlexSlider v2.1 * http://www.woothemes.com/flexslider/ * * Copyright 2012 WooThemes * Free to use under the GPLv2 license. * http://www.gnu.org/licenses/gpl-2.0.html * * Contributing author: Tyler Smith (@mbmufffin) */ (function ($) { //FlexSlider: Object Instance $.flexslider = function (el, options) { var slider = $(el), vars = $.extend({}, $.flexslider.defaults, options), namespace = vars.namespace, touch = "ontouchstart" in window || (window.DocumentTouch && document instanceof DocumentTouch), eventType = touch ? "touchend" : "click", vertical = vars.direction === "vertical", reverse = vars.reverse, carousel = vars.itemWidth > 0, fade = vars.animation === "fade", asNav = vars.asNavFor !== "", methods = {}; // Store a reference to the slider object $.data(el, "flexslider", slider); // Privat slider methods methods = { init: function () { slider.animating = false; slider.currentSlide = vars.startAt; slider.animatingTo = slider.currentSlide; slider.atEnd = slider.currentSlide === 0 || slider.currentSlide === slider.last; slider.containerSelector = vars.selector.substr( 0, vars.selector.search(" "), ); slider.slides = $(vars.selector, slider); slider.container = $(slider.containerSelector, slider); slider.count = slider.slides.length; // SYNC: slider.syncExists = $(vars.sync).length > 0; // SLIDE: if (vars.animation === "slide") vars.animation = "swing"; slider.prop = vertical ? "top" : "marginLeft"; slider.args = {}; // SLIDESHOW: slider.manualPause = false; // TOUCH/USECSS: slider.transitions = !vars.video && !fade && vars.useCSS && (function () { var obj = document.createElement("div"), props = [ "perspectiveProperty", "WebkitPerspective", "MozPerspective", "OPerspective", "msPerspective", ]; for (var i in props) { if (obj.style[props[i]] !== undefined) { slider.pfx = props[i].replace("Perspective", "").toLowerCase(); slider.prop = "-" + slider.pfx + "-transform"; return true; } } return false; })(); // CONTROLSCONTAINER: if (vars.controlsContainer !== "") slider.controlsContainer = $(vars.controlsContainer).length > 0 && $(vars.controlsContainer); // MANUAL: if (vars.manualControls !== "") slider.manualControls = $(vars.manualControls).length > 0 && $(vars.manualControls); // RANDOMIZE: if (vars.randomize) { slider.slides.sort(function () { return Math.round(Math.random()) - 0.5; }); slider.container.empty().append(slider.slides); } slider.doMath(); // ASNAV: if (asNav) methods.asNav.setup(); // INIT slider.setup("init"); // CONTROLNAV: if (vars.controlNav) methods.controlNav.setup(); // DIRECTIONNAV: if (vars.directionNav) methods.directionNav.setup(); // KEYBOARD: if ( vars.keyboard && ($(slider.containerSelector).length === 1 || vars.multipleKeyboard) ) { $(document).bind("keyup", function (event) { var keycode = event.keyCode; if (!slider.animating && (keycode === 39 || keycode === 37)) { var target = keycode === 39 ? slider.getTarget("next") : keycode === 37 ? slider.getTarget("prev") : false; slider.flexAnimate(target, vars.pauseOnAction); } }); } // MOUSEWHEEL: if (vars.mousewheel) { slider.bind("mousewheel", function (event, delta, deltaX, deltaY) { event.preventDefault(); var target = delta < 0 ? slider.getTarget("next") : slider.getTarget("prev"); slider.flexAnimate(target, vars.pauseOnAction); }); } // PAUSEPLAY if (vars.pausePlay) methods.pausePlay.setup(); // SLIDSESHOW if (vars.slideshow) { if (vars.pauseOnHover) { slider.hover( function () { if (!slider.manualPlay && !slider.manualPause) slider.pause(); }, function () { if (!slider.manualPause && !slider.manualPlay) slider.play(); }, ); } // initialize animation vars.initDelay > 0 ? setTimeout(slider.play, vars.initDelay) : slider.play(); } // TOUCH if (touch && vars.touch) methods.touch(); // FADE&&SMOOTHHEIGHT || SLIDE: if (!fade || (fade && vars.smoothHeight)) $(window).bind("resize focus", methods.resize); // API: start() Callback setTimeout(function () { vars.start(slider); }, 200); }, asNav: { setup: function () { slider.asNav = true; slider.animatingTo = Math.floor(slider.currentSlide / slider.move); slider.currentItem = slider.currentSlide; slider.slides .removeClass(namespace + "active-slide") .eq(slider.currentItem) .addClass(namespace + "active-slide"); slider.slides.click(function (e) { e.preventDefault(); var $slide = $(this), target = $slide.index(); if ( !$(vars.asNavFor).data("flexslider").animating && !$slide.hasClass("active") ) { slider.direction = slider.currentItem < target ? "next" : "prev"; slider.flexAnimate(target, vars.pauseOnAction, false, true, true); } }); }, }, controlNav: { setup: function () { if (!slider.manualControls) { methods.controlNav.setupPaging(); } else { // MANUALCONTROLS: methods.controlNav.setupManual(); } }, setupPaging: function () { var type = vars.controlNav === "thumbnails" ? "control-thumbs" : "control-paging", j = 1, item; slider.controlNavScaffold = $( '
        ', ); if (slider.pagingCount > 1) { for (var i = 0; i < slider.pagingCount; i++) { item = vars.controlNav === "thumbnails" ? '' : "" + j + ""; slider.controlNavScaffold.append("
      1. " + item + "
      2. "); j++; } } // CONTROLSCONTAINER: slider.controlsContainer ? $(slider.controlsContainer).append(slider.controlNavScaffold) : slider.append(slider.controlNavScaffold); methods.controlNav.set(); methods.controlNav.active(); slider.controlNavScaffold.delegate( "a, img", eventType, function (event) { event.preventDefault(); var $this = $(this), target = slider.controlNav.index($this); if (!$this.hasClass(namespace + "active")) { slider.direction = target > slider.currentSlide ? "next" : "prev"; slider.flexAnimate(target, vars.pauseOnAction); } }, ); // Prevent iOS click event bug if (touch) { slider.controlNavScaffold.delegate( "a", "click touchstart", function (event) { event.preventDefault(); }, ); } }, setupManual: function () { slider.controlNav = slider.manualControls; methods.controlNav.active(); slider.controlNav.live(eventType, function (event) { event.preventDefault(); var $this = $(this), target = slider.controlNav.index($this); if (!$this.hasClass(namespace + "active")) { target > slider.currentSlide ? (slider.direction = "next") : (slider.direction = "prev"); slider.flexAnimate(target, vars.pauseOnAction); } }); // Prevent iOS click event bug if (touch) { slider.controlNav.live("click touchstart", function (event) { event.preventDefault(); }); } }, set: function () { var selector = vars.controlNav === "thumbnails" ? "img" : "a"; slider.controlNav = $( "." + namespace + "control-nav li " + selector, slider.controlsContainer ? slider.controlsContainer : slider, ); }, active: function () { slider.controlNav .removeClass(namespace + "active") .eq(slider.animatingTo) .addClass(namespace + "active"); }, update: function (action, pos) { if (slider.pagingCount > 1 && action === "add") { slider.controlNavScaffold.append( $("
      3. " + slider.count + "
      4. "), ); } else if (slider.pagingCount === 1) { slider.controlNavScaffold.find("li").remove(); } else { slider.controlNav.eq(pos).closest("li").remove(); } methods.controlNav.set(); slider.pagingCount > 1 && slider.pagingCount !== slider.controlNav.length ? slider.update(pos, action) : methods.controlNav.active(); }, }, directionNav: { setup: function () { var directionNavScaffold = $( '", ); // CONTROLSCONTAINER: if (slider.controlsContainer) { $(slider.controlsContainer).append(directionNavScaffold); slider.directionNav = $( "." + namespace + "direction-nav li a", slider.controlsContainer, ); } else { slider.append(directionNavScaffold); slider.directionNav = $( "." + namespace + "direction-nav li a", slider, ); } methods.directionNav.update(); slider.directionNav.bind(eventType, function (event) { event.preventDefault(); var target = $(this).hasClass(namespace + "next") ? slider.getTarget("next") : slider.getTarget("prev"); slider.flexAnimate(target, vars.pauseOnAction); }); // Prevent iOS click event bug if (touch) { slider.directionNav.bind("click touchstart", function (event) { event.preventDefault(); }); } }, update: function () { var disabledClass = namespace + "disabled"; if (slider.pagingCount === 1) { slider.directionNav.addClass(disabledClass); } else if (!vars.animationLoop) { if (slider.animatingTo === 0) { slider.directionNav .removeClass(disabledClass) .filter("." + namespace + "prev") .addClass(disabledClass); } else if (slider.animatingTo === slider.last) { slider.directionNav .removeClass(disabledClass) .filter("." + namespace + "next") .addClass(disabledClass); } else { slider.directionNav.removeClass(disabledClass); } } else { slider.directionNav.removeClass(disabledClass); } }, }, pausePlay: { setup: function () { var pausePlayScaffold = $( '
        ', ); // CONTROLSCONTAINER: if (slider.controlsContainer) { slider.controlsContainer.append(pausePlayScaffold); slider.pausePlay = $( "." + namespace + "pauseplay a", slider.controlsContainer, ); } else { slider.append(pausePlayScaffold); slider.pausePlay = $("." + namespace + "pauseplay a", slider); } methods.pausePlay.update( vars.slideshow ? namespace + "pause" : namespace + "play", ); slider.pausePlay.bind(eventType, function (event) { event.preventDefault(); if ($(this).hasClass(namespace + "pause")) { slider.manualPause = true; slider.manualPlay = false; slider.pause(); } else { slider.manualPause = false; slider.manualPlay = true; slider.play(); } }); // Prevent iOS click event bug if (touch) { slider.pausePlay.bind("click touchstart", function (event) { event.preventDefault(); }); } }, update: function (state) { state === "play" ? slider.pausePlay .removeClass(namespace + "pause") .addClass(namespace + "play") .text(vars.playText) : slider.pausePlay .removeClass(namespace + "play") .addClass(namespace + "pause") .text(vars.pauseText); }, }, touch: function () { var startX, startY, offset, cwidth, dx, startT, scrolling = false; el.addEventListener("touchstart", onTouchStart, false); function onTouchStart(e) { if (slider.animating) { e.preventDefault(); } else if (e.touches.length === 1) { slider.pause(); // CAROUSEL: cwidth = vertical ? slider.h : slider.w; startT = Number(new Date()); // CAROUSEL: offset = carousel && reverse && slider.animatingTo === slider.last ? 0 : carousel && reverse ? slider.limit - (slider.itemW + vars.itemMargin) * slider.move * slider.animatingTo : carousel && slider.currentSlide === slider.last ? slider.limit : carousel ? (slider.itemW + vars.itemMargin) * slider.move * slider.currentSlide : reverse ? (slider.last - slider.currentSlide + slider.cloneOffset) * cwidth : (slider.currentSlide + slider.cloneOffset) * cwidth; startX = vertical ? e.touches[0].pageY : e.touches[0].pageX; startY = vertical ? e.touches[0].pageX : e.touches[0].pageY; el.addEventListener("touchmove", onTouchMove, false); el.addEventListener("touchend", onTouchEnd, false); } } function onTouchMove(e) { dx = vertical ? startX - e.touches[0].pageY : startX - e.touches[0].pageX; scrolling = vertical ? Math.abs(dx) < Math.abs(e.touches[0].pageX - startY) : Math.abs(dx) < Math.abs(e.touches[0].pageY - startY); if (!scrolling || Number(new Date()) - startT > 500) { e.preventDefault(); if (!fade && slider.transitions) { if (!vars.animationLoop) { dx = dx / ((slider.currentSlide === 0 && dx < 0) || (slider.currentSlide === slider.last && dx > 0) ? Math.abs(dx) / cwidth + 2 : 1); } slider.setProps(offset + dx, "setTouch"); } } } function onTouchEnd(e) { // finish the touch by undoing the touch session el.removeEventListener("touchmove", onTouchMove, false); if ( slider.animatingTo === slider.currentSlide && !scrolling && !(dx === null) ) { var updateDx = reverse ? -dx : dx, target = updateDx > 0 ? slider.getTarget("next") : slider.getTarget("prev"); if ( slider.canAdvance(target) && ((Number(new Date()) - startT < 550 && Math.abs(updateDx) > 50) || Math.abs(updateDx) > cwidth / 2) ) { slider.flexAnimate(target, vars.pauseOnAction); } else { if (!fade) slider.flexAnimate( slider.currentSlide, vars.pauseOnAction, true, ); } } el.removeEventListener("touchend", onTouchEnd, false); startX = null; startY = null; dx = null; offset = null; } }, resize: function () { if (!slider.animating && slider.is(":visible")) { if (!carousel) slider.doMath(); if (fade) { // SMOOTH HEIGHT: methods.smoothHeight(); } else if (carousel) { //CAROUSEL: slider.slides.width(slider.computedW); slider.update(slider.pagingCount); slider.setProps(); } else if (vertical) { //VERTICAL: slider.viewport.height(slider.h); slider.setProps(slider.h, "setTotal"); } else { // SMOOTH HEIGHT: if (vars.smoothHeight) methods.smoothHeight(); slider.newSlides.width(slider.computedW); slider.setProps(slider.computedW, "setTotal"); } } }, smoothHeight: function (dur) { if (!vertical || fade) { var $obj = fade ? slider : slider.viewport; dur ? $obj.animate( { height: slider.slides.eq(slider.animatingTo).height() }, dur, ) : $obj.height(slider.slides.eq(slider.animatingTo).height()); } }, sync: function (action) { var $obj = $(vars.sync).data("flexslider"), target = slider.animatingTo; switch (action) { case "animate": $obj.flexAnimate(target, vars.pauseOnAction, false, true); break; case "play": if (!$obj.playing && !$obj.asNav) { $obj.play(); } break; case "pause": $obj.pause(); break; } }, }; // public methods slider.flexAnimate = function (target, pause, override, withSync, fromNav) { if (asNav && slider.pagingCount === 1) slider.direction = slider.currentItem < target ? "next" : "prev"; if ( !slider.animating && (slider.canAdvance(target, fromNav) || override) && slider.is(":visible") ) { if (asNav && withSync) { var master = $(vars.asNavFor).data("flexslider"); slider.atEnd = target === 0 || target === slider.count - 1; master.flexAnimate(target, true, false, true, fromNav); slider.direction = slider.currentItem < target ? "next" : "prev"; master.direction = slider.direction; if ( Math.ceil((target + 1) / slider.visible) - 1 !== slider.currentSlide && target !== 0 ) { slider.currentItem = target; slider.slides .removeClass(namespace + "active-slide") .eq(target) .addClass(namespace + "active-slide"); target = Math.floor(target / slider.visible); } else { slider.currentItem = target; slider.slides .removeClass(namespace + "active-slide") .eq(target) .addClass(namespace + "active-slide"); return false; } } slider.animating = true; slider.animatingTo = target; // API: before() animation Callback vars.before(slider); // SLIDESHOW: if (pause) slider.pause(); // SYNC: if (slider.syncExists && !fromNav) methods.sync("animate"); // CONTROLNAV if (vars.controlNav) methods.controlNav.active(); // !CAROUSEL: // CANDIDATE: slide active class (for add/remove slide) if (!carousel) slider.slides .removeClass(namespace + "active-slide") .eq(target) .addClass(namespace + "active-slide"); // INFINITE LOOP: // CANDIDATE: atEnd slider.atEnd = target === 0 || target === slider.last; // DIRECTIONNAV: if (vars.directionNav) methods.directionNav.update(); if (target === slider.last) { // API: end() of cycle Callback vars.end(slider); // SLIDESHOW && !INFINITE LOOP: if (!vars.animationLoop) slider.pause(); } // SLIDE: if (!fade) { var dimension = vertical ? slider.slides.filter(":first").height() : slider.computedW, margin, slideString, calcNext; // INFINITE LOOP / REVERSE: if (carousel) { margin = vars.itemWidth > slider.w ? vars.itemMargin * 2 : vars.itemMargin; calcNext = (slider.itemW + margin) * slider.move * slider.animatingTo; slideString = calcNext > slider.limit && slider.visible !== 1 ? slider.limit : calcNext; } else if ( slider.currentSlide === 0 && target === slider.count - 1 && vars.animationLoop && slider.direction !== "next" ) { slideString = reverse ? (slider.count + slider.cloneOffset) * dimension : 0; } else if ( slider.currentSlide === slider.last && target === 0 && vars.animationLoop && slider.direction !== "prev" ) { slideString = reverse ? 0 : (slider.count + 1) * dimension; } else { slideString = reverse ? (slider.count - 1 - target + slider.cloneOffset) * dimension : (target + slider.cloneOffset) * dimension; } slider.setProps(slideString, "", vars.animationSpeed); if (slider.transitions) { if (!vars.animationLoop || !slider.atEnd) { slider.animating = false; slider.currentSlide = slider.animatingTo; } slider.container.unbind("webkitTransitionEnd transitionend"); slider.container.bind( "webkitTransitionEnd transitionend", function () { slider.wrapup(dimension); }, ); } else { slider.container.animate( slider.args, vars.animationSpeed, vars.easing, function () { slider.wrapup(dimension); }, ); } } else { // FADE: if (!touch) { slider.slides .eq(slider.currentSlide) .fadeOut(vars.animationSpeed, vars.easing); slider.slides .eq(target) .fadeIn(vars.animationSpeed, vars.easing, slider.wrapup); } else { slider.slides .eq(slider.currentSlide) .css({ opacity: 0, zIndex: 1 }); slider.slides.eq(target).css({ opacity: 1, zIndex: 2 }); slider.slides.unbind("webkitTransitionEnd transitionend"); slider.slides .eq(slider.currentSlide) .bind("webkitTransitionEnd transitionend", function () { // API: after() animation Callback vars.after(slider); }); slider.animating = false; slider.currentSlide = slider.animatingTo; } } // SMOOTH HEIGHT: if (vars.smoothHeight) methods.smoothHeight(vars.animationSpeed); } }; slider.wrapup = function (dimension) { // SLIDE: if (!fade && !carousel) { if ( slider.currentSlide === 0 && slider.animatingTo === slider.last && vars.animationLoop ) { slider.setProps(dimension, "jumpEnd"); } else if ( slider.currentSlide === slider.last && slider.animatingTo === 0 && vars.animationLoop ) { slider.setProps(dimension, "jumpStart"); } } slider.animating = false; slider.currentSlide = slider.animatingTo; // API: after() animation Callback vars.after(slider); }; // SLIDESHOW: slider.animateSlides = function () { if (!slider.animating) slider.flexAnimate(slider.getTarget("next")); }; // SLIDESHOW: slider.pause = function () { clearInterval(slider.animatedSlides); slider.playing = false; // PAUSEPLAY: if (vars.pausePlay) methods.pausePlay.update("play"); // SYNC: if (slider.syncExists) methods.sync("pause"); }; // SLIDESHOW: slider.play = function () { slider.animatedSlides = setInterval( slider.animateSlides, vars.slideshowSpeed, ); slider.playing = true; // PAUSEPLAY: if (vars.pausePlay) methods.pausePlay.update("pause"); // SYNC: if (slider.syncExists) methods.sync("play"); }; slider.canAdvance = function (target, fromNav) { // ASNAV: var last = asNav ? slider.pagingCount - 1 : slider.last; return fromNav ? true : asNav && slider.currentItem === slider.count - 1 && target === 0 && slider.direction === "prev" ? true : asNav && slider.currentItem === 0 && target === slider.pagingCount - 1 && slider.direction !== "next" ? false : target === slider.currentSlide && !asNav ? false : vars.animationLoop ? true : slider.atEnd && slider.currentSlide === 0 && target === last && slider.direction !== "next" ? false : slider.atEnd && slider.currentSlide === last && target === 0 && slider.direction === "next" ? false : true; }; slider.getTarget = function (dir) { slider.direction = dir; if (dir === "next") { return slider.currentSlide === slider.last ? 0 : slider.currentSlide + 1; } else { return slider.currentSlide === 0 ? slider.last : slider.currentSlide - 1; } }; // SLIDE: slider.setProps = function (pos, special, dur) { var target = (function () { var posCheck = pos ? pos : (slider.itemW + vars.itemMargin) * slider.move * slider.animatingTo, posCalc = (function () { if (carousel) { return special === "setTouch" ? pos : reverse && slider.animatingTo === slider.last ? 0 : reverse ? slider.limit - (slider.itemW + vars.itemMargin) * slider.move * slider.animatingTo : slider.animatingTo === slider.last ? slider.limit : posCheck; } else { switch (special) { case "setTotal": return reverse ? (slider.count - 1 - slider.currentSlide + slider.cloneOffset) * pos : (slider.currentSlide + slider.cloneOffset) * pos; case "setTouch": return reverse ? pos : pos; case "jumpEnd": return reverse ? pos : slider.count * pos; case "jumpStart": return reverse ? slider.count * pos : pos; default: return pos; } } })(); return posCalc * -1 + "px"; })(); if (slider.transitions) { target = vertical ? "translate3d(0," + target + ",0)" : "translate3d(" + target + ",0,0)"; dur = dur !== undefined ? dur / 1000 + "s" : "0s"; slider.container.css("-" + slider.pfx + "-transition-duration", dur); } slider.args[slider.prop] = target; if (slider.transitions || dur === undefined) slider.container.css(slider.args); }; slider.setup = function (type) { // SLIDE: if (!fade) { var sliderOffset, arr; if (type === "init") { slider.viewport = $('
        ') .css({ overflow: "hidden", position: "relative" }) .appendTo(slider) .append(slider.container); // INFINITE LOOP: slider.cloneCount = 0; slider.cloneOffset = 0; // REVERSE: if (reverse) { arr = $.makeArray(slider.slides).reverse(); slider.slides = $(arr); slider.container.empty().append(slider.slides); } } // INFINITE LOOP && !CAROUSEL: if (vars.animationLoop && !carousel) { slider.cloneCount = 2; slider.cloneOffset = 1; // clear out old clones if (type !== "init") slider.container.find(".clone").remove(); slider.container .append(slider.slides.first().clone().addClass("clone")) .prepend(slider.slides.last().clone().addClass("clone")); } slider.newSlides = $(vars.selector, slider); sliderOffset = reverse ? slider.count - 1 - slider.currentSlide + slider.cloneOffset : slider.currentSlide + slider.cloneOffset; // VERTICAL: if (vertical && !carousel) { slider.container .height((slider.count + slider.cloneCount) * 200 + "%") .css("position", "absolute") .width("100%"); setTimeout( function () { slider.newSlides.css({ display: "block" }); slider.doMath(); slider.viewport.height(slider.h); slider.setProps(sliderOffset * slider.h, "init"); }, type === "init" ? 100 : 0, ); } else { slider.container.width( (slider.count + slider.cloneCount) * 200 + "%", ); slider.setProps(sliderOffset * slider.computedW, "init"); setTimeout( function () { slider.doMath(); slider.newSlides.css({ width: slider.computedW, float: "left", display: "block", }); // SMOOTH HEIGHT: if (vars.smoothHeight) methods.smoothHeight(); }, type === "init" ? 100 : 0, ); } } else { // FADE: slider.slides.css({ width: "100%", float: "left", marginRight: "-100%", position: "relative", }); if (type === "init") { if (!touch) { slider.slides .eq(slider.currentSlide) .fadeIn(vars.animationSpeed, vars.easing); } else { slider.slides .css({ opacity: 0, display: "block", webkitTransition: "opacity " + vars.animationSpeed / 1000 + "s ease", zIndex: 1, }) .eq(slider.currentSlide) .css({ opacity: 1, zIndex: 2 }); } } // SMOOTH HEIGHT: if (vars.smoothHeight) methods.smoothHeight(); } // !CAROUSEL: // CANDIDATE: active slide if (!carousel) slider.slides .removeClass(namespace + "active-slide") .eq(slider.currentSlide) .addClass(namespace + "active-slide"); }; slider.doMath = function () { var slide = slider.slides.first(), slideMargin = vars.itemMargin, minItems = vars.minItems, maxItems = vars.maxItems; slider.w = slider.width(); slider.h = slide.height(); slider.boxPadding = slide.outerWidth() - slide.width(); // CAROUSEL: if (carousel) { slider.itemT = vars.itemWidth + slideMargin; slider.minW = minItems ? minItems * slider.itemT : slider.w; slider.maxW = maxItems ? maxItems * slider.itemT : slider.w; slider.itemW = slider.minW > slider.w ? (slider.w - slideMargin * minItems) / minItems : slider.maxW < slider.w ? (slider.w - slideMargin * maxItems) / maxItems : vars.itemWidth > slider.w ? slider.w : vars.itemWidth; slider.visible = Math.floor(slider.w / (slider.itemW + slideMargin)); slider.move = vars.move > 0 && vars.move < slider.visible ? vars.move : slider.visible; slider.pagingCount = Math.ceil( (slider.count - slider.visible) / slider.move + 1, ); slider.last = slider.pagingCount - 1; slider.limit = slider.pagingCount === 1 ? 0 : vars.itemWidth > slider.w ? (slider.itemW + slideMargin * 2) * slider.count - slider.w - slideMargin : (slider.itemW + slideMargin) * slider.count - slider.w - slideMargin; } else { slider.itemW = slider.w; slider.pagingCount = slider.count; slider.last = slider.count - 1; } slider.computedW = slider.itemW - slider.boxPadding; }; slider.update = function (pos, action) { slider.doMath(); // update currentSlide and slider.animatingTo if necessary if (!carousel) { if (pos < slider.currentSlide) { slider.currentSlide += 1; } else if (pos <= slider.currentSlide && pos !== 0) { slider.currentSlide -= 1; } slider.animatingTo = slider.currentSlide; } // update controlNav if (vars.controlNav && !slider.manualControls) { if ( (action === "add" && !carousel) || slider.pagingCount > slider.controlNav.length ) { methods.controlNav.update("add"); } else if ( (action === "remove" && !carousel) || slider.pagingCount < slider.controlNav.length ) { if (carousel && slider.currentSlide > slider.last) { slider.currentSlide -= 1; slider.animatingTo -= 1; } methods.controlNav.update("remove", slider.last); } } // update directionNav if (vars.directionNav) methods.directionNav.update(); }; slider.addSlide = function (obj, pos) { var $obj = $(obj); slider.count += 1; slider.last = slider.count - 1; // append new slide if (vertical && reverse) { pos !== undefined ? slider.slides.eq(slider.count - pos).after($obj) : slider.container.prepend($obj); } else { pos !== undefined ? slider.slides.eq(pos).before($obj) : slider.container.append($obj); } // update currentSlide, animatingTo, controlNav, and directionNav slider.update(pos, "add"); // update slider.slides slider.slides = $(vars.selector + ":not(.clone)", slider); // re-setup the slider to accomdate new slide slider.setup(); //FlexSlider: added() Callback vars.added(slider); }; slider.removeSlide = function (obj) { var pos = isNaN(obj) ? slider.slides.index($(obj)) : obj; // update count slider.count -= 1; slider.last = slider.count - 1; // remove slide if (isNaN(obj)) { $(obj, slider.slides).remove(); } else { vertical && reverse ? slider.slides.eq(slider.last).remove() : slider.slides.eq(obj).remove(); } // update currentSlide, animatingTo, controlNav, and directionNav slider.doMath(); slider.update(pos, "remove"); // update slider.slides slider.slides = $(vars.selector + ":not(.clone)", slider); // re-setup the slider to accomdate new slide slider.setup(); // FlexSlider: removed() Callback vars.removed(slider); }; //FlexSlider: Initialize methods.init(); }; //FlexSlider: Default Settings $.flexslider.defaults = { namespace: "flex-", //{NEW} String: Prefix string attached to the class of every element generated by the plugin selector: ".slides > li", //{NEW} Selector: Must match a simple pattern. '{container} > {slide}' -- Ignore pattern at your own peril animation: "fade", //String: Select your animation type, "fade" or "slide" easing: "swing", //{NEW} String: Determines the easing method used in jQuery transitions. jQuery easing plugin is supported! direction: "horizontal", //String: Select the sliding direction, "horizontal" or "vertical" reverse: false, //{NEW} Boolean: Reverse the animation direction animationLoop: true, //Boolean: Should the animation loop? If false, directionNav will received "disable" classes at either end smoothHeight: false, //{NEW} Boolean: Allow height of the slider to animate smoothly in horizontal mode startAt: 0, //Integer: The slide that the slider should start on. Array notation (0 = first slide) slideshow: true, //Boolean: Animate slider automatically slideshowSpeed: 7000, //Integer: Set the speed of the slideshow cycling, in milliseconds animationSpeed: 600, //Integer: Set the speed of animations, in milliseconds initDelay: 0, //{NEW} Integer: Set an initialization delay, in milliseconds randomize: false, //Boolean: Randomize slide order // Usability features pauseOnAction: true, //Boolean: Pause the slideshow when interacting with control elements, highly recommended. pauseOnHover: false, //Boolean: Pause the slideshow when hovering over slider, then resume when no longer hovering useCSS: true, //{NEW} Boolean: Slider will use CSS3 transitions if available touch: true, //{NEW} Boolean: Allow touch swipe navigation of the slider on touch-enabled devices video: false, //{NEW} Boolean: If using video in the slider, will prevent CSS3 3D Transforms to avoid graphical glitches // Primary Controls controlNav: true, //Boolean: Create navigation for paging control of each clide? Note: Leave true for manualControls usage directionNav: true, //Boolean: Create navigation for previous/next navigation? (true/false) prevText: "Previous", //String: Set the text for the "previous" directionNav item nextText: "Next", //String: Set the text for the "next" directionNav item // Secondary Navigation keyboard: true, //Boolean: Allow slider navigating via keyboard left/right keys multipleKeyboard: false, //{NEW} Boolean: Allow keyboard navigation to affect multiple sliders. Default behavior cuts out keyboard navigation with more than one slider present. mousewheel: false, //{UPDATED} Boolean: Requires jquery.mousewheel.js (https://github.com/brandonaaron/jquery-mousewheel) - Allows slider navigating via mousewheel pausePlay: false, //Boolean: Create pause/play dynamic element pauseText: "Pause", //String: Set the text for the "pause" pausePlay item playText: "Play", //String: Set the text for the "play" pausePlay item // Special properties controlsContainer: "", //{UPDATED} jQuery Object/Selector: Declare which container the navigation elements should be appended too. Default container is the FlexSlider element. Example use would be $(".flexslider-container"). Property is ignored if given element is not found. manualControls: "", //{UPDATED} jQuery Object/Selector: Declare custom control navigation. Examples would be $(".flex-control-nav li") or "#tabs-nav li img", etc. The number of elements in your controlNav should match the number of slides/tabs. sync: "", //{NEW} Selector: Mirror the actions performed on this slider with another slider. Use with care. asNavFor: "", //{NEW} Selector: Internal property exposed for turning the slider into a thumbnail navigation for another slider // Carousel Options itemWidth: 0, //{NEW} Integer: Box-model width of individual carousel items, including horizontal borders and padding. itemMargin: 0, //{NEW} Integer: Margin between carousel items. minItems: 0, //{NEW} Integer: Minimum number of carousel items that should be visible. Items will resize fluidly when below this. maxItems: 0, //{NEW} Integer: Maxmimum number of carousel items that should be visible. Items will resize fluidly when above this limit. move: 0, //{NEW} Integer: Number of carousel items that should move on animation. If 0, slider will move all visible items. // Callback API start: function () {}, //Callback: function(slider) - Fires when the slider loads the first slide before: function () {}, //Callback: function(slider) - Fires asynchronously with each slider animation after: function () {}, //Callback: function(slider) - Fires after each slider animation completes end: function () {}, //Callback: function(slider) - Fires when the slider reaches the last slide (asynchronous) added: function () {}, //{NEW} Callback: function(slider) - Fires after a slide is added removed: function () {}, //{NEW} Callback: function(slider) - Fires after a slide is removed }; //FlexSlider: Plugin Function $.fn.flexslider = function (options) { if (options === undefined) options = {}; if (typeof options === "object") { return this.each(function () { var $this = $(this), selector = options.selector ? options.selector : ".slides > li", $slides = $this.find(selector); if ($slides.length === 1) { $slides.fadeIn(400); if (options.start) options.start($this); } else if ($this.data("flexslider") == undefined) { new $.flexslider(this, options); } }); } else { // Helper strings to quickly perform functions on the slider var $slider = $(this).data("flexslider"); switch (options) { case "play": $slider.play(); break; case "pause": $slider.pause(); break; case "next": $slider.flexAnimate($slider.getTarget("next"), true); break; case "prev": case "previous": $slider.flexAnimate($slider.getTarget("prev"), true); break; default: if (typeof options === "number") $slider.flexAnimate(options, true); } } }; })(jQuery); ================================================ FILE: backend/tests/integration/tests/pruning/website/js/jquery.js ================================================ /*! jQuery v1.9.1 | (c) 2005, 2012 jQuery Foundation, Inc. | jquery.org/license //@ sourceMappingURL=jquery.min.map */ (function (a, b) { function G(a) { var b = (F[a] = {}); return ( p.each(a.split(s), function (a, c) { b[c] = !0; }), b ); } function J(a, c, d) { if (d === b && a.nodeType === 1) { var e = "data-" + c.replace(I, "-$1").toLowerCase(); d = a.getAttribute(e); if (typeof d == "string") { try { d = d === "true" ? !0 : d === "false" ? !1 : d === "null" ? null : +d + "" === d ? +d : H.test(d) ? p.parseJSON(d) : d; } catch (f) {} p.data(a, c, d); } else d = b; } return d; } function K(a) { var b; for (b in a) { if (b === "data" && p.isEmptyObject(a[b])) continue; if (b !== "toJSON") return !1; } return !0; } function ba() { return !1; } function bb() { return !0; } function bh(a) { return !a || !a.parentNode || a.parentNode.nodeType === 11; } function bi(a, b) { do a = a[b]; while (a && a.nodeType !== 1); return a; } function bj(a, b, c) { b = b || 0; if (p.isFunction(b)) return p.grep(a, function (a, d) { var e = !!b.call(a, d, a); return e === c; }); if (b.nodeType) return p.grep(a, function (a, d) { return (a === b) === c; }); if (typeof b == "string") { var d = p.grep(a, function (a) { return a.nodeType === 1; }); if (be.test(b)) return p.filter(b, d, !c); b = p.filter(b, d); } return p.grep(a, function (a, d) { return p.inArray(a, b) >= 0 === c; }); } function bk(a) { var b = bl.split("|"), c = a.createDocumentFragment(); if (c.createElement) while (b.length) c.createElement(b.pop()); return c; } function bC(a, b) { return ( a.getElementsByTagName(b)[0] || a.appendChild(a.ownerDocument.createElement(b)) ); } function bD(a, b) { if (b.nodeType !== 1 || !p.hasData(a)) return; var c, d, e, f = p._data(a), g = p._data(b, f), h = f.events; if (h) { delete g.handle, (g.events = {}); for (c in h) for (d = 0, e = h[c].length; d < e; d++) p.event.add(b, c, h[c][d]); } g.data && (g.data = p.extend({}, g.data)); } function bE(a, b) { var c; if (b.nodeType !== 1) return; b.clearAttributes && b.clearAttributes(), b.mergeAttributes && b.mergeAttributes(a), (c = b.nodeName.toLowerCase()), c === "object" ? (b.parentNode && (b.outerHTML = a.outerHTML), p.support.html5Clone && a.innerHTML && !p.trim(b.innerHTML) && (b.innerHTML = a.innerHTML)) : c === "input" && bv.test(a.type) ? ((b.defaultChecked = b.checked = a.checked), b.value !== a.value && (b.value = a.value)) : c === "option" ? (b.selected = a.defaultSelected) : c === "input" || c === "textarea" ? (b.defaultValue = a.defaultValue) : c === "script" && b.text !== a.text && (b.text = a.text), b.removeAttribute(p.expando); } function bF(a) { return typeof a.getElementsByTagName != "undefined" ? a.getElementsByTagName("*") : typeof a.querySelectorAll != "undefined" ? a.querySelectorAll("*") : []; } function bG(a) { bv.test(a.type) && (a.defaultChecked = a.checked); } function bY(a, b) { if (b in a) return b; var c = b.charAt(0).toUpperCase() + b.slice(1), d = b, e = bW.length; while (e--) { b = bW[e] + c; if (b in a) return b; } return d; } function bZ(a, b) { return ( (a = b || a), p.css(a, "display") === "none" || !p.contains(a.ownerDocument, a) ); } function b$(a, b) { var c, d, e = [], f = 0, g = a.length; for (; f < g; f++) { c = a[f]; if (!c.style) continue; (e[f] = p._data(c, "olddisplay")), b ? (!e[f] && c.style.display === "none" && (c.style.display = ""), c.style.display === "" && bZ(c) && (e[f] = p._data(c, "olddisplay", cc(c.nodeName)))) : ((d = bH(c, "display")), !e[f] && d !== "none" && p._data(c, "olddisplay", d)); } for (f = 0; f < g; f++) { c = a[f]; if (!c.style) continue; if (!b || c.style.display === "none" || c.style.display === "") c.style.display = b ? e[f] || "" : "none"; } return a; } function b_(a, b, c) { var d = bP.exec(b); return d ? Math.max(0, d[1] - (c || 0)) + (d[2] || "px") : b; } function ca(a, b, c, d) { var e = c === (d ? "border" : "content") ? 4 : b === "width" ? 1 : 0, f = 0; for (; e < 4; e += 2) c === "margin" && (f += p.css(a, c + bV[e], !0)), d ? (c === "content" && (f -= parseFloat(bH(a, "padding" + bV[e])) || 0), c !== "margin" && (f -= parseFloat(bH(a, "border" + bV[e] + "Width")) || 0)) : ((f += parseFloat(bH(a, "padding" + bV[e])) || 0), c !== "padding" && (f += parseFloat(bH(a, "border" + bV[e] + "Width")) || 0)); return f; } function cb(a, b, c) { var d = b === "width" ? a.offsetWidth : a.offsetHeight, e = !0, f = p.support.boxSizing && p.css(a, "boxSizing") === "border-box"; if (d <= 0 || d == null) { d = bH(a, b); if (d < 0 || d == null) d = a.style[b]; if (bQ.test(d)) return d; (e = f && (p.support.boxSizingReliable || d === a.style[b])), (d = parseFloat(d) || 0); } return d + ca(a, b, c || (f ? "border" : "content"), e) + "px"; } function cc(a) { if (bS[a]) return bS[a]; var b = p("<" + a + ">").appendTo(e.body), c = b.css("display"); b.remove(); if (c === "none" || c === "") { bI = e.body.appendChild( bI || p.extend(e.createElement("iframe"), { frameBorder: 0, width: 0, height: 0, }), ); if (!bJ || !bI.createElement) (bJ = (bI.contentWindow || bI.contentDocument).document), bJ.write(""), bJ.close(); (b = bJ.body.appendChild(bJ.createElement(a))), (c = bH(b, "display")), e.body.removeChild(bI); } return (bS[a] = c), c; } function ci(a, b, c, d) { var e; if (p.isArray(b)) p.each(b, function (b, e) { c || ce.test(a) ? d(a, e) : ci(a + "[" + (typeof e == "object" ? b : "") + "]", e, c, d); }); else if (!c && p.type(b) === "object") for (e in b) ci(a + "[" + e + "]", b[e], c, d); else d(a, b); } function cz(a) { return function (b, c) { typeof b != "string" && ((c = b), (b = "*")); var d, e, f, g = b.toLowerCase().split(s), h = 0, i = g.length; if (p.isFunction(c)) for (; h < i; h++) (d = g[h]), (f = /^\+/.test(d)), f && (d = d.substr(1) || "*"), (e = a[d] = a[d] || []), e[f ? "unshift" : "push"](c); }; } function cA(a, c, d, e, f, g) { (f = f || c.dataTypes[0]), (g = g || {}), (g[f] = !0); var h, i = a[f], j = 0, k = i ? i.length : 0, l = a === cv; for (; j < k && (l || !h); j++) (h = i[j](c, d, e)), typeof h == "string" && (!l || g[h] ? (h = b) : (c.dataTypes.unshift(h), (h = cA(a, c, d, e, h, g)))); return (l || !h) && !g["*"] && (h = cA(a, c, d, e, "*", g)), h; } function cB(a, c) { var d, e, f = p.ajaxSettings.flatOptions || {}; for (d in c) c[d] !== b && ((f[d] ? a : e || (e = {}))[d] = c[d]); e && p.extend(!0, a, e); } function cC(a, c, d) { var e, f, g, h, i = a.contents, j = a.dataTypes, k = a.responseFields; for (f in k) f in d && (c[k[f]] = d[f]); while (j[0] === "*") j.shift(), e === b && (e = a.mimeType || c.getResponseHeader("content-type")); if (e) for (f in i) if (i[f] && i[f].test(e)) { j.unshift(f); break; } if (j[0] in d) g = j[0]; else { for (f in d) { if (!j[0] || a.converters[f + " " + j[0]]) { g = f; break; } h || (h = f); } g = g || h; } if (g) return g !== j[0] && j.unshift(g), d[g]; } function cD(a, b) { var c, d, e, f, g = a.dataTypes.slice(), h = g[0], i = {}, j = 0; a.dataFilter && (b = a.dataFilter(b, a.dataType)); if (g[1]) for (c in a.converters) i[c.toLowerCase()] = a.converters[c]; for (; (e = g[++j]); ) if (e !== "*") { if (h !== "*" && h !== e) { c = i[h + " " + e] || i["* " + e]; if (!c) for (d in i) { f = d.split(" "); if (f[1] === e) { c = i[h + " " + f[0]] || i["* " + f[0]]; if (c) { c === !0 ? (c = i[d]) : i[d] !== !0 && ((e = f[0]), g.splice(j--, 0, e)); break; } } } if (c !== !0) if (c && a["throws"]) b = c(b); else try { b = c(b); } catch (k) { return { state: "parsererror", error: c ? k : "No conversion from " + h + " to " + e, }; } } h = e; } return { state: "success", data: b }; } function cL() { try { return new a.XMLHttpRequest(); } catch (b) {} } function cM() { try { return new a.ActiveXObject("Microsoft.XMLHTTP"); } catch (b) {} } function cU() { return ( setTimeout(function () { cN = b; }, 0), (cN = p.now()) ); } function cV(a, b) { p.each(b, function (b, c) { var d = (cT[b] || []).concat(cT["*"]), e = 0, f = d.length; for (; e < f; e++) if (d[e].call(a, b, c)) return; }); } function cW(a, b, c) { var d, e = 0, f = 0, g = cS.length, h = p.Deferred().always(function () { delete i.elem; }), i = function () { var b = cN || cU(), c = Math.max(0, j.startTime + j.duration - b), d = 1 - (c / j.duration || 0), e = 0, f = j.tweens.length; for (; e < f; e++) j.tweens[e].run(d); return ( h.notifyWith(a, [j, d, c]), d < 1 && f ? c : (h.resolveWith(a, [j]), !1) ); }, j = h.promise({ elem: a, props: p.extend({}, b), opts: p.extend(!0, { specialEasing: {} }, c), originalProperties: b, originalOptions: c, startTime: cN || cU(), duration: c.duration, tweens: [], createTween: function (b, c, d) { var e = p.Tween( a, j.opts, b, c, j.opts.specialEasing[b] || j.opts.easing, ); return j.tweens.push(e), e; }, stop: function (b) { var c = 0, d = b ? j.tweens.length : 0; for (; c < d; c++) j.tweens[c].run(1); return b ? h.resolveWith(a, [j, b]) : h.rejectWith(a, [j, b]), this; }, }), k = j.props; cX(k, j.opts.specialEasing); for (; e < g; e++) { d = cS[e].call(j, a, k, j.opts); if (d) return d; } return ( cV(j, k), p.isFunction(j.opts.start) && j.opts.start.call(a, j), p.fx.timer(p.extend(i, { anim: j, queue: j.opts.queue, elem: a })), j .progress(j.opts.progress) .done(j.opts.done, j.opts.complete) .fail(j.opts.fail) .always(j.opts.always) ); } function cX(a, b) { var c, d, e, f, g; for (c in a) { (d = p.camelCase(c)), (e = b[d]), (f = a[c]), p.isArray(f) && ((e = f[1]), (f = a[c] = f[0])), c !== d && ((a[d] = f), delete a[c]), (g = p.cssHooks[d]); if (g && "expand" in g) { (f = g.expand(f)), delete a[d]; for (c in f) c in a || ((a[c] = f[c]), (b[c] = e)); } else b[d] = e; } } function cY(a, b, c) { var d, e, f, g, h, i, j, k, l = this, m = a.style, n = {}, o = [], q = a.nodeType && bZ(a); c.queue || ((j = p._queueHooks(a, "fx")), j.unqueued == null && ((j.unqueued = 0), (k = j.empty.fire), (j.empty.fire = function () { j.unqueued || k(); })), j.unqueued++, l.always(function () { l.always(function () { j.unqueued--, p.queue(a, "fx").length || j.empty.fire(); }); })), a.nodeType === 1 && ("height" in b || "width" in b) && ((c.overflow = [m.overflow, m.overflowX, m.overflowY]), p.css(a, "display") === "inline" && p.css(a, "float") === "none" && (!p.support.inlineBlockNeedsLayout || cc(a.nodeName) === "inline" ? (m.display = "inline-block") : (m.zoom = 1))), c.overflow && ((m.overflow = "hidden"), p.support.shrinkWrapBlocks || l.done(function () { (m.overflow = c.overflow[0]), (m.overflowX = c.overflow[1]), (m.overflowY = c.overflow[2]); })); for (d in b) { f = b[d]; if (cP.exec(f)) { delete b[d]; if (f === (q ? "hide" : "show")) continue; o.push(d); } } g = o.length; if (g) { (h = p._data(a, "fxshow") || p._data(a, "fxshow", {})), q ? p(a).show() : l.done(function () { p(a).hide(); }), l.done(function () { var b; p.removeData(a, "fxshow", !0); for (b in n) p.style(a, b, n[b]); }); for (d = 0; d < g; d++) (e = o[d]), (i = l.createTween(e, q ? h[e] : 0)), (n[e] = h[e] || p.style(a, e)), e in h || ((h[e] = i.start), q && ((i.end = i.start), (i.start = e === "width" || e === "height" ? 1 : 0))); } } function cZ(a, b, c, d, e) { return new cZ.prototype.init(a, b, c, d, e); } function c$(a, b) { var c, d = { height: a }, e = 0; b = b ? 1 : 0; for (; e < 4; e += 2 - b) (c = bV[e]), (d["margin" + c] = d["padding" + c] = a); return b && (d.opacity = d.width = a), d; } function da(a) { return p.isWindow(a) ? a : a.nodeType === 9 ? a.defaultView || a.parentWindow : !1; } var c, d, e = a.document, f = a.location, g = a.navigator, h = a.jQuery, i = a.$, j = Array.prototype.push, k = Array.prototype.slice, l = Array.prototype.indexOf, m = Object.prototype.toString, n = Object.prototype.hasOwnProperty, o = String.prototype.trim, p = function (a, b) { return new p.fn.init(a, b, c); }, q = /[\-+]?(?:\d*\.|)\d+(?:[eE][\-+]?\d+|)/.source, r = /\S/, s = /\s+/, t = /^[\s\uFEFF\xA0]+|[\s\uFEFF\xA0]+$/g, u = /^(?:[^#<]*(<[\w\W]+>)[^>]*$|#([\w\-]*)$)/, v = /^<(\w+)\s*\/?>(?:<\/\1>|)$/, w = /^[\],:{}\s]*$/, x = /(?:^|:|,)(?:\s*\[)+/g, y = /\\(?:["\\\/bfnrt]|u[\da-fA-F]{4})/g, z = /"[^"\\\r\n]*"|true|false|null|-?(?:\d\d*\.|)\d+(?:[eE][\-+]?\d+|)/g, A = /^-ms-/, B = /-([\da-z])/gi, C = function (a, b) { return (b + "").toUpperCase(); }, D = function () { e.addEventListener ? (e.removeEventListener("DOMContentLoaded", D, !1), p.ready()) : e.readyState === "complete" && (e.detachEvent("onreadystatechange", D), p.ready()); }, E = {}; (p.fn = p.prototype = { constructor: p, init: function (a, c, d) { var f, g, h, i; if (!a) return this; if (a.nodeType) return (this.context = this[0] = a), (this.length = 1), this; if (typeof a == "string") { a.charAt(0) === "<" && a.charAt(a.length - 1) === ">" && a.length >= 3 ? (f = [null, a, null]) : (f = u.exec(a)); if (f && (f[1] || !c)) { if (f[1]) return ( (c = c instanceof p ? c[0] : c), (i = c && c.nodeType ? c.ownerDocument || c : e), (a = p.parseHTML(f[1], i, !0)), v.test(f[1]) && p.isPlainObject(c) && this.attr.call(a, c, !0), p.merge(this, a) ); g = e.getElementById(f[2]); if (g && g.parentNode) { if (g.id !== f[2]) return d.find(a); (this.length = 1), (this[0] = g); } return (this.context = e), (this.selector = a), this; } return !c || c.jquery ? (c || d).find(a) : this.constructor(c).find(a); } return p.isFunction(a) ? d.ready(a) : (a.selector !== b && ((this.selector = a.selector), (this.context = a.context)), p.makeArray(a, this)); }, selector: "", jquery: "1.8.1", length: 0, size: function () { return this.length; }, toArray: function () { return k.call(this); }, get: function (a) { return a == null ? this.toArray() : a < 0 ? this[this.length + a] : this[a]; }, pushStack: function (a, b, c) { var d = p.merge(this.constructor(), a); return ( (d.prevObject = this), (d.context = this.context), b === "find" ? (d.selector = this.selector + (this.selector ? " " : "") + c) : b && (d.selector = this.selector + "." + b + "(" + c + ")"), d ); }, each: function (a, b) { return p.each(this, a, b); }, ready: function (a) { return p.ready.promise().done(a), this; }, eq: function (a) { return (a = +a), a === -1 ? this.slice(a) : this.slice(a, a + 1); }, first: function () { return this.eq(0); }, last: function () { return this.eq(-1); }, slice: function () { return this.pushStack( k.apply(this, arguments), "slice", k.call(arguments).join(","), ); }, map: function (a) { return this.pushStack( p.map(this, function (b, c) { return a.call(b, c, b); }), ); }, end: function () { return this.prevObject || this.constructor(null); }, push: j, sort: [].sort, splice: [].splice, }), (p.fn.init.prototype = p.fn), (p.extend = p.fn.extend = function () { var a, c, d, e, f, g, h = arguments[0] || {}, i = 1, j = arguments.length, k = !1; typeof h == "boolean" && ((k = h), (h = arguments[1] || {}), (i = 2)), typeof h != "object" && !p.isFunction(h) && (h = {}), j === i && ((h = this), --i); for (; i < j; i++) if ((a = arguments[i]) != null) for (c in a) { (d = h[c]), (e = a[c]); if (h === e) continue; k && e && (p.isPlainObject(e) || (f = p.isArray(e))) ? (f ? ((f = !1), (g = d && p.isArray(d) ? d : [])) : (g = d && p.isPlainObject(d) ? d : {}), (h[c] = p.extend(k, g, e))) : e !== b && (h[c] = e); } return h; }), p.extend({ noConflict: function (b) { return a.$ === p && (a.$ = i), b && a.jQuery === p && (a.jQuery = h), p; }, isReady: !1, readyWait: 1, holdReady: function (a) { a ? p.readyWait++ : p.ready(!0); }, ready: function (a) { if (a === !0 ? --p.readyWait : p.isReady) return; if (!e.body) return setTimeout(p.ready, 1); p.isReady = !0; if (a !== !0 && --p.readyWait > 0) return; d.resolveWith(e, [p]), p.fn.trigger && p(e).trigger("ready").off("ready"); }, isFunction: function (a) { return p.type(a) === "function"; }, isArray: Array.isArray || function (a) { return p.type(a) === "array"; }, isWindow: function (a) { return a != null && a == a.window; }, isNumeric: function (a) { return !isNaN(parseFloat(a)) && isFinite(a); }, type: function (a) { return a == null ? String(a) : E[m.call(a)] || "object"; }, isPlainObject: function (a) { if (!a || p.type(a) !== "object" || a.nodeType || p.isWindow(a)) return !1; try { if ( a.constructor && !n.call(a, "constructor") && !n.call(a.constructor.prototype, "isPrototypeOf") ) return !1; } catch (c) { return !1; } var d; for (d in a); return d === b || n.call(a, d); }, isEmptyObject: function (a) { var b; for (b in a) return !1; return !0; }, error: function (a) { throw new Error(a); }, parseHTML: function (a, b, c) { var d; return !a || typeof a != "string" ? null : (typeof b == "boolean" && ((c = b), (b = 0)), (b = b || e), (d = v.exec(a)) ? [b.createElement(d[1])] : ((d = p.buildFragment([a], b, c ? null : [])), p.merge( [], (d.cacheable ? p.clone(d.fragment) : d.fragment).childNodes, ))); }, parseJSON: function (b) { if (!b || typeof b != "string") return null; b = p.trim(b); if (a.JSON && a.JSON.parse) return a.JSON.parse(b); if (w.test(b.replace(y, "@").replace(z, "]").replace(x, ""))) return new Function("return " + b)(); p.error("Invalid JSON: " + b); }, parseXML: function (c) { var d, e; if (!c || typeof c != "string") return null; try { a.DOMParser ? ((e = new DOMParser()), (d = e.parseFromString(c, "text/xml"))) : ((d = new ActiveXObject("Microsoft.XMLDOM")), (d.async = "false"), d.loadXML(c)); } catch (f) { d = b; } return ( (!d || !d.documentElement || d.getElementsByTagName("parsererror").length) && p.error("Invalid XML: " + c), d ); }, noop: function () {}, globalEval: function (b) { b && r.test(b) && ( a.execScript || function (b) { a.eval.call(a, b); } )(b); }, camelCase: function (a) { return a.replace(A, "ms-").replace(B, C); }, nodeName: function (a, b) { return a.nodeName && a.nodeName.toUpperCase() === b.toUpperCase(); }, each: function (a, c, d) { var e, f = 0, g = a.length, h = g === b || p.isFunction(a); if (d) { if (h) { for (e in a) if (c.apply(a[e], d) === !1) break; } else for (; f < g; ) if (c.apply(a[f++], d) === !1) break; } else if (h) { for (e in a) if (c.call(a[e], e, a[e]) === !1) break; } else for (; f < g; ) if (c.call(a[f], f, a[f++]) === !1) break; return a; }, trim: o && !o.call(" ") ? function (a) { return a == null ? "" : o.call(a); } : function (a) { return a == null ? "" : a.toString().replace(t, ""); }, makeArray: function (a, b) { var c, d = b || []; return ( a != null && ((c = p.type(a)), a.length == null || c === "string" || c === "function" || c === "regexp" || p.isWindow(a) ? j.call(d, a) : p.merge(d, a)), d ); }, inArray: function (a, b, c) { var d; if (b) { if (l) return l.call(b, a, c); (d = b.length), (c = c ? (c < 0 ? Math.max(0, d + c) : c) : 0); for (; c < d; c++) if (c in b && b[c] === a) return c; } return -1; }, merge: function (a, c) { var d = c.length, e = a.length, f = 0; if (typeof d == "number") for (; f < d; f++) a[e++] = c[f]; else while (c[f] !== b) a[e++] = c[f++]; return (a.length = e), a; }, grep: function (a, b, c) { var d, e = [], f = 0, g = a.length; c = !!c; for (; f < g; f++) (d = !!b(a[f], f)), c !== d && e.push(a[f]); return e; }, map: function (a, c, d) { var e, f, g = [], h = 0, i = a.length, j = a instanceof p || (i !== b && typeof i == "number" && ((i > 0 && a[0] && a[i - 1]) || i === 0 || p.isArray(a))); if (j) for (; h < i; h++) (e = c(a[h], h, d)), e != null && (g[g.length] = e); else for (f in a) (e = c(a[f], f, d)), e != null && (g[g.length] = e); return g.concat.apply([], g); }, guid: 1, proxy: function (a, c) { var d, e, f; return ( typeof c == "string" && ((d = a[c]), (c = a), (a = d)), p.isFunction(a) ? ((e = k.call(arguments, 2)), (f = function () { return a.apply(c, e.concat(k.call(arguments))); }), (f.guid = a.guid = a.guid || f.guid || p.guid++), f) : b ); }, access: function (a, c, d, e, f, g, h) { var i, j = d == null, k = 0, l = a.length; if (d && typeof d == "object") { for (k in d) p.access(a, c, k, d[k], 1, g, e); f = 1; } else if (e !== b) { (i = h === b && p.isFunction(e)), j && (i ? ((i = c), (c = function (a, b, c) { return i.call(p(a), c); })) : (c.call(a, e), (c = null))); if (c) for (; k < l; k++) c(a[k], d, i ? e.call(a[k], k, c(a[k], d)) : e, h); f = 1; } return f ? a : j ? c.call(a) : l ? c(a[0], d) : g; }, now: function () { return new Date().getTime(); }, }), (p.ready.promise = function (b) { if (!d) { d = p.Deferred(); if (e.readyState === "complete") setTimeout(p.ready, 1); else if (e.addEventListener) e.addEventListener("DOMContentLoaded", D, !1), a.addEventListener("load", p.ready, !1); else { e.attachEvent("onreadystatechange", D), a.attachEvent("onload", p.ready); var c = !1; try { c = a.frameElement == null && e.documentElement; } catch (f) {} c && c.doScroll && (function g() { if (!p.isReady) { try { c.doScroll("left"); } catch (a) { return setTimeout(g, 50); } p.ready(); } })(); } } return d.promise(b); }), p.each( "Boolean Number String Function Array Date RegExp Object".split(" "), function (a, b) { E["[object " + b + "]"] = b.toLowerCase(); }, ), (c = p(e)); var F = {}; (p.Callbacks = function (a) { a = typeof a == "string" ? F[a] || G(a) : p.extend({}, a); var c, d, e, f, g, h, i = [], j = !a.once && [], k = function (b) { (c = a.memory && b), (d = !0), (h = f || 0), (f = 0), (g = i.length), (e = !0); for (; i && h < g; h++) if (i[h].apply(b[0], b[1]) === !1 && a.stopOnFalse) { c = !1; break; } (e = !1), i && (j ? j.length && k(j.shift()) : c ? (i = []) : l.disable()); }, l = { add: function () { if (i) { var b = i.length; (function d(b) { p.each(b, function (b, c) { var e = p.type(c); e === "function" && (!a.unique || !l.has(c)) ? i.push(c) : c && c.length && e !== "string" && d(c); }); })(arguments), e ? (g = i.length) : c && ((f = b), k(c)); } return this; }, remove: function () { return ( i && p.each(arguments, function (a, b) { var c; while ((c = p.inArray(b, i, c)) > -1) i.splice(c, 1), e && (c <= g && g--, c <= h && h--); }), this ); }, has: function (a) { return p.inArray(a, i) > -1; }, empty: function () { return (i = []), this; }, disable: function () { return (i = j = c = b), this; }, disabled: function () { return !i; }, lock: function () { return (j = b), c || l.disable(), this; }, locked: function () { return !j; }, fireWith: function (a, b) { return ( (b = b || []), (b = [a, b.slice ? b.slice() : b]), i && (!d || j) && (e ? j.push(b) : k(b)), this ); }, fire: function () { return l.fireWith(this, arguments), this; }, fired: function () { return !!d; }, }; return l; }), p.extend({ Deferred: function (a) { var b = [ ["resolve", "done", p.Callbacks("once memory"), "resolved"], ["reject", "fail", p.Callbacks("once memory"), "rejected"], ["notify", "progress", p.Callbacks("memory")], ], c = "pending", d = { state: function () { return c; }, always: function () { return e.done(arguments).fail(arguments), this; }, then: function () { var a = arguments; return p .Deferred(function (c) { p.each(b, function (b, d) { var f = d[0], g = a[b]; e[d[1]]( p.isFunction(g) ? function () { var a = g.apply(this, arguments); a && p.isFunction(a.promise) ? a .promise() .done(c.resolve) .fail(c.reject) .progress(c.notify) : c[f + "With"](this === e ? c : this, [a]); } : c[f], ); }), (a = null); }) .promise(); }, promise: function (a) { return typeof a == "object" ? p.extend(a, d) : d; }, }, e = {}; return ( (d.pipe = d.then), p.each(b, function (a, f) { var g = f[2], h = f[3]; (d[f[1]] = g.add), h && g.add( function () { c = h; }, b[a ^ 1][2].disable, b[2][2].lock, ), (e[f[0]] = g.fire), (e[f[0] + "With"] = g.fireWith); }), d.promise(e), a && a.call(e, e), e ); }, when: function (a) { var b = 0, c = k.call(arguments), d = c.length, e = d !== 1 || (a && p.isFunction(a.promise)) ? d : 0, f = e === 1 ? a : p.Deferred(), g = function (a, b, c) { return function (d) { (b[a] = this), (c[a] = arguments.length > 1 ? k.call(arguments) : d), c === h ? f.notifyWith(b, c) : --e || f.resolveWith(b, c); }; }, h, i, j; if (d > 1) { (h = new Array(d)), (i = new Array(d)), (j = new Array(d)); for (; b < d; b++) c[b] && p.isFunction(c[b].promise) ? c[b] .promise() .done(g(b, j, c)) .fail(f.reject) .progress(g(b, i, h)) : --e; } return e || f.resolveWith(j, c), f.promise(); }, }), (p.support = (function () { var b, c, d, f, g, h, i, j, k, l, m, n = e.createElement("div"); n.setAttribute("className", "t"), (n.innerHTML = "
        a"), (c = n.getElementsByTagName("*")), (d = n.getElementsByTagName("a")[0]), (d.style.cssText = "top:1px;float:left;opacity:.5"); if (!c || !c.length || !d) return {}; (f = e.createElement("select")), (g = f.appendChild(e.createElement("option"))), (h = n.getElementsByTagName("input")[0]), (b = { leadingWhitespace: n.firstChild.nodeType === 3, tbody: !n.getElementsByTagName("tbody").length, htmlSerialize: !!n.getElementsByTagName("link").length, style: /top/.test(d.getAttribute("style")), hrefNormalized: d.getAttribute("href") === "/a", opacity: /^0.5/.test(d.style.opacity), cssFloat: !!d.style.cssFloat, checkOn: h.value === "on", optSelected: g.selected, getSetAttribute: n.className !== "t", enctype: !!e.createElement("form").enctype, html5Clone: e.createElement("nav").cloneNode(!0).outerHTML !== "<:nav>", boxModel: e.compatMode === "CSS1Compat", submitBubbles: !0, changeBubbles: !0, focusinBubbles: !1, deleteExpando: !0, noCloneEvent: !0, inlineBlockNeedsLayout: !1, shrinkWrapBlocks: !1, reliableMarginRight: !0, boxSizingReliable: !0, pixelPosition: !1, }), (h.checked = !0), (b.noCloneChecked = h.cloneNode(!0).checked), (f.disabled = !0), (b.optDisabled = !g.disabled); try { delete n.test; } catch (o) { b.deleteExpando = !1; } !n.addEventListener && n.attachEvent && n.fireEvent && (n.attachEvent( "onclick", (m = function () { b.noCloneEvent = !1; }), ), n.cloneNode(!0).fireEvent("onclick"), n.detachEvent("onclick", m)), (h = e.createElement("input")), (h.value = "t"), h.setAttribute("type", "radio"), (b.radioValue = h.value === "t"), h.setAttribute("checked", "checked"), h.setAttribute("name", "t"), n.appendChild(h), (i = e.createDocumentFragment()), i.appendChild(n.lastChild), (b.checkClone = i.cloneNode(!0).cloneNode(!0).lastChild.checked), (b.appendChecked = h.checked), i.removeChild(h), i.appendChild(n); if (n.attachEvent) for (k in { submit: !0, change: !0, focusin: !0 }) (j = "on" + k), (l = j in n), l || (n.setAttribute(j, "return;"), (l = typeof n[j] == "function")), (b[k + "Bubbles"] = l); return ( p(function () { var c, d, f, g, h = "padding:0;margin:0;border:0;display:block;overflow:hidden;", i = e.getElementsByTagName("body")[0]; if (!i) return; (c = e.createElement("div")), (c.style.cssText = "visibility:hidden;border:0;width:0;height:0;position:static;top:0;margin-top:1px"), i.insertBefore(c, i.firstChild), (d = e.createElement("div")), c.appendChild(d), (d.innerHTML = "
        t
        "), (f = d.getElementsByTagName("td")), (f[0].style.cssText = "padding:0;margin:0;border:0;display:none"), (l = f[0].offsetHeight === 0), (f[0].style.display = ""), (f[1].style.display = "none"), (b.reliableHiddenOffsets = l && f[0].offsetHeight === 0), (d.innerHTML = ""), (d.style.cssText = "box-sizing:border-box;-moz-box-sizing:border-box;-webkit-box-sizing:border-box;padding:1px;border:1px;display:block;width:4px;margin-top:1%;position:absolute;top:1%;"), (b.boxSizing = d.offsetWidth === 4), (b.doesNotIncludeMarginInBodyOffset = i.offsetTop !== 1), a.getComputedStyle && ((b.pixelPosition = (a.getComputedStyle(d, null) || {}).top !== "1%"), (b.boxSizingReliable = (a.getComputedStyle(d, null) || { width: "4px" }).width === "4px"), (g = e.createElement("div")), (g.style.cssText = d.style.cssText = h), (g.style.marginRight = g.style.width = "0"), (d.style.width = "1px"), d.appendChild(g), (b.reliableMarginRight = !parseFloat( (a.getComputedStyle(g, null) || {}).marginRight, ))), typeof d.style.zoom != "undefined" && ((d.innerHTML = ""), (d.style.cssText = h + "width:1px;padding:1px;display:inline;zoom:1"), (b.inlineBlockNeedsLayout = d.offsetWidth === 3), (d.style.display = "block"), (d.style.overflow = "visible"), (d.innerHTML = "
        "), (d.firstChild.style.width = "5px"), (b.shrinkWrapBlocks = d.offsetWidth !== 3), (c.style.zoom = 1)), i.removeChild(c), (c = d = f = g = null); }), i.removeChild(n), (c = d = f = g = h = i = n = null), b ); })()); var H = /(?:\{[\s\S]*\}|\[[\s\S]*\])$/, I = /([A-Z])/g; p.extend({ cache: {}, deletedIds: [], uuid: 0, expando: "jQuery" + (p.fn.jquery + Math.random()).replace(/\D/g, ""), noData: { embed: !0, object: "clsid:D27CDB6E-AE6D-11cf-96B8-444553540000", applet: !0, }, hasData: function (a) { return ( (a = a.nodeType ? p.cache[a[p.expando]] : a[p.expando]), !!a && !K(a) ); }, data: function (a, c, d, e) { if (!p.acceptData(a)) return; var f, g, h = p.expando, i = typeof c == "string", j = a.nodeType, k = j ? p.cache : a, l = j ? a[h] : a[h] && h; if ((!l || !k[l] || (!e && !k[l].data)) && i && d === b) return; l || (j ? (a[h] = l = p.deletedIds.pop() || ++p.uuid) : (l = h)), k[l] || ((k[l] = {}), j || (k[l].toJSON = p.noop)); if (typeof c == "object" || typeof c == "function") e ? (k[l] = p.extend(k[l], c)) : (k[l].data = p.extend(k[l].data, c)); return ( (f = k[l]), e || (f.data || (f.data = {}), (f = f.data)), d !== b && (f[p.camelCase(c)] = d), i ? ((g = f[c]), g == null && (g = f[p.camelCase(c)])) : (g = f), g ); }, removeData: function (a, b, c) { if (!p.acceptData(a)) return; var d, e, f, g = a.nodeType, h = g ? p.cache : a, i = g ? a[p.expando] : p.expando; if (!h[i]) return; if (b) { d = c ? h[i] : h[i].data; if (d) { p.isArray(b) || (b in d ? (b = [b]) : ((b = p.camelCase(b)), b in d ? (b = [b]) : (b = b.split(" ")))); for (e = 0, f = b.length; e < f; e++) delete d[b[e]]; if (!(c ? K : p.isEmptyObject)(d)) return; } } if (!c) { delete h[i].data; if (!K(h[i])) return; } g ? p.cleanData([a], !0) : p.support.deleteExpando || h != h.window ? delete h[i] : (h[i] = null); }, _data: function (a, b, c) { return p.data(a, b, c, !0); }, acceptData: function (a) { var b = a.nodeName && p.noData[a.nodeName.toLowerCase()]; return !b || (b !== !0 && a.getAttribute("classid") === b); }, }), p.fn.extend({ data: function (a, c) { var d, e, f, g, h, i = this[0], j = 0, k = null; if (a === b) { if (this.length) { k = p.data(i); if (i.nodeType === 1 && !p._data(i, "parsedAttrs")) { f = i.attributes; for (h = f.length; j < h; j++) (g = f[j].name), g.indexOf("data-") === 0 && ((g = p.camelCase(g.substring(5))), J(i, g, k[g])); p._data(i, "parsedAttrs", !0); } } return k; } return typeof a == "object" ? this.each(function () { p.data(this, a); }) : ((d = a.split(".", 2)), (d[1] = d[1] ? "." + d[1] : ""), (e = d[1] + "!"), p.access( this, function (c) { if (c === b) return ( (k = this.triggerHandler("getData" + e, [d[0]])), k === b && i && ((k = p.data(i, a)), (k = J(i, a, k))), k === b && d[1] ? this.data(d[0]) : k ); (d[1] = c), this.each(function () { var b = p(this); b.triggerHandler("setData" + e, d), p.data(this, a, c), b.triggerHandler("changeData" + e, d); }); }, null, c, arguments.length > 1, null, !1, )); }, removeData: function (a) { return this.each(function () { p.removeData(this, a); }); }, }), p.extend({ queue: function (a, b, c) { var d; if (a) return ( (b = (b || "fx") + "queue"), (d = p._data(a, b)), c && (!d || p.isArray(c) ? (d = p._data(a, b, p.makeArray(c))) : d.push(c)), d || [] ); }, dequeue: function (a, b) { b = b || "fx"; var c = p.queue(a, b), d = c.length, e = c.shift(), f = p._queueHooks(a, b), g = function () { p.dequeue(a, b); }; e === "inprogress" && ((e = c.shift()), d--), e && (b === "fx" && c.unshift("inprogress"), delete f.stop, e.call(a, g, f)), !d && f && f.empty.fire(); }, _queueHooks: function (a, b) { var c = b + "queueHooks"; return ( p._data(a, c) || p._data(a, c, { empty: p.Callbacks("once memory").add(function () { p.removeData(a, b + "queue", !0), p.removeData(a, c, !0); }), }) ); }, }), p.fn.extend({ queue: function (a, c) { var d = 2; return ( typeof a != "string" && ((c = a), (a = "fx"), d--), arguments.length < d ? p.queue(this[0], a) : c === b ? this : this.each(function () { var b = p.queue(this, a, c); p._queueHooks(this, a), a === "fx" && b[0] !== "inprogress" && p.dequeue(this, a); }) ); }, dequeue: function (a) { return this.each(function () { p.dequeue(this, a); }); }, delay: function (a, b) { return ( (a = p.fx ? p.fx.speeds[a] || a : a), (b = b || "fx"), this.queue(b, function (b, c) { var d = setTimeout(b, a); c.stop = function () { clearTimeout(d); }; }) ); }, clearQueue: function (a) { return this.queue(a || "fx", []); }, promise: function (a, c) { var d, e = 1, f = p.Deferred(), g = this, h = this.length, i = function () { --e || f.resolveWith(g, [g]); }; typeof a != "string" && ((c = a), (a = b)), (a = a || "fx"); while (h--) (d = p._data(g[h], a + "queueHooks")), d && d.empty && (e++, d.empty.add(i)); return i(), f.promise(c); }, }); var L, M, N, O = /[\t\r\n]/g, P = /\r/g, Q = /^(?:button|input)$/i, R = /^(?:button|input|object|select|textarea)$/i, S = /^a(?:rea|)$/i, T = /^(?:autofocus|autoplay|async|checked|controls|defer|disabled|hidden|loop|multiple|open|readonly|required|scoped|selected)$/i, U = p.support.getSetAttribute; p.fn.extend({ attr: function (a, b) { return p.access(this, p.attr, a, b, arguments.length > 1); }, removeAttr: function (a) { return this.each(function () { p.removeAttr(this, a); }); }, prop: function (a, b) { return p.access(this, p.prop, a, b, arguments.length > 1); }, removeProp: function (a) { return ( (a = p.propFix[a] || a), this.each(function () { try { (this[a] = b), delete this[a]; } catch (c) {} }) ); }, addClass: function (a) { var b, c, d, e, f, g, h; if (p.isFunction(a)) return this.each(function (b) { p(this).addClass(a.call(this, b, this.className)); }); if (a && typeof a == "string") { b = a.split(s); for (c = 0, d = this.length; c < d; c++) { e = this[c]; if (e.nodeType === 1) if (!e.className && b.length === 1) e.className = a; else { f = " " + e.className + " "; for (g = 0, h = b.length; g < h; g++) ~f.indexOf(" " + b[g] + " ") || (f += b[g] + " "); e.className = p.trim(f); } } } return this; }, removeClass: function (a) { var c, d, e, f, g, h, i; if (p.isFunction(a)) return this.each(function (b) { p(this).removeClass(a.call(this, b, this.className)); }); if ((a && typeof a == "string") || a === b) { c = (a || "").split(s); for (h = 0, i = this.length; h < i; h++) { e = this[h]; if (e.nodeType === 1 && e.className) { d = (" " + e.className + " ").replace(O, " "); for (f = 0, g = c.length; f < g; f++) while (d.indexOf(" " + c[f] + " ") > -1) d = d.replace(" " + c[f] + " ", " "); e.className = a ? p.trim(d) : ""; } } } return this; }, toggleClass: function (a, b) { var c = typeof a, d = typeof b == "boolean"; return p.isFunction(a) ? this.each(function (c) { p(this).toggleClass(a.call(this, c, this.className, b), b); }) : this.each(function () { if (c === "string") { var e, f = 0, g = p(this), h = b, i = a.split(s); while ((e = i[f++])) (h = d ? h : !g.hasClass(e)), g[h ? "addClass" : "removeClass"](e); } else if (c === "undefined" || c === "boolean") this.className && p._data(this, "__className__", this.className), (this.className = this.className || a === !1 ? "" : p._data(this, "__className__") || ""); }); }, hasClass: function (a) { var b = " " + a + " ", c = 0, d = this.length; for (; c < d; c++) if ( this[c].nodeType === 1 && (" " + this[c].className + " ").replace(O, " ").indexOf(b) > -1 ) return !0; return !1; }, val: function (a) { var c, d, e, f = this[0]; if (!arguments.length) { if (f) return ( (c = p.valHooks[f.type] || p.valHooks[f.nodeName.toLowerCase()]), c && "get" in c && (d = c.get(f, "value")) !== b ? d : ((d = f.value), typeof d == "string" ? d.replace(P, "") : d == null ? "" : d) ); return; } return ( (e = p.isFunction(a)), this.each(function (d) { var f, g = p(this); if (this.nodeType !== 1) return; e ? (f = a.call(this, d, g.val())) : (f = a), f == null ? (f = "") : typeof f == "number" ? (f += "") : p.isArray(f) && (f = p.map(f, function (a) { return a == null ? "" : a + ""; })), (c = p.valHooks[this.type] || p.valHooks[this.nodeName.toLowerCase()]); if (!c || !("set" in c) || c.set(this, f, "value") === b) this.value = f; }) ); }, }), p.extend({ valHooks: { option: { get: function (a) { var b = a.attributes.value; return !b || b.specified ? a.value : a.text; }, }, select: { get: function (a) { var b, c, d, e, f = a.selectedIndex, g = [], h = a.options, i = a.type === "select-one"; if (f < 0) return null; (c = i ? f : 0), (d = i ? f + 1 : h.length); for (; c < d; c++) { e = h[c]; if ( e.selected && (p.support.optDisabled ? !e.disabled : e.getAttribute("disabled") === null) && (!e.parentNode.disabled || !p.nodeName(e.parentNode, "optgroup")) ) { b = p(e).val(); if (i) return b; g.push(b); } } return i && !g.length && h.length ? p(h[f]).val() : g; }, set: function (a, b) { var c = p.makeArray(b); return ( p(a) .find("option") .each(function () { this.selected = p.inArray(p(this).val(), c) >= 0; }), c.length || (a.selectedIndex = -1), c ); }, }, }, attrFn: {}, attr: function (a, c, d, e) { var f, g, h, i = a.nodeType; if (!a || i === 3 || i === 8 || i === 2) return; if (e && p.isFunction(p.fn[c])) return p(a)[c](d); if (typeof a.getAttribute == "undefined") return p.prop(a, c, d); (h = i !== 1 || !p.isXMLDoc(a)), h && ((c = c.toLowerCase()), (g = p.attrHooks[c] || (T.test(c) ? M : L))); if (d !== b) { if (d === null) { p.removeAttr(a, c); return; } return g && "set" in g && h && (f = g.set(a, d, c)) !== b ? f : (a.setAttribute(c, "" + d), d); } return g && "get" in g && h && (f = g.get(a, c)) !== null ? f : ((f = a.getAttribute(c)), f === null ? b : f); }, removeAttr: function (a, b) { var c, d, e, f, g = 0; if (b && a.nodeType === 1) { d = b.split(s); for (; g < d.length; g++) (e = d[g]), e && ((c = p.propFix[e] || e), (f = T.test(e)), f || p.attr(a, e, ""), a.removeAttribute(U ? e : c), f && c in a && (a[c] = !1)); } }, attrHooks: { type: { set: function (a, b) { if (Q.test(a.nodeName) && a.parentNode) p.error("type property can't be changed"); else if ( !p.support.radioValue && b === "radio" && p.nodeName(a, "input") ) { var c = a.value; return a.setAttribute("type", b), c && (a.value = c), b; } }, }, value: { get: function (a, b) { return L && p.nodeName(a, "button") ? L.get(a, b) : b in a ? a.value : null; }, set: function (a, b, c) { if (L && p.nodeName(a, "button")) return L.set(a, b, c); a.value = b; }, }, }, propFix: { tabindex: "tabIndex", readonly: "readOnly", for: "htmlFor", class: "className", maxlength: "maxLength", cellspacing: "cellSpacing", cellpadding: "cellPadding", rowspan: "rowSpan", colspan: "colSpan", usemap: "useMap", frameborder: "frameBorder", contenteditable: "contentEditable", }, prop: function (a, c, d) { var e, f, g, h = a.nodeType; if (!a || h === 3 || h === 8 || h === 2) return; return ( (g = h !== 1 || !p.isXMLDoc(a)), g && ((c = p.propFix[c] || c), (f = p.propHooks[c])), d !== b ? f && "set" in f && (e = f.set(a, d, c)) !== b ? e : (a[c] = d) : f && "get" in f && (e = f.get(a, c)) !== null ? e : a[c] ); }, propHooks: { tabIndex: { get: function (a) { var c = a.getAttributeNode("tabindex"); return c && c.specified ? parseInt(c.value, 10) : R.test(a.nodeName) || (S.test(a.nodeName) && a.href) ? 0 : b; }, }, }, }), (M = { get: function (a, c) { var d, e = p.prop(a, c); return e === !0 || (typeof e != "boolean" && (d = a.getAttributeNode(c)) && d.nodeValue !== !1) ? c.toLowerCase() : b; }, set: function (a, b, c) { var d; return ( b === !1 ? p.removeAttr(a, c) : ((d = p.propFix[c] || c), d in a && (a[d] = !0), a.setAttribute(c, c.toLowerCase())), c ); }, }), U || ((N = { name: !0, id: !0, coords: !0 }), (L = p.valHooks.button = { get: function (a, c) { var d; return ( (d = a.getAttributeNode(c)), d && (N[c] ? d.value !== "" : d.specified) ? d.value : b ); }, set: function (a, b, c) { var d = a.getAttributeNode(c); return ( d || ((d = e.createAttribute(c)), a.setAttributeNode(d)), (d.value = b + "") ); }, }), p.each(["width", "height"], function (a, b) { p.attrHooks[b] = p.extend(p.attrHooks[b], { set: function (a, c) { if (c === "") return a.setAttribute(b, "auto"), c; }, }); }), (p.attrHooks.contenteditable = { get: L.get, set: function (a, b, c) { b === "" && (b = "false"), L.set(a, b, c); }, })), p.support.hrefNormalized || p.each(["href", "src", "width", "height"], function (a, c) { p.attrHooks[c] = p.extend(p.attrHooks[c], { get: function (a) { var d = a.getAttribute(c, 2); return d === null ? b : d; }, }); }), p.support.style || (p.attrHooks.style = { get: function (a) { return a.style.cssText.toLowerCase() || b; }, set: function (a, b) { return (a.style.cssText = "" + b); }, }), p.support.optSelected || (p.propHooks.selected = p.extend(p.propHooks.selected, { get: function (a) { var b = a.parentNode; return ( b && (b.selectedIndex, b.parentNode && b.parentNode.selectedIndex), null ); }, })), p.support.enctype || (p.propFix.enctype = "encoding"), p.support.checkOn || p.each(["radio", "checkbox"], function () { p.valHooks[this] = { get: function (a) { return a.getAttribute("value") === null ? "on" : a.value; }, }; }), p.each(["radio", "checkbox"], function () { p.valHooks[this] = p.extend(p.valHooks[this], { set: function (a, b) { if (p.isArray(b)) return (a.checked = p.inArray(p(a).val(), b) >= 0); }, }); }); var V = /^(?:textarea|input|select)$/i, W = /^([^\.]*|)(?:\.(.+)|)$/, X = /(?:^|\s)hover(\.\S+|)\b/, Y = /^key/, Z = /^(?:mouse|contextmenu)|click/, $ = /^(?:focusinfocus|focusoutblur)$/, _ = function (a) { return p.event.special.hover ? a : a.replace(X, "mouseenter$1 mouseleave$1"); }; (p.event = { add: function (a, c, d, e, f) { var g, h, i, j, k, l, m, n, o, q, r; if (a.nodeType === 3 || a.nodeType === 8 || !c || !d || !(g = p._data(a))) return; d.handler && ((o = d), (d = o.handler), (f = o.selector)), d.guid || (d.guid = p.guid++), (i = g.events), i || (g.events = i = {}), (h = g.handle), h || ((g.handle = h = function (a) { return typeof p != "undefined" && (!a || p.event.triggered !== a.type) ? p.event.dispatch.apply(h.elem, arguments) : b; }), (h.elem = a)), (c = p.trim(_(c)).split(" ")); for (j = 0; j < c.length; j++) { (k = W.exec(c[j]) || []), (l = k[1]), (m = (k[2] || "").split(".").sort()), (r = p.event.special[l] || {}), (l = (f ? r.delegateType : r.bindType) || l), (r = p.event.special[l] || {}), (n = p.extend( { type: l, origType: k[1], data: e, handler: d, guid: d.guid, selector: f, namespace: m.join("."), }, o, )), (q = i[l]); if (!q) { (q = i[l] = []), (q.delegateCount = 0); if (!r.setup || r.setup.call(a, e, m, h) === !1) a.addEventListener ? a.addEventListener(l, h, !1) : a.attachEvent && a.attachEvent("on" + l, h); } r.add && (r.add.call(a, n), n.handler.guid || (n.handler.guid = d.guid)), f ? q.splice(q.delegateCount++, 0, n) : q.push(n), (p.event.global[l] = !0); } a = null; }, global: {}, remove: function (a, b, c, d, e) { var f, g, h, i, j, k, l, m, n, o, q, r = p.hasData(a) && p._data(a); if (!r || !(m = r.events)) return; b = p.trim(_(b || "")).split(" "); for (f = 0; f < b.length; f++) { (g = W.exec(b[f]) || []), (h = i = g[1]), (j = g[2]); if (!h) { for (h in m) p.event.remove(a, h + b[f], c, d, !0); continue; } (n = p.event.special[h] || {}), (h = (d ? n.delegateType : n.bindType) || h), (o = m[h] || []), (k = o.length), (j = j ? new RegExp( "(^|\\.)" + j.split(".").sort().join("\\.(?:.*\\.|)") + "(\\.|$)", ) : null); for (l = 0; l < o.length; l++) (q = o[l]), (e || i === q.origType) && (!c || c.guid === q.guid) && (!j || j.test(q.namespace)) && (!d || d === q.selector || (d === "**" && q.selector)) && (o.splice(l--, 1), q.selector && o.delegateCount--, n.remove && n.remove.call(a, q)); o.length === 0 && k !== o.length && ((!n.teardown || n.teardown.call(a, j, r.handle) === !1) && p.removeEvent(a, h, r.handle), delete m[h]); } p.isEmptyObject(m) && (delete r.handle, p.removeData(a, "events", !0)); }, customEvent: { getData: !0, setData: !0, changeData: !0 }, trigger: function (c, d, f, g) { if (!f || (f.nodeType !== 3 && f.nodeType !== 8)) { var h, i, j, k, l, m, n, o, q, r, s = c.type || c, t = []; if ($.test(s + p.event.triggered)) return; s.indexOf("!") >= 0 && ((s = s.slice(0, -1)), (i = !0)), s.indexOf(".") >= 0 && ((t = s.split(".")), (s = t.shift()), t.sort()); if ((!f || p.event.customEvent[s]) && !p.event.global[s]) return; (c = typeof c == "object" ? c[p.expando] ? c : new p.Event(s, c) : new p.Event(s)), (c.type = s), (c.isTrigger = !0), (c.exclusive = i), (c.namespace = t.join(".")), (c.namespace_re = c.namespace ? new RegExp("(^|\\.)" + t.join("\\.(?:.*\\.|)") + "(\\.|$)") : null), (m = s.indexOf(":") < 0 ? "on" + s : ""); if (!f) { h = p.cache; for (j in h) h[j].events && h[j].events[s] && p.event.trigger(c, d, h[j].handle.elem, !0); return; } (c.result = b), c.target || (c.target = f), (d = d != null ? p.makeArray(d) : []), d.unshift(c), (n = p.event.special[s] || {}); if (n.trigger && n.trigger.apply(f, d) === !1) return; q = [[f, n.bindType || s]]; if (!g && !n.noBubble && !p.isWindow(f)) { (r = n.delegateType || s), (k = $.test(r + s) ? f : f.parentNode); for (l = f; k; k = k.parentNode) q.push([k, r]), (l = k); l === (f.ownerDocument || e) && q.push([l.defaultView || l.parentWindow || a, r]); } for (j = 0; j < q.length && !c.isPropagationStopped(); j++) (k = q[j][0]), (c.type = q[j][1]), (o = (p._data(k, "events") || {})[c.type] && p._data(k, "handle")), o && o.apply(k, d), (o = m && k[m]), o && p.acceptData(k) && o.apply(k, d) === !1 && c.preventDefault(); return ( (c.type = s), !g && !c.isDefaultPrevented() && (!n._default || n._default.apply(f.ownerDocument, d) === !1) && (s !== "click" || !p.nodeName(f, "a")) && p.acceptData(f) && m && f[s] && ((s !== "focus" && s !== "blur") || c.target.offsetWidth !== 0) && !p.isWindow(f) && ((l = f[m]), l && (f[m] = null), (p.event.triggered = s), f[s](), (p.event.triggered = b), l && (f[m] = l)), c.result ); } return; }, dispatch: function (c) { c = p.event.fix(c || a.event); var d, e, f, g, h, i, j, k, l, m, n = (p._data(this, "events") || {})[c.type] || [], o = n.delegateCount, q = [].slice.call(arguments), r = !c.exclusive && !c.namespace, s = p.event.special[c.type] || {}, t = []; (q[0] = c), (c.delegateTarget = this); if (s.preDispatch && s.preDispatch.call(this, c) === !1) return; if (o && (!c.button || c.type !== "click")) for (f = c.target; f != this; f = f.parentNode || this) if (f.disabled !== !0 || c.type !== "click") { (h = {}), (j = []); for (d = 0; d < o; d++) (k = n[d]), (l = k.selector), h[l] === b && (h[l] = p(l, this).index(f) >= 0), h[l] && j.push(k); j.length && t.push({ elem: f, matches: j }); } n.length > o && t.push({ elem: this, matches: n.slice(o) }); for (d = 0; d < t.length && !c.isPropagationStopped(); d++) { (i = t[d]), (c.currentTarget = i.elem); for ( e = 0; e < i.matches.length && !c.isImmediatePropagationStopped(); e++ ) { k = i.matches[e]; if ( r || (!c.namespace && !k.namespace) || (c.namespace_re && c.namespace_re.test(k.namespace)) ) (c.data = k.data), (c.handleObj = k), (g = ( (p.event.special[k.origType] || {}).handle || k.handler ).apply(i.elem, q)), g !== b && ((c.result = g), g === !1 && (c.preventDefault(), c.stopPropagation())); } } return s.postDispatch && s.postDispatch.call(this, c), c.result; }, props: "attrChange attrName relatedNode srcElement altKey bubbles cancelable ctrlKey currentTarget eventPhase metaKey relatedTarget shiftKey target timeStamp view which".split( " ", ), fixHooks: {}, keyHooks: { props: "char charCode key keyCode".split(" "), filter: function (a, b) { return ( a.which == null && (a.which = b.charCode != null ? b.charCode : b.keyCode), a ); }, }, mouseHooks: { props: "button buttons clientX clientY fromElement offsetX offsetY pageX pageY screenX screenY toElement".split( " ", ), filter: function (a, c) { var d, f, g, h = c.button, i = c.fromElement; return ( a.pageX == null && c.clientX != null && ((d = a.target.ownerDocument || e), (f = d.documentElement), (g = d.body), (a.pageX = c.clientX + ((f && f.scrollLeft) || (g && g.scrollLeft) || 0) - ((f && f.clientLeft) || (g && g.clientLeft) || 0)), (a.pageY = c.clientY + ((f && f.scrollTop) || (g && g.scrollTop) || 0) - ((f && f.clientTop) || (g && g.clientTop) || 0))), !a.relatedTarget && i && (a.relatedTarget = i === a.target ? c.toElement : i), !a.which && h !== b && (a.which = h & 1 ? 1 : h & 2 ? 3 : h & 4 ? 2 : 0), a ); }, }, fix: function (a) { if (a[p.expando]) return a; var b, c, d = a, f = p.event.fixHooks[a.type] || {}, g = f.props ? this.props.concat(f.props) : this.props; a = p.Event(d); for (b = g.length; b; ) (c = g[--b]), (a[c] = d[c]); return ( a.target || (a.target = d.srcElement || e), a.target.nodeType === 3 && (a.target = a.target.parentNode), (a.metaKey = !!a.metaKey), f.filter ? f.filter(a, d) : a ); }, special: { load: { noBubble: !0 }, focus: { delegateType: "focusin" }, blur: { delegateType: "focusout" }, beforeunload: { setup: function (a, b, c) { p.isWindow(this) && (this.onbeforeunload = c); }, teardown: function (a, b) { this.onbeforeunload === b && (this.onbeforeunload = null); }, }, }, simulate: function (a, b, c, d) { var e = p.extend(new p.Event(), c, { type: a, isSimulated: !0, originalEvent: {}, }); d ? p.event.trigger(e, null, b) : p.event.dispatch.call(b, e), e.isDefaultPrevented() && c.preventDefault(); }, }), (p.event.handle = p.event.dispatch), (p.removeEvent = e.removeEventListener ? function (a, b, c) { a.removeEventListener && a.removeEventListener(b, c, !1); } : function (a, b, c) { var d = "on" + b; a.detachEvent && (typeof a[d] == "undefined" && (a[d] = null), a.detachEvent(d, c)); }), (p.Event = function (a, b) { if (this instanceof p.Event) a && a.type ? ((this.originalEvent = a), (this.type = a.type), (this.isDefaultPrevented = a.defaultPrevented || a.returnValue === !1 || (a.getPreventDefault && a.getPreventDefault()) ? bb : ba)) : (this.type = a), b && p.extend(this, b), (this.timeStamp = (a && a.timeStamp) || p.now()), (this[p.expando] = !0); else return new p.Event(a, b); }), (p.Event.prototype = { preventDefault: function () { this.isDefaultPrevented = bb; var a = this.originalEvent; if (!a) return; a.preventDefault ? a.preventDefault() : (a.returnValue = !1); }, stopPropagation: function () { this.isPropagationStopped = bb; var a = this.originalEvent; if (!a) return; a.stopPropagation && a.stopPropagation(), (a.cancelBubble = !0); }, stopImmediatePropagation: function () { (this.isImmediatePropagationStopped = bb), this.stopPropagation(); }, isDefaultPrevented: ba, isPropagationStopped: ba, isImmediatePropagationStopped: ba, }), p.each( { mouseenter: "mouseover", mouseleave: "mouseout" }, function (a, b) { p.event.special[a] = { delegateType: b, bindType: b, handle: function (a) { var c, d = this, e = a.relatedTarget, f = a.handleObj, g = f.selector; if (!e || (e !== d && !p.contains(d, e))) (a.type = f.origType), (c = f.handler.apply(this, arguments)), (a.type = b); return c; }, }; }, ), p.support.submitBubbles || (p.event.special.submit = { setup: function () { if (p.nodeName(this, "form")) return !1; p.event.add(this, "click._submit keypress._submit", function (a) { var c = a.target, d = p.nodeName(c, "input") || p.nodeName(c, "button") ? c.form : b; d && !p._data(d, "_submit_attached") && (p.event.add(d, "submit._submit", function (a) { a._submit_bubble = !0; }), p._data(d, "_submit_attached", !0)); }); }, postDispatch: function (a) { a._submit_bubble && (delete a._submit_bubble, this.parentNode && !a.isTrigger && p.event.simulate("submit", this.parentNode, a, !0)); }, teardown: function () { if (p.nodeName(this, "form")) return !1; p.event.remove(this, "._submit"); }, }), p.support.changeBubbles || (p.event.special.change = { setup: function () { if (V.test(this.nodeName)) { if (this.type === "checkbox" || this.type === "radio") p.event.add(this, "propertychange._change", function (a) { a.originalEvent.propertyName === "checked" && (this._just_changed = !0); }), p.event.add(this, "click._change", function (a) { this._just_changed && !a.isTrigger && (this._just_changed = !1), p.event.simulate("change", this, a, !0); }); return !1; } p.event.add(this, "beforeactivate._change", function (a) { var b = a.target; V.test(b.nodeName) && !p._data(b, "_change_attached") && (p.event.add(b, "change._change", function (a) { this.parentNode && !a.isSimulated && !a.isTrigger && p.event.simulate("change", this.parentNode, a, !0); }), p._data(b, "_change_attached", !0)); }); }, handle: function (a) { var b = a.target; if ( this !== b || a.isSimulated || a.isTrigger || (b.type !== "radio" && b.type !== "checkbox") ) return a.handleObj.handler.apply(this, arguments); }, teardown: function () { return p.event.remove(this, "._change"), !V.test(this.nodeName); }, }), p.support.focusinBubbles || p.each({ focus: "focusin", blur: "focusout" }, function (a, b) { var c = 0, d = function (a) { p.event.simulate(b, a.target, p.event.fix(a), !0); }; p.event.special[b] = { setup: function () { c++ === 0 && e.addEventListener(a, d, !0); }, teardown: function () { --c === 0 && e.removeEventListener(a, d, !0); }, }; }), p.fn.extend({ on: function (a, c, d, e, f) { var g, h; if (typeof a == "object") { typeof c != "string" && ((d = d || c), (c = b)); for (h in a) this.on(h, c, d, a[h], f); return this; } d == null && e == null ? ((e = c), (d = c = b)) : e == null && (typeof c == "string" ? ((e = d), (d = b)) : ((e = d), (d = c), (c = b))); if (e === !1) e = ba; else if (!e) return this; return ( f === 1 && ((g = e), (e = function (a) { return p().off(a), g.apply(this, arguments); }), (e.guid = g.guid || (g.guid = p.guid++))), this.each(function () { p.event.add(this, a, e, d, c); }) ); }, one: function (a, b, c, d) { return this.on(a, b, c, d, 1); }, off: function (a, c, d) { var e, f; if (a && a.preventDefault && a.handleObj) return ( (e = a.handleObj), p(a.delegateTarget).off( e.namespace ? e.origType + "." + e.namespace : e.origType, e.selector, e.handler, ), this ); if (typeof a == "object") { for (f in a) this.off(f, c, a[f]); return this; } if (c === !1 || typeof c == "function") (d = c), (c = b); return ( d === !1 && (d = ba), this.each(function () { p.event.remove(this, a, d, c); }) ); }, bind: function (a, b, c) { return this.on(a, null, b, c); }, unbind: function (a, b) { return this.off(a, null, b); }, live: function (a, b, c) { return p(this.context).on(a, this.selector, b, c), this; }, die: function (a, b) { return p(this.context).off(a, this.selector || "**", b), this; }, delegate: function (a, b, c, d) { return this.on(b, a, c, d); }, undelegate: function (a, b, c) { return arguments.length == 1 ? this.off(a, "**") : this.off(b, a || "**", c); }, trigger: function (a, b) { return this.each(function () { p.event.trigger(a, b, this); }); }, triggerHandler: function (a, b) { if (this[0]) return p.event.trigger(a, b, this[0], !0); }, toggle: function (a) { var b = arguments, c = a.guid || p.guid++, d = 0, e = function (c) { var e = (p._data(this, "lastToggle" + a.guid) || 0) % d; return ( p._data(this, "lastToggle" + a.guid, e + 1), c.preventDefault(), b[e].apply(this, arguments) || !1 ); }; e.guid = c; while (d < b.length) b[d++].guid = c; return this.click(e); }, hover: function (a, b) { return this.mouseenter(a).mouseleave(b || a); }, }), p.each( "blur focus focusin focusout load resize scroll unload click dblclick mousedown mouseup mousemove mouseover mouseout mouseenter mouseleave change select submit keydown keypress keyup error contextmenu".split( " ", ), function (a, b) { (p.fn[b] = function (a, c) { return ( c == null && ((c = a), (a = null)), arguments.length > 0 ? this.on(b, null, a, c) : this.trigger(b) ); }), Y.test(b) && (p.event.fixHooks[b] = p.event.keyHooks), Z.test(b) && (p.event.fixHooks[b] = p.event.mouseHooks); }, ), (function (a, b) { function $(a, b, c, d) { (c = c || []), (b = b || q); var e, f, g, j, k = b.nodeType; if (k !== 1 && k !== 9) return []; if (!a || typeof a != "string") return c; g = h(b); if (!g && !d) if ((e = L.exec(a))) if ((j = e[1])) { if (k === 9) { f = b.getElementById(j); if (!f || !f.parentNode) return c; if (f.id === j) return c.push(f), c; } else if ( b.ownerDocument && (f = b.ownerDocument.getElementById(j)) && i(b, f) && f.id === j ) return c.push(f), c; } else { if (e[2]) return u.apply(c, t.call(b.getElementsByTagName(a), 0)), c; if ((j = e[3]) && X && b.getElementsByClassName) return u.apply(c, t.call(b.getElementsByClassName(j), 0)), c; } return bk(a, b, c, d, g); } function _(a) { return function (b) { var c = b.nodeName.toLowerCase(); return c === "input" && b.type === a; }; } function ba(a) { return function (b) { var c = b.nodeName.toLowerCase(); return (c === "input" || c === "button") && b.type === a; }; } function bb(a, b, c) { if (a === b) return c; var d = a.nextSibling; while (d) { if (d === b) return -1; d = d.nextSibling; } return 1; } function bc(a, b, c, d) { var e, g, h, i, j, k, l, m, n, p, r = !c && b !== q, s = (r ? "" : "") + a.replace(H, "$1"), u = y[o][s]; if (u) return d ? 0 : t.call(u, 0); (j = a), (k = []), (m = 0), (n = f.preFilter), (p = f.filter); while (j) { if (!e || (g = I.exec(j))) g && ((j = j.slice(g[0].length)), (h.selector = l)), k.push((h = [])), (l = ""), r && (j = " " + j); e = !1; if ((g = J.exec(j))) (l += g[0]), (j = j.slice(g[0].length)), (e = h.push({ part: g.pop().replace(H, " "), string: g[0], captures: g, })); for (i in p) (g = S[i].exec(j)) && (!n[i] || (g = n[i](g, b, c))) && ((l += g[0]), (j = j.slice(g[0].length)), (e = h.push({ part: i, string: g.shift(), captures: g }))); if (!e) break; } return ( l && (h.selector = l), d ? j.length : j ? $.error(a) : t.call(y(s, k), 0) ); } function bd(a, b, e, f) { var g = b.dir, h = s++; return ( a || (a = function (a) { return a === e; }), b.first ? function (b) { while ((b = b[g])) if (b.nodeType === 1) return a(b) && b; } : f ? function (b) { while ((b = b[g])) if (b.nodeType === 1 && a(b)) return b; } : function (b) { var e, f = h + "." + c, i = f + "." + d; while ((b = b[g])) if (b.nodeType === 1) { if ((e = b[o]) === i) return b.sizset; if (typeof e == "string" && e.indexOf(f) === 0) { if (b.sizset) return b; } else { b[o] = i; if (a(b)) return (b.sizset = !0), b; b.sizset = !1; } } } ); } function be(a, b) { return a ? function (c) { var d = b(c); return d && a(d === !0 ? c : d); } : b; } function bf(a, b, c) { var d, e, g = 0; for (; (d = a[g]); g++) f.relative[d.part] ? (e = bd(e, f.relative[d.part], b, c)) : (e = be( e, f.filter[d.part].apply(null, d.captures.concat(b, c)), )); return e; } function bg(a) { return function (b) { var c, d = 0; for (; (c = a[d]); d++) if (c(b)) return !0; return !1; }; } function bh(a, b, c, d) { var e = 0, f = b.length; for (; e < f; e++) $(a, b[e], c, d); } function bi(a, b, c, d, e, g) { var h, i = f.setFilters[b.toLowerCase()]; return ( i || $.error(b), (a || !(h = e)) && bh(a || "*", d, (h = []), e), h.length > 0 ? i(h, c, g) : [] ); } function bj(a, c, d, e) { var f, g, h, i, j, k, l, m, n, o, p, q, r, s = 0, t = a.length, v = S.POS, w = new RegExp("^" + v.source + "(?!" + A + ")", "i"), x = function () { var a = 1, c = arguments.length - 2; for (; a < c; a++) arguments[a] === b && (n[a] = b); }; for (; s < t; s++) { (f = a[s]), (g = ""), (m = e); for (h = 0, i = f.length; h < i; h++) { (j = f[h]), (k = j.string); if (j.part === "PSEUDO") { v.exec(""), (l = 0); while ((n = v.exec(k))) { (o = !0), (p = v.lastIndex = n.index + n[0].length); if (p > l) { (g += k.slice(l, n.index)), (l = p), (q = [c]), J.test(g) && (m && (q = m), (m = e)); if ((r = O.test(g))) (g = g.slice(0, -5).replace(J, "$&*")), l++; n.length > 1 && n[0].replace(w, x), (m = bi(g, n[1], n[2], q, m, r)); } g = ""; } } o || (g += k), (o = !1); } g ? J.test(g) ? bh(g, m || [c], d, e) : $(g, c, d, e ? e.concat(m) : m) : u.apply(d, m); } return t === 1 ? d : $.uniqueSort(d); } function bk(a, b, e, g, h) { a = a.replace(H, "$1"); var i, k, l, m, n, o, p, q, r, s, v = bc(a, b, h), w = b.nodeType; if (S.POS.test(a)) return bj(v, b, e, g); if (g) i = t.call(g, 0); else if (v.length === 1) { if ( (o = t.call(v[0], 0)).length > 2 && (p = o[0]).part === "ID" && w === 9 && !h && f.relative[o[1].part] ) { b = f.find.ID(p.captures[0].replace(R, ""), b, h)[0]; if (!b) return e; a = a.slice(o.shift().string.length); } (r = ((v = N.exec(o[0].string)) && !v.index && b.parentNode) || b), (q = ""); for (n = o.length - 1; n >= 0; n--) { (p = o[n]), (s = p.part), (q = p.string + q); if (f.relative[s]) break; if (f.order.test(s)) { i = f.find[s](p.captures[0].replace(R, ""), r, h); if (i == null) continue; (a = a.slice(0, a.length - q.length) + q.replace(S[s], "")), a || u.apply(e, t.call(i, 0)); break; } } } if (a) { (k = j(a, b, h)), (c = k.dirruns++), i == null && (i = f.find.TAG("*", (N.test(a) && b.parentNode) || b)); for (n = 0; (m = i[n]); n++) (d = k.runs++), k(m) && e.push(m); } return e; } var c, d, e, f, g, h, i, j, k, l, m = !0, n = "undefined", o = ("sizcache" + Math.random()).replace(".", ""), q = a.document, r = q.documentElement, s = 0, t = [].slice, u = [].push, v = function (a, b) { return (a[o] = b || !0), a; }, w = function () { var a = {}, b = []; return v(function (c, d) { return b.push(c) > f.cacheLength && delete a[b.shift()], (a[c] = d); }, a); }, x = w(), y = w(), z = w(), A = "[\\x20\\t\\r\\n\\f]", B = "(?:\\\\.|[-\\w]|[^\\x00-\\xa0])+", C = B.replace("w", "w#"), D = "([*^$|!~]?=)", E = "\\[" + A + "*(" + B + ")" + A + "*(?:" + D + A + "*(?:(['\"])((?:\\\\.|[^\\\\])*?)\\3|(" + C + ")|)|)" + A + "*\\]", F = ":(" + B + ")(?:\\((?:(['\"])((?:\\\\.|[^\\\\])*?)\\2|([^()[\\]]*|(?:(?:" + E + ")|[^:]|\\\\.)*|.*))\\)|)", G = ":(nth|eq|gt|lt|first|last|even|odd)(?:\\(((?:-\\d)?\\d*)\\)|)(?=[^-]|$)", H = new RegExp("^" + A + "+|((?:^|[^\\\\])(?:\\\\.)*)" + A + "+$", "g"), I = new RegExp("^" + A + "*," + A + "*"), J = new RegExp("^" + A + "*([\\x20\\t\\r\\n\\f>+~])" + A + "*"), K = new RegExp(F), L = /^(?:#([\w\-]+)|(\w+)|\.([\w\-]+))$/, M = /^:not/, N = /[\x20\t\r\n\f]*[+~]/, O = /:not\($/, P = /h\d/i, Q = /input|select|textarea|button/i, R = /\\(?!\\)/g, S = { ID: new RegExp("^#(" + B + ")"), CLASS: new RegExp("^\\.(" + B + ")"), NAME: new RegExp("^\\[name=['\"]?(" + B + ")['\"]?\\]"), TAG: new RegExp("^(" + B.replace("w", "w*") + ")"), ATTR: new RegExp("^" + E), PSEUDO: new RegExp("^" + F), CHILD: new RegExp( "^:(only|nth|last|first)-child(?:\\(" + A + "*(even|odd|(([+-]|)(\\d*)n|)" + A + "*(?:([+-]|)" + A + "*(\\d+)|))" + A + "*\\)|)", "i", ), POS: new RegExp(G, "ig"), needsContext: new RegExp("^" + A + "*[>+~]|" + G, "i"), }, T = function (a) { var b = q.createElement("div"); try { return a(b); } catch (c) { return !1; } finally { b = null; } }, U = T(function (a) { return ( a.appendChild(q.createComment("")), !a.getElementsByTagName("*").length ); }), V = T(function (a) { return ( (a.innerHTML = ""), a.firstChild && typeof a.firstChild.getAttribute !== n && a.firstChild.getAttribute("href") === "#" ); }), W = T(function (a) { a.innerHTML = ""; var b = typeof a.lastChild.getAttribute("multiple"); return b !== "boolean" && b !== "string"; }), X = T(function (a) { return ( (a.innerHTML = ""), !a.getElementsByClassName || !a.getElementsByClassName("e").length ? !1 : ((a.lastChild.className = "e"), a.getElementsByClassName("e").length === 2) ); }), Y = T(function (a) { (a.id = o + 0), (a.innerHTML = "
        "), r.insertBefore(a, r.firstChild); var b = q.getElementsByName && q.getElementsByName(o).length === 2 + q.getElementsByName(o + 0).length; return (e = !q.getElementById(o)), r.removeChild(a), b; }); try { t.call(r.childNodes, 0)[0].nodeType; } catch (Z) { t = function (a) { var b, c = []; for (; (b = this[a]); a++) c.push(b); return c; }; } ($.matches = function (a, b) { return $(a, null, null, b); }), ($.matchesSelector = function (a, b) { return $(b, null, null, [a]).length > 0; }), (g = $.getText = function (a) { var b, c = "", d = 0, e = a.nodeType; if (e) { if (e === 1 || e === 9 || e === 11) { if (typeof a.textContent == "string") return a.textContent; for (a = a.firstChild; a; a = a.nextSibling) c += g(a); } else if (e === 3 || e === 4) return a.nodeValue; } else for (; (b = a[d]); d++) c += g(b); return c; }), (h = $.isXML = function (a) { var b = a && (a.ownerDocument || a).documentElement; return b ? b.nodeName !== "HTML" : !1; }), (i = $.contains = r.contains ? function (a, b) { var c = a.nodeType === 9 ? a.documentElement : a, d = b && b.parentNode; return ( a === d || !!(d && d.nodeType === 1 && c.contains && c.contains(d)) ); } : r.compareDocumentPosition ? function (a, b) { return b && !!(a.compareDocumentPosition(b) & 16); } : function (a, b) { while ((b = b.parentNode)) if (b === a) return !0; return !1; }), ($.attr = function (a, b) { var c, d = h(a); return ( d || (b = b.toLowerCase()), f.attrHandle[b] ? f.attrHandle[b](a) : W || d ? a.getAttribute(b) : ((c = a.getAttributeNode(b)), c ? typeof a[b] == "boolean" ? a[b] ? b : null : c.specified ? c.value : null : null) ); }), (f = $.selectors = { cacheLength: 50, createPseudo: v, match: S, order: new RegExp( "ID|TAG" + (Y ? "|NAME" : "") + (X ? "|CLASS" : ""), ), attrHandle: V ? {} : { href: function (a) { return a.getAttribute("href", 2); }, type: function (a) { return a.getAttribute("type"); }, }, find: { ID: e ? function (a, b, c) { if (typeof b.getElementById !== n && !c) { var d = b.getElementById(a); return d && d.parentNode ? [d] : []; } } : function (a, c, d) { if (typeof c.getElementById !== n && !d) { var e = c.getElementById(a); return e ? e.id === a || (typeof e.getAttributeNode !== n && e.getAttributeNode("id").value === a) ? [e] : b : []; } }, TAG: U ? function (a, b) { if (typeof b.getElementsByTagName !== n) return b.getElementsByTagName(a); } : function (a, b) { var c = b.getElementsByTagName(a); if (a === "*") { var d, e = [], f = 0; for (; (d = c[f]); f++) d.nodeType === 1 && e.push(d); return e; } return c; }, NAME: function (a, b) { if (typeof b.getElementsByName !== n) return b.getElementsByName(name); }, CLASS: function (a, b, c) { if (typeof b.getElementsByClassName !== n && !c) return b.getElementsByClassName(a); }, }, relative: { ">": { dir: "parentNode", first: !0 }, " ": { dir: "parentNode" }, "+": { dir: "previousSibling", first: !0 }, "~": { dir: "previousSibling" }, }, preFilter: { ATTR: function (a) { return ( (a[1] = a[1].replace(R, "")), (a[3] = (a[4] || a[5] || "").replace(R, "")), a[2] === "~=" && (a[3] = " " + a[3] + " "), a.slice(0, 4) ); }, CHILD: function (a) { return ( (a[1] = a[1].toLowerCase()), a[1] === "nth" ? (a[2] || $.error(a[0]), (a[3] = +(a[3] ? a[4] + (a[5] || 1) : 2 * (a[2] === "even" || a[2] === "odd"))), (a[4] = +(a[6] + a[7] || a[2] === "odd"))) : a[2] && $.error(a[0]), a ); }, PSEUDO: function (a, b, c) { var d, e; if (S.CHILD.test(a[0])) return null; if (a[3]) a[2] = a[3]; else if ((d = a[4])) K.test(d) && (e = bc(d, b, c, !0)) && (e = d.indexOf(")", d.length - e) - d.length) && ((d = d.slice(0, e)), (a[0] = a[0].slice(0, e))), (a[2] = d); return a.slice(0, 3); }, }, filter: { ID: e ? function (a) { return ( (a = a.replace(R, "")), function (b) { return b.getAttribute("id") === a; } ); } : function (a) { return ( (a = a.replace(R, "")), function (b) { var c = typeof b.getAttributeNode !== n && b.getAttributeNode("id"); return c && c.value === a; } ); }, TAG: function (a) { return a === "*" ? function () { return !0; } : ((a = a.replace(R, "").toLowerCase()), function (b) { return b.nodeName && b.nodeName.toLowerCase() === a; }); }, CLASS: function (a) { var b = x[o][a]; return ( b || (b = x( a, new RegExp("(^|" + A + ")" + a + "(" + A + "|$)"), )), function (a) { return b.test( a.className || (typeof a.getAttribute !== n && a.getAttribute("class")) || "", ); } ); }, ATTR: function (a, b, c) { return b ? function (d) { var e = $.attr(d, a), f = e + ""; if (e == null) return b === "!="; switch (b) { case "=": return f === c; case "!=": return f !== c; case "^=": return c && f.indexOf(c) === 0; case "*=": return c && f.indexOf(c) > -1; case "$=": return c && f.substr(f.length - c.length) === c; case "~=": return (" " + f + " ").indexOf(c) > -1; case "|=": return ( f === c || f.substr(0, c.length + 1) === c + "-" ); } } : function (b) { return $.attr(b, a) != null; }; }, CHILD: function (a, b, c, d) { if (a === "nth") { var e = s++; return function (a) { var b, f, g = 0, h = a; if (c === 1 && d === 0) return !0; b = a.parentNode; if (b && (b[o] !== e || !a.sizset)) { for (h = b.firstChild; h; h = h.nextSibling) if (h.nodeType === 1) { h.sizset = ++g; if (h === a) break; } b[o] = e; } return ( (f = a.sizset - d), c === 0 ? f === 0 : f % c === 0 && f / c >= 0 ); }; } return function (b) { var c = b; switch (a) { case "only": case "first": while ((c = c.previousSibling)) if (c.nodeType === 1) return !1; if (a === "first") return !0; c = b; case "last": while ((c = c.nextSibling)) if (c.nodeType === 1) return !1; return !0; } }; }, PSEUDO: function (a, b, c, d) { var e, g = f.pseudos[a] || f.pseudos[a.toLowerCase()]; return ( g || $.error("unsupported pseudo: " + a), g[o] ? g(b, c, d) : g.length > 1 ? ((e = [a, a, "", b]), function (a) { return g(a, 0, e); }) : g ); }, }, pseudos: { not: v(function (a, b, c) { var d = j(a.replace(H, "$1"), b, c); return function (a) { return !d(a); }; }), enabled: function (a) { return a.disabled === !1; }, disabled: function (a) { return a.disabled === !0; }, checked: function (a) { var b = a.nodeName.toLowerCase(); return ( (b === "input" && !!a.checked) || (b === "option" && !!a.selected) ); }, selected: function (a) { return ( a.parentNode && a.parentNode.selectedIndex, a.selected === !0 ); }, parent: function (a) { return !f.pseudos.empty(a); }, empty: function (a) { var b; a = a.firstChild; while (a) { if (a.nodeName > "@" || (b = a.nodeType) === 3 || b === 4) return !1; a = a.nextSibling; } return !0; }, contains: v(function (a) { return function (b) { return (b.textContent || b.innerText || g(b)).indexOf(a) > -1; }; }), has: v(function (a) { return function (b) { return $(a, b).length > 0; }; }), header: function (a) { return P.test(a.nodeName); }, text: function (a) { var b, c; return ( a.nodeName.toLowerCase() === "input" && (b = a.type) === "text" && ((c = a.getAttribute("type")) == null || c.toLowerCase() === b) ); }, radio: _("radio"), checkbox: _("checkbox"), file: _("file"), password: _("password"), image: _("image"), submit: ba("submit"), reset: ba("reset"), button: function (a) { var b = a.nodeName.toLowerCase(); return (b === "input" && a.type === "button") || b === "button"; }, input: function (a) { return Q.test(a.nodeName); }, focus: function (a) { var b = a.ownerDocument; return ( a === b.activeElement && (!b.hasFocus || b.hasFocus()) && (!!a.type || !!a.href) ); }, active: function (a) { return a === a.ownerDocument.activeElement; }, }, setFilters: { first: function (a, b, c) { return c ? a.slice(1) : [a[0]]; }, last: function (a, b, c) { var d = a.pop(); return c ? a : [d]; }, even: function (a, b, c) { var d = [], e = c ? 1 : 0, f = a.length; for (; e < f; e = e + 2) d.push(a[e]); return d; }, odd: function (a, b, c) { var d = [], e = c ? 0 : 1, f = a.length; for (; e < f; e = e + 2) d.push(a[e]); return d; }, lt: function (a, b, c) { return c ? a.slice(+b) : a.slice(0, +b); }, gt: function (a, b, c) { return c ? a.slice(0, +b + 1) : a.slice(+b + 1); }, eq: function (a, b, c) { var d = a.splice(+b, 1); return c ? a : d; }, }, }), (k = r.compareDocumentPosition ? function (a, b) { return a === b ? ((l = !0), 0) : ( !a.compareDocumentPosition || !b.compareDocumentPosition ? a.compareDocumentPosition : a.compareDocumentPosition(b) & 4 ) ? -1 : 1; } : function (a, b) { if (a === b) return (l = !0), 0; if (a.sourceIndex && b.sourceIndex) return a.sourceIndex - b.sourceIndex; var c, d, e = [], f = [], g = a.parentNode, h = b.parentNode, i = g; if (g === h) return bb(a, b); if (!g) return -1; if (!h) return 1; while (i) e.unshift(i), (i = i.parentNode); i = h; while (i) f.unshift(i), (i = i.parentNode); (c = e.length), (d = f.length); for (var j = 0; j < c && j < d; j++) if (e[j] !== f[j]) return bb(e[j], f[j]); return j === c ? bb(a, f[j], -1) : bb(e[j], b, 1); }), [0, 0].sort(k), (m = !l), ($.uniqueSort = function (a) { var b, c = 1; (l = m), a.sort(k); if (l) for (; (b = a[c]); c++) b === a[c - 1] && a.splice(c--, 1); return a; }), ($.error = function (a) { throw new Error("Syntax error, unrecognized expression: " + a); }), (j = $.compile = function (a, b, c) { var d, e, f, g = z[o][a]; if (g && g.context === b) return g; d = bc(a, b, c); for (e = 0, f = d.length; e < f; e++) d[e] = bf(d[e], b, c); return ( (g = z(a, bg(d))), (g.context = b), (g.runs = g.dirruns = 0), g ); }), q.querySelectorAll && (function () { var a, b = bk, c = /'|\\/g, d = /\=[\x20\t\r\n\f]*([^'"\]]*)[\x20\t\r\n\f]*\]/g, e = [], f = [":active"], g = r.matchesSelector || r.mozMatchesSelector || r.webkitMatchesSelector || r.oMatchesSelector || r.msMatchesSelector; T(function (a) { (a.innerHTML = ""), a.querySelectorAll("[selected]").length || e.push( "\\[" + A + "*(?:checked|disabled|ismap|multiple|readonly|selected|value)", ), a.querySelectorAll(":checked").length || e.push(":checked"); }), T(function (a) { (a.innerHTML = "

        "), a.querySelectorAll("[test^='']").length && e.push("[*^$]=" + A + "*(?:\"\"|'')"), (a.innerHTML = ""), a.querySelectorAll(":enabled").length || e.push(":enabled", ":disabled"); }), (e = e.length && new RegExp(e.join("|"))), (bk = function (a, d, f, g, h) { if (!g && !h && (!e || !e.test(a))) if (d.nodeType === 9) try { return u.apply(f, t.call(d.querySelectorAll(a), 0)), f; } catch (i) {} else if ( d.nodeType === 1 && d.nodeName.toLowerCase() !== "object" ) { var j, k, l, m = d.getAttribute("id"), n = m || o, p = (N.test(a) && d.parentNode) || d; m ? (n = n.replace(c, "\\$&")) : d.setAttribute("id", n), (j = bc(a, d, h)), (n = "[id='" + n + "']"); for (k = 0, l = j.length; k < l; k++) j[k] = n + j[k].selector; try { return ( u.apply(f, t.call(p.querySelectorAll(j.join(",")), 0)), f ); } catch (i) { } finally { m || d.removeAttribute("id"); } } return b(a, d, f, g, h); }), g && (T(function (b) { a = g.call(b, "div"); try { g.call(b, "[test!='']:sizzle"), f.push(S.PSEUDO.source, S.POS.source, "!="); } catch (c) {} }), (f = new RegExp(f.join("|"))), ($.matchesSelector = function (b, c) { c = c.replace(d, "='$1']"); if (!h(b) && !f.test(c) && (!e || !e.test(c))) try { var i = g.call(b, c); if (i || a || (b.document && b.document.nodeType !== 11)) return i; } catch (j) {} return $(c, null, null, [b]).length > 0; })); })(), (f.setFilters.nth = f.setFilters.eq), (f.filters = f.pseudos), ($.attr = p.attr), (p.find = $), (p.expr = $.selectors), (p.expr[":"] = p.expr.pseudos), (p.unique = $.uniqueSort), (p.text = $.getText), (p.isXMLDoc = $.isXML), (p.contains = $.contains); })(a); var bc = /Until$/, bd = /^(?:parents|prev(?:Until|All))/, be = /^.[^:#\[\.,]*$/, bf = p.expr.match.needsContext, bg = { children: !0, contents: !0, next: !0, prev: !0 }; p.fn.extend({ find: function (a) { var b, c, d, e, f, g, h = this; if (typeof a != "string") return p(a).filter(function () { for (b = 0, c = h.length; b < c; b++) if (p.contains(h[b], this)) return !0; }); g = this.pushStack("", "find", a); for (b = 0, c = this.length; b < c; b++) { (d = g.length), p.find(a, this[b], g); if (b > 0) for (e = d; e < g.length; e++) for (f = 0; f < d; f++) if (g[f] === g[e]) { g.splice(e--, 1); break; } } return g; }, has: function (a) { var b, c = p(a, this), d = c.length; return this.filter(function () { for (b = 0; b < d; b++) if (p.contains(this, c[b])) return !0; }); }, not: function (a) { return this.pushStack(bj(this, a, !1), "not", a); }, filter: function (a) { return this.pushStack(bj(this, a, !0), "filter", a); }, is: function (a) { return ( !!a && (typeof a == "string" ? bf.test(a) ? p(a, this.context).index(this[0]) >= 0 : p.filter(a, this).length > 0 : this.filter(a).length > 0) ); }, closest: function (a, b) { var c, d = 0, e = this.length, f = [], g = bf.test(a) || typeof a != "string" ? p(a, b || this.context) : 0; for (; d < e; d++) { c = this[d]; while (c && c.ownerDocument && c !== b && c.nodeType !== 11) { if (g ? g.index(c) > -1 : p.find.matchesSelector(c, a)) { f.push(c); break; } c = c.parentNode; } } return ( (f = f.length > 1 ? p.unique(f) : f), this.pushStack(f, "closest", a) ); }, index: function (a) { return a ? typeof a == "string" ? p.inArray(this[0], p(a)) : p.inArray(a.jquery ? a[0] : a, this) : this[0] && this[0].parentNode ? this.prevAll().length : -1; }, add: function (a, b) { var c = typeof a == "string" ? p(a, b) : p.makeArray(a && a.nodeType ? [a] : a), d = p.merge(this.get(), c); return this.pushStack(bh(c[0]) || bh(d[0]) ? d : p.unique(d)); }, addBack: function (a) { return this.add(a == null ? this.prevObject : this.prevObject.filter(a)); }, }), (p.fn.andSelf = p.fn.addBack), p.each( { parent: function (a) { var b = a.parentNode; return b && b.nodeType !== 11 ? b : null; }, parents: function (a) { return p.dir(a, "parentNode"); }, parentsUntil: function (a, b, c) { return p.dir(a, "parentNode", c); }, next: function (a) { return bi(a, "nextSibling"); }, prev: function (a) { return bi(a, "previousSibling"); }, nextAll: function (a) { return p.dir(a, "nextSibling"); }, prevAll: function (a) { return p.dir(a, "previousSibling"); }, nextUntil: function (a, b, c) { return p.dir(a, "nextSibling", c); }, prevUntil: function (a, b, c) { return p.dir(a, "previousSibling", c); }, siblings: function (a) { return p.sibling((a.parentNode || {}).firstChild, a); }, children: function (a) { return p.sibling(a.firstChild); }, contents: function (a) { return p.nodeName(a, "iframe") ? a.contentDocument || a.contentWindow.document : p.merge([], a.childNodes); }, }, function (a, b) { p.fn[a] = function (c, d) { var e = p.map(this, b, c); return ( bc.test(a) || (d = c), d && typeof d == "string" && (e = p.filter(d, e)), (e = this.length > 1 && !bg[a] ? p.unique(e) : e), this.length > 1 && bd.test(a) && (e = e.reverse()), this.pushStack(e, a, k.call(arguments).join(",")) ); }; }, ), p.extend({ filter: function (a, b, c) { return ( c && (a = ":not(" + a + ")"), b.length === 1 ? p.find.matchesSelector(b[0], a) ? [b[0]] : [] : p.find.matches(a, b) ); }, dir: function (a, c, d) { var e = [], f = a[c]; while ( f && f.nodeType !== 9 && (d === b || f.nodeType !== 1 || !p(f).is(d)) ) f.nodeType === 1 && e.push(f), (f = f[c]); return e; }, sibling: function (a, b) { var c = []; for (; a; a = a.nextSibling) a.nodeType === 1 && a !== b && c.push(a); return c; }, }); var bl = "abbr|article|aside|audio|bdi|canvas|data|datalist|details|figcaption|figure|footer|header|hgroup|mark|meter|nav|output|progress|section|summary|time|video", bm = / jQuery\d+="(?:null|\d+)"/g, bn = /^\s+/, bo = /<(?!area|br|col|embed|hr|img|input|link|meta|param)(([\w:]+)[^>]*)\/>/gi, bp = /<([\w:]+)/, bq = /]", "i"), bv = /^(?:checkbox|radio)$/, bw = /checked\s*(?:[^=]|=\s*.checked.)/i, bx = /\/(java|ecma)script/i, by = /^\s*\s*$/g, bz = { option: [1, ""], legend: [1, "
        ", "
        "], thead: [1, "", "
        "], tr: [2, "", "
        "], td: [3, "", "
        "], col: [2, "", "
        "], area: [1, "", ""], _default: [0, "", ""], }, bA = bk(e), bB = bA.appendChild(e.createElement("div")); (bz.optgroup = bz.option), (bz.tbody = bz.tfoot = bz.colgroup = bz.caption = bz.thead), (bz.th = bz.td), p.support.htmlSerialize || (bz._default = [1, "X
        ", "
        "]), p.fn.extend({ text: function (a) { return p.access( this, function (a) { return a === b ? p.text(this) : this.empty().append( ((this[0] && this[0].ownerDocument) || e).createTextNode(a), ); }, null, a, arguments.length, ); }, wrapAll: function (a) { if (p.isFunction(a)) return this.each(function (b) { p(this).wrapAll(a.call(this, b)); }); if (this[0]) { var b = p(a, this[0].ownerDocument).eq(0).clone(!0); this[0].parentNode && b.insertBefore(this[0]), b .map(function () { var a = this; while (a.firstChild && a.firstChild.nodeType === 1) a = a.firstChild; return a; }) .append(this); } return this; }, wrapInner: function (a) { return p.isFunction(a) ? this.each(function (b) { p(this).wrapInner(a.call(this, b)); }) : this.each(function () { var b = p(this), c = b.contents(); c.length ? c.wrapAll(a) : b.append(a); }); }, wrap: function (a) { var b = p.isFunction(a); return this.each(function (c) { p(this).wrapAll(b ? a.call(this, c) : a); }); }, unwrap: function () { return this.parent() .each(function () { p.nodeName(this, "body") || p(this).replaceWith(this.childNodes); }) .end(); }, append: function () { return this.domManip(arguments, !0, function (a) { (this.nodeType === 1 || this.nodeType === 11) && this.appendChild(a); }); }, prepend: function () { return this.domManip(arguments, !0, function (a) { (this.nodeType === 1 || this.nodeType === 11) && this.insertBefore(a, this.firstChild); }); }, before: function () { if (!bh(this[0])) return this.domManip(arguments, !1, function (a) { this.parentNode.insertBefore(a, this); }); if (arguments.length) { var a = p.clean(arguments); return this.pushStack(p.merge(a, this), "before", this.selector); } }, after: function () { if (!bh(this[0])) return this.domManip(arguments, !1, function (a) { this.parentNode.insertBefore(a, this.nextSibling); }); if (arguments.length) { var a = p.clean(arguments); return this.pushStack(p.merge(this, a), "after", this.selector); } }, remove: function (a, b) { var c, d = 0; for (; (c = this[d]) != null; d++) if (!a || p.filter(a, [c]).length) !b && c.nodeType === 1 && (p.cleanData(c.getElementsByTagName("*")), p.cleanData([c])), c.parentNode && c.parentNode.removeChild(c); return this; }, empty: function () { var a, b = 0; for (; (a = this[b]) != null; b++) { a.nodeType === 1 && p.cleanData(a.getElementsByTagName("*")); while (a.firstChild) a.removeChild(a.firstChild); } return this; }, clone: function (a, b) { return ( (a = a == null ? !1 : a), (b = b == null ? a : b), this.map(function () { return p.clone(this, a, b); }) ); }, html: function (a) { return p.access( this, function (a) { var c = this[0] || {}, d = 0, e = this.length; if (a === b) return c.nodeType === 1 ? c.innerHTML.replace(bm, "") : b; if ( typeof a == "string" && !bs.test(a) && (p.support.htmlSerialize || !bu.test(a)) && (p.support.leadingWhitespace || !bn.test(a)) && !bz[(bp.exec(a) || ["", ""])[1].toLowerCase()] ) { a = a.replace(bo, "<$1>"); try { for (; d < e; d++) (c = this[d] || {}), c.nodeType === 1 && (p.cleanData(c.getElementsByTagName("*")), (c.innerHTML = a)); c = 0; } catch (f) {} } c && this.empty().append(a); }, null, a, arguments.length, ); }, replaceWith: function (a) { return bh(this[0]) ? this.length ? this.pushStack(p(p.isFunction(a) ? a() : a), "replaceWith", a) : this : p.isFunction(a) ? this.each(function (b) { var c = p(this), d = c.html(); c.replaceWith(a.call(this, b, d)); }) : (typeof a != "string" && (a = p(a).detach()), this.each(function () { var b = this.nextSibling, c = this.parentNode; p(this).remove(), b ? p(b).before(a) : p(c).append(a); })); }, detach: function (a) { return this.remove(a, !0); }, domManip: function (a, c, d) { a = [].concat.apply([], a); var e, f, g, h, i = 0, j = a[0], k = [], l = this.length; if ( !p.support.checkClone && l > 1 && typeof j == "string" && bw.test(j) ) return this.each(function () { p(this).domManip(a, c, d); }); if (p.isFunction(j)) return this.each(function (e) { var f = p(this); (a[0] = j.call(this, e, c ? f.html() : b)), f.domManip(a, c, d); }); if (this[0]) { (e = p.buildFragment(a, this, k)), (g = e.fragment), (f = g.firstChild), g.childNodes.length === 1 && (g = f); if (f) { c = c && p.nodeName(f, "tr"); for (h = e.cacheable || l - 1; i < l; i++) d.call( c && p.nodeName(this[i], "table") ? bC(this[i], "tbody") : this[i], i === h ? g : p.clone(g, !0, !0), ); } (g = f = null), k.length && p.each(k, function (a, b) { b.src ? p.ajax ? p.ajax({ url: b.src, type: "GET", dataType: "script", async: !1, global: !1, throws: !0, }) : p.error("no ajax") : p.globalEval( (b.text || b.textContent || b.innerHTML || "").replace( by, "", ), ), b.parentNode && b.parentNode.removeChild(b); }); } return this; }, }), (p.buildFragment = function (a, c, d) { var f, g, h, i = a[0]; return ( (c = c || e), (c = (!c.nodeType && c[0]) || c), (c = c.ownerDocument || c), a.length === 1 && typeof i == "string" && i.length < 512 && c === e && i.charAt(0) === "<" && !bt.test(i) && (p.support.checkClone || !bw.test(i)) && (p.support.html5Clone || !bu.test(i)) && ((g = !0), (f = p.fragments[i]), (h = f !== b)), f || ((f = c.createDocumentFragment()), p.clean(a, c, f, d), g && (p.fragments[i] = h && f)), { fragment: f, cacheable: g } ); }), (p.fragments = {}), p.each( { appendTo: "append", prependTo: "prepend", insertBefore: "before", insertAfter: "after", replaceAll: "replaceWith", }, function (a, b) { p.fn[a] = function (c) { var d, e = 0, f = [], g = p(c), h = g.length, i = this.length === 1 && this[0].parentNode; if ( (i == null || (i && i.nodeType === 11 && i.childNodes.length === 1)) && h === 1 ) return g[b](this[0]), this; for (; e < h; e++) (d = (e > 0 ? this.clone(!0) : this).get()), p(g[e])[b](d), (f = f.concat(d)); return this.pushStack(f, a, g.selector); }; }, ), p.extend({ clone: function (a, b, c) { var d, e, f, g; p.support.html5Clone || p.isXMLDoc(a) || !bu.test("<" + a.nodeName + ">") ? (g = a.cloneNode(!0)) : ((bB.innerHTML = a.outerHTML), bB.removeChild((g = bB.firstChild))); if ( (!p.support.noCloneEvent || !p.support.noCloneChecked) && (a.nodeType === 1 || a.nodeType === 11) && !p.isXMLDoc(a) ) { bE(a, g), (d = bF(a)), (e = bF(g)); for (f = 0; d[f]; ++f) e[f] && bE(d[f], e[f]); } if (b) { bD(a, g); if (c) { (d = bF(a)), (e = bF(g)); for (f = 0; d[f]; ++f) bD(d[f], e[f]); } } return (d = e = null), g; }, clean: function (a, b, c, d) { var f, g, h, i, j, k, l, m, n, o, q, r, s = b === e && bA, t = []; if (!b || typeof b.createDocumentFragment == "undefined") b = e; for (f = 0; (h = a[f]) != null; f++) { typeof h == "number" && (h += ""); if (!h) continue; if (typeof h == "string") if (!br.test(h)) h = b.createTextNode(h); else { (s = s || bk(b)), (l = b.createElement("div")), s.appendChild(l), (h = h.replace(bo, "<$1>")), (i = (bp.exec(h) || ["", ""])[1].toLowerCase()), (j = bz[i] || bz._default), (k = j[0]), (l.innerHTML = j[1] + h + j[2]); while (k--) l = l.lastChild; if (!p.support.tbody) { (m = bq.test(h)), (n = i === "table" && !m ? l.firstChild && l.firstChild.childNodes : j[1] === "" && !m ? l.childNodes : []); for (g = n.length - 1; g >= 0; --g) p.nodeName(n[g], "tbody") && !n[g].childNodes.length && n[g].parentNode.removeChild(n[g]); } !p.support.leadingWhitespace && bn.test(h) && l.insertBefore(b.createTextNode(bn.exec(h)[0]), l.firstChild), (h = l.childNodes), l.parentNode.removeChild(l); } h.nodeType ? t.push(h) : p.merge(t, h); } l && (h = l = s = null); if (!p.support.appendChecked) for (f = 0; (h = t[f]) != null; f++) p.nodeName(h, "input") ? bG(h) : typeof h.getElementsByTagName != "undefined" && p.grep(h.getElementsByTagName("input"), bG); if (c) { q = function (a) { if (!a.type || bx.test(a.type)) return d ? d.push(a.parentNode ? a.parentNode.removeChild(a) : a) : c.appendChild(a); }; for (f = 0; (h = t[f]) != null; f++) if (!p.nodeName(h, "script") || !q(h)) c.appendChild(h), typeof h.getElementsByTagName != "undefined" && ((r = p.grep( p.merge([], h.getElementsByTagName("script")), q, )), t.splice.apply(t, [f + 1, 0].concat(r)), (f += r.length)); } return t; }, cleanData: function (a, b) { var c, d, e, f, g = 0, h = p.expando, i = p.cache, j = p.support.deleteExpando, k = p.event.special; for (; (e = a[g]) != null; g++) if (b || p.acceptData(e)) { (d = e[h]), (c = d && i[d]); if (c) { if (c.events) for (f in c.events) k[f] ? p.event.remove(e, f) : p.removeEvent(e, f, c.handle); i[d] && (delete i[d], j ? delete e[h] : e.removeAttribute ? e.removeAttribute(h) : (e[h] = null), p.deletedIds.push(d)); } } }, }), (function () { var a, b; (p.uaMatch = function (a) { a = a.toLowerCase(); var b = /(chrome)[ \/]([\w.]+)/.exec(a) || /(webkit)[ \/]([\w.]+)/.exec(a) || /(opera)(?:.*version|)[ \/]([\w.]+)/.exec(a) || /(msie) ([\w.]+)/.exec(a) || (a.indexOf("compatible") < 0 && /(mozilla)(?:.*? rv:([\w.]+)|)/.exec(a)) || []; return { browser: b[1] || "", version: b[2] || "0" }; }), (a = p.uaMatch(g.userAgent)), (b = {}), a.browser && ((b[a.browser] = !0), (b.version = a.version)), b.chrome ? (b.webkit = !0) : b.webkit && (b.safari = !0), (p.browser = b), (p.sub = function () { function a(b, c) { return new a.fn.init(b, c); } p.extend(!0, a, this), (a.superclass = this), (a.fn = a.prototype = this()), (a.fn.constructor = a), (a.sub = this.sub), (a.fn.init = function c(c, d) { return ( d && d instanceof p && !(d instanceof a) && (d = a(d)), p.fn.init.call(this, c, d, b) ); }), (a.fn.init.prototype = a.fn); var b = a(e); return a; }); })(); var bH, bI, bJ, bK = /alpha\([^)]*\)/i, bL = /opacity=([^)]*)/, bM = /^(top|right|bottom|left)$/, bN = /^(none|table(?!-c[ea]).+)/, bO = /^margin/, bP = new RegExp("^(" + q + ")(.*)$", "i"), bQ = new RegExp("^(" + q + ")(?!px)[a-z%]+$", "i"), bR = new RegExp("^([-+])=(" + q + ")", "i"), bS = {}, bT = { position: "absolute", visibility: "hidden", display: "block" }, bU = { letterSpacing: 0, fontWeight: 400 }, bV = ["Top", "Right", "Bottom", "Left"], bW = ["Webkit", "O", "Moz", "ms"], bX = p.fn.toggle; p.fn.extend({ css: function (a, c) { return p.access( this, function (a, c, d) { return d !== b ? p.style(a, c, d) : p.css(a, c); }, a, c, arguments.length > 1, ); }, show: function () { return b$(this, !0); }, hide: function () { return b$(this); }, toggle: function (a, b) { var c = typeof a == "boolean"; return p.isFunction(a) && p.isFunction(b) ? bX.apply(this, arguments) : this.each(function () { (c ? a : bZ(this)) ? p(this).show() : p(this).hide(); }); }, }), p.extend({ cssHooks: { opacity: { get: function (a, b) { if (b) { var c = bH(a, "opacity"); return c === "" ? "1" : c; } }, }, }, cssNumber: { fillOpacity: !0, fontWeight: !0, lineHeight: !0, opacity: !0, orphans: !0, widows: !0, zIndex: !0, zoom: !0, }, cssProps: { float: p.support.cssFloat ? "cssFloat" : "styleFloat" }, style: function (a, c, d, e) { if (!a || a.nodeType === 3 || a.nodeType === 8 || !a.style) return; var f, g, h, i = p.camelCase(c), j = a.style; (c = p.cssProps[i] || (p.cssProps[i] = bY(j, i))), (h = p.cssHooks[c] || p.cssHooks[i]); if (d === b) return h && "get" in h && (f = h.get(a, !1, e)) !== b ? f : j[c]; (g = typeof d), g === "string" && (f = bR.exec(d)) && ((d = (f[1] + 1) * f[2] + parseFloat(p.css(a, c))), (g = "number")); if (d == null || (g === "number" && isNaN(d))) return; g === "number" && !p.cssNumber[i] && (d += "px"); if (!h || !("set" in h) || (d = h.set(a, d, e)) !== b) try { j[c] = d; } catch (k) {} }, css: function (a, c, d, e) { var f, g, h, i = p.camelCase(c); return ( (c = p.cssProps[i] || (p.cssProps[i] = bY(a.style, i))), (h = p.cssHooks[c] || p.cssHooks[i]), h && "get" in h && (f = h.get(a, !0, e)), f === b && (f = bH(a, c)), f === "normal" && c in bU && (f = bU[c]), d || e !== b ? ((g = parseFloat(f)), d || p.isNumeric(g) ? g || 0 : f) : f ); }, swap: function (a, b, c) { var d, e, f = {}; for (e in b) (f[e] = a.style[e]), (a.style[e] = b[e]); d = c.call(a); for (e in b) a.style[e] = f[e]; return d; }, }), a.getComputedStyle ? (bH = function (b, c) { var d, e, f, g, h = a.getComputedStyle(b, null), i = b.style; return ( h && ((d = h[c]), d === "" && !p.contains(b.ownerDocument, b) && (d = p.style(b, c)), bQ.test(d) && bO.test(c) && ((e = i.width), (f = i.minWidth), (g = i.maxWidth), (i.minWidth = i.maxWidth = i.width = d), (d = h.width), (i.width = e), (i.minWidth = f), (i.maxWidth = g))), d ); }) : e.documentElement.currentStyle && (bH = function (a, b) { var c, d, e = a.currentStyle && a.currentStyle[b], f = a.style; return ( e == null && f && f[b] && (e = f[b]), bQ.test(e) && !bM.test(b) && ((c = f.left), (d = a.runtimeStyle && a.runtimeStyle.left), d && (a.runtimeStyle.left = a.currentStyle.left), (f.left = b === "fontSize" ? "1em" : e), (e = f.pixelLeft + "px"), (f.left = c), d && (a.runtimeStyle.left = d)), e === "" ? "auto" : e ); }), p.each(["height", "width"], function (a, b) { p.cssHooks[b] = { get: function (a, c, d) { if (c) return a.offsetWidth === 0 && bN.test(bH(a, "display")) ? p.swap(a, bT, function () { return cb(a, b, d); }) : cb(a, b, d); }, set: function (a, c, d) { return b_( a, c, d ? ca( a, b, d, p.support.boxSizing && p.css(a, "boxSizing") === "border-box", ) : 0, ); }, }; }), p.support.opacity || (p.cssHooks.opacity = { get: function (a, b) { return bL.test( (b && a.currentStyle ? a.currentStyle.filter : a.style.filter) || "", ) ? 0.01 * parseFloat(RegExp.$1) + "" : b ? "1" : ""; }, set: function (a, b) { var c = a.style, d = a.currentStyle, e = p.isNumeric(b) ? "alpha(opacity=" + b * 100 + ")" : "", f = (d && d.filter) || c.filter || ""; c.zoom = 1; if (b >= 1 && p.trim(f.replace(bK, "")) === "" && c.removeAttribute) { c.removeAttribute("filter"); if (d && !d.filter) return; } c.filter = bK.test(f) ? f.replace(bK, e) : f + " " + e; }, }), p(function () { p.support.reliableMarginRight || (p.cssHooks.marginRight = { get: function (a, b) { return p.swap(a, { display: "inline-block" }, function () { if (b) return bH(a, "marginRight"); }); }, }), !p.support.pixelPosition && p.fn.position && p.each(["top", "left"], function (a, b) { p.cssHooks[b] = { get: function (a, c) { if (c) { var d = bH(a, b); return bQ.test(d) ? p(a).position()[b] + "px" : d; } }, }; }); }), p.expr && p.expr.filters && ((p.expr.filters.hidden = function (a) { return ( (a.offsetWidth === 0 && a.offsetHeight === 0) || (!p.support.reliableHiddenOffsets && ((a.style && a.style.display) || bH(a, "display")) === "none") ); }), (p.expr.filters.visible = function (a) { return !p.expr.filters.hidden(a); })), p.each({ margin: "", padding: "", border: "Width" }, function (a, b) { (p.cssHooks[a + b] = { expand: function (c) { var d, e = typeof c == "string" ? c.split(" ") : [c], f = {}; for (d = 0; d < 4; d++) f[a + bV[d] + b] = e[d] || e[d - 2] || e[0]; return f; }, }), bO.test(a) || (p.cssHooks[a + b].set = b_); }); var cd = /%20/g, ce = /\[\]$/, cf = /\r?\n/g, cg = /^(?:color|date|datetime|datetime-local|email|hidden|month|number|password|range|search|tel|text|time|url|week)$/i, ch = /^(?:select|textarea)/i; p.fn.extend({ serialize: function () { return p.param(this.serializeArray()); }, serializeArray: function () { return this.map(function () { return this.elements ? p.makeArray(this.elements) : this; }) .filter(function () { return ( this.name && !this.disabled && (this.checked || ch.test(this.nodeName) || cg.test(this.type)) ); }) .map(function (a, b) { var c = p(this).val(); return c == null ? null : p.isArray(c) ? p.map(c, function (a, c) { return { name: b.name, value: a.replace(cf, "\r\n") }; }) : { name: b.name, value: c.replace(cf, "\r\n") }; }) .get(); }, }), (p.param = function (a, c) { var d, e = [], f = function (a, b) { (b = p.isFunction(b) ? b() : b == null ? "" : b), (e[e.length] = encodeURIComponent(a) + "=" + encodeURIComponent(b)); }; c === b && (c = p.ajaxSettings && p.ajaxSettings.traditional); if (p.isArray(a) || (a.jquery && !p.isPlainObject(a))) p.each(a, function () { f(this.name, this.value); }); else for (d in a) ci(d, a[d], c, f); return e.join("&").replace(cd, "+"); }); var cj, ck, cl = /#.*$/, cm = /^(.*?):[ \t]*([^\r\n]*)\r?$/gm, cn = /^(?:about|app|app\-storage|.+\-extension|file|res|widget):$/, co = /^(?:GET|HEAD)$/, cp = /^\/\//, cq = /\?/, cr = /)<[^<]*)*<\/script>/gi, cs = /([?&])_=[^&]*/, ct = /^([\w\+\.\-]+:)(?:\/\/([^\/?#:]*)(?::(\d+)|)|)/, cu = p.fn.load, cv = {}, cw = {}, cx = ["*/"] + ["*"]; try { cj = f.href; } catch (cy) { (cj = e.createElement("a")), (cj.href = ""), (cj = cj.href); } (ck = ct.exec(cj.toLowerCase()) || []), (p.fn.load = function (a, c, d) { if (typeof a != "string" && cu) return cu.apply(this, arguments); if (!this.length) return this; var e, f, g, h = this, i = a.indexOf(" "); return ( i >= 0 && ((e = a.slice(i, a.length)), (a = a.slice(0, i))), p.isFunction(c) ? ((d = c), (c = b)) : c && typeof c == "object" && (f = "POST"), p .ajax({ url: a, type: f, dataType: "html", data: c, complete: function (a, b) { d && h.each(d, g || [a.responseText, b, a]); }, }) .done(function (a) { (g = arguments), h.html(e ? p("
        ").append(a.replace(cr, "")).find(e) : a); }), this ); }), p.each( "ajaxStart ajaxStop ajaxComplete ajaxError ajaxSuccess ajaxSend".split( " ", ), function (a, b) { p.fn[b] = function (a) { return this.on(b, a); }; }, ), p.each(["get", "post"], function (a, c) { p[c] = function (a, d, e, f) { return ( p.isFunction(d) && ((f = f || e), (e = d), (d = b)), p.ajax({ type: c, url: a, data: d, success: e, dataType: f }) ); }; }), p.extend({ getScript: function (a, c) { return p.get(a, b, c, "script"); }, getJSON: function (a, b, c) { return p.get(a, b, c, "json"); }, ajaxSetup: function (a, b) { return ( b ? cB(a, p.ajaxSettings) : ((b = a), (a = p.ajaxSettings)), cB(a, b), a ); }, ajaxSettings: { url: cj, isLocal: cn.test(ck[1]), global: !0, type: "GET", contentType: "application/x-www-form-urlencoded; charset=UTF-8", processData: !0, async: !0, accepts: { xml: "application/xml, text/xml", html: "text/html", text: "text/plain", json: "application/json, text/javascript", "*": cx, }, contents: { xml: /xml/, html: /html/, json: /json/ }, responseFields: { xml: "responseXML", text: "responseText" }, converters: { "* text": a.String, "text html": !0, "text json": p.parseJSON, "text xml": p.parseXML, }, flatOptions: { context: !0, url: !0 }, }, ajaxPrefilter: cz(cv), ajaxTransport: cz(cw), ajax: function (a, c) { function y(a, c, f, i) { var k, s, t, u, w, y = c; if (v === 2) return; (v = 2), h && clearTimeout(h), (g = b), (e = i || ""), (x.readyState = a > 0 ? 4 : 0), f && (u = cC(l, x, f)); if ((a >= 200 && a < 300) || a === 304) l.ifModified && ((w = x.getResponseHeader("Last-Modified")), w && (p.lastModified[d] = w), (w = x.getResponseHeader("Etag")), w && (p.etag[d] = w)), a === 304 ? ((y = "notmodified"), (k = !0)) : ((k = cD(l, u)), (y = k.state), (s = k.data), (t = k.error), (k = !t)); else { t = y; if (!y || a) (y = "error"), a < 0 && (a = 0); } (x.status = a), (x.statusText = "" + (c || y)), k ? o.resolveWith(m, [s, y, x]) : o.rejectWith(m, [x, y, t]), x.statusCode(r), (r = b), j && n.trigger("ajax" + (k ? "Success" : "Error"), [x, l, k ? s : t]), q.fireWith(m, [x, y]), j && (n.trigger("ajaxComplete", [x, l]), --p.active || p.event.trigger("ajaxStop")); } typeof a == "object" && ((c = a), (a = b)), (c = c || {}); var d, e, f, g, h, i, j, k, l = p.ajaxSetup({}, c), m = l.context || l, n = m !== l && (m.nodeType || m instanceof p) ? p(m) : p.event, o = p.Deferred(), q = p.Callbacks("once memory"), r = l.statusCode || {}, t = {}, u = {}, v = 0, w = "canceled", x = { readyState: 0, setRequestHeader: function (a, b) { if (!v) { var c = a.toLowerCase(); (a = u[c] = u[c] || a), (t[a] = b); } return this; }, getAllResponseHeaders: function () { return v === 2 ? e : null; }, getResponseHeader: function (a) { var c; if (v === 2) { if (!f) { f = {}; while ((c = cm.exec(e))) f[c[1].toLowerCase()] = c[2]; } c = f[a.toLowerCase()]; } return c === b ? null : c; }, overrideMimeType: function (a) { return v || (l.mimeType = a), this; }, abort: function (a) { return (a = a || w), g && g.abort(a), y(0, a), this; }, }; o.promise(x), (x.success = x.done), (x.error = x.fail), (x.complete = q.add), (x.statusCode = function (a) { if (a) { var b; if (v < 2) for (b in a) r[b] = [r[b], a[b]]; else (b = a[x.status]), x.always(b); } return this; }), (l.url = ((a || l.url) + "") .replace(cl, "") .replace(cp, ck[1] + "//")), (l.dataTypes = p .trim(l.dataType || "*") .toLowerCase() .split(s)), l.crossDomain == null && ((i = ct.exec(l.url.toLowerCase())), (l.crossDomain = !( !i || (i[1] == ck[1] && i[2] == ck[2] && (i[3] || (i[1] === "http:" ? 80 : 443)) == (ck[3] || (ck[1] === "http:" ? 80 : 443))) ))), l.data && l.processData && typeof l.data != "string" && (l.data = p.param(l.data, l.traditional)), cA(cv, l, c, x); if (v === 2) return x; (j = l.global), (l.type = l.type.toUpperCase()), (l.hasContent = !co.test(l.type)), j && p.active++ === 0 && p.event.trigger("ajaxStart"); if (!l.hasContent) { l.data && ((l.url += (cq.test(l.url) ? "&" : "?") + l.data), delete l.data), (d = l.url); if (l.cache === !1) { var z = p.now(), A = l.url.replace(cs, "$1_=" + z); l.url = A + (A === l.url ? (cq.test(l.url) ? "&" : "?") + "_=" + z : ""); } } ((l.data && l.hasContent && l.contentType !== !1) || c.contentType) && x.setRequestHeader("Content-Type", l.contentType), l.ifModified && ((d = d || l.url), p.lastModified[d] && x.setRequestHeader("If-Modified-Since", p.lastModified[d]), p.etag[d] && x.setRequestHeader("If-None-Match", p.etag[d])), x.setRequestHeader( "Accept", l.dataTypes[0] && l.accepts[l.dataTypes[0]] ? l.accepts[l.dataTypes[0]] + (l.dataTypes[0] !== "*" ? ", " + cx + "; q=0.01" : "") : l.accepts["*"], ); for (k in l.headers) x.setRequestHeader(k, l.headers[k]); if (!l.beforeSend || (l.beforeSend.call(m, x, l) !== !1 && v !== 2)) { w = "abort"; for (k in { success: 1, error: 1, complete: 1 }) x[k](l[k]); g = cA(cw, l, c, x); if (!g) y(-1, "No Transport"); else { (x.readyState = 1), j && n.trigger("ajaxSend", [x, l]), l.async && l.timeout > 0 && (h = setTimeout(function () { x.abort("timeout"); }, l.timeout)); try { (v = 1), g.send(t, y); } catch (B) { if (v < 2) y(-1, B); else throw B; } } return x; } return x.abort(); }, active: 0, lastModified: {}, etag: {}, }); var cE = [], cF = /\?/, cG = /(=)\?(?=&|$)|\?\?/, cH = p.now(); p.ajaxSetup({ jsonp: "callback", jsonpCallback: function () { var a = cE.pop() || p.expando + "_" + cH++; return (this[a] = !0), a; }, }), p.ajaxPrefilter("json jsonp", function (c, d, e) { var f, g, h, i = c.data, j = c.url, k = c.jsonp !== !1, l = k && cG.test(j), m = k && !l && typeof i == "string" && !(c.contentType || "").indexOf("application/x-www-form-urlencoded") && cG.test(i); if (c.dataTypes[0] === "jsonp" || l || m) return ( (f = c.jsonpCallback = p.isFunction(c.jsonpCallback) ? c.jsonpCallback() : c.jsonpCallback), (g = a[f]), l ? (c.url = j.replace(cG, "$1" + f)) : m ? (c.data = i.replace(cG, "$1" + f)) : k && (c.url += (cF.test(j) ? "&" : "?") + c.jsonp + "=" + f), (c.converters["script json"] = function () { return h || p.error(f + " was not called"), h[0]; }), (c.dataTypes[0] = "json"), (a[f] = function () { h = arguments; }), e.always(function () { (a[f] = g), c[f] && ((c.jsonpCallback = d.jsonpCallback), cE.push(f)), h && p.isFunction(g) && g(h[0]), (h = g = b); }), "script" ); }), p.ajaxSetup({ accepts: { script: "text/javascript, application/javascript, application/ecmascript, application/x-ecmascript", }, contents: { script: /javascript|ecmascript/ }, converters: { "text script": function (a) { return p.globalEval(a), a; }, }, }), p.ajaxPrefilter("script", function (a) { a.cache === b && (a.cache = !1), a.crossDomain && ((a.type = "GET"), (a.global = !1)); }), p.ajaxTransport("script", function (a) { if (a.crossDomain) { var c, d = e.head || e.getElementsByTagName("head")[0] || e.documentElement; return { send: function (f, g) { (c = e.createElement("script")), (c.async = "async"), a.scriptCharset && (c.charset = a.scriptCharset), (c.src = a.url), (c.onload = c.onreadystatechange = function (a, e) { if ( e || !c.readyState || /loaded|complete/.test(c.readyState) ) (c.onload = c.onreadystatechange = null), d && c.parentNode && d.removeChild(c), (c = b), e || g(200, "success"); }), d.insertBefore(c, d.firstChild); }, abort: function () { c && c.onload(0, 1); }, }; } }); var cI, cJ = a.ActiveXObject ? function () { for (var a in cI) cI[a](0, 1); } : !1, cK = 0; (p.ajaxSettings.xhr = a.ActiveXObject ? function () { return (!this.isLocal && cL()) || cM(); } : cL), (function (a) { p.extend(p.support, { ajax: !!a, cors: !!a && "withCredentials" in a }); })(p.ajaxSettings.xhr()), p.support.ajax && p.ajaxTransport(function (c) { if (!c.crossDomain || p.support.cors) { var d; return { send: function (e, f) { var g, h, i = c.xhr(); c.username ? i.open(c.type, c.url, c.async, c.username, c.password) : i.open(c.type, c.url, c.async); if (c.xhrFields) for (h in c.xhrFields) i[h] = c.xhrFields[h]; c.mimeType && i.overrideMimeType && i.overrideMimeType(c.mimeType), !c.crossDomain && !e["X-Requested-With"] && (e["X-Requested-With"] = "XMLHttpRequest"); try { for (h in e) i.setRequestHeader(h, e[h]); } catch (j) {} i.send((c.hasContent && c.data) || null), (d = function (a, e) { var h, j, k, l, m; try { if (d && (e || i.readyState === 4)) { (d = b), g && ((i.onreadystatechange = p.noop), cJ && delete cI[g]); if (e) i.readyState !== 4 && i.abort(); else { (h = i.status), (k = i.getAllResponseHeaders()), (l = {}), (m = i.responseXML), m && m.documentElement && (l.xml = m); try { l.text = i.responseText; } catch (a) {} try { j = i.statusText; } catch (n) { j = ""; } !h && c.isLocal && !c.crossDomain ? (h = l.text ? 200 : 404) : h === 1223 && (h = 204); } } } catch (o) { e || f(-1, o); } l && f(h, j, l, k); }), c.async ? i.readyState === 4 ? setTimeout(d, 0) : ((g = ++cK), cJ && (cI || ((cI = {}), p(a).unload(cJ)), (cI[g] = d)), (i.onreadystatechange = d)) : d(); }, abort: function () { d && d(0, 1); }, }; } }); var cN, cO, cP = /^(?:toggle|show|hide)$/, cQ = new RegExp("^(?:([-+])=|)(" + q + ")([a-z%]*)$", "i"), cR = /queueHooks$/, cS = [cY], cT = { "*": [ function (a, b) { var c, d, e, f = this.createTween(a, b), g = cQ.exec(b), h = f.cur(), i = +h || 0, j = 1; if (g) { (c = +g[2]), (d = g[3] || (p.cssNumber[a] ? "" : "px")); if (d !== "px" && i) { i = p.css(f.elem, a, !0) || c || 1; do (e = j = j || ".5"), (i = i / j), p.style(f.elem, a, i + d), (j = f.cur() / h); while (j !== 1 && j !== e); } (f.unit = d), (f.start = i), (f.end = g[1] ? i + (g[1] + 1) * c : c); } return f; }, ], }; (p.Animation = p.extend(cW, { tweener: function (a, b) { p.isFunction(a) ? ((b = a), (a = ["*"])) : (a = a.split(" ")); var c, d = 0, e = a.length; for (; d < e; d++) (c = a[d]), (cT[c] = cT[c] || []), cT[c].unshift(b); }, prefilter: function (a, b) { b ? cS.unshift(a) : cS.push(a); }, })), (p.Tween = cZ), (cZ.prototype = { constructor: cZ, init: function (a, b, c, d, e, f) { (this.elem = a), (this.prop = c), (this.easing = e || "swing"), (this.options = b), (this.start = this.now = this.cur()), (this.end = d), (this.unit = f || (p.cssNumber[c] ? "" : "px")); }, cur: function () { var a = cZ.propHooks[this.prop]; return a && a.get ? a.get(this) : cZ.propHooks._default.get(this); }, run: function (a) { var b, c = cZ.propHooks[this.prop]; return ( this.options.duration ? (this.pos = b = p.easing[this.easing]( a, this.options.duration * a, 0, 1, this.options.duration, )) : (this.pos = b = a), (this.now = (this.end - this.start) * b + this.start), this.options.step && this.options.step.call(this.elem, this.now, this), c && c.set ? c.set(this) : cZ.propHooks._default.set(this), this ); }, }), (cZ.prototype.init.prototype = cZ.prototype), (cZ.propHooks = { _default: { get: function (a) { var b; return a.elem[a.prop] == null || (!!a.elem.style && a.elem.style[a.prop] != null) ? ((b = p.css(a.elem, a.prop, !1, "")), !b || b === "auto" ? 0 : b) : a.elem[a.prop]; }, set: function (a) { p.fx.step[a.prop] ? p.fx.step[a.prop](a) : a.elem.style && (a.elem.style[p.cssProps[a.prop]] != null || p.cssHooks[a.prop]) ? p.style(a.elem, a.prop, a.now + a.unit) : (a.elem[a.prop] = a.now); }, }, }), (cZ.propHooks.scrollTop = cZ.propHooks.scrollLeft = { set: function (a) { a.elem.nodeType && a.elem.parentNode && (a.elem[a.prop] = a.now); }, }), p.each(["toggle", "show", "hide"], function (a, b) { var c = p.fn[b]; p.fn[b] = function (d, e, f) { return d == null || typeof d == "boolean" || (!a && p.isFunction(d) && p.isFunction(e)) ? c.apply(this, arguments) : this.animate(c$(b, !0), d, e, f); }; }), p.fn.extend({ fadeTo: function (a, b, c, d) { return this.filter(bZ) .css("opacity", 0) .show() .end() .animate({ opacity: b }, a, c, d); }, animate: function (a, b, c, d) { var e = p.isEmptyObject(a), f = p.speed(b, c, d), g = function () { var b = cW(this, p.extend({}, a), f); e && b.stop(!0); }; return e || f.queue === !1 ? this.each(g) : this.queue(f.queue, g); }, stop: function (a, c, d) { var e = function (a) { var b = a.stop; delete a.stop, b(d); }; return ( typeof a != "string" && ((d = c), (c = a), (a = b)), c && a !== !1 && this.queue(a || "fx", []), this.each(function () { var b = !0, c = a != null && a + "queueHooks", f = p.timers, g = p._data(this); if (c) g[c] && g[c].stop && e(g[c]); else for (c in g) g[c] && g[c].stop && cR.test(c) && e(g[c]); for (c = f.length; c--; ) f[c].elem === this && (a == null || f[c].queue === a) && (f[c].anim.stop(d), (b = !1), f.splice(c, 1)); (b || !d) && p.dequeue(this, a); }) ); }, }), p.each( { slideDown: c$("show"), slideUp: c$("hide"), slideToggle: c$("toggle"), fadeIn: { opacity: "show" }, fadeOut: { opacity: "hide" }, fadeToggle: { opacity: "toggle" }, }, function (a, b) { p.fn[a] = function (a, c, d) { return this.animate(b, a, c, d); }; }, ), (p.speed = function (a, b, c) { var d = a && typeof a == "object" ? p.extend({}, a) : { complete: c || (!c && b) || (p.isFunction(a) && a), duration: a, easing: (c && b) || (b && !p.isFunction(b) && b), }; d.duration = p.fx.off ? 0 : typeof d.duration == "number" ? d.duration : d.duration in p.fx.speeds ? p.fx.speeds[d.duration] : p.fx.speeds._default; if (d.queue == null || d.queue === !0) d.queue = "fx"; return ( (d.old = d.complete), (d.complete = function () { p.isFunction(d.old) && d.old.call(this), d.queue && p.dequeue(this, d.queue); }), d ); }), (p.easing = { linear: function (a) { return a; }, swing: function (a) { return 0.5 - Math.cos(a * Math.PI) / 2; }, }), (p.timers = []), (p.fx = cZ.prototype.init), (p.fx.tick = function () { var a, b = p.timers, c = 0; for (; c < b.length; c++) (a = b[c]), !a() && b[c] === a && b.splice(c--, 1); b.length || p.fx.stop(); }), (p.fx.timer = function (a) { a() && p.timers.push(a) && !cO && (cO = setInterval(p.fx.tick, p.fx.interval)); }), (p.fx.interval = 13), (p.fx.stop = function () { clearInterval(cO), (cO = null); }), (p.fx.speeds = { slow: 600, fast: 200, _default: 400 }), (p.fx.step = {}), p.expr && p.expr.filters && (p.expr.filters.animated = function (a) { return p.grep(p.timers, function (b) { return a === b.elem; }).length; }); var c_ = /^(?:body|html)$/i; (p.fn.offset = function (a) { if (arguments.length) return a === b ? this : this.each(function (b) { p.offset.setOffset(this, a, b); }); var c, d, e, f, g, h, i, j, k, l, m = this[0], n = m && m.ownerDocument; if (!n) return; return (e = n.body) === m ? p.offset.bodyOffset(m) : ((d = n.documentElement), p.contains(d, m) ? ((c = m.getBoundingClientRect()), (f = da(n)), (g = d.clientTop || e.clientTop || 0), (h = d.clientLeft || e.clientLeft || 0), (i = f.pageYOffset || d.scrollTop), (j = f.pageXOffset || d.scrollLeft), (k = c.top + i - g), (l = c.left + j - h), { top: k, left: l }) : { top: 0, left: 0 }); }), (p.offset = { bodyOffset: function (a) { var b = a.offsetTop, c = a.offsetLeft; return ( p.support.doesNotIncludeMarginInBodyOffset && ((b += parseFloat(p.css(a, "marginTop")) || 0), (c += parseFloat(p.css(a, "marginLeft")) || 0)), { top: b, left: c } ); }, setOffset: function (a, b, c) { var d = p.css(a, "position"); d === "static" && (a.style.position = "relative"); var e = p(a), f = e.offset(), g = p.css(a, "top"), h = p.css(a, "left"), i = (d === "absolute" || d === "fixed") && p.inArray("auto", [g, h]) > -1, j = {}, k = {}, l, m; i ? ((k = e.position()), (l = k.top), (m = k.left)) : ((l = parseFloat(g) || 0), (m = parseFloat(h) || 0)), p.isFunction(b) && (b = b.call(a, c, f)), b.top != null && (j.top = b.top - f.top + l), b.left != null && (j.left = b.left - f.left + m), "using" in b ? b.using.call(a, j) : e.css(j); }, }), p.fn.extend({ position: function () { if (!this[0]) return; var a = this[0], b = this.offsetParent(), c = this.offset(), d = c_.test(b[0].nodeName) ? { top: 0, left: 0 } : b.offset(); return ( (c.top -= parseFloat(p.css(a, "marginTop")) || 0), (c.left -= parseFloat(p.css(a, "marginLeft")) || 0), (d.top += parseFloat(p.css(b[0], "borderTopWidth")) || 0), (d.left += parseFloat(p.css(b[0], "borderLeftWidth")) || 0), { top: c.top - d.top, left: c.left - d.left } ); }, offsetParent: function () { return this.map(function () { var a = this.offsetParent || e.body; while (a && !c_.test(a.nodeName) && p.css(a, "position") === "static") a = a.offsetParent; return a || e.body; }); }, }), p.each( { scrollLeft: "pageXOffset", scrollTop: "pageYOffset" }, function (a, c) { var d = /Y/.test(c); p.fn[a] = function (e) { return p.access( this, function (a, e, f) { var g = da(a); if (f === b) return g ? c in g ? g[c] : g.document.documentElement[e] : a[e]; g ? g.scrollTo( d ? p(g).scrollLeft() : f, d ? f : p(g).scrollTop(), ) : (a[e] = f); }, a, e, arguments.length, null, ); }; }, ), p.each({ Height: "height", Width: "width" }, function (a, c) { p.each( { padding: "inner" + a, content: c, "": "outer" + a }, function (d, e) { p.fn[e] = function (e, f) { var g = arguments.length && (d || typeof e != "boolean"), h = d || (e === !0 || f === !0 ? "margin" : "border"); return p.access( this, function (c, d, e) { var f; return p.isWindow(c) ? c.document.documentElement["client" + a] : c.nodeType === 9 ? ((f = c.documentElement), Math.max( c.body["scroll" + a], f["scroll" + a], c.body["offset" + a], f["offset" + a], f["client" + a], )) : e === b ? p.css(c, d, e, h) : p.style(c, d, e, h); }, c, g ? e : b, g, null, ); }; }, ); }), (a.jQuery = a.$ = p), typeof define == "function" && define.amd && define.amd.jQuery && define("jquery", [], function () { return p; }); })(window); ================================================ FILE: backend/tests/integration/tests/pruning/website/js/portfolio/jquery.quicksand.js ================================================ /* Quicksand 1.2.2 Reorder and filter items with a nice shuffling animation. Copyright (c) 2010 Jacek Galanciak (razorjack.net) and agilope.com Big thanks for Piotr Petrus (riddle.pl) for deep code review and wonderful docs & demos. Dual licensed under the MIT and GPL version 2 licenses. http://github.com/jquery/jquery/blob/master/MIT-LICENSE.txt http://github.com/jquery/jquery/blob/master/GPL-LICENSE.txt Project site: http://razorjack.net/quicksand Github site: http://github.com/razorjack/quicksand */ (function ($) { $.fn.quicksand = function (collection, customOptions) { var options = { duration: 750, easing: "swing", attribute: "data-id", // attribute to recognize same items within source and dest adjustHeight: "auto", // 'dynamic' animates height during shuffling (slow), 'auto' adjusts it before or after the animation, false leaves height constant useScaling: true, // disable it if you're not using scaling effect or want to improve performance enhancement: function (c) {}, // Visual enhacement (eg. font replacement) function for cloned elements selector: "> *", dx: 0, dy: 0, }; $.extend(options, customOptions); if ($.browser.msie || typeof $.fn.scale == "undefined") { // Got IE and want scaling effect? Kiss my ass. options.useScaling = false; } var callbackFunction; if (typeof arguments[1] == "function") { var callbackFunction = arguments[1]; } else if (typeof (arguments[2] == "function")) { var callbackFunction = arguments[2]; } return this.each(function (i) { var val; var animationQueue = []; // used to store all the animation params before starting the animation; solves initial animation slowdowns var $collection = $(collection).clone(); // destination (target) collection var $sourceParent = $(this); // source, the visible container of source collection var sourceHeight = $(this).css("height"); // used to keep height and document flow during the animation var destHeight; var adjustHeightOnCallback = false; var offset = $($sourceParent).offset(); // offset of visible container, used in animation calculations var offsets = []; // coordinates of every source collection item var $source = $(this).find(options.selector); // source collection items // Replace the collection and quit if IE6 if ($.browser.msie && $.browser.version.substr(0, 1) < 7) { $sourceParent.html("").append($collection); return; } // Gets called when any animation is finished var postCallbackPerformed = 0; // prevents the function from being called more than one time var postCallback = function () { if (!postCallbackPerformed) { postCallbackPerformed = 1; // hack: // used to be: $sourceParent.html($dest.html()); // put target HTML into visible source container // but new webkit builds cause flickering when replacing the collections $toDelete = $sourceParent.find("> *"); $sourceParent.prepend($dest.find("> *")); $toDelete.remove(); if (adjustHeightOnCallback) { $sourceParent.css("height", destHeight); } options.enhancement($sourceParent); // Perform custom visual enhancements on a newly replaced collection if (typeof callbackFunction == "function") { callbackFunction.call(this); } } }; // Position: relative situations var $correctionParent = $sourceParent.offsetParent(); var correctionOffset = $correctionParent.offset(); if ($correctionParent.css("position") == "relative") { if ($correctionParent.get(0).nodeName.toLowerCase() == "body") { } else { correctionOffset.top += parseFloat($correctionParent.css("border-top-width")) || 0; correctionOffset.left += parseFloat($correctionParent.css("border-left-width")) || 0; } } else { correctionOffset.top -= parseFloat($correctionParent.css("border-top-width")) || 0; correctionOffset.left -= parseFloat($correctionParent.css("border-left-width")) || 0; correctionOffset.top -= parseFloat($correctionParent.css("margin-top")) || 0; correctionOffset.left -= parseFloat($correctionParent.css("margin-left")) || 0; } // perform custom corrections from options (use when Quicksand fails to detect proper correction) if (isNaN(correctionOffset.left)) { correctionOffset.left = 0; } if (isNaN(correctionOffset.top)) { correctionOffset.top = 0; } correctionOffset.left -= options.dx; correctionOffset.top -= options.dy; // keeps nodes after source container, holding their position $sourceParent.css("height", $(this).height()); // get positions of source collections $source.each(function (i) { offsets[i] = $(this).offset(); }); // stops previous animations on source container $(this).stop(); var dx = 0; var dy = 0; $source.each(function (i) { $(this).stop(); // stop animation of collection items var rawObj = $(this).get(0); if (rawObj.style.position == "absolute") { dx = -options.dx; dy = -options.dy; } else { dx = options.dx; dy = options.dy; } rawObj.style.position = "absolute"; rawObj.style.margin = "0"; rawObj.style.top = offsets[i].top - parseFloat(rawObj.style.marginTop) - correctionOffset.top + dy + "px"; rawObj.style.left = offsets[i].left - parseFloat(rawObj.style.marginLeft) - correctionOffset.left + dx + "px"; }); // create temporary container with destination collection var $dest = $($sourceParent).clone(); var rawDest = $dest.get(0); rawDest.innerHTML = ""; rawDest.setAttribute("id", ""); rawDest.style.height = "auto"; rawDest.style.width = $sourceParent.width() + "px"; $dest.append($collection); // insert node into HTML // Note that the node is under visible source container in the exactly same position // The browser render all the items without showing them (opacity: 0.0) // No offset calculations are needed, the browser just extracts position from underlayered destination items // and sets animation to destination positions. $dest.insertBefore($sourceParent); $dest.css("opacity", 0.0); rawDest.style.zIndex = -1; rawDest.style.margin = "0"; rawDest.style.position = "absolute"; rawDest.style.top = offset.top - correctionOffset.top + "px"; rawDest.style.left = offset.left - correctionOffset.left + "px"; if (options.adjustHeight === "dynamic") { // If destination container has different height than source container // the height can be animated, adjusting it to destination height $sourceParent.animate( { height: $dest.height() }, options.duration, options.easing, ); } else if (options.adjustHeight === "auto") { destHeight = $dest.height(); if (parseFloat(sourceHeight) < parseFloat(destHeight)) { // Adjust the height now so that the items don't move out of the container $sourceParent.css("height", destHeight); } else { // Adjust later, on callback adjustHeightOnCallback = true; } } // Now it's time to do shuffling animation // First of all, we need to identify same elements within source and destination collections $source.each(function (i) { var destElement = []; if (typeof options.attribute == "function") { val = options.attribute($(this)); $collection.each(function () { if (options.attribute(this) == val) { destElement = $(this); return false; } }); } else { destElement = $collection.filter( "[" + options.attribute + "=" + $(this).attr(options.attribute) + "]", ); } if (destElement.length) { // The item is both in source and destination collections // It it's under different position, let's move it if (!options.useScaling) { animationQueue.push({ element: $(this), animation: { top: destElement.offset().top - correctionOffset.top, left: destElement.offset().left - correctionOffset.left, opacity: 1.0, }, }); } else { animationQueue.push({ element: $(this), animation: { top: destElement.offset().top - correctionOffset.top, left: destElement.offset().left - correctionOffset.left, opacity: 1.0, scale: "1.0", }, }); } } else { // The item from source collection is not present in destination collections // Let's remove it if (!options.useScaling) { animationQueue.push({ element: $(this), animation: { opacity: "0.0" }, }); } else { animationQueue.push({ element: $(this), animation: { opacity: "0.0", scale: "0.0" }, }); } } }); $collection.each(function (i) { // Grab all items from target collection not present in visible source collection var sourceElement = []; var destElement = []; if (typeof options.attribute == "function") { val = options.attribute($(this)); $source.each(function () { if (options.attribute(this) == val) { sourceElement = $(this); return false; } }); $collection.each(function () { if (options.attribute(this) == val) { destElement = $(this); return false; } }); } else { sourceElement = $source.filter( "[" + options.attribute + "=" + $(this).attr(options.attribute) + "]", ); destElement = $collection.filter( "[" + options.attribute + "=" + $(this).attr(options.attribute) + "]", ); } var animationOptions; if (sourceElement.length === 0) { // No such element in source collection... if (!options.useScaling) { animationOptions = { opacity: "1.0", }; } else { animationOptions = { opacity: "1.0", scale: "1.0", }; } // Let's create it d = destElement.clone(); var rawDestElement = d.get(0); rawDestElement.style.position = "absolute"; rawDestElement.style.margin = "0"; rawDestElement.style.top = destElement.offset().top - correctionOffset.top + "px"; rawDestElement.style.left = destElement.offset().left - correctionOffset.left + "px"; d.css("opacity", 0.0); // IE if (options.useScaling) { d.css("transform", "scale(0.0)"); } d.appendTo($sourceParent); animationQueue.push({ element: $(d), animation: animationOptions }); } }); $dest.remove(); options.enhancement($sourceParent); // Perform custom visual enhancements during the animation for (i = 0; i < animationQueue.length; i++) { animationQueue[i].element.animate( animationQueue[i].animation, options.duration, options.easing, postCallback, ); } }); }; })(jQuery); ================================================ FILE: backend/tests/integration/tests/pruning/website/js/portfolio/setting.js ================================================ jQuery(document).ready(function ($) { if (jQuery().quicksand) { // Clone applications to get a second collection var $data = $(".portfolio").clone(); //NOTE: Only filter on the main portfolio page, not on the subcategory pages $(".filter li").click(function (e) { $(".filter li").removeClass("active"); // Use the last category class as the category to filter by. This means that multiple categories are not supported (yet) var filterClass = $(this).attr("class").split(" ").slice(-1)[0]; if (filterClass == "all") { var $filteredData = $data.find(".item-thumbs"); } else { var $filteredData = $data.find( ".item-thumbs[data-type=" + filterClass + "]", ); } $(".portfolio").quicksand( $filteredData, { duration: 600, adjustHeight: "auto", }, function () { // Portfolio fancybox $(".fancybox").fancybox({ padding: 0, beforeShow: function () { this.title = $(this.element).attr("title"); this.title = "

        " + this.title + "

        " + "

        " + $(this.element).parent().find("img").attr("alt") + "

        "; }, helpers: { title: { type: "inside" }, }, }); }, ); $(this).addClass("active"); return false; }); } //if quicksand }); ================================================ FILE: backend/tests/integration/tests/pruning/website/js/quicksand/jquery.quicksand.js ================================================ /* Quicksand 1.2.2 Reorder and filter items with a nice shuffling animation. Copyright (c) 2010 Jacek Galanciak (razorjack.net) and agilope.com Big thanks for Piotr Petrus (riddle.pl) for deep code review and wonderful docs & demos. Dual licensed under the MIT and GPL version 2 licenses. http://github.com/jquery/jquery/blob/master/MIT-LICENSE.txt http://github.com/jquery/jquery/blob/master/GPL-LICENSE.txt Project site: http://razorjack.net/quicksand Github site: http://github.com/razorjack/quicksand */ (function ($) { $.fn.quicksand = function (collection, customOptions) { var options = { duration: 750, easing: "swing", attribute: "data-id", // attribute to recognize same items within source and dest adjustHeight: "auto", // 'dynamic' animates height during shuffling (slow), 'auto' adjusts it before or after the animation, false leaves height constant useScaling: true, // disable it if you're not using scaling effect or want to improve performance enhancement: function (c) {}, // Visual enhacement (eg. font replacement) function for cloned elements selector: "> *", dx: 0, dy: 0, }; $.extend(options, customOptions); if ($.browser.msie || typeof $.fn.scale == "undefined") { // Got IE and want scaling effect? Kiss my ass. options.useScaling = false; } var callbackFunction; if (typeof arguments[1] == "function") { var callbackFunction = arguments[1]; } else if (typeof (arguments[2] == "function")) { var callbackFunction = arguments[2]; } return this.each(function (i) { var val; var animationQueue = []; // used to store all the animation params before starting the animation; solves initial animation slowdowns var $collection = $(collection).clone(); // destination (target) collection var $sourceParent = $(this); // source, the visible container of source collection var sourceHeight = $(this).css("height"); // used to keep height and document flow during the animation var destHeight; var adjustHeightOnCallback = false; var offset = $($sourceParent).offset(); // offset of visible container, used in animation calculations var offsets = []; // coordinates of every source collection item var $source = $(this).find(options.selector); // source collection items // Replace the collection and quit if IE6 if ($.browser.msie && $.browser.version.substr(0, 1) < 7) { $sourceParent.html("").append($collection); return; } // Gets called when any animation is finished var postCallbackPerformed = 0; // prevents the function from being called more than one time var postCallback = function () { if (!postCallbackPerformed) { postCallbackPerformed = 1; // hack: // used to be: $sourceParent.html($dest.html()); // put target HTML into visible source container // but new webkit builds cause flickering when replacing the collections $toDelete = $sourceParent.find("> *"); $sourceParent.prepend($dest.find("> *")); $toDelete.remove(); if (adjustHeightOnCallback) { $sourceParent.css("height", destHeight); } options.enhancement($sourceParent); // Perform custom visual enhancements on a newly replaced collection if (typeof callbackFunction == "function") { callbackFunction.call(this); } } }; // Position: relative situations var $correctionParent = $sourceParent.offsetParent(); var correctionOffset = $correctionParent.offset(); if ($correctionParent.css("position") == "relative") { if ($correctionParent.get(0).nodeName.toLowerCase() == "body") { } else { correctionOffset.top += parseFloat($correctionParent.css("border-top-width")) || 0; correctionOffset.left += parseFloat($correctionParent.css("border-left-width")) || 0; } } else { correctionOffset.top -= parseFloat($correctionParent.css("border-top-width")) || 0; correctionOffset.left -= parseFloat($correctionParent.css("border-left-width")) || 0; correctionOffset.top -= parseFloat($correctionParent.css("margin-top")) || 0; correctionOffset.left -= parseFloat($correctionParent.css("margin-left")) || 0; } // perform custom corrections from options (use when Quicksand fails to detect proper correction) if (isNaN(correctionOffset.left)) { correctionOffset.left = 0; } if (isNaN(correctionOffset.top)) { correctionOffset.top = 0; } correctionOffset.left -= options.dx; correctionOffset.top -= options.dy; // keeps nodes after source container, holding their position $sourceParent.css("height", $(this).height()); // get positions of source collections $source.each(function (i) { offsets[i] = $(this).offset(); }); // stops previous animations on source container $(this).stop(); var dx = 0; var dy = 0; $source.each(function (i) { $(this).stop(); // stop animation of collection items var rawObj = $(this).get(0); if (rawObj.style.position == "absolute") { dx = -options.dx; dy = -options.dy; } else { dx = options.dx; dy = options.dy; } rawObj.style.position = "absolute"; rawObj.style.margin = "0"; rawObj.style.top = offsets[i].top - parseFloat(rawObj.style.marginTop) - correctionOffset.top + dy + "px"; rawObj.style.left = offsets[i].left - parseFloat(rawObj.style.marginLeft) - correctionOffset.left + dx + "px"; }); // create temporary container with destination collection var $dest = $($sourceParent).clone(); var rawDest = $dest.get(0); rawDest.innerHTML = ""; rawDest.setAttribute("id", ""); rawDest.style.height = "auto"; rawDest.style.width = $sourceParent.width() + "px"; $dest.append($collection); // insert node into HTML // Note that the node is under visible source container in the exactly same position // The browser render all the items without showing them (opacity: 0.0) // No offset calculations are needed, the browser just extracts position from underlayered destination items // and sets animation to destination positions. $dest.insertBefore($sourceParent); $dest.css("opacity", 0.0); rawDest.style.zIndex = -1; rawDest.style.margin = "0"; rawDest.style.position = "absolute"; rawDest.style.top = offset.top - correctionOffset.top + "px"; rawDest.style.left = offset.left - correctionOffset.left + "px"; if (options.adjustHeight === "dynamic") { // If destination container has different height than source container // the height can be animated, adjusting it to destination height $sourceParent.animate( { height: $dest.height() }, options.duration, options.easing, ); } else if (options.adjustHeight === "auto") { destHeight = $dest.height(); if (parseFloat(sourceHeight) < parseFloat(destHeight)) { // Adjust the height now so that the items don't move out of the container $sourceParent.css("height", destHeight); } else { // Adjust later, on callback adjustHeightOnCallback = true; } } // Now it's time to do shuffling animation // First of all, we need to identify same elements within source and destination collections $source.each(function (i) { var destElement = []; if (typeof options.attribute == "function") { val = options.attribute($(this)); $collection.each(function () { if (options.attribute(this) == val) { destElement = $(this); return false; } }); } else { destElement = $collection.filter( "[" + options.attribute + "=" + $(this).attr(options.attribute) + "]", ); } if (destElement.length) { // The item is both in source and destination collections // It it's under different position, let's move it if (!options.useScaling) { animationQueue.push({ element: $(this), animation: { top: destElement.offset().top - correctionOffset.top, left: destElement.offset().left - correctionOffset.left, opacity: 1.0, }, }); } else { animationQueue.push({ element: $(this), animation: { top: destElement.offset().top - correctionOffset.top, left: destElement.offset().left - correctionOffset.left, opacity: 1.0, scale: "1.0", }, }); } } else { // The item from source collection is not present in destination collections // Let's remove it if (!options.useScaling) { animationQueue.push({ element: $(this), animation: { opacity: "0.0" }, }); } else { animationQueue.push({ element: $(this), animation: { opacity: "0.0", scale: "0.0" }, }); } } }); $collection.each(function (i) { // Grab all items from target collection not present in visible source collection var sourceElement = []; var destElement = []; if (typeof options.attribute == "function") { val = options.attribute($(this)); $source.each(function () { if (options.attribute(this) == val) { sourceElement = $(this); return false; } }); $collection.each(function () { if (options.attribute(this) == val) { destElement = $(this); return false; } }); } else { sourceElement = $source.filter( "[" + options.attribute + "=" + $(this).attr(options.attribute) + "]", ); destElement = $collection.filter( "[" + options.attribute + "=" + $(this).attr(options.attribute) + "]", ); } var animationOptions; if (sourceElement.length === 0) { // No such element in source collection... if (!options.useScaling) { animationOptions = { opacity: "1.0", }; } else { animationOptions = { opacity: "1.0", scale: "1.0", }; } // Let's create it d = destElement.clone(); var rawDestElement = d.get(0); rawDestElement.style.position = "absolute"; rawDestElement.style.margin = "0"; rawDestElement.style.top = destElement.offset().top - correctionOffset.top + "px"; rawDestElement.style.left = destElement.offset().left - correctionOffset.left + "px"; d.css("opacity", 0.0); // IE if (options.useScaling) { d.css("transform", "scale(0.0)"); } d.appendTo($sourceParent); animationQueue.push({ element: $(d), animation: animationOptions }); } }); $dest.remove(); options.enhancement($sourceParent); // Perform custom visual enhancements during the animation for (i = 0; i < animationQueue.length; i++) { animationQueue[i].element.animate( animationQueue[i].animation, options.duration, options.easing, postCallback, ); } }); }; })(jQuery); ================================================ FILE: backend/tests/integration/tests/pruning/website/js/quicksand/setting.js ================================================ jQuery.noConflict(); jQuery(document).ready(function($){ if (jQuery().quicksand) { // Clone applications to get a second collection var $data = $(".portfolio-area").clone(); //NOTE: Only filter on the main portfolio page, not on the subcategory pages $('.portfolio-categ li').click(function(e) { $(".filter li").removeClass("active"); // Use the last category class as the category to filter by. This means that multiple categories are not supported (yet) var filterClass=$(this).attr('class').split(' ').slice(-1)[0]; if (filterClass == 'all') { var $filteredData = $data.find('.item-thumbs'); } else { var $filteredData = $data.find('.item-thumbs[data-type=' + filterClass + ']'); } $(".portfolio-area").quicksand($filteredData, { duration: 600, adjustHeight: 'auto' } $(this).addClass("active"); return false; }); }//if quicksand }); ================================================ FILE: backend/tests/integration/tests/pruning/website/js/validate.js ================================================ /*global jQuery:false */ jQuery(document).ready(function ($) { "use strict"; //Contact $("form.validateform").submit(function () { var f = $(this).find(".field"), ferror = false, emailExp = /^[^\s()<>@,;:\/]+@\w[\w\.-]+\.[a-z]{2,}$/i; f.children("input").each(function () { // run all inputs var i = $(this); // current input var rule = i.attr("data-rule"); if (rule != undefined) { var ierror = false; // error flag for current input var pos = rule.indexOf(":", 0); if (pos >= 0) { var exp = rule.substr(pos + 1, rule.length); rule = rule.substr(0, pos); } else { rule = rule.substr(pos + 1, rule.length); } switch (rule) { case "required": if (i.val() == "") { ferror = ierror = true; } break; case "maxlen": if (i.val().length < parseInt(exp)) { ferror = ierror = true; } break; case "email": if (!emailExp.test(i.val())) { ferror = ierror = true; } break; case "checked": if (!i.attr("checked")) { ferror = ierror = true; } break; case "regexp": exp = new RegExp(exp); if (!exp.test(i.val())) { ferror = ierror = true; } break; } i.next(".validation") .html( ierror ? i.attr("data-msg") != undefined ? i.attr("data-msg") : "wrong Input" : "", ) .show("blind"); } }); f.children("textarea").each(function () { // run all inputs var i = $(this); // current input var rule = i.attr("data-rule"); if (rule != undefined) { var ierror = false; // error flag for current input var pos = rule.indexOf(":", 0); if (pos >= 0) { var exp = rule.substr(pos + 1, rule.length); rule = rule.substr(0, pos); } else { rule = rule.substr(pos + 1, rule.length); } switch (rule) { case "required": if (i.val() == "") { ferror = ierror = true; } break; case "maxlen": if (i.val().length < parseInt(exp)) { ferror = ierror = true; } break; } i.next(".validation") .html( ierror ? i.attr("data-msg") != undefined ? i.attr("data-msg") : "wrong Input" : "", ) .show("blind"); } }); if (ferror) return false; else var str = $(this).serialize(); $.ajax({ type: "POST", url: "contact/contact.php", data: str, success: function (msg) { $("#sendmessage").addClass("show"); $("#errormessage").ajaxComplete(function (event, request, settings) { if (msg == "OK") { $("#sendmessage").addClass("show"); } else { $("#sendmessage").removeClass("show"); result = msg; } $(this).html(result); }); }, }); return false; }); }); ================================================ FILE: backend/tests/integration/tests/pruning/website/portfolio.html ================================================ Above Multi-purpose Free Bootstrap Responsive Template ================================================ FILE: backend/tests/integration/tests/pruning/website/pricing.html ================================================ Above Multi-purpose Free Bootstrap Responsive Template

        Pricing

        Basic

        $15.00 / Year
        • Responsive Design
        • Bootstrap Design
        • Unlimited Support
        • Free Trial version
        • HTML5 CSS3 jQuery

        Standard

        $20.00 / Year
        • Responsive Design
        • Bootstrap Design
        • Unlimited Support
        • Free Trial version
        • HTML5 CSS3 jQuery

        Advanced

        $15.00 / Year
        • Responsive Design
        • Bootstrap Design
        • Unlimited Support
        • Free Trial version
        • HTML5 CSS3 jQuery

        Mighty

        $15.00 / Year
        • Responsive Design
        • Bootstrap Design
        • Unlimited Support
        • Free Trial version
        • HTML5 CSS3 jQuery
        ================================================ FILE: backend/tests/integration/tests/pruning/website/readme.txt ================================================ Free Responsive HTML5 Template Above Educational Bootstrap Responsive template is a modern clean multi-purpose html5 template built with valid HTML5 & CSS3. It's build on top of latest Bootstrap framework 3.3.1 fully responsive web compatible with multi browser and devices. This template can be used for multi-purpose needs like Educational Institutes, colleges, Schools, e-Learning, Training centre, Tutors, Charity, Primary School, business, consultancy, agency, personal portfolio, profile and mobile website. Key features ------------- Twitter Bootstrap 3.3.1 Clean & Developer-friendly HTML5 and CSS3 code 100% Responsive Layout Design Multi-purpose theme Google Fonts Support Font Awesome Smooth Scrolling Fully Customizable Contact Form Credits : ------- => Design and developed: "WebThemez" http://webthemez.com => Photos used in template: **Unsplash** - http://unsplash.com => For more free web themes: http://webthemez.com => Framework : http://getbootstrap.com License : ------- **Creative Commons Attribution 3.0** - http://creativecommons.org/licenses/by/3.0/ Note: All images user here is for demo purpose only, we are not responsible for any copyrights. ================================================ FILE: backend/tests/integration/tests/query_history/test_query_history.py ================================================ import csv import io import os from datetime import datetime from datetime import timedelta from datetime import timezone import pytest from onyx.configs.constants import QAFeedbackType from onyx.configs.constants import SessionType from tests.integration.common_utils.managers.api_key import APIKeyManager from tests.integration.common_utils.managers.cc_pair import CCPairManager from tests.integration.common_utils.managers.chat import ChatSessionManager from tests.integration.common_utils.managers.document import DocumentManager from tests.integration.common_utils.managers.llm_provider import LLMProviderManager from tests.integration.common_utils.managers.query_history import QueryHistoryManager from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.test_models import DATestUser @pytest.fixture def setup_chat_session(reset: None) -> tuple[DATestUser, str]: # noqa: ARG001 # Create admin user and required resources admin_user: DATestUser = UserManager.create(name="admin_user") cc_pair = CCPairManager.create_from_scratch(user_performing_action=admin_user) api_key = APIKeyManager.create(user_performing_action=admin_user) LLMProviderManager.create(user_performing_action=admin_user) # Seed a document cc_pair.documents = [] cc_pair.documents.append( DocumentManager.seed_doc_with_content( cc_pair=cc_pair, content="The company's revenue in Q1 was $1M", api_key=api_key, ) ) # Create chat session and send a message chat_session = ChatSessionManager.create( persona_id=0, description="Test chat session", user_performing_action=admin_user, ) ChatSessionManager.send_message( chat_session_id=chat_session.id, message="What was the Q1 revenue?", user_performing_action=admin_user, ) messages = ChatSessionManager.get_chat_history( chat_session=chat_session, user_performing_action=admin_user, ) # Add another message to the chat session ChatSessionManager.send_message( chat_session_id=chat_session.id, message="What about Q2 revenue?", user_performing_action=admin_user, parent_message_id=messages[-1].id, ) return admin_user, str(chat_session.id) @pytest.mark.skipif( os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true", reason="Chat history tests are enterprise only", ) def test_chat_history_endpoints( reset: None, # noqa: ARG001 setup_chat_session: tuple[DATestUser, str], ) -> None: admin_user, first_chat_id = setup_chat_session # Get chat history history_response = QueryHistoryManager.get_query_history_page( user_performing_action=admin_user ) # Verify we got back the one chat session we created assert len(history_response.items) == 1 # Verify the first chat session details first_session = history_response.items[0] assert first_session.user_email == admin_user.email assert first_session.name == "Test chat session" assert first_session.first_user_message == "What was the Q1 revenue?" assert first_session.first_ai_message is not None assert first_session.assistant_id == 0 assert first_session.feedback_type is None assert first_session.flow_type == SessionType.CHAT assert first_session.conversation_length == 4 # 2 User messages + 2 AI responses # Test date filtering - should return no results past_end = datetime.now(tz=timezone.utc) - timedelta(days=1) past_start = past_end - timedelta(days=1) history_response = QueryHistoryManager.get_query_history_page( start_time=past_start, end_time=past_end, user_performing_action=admin_user, ) assert len(history_response.items) == 0 # Test get specific chat session endpoint session_details = QueryHistoryManager.get_chat_session_admin( chat_session_id=first_chat_id, user_performing_action=admin_user, ) # Verify the session details assert str(session_details.id) == first_chat_id assert len(session_details.messages) > 0 assert session_details.flow_type == SessionType.CHAT # Test filtering by feedback history_response = QueryHistoryManager.get_query_history_page( feedback_type=QAFeedbackType.LIKE, user_performing_action=admin_user, ) assert len(history_response.items) == 0 @pytest.mark.skipif( os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true", reason="Chat history tests are enterprise only", ) def test_chat_history_csv_export( reset: None, # noqa: ARG001 setup_chat_session: tuple[DATestUser, str], ) -> None: admin_user, _ = setup_chat_session # Test CSV export endpoint with date filtering headers, csv_content = QueryHistoryManager.get_query_history_as_csv( user_performing_action=admin_user, ) assert headers["Content-Type"] == "text/csv; charset=utf-8" assert "Content-Disposition" in headers # Use csv.reader to properly handle newlines inside quoted fields csv_rows = list(csv.reader(io.StringIO(csv_content))) assert len(csv_rows) == 3 # Header + 2 QA pairs assert csv_rows[0][0] == "chat_session_id" assert "user_message" in csv_rows[0] assert "ai_response" in csv_rows[0] assert "What was the Q1 revenue?" in csv_content assert "What about Q2 revenue?" in csv_content # Test CSV export with date filtering - should return no results past_end = datetime.now(tz=timezone.utc) - timedelta(days=1) past_start = past_end - timedelta(days=1) headers, csv_content = QueryHistoryManager.get_query_history_as_csv( start_time=past_start, end_time=past_end, user_performing_action=admin_user, ) csv_rows = list(csv.reader(io.StringIO(csv_content))) assert len(csv_rows) == 1 # Only header, no data rows ================================================ FILE: backend/tests/integration/tests/query_history/test_query_history_pagination.py ================================================ import os from datetime import datetime import pytest from onyx.configs.constants import QAFeedbackType from tests.integration.common_utils.managers.query_history import QueryHistoryManager from tests.integration.common_utils.test_models import DAQueryHistoryEntry from tests.integration.common_utils.test_models import DATestUser from tests.integration.tests.query_history.utils import ( setup_chat_sessions_with_different_feedback, ) def _verify_query_history_pagination( chat_sessions: list[DAQueryHistoryEntry], user_performing_action: DATestUser, page_size: int = 5, feedback_type: QAFeedbackType | None = None, start_time: datetime | None = None, end_time: datetime | None = None, ) -> None: retrieved_sessions: list[str] = [] for i in range(0, len(chat_sessions), page_size): paginated_result = QueryHistoryManager.get_query_history_page( page_num=i // page_size, page_size=page_size, feedback_type=feedback_type, start_time=start_time, end_time=end_time, user_performing_action=user_performing_action, ) # Verify that the total items is equal to the length of the chat sessions list assert paginated_result.total_items == len(chat_sessions) # Verify that the number of items in the page is equal to the page size assert len(paginated_result.items) == min(page_size, len(chat_sessions) - i) # Add the retrieved chat sessions to the list of retrieved sessions retrieved_sessions.extend( [str(session.id) for session in paginated_result.items] ) # Create a set of all the expected chat session IDs all_expected_sessions = set(str(session.id) for session in chat_sessions) # Create a set of all the retrieved chat session IDs all_retrieved_sessions = set(retrieved_sessions) # Verify that the set of retrieved sessions is equal to the set of expected sessions assert all_expected_sessions == all_retrieved_sessions @pytest.mark.skipif( os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true", reason="Query history tests are enterprise only", ) def test_query_history_pagination(reset: None) -> None: # noqa: ARG001 ( admin_user, chat_sessions_by_feedback_type, ) = setup_chat_sessions_with_different_feedback() all_chat_sessions = [] for _, chat_sessions in chat_sessions_by_feedback_type.items(): all_chat_sessions.extend(chat_sessions) # Verify basic pagination with different page sizes print("Verifying basic pagination with page size 5") _verify_query_history_pagination( chat_sessions=all_chat_sessions, page_size=5, user_performing_action=admin_user, ) print("Verifying basic pagination with page size 10") _verify_query_history_pagination( chat_sessions=all_chat_sessions, page_size=10, user_performing_action=admin_user, ) print("Verifying pagination with feedback type LIKE") liked_sessions = chat_sessions_by_feedback_type[QAFeedbackType.LIKE] _verify_query_history_pagination( chat_sessions=liked_sessions, feedback_type=QAFeedbackType.LIKE, user_performing_action=admin_user, ) print("Verifying pagination with feedback type DISLIKE") disliked_sessions = chat_sessions_by_feedback_type[QAFeedbackType.DISLIKE] _verify_query_history_pagination( chat_sessions=disliked_sessions, feedback_type=QAFeedbackType.DISLIKE, user_performing_action=admin_user, ) print("Verifying pagination with feedback type MIXED") mixed_sessions = chat_sessions_by_feedback_type[QAFeedbackType.MIXED] _verify_query_history_pagination( chat_sessions=mixed_sessions, feedback_type=QAFeedbackType.MIXED, user_performing_action=admin_user, ) # Test with a small page size to verify handling of partial pages print("Verifying pagination with page size 3") _verify_query_history_pagination( chat_sessions=all_chat_sessions, page_size=3, user_performing_action=admin_user, ) # Test with a page size larger than the total number of items print("Verifying pagination with page size 50") _verify_query_history_pagination( chat_sessions=all_chat_sessions, page_size=50, user_performing_action=admin_user, ) ================================================ FILE: backend/tests/integration/tests/query_history/test_usage_reports.py ================================================ from datetime import datetime from datetime import timedelta from datetime import timezone from ee.onyx.db.usage_export import get_all_empty_chat_message_entries from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.seeding.chat_history_seeding import seed_chat_history def test_usage_reports(reset: None) -> None: # noqa: ARG001 EXPECTED_SESSIONS = 2048 MESSAGES_PER_SESSION = 4 # divide by 2 because only messages of type USER are returned EXPECTED_MESSAGES = EXPECTED_SESSIONS * MESSAGES_PER_SESSION / 2 seed_chat_history(EXPECTED_SESSIONS, MESSAGES_PER_SESSION, 90) with get_session_with_current_tenant() as db_session: # count of all entries should be exact period = ( datetime.fromtimestamp(0, tz=timezone.utc), datetime.now(tz=timezone.utc), ) count = 0 for entry_batch in get_all_empty_chat_message_entries(db_session, period): for entry in entry_batch: count += 1 assert count == EXPECTED_MESSAGES # count in a one month time range should be within a certain range statistically # this can be improved if we seed the chat history data deterministically period = ( datetime.now(tz=timezone.utc) - timedelta(days=30), datetime.now(tz=timezone.utc), ) count = 0 for entry_batch in get_all_empty_chat_message_entries(db_session, period): for entry in entry_batch: count += 1 lower = EXPECTED_MESSAGES // 3 - (EXPECTED_MESSAGES // (3 * 3)) upper = EXPECTED_MESSAGES // 3 + (EXPECTED_MESSAGES // (3 * 3)) assert count > lower assert count < upper ================================================ FILE: backend/tests/integration/tests/query_history/utils.py ================================================ from concurrent.futures import as_completed from concurrent.futures import ThreadPoolExecutor from onyx.configs.constants import QAFeedbackType from tests.integration.common_utils.managers.api_key import APIKeyManager from tests.integration.common_utils.managers.cc_pair import CCPairManager from tests.integration.common_utils.managers.chat import ChatSessionManager from tests.integration.common_utils.managers.document import DocumentManager from tests.integration.common_utils.managers.llm_provider import LLMProviderManager from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.test_models import DAQueryHistoryEntry from tests.integration.common_utils.test_models import DATestUser def _create_chat_session_with_feedback( admin_user: DATestUser, i: int, feedback_type: QAFeedbackType | None, ) -> tuple[QAFeedbackType | None, DAQueryHistoryEntry]: print(f"Creating chat session {i} with feedback type {feedback_type}") # Create chat session with timestamp spread over 30 days chat_session = ChatSessionManager.create( persona_id=0, description=f"Test chat session {i}", user_performing_action=admin_user, ) test_session = DAQueryHistoryEntry( id=chat_session.id, persona_id=0, description=f"Test chat session {i}", feedback_type=feedback_type, ) # First message in chat ChatSessionManager.send_message( chat_session_id=chat_session.id, message=f"Question {i}?", user_performing_action=admin_user, ) messages = ChatSessionManager.get_chat_history( chat_session=chat_session, user_performing_action=admin_user, ) if feedback_type == QAFeedbackType.MIXED or feedback_type == QAFeedbackType.DISLIKE: ChatSessionManager.create_chat_message_feedback( message_id=messages[-1].id, is_positive=False, user_performing_action=admin_user, ) # Second message with different feedback types ChatSessionManager.send_message( chat_session_id=chat_session.id, message=f"Follow up {i}?", user_performing_action=admin_user, parent_message_id=messages[-1].id, ) # Get updated messages to get the ID of the second message messages = ChatSessionManager.get_chat_history( chat_session=chat_session, user_performing_action=admin_user, ) if feedback_type == QAFeedbackType.MIXED or feedback_type == QAFeedbackType.LIKE: ChatSessionManager.create_chat_message_feedback( message_id=messages[-1].id, is_positive=True, user_performing_action=admin_user, ) return feedback_type, test_session def setup_chat_sessions_with_different_feedback() -> ( tuple[DATestUser, dict[QAFeedbackType | None, list[DAQueryHistoryEntry]]] ): # Create admin user and required resources admin_user: DATestUser = UserManager.create(name="admin_user") cc_pair = CCPairManager.create_from_scratch(user_performing_action=admin_user) api_key = APIKeyManager.create(user_performing_action=admin_user) LLMProviderManager.create(user_performing_action=admin_user) # Seed a document cc_pair.documents = [] cc_pair.documents.append( DocumentManager.seed_doc_with_content( cc_pair=cc_pair, content="The company's revenue in Q1 was $1M", api_key=api_key, ) ) chat_sessions_by_feedback_type: dict[ QAFeedbackType | None, list[DAQueryHistoryEntry] ] = {} # Use ThreadPoolExecutor to create chat sessions in parallel with ThreadPoolExecutor(max_workers=5) as executor: # Submit all tasks and store futures j = 0 # Will result in 40 sessions number_of_sessions = 10 futures = [] for feedback_type in [ QAFeedbackType.MIXED, QAFeedbackType.LIKE, QAFeedbackType.DISLIKE, None, ]: futures.extend( [ executor.submit( _create_chat_session_with_feedback, admin_user, (j * number_of_sessions) + i, feedback_type, ) for i in range(number_of_sessions) ] ) j += 1 # Collect results in order for future in as_completed(futures): feedback_type, chat_session = future.result() chat_sessions_by_feedback_type.setdefault(feedback_type, []).append( chat_session ) return admin_user, chat_sessions_by_feedback_type ================================================ FILE: backend/tests/integration/tests/reporting/test_usage_export_api.py ================================================ import csv import os import time from datetime import datetime from datetime import timedelta from datetime import timezone from io import BytesIO from io import StringIO from uuid import UUID from zipfile import ZipFile import pytest import requests from ee.onyx.db.usage_export import UsageReportMetadata from onyx.configs.constants import DEFAULT_PERSONA_ID from onyx.db.seeding.chat_history_seeding import seed_chat_history from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.test_models import DATestUser @pytest.mark.skipif( os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true", reason="Usage export is an enterprise feature", ) class TestUsageExportAPI: def test_generate_usage_report( self, reset: None, # noqa: ARG002 admin_user: DATestUser, # noqa: ARG002 ) -> None: # Seed some chat history data for the report seed_chat_history( num_sessions=10, num_messages=4, days=30, user_id=UUID(admin_user.id), persona_id=DEFAULT_PERSONA_ID, ) # Get initial list of reports initial_response = requests.get( f"{API_SERVER_URL}/admin/usage-report", headers=admin_user.headers, ) assert initial_response.status_code == 200 initial_reports = initial_response.json() initial_count = len(initial_reports) # Test generating a report without date filters (all time) response = requests.post( f"{API_SERVER_URL}/admin/usage-report", json={}, headers=admin_user.headers, ) assert response.status_code == 204 # Wait for the new report to appear (with timeout) max_wait_time = 60 # seconds start_time = time.time() current_reports = initial_reports while time.time() - start_time < max_wait_time: check_response = requests.get( f"{API_SERVER_URL}/admin/usage-report", headers=admin_user.headers, ) assert check_response.status_code == 200 current_reports = check_response.json() if len(current_reports) > initial_count: # New report has been generated break time.sleep(2) # Verify a new report was created assert len(current_reports) > initial_count # Find the new report (should be the first one since they're ordered by time) new_report = current_reports[0] assert "report_name" in new_report assert new_report["report_name"].endswith(".zip") def test_generate_usage_report_with_date_range( self, reset: None, # noqa: ARG002 admin_user: DATestUser, # noqa: ARG002 ) -> None: # Seed some chat history data seed_chat_history( num_sessions=20, num_messages=4, days=60, user_id=UUID(admin_user.id), persona_id=DEFAULT_PERSONA_ID, ) # Get initial list of reports initial_response = requests.get( f"{API_SERVER_URL}/admin/usage-report", headers=admin_user.headers, ) assert initial_response.status_code == 200 initial_reports = initial_response.json() initial_count = len(initial_reports) # Generate report for the last 30 days period_to = datetime.now(tz=timezone.utc) period_from = period_to - timedelta(days=30) response = requests.post( f"{API_SERVER_URL}/admin/usage-report", json={ "period_from": period_from.isoformat(), "period_to": period_to.isoformat(), }, headers=admin_user.headers, ) assert response.status_code == 204 # Wait for the new report to appear max_wait_time = 60 start_time = time.time() current_reports = initial_reports while time.time() - start_time < max_wait_time: check_response = requests.get( f"{API_SERVER_URL}/admin/usage-report", headers=admin_user.headers, ) assert check_response.status_code == 200 current_reports = check_response.json() if len(current_reports) > initial_count: break time.sleep(2) assert len(current_reports) > initial_count # Find the new report (the one that wasn't in initial_reports) new_reports = [r for r in current_reports if r not in initial_reports] assert len(new_reports) > 0 new_report = new_reports[0] # Verify the new report has the expected date range assert new_report["period_from"] is not None assert new_report["period_to"] is not None def test_generate_usage_report_invalid_dates( self, reset: None, # noqa: ARG002 admin_user: DATestUser, # noqa: ARG002 ) -> None: # Test with invalid date format response = requests.post( f"{API_SERVER_URL}/admin/usage-report", json={ "period_from": "not-a-date", "period_to": datetime.now(tz=timezone.utc).isoformat(), }, headers=admin_user.headers, ) assert response.status_code == 400 def test_fetch_usage_reports( self, reset: None, # noqa: ARG002 admin_user: DATestUser, # noqa: ARG002 ) -> None: # First generate a report to ensure we have at least one seed_chat_history( num_sessions=5, num_messages=4, days=30, user_id=UUID(admin_user.id), persona_id=DEFAULT_PERSONA_ID, ) # Get initial count initial_response = requests.get( f"{API_SERVER_URL}/admin/usage-report", headers=admin_user.headers, ) assert initial_response.status_code == 200 initial_count = len(initial_response.json()) # Generate a report generate_response = requests.post( f"{API_SERVER_URL}/admin/usage-report", json={}, headers=admin_user.headers, ) assert generate_response.status_code == 204 # Wait for the new report to appear max_wait_time = 15 start_time = time.time() reports = [] while time.time() - start_time < max_wait_time: response = requests.get( f"{API_SERVER_URL}/admin/usage-report", headers=admin_user.headers, ) assert response.status_code == 200 reports = response.json() if len(reports) > initial_count: break time.sleep(2) # Verify we have at least one report assert isinstance(reports, list) assert len(reports) > initial_count # Validate the structure of the first report first_report = reports[0] assert "report_name" in first_report assert "requestor" in first_report assert "time_created" in first_report assert "period_from" in first_report assert "period_to" in first_report # Verify it's a valid UsageReportMetadata object report_metadata = UsageReportMetadata(**first_report) assert report_metadata.report_name.endswith(".zip") def test_read_usage_report( self, reset: None, # noqa: ARG002 admin_user: DATestUser, # noqa: ARG002 ) -> None: # First generate a report seed_chat_history( num_sessions=5, num_messages=4, days=30, user_id=UUID(admin_user.id), persona_id=DEFAULT_PERSONA_ID, ) # Get initial reports count initial_response = requests.get( f"{API_SERVER_URL}/admin/usage-report", headers=admin_user.headers, ) assert initial_response.status_code == 200 initial_count = len(initial_response.json()) generate_response = requests.post( f"{API_SERVER_URL}/admin/usage-report", json={}, headers=admin_user.headers, ) assert generate_response.status_code == 204 # Wait for the new report to appear max_wait_time = 15 start_time = time.time() reports = [] while time.time() - start_time < max_wait_time: list_response = requests.get( f"{API_SERVER_URL}/admin/usage-report", headers=admin_user.headers, ) assert list_response.status_code == 200 reports = list_response.json() if len(reports) > initial_count: break time.sleep(2) assert len(reports) > initial_count report_name = reports[0]["report_name"] # Download the report download_response = requests.get( f"{API_SERVER_URL}/admin/usage-report/{report_name}", headers=admin_user.headers, stream=True, ) assert download_response.status_code == 200 assert download_response.headers["Content-Type"] == "application/zip" assert "Content-Disposition" in download_response.headers assert ( f"filename={report_name}" in download_response.headers["Content-Disposition"] ) # Verify it's a valid zip file zip_content = BytesIO(download_response.content) with ZipFile(zip_content, "r") as zip_file: # Check that the zip contains expected files file_names = zip_file.namelist() assert "chat_messages.csv" in file_names assert "users.csv" in file_names # Verify chat_messages.csv has the expected columns with zip_file.open("chat_messages.csv") as csv_file: csv_content = csv_file.read().decode("utf-8") csv_reader = csv.DictReader(StringIO(csv_content)) # Check that all expected columns are present expected_columns = { "session_id", "user_id", "flow_type", "time_sent", "assistant_name", "user_email", "number_of_tokens", } actual_columns = set(csv_reader.fieldnames or []) assert ( expected_columns == actual_columns ), f"Expected columns {expected_columns}, but got {actual_columns}" # Verify there's at least one row of data rows = list(csv_reader) assert len(rows) > 0, "Expected at least one message in the report" # Verify the first row has non-empty values for all columns first_row = rows[0] for column in expected_columns: assert column in first_row, f"Column {column} not found in row" assert first_row[ column ], f"Column {column} has empty value in first row" # Verify specific new fields have appropriate values assert first_row["assistant_name"], "assistant_name should not be empty" assert first_row["user_email"], "user_email should not be empty" assert first_row[ "number_of_tokens" ].isdigit(), "number_of_tokens should be a numeric value" assert ( int(first_row["number_of_tokens"]) >= 0 ), "number_of_tokens should be non-negative" def test_read_nonexistent_report( self, reset: None, # noqa: ARG002 admin_user: DATestUser, # noqa: ARG002 ) -> None: # Try to download a report that doesn't exist response = requests.get( f"{API_SERVER_URL}/admin/usage-report/nonexistent_report.zip", headers=admin_user.headers, ) assert response.status_code == 404 def test_non_admin_cannot_generate_report( self, reset: None, # noqa: ARG002 basic_user: DATestUser, # noqa: ARG002 ) -> None: # Try to generate a report as non-admin response = requests.post( f"{API_SERVER_URL}/admin/usage-report", json={}, headers=basic_user.headers, ) assert response.status_code == 403 def test_non_admin_cannot_fetch_reports( self, reset: None, # noqa: ARG002 basic_user: DATestUser, # noqa: ARG002 ) -> None: # Try to fetch reports as non-admin response = requests.get( f"{API_SERVER_URL}/admin/usage-report", headers=basic_user.headers, ) assert response.status_code == 403 def test_non_admin_cannot_download_report( self, reset: None, # noqa: ARG002 basic_user: DATestUser, # noqa: ARG002 ) -> None: # Try to download a report as non-admin response = requests.get( f"{API_SERVER_URL}/admin/usage-report/some_report.zip", headers=basic_user.headers, ) assert response.status_code == 403 def test_concurrent_report_generation( self, reset: None, # noqa: ARG002 admin_user: DATestUser, # noqa: ARG002 ) -> None: # Seed some data seed_chat_history( num_sessions=10, num_messages=4, days=30, user_id=UUID(admin_user.id), persona_id=DEFAULT_PERSONA_ID, ) # Get initial count of reports initial_response = requests.get( f"{API_SERVER_URL}/admin/usage-report", headers=admin_user.headers, ) assert initial_response.status_code == 200 initial_count = len(initial_response.json()) # Generate multiple reports concurrently num_reports = 3 for i in range(num_reports): response = requests.post( f"{API_SERVER_URL}/admin/usage-report", json={}, headers=admin_user.headers, ) assert response.status_code == 204 # Wait for all reports to be generated max_wait_time = 120 start_time = time.time() reports = [] while time.time() - start_time < max_wait_time: response = requests.get( f"{API_SERVER_URL}/admin/usage-report", headers=admin_user.headers, ) assert response.status_code == 200 reports = response.json() if len(reports) >= initial_count + num_reports: break time.sleep(2) # Verify we have at least 3 new reports assert len(reports) >= initial_count + num_reports ================================================ FILE: backend/tests/integration/tests/scim/test_scim_groups.py ================================================ """Integration tests for SCIM group provisioning endpoints. Covers the full group lifecycle as driven by an IdP (Okta / Azure AD): 1. Create a group via POST /Groups 2. Retrieve a group via GET /Groups/{id} 3. List, filter, and paginate groups via GET /Groups 4. Replace a group via PUT /Groups/{id} 5. Patch a group (add/remove members, rename) via PATCH /Groups/{id} 6. Delete a group via DELETE /Groups/{id} 7. Error cases: duplicate name, not-found, invalid member IDs All tests are parameterized across IdP request styles (Okta sends lowercase PATCH ops; Entra sends capitalized ops like ``"Replace"``). The server normalizes both — these tests verify that. Auth tests live in test_scim_tokens.py. User lifecycle tests live in test_scim_users.py. """ import pytest import requests from onyx.auth.schemas import UserRole from tests.integration.common_utils.constants import ADMIN_USER_NAME from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.managers.scim_client import ScimClient from tests.integration.common_utils.managers.scim_token import ScimTokenManager from tests.integration.common_utils.managers.user import build_email from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.test_models import DATestUser SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group" SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User" SCIM_PATCH_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp" @pytest.fixture(scope="module", params=["okta", "entra"]) def idp_style(request: pytest.FixtureRequest) -> str: """Parameterized IdP style — runs every test with both Okta and Entra request formats.""" return request.param @pytest.fixture(scope="module") def scim_token(idp_style: str) -> str: """Create a single SCIM token shared across all tests in this module. Creating a new token revokes the previous one, so we create exactly once per IdP-style run and reuse. Uses UserManager directly to avoid fixture-scope conflicts with the function-scoped admin_user fixture. """ try: admin = UserManager.create(name=ADMIN_USER_NAME) except Exception: admin = UserManager.login_as_user( DATestUser( id="", email=build_email(ADMIN_USER_NAME), password=DEFAULT_PASSWORD, headers=GENERAL_HEADERS, role=UserRole.ADMIN, is_active=True, ) ) token = ScimTokenManager.create( name=f"scim-group-tests-{idp_style}", user_performing_action=admin, ).raw_token assert token is not None return token def _make_group_resource( display_name: str, external_id: str | None = None, members: list[dict] | None = None, ) -> dict: """Build a minimal SCIM GroupResource payload.""" resource: dict = { "schemas": [SCIM_GROUP_SCHEMA], "displayName": display_name, } if external_id is not None: resource["externalId"] = external_id if members is not None: resource["members"] = members return resource def _make_user_resource(email: str, external_id: str) -> dict: """Build a minimal SCIM UserResource payload for member creation.""" return { "schemas": [SCIM_USER_SCHEMA], "userName": email, "externalId": external_id, "name": {"givenName": "Test", "familyName": "User"}, "active": True, } def _make_patch_request(operations: list[dict], idp_style: str = "okta") -> dict: """Build a SCIM PatchOp payload, applying IdP-specific operation casing. Entra sends capitalized operations (e.g. ``"Replace"`` instead of ``"replace"``). The server's ``normalize_operation`` validator lowercases them — these tests verify that both casings are accepted. """ cased_operations = [] for operation in operations: cased = dict(operation) if idp_style == "entra": cased["op"] = operation["op"].capitalize() cased_operations.append(cased) return { "schemas": [SCIM_PATCH_SCHEMA], "Operations": cased_operations, } def _create_scim_user(token: str, email: str, external_id: str) -> requests.Response: return ScimClient.post( "/Users", token, json=_make_user_resource(email, external_id) ) def _create_scim_group( token: str, display_name: str, external_id: str | None = None, members: list[dict] | None = None, ) -> requests.Response: return ScimClient.post( "/Groups", token, json=_make_group_resource(display_name, external_id, members), ) # ------------------------------------------------------------------ # Lifecycle: create → get → list → replace → patch → delete # ------------------------------------------------------------------ def test_create_group(scim_token: str, idp_style: str) -> None: """POST /Groups creates a group and returns 201.""" name = f"Engineering {idp_style}" resp = _create_scim_group(scim_token, name, external_id=f"ext-eng-{idp_style}") assert resp.status_code == 201 body = resp.json() assert body["displayName"] == name assert body["externalId"] == f"ext-eng-{idp_style}" assert body["id"] # integer ID assigned by server assert body["meta"]["resourceType"] == "Group" def test_create_group_with_members(scim_token: str, idp_style: str) -> None: """POST /Groups with members populates the member list.""" user = _create_scim_user( scim_token, f"grp_member1_{idp_style}@example.com", f"ext-gm-{idp_style}" ).json() resp = _create_scim_group( scim_token, f"Backend Team {idp_style}", external_id=f"ext-backend-{idp_style}", members=[{"value": user["id"]}], ) assert resp.status_code == 201 body = resp.json() member_ids = [m["value"] for m in body["members"]] assert user["id"] in member_ids def test_get_group(scim_token: str, idp_style: str) -> None: """GET /Groups/{id} returns the group resource including members.""" user = _create_scim_user( scim_token, f"grp_get_m_{idp_style}@example.com", f"ext-ggm-{idp_style}" ).json() created = _create_scim_group( scim_token, f"Frontend Team {idp_style}", external_id=f"ext-fe-{idp_style}", members=[{"value": user["id"]}], ).json() resp = ScimClient.get(f"/Groups/{created['id']}", scim_token) assert resp.status_code == 200 body = resp.json() assert body["id"] == created["id"] assert body["displayName"] == f"Frontend Team {idp_style}" assert body["externalId"] == f"ext-fe-{idp_style}" member_ids = [m["value"] for m in body["members"]] assert user["id"] in member_ids def test_list_groups(scim_token: str, idp_style: str) -> None: """GET /Groups returns a ListResponse containing provisioned groups.""" name = f"DevOps Team {idp_style}" _create_scim_group(scim_token, name, external_id=f"ext-devops-{idp_style}") resp = ScimClient.get("/Groups", scim_token) assert resp.status_code == 200 body = resp.json() assert body["totalResults"] >= 1 names = [r["displayName"] for r in body["Resources"]] assert name in names def test_list_groups_pagination(scim_token: str, idp_style: str) -> None: """GET /Groups with startIndex and count returns correct pagination.""" _create_scim_group( scim_token, f"Page Group A {idp_style}", external_id=f"ext-page-a-{idp_style}" ) _create_scim_group( scim_token, f"Page Group B {idp_style}", external_id=f"ext-page-b-{idp_style}" ) resp = ScimClient.get("/Groups?startIndex=1&count=1", scim_token) assert resp.status_code == 200 body = resp.json() assert body["startIndex"] == 1 assert body["itemsPerPage"] == 1 assert body["totalResults"] >= 2 assert len(body["Resources"]) == 1 def test_filter_groups_by_display_name(scim_token: str, idp_style: str) -> None: """GET /Groups?filter=displayName eq '...' returns only matching groups.""" name = f"Unique QA Team {idp_style}" _create_scim_group(scim_token, name, external_id=f"ext-qa-filter-{idp_style}") resp = ScimClient.get(f'/Groups?filter=displayName eq "{name}"', scim_token) assert resp.status_code == 200 body = resp.json() assert body["totalResults"] == 1 assert body["Resources"][0]["displayName"] == name def test_filter_groups_by_external_id(scim_token: str, idp_style: str) -> None: """GET /Groups?filter=externalId eq '...' returns the matching group.""" ext_id = f"ext-unique-group-id-{idp_style}" _create_scim_group( scim_token, f"ExtId Filter Group {idp_style}", external_id=ext_id ) resp = ScimClient.get(f'/Groups?filter=externalId eq "{ext_id}"', scim_token) assert resp.status_code == 200 body = resp.json() assert body["totalResults"] == 1 assert body["Resources"][0]["externalId"] == ext_id def test_replace_group(scim_token: str, idp_style: str) -> None: """PUT /Groups/{id} replaces the group resource.""" created = _create_scim_group( scim_token, f"Original Name {idp_style}", external_id=f"ext-replace-g-{idp_style}", ).json() user = _create_scim_user( scim_token, f"grp_replace_m_{idp_style}@example.com", f"ext-grm-{idp_style}" ).json() updated_resource = _make_group_resource( display_name=f"Renamed Group {idp_style}", external_id=f"ext-replace-g-{idp_style}", members=[{"value": user["id"]}], ) resp = ScimClient.put(f"/Groups/{created['id']}", scim_token, json=updated_resource) assert resp.status_code == 200 body = resp.json() assert body["displayName"] == f"Renamed Group {idp_style}" member_ids = [m["value"] for m in body["members"]] assert user["id"] in member_ids def test_replace_group_clears_members(scim_token: str, idp_style: str) -> None: """PUT /Groups/{id} with empty members removes all memberships.""" user = _create_scim_user( scim_token, f"grp_clear_m_{idp_style}@example.com", f"ext-gcm-{idp_style}" ).json() created = _create_scim_group( scim_token, f"Clear Members Group {idp_style}", external_id=f"ext-clear-g-{idp_style}", members=[{"value": user["id"]}], ).json() assert len(created["members"]) == 1 resp = ScimClient.put( f"/Groups/{created['id']}", scim_token, json=_make_group_resource( f"Clear Members Group {idp_style}", f"ext-clear-g-{idp_style}", members=[] ), ) assert resp.status_code == 200 assert resp.json()["members"] == [] def test_patch_add_member(scim_token: str, idp_style: str) -> None: """PATCH /Groups/{id} with op=add adds a member.""" created = _create_scim_group( scim_token, f"Patch Add Group {idp_style}", external_id=f"ext-patch-add-{idp_style}", ).json() user = _create_scim_user( scim_token, f"grp_patch_add_{idp_style}@example.com", f"ext-gpa-{idp_style}" ).json() resp = ScimClient.patch( f"/Groups/{created['id']}", scim_token, json=_make_patch_request( [{"op": "add", "path": "members", "value": [{"value": user["id"]}]}], idp_style, ), ) assert resp.status_code == 200 member_ids = [m["value"] for m in resp.json()["members"]] assert user["id"] in member_ids def test_patch_remove_member(scim_token: str, idp_style: str) -> None: """PATCH /Groups/{id} with op=remove removes a specific member.""" user = _create_scim_user( scim_token, f"grp_patch_rm_{idp_style}@example.com", f"ext-gpr-{idp_style}" ).json() created = _create_scim_group( scim_token, f"Patch Remove Group {idp_style}", external_id=f"ext-patch-rm-{idp_style}", members=[{"value": user["id"]}], ).json() assert len(created["members"]) == 1 resp = ScimClient.patch( f"/Groups/{created['id']}", scim_token, json=_make_patch_request( [ { "op": "remove", "path": f'members[value eq "{user["id"]}"]', } ], idp_style, ), ) assert resp.status_code == 200 assert resp.json()["members"] == [] def test_patch_replace_members(scim_token: str, idp_style: str) -> None: """PATCH /Groups/{id} with op=replace on members swaps the entire list.""" user_a = _create_scim_user( scim_token, f"grp_repl_a_{idp_style}@example.com", f"ext-gra-{idp_style}" ).json() user_b = _create_scim_user( scim_token, f"grp_repl_b_{idp_style}@example.com", f"ext-grb-{idp_style}" ).json() created = _create_scim_group( scim_token, f"Patch Replace Group {idp_style}", external_id=f"ext-patch-repl-{idp_style}", members=[{"value": user_a["id"]}], ).json() # Replace member list: swap A for B resp = ScimClient.patch( f"/Groups/{created['id']}", scim_token, json=_make_patch_request( [ { "op": "replace", "path": "members", "value": [{"value": user_b["id"]}], } ], idp_style, ), ) assert resp.status_code == 200 member_ids = [m["value"] for m in resp.json()["members"]] assert user_b["id"] in member_ids assert user_a["id"] not in member_ids def test_patch_rename_group(scim_token: str, idp_style: str) -> None: """PATCH /Groups/{id} with op=replace on displayName renames the group.""" created = _create_scim_group( scim_token, f"Old Group Name {idp_style}", external_id=f"ext-rename-g-{idp_style}", ).json() new_name = f"New Group Name {idp_style}" resp = ScimClient.patch( f"/Groups/{created['id']}", scim_token, json=_make_patch_request( [{"op": "replace", "path": "displayName", "value": new_name}], idp_style, ), ) assert resp.status_code == 200 assert resp.json()["displayName"] == new_name # Confirm via GET get_resp = ScimClient.get(f"/Groups/{created['id']}", scim_token) assert get_resp.json()["displayName"] == new_name def test_delete_group(scim_token: str, idp_style: str) -> None: """DELETE /Groups/{id} removes the group.""" created = _create_scim_group( scim_token, f"Delete Me Group {idp_style}", external_id=f"ext-del-g-{idp_style}", ).json() resp = ScimClient.delete(f"/Groups/{created['id']}", scim_token) assert resp.status_code == 204 # Second DELETE returns 404 (group hard-deleted) resp2 = ScimClient.delete(f"/Groups/{created['id']}", scim_token) assert resp2.status_code == 404 def test_delete_group_preserves_members(scim_token: str, idp_style: str) -> None: """DELETE /Groups/{id} removes memberships but does not deactivate users.""" user = _create_scim_user( scim_token, f"grp_del_member_{idp_style}@example.com", f"ext-gdm-{idp_style}" ).json() created = _create_scim_group( scim_token, f"Delete With Members {idp_style}", external_id=f"ext-del-wm-{idp_style}", members=[{"value": user["id"]}], ).json() resp = ScimClient.delete(f"/Groups/{created['id']}", scim_token) assert resp.status_code == 204 # User should still be active and retrievable user_resp = ScimClient.get(f"/Users/{user['id']}", scim_token) assert user_resp.status_code == 200 assert user_resp.json()["active"] is True # ------------------------------------------------------------------ # Error cases # ------------------------------------------------------------------ def test_create_group_duplicate_name(scim_token: str, idp_style: str) -> None: """POST /Groups with an already-taken displayName returns 409.""" name = f"Dup Name Group {idp_style}" resp1 = _create_scim_group(scim_token, name, external_id=f"ext-dup-g1-{idp_style}") assert resp1.status_code == 201 resp2 = _create_scim_group(scim_token, name, external_id=f"ext-dup-g2-{idp_style}") assert resp2.status_code == 409 def test_get_nonexistent_group(scim_token: str) -> None: """GET /Groups/{bad-id} returns 404.""" resp = ScimClient.get("/Groups/999999999", scim_token) assert resp.status_code == 404 def test_create_group_with_invalid_member(scim_token: str, idp_style: str) -> None: """POST /Groups with a non-existent member UUID returns 400.""" resp = _create_scim_group( scim_token, f"Bad Member Group {idp_style}", external_id=f"ext-bad-m-{idp_style}", members=[{"value": "00000000-0000-0000-0000-000000000000"}], ) assert resp.status_code == 400 assert "not found" in resp.json()["detail"].lower() def test_patch_add_nonexistent_member(scim_token: str, idp_style: str) -> None: """PATCH /Groups/{id} adding a non-existent member returns 400.""" created = _create_scim_group( scim_token, f"Patch Bad Member Group {idp_style}", external_id=f"ext-pbm-{idp_style}", ).json() resp = ScimClient.patch( f"/Groups/{created['id']}", scim_token, json=_make_patch_request( [ { "op": "add", "path": "members", "value": [{"value": "00000000-0000-0000-0000-000000000000"}], } ], idp_style, ), ) assert resp.status_code == 400 assert "not found" in resp.json()["detail"].lower() def test_patch_add_duplicate_member_is_idempotent( scim_token: str, idp_style: str ) -> None: """PATCH /Groups/{id} adding an already-present member succeeds silently.""" user = _create_scim_user( scim_token, f"grp_dup_add_{idp_style}@example.com", f"ext-gda-{idp_style}" ).json() created = _create_scim_group( scim_token, f"Idempotent Add Group {idp_style}", external_id=f"ext-idem-g-{idp_style}", members=[{"value": user["id"]}], ).json() assert len(created["members"]) == 1 # Add same member again resp = ScimClient.patch( f"/Groups/{created['id']}", scim_token, json=_make_patch_request( [{"op": "add", "path": "members", "value": [{"value": user["id"]}]}], idp_style, ), ) assert resp.status_code == 200 assert len(resp.json()["members"]) == 1 # still just one member def test_create_group_reserved_name_admin(scim_token: str) -> None: """POST /Groups with reserved name 'Admin' returns 409.""" resp = _create_scim_group(scim_token, "Admin", external_id="ext-reserved-admin") assert resp.status_code == 409 assert "reserved" in resp.json()["detail"].lower() def test_create_group_reserved_name_basic(scim_token: str) -> None: """POST /Groups with reserved name 'Basic' returns 409.""" resp = _create_scim_group(scim_token, "Basic", external_id="ext-reserved-basic") assert resp.status_code == 409 assert "reserved" in resp.json()["detail"].lower() def test_replace_group_cannot_rename_to_reserved( scim_token: str, idp_style: str ) -> None: """PUT /Groups/{id} renaming a group to 'Admin' returns 409.""" created = _create_scim_group( scim_token, f"Rename To Reserved {idp_style}", external_id=f"ext-rtr-{idp_style}", ).json() resp = ScimClient.put( f"/Groups/{created['id']}", scim_token, json=_make_group_resource( display_name="Admin", external_id=f"ext-rtr-{idp_style}" ), ) assert resp.status_code == 409 assert "reserved" in resp.json()["detail"].lower() def test_patch_rename_to_reserved_name(scim_token: str, idp_style: str) -> None: """PATCH /Groups/{id} renaming a group to 'Basic' returns 409.""" created = _create_scim_group( scim_token, f"Patch Rename Reserved {idp_style}", external_id=f"ext-prr-{idp_style}", ).json() resp = ScimClient.patch( f"/Groups/{created['id']}", scim_token, json=_make_patch_request( [{"op": "replace", "path": "displayName", "value": "Basic"}], idp_style, ), ) assert resp.status_code == 409 assert "reserved" in resp.json()["detail"].lower() def test_delete_reserved_group_rejected(scim_token: str) -> None: """DELETE /Groups/{id} on a reserved group ('Admin') returns 409.""" # Look up the reserved 'Admin' group via SCIM filter resp = ScimClient.get('/Groups?filter=displayName eq "Admin"', scim_token) assert resp.status_code == 200 resources = resp.json()["Resources"] assert len(resources) >= 1, "Expected reserved 'Admin' group to exist" admin_group_id = resources[0]["id"] resp = ScimClient.delete(f"/Groups/{admin_group_id}", scim_token) assert resp.status_code == 409 assert "reserved" in resp.json()["detail"].lower() def test_scim_created_group_has_basic_permission( scim_token: str, idp_style: str ) -> None: """POST /Groups assigns the 'basic' permission to the group itself.""" # Create a SCIM group (no members needed — we check the group's permissions) resp = _create_scim_group( scim_token, f"Basic Perm Group {idp_style}", external_id=f"ext-basic-perm-{idp_style}", ) assert resp.status_code == 201 group_id = resp.json()["id"] # Log in as the admin user (created by the scim_token fixture). admin = DATestUser( id="", email=build_email(ADMIN_USER_NAME), password=DEFAULT_PASSWORD, headers=GENERAL_HEADERS, role=UserRole.ADMIN, is_active=True, ) admin = UserManager.login_as_user(admin) # Verify the group itself was granted the basic permission perms_resp = requests.get( f"{API_SERVER_URL}/manage/admin/user-group/{group_id}/permissions", headers=admin.headers, ) perms_resp.raise_for_status() perms = perms_resp.json() assert "basic" in perms, f"SCIM group should have 'basic' permission, got: {perms}" def test_replace_group_cannot_rename_from_reserved(scim_token: str) -> None: """PUT /Groups/{id} renaming a reserved group ('Admin') to a non-reserved name returns 409.""" resp = ScimClient.get('/Groups?filter=displayName eq "Admin"', scim_token) assert resp.status_code == 200 resources = resp.json()["Resources"] assert len(resources) >= 1, "Expected reserved 'Admin' group to exist" admin_group_id = resources[0]["id"] resp = ScimClient.put( f"/Groups/{admin_group_id}", scim_token, json=_make_group_resource( display_name="RenamedAdmin", external_id="ext-rename-from-reserved" ), ) assert resp.status_code == 409 assert "reserved" in resp.json()["detail"].lower() def test_patch_rename_from_reserved_name(scim_token: str, idp_style: str) -> None: """PATCH /Groups/{id} renaming a reserved group ('Admin') returns 409.""" resp = ScimClient.get('/Groups?filter=displayName eq "Admin"', scim_token) assert resp.status_code == 200 resources = resp.json()["Resources"] assert len(resources) >= 1, "Expected reserved 'Admin' group to exist" admin_group_id = resources[0]["id"] resp = ScimClient.patch( f"/Groups/{admin_group_id}", scim_token, json=_make_patch_request( [{"op": "replace", "path": "displayName", "value": "RenamedAdmin"}], idp_style, ), ) assert resp.status_code == 409 assert "reserved" in resp.json()["detail"].lower() ================================================ FILE: backend/tests/integration/tests/scim/test_scim_tokens.py ================================================ """Integration tests for SCIM token management. Covers the admin token API and SCIM bearer-token authentication: 1. Token lifecycle: create, retrieve metadata, use for SCIM requests 2. Token rotation: creating a new token revokes previous tokens 3. Revoked tokens are rejected by SCIM endpoints 4. Non-admin users cannot manage SCIM tokens 5. SCIM requests without a token are rejected 6. Service discovery endpoints work without authentication 7. last_used_at is updated after a SCIM request """ import time import requests from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.managers.scim_client import ScimClient from tests.integration.common_utils.managers.scim_token import ScimTokenManager from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.test_models import DATestUser def test_scim_token_lifecycle(admin_user: DATestUser) -> None: """Create token → retrieve metadata → use for SCIM request.""" token = ScimTokenManager.create( name="Test SCIM Token", user_performing_action=admin_user, ) assert token.raw_token is not None assert token.raw_token.startswith("onyx_scim_") assert token.is_active is True assert "****" in token.token_display # GET returns the same metadata but raw_token is None because the # server only reveals the raw token once at creation time (it stores # only the SHA-256 hash). active = ScimTokenManager.get_active(user_performing_action=admin_user) assert active == token.model_copy(update={"raw_token": None}) # Token works for SCIM requests response = ScimClient.get("/Users", token.raw_token) assert response.status_code == 200 body = response.json() assert "Resources" in body assert body["totalResults"] >= 0 def test_scim_token_rotation_revokes_previous(admin_user: DATestUser) -> None: """Creating a new token automatically revokes the previous one.""" first = ScimTokenManager.create( name="First Token", user_performing_action=admin_user, ) assert first.raw_token is not None response = ScimClient.get("/Users", first.raw_token) assert response.status_code == 200 # Create second token — should revoke first second = ScimTokenManager.create( name="Second Token", user_performing_action=admin_user, ) assert second.raw_token is not None # Active token should now be the second one active = ScimTokenManager.get_active(user_performing_action=admin_user) assert active == second.model_copy(update={"raw_token": None}) # First token rejected, second works assert ScimClient.get("/Users", first.raw_token).status_code == 401 assert ScimClient.get("/Users", second.raw_token).status_code == 200 def test_scim_request_without_token_rejected( admin_user: DATestUser, # noqa: ARG001 ) -> None: """SCIM endpoints reject requests with no Authorization header.""" assert ScimClient.get_no_auth("/Users").status_code == 401 def test_scim_request_with_bad_token_rejected( admin_user: DATestUser, # noqa: ARG001 ) -> None: """SCIM endpoints reject requests with an invalid token.""" assert ScimClient.get("/Users", "onyx_scim_bogus_token_value").status_code == 401 def test_non_admin_cannot_create_token( admin_user: DATestUser, # noqa: ARG001 ) -> None: """Non-admin users get 403 when trying to create a SCIM token.""" basic_user = UserManager.create(name="scim_basic_user") response = requests.post( f"{API_SERVER_URL}/admin/enterprise-settings/scim/token", json={"name": "Should Fail"}, headers=basic_user.headers, timeout=60, ) assert response.status_code == 403 def test_non_admin_cannot_get_token( admin_user: DATestUser, # noqa: ARG001 ) -> None: """Non-admin users get 403 when trying to retrieve SCIM token metadata.""" basic_user = UserManager.create(name="scim_basic_user2") response = requests.get( f"{API_SERVER_URL}/admin/enterprise-settings/scim/token", headers=basic_user.headers, timeout=60, ) assert response.status_code == 403 def test_no_active_token_returns_404(new_admin_user: DATestUser) -> None: """GET active token returns 404 when no token exists.""" # new_admin_user depends on the reset fixture, ensuring a clean DB # with no active SCIM tokens. active = ScimTokenManager.get_active(user_performing_action=new_admin_user) assert active is None response = requests.get( f"{API_SERVER_URL}/admin/enterprise-settings/scim/token", headers=new_admin_user.headers, timeout=60, ) assert response.status_code == 404 def test_service_discovery_no_auth_required( admin_user: DATestUser, # noqa: ARG001 ) -> None: """Service discovery endpoints work without any authentication.""" for path in ["/ServiceProviderConfig", "/ResourceTypes", "/Schemas"]: response = ScimClient.get_no_auth(path) assert response.status_code == 200, f"{path} returned {response.status_code}" def test_last_used_at_updated_after_scim_request( admin_user: DATestUser, ) -> None: """last_used_at timestamp is updated after using the token.""" token = ScimTokenManager.create( name="Last Used Token", user_performing_action=admin_user, ) assert token.raw_token is not None active = ScimTokenManager.get_active(user_performing_action=admin_user) assert active is not None assert active.last_used_at is None # Make a SCIM request, then verify last_used_at is set assert ScimClient.get("/Users", token.raw_token).status_code == 200 time.sleep(0.5) active_after = ScimTokenManager.get_active(user_performing_action=admin_user) assert active_after is not None assert active_after.last_used_at is not None ================================================ FILE: backend/tests/integration/tests/scim/test_scim_users.py ================================================ """Integration tests for SCIM user provisioning endpoints. Covers the full user lifecycle as driven by an IdP (Okta / Azure AD): 1. Create a user via POST /Users 2. Retrieve a user via GET /Users/{id} 3. List, filter, and paginate users via GET /Users 4. Replace a user via PUT /Users/{id} 5. Patch a user (deactivate/reactivate) via PATCH /Users/{id} 6. Delete a user via DELETE /Users/{id} 7. Error cases: missing externalId, duplicate email, not-found, seat limit All tests are parameterized across IdP request styles: - **Okta**: lowercase PATCH ops, minimal payloads (core schema only). - **Entra**: capitalized ops (``"Replace"``), enterprise extension data (department, manager), and structured email arrays. The server normalizes both — these tests verify that all IdP-specific fields are accepted and round-tripped correctly. Auth, revoked-token, and service-discovery tests live in test_scim_tokens.py. """ from datetime import datetime from datetime import timedelta from datetime import timezone import pytest import redis import requests from ee.onyx.server.license.models import LicenseMetadata from ee.onyx.server.license.models import LicenseSource from ee.onyx.server.license.models import PlanType from onyx.auth.schemas import UserRole from onyx.configs.app_configs import REDIS_DB_NUMBER from onyx.configs.app_configs import REDIS_HOST from onyx.configs.app_configs import REDIS_PORT from onyx.db.enums import AccountType from onyx.server.settings.models import ApplicationStatus from tests.integration.common_utils.constants import ADMIN_USER_NAME from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.managers.scim_client import ScimClient from tests.integration.common_utils.managers.scim_token import ScimTokenManager from tests.integration.common_utils.managers.user import build_email from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.test_models import DATestUser SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User" SCIM_ENTERPRISE_USER_SCHEMA = ( "urn:ietf:params:scim:schemas:extension:enterprise:2.0:User" ) SCIM_PATCH_SCHEMA = "urn:ietf:params:scim:api:messages:2.0:PatchOp" _LICENSE_REDIS_KEY = "public:license:metadata" @pytest.fixture(scope="module", params=["okta", "entra"]) def idp_style(request: pytest.FixtureRequest) -> str: """Parameterized IdP style — runs every test with both Okta and Entra request formats.""" return request.param @pytest.fixture(scope="module") def scim_token(idp_style: str) -> str: """Create a single SCIM token shared across all tests in this module. Creating a new token revokes the previous one, so we create exactly once per IdP-style run and reuse. Uses UserManager directly to avoid fixture-scope conflicts with the function-scoped admin_user fixture. """ from tests.integration.common_utils.constants import ADMIN_USER_NAME from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.managers.user import build_email from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.test_models import DATestUser try: admin = UserManager.create(name=ADMIN_USER_NAME) except Exception: admin = UserManager.login_as_user( DATestUser( id="", email=build_email(ADMIN_USER_NAME), password=DEFAULT_PASSWORD, headers=GENERAL_HEADERS, role=UserRole.ADMIN, is_active=True, ) ) token = ScimTokenManager.create( name=f"scim-user-tests-{idp_style}", user_performing_action=admin, ).raw_token assert token is not None return token def _make_user_resource( email: str, external_id: str, given_name: str = "Test", family_name: str = "User", active: bool = True, idp_style: str = "okta", department: str | None = None, manager_id: str | None = None, ) -> dict: """Build a SCIM UserResource payload appropriate for the IdP style. Entra sends richer payloads including enterprise extension data (department, manager), structured email arrays, and the enterprise schema URN. Okta sends minimal payloads with just core user fields. """ resource: dict = { "schemas": [SCIM_USER_SCHEMA], "userName": email, "externalId": external_id, "name": { "givenName": given_name, "familyName": family_name, }, "active": active, } if idp_style == "entra": dept = department or "Engineering" mgr = manager_id or "mgr-ext-001" resource["schemas"].append(SCIM_ENTERPRISE_USER_SCHEMA) resource[SCIM_ENTERPRISE_USER_SCHEMA] = { "department": dept, "manager": {"value": mgr}, } resource["emails"] = [ {"value": email, "type": "work", "primary": True}, ] return resource def _make_patch_request(operations: list[dict], idp_style: str = "okta") -> dict: """Build a SCIM PatchOp payload, applying IdP-specific operation casing. Entra sends capitalized operations (e.g. ``"Replace"`` instead of ``"replace"``). The server's ``normalize_operation`` validator lowercases them — these tests verify that both casings are accepted. """ cased_operations = [] for operation in operations: cased = dict(operation) if idp_style == "entra": cased["op"] = operation["op"].capitalize() cased_operations.append(cased) return { "schemas": [SCIM_PATCH_SCHEMA], "Operations": cased_operations, } def _create_scim_user( token: str, email: str, external_id: str, idp_style: str = "okta", ) -> requests.Response: return ScimClient.post( "/Users", token, json=_make_user_resource(email, external_id, idp_style=idp_style), ) def _assert_entra_extension( body: dict, expected_department: str = "Engineering", expected_manager: str = "mgr-ext-001", ) -> None: """Assert that Entra enterprise extension fields round-tripped correctly.""" assert SCIM_ENTERPRISE_USER_SCHEMA in body["schemas"] ext = body[SCIM_ENTERPRISE_USER_SCHEMA] assert ext["department"] == expected_department assert ext["manager"]["value"] == expected_manager def _assert_entra_emails(body: dict, expected_email: str) -> None: """Assert that structured email metadata round-tripped correctly.""" emails = body["emails"] assert len(emails) >= 1 work_email = next(e for e in emails if e.get("type") == "work") assert work_email["value"] == expected_email assert work_email["primary"] is True # ------------------------------------------------------------------ # Lifecycle: create -> get -> list -> replace -> patch -> delete # ------------------------------------------------------------------ def test_create_user(scim_token: str, idp_style: str) -> None: """POST /Users creates a provisioned user and returns 201.""" email = f"scim_create_{idp_style}@example.com" ext_id = f"ext-create-{idp_style}" resp = _create_scim_user(scim_token, email, ext_id, idp_style) assert resp.status_code == 201 body = resp.json() assert body["userName"] == email assert body["externalId"] == ext_id assert body["active"] is True assert body["id"] # UUID assigned by server assert body["meta"]["resourceType"] == "User" assert body["name"]["givenName"] == "Test" assert body["name"]["familyName"] == "User" if idp_style == "entra": _assert_entra_extension(body) _assert_entra_emails(body, email) def test_create_user_default_group_and_account_type( scim_token: str, idp_style: str ) -> None: """SCIM-provisioned users get Basic default group and STANDARD account_type.""" email = f"scim_defaults_{idp_style}@example.com" ext_id = f"ext-defaults-{idp_style}" resp = _create_scim_user(scim_token, email, ext_id, idp_style) assert resp.status_code == 201 user_id = resp.json()["id"] # --- Verify group assignment via SCIM GET --- get_resp = ScimClient.get(f"/Users/{user_id}", scim_token) assert get_resp.status_code == 200 groups = get_resp.json().get("groups", []) group_names = {g["display"] for g in groups} assert "Basic" in group_names, f"Expected 'Basic' in groups, got {group_names}" assert "Admin" not in group_names, "SCIM user should not be in Admin group" # --- Verify account_type via admin API --- admin = UserManager.login_as_user( DATestUser( id="", email=build_email(ADMIN_USER_NAME), password=DEFAULT_PASSWORD, headers=GENERAL_HEADERS, role=UserRole.ADMIN, is_active=True, ) ) page = UserManager.get_user_page( user_performing_action=admin, search_query=email, ) assert page.total_items >= 1 scim_user_snapshot = next((u for u in page.items if u.email == email), None) assert ( scim_user_snapshot is not None ), f"SCIM user {email} not found in user listing" assert ( scim_user_snapshot.account_type == AccountType.STANDARD ), f"Expected STANDARD, got {scim_user_snapshot.account_type}" def test_get_user(scim_token: str, idp_style: str) -> None: """GET /Users/{id} returns the user resource with all stored fields.""" email = f"scim_get_{idp_style}@example.com" ext_id = f"ext-get-{idp_style}" created = _create_scim_user(scim_token, email, ext_id, idp_style).json() resp = ScimClient.get(f"/Users/{created['id']}", scim_token) assert resp.status_code == 200 body = resp.json() assert body["id"] == created["id"] assert body["userName"] == email assert body["externalId"] == ext_id assert body["name"]["givenName"] == "Test" assert body["name"]["familyName"] == "User" if idp_style == "entra": _assert_entra_extension(body) _assert_entra_emails(body, email) def test_list_users(scim_token: str, idp_style: str) -> None: """GET /Users returns a ListResponse containing provisioned users.""" email = f"scim_list_{idp_style}@example.com" _create_scim_user(scim_token, email, f"ext-list-{idp_style}", idp_style) resp = ScimClient.get("/Users", scim_token) assert resp.status_code == 200 body = resp.json() assert body["totalResults"] >= 1 emails = [r["userName"] for r in body["Resources"]] assert email in emails def test_list_users_pagination(scim_token: str, idp_style: str) -> None: """GET /Users with startIndex and count returns correct pagination.""" _create_scim_user( scim_token, f"scim_page1_{idp_style}@example.com", f"ext-page-1-{idp_style}", idp_style, ) _create_scim_user( scim_token, f"scim_page2_{idp_style}@example.com", f"ext-page-2-{idp_style}", idp_style, ) resp = ScimClient.get("/Users?startIndex=1&count=1", scim_token) assert resp.status_code == 200 body = resp.json() assert body["startIndex"] == 1 assert body["itemsPerPage"] == 1 assert body["totalResults"] >= 2 assert len(body["Resources"]) == 1 def test_filter_users_by_username(scim_token: str, idp_style: str) -> None: """GET /Users?filter=userName eq '...' returns only matching users.""" email = f"scim_filter_{idp_style}@example.com" _create_scim_user(scim_token, email, f"ext-filter-{idp_style}", idp_style) resp = ScimClient.get(f'/Users?filter=userName eq "{email}"', scim_token) assert resp.status_code == 200 body = resp.json() assert body["totalResults"] == 1 assert body["Resources"][0]["userName"] == email def test_replace_user(scim_token: str, idp_style: str) -> None: """PUT /Users/{id} replaces the user resource including enterprise fields.""" email = f"scim_replace_{idp_style}@example.com" ext_id = f"ext-replace-{idp_style}" created = _create_scim_user(scim_token, email, ext_id, idp_style).json() updated_resource = _make_user_resource( email=email, external_id=ext_id, given_name="Updated", family_name="Name", idp_style=idp_style, department="Product", ) resp = ScimClient.put(f"/Users/{created['id']}", scim_token, json=updated_resource) assert resp.status_code == 200 body = resp.json() assert body["name"]["givenName"] == "Updated" assert body["name"]["familyName"] == "Name" if idp_style == "entra": _assert_entra_extension(body, expected_department="Product") _assert_entra_emails(body, email) def test_patch_deactivate_user(scim_token: str, idp_style: str) -> None: """PATCH /Users/{id} with active=false deactivates the user.""" created = _create_scim_user( scim_token, f"scim_deactivate_{idp_style}@example.com", f"ext-deactivate-{idp_style}", idp_style, ).json() assert created["active"] is True resp = ScimClient.patch( f"/Users/{created['id']}", scim_token, json=_make_patch_request( [{"op": "replace", "path": "active", "value": False}], idp_style ), ) assert resp.status_code == 200 assert resp.json()["active"] is False # Confirm via GET get_resp = ScimClient.get(f"/Users/{created['id']}", scim_token) assert get_resp.json()["active"] is False def test_patch_reactivate_user(scim_token: str, idp_style: str) -> None: """PATCH active=true reactivates a previously deactivated user.""" created = _create_scim_user( scim_token, f"scim_reactivate_{idp_style}@example.com", f"ext-reactivate-{idp_style}", idp_style, ).json() # Deactivate deactivate_resp = ScimClient.patch( f"/Users/{created['id']}", scim_token, json=_make_patch_request( [{"op": "replace", "path": "active", "value": False}], idp_style ), ) assert deactivate_resp.status_code == 200 assert deactivate_resp.json()["active"] is False # Reactivate resp = ScimClient.patch( f"/Users/{created['id']}", scim_token, json=_make_patch_request( [{"op": "replace", "path": "active", "value": True}], idp_style ), ) assert resp.status_code == 200 assert resp.json()["active"] is True def test_delete_user(scim_token: str, idp_style: str) -> None: """DELETE /Users/{id} deactivates and removes the SCIM mapping.""" created = _create_scim_user( scim_token, f"scim_delete_{idp_style}@example.com", f"ext-delete-{idp_style}", idp_style, ).json() resp = ScimClient.delete(f"/Users/{created['id']}", scim_token) assert resp.status_code == 204 # Second DELETE returns 404 per RFC 7644 §3.6 (mapping removed) resp2 = ScimClient.delete(f"/Users/{created['id']}", scim_token) assert resp2.status_code == 404 # ------------------------------------------------------------------ # Error cases # ------------------------------------------------------------------ def test_create_user_missing_external_id(scim_token: str, idp_style: str) -> None: """POST /Users without externalId succeeds (RFC 7643: externalId is optional).""" email = f"scim_no_extid_{idp_style}@example.com" resp = ScimClient.post( "/Users", scim_token, json={ "schemas": [SCIM_USER_SCHEMA], "userName": email, "active": True, }, ) assert resp.status_code == 201 body = resp.json() assert body["userName"] == email assert body.get("externalId") is None def test_create_user_duplicate_email(scim_token: str, idp_style: str) -> None: """POST /Users with an already-taken email returns 409.""" email = f"scim_dup_{idp_style}@example.com" resp1 = _create_scim_user(scim_token, email, f"ext-dup-1-{idp_style}", idp_style) assert resp1.status_code == 201 resp2 = _create_scim_user(scim_token, email, f"ext-dup-2-{idp_style}", idp_style) assert resp2.status_code == 409 def test_get_nonexistent_user(scim_token: str) -> None: """GET /Users/{bad-id} returns 404.""" resp = ScimClient.get("/Users/00000000-0000-0000-0000-000000000000", scim_token) assert resp.status_code == 404 def test_filter_users_by_external_id(scim_token: str, idp_style: str) -> None: """GET /Users?filter=externalId eq '...' returns the matching user.""" ext_id = f"ext-unique-filter-id-{idp_style}" _create_scim_user( scim_token, f"scim_extfilter_{idp_style}@example.com", ext_id, idp_style ) resp = ScimClient.get(f'/Users?filter=externalId eq "{ext_id}"', scim_token) assert resp.status_code == 200 body = resp.json() assert body["totalResults"] == 1 assert body["Resources"][0]["externalId"] == ext_id # ------------------------------------------------------------------ # Seat-limit enforcement # ------------------------------------------------------------------ def _seed_license(r: redis.Redis, seats: int) -> None: """Write a LicenseMetadata entry into Redis with the given seat cap.""" now = datetime.now(timezone.utc) metadata = LicenseMetadata( tenant_id="public", organization_name="Test Org", seats=seats, used_seats=0, # check_seat_availability recalculates from DB plan_type=PlanType.ANNUAL, issued_at=now, expires_at=now + timedelta(days=365), status=ApplicationStatus.ACTIVE, source=LicenseSource.MANUAL_UPLOAD, ) r.set(_LICENSE_REDIS_KEY, metadata.model_dump_json(), ex=300) def test_create_user_seat_limit(scim_token: str, idp_style: str) -> None: """POST /Users returns 403 when the seat limit is reached.""" r = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER) # admin_user already occupies 1 seat; cap at 1 -> full _seed_license(r, seats=1) try: resp = _create_scim_user( scim_token, f"scim_blocked_{idp_style}@example.com", f"ext-blocked-{idp_style}", idp_style, ) assert resp.status_code == 403 assert "seat" in resp.json()["detail"].lower() finally: r.delete(_LICENSE_REDIS_KEY) def test_reactivate_user_seat_limit(scim_token: str, idp_style: str) -> None: """PATCH active=true returns 403 when the seat limit is reached.""" # Create and deactivate a user (before license is seeded) created = _create_scim_user( scim_token, f"scim_reactivate_blocked_{idp_style}@example.com", f"ext-reactivate-blocked-{idp_style}", idp_style, ).json() assert created["active"] is True deactivate_resp = ScimClient.patch( f"/Users/{created['id']}", scim_token, json=_make_patch_request( [{"op": "replace", "path": "active", "value": False}], idp_style ), ) assert deactivate_resp.status_code == 200 assert deactivate_resp.json()["active"] is False r = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER) # Seed license capped at current active users -> reactivation should fail _seed_license(r, seats=1) try: resp = ScimClient.patch( f"/Users/{created['id']}", scim_token, json=_make_patch_request( [{"op": "replace", "path": "active", "value": True}], idp_style ), ) assert resp.status_code == 403 assert "seat" in resp.json()["detail"].lower() finally: r.delete(_LICENSE_REDIS_KEY) ================================================ FILE: backend/tests/integration/tests/search_settings/test_search_settings.py ================================================ import requests from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.managers.llm_provider import LLMProviderManager from tests.integration.common_utils.test_models import DATestLLMProvider from tests.integration.common_utils.test_models import DATestUser SEARCH_SETTINGS_URL = f"{API_SERVER_URL}/search-settings" def _get_current_search_settings(user: DATestUser) -> dict: response = requests.get( f"{SEARCH_SETTINGS_URL}/get-current-search-settings", headers=user.headers, ) response.raise_for_status() return response.json() def _get_all_search_settings(user: DATestUser) -> dict: response = requests.get( f"{SEARCH_SETTINGS_URL}/get-all-search-settings", headers=user.headers, ) response.raise_for_status() return response.json() def _get_secondary_search_settings(user: DATestUser) -> dict | None: response = requests.get( f"{SEARCH_SETTINGS_URL}/get-secondary-search-settings", headers=user.headers, ) response.raise_for_status() return response.json() def _update_inference_settings(user: DATestUser, settings: dict) -> None: response = requests.post( f"{SEARCH_SETTINGS_URL}/update-inference-settings", json=settings, headers=user.headers, ) response.raise_for_status() def _set_new_search_settings( user: DATestUser, current_settings: dict, enable_contextual_rag: bool = False, contextual_rag_llm_name: str | None = None, contextual_rag_llm_provider: str | None = None, ) -> requests.Response: """POST to set-new-search-settings, deriving the payload from current settings.""" payload = { "model_name": current_settings["model_name"], "model_dim": current_settings["model_dim"], "normalize": current_settings["normalize"], "query_prefix": current_settings.get("query_prefix") or "", "passage_prefix": current_settings.get("passage_prefix") or "", "provider_type": current_settings.get("provider_type"), "index_name": None, "multipass_indexing": current_settings.get("multipass_indexing", False), "embedding_precision": current_settings["embedding_precision"], "reduced_dimension": current_settings.get("reduced_dimension"), "enable_contextual_rag": enable_contextual_rag, "contextual_rag_llm_name": contextual_rag_llm_name, "contextual_rag_llm_provider": contextual_rag_llm_provider, } return requests.post( f"{SEARCH_SETTINGS_URL}/set-new-search-settings", json=payload, headers=user.headers, ) def _cancel_new_embedding(user: DATestUser) -> None: response = requests.post( f"{SEARCH_SETTINGS_URL}/cancel-new-embedding", headers=user.headers, ) response.raise_for_status() def test_get_current_search_settings( reset: None, # noqa: ARG001 admin_user: DATestUser, ) -> None: """Verify that GET current search settings returns expected fields.""" settings = _get_current_search_settings(admin_user) assert "model_name" in settings assert "model_dim" in settings assert "enable_contextual_rag" in settings assert "contextual_rag_llm_name" in settings assert "contextual_rag_llm_provider" in settings assert "index_name" in settings assert "embedding_precision" in settings def test_get_all_search_settings( reset: None, # noqa: ARG001 admin_user: DATestUser, ) -> None: """Verify that GET all search settings returns current and secondary.""" all_settings = _get_all_search_settings(admin_user) assert "current_settings" in all_settings assert "secondary_settings" in all_settings assert all_settings["current_settings"] is not None assert "model_name" in all_settings["current_settings"] def test_get_secondary_search_settings_none_by_default( reset: None, # noqa: ARG001 admin_user: DATestUser, ) -> None: """Verify that no secondary search settings exist by default.""" secondary = _get_secondary_search_settings(admin_user) assert secondary is None def test_set_contextual_rag_model( reset: None, # noqa: ARG001 admin_user: DATestUser, llm_provider: DATestLLMProvider, ) -> None: """Set contextual RAG LLM model and verify it persists.""" settings = _get_current_search_settings(admin_user) settings["enable_contextual_rag"] = True settings["contextual_rag_llm_name"] = llm_provider.default_model_name settings["contextual_rag_llm_provider"] = llm_provider.name _update_inference_settings(admin_user, settings) updated = _get_current_search_settings(admin_user) assert updated["contextual_rag_llm_name"] == llm_provider.default_model_name assert updated["contextual_rag_llm_provider"] == llm_provider.name def test_unset_contextual_rag_model( reset: None, # noqa: ARG001 admin_user: DATestUser, llm_provider: DATestLLMProvider, ) -> None: """Set a contextual RAG model, then unset it and verify it becomes None.""" settings = _get_current_search_settings(admin_user) settings["enable_contextual_rag"] = True settings["contextual_rag_llm_name"] = llm_provider.default_model_name settings["contextual_rag_llm_provider"] = llm_provider.name _update_inference_settings(admin_user, settings) # Verify it's set updated = _get_current_search_settings(admin_user) assert updated["contextual_rag_llm_name"] == llm_provider.default_model_name assert updated["contextual_rag_llm_provider"] == llm_provider.name # Unset by disabling contextual RAG updated["enable_contextual_rag"] = False updated["contextual_rag_llm_name"] = None updated["contextual_rag_llm_provider"] = None _update_inference_settings(admin_user, updated) # Verify it's unset final = _get_current_search_settings(admin_user) assert final["contextual_rag_llm_name"] is None assert final["contextual_rag_llm_provider"] is None def test_change_contextual_rag_model( reset: None, # noqa: ARG001 admin_user: DATestUser, llm_provider: DATestLLMProvider, ) -> None: """Change contextual RAG from one model to another and verify the switch.""" second_provider = LLMProviderManager.create( name="second-provider", default_model_name="gpt-4o", user_performing_action=admin_user, ) settings = _get_current_search_settings(admin_user) settings["enable_contextual_rag"] = True settings["contextual_rag_llm_name"] = llm_provider.default_model_name settings["contextual_rag_llm_provider"] = llm_provider.name _update_inference_settings(admin_user, settings) updated = _get_current_search_settings(admin_user) assert updated["contextual_rag_llm_name"] == llm_provider.default_model_name assert updated["contextual_rag_llm_provider"] == llm_provider.name # Switch to a different model and provider updated["enable_contextual_rag"] = True updated["contextual_rag_llm_name"] = second_provider.default_model_name updated["contextual_rag_llm_provider"] = second_provider.name _update_inference_settings(admin_user, updated) final = _get_current_search_settings(admin_user) assert final["contextual_rag_llm_name"] == second_provider.default_model_name assert final["contextual_rag_llm_provider"] == second_provider.name def test_change_contextual_rag_provider_only( reset: None, # noqa: ARG001 admin_user: DATestUser, llm_provider: DATestLLMProvider, ) -> None: """Change only the provider while keeping the same model name.""" shared_model_name = llm_provider.default_model_name second_provider = LLMProviderManager.create( name="second-provider", default_model_name=shared_model_name, user_performing_action=admin_user, ) settings = _get_current_search_settings(admin_user) settings["enable_contextual_rag"] = True settings["contextual_rag_llm_name"] = shared_model_name settings["contextual_rag_llm_provider"] = llm_provider.name _update_inference_settings(admin_user, settings) updated = _get_current_search_settings(admin_user) updated["enable_contextual_rag"] = True updated["contextual_rag_llm_provider"] = second_provider.name _update_inference_settings(admin_user, updated) final = _get_current_search_settings(admin_user) assert final["contextual_rag_llm_name"] == shared_model_name assert final["contextual_rag_llm_provider"] == second_provider.name def test_enable_contextual_rag_preserved_on_inference_update( reset: None, # noqa: ARG001 admin_user: DATestUser, ) -> None: """Verify that enable_contextual_rag cannot be toggled via update-inference-settings because it is a preserved field.""" settings = _get_current_search_settings(admin_user) original_enable = settings["enable_contextual_rag"] # Attempt to flip the flag settings["enable_contextual_rag"] = not original_enable settings["contextual_rag_llm_name"] = None settings["contextual_rag_llm_provider"] = None _update_inference_settings(admin_user, settings) updated = _get_current_search_settings(admin_user) assert updated["enable_contextual_rag"] == original_enable def test_model_name_preserved_on_inference_update( reset: None, # noqa: ARG001 admin_user: DATestUser, ) -> None: """Verify that model_name cannot be changed via update-inference-settings because it is a preserved field.""" settings = _get_current_search_settings(admin_user) original_model_name = settings["model_name"] settings["model_name"] = "some-other-model" _update_inference_settings(admin_user, settings) updated = _get_current_search_settings(admin_user) assert updated["model_name"] == original_model_name def test_contextual_rag_settings_reflected_in_get_all( reset: None, # noqa: ARG001 admin_user: DATestUser, llm_provider: DATestLLMProvider, ) -> None: """Verify that contextual RAG updates appear in get-all-search-settings.""" settings = _get_current_search_settings(admin_user) settings["enable_contextual_rag"] = True settings["contextual_rag_llm_name"] = llm_provider.default_model_name settings["contextual_rag_llm_provider"] = llm_provider.name _update_inference_settings(admin_user, settings) all_settings = _get_all_search_settings(admin_user) current = all_settings["current_settings"] assert current["contextual_rag_llm_name"] == llm_provider.default_model_name assert current["contextual_rag_llm_provider"] == llm_provider.name def test_update_contextual_rag_nonexistent_provider( reset: None, # noqa: ARG001 admin_user: DATestUser, ) -> None: """Updating with a provider that does not exist should return 400.""" settings = _get_current_search_settings(admin_user) settings["enable_contextual_rag"] = True settings["contextual_rag_llm_name"] = "some-model" settings["contextual_rag_llm_provider"] = "nonexistent-provider" response = requests.post( f"{SEARCH_SETTINGS_URL}/update-inference-settings", json=settings, headers=admin_user.headers, ) assert response.status_code == 400 assert "Provider nonexistent-provider not found" in response.json()["detail"] def test_update_contextual_rag_nonexistent_model( reset: None, # noqa: ARG001 admin_user: DATestUser, llm_provider: DATestLLMProvider, ) -> None: """Updating with a valid provider but a model not in that provider should return 400.""" settings = _get_current_search_settings(admin_user) settings["enable_contextual_rag"] = True settings["contextual_rag_llm_name"] = "nonexistent-model" settings["contextual_rag_llm_provider"] = llm_provider.name response = requests.post( f"{SEARCH_SETTINGS_URL}/update-inference-settings", json=settings, headers=admin_user.headers, ) assert response.status_code == 400 assert ( f"Model nonexistent-model not found in provider {llm_provider.name}" in response.json()["detail"] ) def test_update_contextual_rag_missing_provider_name( reset: None, # noqa: ARG001 admin_user: DATestUser, ) -> None: """Providing a model name without a provider name should return 400.""" settings = _get_current_search_settings(admin_user) settings["enable_contextual_rag"] = True settings["contextual_rag_llm_name"] = "some-model" settings["contextual_rag_llm_provider"] = None response = requests.post( f"{SEARCH_SETTINGS_URL}/update-inference-settings", json=settings, headers=admin_user.headers, ) assert response.status_code == 400 assert "Provider name and model name are required" in response.json()["detail"] def test_update_contextual_rag_missing_model_name( reset: None, # noqa: ARG001 admin_user: DATestUser, llm_provider: DATestLLMProvider, ) -> None: """Providing a provider name without a model name should return 400.""" settings = _get_current_search_settings(admin_user) settings["enable_contextual_rag"] = True settings["contextual_rag_llm_name"] = None settings["contextual_rag_llm_provider"] = llm_provider.name response = requests.post( f"{SEARCH_SETTINGS_URL}/update-inference-settings", json=settings, headers=admin_user.headers, ) assert response.status_code == 400 assert "Provider name and model name are required" in response.json()["detail"] def test_set_new_search_settings_with_contextual_rag( reset: None, # noqa: ARG001 admin_user: DATestUser, llm_provider: DATestLLMProvider, ) -> None: """Create new search settings with contextual RAG enabled and verify the secondary settings contain the correct provider and model.""" current = _get_current_search_settings(admin_user) response = _set_new_search_settings( user=admin_user, current_settings=current, enable_contextual_rag=True, contextual_rag_llm_name=llm_provider.default_model_name, contextual_rag_llm_provider=llm_provider.name, ) response.raise_for_status() assert "id" in response.json() secondary = _get_secondary_search_settings(admin_user) assert secondary is not None assert secondary["enable_contextual_rag"] is True assert secondary["contextual_rag_llm_name"] == llm_provider.default_model_name assert secondary["contextual_rag_llm_provider"] == llm_provider.name _cancel_new_embedding(admin_user) def test_set_new_search_settings_without_contextual_rag( reset: None, # noqa: ARG001 admin_user: DATestUser, ) -> None: """Create new search settings with contextual RAG disabled and verify the secondary settings have no RAG provider.""" current = _get_current_search_settings(admin_user) response = _set_new_search_settings( user=admin_user, current_settings=current, enable_contextual_rag=False, ) response.raise_for_status() secondary = _get_secondary_search_settings(admin_user) assert secondary is not None assert secondary["enable_contextual_rag"] is False assert secondary["contextual_rag_llm_name"] is None assert secondary["contextual_rag_llm_provider"] is None _cancel_new_embedding(admin_user) def test_set_new_then_update_inference_settings( reset: None, # noqa: ARG001 admin_user: DATestUser, llm_provider: DATestLLMProvider, ) -> None: """Create new secondary settings, then update the current (primary) settings with contextual RAG and verify both are visible through get-all.""" current = _get_current_search_settings(admin_user) # Create secondary settings without contextual RAG response = _set_new_search_settings( user=admin_user, current_settings=current, enable_contextual_rag=False, ) response.raise_for_status() # Update the *current* (primary) settings with a contextual RAG provider current["enable_contextual_rag"] = True current["contextual_rag_llm_name"] = llm_provider.default_model_name current["contextual_rag_llm_provider"] = llm_provider.name _update_inference_settings(admin_user, current) all_settings = _get_all_search_settings(admin_user) primary = all_settings["current_settings"] assert primary["contextual_rag_llm_name"] == llm_provider.default_model_name assert primary["contextual_rag_llm_provider"] == llm_provider.name secondary = all_settings["secondary_settings"] assert secondary is not None assert secondary["contextual_rag_llm_name"] is None assert secondary["contextual_rag_llm_provider"] is None _cancel_new_embedding(admin_user) def test_set_new_search_settings_replaces_previous_secondary( reset: None, # noqa: ARG001 admin_user: DATestUser, llm_provider: DATestLLMProvider, ) -> None: """Calling set-new-search-settings twice should retire the first secondary and replace it with the second.""" current = _get_current_search_settings(admin_user) # First: no contextual RAG resp1 = _set_new_search_settings( user=admin_user, current_settings=current, enable_contextual_rag=False, ) resp1.raise_for_status() first_id = resp1.json()["id"] # Second: with contextual RAG resp2 = _set_new_search_settings( user=admin_user, current_settings=current, enable_contextual_rag=True, contextual_rag_llm_name=llm_provider.default_model_name, contextual_rag_llm_provider=llm_provider.name, ) resp2.raise_for_status() second_id = resp2.json()["id"] assert second_id != first_id secondary = _get_secondary_search_settings(admin_user) assert secondary is not None assert secondary["enable_contextual_rag"] is True assert secondary["contextual_rag_llm_name"] == llm_provider.default_model_name assert secondary["contextual_rag_llm_provider"] == llm_provider.name _cancel_new_embedding(admin_user) ================================================ FILE: backend/tests/integration/tests/streaming_endpoints/test_chat_file_attachment.py ================================================ import mimetypes from typing import Any import requests from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.managers.chat import ChatSessionManager from tests.integration.common_utils.managers.file import FileManager from tests.integration.common_utils.managers.llm_provider import LLMProviderManager from tests.integration.common_utils.test_file_utils import create_test_image from tests.integration.common_utils.test_file_utils import create_test_text_file from tests.integration.common_utils.test_models import DATestUser def test_send_message_with_image_attachment(admin_user: DATestUser) -> None: """Test sending a chat message with an attached image file.""" LLMProviderManager.create(user_performing_action=admin_user) # Create a simple test image image_file = create_test_image(width=100, height=100, color="blue") # Upload the image file file_descriptors, error = FileManager.upload_files( files=[("test_image.png", image_file)], user_performing_action=admin_user, ) assert not error, f"File upload should succeed, but got error: {error}" assert len(file_descriptors) == 1, "Should have uploaded one file" assert file_descriptors[0]["type"] == "image", "File should be identified as image" # Create a chat session test_chat_session = ChatSessionManager.create(user_performing_action=admin_user) # Send a message with the image attachment response = ChatSessionManager.send_message( chat_session_id=test_chat_session.id, message="What color is this image?", user_performing_action=admin_user, file_descriptors=file_descriptors, ) # Verify that the message was processed successfully assert response.error is None, "Chat response should not have an error" assert ( "blue" in response.full_message.lower() ), "Chat response should contain the color of the image" def test_send_message_with_text_file_attachment(admin_user: DATestUser) -> None: """Test sending a chat message with an attached text file.""" LLMProviderManager.create(user_performing_action=admin_user) # Create a simple test text file text_file = create_test_text_file( "This is a test document.\nIt has multiple lines.\nThis is the third line." ) # Upload the text file file_descriptors, error = FileManager.upload_files( files=[("test_document.txt", text_file)], user_performing_action=admin_user, ) assert not error, f"File upload should succeed, but got error: {error}" assert len(file_descriptors) == 1, "Should have uploaded one file" assert file_descriptors[0]["type"] in [ "plain_text", "document", ], "File should be identified as text or document" # Create a chat session test_chat_session = ChatSessionManager.create(user_performing_action=admin_user) # Send a message with the text file attachment response = ChatSessionManager.send_message( chat_session_id=test_chat_session.id, message="Repeat the contents of this file word for word.", user_performing_action=admin_user, file_descriptors=file_descriptors, ) # Verify that the message was processed successfully assert response.error is None, "Chat response should not have an error" assert ( "third line" in response.full_message.lower() ), "Chat response should contain the contents of the file" def _set_token_threshold(admin_user: DATestUser, threshold_k: int) -> None: """Set the file token count threshold via admin settings API.""" response = requests.put( f"{API_SERVER_URL}/admin/settings", json={"file_token_count_threshold_k": threshold_k}, headers=admin_user.headers, ) response.raise_for_status() def _upload_raw( filename: str, content: bytes, user: DATestUser, ) -> dict[str, Any]: """Upload a file and return the full JSON response (user_files + rejected_files).""" mime_type, _ = mimetypes.guess_type(filename) headers = user.headers.copy() headers.pop("Content-Type", None) response = requests.post( f"{API_SERVER_URL}/user/projects/file/upload", files=[("files", (filename, content, mime_type or "application/octet-stream"))], headers=headers, ) response.raise_for_status() return response.json() def test_csv_over_token_threshold_uploaded_not_indexed( admin_user: DATestUser, ) -> None: """CSV exceeding token threshold is uploaded (accepted) but skips indexing.""" _set_token_threshold(admin_user, threshold_k=1) try: # ~2000 tokens with default tokenizer, well over 1K threshold content = ("x " * 100 + "\n") * 20 result = _upload_raw("large.csv", content.encode(), admin_user) assert len(result["user_files"]) == 1, "CSV should be accepted" assert len(result["rejected_files"]) == 0, "CSV should not be rejected" assert ( result["user_files"][0]["status"] == "SKIPPED" ), "CSV over threshold should be SKIPPED (uploaded but not indexed)" assert ( result["user_files"][0]["chunk_count"] is None ), "Skipped file should have no chunks" finally: _set_token_threshold(admin_user, threshold_k=200) def test_csv_under_token_threshold_uploaded_and_indexed( admin_user: DATestUser, ) -> None: """CSV under token threshold is uploaded and queued for indexing.""" _set_token_threshold(admin_user, threshold_k=200) try: content = "col1,col2\na,b\n" result = _upload_raw("small.csv", content.encode(), admin_user) assert len(result["user_files"]) == 1, "CSV should be accepted" assert len(result["rejected_files"]) == 0, "CSV should not be rejected" assert ( result["user_files"][0]["status"] == "PROCESSING" ), "CSV under threshold should be PROCESSING (queued for indexing)" finally: _set_token_threshold(admin_user, threshold_k=200) def test_txt_over_token_threshold_rejected( admin_user: DATestUser, ) -> None: """Non-exempt file exceeding token threshold is rejected entirely.""" _set_token_threshold(admin_user, threshold_k=1) try: # ~2000 tokens, well over 1K threshold. Unlike CSV, .txt is not # exempt from the threshold so the file should be rejected. content = ("x " * 100 + "\n") * 20 result = _upload_raw("big.txt", content.encode(), admin_user) assert len(result["user_files"]) == 0, "File should not be accepted" assert len(result["rejected_files"]) == 1, "File should be rejected" assert "token limit" in result["rejected_files"][0]["reason"].lower() finally: _set_token_threshold(admin_user, threshold_k=200) ================================================ FILE: backend/tests/integration/tests/streaming_endpoints/test_chat_stream.py ================================================ import time from onyx.configs.constants import MessageType from tests.integration.common_utils.managers.chat import ChatSessionManager from tests.integration.common_utils.managers.llm_provider import LLMProviderManager from tests.integration.common_utils.test_models import DATestUser from tests.integration.conftest import DocumentBuilderType TERMINATED_RESPONSE_MESSAGE = ( "Response was terminated prior to completion, try regenerating." ) LOADING_RESPONSE_MESSAGE = "Message is loading... Please refresh the page soon." def test_send_two_messages(basic_user: DATestUser) -> None: # Create a chat session test_chat_session = ChatSessionManager.create( persona_id=0, # Use default persona description="Test chat session for multiple messages", user_performing_action=basic_user, ) # Send a message to create some data response = ChatSessionManager.send_message( chat_session_id=test_chat_session.id, message="hello", user_performing_action=basic_user, ) # Verify that the message was processed successfully assert response.error is None, "Chat response should not have an error" assert len(response.full_message) > 0, "Chat response should not be empty" # Verify that the chat session can be retrieved before deletion chat_history = ChatSessionManager.get_chat_history( chat_session=test_chat_session, user_performing_action=basic_user, ) assert ( len(chat_history) == 3 ), "Chat session should have 1 system message, 1 user message, and 1 assistant message" response2 = ChatSessionManager.send_message( chat_session_id=test_chat_session.id, message="hello again", user_performing_action=basic_user, parent_message_id=response.assistant_message_id, ) assert response2.error is None, "Chat response should not have an error" assert len(response2.full_message) > 0, "Chat response should not be empty" # Verify that the chat session can be retrieved before deletion chat_history2 = ChatSessionManager.get_chat_history( chat_session=test_chat_session, user_performing_action=basic_user, ) assert ( len(chat_history2) == 5 ), "Chat session should have 1 system message, 2 user messages, and 2 assistant messages" def test_send_message_simple_with_history( reset: None, # noqa: ARG001 admin_user: DATestUser, ) -> None: LLMProviderManager.create(user_performing_action=admin_user) test_chat_session = ChatSessionManager.create(user_performing_action=admin_user) response = ChatSessionManager.send_message( chat_session_id=test_chat_session.id, message="this is a test message", user_performing_action=admin_user, ) assert response.error is None, "Chat response should not have an error" assert len(response.full_message) > 0 def test_send_message__basic_searches( reset: None, # noqa: ARG001 admin_user: DATestUser, document_builder: DocumentBuilderType, ) -> None: MESSAGE = "run a search for 'test'. Use the internal search tool." SHORT_DOC_CONTENT = "test" LONG_DOC_CONTENT = "blah blah blah blah" * 100 LLMProviderManager.create(user_performing_action=admin_user) short_doc = document_builder([SHORT_DOC_CONTENT])[0] test_chat_session = ChatSessionManager.create(user_performing_action=admin_user) response = ChatSessionManager.send_message( chat_session_id=test_chat_session.id, message=MESSAGE, user_performing_action=admin_user, ) assert response.error is None, "Chat response should not have an error" assert response.top_documents is not None assert len(response.top_documents) == 1 assert response.top_documents[0].document_id == short_doc.id # make sure this doc is really long so that it will be split into multiple chunks long_doc = document_builder([LONG_DOC_CONTENT])[0] # new chat session for simplicity test_chat_session = ChatSessionManager.create(user_performing_action=admin_user) response = ChatSessionManager.send_message( chat_session_id=test_chat_session.id, message=MESSAGE, user_performing_action=admin_user, ) assert response.error is None, "Chat response should not have an error" assert response.top_documents is not None assert len(response.top_documents) == 2 # short doc should be more relevant and thus first assert response.top_documents[0].document_id == short_doc.id assert response.top_documents[1].document_id == long_doc.id def test_send_message_disconnect_and_cleanup( reset: None, # noqa: ARG001 admin_user: DATestUser, ) -> None: """ Test that when a client disconnects mid-stream: 1. Client sends a message and disconnects after receiving just 1 packet 2. Client checks to see that their message ends up completed Note: There is an interim period (between disconnect and checkup) where we expect to see some sort of 'loading' message. """ LLMProviderManager.create(user_performing_action=admin_user) test_chat_session = ChatSessionManager.create(user_performing_action=admin_user) # Send a message and disconnect after receiving just 1 packet ChatSessionManager.send_message_with_disconnect( chat_session_id=test_chat_session.id, message="What are some important events that happened today?", user_performing_action=admin_user, disconnect_after_packets=1, ) # Every 5 seconds, check if we have the latest state of the chat session up to a minute increment_seconds = 1 max_seconds = 60 msg = TERMINATED_RESPONSE_MESSAGE for _ in range(max_seconds // increment_seconds): time.sleep(increment_seconds) # Get the chat history chat_history = ChatSessionManager.get_chat_history( chat_session=test_chat_session, user_performing_action=admin_user, ) # Find the assistant message assistant_message = None for chat_obj in chat_history: if chat_obj.message_type == MessageType.ASSISTANT: assistant_message = chat_obj break assert assistant_message is not None, "Assistant message should exist" msg = assistant_message.message if msg != TERMINATED_RESPONSE_MESSAGE and msg != LOADING_RESPONSE_MESSAGE: break assert ( msg != TERMINATED_RESPONSE_MESSAGE and msg != LOADING_RESPONSE_MESSAGE ), f"Assistant message should no longer be the terminated response message after cleanup, got: {msg}" ================================================ FILE: backend/tests/integration/tests/tags/test_tags.py ================================================ from onyx.configs.constants import DocumentSource from onyx.connectors.models import InputType from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import Document from onyx.db.tag import get_structured_tags_for_document from tests.integration.common_utils.managers.api_key import APIKeyManager from tests.integration.common_utils.managers.cc_pair import CCPairManager from tests.integration.common_utils.managers.document import DocumentManager from tests.integration.common_utils.managers.llm_provider import LLMProviderManager from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.test_models import DATestUser def test_tag_creation_and_update(reset: None) -> None: # noqa: ARG001 # create admin user admin_user: DATestUser = UserManager.create(email="admin@onyx.app") # create a minimal file connector cc_pair = CCPairManager.create_from_scratch( name="KG-Test-FileConnector", source=DocumentSource.FILE, input_type=InputType.LOAD_STATE, connector_specific_config={ "file_locations": [], "file_names": [], "zip_metadata_file_id": None, }, user_performing_action=admin_user, ) api_key = APIKeyManager.create(user_performing_action=admin_user) api_key.headers.update(admin_user.headers) LLMProviderManager.create(user_performing_action=admin_user) # create document doc1_expected_metadata: dict[str, str | list[str]] = { "value": "val", "multiple_list": ["a", "b", "c"], "single_list": ["x"], } doc1_expected_tags: set[tuple[str, str, bool]] = { ("value", "val", False), ("multiple_list", "a", True), ("multiple_list", "b", True), ("multiple_list", "c", True), ("single_list", "x", True), } doc1 = DocumentManager.seed_doc_with_content( cc_pair=cc_pair, content="Dummy content", document_id="doc1", metadata=doc1_expected_metadata, api_key=api_key, ) # these are added by the connector doc1_expected_metadata["document_id"] = "doc1" doc1_expected_tags.add(("document_id", "doc1", False)) # get document from db with get_session_with_current_tenant() as db_session: doc1_db = db_session.query(Document).filter(Document.id == doc1.id).first() assert doc1_db is not None assert doc1_db.id == doc1.id doc1_tags = doc1_db.tags # check tags doc1_tags_data: set[tuple[str, str, bool]] = { (tag.tag_key, tag.tag_value, tag.is_list) for tag in doc1_tags } assert doc1_tags_data == doc1_expected_tags # check structured tags with get_session_with_current_tenant() as db_session: doc1_metadata = get_structured_tags_for_document(doc1.id, db_session) assert doc1_metadata == doc1_expected_metadata # update metadata doc1_new_expected_metadata: dict[str, str | list[str]] = { "value": "val2", "multiple_list": ["a", "d"], "new_value": "new_val", } doc1_new_expected_tags: set[tuple[str, str, bool]] = { ("value", "val2", False), ("multiple_list", "a", True), ("multiple_list", "d", True), ("new_value", "new_val", False), } doc1_new = DocumentManager.seed_doc_with_content( cc_pair=cc_pair, content="Dummy content", document_id="doc1", metadata=doc1_new_expected_metadata, api_key=api_key, ) assert doc1_new.id == doc1.id # these are added by the connector doc1_new_expected_metadata["document_id"] = "doc1" doc1_new_expected_tags.add(("document_id", "doc1", False)) # get new document from db with get_session_with_current_tenant() as db_session: doc1_new_db = db_session.query(Document).filter(Document.id == doc1.id).first() assert doc1_new_db is not None assert doc1_new_db.id == doc1.id doc1_new_tags = doc1_new_db.tags # check tags doc1_new_tags_data: set[tuple[str, str, bool]] = { (tag.tag_key, tag.tag_value, tag.is_list) for tag in doc1_new_tags } assert doc1_new_tags_data == doc1_new_expected_tags # check structured tags with get_session_with_current_tenant() as db_session: doc1_new_metadata = get_structured_tags_for_document(doc1.id, db_session) assert doc1_new_metadata == doc1_new_expected_metadata def test_tag_sharing(reset: None) -> None: # noqa: ARG001 # create admin user admin_user: DATestUser = UserManager.create(email="admin@onyx.app") # create a minimal file connector cc_pair = CCPairManager.create_from_scratch( name="KG-Test-FileConnector", source=DocumentSource.FILE, input_type=InputType.LOAD_STATE, connector_specific_config={ "file_locations": [], "file_names": [], "zip_metadata_file_id": None, }, user_performing_action=admin_user, ) api_key = APIKeyManager.create(user_performing_action=admin_user) api_key.headers.update(admin_user.headers) LLMProviderManager.create(user_performing_action=admin_user) # create documents doc1_expected_metadata: dict[str, str | list[str]] = { "value": "val", "list": ["a", "b"], "same_key": "x", } doc1_expected_tags: set[tuple[str, str, bool]] = { ("value", "val", False), ("list", "a", True), ("list", "b", True), ("same_key", "x", False), } doc1 = DocumentManager.seed_doc_with_content( cc_pair=cc_pair, content="Dummy content", document_id="doc1", metadata=doc1_expected_metadata, api_key=api_key, ) doc2_expected_metadata: dict[str, str | list[str]] = { "value": "val", "list": ["a", "c"], "same_key": ["x"], } doc2_expected_tags: set[tuple[str, str, bool]] = { ("value", "val", False), ("list", "a", True), ("list", "c", True), ("same_key", "x", True), } doc2 = DocumentManager.seed_doc_with_content( cc_pair=cc_pair, content="Dummy content", document_id="doc2", metadata=doc2_expected_metadata, api_key=api_key, ) # these are added by the connector doc1_expected_metadata["document_id"] = "doc1" doc1_expected_tags.add(("document_id", "doc1", False)) doc2_expected_metadata["document_id"] = "doc2" doc2_expected_tags.add(("document_id", "doc2", False)) # get documents from db with get_session_with_current_tenant() as db_session: doc1_db = db_session.query(Document).filter(Document.id == doc1.id).first() doc2_db = db_session.query(Document).filter(Document.id == doc2.id).first() assert doc1_db is not None assert doc1_db.id == doc1.id assert doc2_db is not None assert doc2_db.id == doc2.id doc1_tags = doc1_db.tags doc2_tags = doc2_db.tags # check tags doc1_tags_data: set[tuple[str, str, bool]] = { (tag.tag_key, tag.tag_value, tag.is_list) for tag in doc1_tags } assert doc1_tags_data == doc1_expected_tags doc2_tags_data: set[tuple[str, str, bool]] = { (tag.tag_key, tag.tag_value, tag.is_list) for tag in doc2_tags } assert doc2_tags_data == doc2_expected_tags # check tag sharing doc1_tagkv_id: dict[tuple[str, str], int] = { (tag.tag_key, tag.tag_value): tag.id for tag in doc1_tags } doc2_tagkv_id: dict[tuple[str, str], int] = { (tag.tag_key, tag.tag_value): tag.id for tag in doc2_tags } assert doc1_tagkv_id[("value", "val")] == doc2_tagkv_id[("value", "val")] assert doc1_tagkv_id[("list", "a")] == doc2_tagkv_id[("list", "a")] assert doc1_tagkv_id[("same_key", "x")] != doc2_tagkv_id[("same_key", "x")] ================================================ FILE: backend/tests/integration/tests/tools/test_force_tool_use.py ================================================ """ Integration test for forced tool use to verify that web_search can be forced. This test verifies that forcing a tool use works through the complete API flow. """ import pytest from sqlalchemy import select from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.models import Tool from tests.integration.common_utils.managers.chat import ChatSessionManager from tests.integration.common_utils.test_models import DATestImageGenerationConfig from tests.integration.common_utils.test_models import DATestUser from tests.integration.common_utils.test_models import ToolName def test_force_tool_use( basic_user: DATestUser, image_generation_config: DATestImageGenerationConfig, # noqa: ARG001 ) -> None: with get_session_with_current_tenant() as db_session: image_generation_tool = db_session.execute( select(Tool).where(Tool.in_code_tool_id == "ImageGenerationTool") ).scalar_one_or_none() assert image_generation_tool is not None, "ImageGenerationTool must exist" image_generation_tool_id = image_generation_tool.id # Create a chat session chat_session = ChatSessionManager.create(user_performing_action=basic_user) # Send a simple message that wouldn't normally trigger image generation # but force the image generation tool to be used message = "hi" analyzed_response = ChatSessionManager.send_message( chat_session_id=chat_session.id, message=message, user_performing_action=basic_user, forced_tool_ids=[image_generation_tool_id], ) assert analyzed_response.error is None, "Chat response should not have an error" image_generation_tool_used = any( tool.tool_name == ToolName.IMAGE_GENERATION for tool in analyzed_response.used_tools ) assert ( image_generation_tool_used ), "Image generation tool should have been forced to run" if __name__ == "__main__": # Run with: python -m dotenv -f .vscode/.env run -- # python -m pytest backend/tests/integration/tests/tools/test_force_tool_use.py -v -s pytest.main([__file__, "-v", "-s"]) ================================================ FILE: backend/tests/integration/tests/tools/test_image_generation_streaming.py ================================================ """ Integration test for image generation heartbeat streaming through the /send-message API. This test verifies that heartbeat packets are properly streamed through the complete API flow. """ import time import pytest from onyx.server.query_and_chat.streaming_models import StreamingType from onyx.tools.tool_implementations.images.image_generation_tool import ( HEARTBEAT_INTERVAL, ) from tests.integration.common_utils.managers.chat import ChatSessionManager from tests.integration.common_utils.test_models import DATestImageGenerationConfig from tests.integration.common_utils.test_models import DATestLLMProvider from tests.integration.common_utils.test_models import DATestUser from tests.integration.common_utils.test_models import ToolName ART_PERSONA_ID = -3 def test_image_generation_streaming( basic_user: DATestUser, llm_provider: DATestLLMProvider, # noqa: ARG001 image_generation_config: DATestImageGenerationConfig, # noqa: ARG001 ) -> None: """ Test image generation to verify: 1. The image generation tool is invoked successfully 2. Heartbeat packets are streamed during generation 3. The response contains the generated image information This test uses the actual API without any mocking. """ # Create a chat session with this persona chat_session = ChatSessionManager.create(user_performing_action=basic_user) # Send a message that should trigger image generation # Use explicit instructions to ensure the image generation tool is used message = "Please generate an image of a beautiful sunset over the ocean. Use the image generation tool to create this image." start_time = time.monotonic() analyzed_response = ChatSessionManager.send_message( chat_session_id=chat_session.id, message=message, user_performing_action=basic_user, ) total_time = time.monotonic() - start_time assert analyzed_response.error is None, "Chat response should not have an error" # 1. Check if image generation tool was used image_gen_used = any( tool.tool_name == ToolName.IMAGE_GENERATION for tool in analyzed_response.used_tools ) assert image_gen_used # Verify we received heartbeat packets during image generation # Image generation typically takes a few seconds and sends heartbeats # every HEARTBEAT_INTERVAL seconds expected_heartbeat_packets = max(1, int(total_time / HEARTBEAT_INTERVAL) - 1) assert len(analyzed_response.heartbeat_packets) >= expected_heartbeat_packets, ( f"Expected at least {expected_heartbeat_packets} heartbeats for {total_time:.2f}s execution, " f"but got {len(analyzed_response.heartbeat_packets)}" ) # Verify the heartbeat packets have the expected structure for packet in analyzed_response.heartbeat_packets: assert "obj" in packet, "Heartbeat packet should have 'obj' field" assert ( packet["obj"].get("type") == StreamingType.IMAGE_GENERATION_HEARTBEAT.value ), f"Expected heartbeat type to be {StreamingType.IMAGE_GENERATION_HEARTBEAT.value}, got {packet['obj'].get('type')}" # 4. Verify image generation tool delta packets with actual image data image_tool_results = [ tool for tool in analyzed_response.used_tools if tool.tool_name == ToolName.IMAGE_GENERATION ] assert len(image_tool_results) > 0, "Should have image generation tool results" image_tool = image_tool_results[0] assert len(image_tool.images) > 0, "Should have generated at least one image" if __name__ == "__main__": # Run with: python -m dotenv -f .vscode/.env run -- # python -m pytest tests/integration/tests/tools/test_image_generation_heartbeat.py -v -s pytest.main([__file__, "-v", "-s"]) ================================================ FILE: backend/tests/integration/tests/usergroup/test_add_users_to_group.py ================================================ import os from uuid import uuid4 import pytest import requests from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.managers.user_group import UserGroupManager from tests.integration.common_utils.test_models import DATestUser from tests.integration.common_utils.test_models import DATestUserGroup @pytest.mark.skipif( os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true", reason="User group tests are enterprise only", ) def test_add_users_to_group(reset: None) -> None: # noqa: ARG001 admin_user: DATestUser = UserManager.create(name="admin_for_add_user") user_to_add: DATestUser = UserManager.create(name="basic_user_to_add") user_group: DATestUserGroup = UserGroupManager.create( name="add-user-test-group", user_ids=[admin_user.id], user_performing_action=admin_user, ) UserGroupManager.wait_for_sync( user_performing_action=admin_user, user_groups_to_check=[user_group], ) updated_user_group = UserGroupManager.add_users( user_group=user_group, user_ids=[user_to_add.id], user_performing_action=admin_user, ) fetched_user_groups = UserGroupManager.get_all(user_performing_action=admin_user) fetched_user_group = next( group for group in fetched_user_groups if group.id == updated_user_group.id ) fetched_user_ids = {user.id for user in fetched_user_group.users} assert admin_user.id in fetched_user_ids assert user_to_add.id in fetched_user_ids @pytest.mark.skipif( os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true", reason="User group tests are enterprise only", ) def test_add_users_to_group_invalid_user(reset: None) -> None: # noqa: ARG001 admin_user: DATestUser = UserManager.create(name="admin_for_add_user_invalid") user_group: DATestUserGroup = UserGroupManager.create( name="add-user-invalid-test-group", user_ids=[admin_user.id], user_performing_action=admin_user, ) invalid_user_id = str(uuid4()) response = requests.post( f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}/add-users", json={"user_ids": [invalid_user_id]}, headers=admin_user.headers, ) assert response.status_code == 404 assert "not found" in response.text.lower() ================================================ FILE: backend/tests/integration/tests/usergroup/test_group_membership_updates_user_permissions.py ================================================ import os import pytest from onyx.db.engine.sql_engine import get_session_with_current_tenant from onyx.db.enums import Permission from onyx.db.models import PermissionGrant from onyx.db.models import UserGroup as UserGroupModel from onyx.db.permissions import recompute_permissions_for_group__no_commit from onyx.db.permissions import recompute_user_permissions__no_commit from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.managers.user_group import UserGroupManager from tests.integration.common_utils.test_models import DATestUser @pytest.mark.skipif( os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true", reason="User group tests are enterprise only", ) def test_user_gets_permissions_when_added_to_group( reset: None, # noqa: ARG001 ) -> None: admin_user: DATestUser = UserManager.create(name="admin_for_perm_test") basic_user: DATestUser = UserManager.create(name="basic_user_for_perm_test") # basic_user starts with only "basic" from the default group initial_permissions = UserManager.get_permissions(basic_user) assert "basic" in initial_permissions assert "add:agents" not in initial_permissions # Create a new group and add basic_user group = UserGroupManager.create( name="perm-test-group", user_ids=[admin_user.id, basic_user.id], user_performing_action=admin_user, ) # Grant a non-basic permission to the group and recompute with get_session_with_current_tenant() as db_session: db_group = db_session.get(UserGroupModel, group.id) assert db_group is not None db_session.add( PermissionGrant( group_id=db_group.id, permission=Permission.ADD_AGENTS, grant_source="SYSTEM", ) ) db_session.flush() recompute_user_permissions__no_commit(basic_user.id, db_session) db_session.commit() # Verify the user gained the new permission (expanded includes read:agents) updated_permissions = UserManager.get_permissions(basic_user) assert ( "add:agents" in updated_permissions ), f"User should have 'add:agents' after group grant, got: {updated_permissions}" assert ( "read:agents" in updated_permissions ), f"User should have implied 'read:agents', got: {updated_permissions}" assert "basic" in updated_permissions @pytest.mark.skipif( os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true", reason="User group tests are enterprise only", ) def test_group_permission_change_propagates_to_all_members( reset: None, # noqa: ARG001 ) -> None: admin_user: DATestUser = UserManager.create(name="admin_propagate") user_a: DATestUser = UserManager.create(name="user_a_propagate") user_b: DATestUser = UserManager.create(name="user_b_propagate") group = UserGroupManager.create( name="propagate-test-group", user_ids=[admin_user.id, user_a.id, user_b.id], user_performing_action=admin_user, ) # Neither user should have add:agents yet for u in (user_a, user_b): assert "add:agents" not in UserManager.get_permissions(u) # Grant add:agents to the group, then batch-recompute with get_session_with_current_tenant() as db_session: grant = PermissionGrant( group_id=group.id, permission=Permission.ADD_AGENTS, grant_source="SYSTEM", ) db_session.add(grant) db_session.flush() recompute_permissions_for_group__no_commit(group.id, db_session) db_session.commit() # Both users should now have the permission (plus implied read:agents) for u in (user_a, user_b): perms = UserManager.get_permissions(u) assert "add:agents" in perms, f"{u.id} missing add:agents: {perms}" assert "read:agents" in perms, f"{u.id} missing implied read:agents: {perms}" # Soft-delete the grant and recompute — permission should be removed with get_session_with_current_tenant() as db_session: db_grant = ( db_session.query(PermissionGrant) .filter_by(group_id=group.id, permission=Permission.ADD_AGENTS) .first() ) assert db_grant is not None db_grant.is_deleted = True db_session.flush() recompute_permissions_for_group__no_commit(group.id, db_session) db_session.commit() for u in (user_a, user_b): perms = UserManager.get_permissions(u) assert "add:agents" not in perms, f"{u.id} still has add:agents: {perms}" ================================================ FILE: backend/tests/integration/tests/usergroup/test_new_group_gets_basic_permission.py ================================================ import os import pytest from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.managers.user_group import UserGroupManager from tests.integration.common_utils.test_models import DATestUser @pytest.mark.skipif( os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true", reason="User group tests are enterprise only", ) def test_new_group_gets_basic_permission(reset: None) -> None: # noqa: ARG001 admin_user: DATestUser = UserManager.create(name="admin_for_basic_perm") user_group = UserGroupManager.create( name="basic-perm-test-group", user_ids=[admin_user.id], user_performing_action=admin_user, ) permissions = UserGroupManager.get_permissions( user_group=user_group, user_performing_action=admin_user, ) assert ( "basic" in permissions ), f"New group should have 'basic' permission, got: {permissions}" ================================================ FILE: backend/tests/integration/tests/usergroup/test_user_group_deletion.py ================================================ """ This tests the deletion of a user group with the following foreign key constraints: - connector_credential_pair - user - credential - llm_provider - document_set - token_rate_limit (Not Implemented) - persona """ import os import pytest from onyx.server.documents.models import DocumentSource from tests.integration.common_utils.managers.cc_pair import CCPairManager from tests.integration.common_utils.managers.credential import CredentialManager from tests.integration.common_utils.managers.document_set import DocumentSetManager from tests.integration.common_utils.managers.llm_provider import LLMProviderManager from tests.integration.common_utils.managers.persona import PersonaManager from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.managers.user_group import UserGroupManager from tests.integration.common_utils.test_models import DATestCredential from tests.integration.common_utils.test_models import DATestDocumentSet from tests.integration.common_utils.test_models import DATestLLMProvider from tests.integration.common_utils.test_models import DATestPersona from tests.integration.common_utils.test_models import DATestUser from tests.integration.common_utils.test_models import DATestUserGroup from tests.integration.common_utils.vespa import vespa_fixture @pytest.mark.skipif( os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true", reason="User group tests are enterprise only", ) def test_user_group_deletion( reset: None, # noqa: ARG001 vespa_client: vespa_fixture, # noqa: ARG001 ) -> None: # Creating an admin user (first user created is automatically an admin) admin_user: DATestUser = UserManager.create(name="admin_user") # create connectors cc_pair = CCPairManager.create_from_scratch( source=DocumentSource.INGESTION_API, user_performing_action=admin_user, ) # Create user group with a cc_pair and a user user_group: DATestUserGroup = UserGroupManager.create( user_ids=[admin_user.id], cc_pair_ids=[cc_pair.id], user_performing_action=admin_user, ) cc_pair.groups = [user_group.id] UserGroupManager.wait_for_sync( user_groups_to_check=[user_group], user_performing_action=admin_user ) UserGroupManager.verify( user_group=user_group, user_performing_action=admin_user, ) CCPairManager.verify( cc_pair=cc_pair, user_performing_action=admin_user, ) # Create other objects that are related to the user group credential: DATestCredential = CredentialManager.create( groups=[user_group.id], user_performing_action=admin_user, ) document_set: DATestDocumentSet = DocumentSetManager.create( cc_pair_ids=[cc_pair.id], groups=[user_group.id], user_performing_action=admin_user, ) llm_provider: DATestLLMProvider = LLMProviderManager.create( groups=[user_group.id], user_performing_action=admin_user, ) persona: DATestPersona = PersonaManager.create( groups=[user_group.id], user_performing_action=admin_user, ) UserGroupManager.wait_for_sync( user_groups_to_check=[user_group], user_performing_action=admin_user ) UserGroupManager.verify( user_group=user_group, user_performing_action=admin_user, ) # Delete the user group UserGroupManager.delete( user_group=user_group, user_performing_action=admin_user, ) UserGroupManager.wait_for_deletion_completion( user_groups_to_check=[user_group], user_performing_action=admin_user ) # Set our expected local representations to empty credential.groups = [] document_set.groups = [] llm_provider.groups = [] persona.groups = [] # Verify that the local representations were updated CredentialManager.verify( credential=credential, user_performing_action=admin_user, ) DocumentSetManager.verify( document_set=document_set, user_performing_action=admin_user, ) LLMProviderManager.verify( llm_provider=llm_provider, user_performing_action=admin_user, ) PersonaManager.verify( persona=persona, user_performing_action=admin_user, ) ================================================ FILE: backend/tests/integration/tests/usergroup/test_usergroup_syncing.py ================================================ import os import pytest from onyx.server.documents.models import DocumentSource from tests.integration.common_utils.constants import NUM_DOCS from tests.integration.common_utils.managers.api_key import APIKeyManager from tests.integration.common_utils.managers.cc_pair import CCPairManager from tests.integration.common_utils.managers.document import DocumentManager from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.managers.user_group import UserGroupManager from tests.integration.common_utils.test_models import DATestAPIKey from tests.integration.common_utils.test_models import DATestUser from tests.integration.common_utils.test_models import DATestUserGroup from tests.integration.common_utils.vespa import vespa_fixture @pytest.mark.skipif( os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true", reason="User group tests are enterprise only", ) def test_removing_connector( reset: None, # noqa: ARG001 vespa_client: vespa_fixture, ) -> None: # Creating an admin user (first user created is automatically an admin) admin_user: DATestUser = UserManager.create(name="admin_user") # create api key api_key: DATestAPIKey = APIKeyManager.create( user_performing_action=admin_user, ) # create connectors cc_pair_1 = CCPairManager.create_from_scratch( source=DocumentSource.INGESTION_API, user_performing_action=admin_user, ) cc_pair_2 = CCPairManager.create_from_scratch( source=DocumentSource.INGESTION_API, user_performing_action=admin_user, ) # seed documents cc_pair_1.documents = DocumentManager.seed_dummy_docs( cc_pair=cc_pair_1, num_docs=NUM_DOCS, api_key=api_key, ) cc_pair_2.documents = DocumentManager.seed_dummy_docs( cc_pair=cc_pair_2, num_docs=NUM_DOCS, api_key=api_key, ) # Create user group user_group_1: DATestUserGroup = UserGroupManager.create( cc_pair_ids=[cc_pair_1.id, cc_pair_2.id], user_performing_action=admin_user, ) UserGroupManager.wait_for_sync( user_groups_to_check=[user_group_1], user_performing_action=admin_user ) UserGroupManager.verify( user_group=user_group_1, user_performing_action=admin_user, ) # make sure cc_pair_1 docs are user_group_1 only DocumentManager.verify( vespa_client=vespa_client, cc_pair=cc_pair_1, group_names=[user_group_1.name], doc_creating_user=admin_user, ) # make sure cc_pair_2 docs are user_group_1 only DocumentManager.verify( vespa_client=vespa_client, cc_pair=cc_pair_2, group_names=[user_group_1.name], doc_creating_user=admin_user, ) # remove cc_pair_2 from document set user_group_1.cc_pair_ids = [cc_pair_1.id] UserGroupManager.edit( user_group_1, user_performing_action=admin_user, ) UserGroupManager.wait_for_sync( user_performing_action=admin_user, ) # make sure cc_pair_1 docs are user_group_1 only DocumentManager.verify( vespa_client=vespa_client, cc_pair=cc_pair_1, group_names=[user_group_1.name], doc_creating_user=admin_user, ) # make sure cc_pair_2 docs have no user group DocumentManager.verify( vespa_client=vespa_client, cc_pair=cc_pair_2, group_names=[], doc_creating_user=admin_user, ) ================================================ FILE: backend/tests/integration/tests/users/test_default_group_assignment.py ================================================ """Integration tests for default group assignment on user registration. Verifies that: - The first registered user is assigned to the Admin default group - Subsequent registered users are assigned to the Basic default group - account_type is set to STANDARD for email/password registrations """ from onyx.auth.schemas import UserRole from onyx.db.enums import AccountType from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.managers.user_group import UserGroupManager from tests.integration.common_utils.test_models import DATestUser def test_default_group_assignment_on_registration(reset: None) -> None: # noqa: ARG001 # Register first user — should become admin admin_user: DATestUser = UserManager.create(name="first_user") assert admin_user.role == UserRole.ADMIN # Register second user — should become basic basic_user: DATestUser = UserManager.create(name="second_user") assert basic_user.role == UserRole.BASIC # Fetch all groups including default ones all_groups = UserGroupManager.get_all( user_performing_action=admin_user, include_default=True, ) # Find the default Admin and Basic groups admin_group = next( (g for g in all_groups if g.name == "Admin" and g.is_default), None ) basic_group = next( (g for g in all_groups if g.name == "Basic" and g.is_default), None ) assert admin_group is not None, "Admin default group not found" assert basic_group is not None, "Basic default group not found" # Verify admin user is in Admin group and NOT in Basic group admin_group_user_ids = {str(u.id) for u in admin_group.users} basic_group_user_ids = {str(u.id) for u in basic_group.users} assert ( admin_user.id in admin_group_user_ids ), "First user should be in Admin default group" assert ( admin_user.id not in basic_group_user_ids ), "First user should NOT be in Basic default group" # Verify basic user is in Basic group and NOT in Admin group assert ( basic_user.id in basic_group_user_ids ), "Second user should be in Basic default group" assert ( basic_user.id not in admin_group_user_ids ), "Second user should NOT be in Admin default group" # Verify account_type is STANDARD for both users via user listing API paginated_result = UserManager.get_user_page( user_performing_action=admin_user, page_num=0, page_size=10, ) users_by_id = {str(u.id): u for u in paginated_result.items} admin_snapshot = users_by_id.get(admin_user.id) basic_snapshot = users_by_id.get(basic_user.id) assert admin_snapshot is not None, "Admin user not found in user listing" assert basic_snapshot is not None, "Basic user not found in user listing" assert ( admin_snapshot.account_type == AccountType.STANDARD ), f"Admin user account_type should be STANDARD, got {admin_snapshot.account_type}" assert ( basic_snapshot.account_type == AccountType.STANDARD ), f"Basic user account_type should be STANDARD, got {basic_snapshot.account_type}" ================================================ FILE: backend/tests/integration/tests/users/test_password_signup_upgrade.py ================================================ """Integration tests for password signup upgrade paths. Verifies that when a BOT or EXT_PERM_USER user signs up via email/password: - Their account_type is upgraded to STANDARD - They are assigned to the Basic default group - They gain the correct effective permissions """ import pytest from onyx.auth.schemas import UserRole from onyx.db.enums import AccountType from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.managers.user_group import UserGroupManager from tests.integration.common_utils.test_models import DATestUser def _get_default_group_member_emails( admin_user: DATestUser, group_name: str, ) -> set[str]: """Get the set of emails of all members in a named default group.""" all_groups = UserGroupManager.get_all(admin_user, include_default=True) matched = [g for g in all_groups if g.is_default and g.name == group_name] assert matched, f"Default group '{group_name}' not found" return {u.email for u in matched[0].users} @pytest.mark.parametrize( "target_role", [UserRole.EXT_PERM_USER, UserRole.SLACK_USER], ids=["ext_perm_user", "slack_user"], ) def test_password_signup_upgrade( reset: None, # noqa: ARG001 target_role: UserRole, ) -> None: """When a non-web user signs up via email/password, they should be upgraded to STANDARD account_type and assigned to the Basic default group.""" admin_user: DATestUser = UserManager.create(email="admin@example.com") test_email = f"{target_role.value}_upgrade@example.com" test_user = UserManager.create(email=test_email) test_user = UserManager.set_role( user_to_set=test_user, target_role=target_role, user_performing_action=admin_user, explicit_override=True, ) # Verify user was removed from Basic group after downgrade basic_emails = _get_default_group_member_emails(admin_user, "Basic") assert ( test_email not in basic_emails ), f"{target_role.value} should not be in Basic default group" # Re-register with the same email — triggers the password signup upgrade upgraded_user = UserManager.create(email=test_email) assert upgraded_user.role == UserRole.BASIC paginated = UserManager.get_user_page( user_performing_action=admin_user, page_num=0, page_size=10, ) user_snapshot = next( (u for u in paginated.items if str(u.id) == upgraded_user.id), None ) assert user_snapshot is not None assert ( user_snapshot.account_type == AccountType.STANDARD ), f"Expected STANDARD, got {user_snapshot.account_type}" # Verify user is now in the Basic default group basic_emails = _get_default_group_member_emails(admin_user, "Basic") assert ( test_email in basic_emails ), f"Upgraded user '{test_email}' not found in Basic default group" def test_password_signup_upgrade_propagates_permissions( reset: None, # noqa: ARG001 ) -> None: """When an EXT_PERM_USER or SLACK_USER signs up via password, they should gain the 'basic' permission through the Basic default group assignment.""" admin_user: DATestUser = UserManager.create(email="admin@example.com") # --- EXT_PERM_USER path --- ext_email = "ext_perms_check@example.com" ext_user = UserManager.create(email=ext_email) initial_perms = UserManager.get_permissions(ext_user) assert "basic" in initial_perms ext_user = UserManager.set_role( user_to_set=ext_user, target_role=UserRole.EXT_PERM_USER, user_performing_action=admin_user, explicit_override=True, ) basic_emails = _get_default_group_member_emails(admin_user, "Basic") assert ext_email not in basic_emails upgraded = UserManager.create(email=ext_email) assert upgraded.role == UserRole.BASIC perms = UserManager.get_permissions(upgraded) assert ( "basic" in perms ), f"Upgraded EXT_PERM_USER should have 'basic' permission, got: {perms}" # --- SLACK_USER path --- slack_email = "slack_perms_check@example.com" slack_user = UserManager.create(email=slack_email) slack_user = UserManager.set_role( user_to_set=slack_user, target_role=UserRole.SLACK_USER, user_performing_action=admin_user, explicit_override=True, ) basic_emails = _get_default_group_member_emails(admin_user, "Basic") assert slack_email not in basic_emails upgraded = UserManager.create(email=slack_email) assert upgraded.role == UserRole.BASIC perms = UserManager.get_permissions(upgraded) assert ( "basic" in perms ), f"Upgraded SLACK_USER should have 'basic' permission, got: {perms}" ================================================ FILE: backend/tests/integration/tests/users/test_reactivation_groups.py ================================================ """Integration tests for default group reconciliation on user reactivation. Verifies that: - A deactivated user retains default group membership after reactivation - Reactivation via the admin API reconciles missing group membership """ from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.managers.user_group import UserGroupManager from tests.integration.common_utils.test_models import DATestUser def _get_default_group_member_emails( admin_user: DATestUser, group_name: str, ) -> set[str]: """Get the set of emails of all members in a named default group.""" all_groups = UserGroupManager.get_all(admin_user, include_default=True) matched = [g for g in all_groups if g.is_default and g.name == group_name] assert matched, f"Default group '{group_name}' not found" return {u.email for u in matched[0].users} def test_reactivated_user_retains_default_group( reset: None, # noqa: ARG001 ) -> None: """Deactivating and reactivating a user should preserve their default group membership.""" admin_user: DATestUser = UserManager.create(name="admin_user") basic_user: DATestUser = UserManager.create(name="basic_user") # Verify user is in Basic group initially basic_emails = _get_default_group_member_emails(admin_user, "Basic") assert basic_user.email in basic_emails # Deactivate the user UserManager.set_status( user_to_set=basic_user, target_status=False, user_performing_action=admin_user, ) # Reactivate the user UserManager.set_status( user_to_set=basic_user, target_status=True, user_performing_action=admin_user, ) # Verify user is still in Basic group after reactivation basic_emails = _get_default_group_member_emails(admin_user, "Basic") assert ( basic_user.email in basic_emails ), "Reactivated user should still be in Basic default group" ================================================ FILE: backend/tests/integration/tests/users/test_seat_limit.py ================================================ """Integration tests for seat limit enforcement on user creation paths. Verifies that when a license with a seat limit is active, new user creation (registration, invite, reactivation) is blocked with HTTP 402. """ from datetime import datetime from datetime import timedelta import redis import requests from ee.onyx.server.license.models import LicenseMetadata from ee.onyx.server.license.models import LicenseSource from ee.onyx.server.license.models import PlanType from onyx.configs.app_configs import REDIS_DB_NUMBER from onyx.configs.app_configs import REDIS_HOST from onyx.configs.app_configs import REDIS_PORT from onyx.server.settings.models import ApplicationStatus from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.constants import GENERAL_HEADERS from tests.integration.common_utils.managers.user import UserManager # TenantRedis prefixes every key with "{tenant_id}:". # Single-tenant deployments use "public" as the tenant id. _LICENSE_REDIS_KEY = "public:license:metadata" def _seed_license(r: redis.Redis, seats: int) -> None: """Write a LicenseMetadata entry into Redis with the given seat cap.""" now = datetime.utcnow() metadata = LicenseMetadata( tenant_id="public", organization_name="Test Org", seats=seats, used_seats=0, # check_seat_availability recalculates from DB plan_type=PlanType.ANNUAL, issued_at=now, expires_at=now + timedelta(days=365), status=ApplicationStatus.ACTIVE, source=LicenseSource.MANUAL_UPLOAD, ) r.set(_LICENSE_REDIS_KEY, metadata.model_dump_json(), ex=300) def _clear_license(r: redis.Redis) -> None: r.delete(_LICENSE_REDIS_KEY) def _redis() -> redis.Redis: return redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER) # ------------------------------------------------------------------ # Registration # ------------------------------------------------------------------ def test_registration_blocked_when_seats_full( reset: None, # noqa: ARG001 ) -> None: # noqa: ARG001 """POST /auth/register returns 402 when the seat limit is reached.""" r = _redis() # First user is admin — occupies 1 seat UserManager.create(name="admin_user") # License allows exactly 1 seat → already full _seed_license(r, seats=1) try: response = requests.post( url=f"{API_SERVER_URL}/auth/register", json={ "email": "blocked@example.com", "username": "blocked@example.com", "password": "TestPassword123!", }, headers=GENERAL_HEADERS, ) assert response.status_code == 402 finally: _clear_license(r) # ------------------------------------------------------------------ # Invitation # ------------------------------------------------------------------ def test_invite_blocked_when_seats_full(reset: None) -> None: # noqa: ARG001 """PUT /manage/admin/users returns 402 when the seat limit is reached.""" r = _redis() admin_user = UserManager.create(name="admin_user") _seed_license(r, seats=1) try: response = requests.put( url=f"{API_SERVER_URL}/manage/admin/users", json={"emails": ["newuser@example.com"]}, headers=admin_user.headers, ) assert response.status_code == 402 finally: _clear_license(r) # ------------------------------------------------------------------ # Reactivation # ------------------------------------------------------------------ def test_reactivation_blocked_when_seats_full( reset: None, # noqa: ARG001 ) -> None: # noqa: ARG001 """PATCH /manage/admin/activate-user returns 402 when seats are full.""" r = _redis() admin_user = UserManager.create(name="admin_user") basic_user = UserManager.create(name="basic_user") # Deactivate the basic user (frees a seat in the DB count) UserManager.set_status( basic_user, target_status=False, user_performing_action=admin_user ) # Set license to 1 seat — only admin counts now _seed_license(r, seats=1) try: response = requests.patch( url=f"{API_SERVER_URL}/manage/admin/activate-user", json={"user_email": basic_user.email}, headers=admin_user.headers, ) assert response.status_code == 402 finally: _clear_license(r) # ------------------------------------------------------------------ # No license → no enforcement # ------------------------------------------------------------------ def test_registration_allowed_without_license( reset: None, # noqa: ARG001 ) -> None: # noqa: ARG001 """Without a license in Redis, registration is unrestricted.""" r = _redis() # Make sure there is no cached license _clear_license(r) UserManager.create(name="admin_user") # Second user should register without issue second_user = UserManager.create(name="second_user") assert second_user is not None ================================================ FILE: backend/tests/integration/tests/users/test_slack_user_deactivation.py ================================================ """Integration tests for Slack user deactivation and reactivation via admin endpoints. Verifies that: - Slack users can be deactivated by admins - Deactivated Slack users can be reactivated by admins - Reactivation is blocked when the seat limit is reached """ from datetime import datetime from datetime import timedelta import redis import requests from ee.onyx.server.license.models import LicenseMetadata from ee.onyx.server.license.models import LicenseSource from ee.onyx.server.license.models import PlanType from onyx.auth.schemas import UserRole from onyx.configs.app_configs import REDIS_DB_NUMBER from onyx.configs.app_configs import REDIS_HOST from onyx.configs.app_configs import REDIS_PORT from onyx.server.settings.models import ApplicationStatus from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.test_models import DATestUser _LICENSE_REDIS_KEY = "public:license:metadata" def _seed_license(r: redis.Redis, seats: int) -> None: now = datetime.utcnow() metadata = LicenseMetadata( tenant_id="public", organization_name="Test Org", seats=seats, used_seats=0, plan_type=PlanType.ANNUAL, issued_at=now, expires_at=now + timedelta(days=365), status=ApplicationStatus.ACTIVE, source=LicenseSource.MANUAL_UPLOAD, ) r.set(_LICENSE_REDIS_KEY, metadata.model_dump_json(), ex=300) def _clear_license(r: redis.Redis) -> None: r.delete(_LICENSE_REDIS_KEY) def _redis() -> redis.Redis: return redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=REDIS_DB_NUMBER) def _get_user_is_active(email: str, admin_user: DATestUser) -> bool: """Look up a user's is_active flag via the admin users list endpoint.""" result = UserManager.get_user_page( user_performing_action=admin_user, search_query=email, ) matching = [u for u in result.items if u.email == email] assert len(matching) == 1, f"Expected exactly 1 user with email {email}" return matching[0].is_active def test_slack_user_deactivate_and_reactivate( reset: None, # noqa: ARG001 ) -> None: # noqa: ARG001 """Admin can deactivate and then reactivate a Slack user.""" admin_user = UserManager.create(name="admin_user") slack_user = UserManager.create(name="slack_test_user") slack_user = UserManager.set_role( user_to_set=slack_user, target_role=UserRole.SLACK_USER, user_performing_action=admin_user, explicit_override=True, ) # Deactivate the Slack user UserManager.set_status( slack_user, target_status=False, user_performing_action=admin_user ) assert _get_user_is_active(slack_user.email, admin_user) is False # Reactivate the Slack user UserManager.set_status( slack_user, target_status=True, user_performing_action=admin_user ) assert _get_user_is_active(slack_user.email, admin_user) is True def test_slack_user_reactivation_blocked_by_seat_limit( reset: None, # noqa: ARG001 ) -> None: """Reactivating a deactivated Slack user returns 402 when seats are full.""" r = _redis() admin_user = UserManager.create(name="admin_user") slack_user = UserManager.create(name="slack_test_user") slack_user = UserManager.set_role( user_to_set=slack_user, target_role=UserRole.SLACK_USER, user_performing_action=admin_user, explicit_override=True, ) UserManager.set_status( slack_user, target_status=False, user_performing_action=admin_user ) # License allows 1 seat — only admin counts _seed_license(r, seats=1) try: response = requests.patch( url=f"{API_SERVER_URL}/manage/admin/activate-user", json={"user_email": slack_user.email}, headers=admin_user.headers, ) assert response.status_code == 402 finally: _clear_license(r) ================================================ FILE: backend/tests/integration/tests/users/test_user_pagination.py ================================================ from onyx.auth.schemas import UserRole from onyx.server.models import FullUserSnapshot from tests.integration.common_utils.managers.user import UserManager from tests.integration.common_utils.test_models import DATestUser # Gets a page of users from the db that match the given parameters and then # compares that returned page to the list of users passed into the function # to verify that the pagination and filtering works as expected. def _verify_user_pagination( users: list[DATestUser], user_performing_action: DATestUser, page_size: int = 5, search_query: str | None = None, role_filter: list[UserRole] | None = None, is_active_filter: bool | None = None, ) -> None: retrieved_users: list[FullUserSnapshot] = [] for i in range(0, len(users), page_size): paginated_result = UserManager.get_user_page( page_num=i // page_size, page_size=page_size, search_query=search_query, role_filter=role_filter, is_active_filter=is_active_filter, user_performing_action=user_performing_action, ) # Verify that the total items is equal to the length of the users list assert paginated_result.total_items == len(users) # Verify that the number of items in the page is equal to the page size assert len(paginated_result.items) == page_size # Add the retrieved users to the list of retrieved users retrieved_users.extend(paginated_result.items) # Create a set of all the expected emails all_expected_emails = set([user.email for user in users]) # Create a set of all the retrieved emails all_retrieved_emails = set([user.email for user in retrieved_users]) # Verify that the set of retrieved emails is equal to the set of expected emails assert all_expected_emails == all_retrieved_emails def test_user_pagination(reset: None) -> None: # noqa: ARG001 # Create an admin user to perform actions user_performing_action: DATestUser = UserManager.create( name="admin_performing_action" ) # Create 9 admin users admin_users: list[DATestUser] = UserManager.create_test_users( user_name_prefix="admin", count=9, role=UserRole.ADMIN, user_performing_action=user_performing_action, ) # Add the user_performing_action to the list of admins admin_users.append(user_performing_action) # Create 20 basic users basic_users: list[DATestUser] = UserManager.create_test_users( user_name_prefix="basic", count=10, role=UserRole.BASIC, user_performing_action=user_performing_action, ) # Create 10 global curators global_curators: list[DATestUser] = UserManager.create_test_users( user_name_prefix="global_curator", count=10, role=UserRole.GLOBAL_CURATOR, user_performing_action=user_performing_action, ) # Create 10 inactive admins inactive_admins: list[DATestUser] = UserManager.create_test_users( user_name_prefix="inactive_admin", count=10, role=UserRole.ADMIN, is_active=False, user_performing_action=user_performing_action, ) # Create 10 global curator users with an email containing "search" searchable_curators: list[DATestUser] = UserManager.create_test_users( user_name_prefix="search_curator", count=10, role=UserRole.GLOBAL_CURATOR, user_performing_action=user_performing_action, ) # Combine all the users lists into the all_users list all_users: list[DATestUser] = ( admin_users + basic_users + global_curators + inactive_admins + searchable_curators ) for user in all_users: # Verify that the user's role in the db matches # the role in the user object assert UserManager.is_role(user, user.role) # Verify that the user's status in the db matches # the status in the user object assert UserManager.is_status(user, user.is_active) # Verify pagination _verify_user_pagination( users=all_users, user_performing_action=user_performing_action, ) # Verify filtering by role _verify_user_pagination( users=admin_users + inactive_admins, role_filter=[UserRole.ADMIN], user_performing_action=user_performing_action, ) # Verify filtering by status _verify_user_pagination( users=inactive_admins, is_active_filter=False, user_performing_action=user_performing_action, ) # Verify filtering by search query _verify_user_pagination( users=searchable_curators, search_query="search", user_performing_action=user_performing_action, ) # Verify filtering by role and status _verify_user_pagination( users=inactive_admins, role_filter=[UserRole.ADMIN], is_active_filter=False, user_performing_action=user_performing_action, ) # Verify filtering by role and search query _verify_user_pagination( users=searchable_curators, role_filter=[UserRole.GLOBAL_CURATOR], search_query="search", user_performing_action=user_performing_action, ) # Verify filtering by role and status and search query _verify_user_pagination( users=inactive_admins, role_filter=[UserRole.ADMIN], is_active_filter=False, search_query="inactive_ad", user_performing_action=user_performing_action, ) # Verify filtering by multiple roles (admin and global curator) _verify_user_pagination( users=admin_users + global_curators + inactive_admins + searchable_curators, role_filter=[UserRole.ADMIN, UserRole.GLOBAL_CURATOR], user_performing_action=user_performing_action, ) ================================================ FILE: backend/tests/integration/tests/web_search/test_web_search_api.py ================================================ import os import pytest import requests from shared_configs.enums import WebContentProviderType from shared_configs.enums import WebSearchProviderType from tests.integration.common_utils.constants import API_SERVER_URL from tests.integration.common_utils.test_models import DATestUser class TestOnyxWebCrawler: """ Integration tests for the Onyx web crawler functionality. These tests verify that the built-in crawler can fetch and parse content from public websites correctly. """ @pytest.mark.skip(reason="Temporarily disabled") def test_fetches_public_url_successfully(self, admin_user: DATestUser) -> None: """Test that the crawler can fetch content from a public URL.""" response = requests.post( f"{API_SERVER_URL}/web-search/open-urls", json={"urls": ["https://example.com/"]}, headers=admin_user.headers, ) assert response.status_code == 200, response.text data = response.json() assert data["provider_type"] == WebContentProviderType.ONYX_WEB_CRAWLER.value assert len(data["results"]) == 1 result = data["results"][0] assert "content" in result content = result["content"] # example.com is a static page maintained by IANA with known content # Verify exact expected text from the page assert "Example Domain" in content assert "This domain is for use in" in content assert "documentation" in content or "illustrative" in content @pytest.mark.skip(reason="Temporarily disabled") def test_fetches_multiple_urls(self, admin_user: DATestUser) -> None: """Test that the crawler can fetch multiple URLs in one request.""" response = requests.post( f"{API_SERVER_URL}/web-search/open-urls", json={ "urls": [ "https://example.com/", "https://www.iana.org/domains/reserved", ] }, headers=admin_user.headers, ) assert response.status_code == 200, response.text data = response.json() assert data["provider_type"] == WebContentProviderType.ONYX_WEB_CRAWLER.value assert len(data["results"]) == 2 for result in data["results"]: assert "content" in result def test_handles_nonexistent_domain(self, admin_user: DATestUser) -> None: """Test that the crawler handles non-existent domains gracefully.""" response = requests.post( f"{API_SERVER_URL}/web-search/open-urls", json={"urls": ["https://this-domain-definitely-does-not-exist-12345.com/"]}, headers=admin_user.headers, ) assert response.status_code == 200, response.text data = response.json() assert data["provider_type"] == WebContentProviderType.ONYX_WEB_CRAWLER.value # The API filters out docs with no title/content, so unreachable domains return no results assert data["results"] == [] def test_handles_404_page(self, admin_user: DATestUser) -> None: """Test that the crawler handles 404 responses gracefully.""" response = requests.post( f"{API_SERVER_URL}/web-search/open-urls", json={"urls": ["https://example.com/this-page-does-not-exist-12345"]}, headers=admin_user.headers, ) assert response.status_code == 200, response.text data = response.json() assert data["provider_type"] == WebContentProviderType.ONYX_WEB_CRAWLER.value # Non-200 responses are treated as non-content and filtered out assert data["results"] == [] def test_https_url_with_path(self, admin_user: DATestUser) -> None: """Test that the crawler handles HTTPS URLs with paths correctly.""" response = requests.post( f"{API_SERVER_URL}/web-search/open-urls", json={"urls": ["https://www.iana.org/about"]}, headers=admin_user.headers, ) assert response.status_code == 200, response.text data = response.json() assert len(data["results"]) == 1 result = data["results"][0] assert "content" in result class TestSsrfProtection: """ Integration tests for SSRF protection on the /open-urls endpoint. These tests verify that the endpoint correctly blocks requests to: - Internal/private IP addresses - Cloud metadata endpoints - Blocked hostnames (Kubernetes, cloud metadata, etc.) """ def test_blocks_localhost_ip(self, admin_user: DATestUser) -> None: """Test that requests to localhost (127.0.0.1) are blocked.""" response = requests.post( f"{API_SERVER_URL}/web-search/open-urls", json={"urls": ["http://127.0.0.1/"]}, headers=admin_user.headers, ) assert response.status_code == 200 data = response.json() # URL should be processed but return empty content (blocked by SSRF protection) assert len(data["results"]) == 0 or data["results"][0]["content"] == "" def test_blocks_private_ip_10_network(self, admin_user: DATestUser) -> None: """Test that requests to 10.x.x.x private network are blocked.""" response = requests.post( f"{API_SERVER_URL}/web-search/open-urls", json={"urls": ["http://10.0.0.1/"]}, headers=admin_user.headers, ) assert response.status_code == 200 data = response.json() assert len(data["results"]) == 0 or data["results"][0]["content"] == "" def test_blocks_private_ip_192_168_network(self, admin_user: DATestUser) -> None: """Test that requests to 192.168.x.x private network are blocked.""" response = requests.post( f"{API_SERVER_URL}/web-search/open-urls", json={"urls": ["http://192.168.1.1/"]}, headers=admin_user.headers, ) assert response.status_code == 200 data = response.json() assert len(data["results"]) == 0 or data["results"][0]["content"] == "" def test_blocks_private_ip_172_network(self, admin_user: DATestUser) -> None: """Test that requests to 172.16-31.x.x private network are blocked.""" response = requests.post( f"{API_SERVER_URL}/web-search/open-urls", json={"urls": ["http://172.16.0.1/"]}, headers=admin_user.headers, ) assert response.status_code == 200 data = response.json() assert len(data["results"]) == 0 or data["results"][0]["content"] == "" def test_blocks_aws_metadata_endpoint(self, admin_user: DATestUser) -> None: """Test that requests to AWS metadata endpoint (169.254.169.254) are blocked.""" response = requests.post( f"{API_SERVER_URL}/web-search/open-urls", json={"urls": ["http://169.254.169.254/latest/meta-data/"]}, headers=admin_user.headers, ) assert response.status_code == 200 data = response.json() assert len(data["results"]) == 0 or data["results"][0]["content"] == "" def test_blocks_kubernetes_metadata_hostname(self, admin_user: DATestUser) -> None: """Test that requests to Kubernetes internal hostname are blocked.""" response = requests.post( f"{API_SERVER_URL}/web-search/open-urls", json={"urls": ["http://kubernetes.default.svc.cluster.local/"]}, headers=admin_user.headers, ) assert response.status_code == 200 data = response.json() assert len(data["results"]) == 0 or data["results"][0]["content"] == "" def test_blocks_google_metadata_hostname(self, admin_user: DATestUser) -> None: """Test that requests to Google Cloud metadata hostname are blocked.""" response = requests.post( f"{API_SERVER_URL}/web-search/open-urls", json={"urls": ["http://metadata.google.internal/"]}, headers=admin_user.headers, ) assert response.status_code == 200 data = response.json() assert len(data["results"]) == 0 or data["results"][0]["content"] == "" def test_blocks_localhost_with_port(self, admin_user: DATestUser) -> None: """Test that requests to localhost with custom port are blocked.""" response = requests.post( f"{API_SERVER_URL}/web-search/open-urls", json={"urls": ["http://127.0.0.1:8080/metrics"]}, headers=admin_user.headers, ) assert response.status_code == 200 data = response.json() assert len(data["results"]) == 0 or data["results"][0]["content"] == "" def test_multiple_urls_filters_internal(self, admin_user: DATestUser) -> None: """Test that internal URLs are filtered while external URLs are processed.""" response = requests.post( f"{API_SERVER_URL}/web-search/open-urls", json={ "urls": [ "http://127.0.0.1/", # Should be blocked "http://192.168.1.1/", # Should be blocked "https://example.com/", # Should be allowed (if reachable) ] }, headers=admin_user.headers, ) assert response.status_code == 200 data = response.json() # Internal URLs should return empty content # The exact behavior depends on whether example.com is reachable # but internal URLs should definitely not return sensitive data for result in data["results"]: # Ensure no result contains internal network data content = result.get("content", "") # These patterns would indicate SSRF vulnerability assert "metrics" not in content.lower() or "example" in content.lower() assert "token" not in content.lower() or "example" in content.lower() # Mark the Exa-dependent tests to skip if no API key pytestmark_exa = pytest.mark.skipif( not os.environ.get("EXA_API_KEY"), reason="EXA_API_KEY not set; live web search tests require real credentials", ) def _activate_exa_provider(admin_user: DATestUser) -> int: response = requests.post( f"{API_SERVER_URL}/admin/web-search/search-providers", json={ "id": None, "name": "integration-exa-provider", "provider_type": WebSearchProviderType.EXA.value, "config": {}, "api_key": os.environ["EXA_API_KEY"], "api_key_changed": True, "activate": True, }, headers=admin_user.headers, ) assert response.status_code == 200, response.text provider = response.json() assert provider["provider_type"] == WebSearchProviderType.EXA.value assert provider["is_active"] is True assert provider["has_api_key"] is True return provider["id"] @pytestmark_exa @pytest.mark.skip(reason="Temporarily disabled") def test_web_search_endpoints_with_exa( reset: None, # noqa: ARG001 admin_user: DATestUser, ) -> None: provider_id = _activate_exa_provider(admin_user) assert isinstance(provider_id, int) search_request = {"queries": ["wikipedia python programming"], "max_results": 3} lite_response = requests.post( f"{API_SERVER_URL}/web-search/search-lite", json=search_request, headers=admin_user.headers, ) assert lite_response.status_code == 200, lite_response.text lite_data = lite_response.json() assert lite_data["provider_type"] == WebSearchProviderType.EXA.value assert lite_data["results"], "Expected web search results from Exa" urls = [result["url"] for result in lite_data["results"] if result.get("url")][:2] assert urls, "Web search should return at least one URL" open_response = requests.post( f"{API_SERVER_URL}/web-search/open-urls", json={"urls": urls}, headers=admin_user.headers, ) assert open_response.status_code == 200, open_response.text open_data = open_response.json() assert open_data["provider_type"] == WebContentProviderType.ONYX_WEB_CRAWLER.value assert len(open_data["results"]) == len(urls) assert all("content" in result for result in open_data["results"]) combined_response = requests.post( f"{API_SERVER_URL}/web-search/search", json=search_request, headers=admin_user.headers, ) assert combined_response.status_code == 200, combined_response.text combined_data = combined_response.json() assert combined_data["search_provider_type"] == WebSearchProviderType.EXA.value assert ( combined_data["content_provider_type"] == WebContentProviderType.ONYX_WEB_CRAWLER.value ) assert combined_data["search_results"] unique_urls = list( dict.fromkeys( result["url"] for result in combined_data["search_results"] if result.get("url") ) ) assert len(combined_data["full_content_results"]) == len(unique_urls) ================================================ FILE: backend/tests/load_env_vars.py ================================================ import os def load_env_vars(env_file: str = ".env") -> None: current_dir = os.path.dirname(os.path.abspath(__file__)) env_path = os.path.join(current_dir, env_file) try: with open(env_path, "r") as f: for line in f: line = line.strip() if line and not line.startswith("#"): key, value = line.split("=", 1) os.environ[key] = value.strip() print("Successfully loaded environment variables") except FileNotFoundError: print(f"File {env_file} not found") ================================================ FILE: backend/tests/regression/answer_quality/README.md ================================================ # Search Quality Test Script This Python script automates the process of running search quality tests for a backend system. ## Features - Loads configuration from a YAML file - Sets up Docker environment - Manages environment variables - Switches to specified Git branch - Uploads test documents - Runs search quality tests - Cleans up Docker containers (optional) ## Usage 1. Ensure you have the required dependencies installed. 2. Configure the `search_test_config.yaml` file based on the `search_test_config.yaml.template` file. 3. Configure the `.env_eval` file in `deployment/docker_compose` with the correct environment variables. 4. Set up the PYTHONPATH permanently: Add the following line to your shell configuration file (e.g., `~/.bashrc`, `~/.zshrc`, or `~/.bash_profile`): ``` export PYTHONPATH=$PYTHONPATH:/path/to/onyx/backend ``` Replace `/path/to/onyx` with the actual path to your Onyx repository. After adding this line, restart your terminal or run `source ~/.bashrc` (or the appropriate config file) to apply the changes. 5. Navigate to Onyx repo: ``` cd path/to/onyx ``` 6. Navigate to the answer_quality folder: ``` cd backend/tests/regression/answer_quality ``` 7. To launch the evaluation environment, run the launch_eval_env.py script (this step can be skipped if you are running the env outside of docker, just leave "environment_name" blank): ``` python launch_eval_env.py ``` 8. Run the file_uploader.py script to upload the zip files located at the path "zipped_documents_file" ``` python file_uploader.py ``` 9. Run the run_qa.py script to ask questions from the jsonl located at the path "questions_file". This will hit the "query/answer-with-quote" API endpoint. ``` python run_qa.py ``` Note: All data will be saved even after the containers are shut down. There are instructions below to re-launching docker containers using this data. If you decide to run multiple UIs at the same time, the ports will increment upwards from 3000 (E.g. http://localhost:3001). To see which port the desired instance is on, look at the ports on the nginx container by running `docker ps` or using docker desktop. Docker daemon must be running for this to work. ## Configuration Edit `search_test_config.yaml` to set: - output_folder - This is the folder where the folders for each test will go - These folders will contain the postgres/vespa data as well as the results for each test - zipped_documents_file - The path to the zip file containing the files you'd like to test against - questions_file - The path to the yaml containing the questions you'd like to test with - commit_sha - Set this to the SHA of the commit you want to run the test against - You must clear all local changes if you want to use this option - Set this to null if you want it to just use the code as is - clean_up_docker_containers - Set this to true to automatically delete all docker containers, networks and volumes after the test - launch_web_ui - Set this to true if you want to use the UI during/after the testing process - only_state - Whether to only run Vespa and Postgres - only_retrieve_docs - Set true to only retrieve documents, not LLM response - This is to save on API costs - use_cloud_gpu - Set to true or false depending on if you want to use the remote gpu - Only need to set this if use_cloud_gpu is true - model_server_ip - This is the ip of the remote model server - Only need to set this if use_cloud_gpu is true - model_server_port - This is the port of the remote model server - Only need to set this if use_cloud_gpu is true - environment_name - Use this if you would like to relaunch a previous test instance - Input the env_name of the test you'd like to re-launch - Leave empty to launch referencing local default network locations - limit - Max number of questions you'd like to ask against the dataset - Set to null for no limit - llm - Fill this out according to the normal LLM seeding ## Relaunching From Existing Data To launch an existing set of containers that has already completed indexing, set the environment_name variable. This will launch the docker containers mounted on the volumes of the indicated env_name and will not automatically index any documents or run any QA. Once these containers are launched you can run file_uploader.py or run_qa.py (assuming you have run the steps in the Usage section above). - file_uploader.py will upload and index additional zipped files located at the zipped_documents_file path. - run_qa.py will ask questions located at the questions_file path against the indexed documents. ================================================ FILE: backend/tests/regression/answer_quality/__init__.py ================================================ ================================================ FILE: backend/tests/regression/answer_quality/api_utils.py ================================================ import requests from retry import retry from onyx.configs.constants import DocumentSource from onyx.connectors.models import InputType from onyx.db.enums import IndexingStatus from onyx.server.documents.models import ConnectorBase from tests.regression.answer_quality.cli_utils import get_api_server_host_port GENERAL_HEADERS = {"Content-Type": "application/json"} def _api_url_builder(env_name: str, api_path: str) -> str: if env_name: return f"http://localhost:{get_api_server_host_port(env_name)}" + api_path else: return "http://localhost:8080" + api_path @retry(tries=10, delay=10) def check_indexing_status(env_name: str) -> tuple[int, bool]: url = _api_url_builder(env_name, "/manage/admin/connector/indexing-status/") try: indexing_status_dict = requests.post( url, headers=GENERAL_HEADERS, json={"get_all_connectors": True} ).json() except Exception as e: print("Failed to check indexing status, API server is likely starting up:") print(f"\t {str(e)}") print("trying again") raise e ongoing_index_attempts = False doc_count = 0 for connectors_by_source in indexing_status_dict: connectors = connectors_by_source["indexing_statuses"] for connector in connectors: status = connector["last_status"] if ( status == IndexingStatus.IN_PROGRESS or status == IndexingStatus.NOT_STARTED ): ongoing_index_attempts = True elif status == IndexingStatus.SUCCESS: doc_count += 16 doc_count += connector["docs_indexed"] doc_count -= 16 # all the +16 and -16 are to account for the fact that the indexing status # is only updated every 16 documents and will tells us how many are # chunked, not indexed. probably need to fix this. in the future! if doc_count: doc_count += 16 return doc_count, ongoing_index_attempts def run_cc_once(env_name: str, connector_id: int, credential_id: int) -> None: url = _api_url_builder(env_name, "/manage/admin/connector/run-once/") body = { "connector_id": connector_id, "credential_ids": [credential_id], "from_beginning": True, } print("body:", body) response = requests.post(url, headers=GENERAL_HEADERS, json=body) if response.status_code == 200: print("Connector created successfully:", response.json()) else: print("Failed status_code:", response.status_code) print("Failed text:", response.text) def create_cc_pair(env_name: str, connector_id: int, credential_id: int) -> None: url = _api_url_builder( env_name, f"/manage/connector/{connector_id}/credential/{credential_id}" ) body = {"name": "zip_folder_contents", "is_public": True, "groups": []} print("body:", body) response = requests.put(url, headers=GENERAL_HEADERS, json=body) if response.status_code == 200: print("Connector created successfully:", response.json()) else: print("Failed status_code:", response.status_code) print("Failed text:", response.text) def _get_existing_connector_names(env_name: str) -> list[str]: url = _api_url_builder(env_name, "/manage/connector") body = { "credential_json": {}, "admin_public": True, } response = requests.get(url, headers=GENERAL_HEADERS, json=body) if response.status_code == 200: connectors = response.json() return [connector["name"] for connector in connectors] else: raise RuntimeError(response.__dict__) def create_connector(env_name: str, file_paths: list[str]) -> int: url = _api_url_builder(env_name, "/manage/admin/connector") connector_name = base_connector_name = "search_eval_connector" existing_connector_names = _get_existing_connector_names(env_name) count = 1 while connector_name in existing_connector_names: connector_name = base_connector_name + "_" + str(count) count += 1 connector = ConnectorBase( name=connector_name, source=DocumentSource.FILE, input_type=InputType.LOAD_STATE, connector_specific_config={ "file_locations": file_paths, "file_names": [], # For regression tests, no need for file_names "zip_metadata_file_id": None, }, refresh_freq=None, prune_freq=None, indexing_start=None, ) body = connector.model_dump() response = requests.post(url, headers=GENERAL_HEADERS, json=body) if response.status_code == 200: return response.json()["id"] else: raise RuntimeError(response.__dict__) def create_credential(env_name: str) -> int: url = _api_url_builder(env_name, "/manage/credential") body = { "credential_json": {}, "admin_public": True, "source": DocumentSource.FILE, } response = requests.post(url, headers=GENERAL_HEADERS, json=body) if response.status_code == 200: print("credential created successfully:", response.json()) return response.json()["id"] else: raise RuntimeError(response.__dict__) @retry(tries=10, delay=2, backoff=2) def upload_file(env_name: str, zip_file_path: str) -> list[str]: files = [ ("files", open(zip_file_path, "rb")), ] api_path = _api_url_builder(env_name, "/manage/admin/connector/file/upload") try: response = requests.post(api_path, files=files) response.raise_for_status() # Raises an HTTPError for bad responses print("file uploaded successfully:", response.json()) return response.json()["file_paths"] except Exception as e: print("File upload failed, waiting for API server to come up and trying again") raise e ================================================ FILE: backend/tests/regression/answer_quality/cli_utils.py ================================================ import json import os import socket import subprocess import sys import time from datetime import datetime from threading import Thread from typing import IO import yaml from retry import retry def _run_command(command: str, stream_output: bool = False) -> tuple[str, str]: process = subprocess.Popen( command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True, bufsize=1, ) stdout_lines: list[str] = [] stderr_lines: list[str] = [] def process_stream(stream: IO[str], lines: list[str]) -> None: for line in stream: lines.append(line) if stream_output: print( line, end="", file=sys.stdout if stream == process.stdout else sys.stderr, ) stdout_thread = Thread(target=process_stream, args=(process.stdout, stdout_lines)) stderr_thread = Thread(target=process_stream, args=(process.stderr, stderr_lines)) stdout_thread.start() stderr_thread.start() stdout_thread.join() stderr_thread.join() process.wait() if process.returncode != 0: raise RuntimeError(f"Command failed with error: {''.join(stderr_lines)}") return "".join(stdout_lines), "".join(stderr_lines) def get_current_commit_sha() -> str: print("Getting current commit SHA...") stdout, _ = _run_command("git rev-parse HEAD") sha = stdout.strip() print(f"Current commit SHA: {sha}") return sha def switch_to_commit(commit_sha: str) -> None: print(f"Switching to commit: {commit_sha}...") _run_command(f"git checkout {commit_sha}") print(f"Successfully switched to commit: {commit_sha}") print("Repository updated successfully.") def get_docker_container_env_vars(env_name: str) -> dict: """ Retrieves environment variables from "background" and "api_server" Docker containers. """ print(f"Getting environment variables for containers with env_name: {env_name}") combined_env_vars = {} for container_type in ["background", "api_server"]: container_name = _run_command( f"docker ps -a --format '{{{{.Names}}}}' | awk '/{container_type}/ && /{env_name}/'" )[0].strip() if not container_name: raise RuntimeError( f"No {container_type} container found with env_name: {env_name}" ) env_vars_json = _run_command( f"docker inspect --format='{{{{json .Config.Env}}}}' {container_name}" )[0] env_vars_list = json.loads(env_vars_json.strip()) for env_var in env_vars_list: key, value = env_var.split("=", 1) combined_env_vars[key] = value return combined_env_vars def manage_data_directories(env_name: str, base_path: str, use_cloud_gpu: bool) -> None: # Use the user's home directory as the base path target_path = os.path.join(os.path.expanduser(base_path), env_name) directories = { "DANSWER_POSTGRES_DATA_DIR": os.path.join(target_path, "postgres/"), "DANSWER_VESPA_DATA_DIR": os.path.join(target_path, "vespa/"), } if not use_cloud_gpu: directories["DANSWER_INDEX_MODEL_CACHE_DIR"] = os.path.join( target_path, "index_model_cache/" ) directories["DANSWER_INFERENCE_MODEL_CACHE_DIR"] = os.path.join( target_path, "inference_model_cache/" ) # Create directories if they don't exist for env_var, directory in directories.items(): os.makedirs(directory, exist_ok=True) os.environ[env_var] = directory print(f"Set {env_var} to: {directory}") results_output_path = os.path.join(target_path, "evaluations_output/") os.makedirs(results_output_path, exist_ok=True) def set_env_variables( remote_server_ip: str, remote_server_port: str, use_cloud_gpu: bool, llm_config: dict, ) -> None: env_vars: dict = {} env_vars["ENV_SEED_CONFIGURATION"] = json.dumps({"llms": [llm_config]}) env_vars["ENABLE_PAID_ENTERPRISE_EDITION_FEATURES"] = "true" if use_cloud_gpu: env_vars["MODEL_SERVER_HOST"] = remote_server_ip env_vars["MODEL_SERVER_PORT"] = remote_server_port env_vars["INDEXING_MODEL_SERVER_HOST"] = remote_server_ip for env_var_name, env_var in env_vars.items(): os.environ[env_var_name] = env_var print(f"Set {env_var_name} to: {env_var}") def _is_port_in_use(port: int) -> bool: with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: return s.connect_ex(("localhost", port)) == 0 def start_docker_compose( env_name: str, launch_web_ui: bool, use_cloud_gpu: bool, only_state: bool = False ) -> None: print("Starting Docker Compose...") os.chdir(os.path.dirname(__file__)) os.chdir("../../../../deployment/docker_compose/") command = ( f"docker compose -f docker-compose.search-testing.yml -p onyx-{env_name} up -d" ) command += " --build" command += " --force-recreate" if only_state: command += " index relational_db" else: if use_cloud_gpu: command += " --scale indexing_model_server=0" command += " --scale inference_model_server=0" if launch_web_ui: web_ui_port = 3000 while _is_port_in_use(web_ui_port): web_ui_port += 1 print(f"UI will be launched at http://localhost:{web_ui_port}") os.environ["NGINX_PORT"] = str(web_ui_port) else: command += " --scale web_server=0" command += " --scale nginx=0" print("Docker Command:\n", command) _run_command(command, stream_output=True) print("Containers have been launched") def cleanup_docker(env_name: str) -> None: print( f"Deleting Docker containers, volumes, and networks for project env_name: {env_name}" ) stdout, _ = _run_command("docker ps -a --format '{{json .}}'") containers = [json.loads(line) for line in stdout.splitlines()] if not env_name: env_name = datetime.now().strftime("-%Y") project_name = f"onyx{env_name}" containers_to_delete = [ c for c in containers if c["Names"].startswith(project_name) ] if not containers_to_delete: print(f"No containers found for project: {project_name}") else: container_ids = " ".join([c["ID"] for c in containers_to_delete]) _run_command(f"docker rm -f {container_ids}") print( f"Successfully deleted {len(containers_to_delete)} containers for project: {project_name}" ) stdout, _ = _run_command("docker volume ls --format '{{.Name}}'") volumes = stdout.splitlines() volumes_to_delete = [v for v in volumes if v.startswith(project_name)] if not volumes_to_delete: print(f"No volumes found for project: {project_name}") return # Delete filtered volumes volume_names = " ".join(volumes_to_delete) _run_command(f"docker volume rm {volume_names}") print( f"Successfully deleted {len(volumes_to_delete)} volumes for project: {project_name}" ) stdout, _ = _run_command("docker network ls --format '{{.Name}}'") networks = stdout.splitlines() networks_to_delete = [n for n in networks if env_name in n] if not networks_to_delete: print(f"No networks found containing env_name: {env_name}") else: network_names = " ".join(networks_to_delete) _run_command(f"docker network rm {network_names}") print( f"Successfully deleted {len(networks_to_delete)} networks containing env_name: {env_name}" ) @retry(tries=5, delay=5, backoff=2) def get_api_server_host_port(env_name: str) -> str: """ This pulls all containers with the provided env_name It then grabs the JSON specific container with a name containing "api_server" It then grabs the port info from the JSON and strips out the relevent data """ container_name = "api_server" stdout, _ = _run_command("docker ps -a --format '{{json .}}'") containers = [json.loads(line) for line in stdout.splitlines()] server_jsons = [] for container in containers: if container_name in container["Names"] and env_name in container["Names"]: server_jsons.append(container) if not server_jsons: raise RuntimeError( f"No container found containing: {container_name} and {env_name}" ) elif len(server_jsons) > 1: raise RuntimeError( f"Too many containers matching {container_name} found, please indicate a env_name" ) server_json = server_jsons[0] # This is in case the api_server has multiple ports client_port = "8080" ports = server_json.get("Ports", "") port_infos = ports.split(",") if ports else [] port_dict = {} for port_info in port_infos: port_arr = port_info.split(":")[-1].split("->") if port_info else [] if len(port_arr) == 2: port_dict[port_arr[1]] = port_arr[0] # Find the host port where client_port is in the key matching_ports = [value for key, value in port_dict.items() if client_port in key] if len(matching_ports) > 1: raise RuntimeError(f"Too many ports matching {client_port} found") if not matching_ports: raise RuntimeError( f"No port found containing: {client_port} for container: {container_name} and env_name: {env_name}" ) return matching_ports[0] # Added function to restart Vespa container def restart_vespa_container(env_name: str) -> None: print(f"Restarting Vespa container for env_name: {env_name}") # Find the Vespa container stdout, _ = _run_command( f"docker ps -a --format '{{{{.Names}}}}' | awk '/index-1/ && /{env_name}/'" ) container_name = stdout.strip() if not container_name: raise RuntimeError(f"No Vespa container found with env_name: {env_name}") # Restart the container _run_command(f"docker restart {container_name}") print(f"Vespa container '{container_name}' has begun restarting") time.sleep(30) print(f"Vespa container '{container_name}' has been restarted") if __name__ == "__main__": """ Running this just cleans up the docker environment for the container indicated by environment_name If no environment_name is indicated, will just clean up all onyx docker containers/volumes/networks Note: vespa/postgres mounts are not deleted """ current_dir = os.path.dirname(os.path.abspath(__file__)) config_path = os.path.join(current_dir, "search_test_config.yaml") with open(config_path, "r") as file: config = yaml.safe_load(file) if not isinstance(config, dict): raise TypeError("config must be a dictionary") cleanup_docker(config["environment_name"]) ================================================ FILE: backend/tests/regression/answer_quality/file_uploader.py ================================================ import csv import os import tempfile import time import zipfile from pathlib import Path from types import SimpleNamespace import yaml from tests.regression.answer_quality.api_utils import check_indexing_status from tests.regression.answer_quality.api_utils import create_cc_pair from tests.regression.answer_quality.api_utils import create_connector from tests.regression.answer_quality.api_utils import create_credential from tests.regression.answer_quality.api_utils import run_cc_once from tests.regression.answer_quality.api_utils import upload_file def unzip_and_get_file_paths(zip_file_path: str) -> list[str]: persistent_dir = tempfile.mkdtemp() with zipfile.ZipFile(zip_file_path, "r") as zip_ref: zip_ref.extractall(persistent_dir) file_paths = [] for root, _, files in os.walk(persistent_dir): for file in sorted(files): file_paths.append(os.path.join(root, file)) return file_paths def create_temp_zip_from_files(file_paths: list[str]) -> str: persistent_dir = tempfile.mkdtemp() zip_file_path = os.path.join(persistent_dir, "temp.zip") with zipfile.ZipFile(zip_file_path, "w") as zip_file: for file_path in file_paths: zip_file.write(file_path, Path(file_path).name) return zip_file_path def upload_test_files(zip_file_path: str, env_name: str) -> None: print("zip:", zip_file_path) file_paths = upload_file(env_name, zip_file_path) conn_id = create_connector(env_name, file_paths) cred_id = create_credential(env_name) create_cc_pair(env_name, conn_id, cred_id) run_cc_once(env_name, conn_id, cred_id) def manage_file_upload(zip_file_path: str, env_name: str) -> None: start_time = time.time() unzipped_file_paths = unzip_and_get_file_paths(zip_file_path) total_file_count = len(unzipped_file_paths) problem_file_list: list[str] = [] while True: doc_count, ongoing_index_attempts = check_indexing_status(env_name) if ongoing_index_attempts: print( f"{doc_count} docs indexed but waiting for ongoing indexing jobs to finish..." ) elif not doc_count: print("No docs indexed, waiting for indexing to start") temp_zip_file_path = create_temp_zip_from_files(unzipped_file_paths) upload_test_files(temp_zip_file_path, env_name) os.unlink(temp_zip_file_path) elif (doc_count + len(problem_file_list)) < total_file_count: print(f"No ongooing indexing attempts but only {doc_count} docs indexed") remaining_files = unzipped_file_paths[doc_count + len(problem_file_list) :] problem_file_list.append(remaining_files.pop(0)) print( f"Removing first doc and grabbed last {len(remaining_files)} docs to try agian" ) temp_zip_file_path = create_temp_zip_from_files(remaining_files) upload_test_files(temp_zip_file_path, env_name) os.unlink(temp_zip_file_path) else: print(f"Successfully uploaded {doc_count} docs!") break time.sleep(10) if problem_file_list: problem_file_csv_path = os.path.join(current_dir, "problem_files.csv") with open(problem_file_csv_path, "w", newline="") as csvfile: csvwriter = csv.writer(csvfile) csvwriter.writerow(["Problematic File Paths"]) for problem_file in problem_file_list: csvwriter.writerow([problem_file]) for file in unzipped_file_paths: os.unlink(file) print(f"Total time taken: {(time.time() - start_time) / 60} minutes") if __name__ == "__main__": current_dir = os.path.dirname(os.path.abspath(__file__)) config_path = os.path.join(current_dir, "search_test_config.yaml") with open(config_path, "r") as file: config = SimpleNamespace(**yaml.safe_load(file)) file_location = config.zipped_documents_file env_name = config.environment_name manage_file_upload(file_location, env_name) ================================================ FILE: backend/tests/regression/answer_quality/launch_eval_env.py ================================================ import os from types import SimpleNamespace import yaml from tests.regression.answer_quality.cli_utils import manage_data_directories from tests.regression.answer_quality.cli_utils import set_env_variables from tests.regression.answer_quality.cli_utils import start_docker_compose from tests.regression.answer_quality.cli_utils import switch_to_commit def load_config(config_filename: str) -> SimpleNamespace: current_dir = os.path.dirname(os.path.abspath(__file__)) config_path = os.path.join(current_dir, config_filename) with open(config_path, "r") as file: return SimpleNamespace(**yaml.safe_load(file)) def main() -> None: config = load_config("search_test_config.yaml") if config.environment_name: env_name = config.environment_name print("launching onyx with environment name:", env_name) else: print("No env name defined. Not launching docker.") print( "Please define a name in the config yaml to start a new env or use an existing env" ) return set_env_variables( config.model_server_ip, config.model_server_port, config.use_cloud_gpu, config.llm, ) manage_data_directories(env_name, config.output_folder, config.use_cloud_gpu) if config.commit_sha: switch_to_commit(config.commit_sha) start_docker_compose( env_name, config.launch_web_ui, config.use_cloud_gpu, config.only_state ) if __name__ == "__main__": main() ================================================ FILE: backend/tests/regression/answer_quality/search_test_config.yaml.template ================================================ # Copy this to search_test_config.yaml and fill in the values to run the eval pipeline # Don't forget to also update the .env_eval file with the correct values # Directory where test results will be saved output_folder: "~/onyx_test_results" # Path to the zip file containing sample documents zipped_documents_file: "~/sampledocs.zip" # Path to the YAML file containing sample questions questions_file: "~/sample_questions.yaml" # Git commit SHA to use (null means use current code as is) commit_sha: null # Whether to launch a web UI for the test launch_web_ui: false # Only retrieve documents, not LLM response only_retrieve_docs: false # Whether to use a cloud GPU for processing use_cloud_gpu: false # IP address of the model server (placeholder) model_server_ip: "PUT_PUBLIC_CLOUD_IP_HERE" # Port of the model server (placeholder) model_server_port: "PUT_PUBLIC_CLOUD_PORT_HERE" # Name for existing testing env (empty string uses default ports) environment_name: "" # Limit on number of tests to run (null means no limit) limit: null # LLM configuration llm: # Name of the LLM name: "default_test_llm" # Provider of the LLM (e.g., OpenAI) provider: "openai" # API key api_key: "PUT_API_KEY_HERE" # Default model name to use default_model_name: "gpt-4o" # List of model names to use for testing model_names: ["gpt-4o"] ================================================ FILE: backend/tests/regression/search_quality/README.md ================================================ # Search Quality Test Script This Python script evaluates the search and answer quality for a list of queries, against a ground truth. It will use the currently ingested documents for the search, answer generation, and ground truth comparisons. ## Usage 1. Ensure you have the required dependencies installed and onyx running. 2. Ensure you have `OPENAI_API_KEY` set if you intend to do answer evaluation (enabled by default, unless you run the script with the `-s` flag). Go to the API Keys page in the admin panel, generate a basic api token, and add it to the env file as `ONYX_API_KEY=on_...`. 3. Navigate to Onyx repo, **search_quality** folder: ``` cd path/to/onyx/backend/tests/regression/search_quality ``` 4. Copy `test_queries.json.template` to `test_queries.json` and add/remove test queries in it. The fields for each query are: - `question: str` the query - `ground_truth: list[GroundTruth]` an un-ranked list of expected search results with fields: - `doc_source: str` document source (e.g., web, google_drive, linear), used to normalize the links in some cases - `doc_link: str` link associated with document, used to find corresponding document in local index - `ground_truth_response: Optional[str]` a response with clauses the ideal answer should include - `categories: Optional[list[str]]` list of categories, used to aggregate evaluation results 5. Run `run_search_eval.py` to evaluate the queries. All parameters are optional and have sensible defaults: ``` python run_search_eval.py -d --dataset # Path to the test-set JSON file (default: ./test_queries.json) -n --num_search # Maximum number of documents to retrieve per search (default: 50) -a --num_answer # Maximum number of documents to use for answer evaluation (default: 25) -w --max_workers # Maximum number of concurrent search requests (0 = unlimited, default: 10). -r --max_req_rate # Maximum number of search requests per minute (0 = unlimited, default: 0). -q --timeout # Request timeout in seconds (default: 120) -e --api_endpoint # Base URL of the Onyx API server (default: http://127.0.0.1:8080) -s --search_only # Only perform search and not answer evaluation (default: false) -t --tenant_id # Tenant ID to use for the evaluation (default: None) ``` Note: If you only care about search quality, you should run with the `-s` flag for a significantly faster evaluation. Furthermore, you should set `-r` to 1 if running with federated search enabled to avoid hitting rate limits. 6. After the run, an `eval-YYYY-MM-DD-HH-MM-SS` folder is created containing: * `test_queries.json` – the dataset used with the list of valid queries and corresponding indexed ground truth. * `search_results.json` – per-query search and answer details. * `results_by_category.csv` – aggregated metrics per category and for "all". * `search_position_chart.png` – bar-chart of ground-truth ranks. You can replace `test_queries.json` with the generated one for a slightly faster loading of the queries the next time around. ================================================ FILE: backend/tests/regression/search_quality/models.py ================================================ from pydantic import BaseModel from onyx.configs.constants import DocumentSource from onyx.context.search.models import SavedSearchDoc class GroundTruth(BaseModel): doc_source: DocumentSource doc_link: str class TestQuery(BaseModel): question: str ground_truth: list[GroundTruth] = [] ground_truth_response: str | None = None categories: list[str] = [] # autogenerated ground_truth_docids: list[str] = [] class EvalConfig(BaseModel): max_search_results: int max_answer_context: int num_workers: int # 0 = unlimited max_request_rate: int # 0 = unlimited request_timeout: int api_url: str search_only: bool class OneshotQAResult(BaseModel): time_taken: float top_documents: list[SavedSearchDoc] answer: str | None class RetrievedDocument(BaseModel): document_id: str chunk_id: int content: str class AnalysisSummary(BaseModel): question: str categories: list[str] found: bool rank: int | None total_results: int ground_truth_count: int response_relevancy: float | None = None faithfulness: float | None = None factual_correctness: float | None = None answer: str | None = None retrieved: list[RetrievedDocument] = [] time_taken: float class SearchMetrics(BaseModel): total_queries: int found_count: int # for found results best_rank: int worst_rank: int average_rank: float top_k_accuracy: dict[int, float] class AnswerMetrics(BaseModel): response_relevancy: float faithfulness: float factual_correctness: float # only for metric computation n_response_relevancy: int n_faithfulness: int n_factual_correctness: int class CombinedMetrics(SearchMetrics, AnswerMetrics): average_time_taken: float ================================================ FILE: backend/tests/regression/search_quality/run_search_eval.py ================================================ import csv import json import os import sys import time from collections import defaultdict from concurrent.futures import as_completed from concurrent.futures import ThreadPoolExecutor from datetime import datetime from pathlib import Path from threading import Event from threading import Lock from threading import Semaphore from typing import cast import matplotlib.pyplot as plt import requests from dotenv import load_dotenv from matplotlib.patches import Patch from pydantic import ValidationError from requests.exceptions import RequestException from retry import retry # add onyx/backend to path (since this isn't done automatically when running as a script) current_dir = Path(__file__).parent onyx_dir = current_dir.parent.parent.parent.parent sys.path.append(str(onyx_dir / "backend")) # load env before app_config loads (since env doesn't get loaded when running as a script) env_path = onyx_dir / ".vscode" / ".env" if not env_path.exists(): raise RuntimeError( "Could not find .env file. Please create one in the root .vscode directory." ) load_dotenv(env_path) # pylint: disable=E402 # flake8: noqa: E402 from ee.onyx.server.query_and_chat.models import SearchFullResponse from ee.onyx.server.query_and_chat.models import SendSearchQueryRequest from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_OVERFLOW from onyx.configs.app_configs import POSTGRES_API_SERVER_POOL_SIZE from onyx.context.search.models import BaseFilters from onyx.context.search.models import SavedSearchDoc from onyx.db.engine.sql_engine import get_session_with_tenant from onyx.db.engine.sql_engine import SqlEngine from onyx.utils.logger import setup_logger from shared_configs.configs import MULTI_TENANT from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE from tests.regression.search_quality.models import AnalysisSummary from tests.regression.search_quality.models import CombinedMetrics from tests.regression.search_quality.models import EvalConfig from tests.regression.search_quality.models import OneshotQAResult from tests.regression.search_quality.models import TestQuery from tests.regression.search_quality.utils import compute_overall_scores from tests.regression.search_quality.utils import find_document_id from tests.regression.search_quality.utils import get_federated_sources from tests.regression.search_quality.utils import LazyJsonWriter from tests.regression.search_quality.utils import ragas_evaluate from tests.regression.search_quality.utils import search_docs_to_doc_contexts logger = setup_logger(__name__) GENERAL_HEADERS = {"Content-Type": "application/json"} TOP_K_LIST = [1, 3, 5, 10] class SearchAnswerAnalyzer: def __init__( self, config: EvalConfig, tenant_id: str | None = None, ): if not MULTI_TENANT: logger.info("Running in single-tenant mode") tenant_id = POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE elif tenant_id is None: raise ValueError("Tenant ID is required for multi-tenant") self.config = config self.tenant_id = tenant_id # shared analysis results self._lock = Lock() self._progress_counter = 0 self._result_writer: LazyJsonWriter | None = None self.ranks: list[int | None] = [] self.metrics: dict[str, CombinedMetrics] = defaultdict( lambda: CombinedMetrics( total_queries=0, found_count=0, best_rank=config.max_search_results, worst_rank=1, average_rank=0.0, top_k_accuracy={k: 0.0 for k in TOP_K_LIST}, response_relevancy=0.0, faithfulness=0.0, factual_correctness=0.0, n_response_relevancy=0, n_faithfulness=0, n_factual_correctness=0, average_time_taken=0.0, ) ) def run_analysis(self, dataset_path: Path, export_path: Path) -> None: # load and save the dataset dataset = self._load_dataset(dataset_path) dataset_size = len(dataset) dataset_export_path = export_path / "test_queries.json" with dataset_export_path.open("w") as f: dataset_serializable = [q.model_dump(mode="json") for q in dataset] json.dump(dataset_serializable, f, indent=4) result_export_path = export_path / "search_results.json" self._result_writer = LazyJsonWriter(result_export_path) # set up rate limiting and threading primitives interval = ( 60.0 / self.config.max_request_rate if self.config.max_request_rate > 0 else 0.0 ) available_workers = Semaphore(self.config.num_workers) stop_event = Event() def _submit_wrapper(tc: TestQuery) -> AnalysisSummary: try: return self._run_and_analyze_one(tc, dataset_size) except Exception as e: logger.error("Error during analysis: %s", e) stop_event.set() raise finally: available_workers.release() # run the analysis logger.info("Starting analysis of %d queries", dataset_size) logger.info("Using %d parallel workers", self.config.num_workers) logger.info("Exporting search results to %s", result_export_path) with ThreadPoolExecutor( max_workers=self.config.num_workers or None ) as executor: # submit requests at configured rate, break early if any error occurs futures = [] for tc in dataset: if stop_event.is_set(): break available_workers.acquire() fut = executor.submit(_submit_wrapper, tc) futures.append(fut) if ( len(futures) != dataset_size and interval > 0 and not stop_event.is_set() ): time.sleep(interval) # ensure all tasks finish and surface any exceptions for fut in as_completed(futures): fut.result() if self._result_writer: self._result_writer.close() self._aggregate_metrics() def generate_detailed_report(self, export_path: Path) -> None: logger.info("Generating detailed report...") csv_path = export_path / "results_by_category.csv" with csv_path.open("w", newline="") as csv_file: csv_writer = csv.writer(csv_file) csv_writer.writerow( [ "category", "total_queries", "found", "percent_found", "best_rank", "worst_rank", "avg_rank", *[f"top_{k}_accuracy" for k in TOP_K_LIST], *( [ "avg_response_relevancy", "avg_faithfulness", "avg_factual_correctness", ] if not self.config.search_only else [] ), "search_score", *(["answer_score"] if not self.config.search_only else []), "avg_time_taken", ] ) for category, metrics in sorted( self.metrics.items(), key=lambda c: (0 if c[0] == "all" else 1, c[0]) ): found_count = metrics.found_count total_count = metrics.total_queries accuracy = found_count / total_count * 100 if total_count > 0 else 0 print( f"\n{category.upper()}: total queries: {total_count}\n found: {found_count} ({accuracy:.1f}%)" ) best_rank = metrics.best_rank if metrics.found_count > 0 else None worst_rank = metrics.worst_rank if metrics.found_count > 0 else None avg_rank = metrics.average_rank if metrics.found_count > 0 else None if metrics.found_count > 0: print( f" average rank (for found results): {avg_rank:.2f}\n" f" best rank (for found results): {best_rank:.2f}\n" f" worst rank (for found results): {worst_rank:.2f}" ) for k, acc in metrics.top_k_accuracy.items(): print(f" top-{k} accuracy: {acc:.1f}%") if not self.config.search_only: if metrics.n_response_relevancy > 0: print( f" average response relevancy: {metrics.response_relevancy:.2f}" ) if metrics.n_faithfulness > 0: print(f" average faithfulness: {metrics.faithfulness:.2f}") if metrics.n_factual_correctness > 0: print( f" average factual correctness: {metrics.factual_correctness:.2f}" ) search_score, answer_score = compute_overall_scores(metrics) print(f" search score: {search_score:.1f}") if not self.config.search_only: print(f" answer score: {answer_score:.1f}") print(f" average time taken: {metrics.average_time_taken:.2f}s") csv_writer.writerow( [ category, total_count, found_count, f"{accuracy:.1f}", best_rank or "", worst_rank or "", f"{avg_rank:.2f}" if avg_rank is not None else "", *[f"{acc:.1f}" for acc in metrics.top_k_accuracy.values()], *( [ ( f"{metrics.response_relevancy:.2f}" if metrics.n_response_relevancy > 0 else "" ), ( f"{metrics.faithfulness:.2f}" if metrics.n_faithfulness > 0 else "" ), ( f"{metrics.factual_correctness:.2f}" if metrics.n_factual_correctness > 0 else "" ), ] if not self.config.search_only else [] ), f"{search_score:.1f}", *( [f"{answer_score:.1f}"] if not self.config.search_only else [] ), f"{metrics.average_time_taken:.2f}", ] ) logger.info("Saved category breakdown csv to %s", csv_path) def generate_chart(self, export_path: Path) -> None: logger.info("Generating search position chart...") if len(self.ranks) == 0: logger.warning("No results to chart") return found_count = 0 not_found_count = 0 rank_counts: dict[int, int] = defaultdict(int) for rank in self.ranks: if rank is None: not_found_count += 1 else: found_count += 1 rank_counts[rank] += 1 # create the data for plotting if found_count: max_rank = max(rank_counts.keys()) positions = list(range(1, max_rank + 1)) counts = [rank_counts.get(pos, 0) for pos in positions] else: positions = [] counts = [] # add the "not found" bar on the far right if not_found_count: # add some spacing between found positions and "not found" not_found_position = (max(positions) + 2) if positions else 1 positions.append(not_found_position) counts.append(not_found_count) # create labels for x-axis x_labels = [str(pos) for pos in positions[:-1]] + [ f"not found\n(>{self.config.max_search_results})" ] else: x_labels = [str(pos) for pos in positions] # create the figure and bar chart plt.figure(figsize=(14, 6)) # use different colors for found vs not found colors = ( ["#3498db"] * (len(positions) - 1) + ["#e74c3c"] if not_found_count > 0 else ["#3498db"] * len(positions) ) bars = plt.bar( positions, counts, color=colors, alpha=0.7, edgecolor="black", linewidth=0.5 ) # customize the chart plt.xlabel("Position in Search Results", fontsize=12) plt.ylabel("Number of Ground Truth Documents", fontsize=12) plt.title( "Ground Truth Document Positions in Search Results", fontsize=14, fontweight="bold", ) plt.grid(axis="y", alpha=0.3) # add value labels on top of each bar for bar, count in zip(bars, counts): if count > 0: plt.text( bar.get_x() + bar.get_width() / 2, bar.get_height() + 0.1, str(count), ha="center", va="bottom", fontweight="bold", ) # set x-axis labels plt.xticks(positions, x_labels, rotation=45 if not_found_count > 0 else 0) # add legend if we have both found and not found if not_found_count and found_count: legend_elements = [ Patch(facecolor="#3498db", alpha=0.7, label="Found in Results"), Patch(facecolor="#e74c3c", alpha=0.7, label="Not Found"), ] plt.legend(handles=legend_elements, loc="upper right") # make layout tight and save plt.tight_layout() chart_file = export_path / "search_position_chart.png" plt.savefig(chart_file, dpi=300, bbox_inches="tight") logger.info("Search position chart saved to: %s", chart_file) plt.show() def _load_dataset(self, dataset_path: Path) -> list[TestQuery]: """Load the test dataset from a JSON file and validate the ground truth documents.""" with dataset_path.open("r") as f: dataset_raw: list[dict] = json.load(f) with get_session_with_tenant(tenant_id=self.tenant_id) as db_session: federated_sources = get_federated_sources(db_session) dataset: list[TestQuery] = [] for datum in dataset_raw: # validate the raw datum try: test_query = TestQuery(**datum) except ValidationError as e: logger.error("Incorrectly formatted query %s: %s", datum, e) continue # in case the dataset was copied from the previous run export if test_query.ground_truth_docids: dataset.append(test_query) continue # validate and get the ground truth documents with get_session_with_tenant(tenant_id=self.tenant_id) as db_session: for ground_truth in test_query.ground_truth: if ( doc_id := find_document_id( ground_truth, federated_sources, db_session ) ) is not None: test_query.ground_truth_docids.append(doc_id) if len(test_query.ground_truth_docids) == 0: logger.warning( "No ground truth documents found for query: %s, skipping...", test_query.question, ) continue dataset.append(test_query) return dataset @retry(tries=3, delay=1, backoff=2) def _perform_search(self, query: str) -> OneshotQAResult: """Perform a document search query against the Onyx API and time it.""" # create the search request filters = BaseFilters() search_request = SendSearchQueryRequest( search_query=query, filters=filters, num_docs_fed_to_llm_selection=self.config.max_search_results, run_query_expansion=False, stream=False, ) # send the request response = None try: request_data = search_request.model_dump() headers = GENERAL_HEADERS.copy() # Add API key if present if os.environ.get("ONYX_API_KEY"): headers["Authorization"] = f"Bearer {os.environ.get('ONYX_API_KEY')}" start_time = time.monotonic() response = requests.post( url=f"{self.config.api_url}/search/send-search-message", json=request_data, headers=headers, timeout=self.config.request_timeout, ) time_taken = time.monotonic() - start_time response.raise_for_status() result = SearchFullResponse.model_validate(response.json()) # extract documents from the search response if result.search_docs: top_documents = [ SavedSearchDoc.from_search_doc(doc) for doc in result.search_docs[: self.config.max_search_results] ] return OneshotQAResult( time_taken=time_taken, top_documents=top_documents, answer=None, # search endpoint doesn't generate answers ) except RequestException as e: raise RuntimeError( f"Search failed for query '{query}': {e}. Response: {response.json()}" if response else "" ) raise RuntimeError(f"Search returned no documents for query {query}") def _run_and_analyze_one(self, test_case: TestQuery, total: int) -> AnalysisSummary: result = self._perform_search(test_case.question) # compute rank rank = None found = False ground_truths = set(test_case.ground_truth_docids) for i, doc in enumerate(result.top_documents, 1): if doc.document_id in ground_truths: rank = i found = True break # print search progress and result with self._lock: self._progress_counter += 1 completed = self._progress_counter status = "✓ Found" if found else "✗ Not found" rank_info = f" (rank {rank})" if found else "" question_snippet = ( test_case.question[:50] + "..." if len(test_case.question) > 50 else test_case.question ) print(f"[{completed}/{total}] {status}{rank_info}: {question_snippet}") # get the search contents retrieved = search_docs_to_doc_contexts(result.top_documents, self.tenant_id) # do answer evaluation response_relevancy: float | None = None faithfulness: float | None = None factual_correctness: float | None = None contexts = [c.content for c in retrieved[: self.config.max_answer_context]] if not self.config.search_only: if result.answer is None: logger.error( "No answer found for query: %s, skipping answer evaluation", test_case.question, ) else: try: ragas_result = ragas_evaluate( question=test_case.question, answer=result.answer, contexts=contexts, reference_answer=test_case.ground_truth_response, ).scores[0] response_relevancy = ragas_result["answer_relevancy"] faithfulness = ragas_result["faithfulness"] factual_correctness = ragas_result.get( "factual_correctness(mode=recall)" ) except Exception as e: logger.error( "Error evaluating answer for query %s: %s", test_case.question, e, ) # save results analysis = AnalysisSummary( question=test_case.question, categories=test_case.categories, found=found, rank=rank, total_results=len(result.top_documents), ground_truth_count=len(test_case.ground_truth_docids), answer=result.answer, response_relevancy=response_relevancy, faithfulness=faithfulness, factual_correctness=factual_correctness, retrieved=retrieved, time_taken=result.time_taken, ) with self._lock: self.ranks.append(analysis.rank) if self._result_writer: self._result_writer.append(analysis.model_dump(mode="json")) self._update_metrics(analysis) return analysis def _update_metrics(self, result: AnalysisSummary) -> None: for cat in result.categories + ["all"]: self.metrics[cat].total_queries += 1 self.metrics[cat].average_time_taken += result.time_taken if result.found: self.metrics[cat].found_count += 1 rank = cast(int, result.rank) self.metrics[cat].best_rank = min(self.metrics[cat].best_rank, rank) self.metrics[cat].worst_rank = max(self.metrics[cat].worst_rank, rank) self.metrics[cat].average_rank += rank for k in TOP_K_LIST: self.metrics[cat].top_k_accuracy[k] += int(rank <= k) if self.config.search_only: continue if result.response_relevancy is not None: self.metrics[cat].response_relevancy += result.response_relevancy self.metrics[cat].n_response_relevancy += 1 if result.faithfulness is not None: self.metrics[cat].faithfulness += result.faithfulness self.metrics[cat].n_faithfulness += 1 if result.factual_correctness is not None: self.metrics[cat].factual_correctness += result.factual_correctness self.metrics[cat].n_factual_correctness += 1 def _aggregate_metrics(self) -> None: for cat in self.metrics: total = self.metrics[cat].total_queries self.metrics[cat].average_time_taken /= total if self.metrics[cat].found_count > 0: self.metrics[cat].average_rank /= self.metrics[cat].found_count for k in TOP_K_LIST: self.metrics[cat].top_k_accuracy[k] /= total self.metrics[cat].top_k_accuracy[k] *= 100 if self.config.search_only: continue if (n := self.metrics[cat].n_response_relevancy) > 0: self.metrics[cat].response_relevancy /= n if (n := self.metrics[cat].n_faithfulness) > 0: self.metrics[cat].faithfulness /= n if (n := self.metrics[cat].n_factual_correctness) > 0: self.metrics[cat].factual_correctness /= n def run_search_eval( dataset_path: Path, config: EvalConfig, tenant_id: str | None, ) -> None: # check openai api key is set if doing answer eval (must be called that for ragas to recognize) if not config.search_only and not os.environ.get("OPENAI_API_KEY"): raise RuntimeError( "OPENAI_API_KEY is required for answer evaluation. Please add it to the root .vscode/.env file." ) # check onyx api key is set (auth is always required) if not os.environ.get("ONYX_API_KEY"): raise RuntimeError( "ONYX_API_KEY is required. Please create one in the admin panel and add it to the root .vscode/.env file." ) # check onyx is running try: response = requests.get( f"{config.api_url}/health", timeout=config.request_timeout ) response.raise_for_status() except RequestException as e: raise RuntimeError(f"Could not connect to Onyx API: {e}") # create the export folder export_folder = current_dir / datetime.now().strftime("eval-%Y-%m-%d-%H-%M-%S") export_path = Path(export_folder) export_path.mkdir(parents=True, exist_ok=True) logger.info("Created export folder: %s", export_path) # run the search eval analyzer = SearchAnswerAnalyzer(config=config, tenant_id=tenant_id) analyzer.run_analysis(dataset_path, export_path) analyzer.generate_detailed_report(export_path) analyzer.generate_chart(export_path) if __name__ == "__main__": import argparse current_dir = Path(__file__).parent parser = argparse.ArgumentParser(description="Run search quality evaluation.") parser.add_argument( "-d", "--dataset", type=Path, default=current_dir / "test_queries.json", help="Path to the test-set JSON file (default: %(default)s).", ) parser.add_argument( "-n", "--num_search", type=int, default=50, help="Maximum number of documents to retrieve per search (default: %(default)s).", ) parser.add_argument( "-a", "--num_answer", type=int, default=25, help="Maximum number of documents to use for answer evaluation (default: %(default)s).", ) parser.add_argument( "-w", "--max_workers", type=int, default=10, help="Maximum number of concurrent search requests (0 = unlimited, default: %(default)s).", ) parser.add_argument( "-r", "--max_req_rate", type=int, default=0, help="Maximum number of search requests per minute (0 = unlimited, default: %(default)s).", ) parser.add_argument( "-q", "--timeout", type=int, default=120, help="Request timeout in seconds (default: %(default)s).", ) parser.add_argument( "-e", "--api_endpoint", type=str, default="http://127.0.0.1:8080", help="Base URL of the Onyx API server (default: %(default)s).", ) parser.add_argument( "-s", "--search_only", action="store_true", default=False, help="Only perform search and not answer evaluation (default: %(default)s).", ) parser.add_argument( "-t", "--tenant_id", type=str, default=None, help="Tenant ID to use for the evaluation (default: %(default)s).", ) args = parser.parse_args() SqlEngine.init_engine( pool_size=POSTGRES_API_SERVER_POOL_SIZE, max_overflow=POSTGRES_API_SERVER_POOL_OVERFLOW, ) try: run_search_eval( args.dataset, EvalConfig( max_search_results=args.num_search, max_answer_context=args.num_answer, num_workers=args.max_workers, max_request_rate=args.max_req_rate, request_timeout=args.timeout, api_url=args.api_endpoint, search_only=args.search_only, ), args.tenant_id, ) except Exception as e: logger.error("Unexpected error during search evaluation: %s", e) raise finally: SqlEngine.reset_engine() ================================================ FILE: backend/tests/regression/search_quality/test_queries.json.template ================================================ [ { "question": "What is Onyx?", "ground_truth": [ { "doc_source": "web", "doc_link": "https://docs.onyx.app/welcome" } ], "categories": [ "keyword", "broad", "easy" ] } ] ================================================ FILE: backend/tests/regression/search_quality/utils.py ================================================ import json import re from pathlib import Path from textwrap import indent from typing import Any from typing import cast from typing import TextIO from ragas import evaluate # type: ignore[import-not-found,unused-ignore] from ragas import EvaluationDataset # type: ignore[import-not-found,unused-ignore] from ragas import SingleTurnSample # type: ignore[import-not-found,unused-ignore] from ragas.dataset_schema import EvaluationResult # type: ignore[import-not-found,unused-ignore] from ragas.metrics import FactualCorrectness # type: ignore[import-not-found,unused-ignore] from ragas.metrics import Faithfulness # type: ignore[import-not-found,unused-ignore] from ragas.metrics import ResponseRelevancy # type: ignore[import-not-found,unused-ignore] from sqlalchemy.orm import Session from onyx.configs.constants import DocumentSource from onyx.context.search.models import IndexFilters from onyx.context.search.models import SavedSearchDoc from onyx.db.engine.sql_engine import get_session_with_tenant from onyx.db.models import Document from onyx.db.models import FederatedConnector from onyx.db.search_settings import get_current_search_settings from onyx.document_index.factory import get_default_document_index from onyx.document_index.interfaces import VespaChunkRequest from onyx.prompts.prompt_utils import build_doc_context_str from onyx.utils.logger import setup_logger from tests.regression.search_quality.models import CombinedMetrics from tests.regression.search_quality.models import GroundTruth from tests.regression.search_quality.models import RetrievedDocument logger = setup_logger(__name__) def get_federated_sources(db_session: Session) -> set[DocumentSource]: """Get all federated sources from the database.""" return { source for connector in db_session.query(FederatedConnector).all() if (source := connector.source.to_non_federated_source()) is not None } def find_document_id( ground_truth: GroundTruth, federated_sources: set[DocumentSource], db_session: Session, ) -> str | None: """Find a document by its link and return its id if found.""" # handle federated sources TODO: maybe make handler dictionary by source if this gets complex if ground_truth.doc_source in federated_sources: if ground_truth.doc_source == DocumentSource.SLACK: groups = re.search( r"archives\/([A-Z0-9]+)\/p([0-9]+)", ground_truth.doc_link ) if groups: channel_id = groups.group(1) message_id = groups.group(2) return f"{channel_id}__{message_id[:-6]}.{message_id[-6:]}" # preprocess links doc_link = ground_truth.doc_link if ground_truth.doc_source == DocumentSource.GOOGLE_DRIVE: if "/edit" in doc_link: doc_link = doc_link.split("/edit", 1)[0] elif "/view" in doc_link: doc_link = doc_link.split("/view", 1)[0] elif ground_truth.doc_source == DocumentSource.FIREFLIES: doc_link = doc_link.split("?", 1)[0] docs = db_session.query(Document).filter(Document.link.ilike(f"{doc_link}%")).all() if len(docs) == 0: logger.warning("Could not find ground truth document: %s", doc_link) return None elif len(docs) > 1: logger.warning( "Found multiple ground truth documents: %s, using the first one: %s", doc_link, docs[0].id, ) return docs[0].id def get_doc_contents( docs: list[SavedSearchDoc], tenant_id: str ) -> dict[tuple[str, int], str]: with get_session_with_tenant(tenant_id=tenant_id) as db_session: search_settings = get_current_search_settings(db_session) document_index = get_default_document_index(search_settings, None, db_session) filters = IndexFilters(access_control_list=None, tenant_id=tenant_id) reqs: list[VespaChunkRequest] = [ VespaChunkRequest( document_id=doc.document_id, min_chunk_ind=doc.chunk_ind, max_chunk_ind=doc.chunk_ind, ) for doc in docs ] results = document_index.id_based_retrieval(chunk_requests=reqs, filters=filters) return {(doc.document_id, doc.chunk_id): doc.content for doc in results} def search_docs_to_doc_contexts( docs: list[SavedSearchDoc], tenant_id: str ) -> list[RetrievedDocument]: try: doc_contents = get_doc_contents(docs, tenant_id) except Exception as e: logger.error("Error getting doc contents: %s", e) doc_contents = {} return [ RetrievedDocument( document_id=doc.document_id, chunk_id=doc.chunk_ind, content=build_doc_context_str( semantic_identifier=doc.semantic_identifier, source_type=doc.source_type, content=doc_contents.get( (doc.document_id, doc.chunk_ind), f"Blurb: {doc.blurb}" ), metadata_dict=doc.metadata, updated_at=doc.updated_at, ind=ind, include_metadata=True, ), ) for ind, doc in enumerate(docs) ] def ragas_evaluate( question: str, answer: str, contexts: list[str], reference_answer: str | None = None ) -> EvaluationResult: sample = SingleTurnSample( user_input=question, retrieved_contexts=contexts, response=answer, reference=reference_answer, ) dataset = EvaluationDataset([sample]) return cast( EvaluationResult, evaluate( dataset, metrics=[ ResponseRelevancy(), Faithfulness(), *( [FactualCorrectness(mode="recall")] if reference_answer is not None else [] ), ], ), ) def compute_overall_scores(metrics: CombinedMetrics) -> tuple[float, float]: """Compute the overall search and answer quality scores. The scores are subjective and may require tuning.""" # search score FOUND_RATIO_WEIGHT = 0.4 TOP_IMPORTANCE = 0.7 # 0-inf, how important is it to be no. 1 over other ranks found_ratio = metrics.found_count / metrics.total_queries sum_k = sum(1.0 / pow(k, TOP_IMPORTANCE) for k in metrics.top_k_accuracy) weighted_topk = sum( acc / (pow(k, TOP_IMPORTANCE) * sum_k * 100) for k, acc in metrics.top_k_accuracy.items() ) search_score = 100 * ( FOUND_RATIO_WEIGHT * found_ratio + (1.0 - FOUND_RATIO_WEIGHT) * weighted_topk ) # answer score mets = [ *([metrics.response_relevancy] if metrics.n_response_relevancy > 0 else []), *([metrics.faithfulness] if metrics.n_faithfulness > 0 else []), *([metrics.factual_correctness] if metrics.n_factual_correctness > 0 else []), ] answer_score = 100 * sum(mets) / len(mets) if mets else 0.0 return search_score, answer_score class LazyJsonWriter: def __init__(self, filepath: Path, indent: int = 4) -> None: self.filepath = filepath self.file: TextIO | None = None self.indent = indent def append(self, serializable_item: dict[str, Any]) -> None: if not self.file: self.file = open(self.filepath, "a") self.file.write("[\n") else: self.file.write(",\n") data = json.dumps(serializable_item, indent=self.indent) self.file.write(indent(data, " " * self.indent)) def close(self) -> None: if not self.file: return self.file.write("\n]") self.file.close() self.file = None ================================================ FILE: backend/tests/unit/__init__.py ================================================ ================================================ FILE: backend/tests/unit/build/test_rewrite_asset_paths.py ================================================ """Unit tests for webapp proxy path rewriting/injection.""" from types import SimpleNamespace from typing import cast from typing import Literal from uuid import UUID import httpx import pytest from fastapi import Request from sqlalchemy.orm import Session from onyx.server.features.build.api import api from onyx.server.features.build.api.api import _inject_hmr_fixer from onyx.server.features.build.api.api import _rewrite_asset_paths from onyx.server.features.build.api.api import _rewrite_proxy_response_headers SESSION_ID = "aaaaaaaa-bbbb-cccc-dddd-eeeeeeeeeeee" BASE = f"/api/build/sessions/{SESSION_ID}/webapp" def rewrite(html: str) -> str: return _rewrite_asset_paths(html.encode(), SESSION_ID).decode() def inject(html: str) -> str: return _inject_hmr_fixer(html.encode(), SESSION_ID).decode() class TestNextjsPathRewriting: def test_rewrites_bare_next_script_src(self) -> None: html = ' ================================================ FILE: desktop/src/titlebar.js ================================================ // Custom title bar for Onyx Desktop // This script injects a draggable title bar that matches Onyx design system (function () { const TITLEBAR_ID = "onyx-desktop-titlebar"; const TITLEBAR_HEIGHT = 36; const STYLE_ID = "onyx-desktop-titlebar-style"; const VIEWPORT_VAR = "--onyx-desktop-viewport-height"; // Wait for DOM to be ready if (document.readyState === "loading") { document.addEventListener("DOMContentLoaded", init); } else { init(); } function getInvoke() { if (window.__TAURI__?.core?.invoke) return window.__TAURI__.core.invoke; if (window.__TAURI__?.invoke) return window.__TAURI__.invoke; if (window.__TAURI_INTERNALS__?.invoke) return window.__TAURI_INTERNALS__.invoke; return null; } async function startWindowDrag() { const invoke = getInvoke(); if (invoke) { try { await invoke("start_drag_window"); return; } catch (err) {} } const appWindow = window.__TAURI__?.window?.getCurrent?.() ?? window.__TAURI__?.window?.appWindow; if (appWindow?.startDragging) { try { await appWindow.startDragging(); } catch (err) {} } } function injectStyles() { if (document.getElementById(STYLE_ID)) return; const style = document.createElement("style"); style.id = STYLE_ID; style.textContent = ` :root { --onyx-desktop-titlebar-height: ${TITLEBAR_HEIGHT}px; --onyx-desktop-viewport-height: 100dvh; --onyx-desktop-safe-height: calc(var(--onyx-desktop-viewport-height) - var(--onyx-desktop-titlebar-height)); } @supports not (height: 100dvh) { :root { --onyx-desktop-viewport-height: 100vh; } } html, body { height: var(--onyx-desktop-viewport-height); min-height: var(--onyx-desktop-viewport-height); margin: 0; padding: 0; overflow: hidden; } body { padding-top: var(--onyx-desktop-titlebar-height) !important; box-sizing: border-box; } body > div#__next, body > div#root, body > main { height: var(--onyx-desktop-safe-height); min-height: var(--onyx-desktop-safe-height); overflow: auto; } /* Override common Tailwind viewport helpers so content fits under the titlebar */ .h-screen { height: var(--onyx-desktop-safe-height) !important; } .min-h-screen { min-height: var(--onyx-desktop-safe-height) !important; } .max-h-screen { max-height: var(--onyx-desktop-safe-height) !important; } #${TITLEBAR_ID} { cursor: default !important; -webkit-user-select: none !important; user-select: none !important; -webkit-app-region: drag; background: rgba(255, 255, 255, 0.85); height: var(--onyx-desktop-titlebar-height); } /* Dark mode support */ .dark #${TITLEBAR_ID} { background: linear-gradient(180deg, rgba(18, 18, 18, 0.82) 0%, rgba(18, 18, 18, 0.72) 100%); border-bottom-color: rgba(255, 255, 255, 0.08); } `; document.head.appendChild(style); } function updateTitleBarTheme(isDark) { const titleBar = document.getElementById(TITLEBAR_ID); if (!titleBar) return; if (isDark) { titleBar.style.background = "linear-gradient(180deg, rgba(18, 18, 18, 0.82) 0%, rgba(18, 18, 18, 0.72) 100%)"; titleBar.style.borderBottom = "1px solid rgba(255, 255, 255, 0.08)"; titleBar.style.boxShadow = "0 8px 28px rgba(0, 0, 0, 0.2)"; } else { titleBar.style.background = "linear-gradient(180deg, rgba(255, 255, 255, 0.94) 0%, rgba(255, 255, 255, 0.78) 100%)"; titleBar.style.borderBottom = "1px solid rgba(0, 0, 0, 0.06)"; titleBar.style.boxShadow = "0 8px 28px rgba(0, 0, 0, 0.04)"; } } function buildTitleBar() { const titleBar = document.createElement("div"); titleBar.id = TITLEBAR_ID; titleBar.setAttribute("data-tauri-drag-region", ""); titleBar.addEventListener("mousedown", (e) => { // Only start drag on left click and not on buttons/inputs const nonDraggable = [ "BUTTON", "INPUT", "TEXTAREA", "A", "SELECT", "OPTION", ]; if (e.button === 0 && !nonDraggable.includes(e.target.tagName)) { e.preventDefault(); startWindowDrag(); } }); // Apply initial styles matching current theme const htmlHasDark = document.documentElement.classList.contains("dark"); const bodyHasDark = document.body?.classList.contains("dark"); const isDark = htmlHasDark || bodyHasDark; // Apply styles matching Onyx design system with translucent glass effect titleBar.style.cssText = ` position: fixed; top: 0; left: 0; right: 0; height: ${TITLEBAR_HEIGHT}px; background: linear-gradient(180deg, rgba(255, 255, 255, 0.94) 0%, rgba(255, 255, 255, 0.78) 100%); border-bottom: 1px solid rgba(0, 0, 0, 0.06); box-shadow: 0 8px 28px rgba(0, 0, 0, 0.04); z-index: 999999; display: flex; align-items: center; justify-content: center; cursor: default; user-select: none; -webkit-user-select: none; font-family: 'Hanken Grotesk', -apple-system, BlinkMacSystemFont, sans-serif; backdrop-filter: blur(18px) saturate(180%); -webkit-backdrop-filter: blur(18px) saturate(180%); -webkit-app-region: drag; padding: 0 12px; transition: background 0.3s ease, border-bottom 0.3s ease, box-shadow 0.3s ease; `; // Apply correct theme updateTitleBarTheme(isDark); return titleBar; } function mountTitleBar() { if (!document.body) { return; } const existing = document.getElementById(TITLEBAR_ID); if (existing?.parentElement === document.body) { // Update theme on existing titlebar const htmlHasDark = document.documentElement.classList.contains("dark"); const bodyHasDark = document.body?.classList.contains("dark"); const isDark = htmlHasDark || bodyHasDark; updateTitleBarTheme(isDark); return; } if (existing) { existing.remove(); } const titleBar = buildTitleBar(); document.body.insertBefore(titleBar, document.body.firstChild); injectStyles(); // Ensure theme is applied immediately after mount setTimeout(() => { const htmlHasDark = document.documentElement.classList.contains("dark"); const bodyHasDark = document.body?.classList.contains("dark"); const isDark = htmlHasDark || bodyHasDark; updateTitleBarTheme(isDark); }, 0); } function syncViewportHeight() { const viewportHeight = window.visualViewport?.height ?? document.documentElement?.clientHeight ?? window.innerHeight; if (viewportHeight) { document.documentElement.style.setProperty( VIEWPORT_VAR, `${viewportHeight}px`, ); } } function observeThemeChanges() { let lastKnownTheme = null; function checkAndUpdateTheme() { // Check both html and body for dark class (some apps use body) const htmlHasDark = document.documentElement.classList.contains("dark"); const bodyHasDark = document.body?.classList.contains("dark"); const isDark = htmlHasDark || bodyHasDark; if (lastKnownTheme !== isDark) { lastKnownTheme = isDark; updateTitleBarTheme(isDark); } } // Immediate check on setup checkAndUpdateTheme(); // Watch for theme changes on the HTML element const themeObserver = new MutationObserver(() => { checkAndUpdateTheme(); }); themeObserver.observe(document.documentElement, { attributes: true, attributeFilter: ["class"], }); // Also observe body if it exists if (document.body) { const bodyObserver = new MutationObserver(() => { checkAndUpdateTheme(); }); bodyObserver.observe(document.body, { attributes: true, attributeFilter: ["class"], }); } // Also check periodically in case classList is manipulated directly // or the theme loads asynchronously after page load const intervalId = setInterval(() => { checkAndUpdateTheme(); }, 300); // Clean up after 30 seconds once theme should be stable setTimeout(() => { clearInterval(intervalId); // But keep checking every 2 seconds for manual theme changes setInterval(() => { checkAndUpdateTheme(); }, 2000); }, 30000); } function init() { mountTitleBar(); syncViewportHeight(); observeThemeChanges(); window.addEventListener("resize", syncViewportHeight, { passive: true }); window.visualViewport?.addEventListener("resize", syncViewportHeight, { passive: true, }); // Keep it around even if the app DOM re-renders const observer = new MutationObserver(() => { if (!document.getElementById(TITLEBAR_ID)) { mountTitleBar(); } }); observer.observe(document.documentElement, { childList: true, subtree: true, }); // Fallback keep-alive check setInterval(() => { if (!document.getElementById(TITLEBAR_ID)) { mountTitleBar(); } }, 1500); } })(); ================================================ FILE: desktop/src-tauri/Cargo.toml ================================================ [package] name = "onyx" version = "0.0.0-dev" description = "Lightweight desktop app for Onyx Cloud" authors = ["you"] edition = "2021" [build-dependencies] tauri-build = { version = "2.5", features = [] } [dependencies] tauri = { version = "2.10", features = ["macos-private-api", "tray-icon", "image-png"] } tauri-plugin-shell = "2.3.5" tauri-plugin-window-state = "2.4.1" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" uuid = { version = "1.0", features = ["v4"] } directories = "5.0" tokio = { version = "1", features = ["time"] } window-vibrancy = "0.7.1" url = "2.5" [features] default = ["custom-protocol"] custom-protocol = ["tauri/custom-protocol"] devtools = ["tauri/devtools"] ================================================ FILE: desktop/src-tauri/build.rs ================================================ fn main() { tauri_build::build() } ================================================ FILE: desktop/src-tauri/gen/schemas/acl-manifests.json ================================================ {"core":{"default_permission":{"identifier":"default","description":"Default core plugins set.","permissions":["core:path:default","core:event:default","core:window:default","core:webview:default","core:app:default","core:image:default","core:resources:default","core:menu:default","core:tray:default"]},"permissions":{},"permission_sets":{},"global_scope_schema":null},"core:app":{"default_permission":{"identifier":"default","description":"Default permissions for the plugin.","permissions":["allow-version","allow-name","allow-tauri-version","allow-identifier","allow-bundle-type","allow-register-listener","allow-remove-listener"]},"permissions":{"allow-app-hide":{"identifier":"allow-app-hide","description":"Enables the app_hide command without any pre-configured scope.","commands":{"allow":["app_hide"],"deny":[]}},"allow-app-show":{"identifier":"allow-app-show","description":"Enables the app_show command without any pre-configured scope.","commands":{"allow":["app_show"],"deny":[]}},"allow-bundle-type":{"identifier":"allow-bundle-type","description":"Enables the bundle_type command without any pre-configured scope.","commands":{"allow":["bundle_type"],"deny":[]}},"allow-default-window-icon":{"identifier":"allow-default-window-icon","description":"Enables the default_window_icon command without any pre-configured scope.","commands":{"allow":["default_window_icon"],"deny":[]}},"allow-fetch-data-store-identifiers":{"identifier":"allow-fetch-data-store-identifiers","description":"Enables the fetch_data_store_identifiers command without any pre-configured scope.","commands":{"allow":["fetch_data_store_identifiers"],"deny":[]}},"allow-identifier":{"identifier":"allow-identifier","description":"Enables the identifier command without any pre-configured scope.","commands":{"allow":["identifier"],"deny":[]}},"allow-name":{"identifier":"allow-name","description":"Enables the name command without any pre-configured scope.","commands":{"allow":["name"],"deny":[]}},"allow-register-listener":{"identifier":"allow-register-listener","description":"Enables the register_listener command without any pre-configured scope.","commands":{"allow":["register_listener"],"deny":[]}},"allow-remove-data-store":{"identifier":"allow-remove-data-store","description":"Enables the remove_data_store command without any pre-configured scope.","commands":{"allow":["remove_data_store"],"deny":[]}},"allow-remove-listener":{"identifier":"allow-remove-listener","description":"Enables the remove_listener command without any pre-configured scope.","commands":{"allow":["remove_listener"],"deny":[]}},"allow-set-app-theme":{"identifier":"allow-set-app-theme","description":"Enables the set_app_theme command without any pre-configured scope.","commands":{"allow":["set_app_theme"],"deny":[]}},"allow-set-dock-visibility":{"identifier":"allow-set-dock-visibility","description":"Enables the set_dock_visibility command without any pre-configured scope.","commands":{"allow":["set_dock_visibility"],"deny":[]}},"allow-tauri-version":{"identifier":"allow-tauri-version","description":"Enables the tauri_version command without any pre-configured scope.","commands":{"allow":["tauri_version"],"deny":[]}},"allow-version":{"identifier":"allow-version","description":"Enables the version command without any pre-configured scope.","commands":{"allow":["version"],"deny":[]}},"deny-app-hide":{"identifier":"deny-app-hide","description":"Denies the app_hide command without any pre-configured scope.","commands":{"allow":[],"deny":["app_hide"]}},"deny-app-show":{"identifier":"deny-app-show","description":"Denies the app_show command without any pre-configured scope.","commands":{"allow":[],"deny":["app_show"]}},"deny-bundle-type":{"identifier":"deny-bundle-type","description":"Denies the bundle_type command without any pre-configured scope.","commands":{"allow":[],"deny":["bundle_type"]}},"deny-default-window-icon":{"identifier":"deny-default-window-icon","description":"Denies the default_window_icon command without any pre-configured scope.","commands":{"allow":[],"deny":["default_window_icon"]}},"deny-fetch-data-store-identifiers":{"identifier":"deny-fetch-data-store-identifiers","description":"Denies the fetch_data_store_identifiers command without any pre-configured scope.","commands":{"allow":[],"deny":["fetch_data_store_identifiers"]}},"deny-identifier":{"identifier":"deny-identifier","description":"Denies the identifier command without any pre-configured scope.","commands":{"allow":[],"deny":["identifier"]}},"deny-name":{"identifier":"deny-name","description":"Denies the name command without any pre-configured scope.","commands":{"allow":[],"deny":["name"]}},"deny-register-listener":{"identifier":"deny-register-listener","description":"Denies the register_listener command without any pre-configured scope.","commands":{"allow":[],"deny":["register_listener"]}},"deny-remove-data-store":{"identifier":"deny-remove-data-store","description":"Denies the remove_data_store command without any pre-configured scope.","commands":{"allow":[],"deny":["remove_data_store"]}},"deny-remove-listener":{"identifier":"deny-remove-listener","description":"Denies the remove_listener command without any pre-configured scope.","commands":{"allow":[],"deny":["remove_listener"]}},"deny-set-app-theme":{"identifier":"deny-set-app-theme","description":"Denies the set_app_theme command without any pre-configured scope.","commands":{"allow":[],"deny":["set_app_theme"]}},"deny-set-dock-visibility":{"identifier":"deny-set-dock-visibility","description":"Denies the set_dock_visibility command without any pre-configured scope.","commands":{"allow":[],"deny":["set_dock_visibility"]}},"deny-tauri-version":{"identifier":"deny-tauri-version","description":"Denies the tauri_version command without any pre-configured scope.","commands":{"allow":[],"deny":["tauri_version"]}},"deny-version":{"identifier":"deny-version","description":"Denies the version command without any pre-configured scope.","commands":{"allow":[],"deny":["version"]}}},"permission_sets":{},"global_scope_schema":null},"core:event":{"default_permission":{"identifier":"default","description":"Default permissions for the plugin, which enables all commands.","permissions":["allow-listen","allow-unlisten","allow-emit","allow-emit-to"]},"permissions":{"allow-emit":{"identifier":"allow-emit","description":"Enables the emit command without any pre-configured scope.","commands":{"allow":["emit"],"deny":[]}},"allow-emit-to":{"identifier":"allow-emit-to","description":"Enables the emit_to command without any pre-configured scope.","commands":{"allow":["emit_to"],"deny":[]}},"allow-listen":{"identifier":"allow-listen","description":"Enables the listen command without any pre-configured scope.","commands":{"allow":["listen"],"deny":[]}},"allow-unlisten":{"identifier":"allow-unlisten","description":"Enables the unlisten command without any pre-configured scope.","commands":{"allow":["unlisten"],"deny":[]}},"deny-emit":{"identifier":"deny-emit","description":"Denies the emit command without any pre-configured scope.","commands":{"allow":[],"deny":["emit"]}},"deny-emit-to":{"identifier":"deny-emit-to","description":"Denies the emit_to command without any pre-configured scope.","commands":{"allow":[],"deny":["emit_to"]}},"deny-listen":{"identifier":"deny-listen","description":"Denies the listen command without any pre-configured scope.","commands":{"allow":[],"deny":["listen"]}},"deny-unlisten":{"identifier":"deny-unlisten","description":"Denies the unlisten command without any pre-configured scope.","commands":{"allow":[],"deny":["unlisten"]}}},"permission_sets":{},"global_scope_schema":null},"core:image":{"default_permission":{"identifier":"default","description":"Default permissions for the plugin, which enables all commands.","permissions":["allow-new","allow-from-bytes","allow-from-path","allow-rgba","allow-size"]},"permissions":{"allow-from-bytes":{"identifier":"allow-from-bytes","description":"Enables the from_bytes command without any pre-configured scope.","commands":{"allow":["from_bytes"],"deny":[]}},"allow-from-path":{"identifier":"allow-from-path","description":"Enables the from_path command without any pre-configured scope.","commands":{"allow":["from_path"],"deny":[]}},"allow-new":{"identifier":"allow-new","description":"Enables the new command without any pre-configured scope.","commands":{"allow":["new"],"deny":[]}},"allow-rgba":{"identifier":"allow-rgba","description":"Enables the rgba command without any pre-configured scope.","commands":{"allow":["rgba"],"deny":[]}},"allow-size":{"identifier":"allow-size","description":"Enables the size command without any pre-configured scope.","commands":{"allow":["size"],"deny":[]}},"deny-from-bytes":{"identifier":"deny-from-bytes","description":"Denies the from_bytes command without any pre-configured scope.","commands":{"allow":[],"deny":["from_bytes"]}},"deny-from-path":{"identifier":"deny-from-path","description":"Denies the from_path command without any pre-configured scope.","commands":{"allow":[],"deny":["from_path"]}},"deny-new":{"identifier":"deny-new","description":"Denies the new command without any pre-configured scope.","commands":{"allow":[],"deny":["new"]}},"deny-rgba":{"identifier":"deny-rgba","description":"Denies the rgba command without any pre-configured scope.","commands":{"allow":[],"deny":["rgba"]}},"deny-size":{"identifier":"deny-size","description":"Denies the size command without any pre-configured scope.","commands":{"allow":[],"deny":["size"]}}},"permission_sets":{},"global_scope_schema":null},"core:menu":{"default_permission":{"identifier":"default","description":"Default permissions for the plugin, which enables all commands.","permissions":["allow-new","allow-append","allow-prepend","allow-insert","allow-remove","allow-remove-at","allow-items","allow-get","allow-popup","allow-create-default","allow-set-as-app-menu","allow-set-as-window-menu","allow-text","allow-set-text","allow-is-enabled","allow-set-enabled","allow-set-accelerator","allow-set-as-windows-menu-for-nsapp","allow-set-as-help-menu-for-nsapp","allow-is-checked","allow-set-checked","allow-set-icon"]},"permissions":{"allow-append":{"identifier":"allow-append","description":"Enables the append command without any pre-configured scope.","commands":{"allow":["append"],"deny":[]}},"allow-create-default":{"identifier":"allow-create-default","description":"Enables the create_default command without any pre-configured scope.","commands":{"allow":["create_default"],"deny":[]}},"allow-get":{"identifier":"allow-get","description":"Enables the get command without any pre-configured scope.","commands":{"allow":["get"],"deny":[]}},"allow-insert":{"identifier":"allow-insert","description":"Enables the insert command without any pre-configured scope.","commands":{"allow":["insert"],"deny":[]}},"allow-is-checked":{"identifier":"allow-is-checked","description":"Enables the is_checked command without any pre-configured scope.","commands":{"allow":["is_checked"],"deny":[]}},"allow-is-enabled":{"identifier":"allow-is-enabled","description":"Enables the is_enabled command without any pre-configured scope.","commands":{"allow":["is_enabled"],"deny":[]}},"allow-items":{"identifier":"allow-items","description":"Enables the items command without any pre-configured scope.","commands":{"allow":["items"],"deny":[]}},"allow-new":{"identifier":"allow-new","description":"Enables the new command without any pre-configured scope.","commands":{"allow":["new"],"deny":[]}},"allow-popup":{"identifier":"allow-popup","description":"Enables the popup command without any pre-configured scope.","commands":{"allow":["popup"],"deny":[]}},"allow-prepend":{"identifier":"allow-prepend","description":"Enables the prepend command without any pre-configured scope.","commands":{"allow":["prepend"],"deny":[]}},"allow-remove":{"identifier":"allow-remove","description":"Enables the remove command without any pre-configured scope.","commands":{"allow":["remove"],"deny":[]}},"allow-remove-at":{"identifier":"allow-remove-at","description":"Enables the remove_at command without any pre-configured scope.","commands":{"allow":["remove_at"],"deny":[]}},"allow-set-accelerator":{"identifier":"allow-set-accelerator","description":"Enables the set_accelerator command without any pre-configured scope.","commands":{"allow":["set_accelerator"],"deny":[]}},"allow-set-as-app-menu":{"identifier":"allow-set-as-app-menu","description":"Enables the set_as_app_menu command without any pre-configured scope.","commands":{"allow":["set_as_app_menu"],"deny":[]}},"allow-set-as-help-menu-for-nsapp":{"identifier":"allow-set-as-help-menu-for-nsapp","description":"Enables the set_as_help_menu_for_nsapp command without any pre-configured scope.","commands":{"allow":["set_as_help_menu_for_nsapp"],"deny":[]}},"allow-set-as-window-menu":{"identifier":"allow-set-as-window-menu","description":"Enables the set_as_window_menu command without any pre-configured scope.","commands":{"allow":["set_as_window_menu"],"deny":[]}},"allow-set-as-windows-menu-for-nsapp":{"identifier":"allow-set-as-windows-menu-for-nsapp","description":"Enables the set_as_windows_menu_for_nsapp command without any pre-configured scope.","commands":{"allow":["set_as_windows_menu_for_nsapp"],"deny":[]}},"allow-set-checked":{"identifier":"allow-set-checked","description":"Enables the set_checked command without any pre-configured scope.","commands":{"allow":["set_checked"],"deny":[]}},"allow-set-enabled":{"identifier":"allow-set-enabled","description":"Enables the set_enabled command without any pre-configured scope.","commands":{"allow":["set_enabled"],"deny":[]}},"allow-set-icon":{"identifier":"allow-set-icon","description":"Enables the set_icon command without any pre-configured scope.","commands":{"allow":["set_icon"],"deny":[]}},"allow-set-text":{"identifier":"allow-set-text","description":"Enables the set_text command without any pre-configured scope.","commands":{"allow":["set_text"],"deny":[]}},"allow-text":{"identifier":"allow-text","description":"Enables the text command without any pre-configured scope.","commands":{"allow":["text"],"deny":[]}},"deny-append":{"identifier":"deny-append","description":"Denies the append command without any pre-configured scope.","commands":{"allow":[],"deny":["append"]}},"deny-create-default":{"identifier":"deny-create-default","description":"Denies the create_default command without any pre-configured scope.","commands":{"allow":[],"deny":["create_default"]}},"deny-get":{"identifier":"deny-get","description":"Denies the get command without any pre-configured scope.","commands":{"allow":[],"deny":["get"]}},"deny-insert":{"identifier":"deny-insert","description":"Denies the insert command without any pre-configured scope.","commands":{"allow":[],"deny":["insert"]}},"deny-is-checked":{"identifier":"deny-is-checked","description":"Denies the is_checked command without any pre-configured scope.","commands":{"allow":[],"deny":["is_checked"]}},"deny-is-enabled":{"identifier":"deny-is-enabled","description":"Denies the is_enabled command without any pre-configured scope.","commands":{"allow":[],"deny":["is_enabled"]}},"deny-items":{"identifier":"deny-items","description":"Denies the items command without any pre-configured scope.","commands":{"allow":[],"deny":["items"]}},"deny-new":{"identifier":"deny-new","description":"Denies the new command without any pre-configured scope.","commands":{"allow":[],"deny":["new"]}},"deny-popup":{"identifier":"deny-popup","description":"Denies the popup command without any pre-configured scope.","commands":{"allow":[],"deny":["popup"]}},"deny-prepend":{"identifier":"deny-prepend","description":"Denies the prepend command without any pre-configured scope.","commands":{"allow":[],"deny":["prepend"]}},"deny-remove":{"identifier":"deny-remove","description":"Denies the remove command without any pre-configured scope.","commands":{"allow":[],"deny":["remove"]}},"deny-remove-at":{"identifier":"deny-remove-at","description":"Denies the remove_at command without any pre-configured scope.","commands":{"allow":[],"deny":["remove_at"]}},"deny-set-accelerator":{"identifier":"deny-set-accelerator","description":"Denies the set_accelerator command without any pre-configured scope.","commands":{"allow":[],"deny":["set_accelerator"]}},"deny-set-as-app-menu":{"identifier":"deny-set-as-app-menu","description":"Denies the set_as_app_menu command without any pre-configured scope.","commands":{"allow":[],"deny":["set_as_app_menu"]}},"deny-set-as-help-menu-for-nsapp":{"identifier":"deny-set-as-help-menu-for-nsapp","description":"Denies the set_as_help_menu_for_nsapp command without any pre-configured scope.","commands":{"allow":[],"deny":["set_as_help_menu_for_nsapp"]}},"deny-set-as-window-menu":{"identifier":"deny-set-as-window-menu","description":"Denies the set_as_window_menu command without any pre-configured scope.","commands":{"allow":[],"deny":["set_as_window_menu"]}},"deny-set-as-windows-menu-for-nsapp":{"identifier":"deny-set-as-windows-menu-for-nsapp","description":"Denies the set_as_windows_menu_for_nsapp command without any pre-configured scope.","commands":{"allow":[],"deny":["set_as_windows_menu_for_nsapp"]}},"deny-set-checked":{"identifier":"deny-set-checked","description":"Denies the set_checked command without any pre-configured scope.","commands":{"allow":[],"deny":["set_checked"]}},"deny-set-enabled":{"identifier":"deny-set-enabled","description":"Denies the set_enabled command without any pre-configured scope.","commands":{"allow":[],"deny":["set_enabled"]}},"deny-set-icon":{"identifier":"deny-set-icon","description":"Denies the set_icon command without any pre-configured scope.","commands":{"allow":[],"deny":["set_icon"]}},"deny-set-text":{"identifier":"deny-set-text","description":"Denies the set_text command without any pre-configured scope.","commands":{"allow":[],"deny":["set_text"]}},"deny-text":{"identifier":"deny-text","description":"Denies the text command without any pre-configured scope.","commands":{"allow":[],"deny":["text"]}}},"permission_sets":{},"global_scope_schema":null},"core:path":{"default_permission":{"identifier":"default","description":"Default permissions for the plugin, which enables all commands.","permissions":["allow-resolve-directory","allow-resolve","allow-normalize","allow-join","allow-dirname","allow-extname","allow-basename","allow-is-absolute"]},"permissions":{"allow-basename":{"identifier":"allow-basename","description":"Enables the basename command without any pre-configured scope.","commands":{"allow":["basename"],"deny":[]}},"allow-dirname":{"identifier":"allow-dirname","description":"Enables the dirname command without any pre-configured scope.","commands":{"allow":["dirname"],"deny":[]}},"allow-extname":{"identifier":"allow-extname","description":"Enables the extname command without any pre-configured scope.","commands":{"allow":["extname"],"deny":[]}},"allow-is-absolute":{"identifier":"allow-is-absolute","description":"Enables the is_absolute command without any pre-configured scope.","commands":{"allow":["is_absolute"],"deny":[]}},"allow-join":{"identifier":"allow-join","description":"Enables the join command without any pre-configured scope.","commands":{"allow":["join"],"deny":[]}},"allow-normalize":{"identifier":"allow-normalize","description":"Enables the normalize command without any pre-configured scope.","commands":{"allow":["normalize"],"deny":[]}},"allow-resolve":{"identifier":"allow-resolve","description":"Enables the resolve command without any pre-configured scope.","commands":{"allow":["resolve"],"deny":[]}},"allow-resolve-directory":{"identifier":"allow-resolve-directory","description":"Enables the resolve_directory command without any pre-configured scope.","commands":{"allow":["resolve_directory"],"deny":[]}},"deny-basename":{"identifier":"deny-basename","description":"Denies the basename command without any pre-configured scope.","commands":{"allow":[],"deny":["basename"]}},"deny-dirname":{"identifier":"deny-dirname","description":"Denies the dirname command without any pre-configured scope.","commands":{"allow":[],"deny":["dirname"]}},"deny-extname":{"identifier":"deny-extname","description":"Denies the extname command without any pre-configured scope.","commands":{"allow":[],"deny":["extname"]}},"deny-is-absolute":{"identifier":"deny-is-absolute","description":"Denies the is_absolute command without any pre-configured scope.","commands":{"allow":[],"deny":["is_absolute"]}},"deny-join":{"identifier":"deny-join","description":"Denies the join command without any pre-configured scope.","commands":{"allow":[],"deny":["join"]}},"deny-normalize":{"identifier":"deny-normalize","description":"Denies the normalize command without any pre-configured scope.","commands":{"allow":[],"deny":["normalize"]}},"deny-resolve":{"identifier":"deny-resolve","description":"Denies the resolve command without any pre-configured scope.","commands":{"allow":[],"deny":["resolve"]}},"deny-resolve-directory":{"identifier":"deny-resolve-directory","description":"Denies the resolve_directory command without any pre-configured scope.","commands":{"allow":[],"deny":["resolve_directory"]}}},"permission_sets":{},"global_scope_schema":null},"core:resources":{"default_permission":{"identifier":"default","description":"Default permissions for the plugin, which enables all commands.","permissions":["allow-close"]},"permissions":{"allow-close":{"identifier":"allow-close","description":"Enables the close command without any pre-configured scope.","commands":{"allow":["close"],"deny":[]}},"deny-close":{"identifier":"deny-close","description":"Denies the close command without any pre-configured scope.","commands":{"allow":[],"deny":["close"]}}},"permission_sets":{},"global_scope_schema":null},"core:tray":{"default_permission":{"identifier":"default","description":"Default permissions for the plugin, which enables all commands.","permissions":["allow-new","allow-get-by-id","allow-remove-by-id","allow-set-icon","allow-set-menu","allow-set-tooltip","allow-set-title","allow-set-visible","allow-set-temp-dir-path","allow-set-icon-as-template","allow-set-show-menu-on-left-click"]},"permissions":{"allow-get-by-id":{"identifier":"allow-get-by-id","description":"Enables the get_by_id command without any pre-configured scope.","commands":{"allow":["get_by_id"],"deny":[]}},"allow-new":{"identifier":"allow-new","description":"Enables the new command without any pre-configured scope.","commands":{"allow":["new"],"deny":[]}},"allow-remove-by-id":{"identifier":"allow-remove-by-id","description":"Enables the remove_by_id command without any pre-configured scope.","commands":{"allow":["remove_by_id"],"deny":[]}},"allow-set-icon":{"identifier":"allow-set-icon","description":"Enables the set_icon command without any pre-configured scope.","commands":{"allow":["set_icon"],"deny":[]}},"allow-set-icon-as-template":{"identifier":"allow-set-icon-as-template","description":"Enables the set_icon_as_template command without any pre-configured scope.","commands":{"allow":["set_icon_as_template"],"deny":[]}},"allow-set-menu":{"identifier":"allow-set-menu","description":"Enables the set_menu command without any pre-configured scope.","commands":{"allow":["set_menu"],"deny":[]}},"allow-set-show-menu-on-left-click":{"identifier":"allow-set-show-menu-on-left-click","description":"Enables the set_show_menu_on_left_click command without any pre-configured scope.","commands":{"allow":["set_show_menu_on_left_click"],"deny":[]}},"allow-set-temp-dir-path":{"identifier":"allow-set-temp-dir-path","description":"Enables the set_temp_dir_path command without any pre-configured scope.","commands":{"allow":["set_temp_dir_path"],"deny":[]}},"allow-set-title":{"identifier":"allow-set-title","description":"Enables the set_title command without any pre-configured scope.","commands":{"allow":["set_title"],"deny":[]}},"allow-set-tooltip":{"identifier":"allow-set-tooltip","description":"Enables the set_tooltip command without any pre-configured scope.","commands":{"allow":["set_tooltip"],"deny":[]}},"allow-set-visible":{"identifier":"allow-set-visible","description":"Enables the set_visible command without any pre-configured scope.","commands":{"allow":["set_visible"],"deny":[]}},"deny-get-by-id":{"identifier":"deny-get-by-id","description":"Denies the get_by_id command without any pre-configured scope.","commands":{"allow":[],"deny":["get_by_id"]}},"deny-new":{"identifier":"deny-new","description":"Denies the new command without any pre-configured scope.","commands":{"allow":[],"deny":["new"]}},"deny-remove-by-id":{"identifier":"deny-remove-by-id","description":"Denies the remove_by_id command without any pre-configured scope.","commands":{"allow":[],"deny":["remove_by_id"]}},"deny-set-icon":{"identifier":"deny-set-icon","description":"Denies the set_icon command without any pre-configured scope.","commands":{"allow":[],"deny":["set_icon"]}},"deny-set-icon-as-template":{"identifier":"deny-set-icon-as-template","description":"Denies the set_icon_as_template command without any pre-configured scope.","commands":{"allow":[],"deny":["set_icon_as_template"]}},"deny-set-menu":{"identifier":"deny-set-menu","description":"Denies the set_menu command without any pre-configured scope.","commands":{"allow":[],"deny":["set_menu"]}},"deny-set-show-menu-on-left-click":{"identifier":"deny-set-show-menu-on-left-click","description":"Denies the set_show_menu_on_left_click command without any pre-configured scope.","commands":{"allow":[],"deny":["set_show_menu_on_left_click"]}},"deny-set-temp-dir-path":{"identifier":"deny-set-temp-dir-path","description":"Denies the set_temp_dir_path command without any pre-configured scope.","commands":{"allow":[],"deny":["set_temp_dir_path"]}},"deny-set-title":{"identifier":"deny-set-title","description":"Denies the set_title command without any pre-configured scope.","commands":{"allow":[],"deny":["set_title"]}},"deny-set-tooltip":{"identifier":"deny-set-tooltip","description":"Denies the set_tooltip command without any pre-configured scope.","commands":{"allow":[],"deny":["set_tooltip"]}},"deny-set-visible":{"identifier":"deny-set-visible","description":"Denies the set_visible command without any pre-configured scope.","commands":{"allow":[],"deny":["set_visible"]}}},"permission_sets":{},"global_scope_schema":null},"core:webview":{"default_permission":{"identifier":"default","description":"Default permissions for the plugin.","permissions":["allow-get-all-webviews","allow-webview-position","allow-webview-size","allow-internal-toggle-devtools"]},"permissions":{"allow-clear-all-browsing-data":{"identifier":"allow-clear-all-browsing-data","description":"Enables the clear_all_browsing_data command without any pre-configured scope.","commands":{"allow":["clear_all_browsing_data"],"deny":[]}},"allow-create-webview":{"identifier":"allow-create-webview","description":"Enables the create_webview command without any pre-configured scope.","commands":{"allow":["create_webview"],"deny":[]}},"allow-create-webview-window":{"identifier":"allow-create-webview-window","description":"Enables the create_webview_window command without any pre-configured scope.","commands":{"allow":["create_webview_window"],"deny":[]}},"allow-get-all-webviews":{"identifier":"allow-get-all-webviews","description":"Enables the get_all_webviews command without any pre-configured scope.","commands":{"allow":["get_all_webviews"],"deny":[]}},"allow-internal-toggle-devtools":{"identifier":"allow-internal-toggle-devtools","description":"Enables the internal_toggle_devtools command without any pre-configured scope.","commands":{"allow":["internal_toggle_devtools"],"deny":[]}},"allow-print":{"identifier":"allow-print","description":"Enables the print command without any pre-configured scope.","commands":{"allow":["print"],"deny":[]}},"allow-reparent":{"identifier":"allow-reparent","description":"Enables the reparent command without any pre-configured scope.","commands":{"allow":["reparent"],"deny":[]}},"allow-set-webview-auto-resize":{"identifier":"allow-set-webview-auto-resize","description":"Enables the set_webview_auto_resize command without any pre-configured scope.","commands":{"allow":["set_webview_auto_resize"],"deny":[]}},"allow-set-webview-background-color":{"identifier":"allow-set-webview-background-color","description":"Enables the set_webview_background_color command without any pre-configured scope.","commands":{"allow":["set_webview_background_color"],"deny":[]}},"allow-set-webview-focus":{"identifier":"allow-set-webview-focus","description":"Enables the set_webview_focus command without any pre-configured scope.","commands":{"allow":["set_webview_focus"],"deny":[]}},"allow-set-webview-position":{"identifier":"allow-set-webview-position","description":"Enables the set_webview_position command without any pre-configured scope.","commands":{"allow":["set_webview_position"],"deny":[]}},"allow-set-webview-size":{"identifier":"allow-set-webview-size","description":"Enables the set_webview_size command without any pre-configured scope.","commands":{"allow":["set_webview_size"],"deny":[]}},"allow-set-webview-zoom":{"identifier":"allow-set-webview-zoom","description":"Enables the set_webview_zoom command without any pre-configured scope.","commands":{"allow":["set_webview_zoom"],"deny":[]}},"allow-webview-close":{"identifier":"allow-webview-close","description":"Enables the webview_close command without any pre-configured scope.","commands":{"allow":["webview_close"],"deny":[]}},"allow-webview-hide":{"identifier":"allow-webview-hide","description":"Enables the webview_hide command without any pre-configured scope.","commands":{"allow":["webview_hide"],"deny":[]}},"allow-webview-position":{"identifier":"allow-webview-position","description":"Enables the webview_position command without any pre-configured scope.","commands":{"allow":["webview_position"],"deny":[]}},"allow-webview-show":{"identifier":"allow-webview-show","description":"Enables the webview_show command without any pre-configured scope.","commands":{"allow":["webview_show"],"deny":[]}},"allow-webview-size":{"identifier":"allow-webview-size","description":"Enables the webview_size command without any pre-configured scope.","commands":{"allow":["webview_size"],"deny":[]}},"deny-clear-all-browsing-data":{"identifier":"deny-clear-all-browsing-data","description":"Denies the clear_all_browsing_data command without any pre-configured scope.","commands":{"allow":[],"deny":["clear_all_browsing_data"]}},"deny-create-webview":{"identifier":"deny-create-webview","description":"Denies the create_webview command without any pre-configured scope.","commands":{"allow":[],"deny":["create_webview"]}},"deny-create-webview-window":{"identifier":"deny-create-webview-window","description":"Denies the create_webview_window command without any pre-configured scope.","commands":{"allow":[],"deny":["create_webview_window"]}},"deny-get-all-webviews":{"identifier":"deny-get-all-webviews","description":"Denies the get_all_webviews command without any pre-configured scope.","commands":{"allow":[],"deny":["get_all_webviews"]}},"deny-internal-toggle-devtools":{"identifier":"deny-internal-toggle-devtools","description":"Denies the internal_toggle_devtools command without any pre-configured scope.","commands":{"allow":[],"deny":["internal_toggle_devtools"]}},"deny-print":{"identifier":"deny-print","description":"Denies the print command without any pre-configured scope.","commands":{"allow":[],"deny":["print"]}},"deny-reparent":{"identifier":"deny-reparent","description":"Denies the reparent command without any pre-configured scope.","commands":{"allow":[],"deny":["reparent"]}},"deny-set-webview-auto-resize":{"identifier":"deny-set-webview-auto-resize","description":"Denies the set_webview_auto_resize command without any pre-configured scope.","commands":{"allow":[],"deny":["set_webview_auto_resize"]}},"deny-set-webview-background-color":{"identifier":"deny-set-webview-background-color","description":"Denies the set_webview_background_color command without any pre-configured scope.","commands":{"allow":[],"deny":["set_webview_background_color"]}},"deny-set-webview-focus":{"identifier":"deny-set-webview-focus","description":"Denies the set_webview_focus command without any pre-configured scope.","commands":{"allow":[],"deny":["set_webview_focus"]}},"deny-set-webview-position":{"identifier":"deny-set-webview-position","description":"Denies the set_webview_position command without any pre-configured scope.","commands":{"allow":[],"deny":["set_webview_position"]}},"deny-set-webview-size":{"identifier":"deny-set-webview-size","description":"Denies the set_webview_size command without any pre-configured scope.","commands":{"allow":[],"deny":["set_webview_size"]}},"deny-set-webview-zoom":{"identifier":"deny-set-webview-zoom","description":"Denies the set_webview_zoom command without any pre-configured scope.","commands":{"allow":[],"deny":["set_webview_zoom"]}},"deny-webview-close":{"identifier":"deny-webview-close","description":"Denies the webview_close command without any pre-configured scope.","commands":{"allow":[],"deny":["webview_close"]}},"deny-webview-hide":{"identifier":"deny-webview-hide","description":"Denies the webview_hide command without any pre-configured scope.","commands":{"allow":[],"deny":["webview_hide"]}},"deny-webview-position":{"identifier":"deny-webview-position","description":"Denies the webview_position command without any pre-configured scope.","commands":{"allow":[],"deny":["webview_position"]}},"deny-webview-show":{"identifier":"deny-webview-show","description":"Denies the webview_show command without any pre-configured scope.","commands":{"allow":[],"deny":["webview_show"]}},"deny-webview-size":{"identifier":"deny-webview-size","description":"Denies the webview_size command without any pre-configured scope.","commands":{"allow":[],"deny":["webview_size"]}}},"permission_sets":{},"global_scope_schema":null},"core:window":{"default_permission":{"identifier":"default","description":"Default permissions for the plugin.","permissions":["allow-get-all-windows","allow-scale-factor","allow-inner-position","allow-outer-position","allow-inner-size","allow-outer-size","allow-is-fullscreen","allow-is-minimized","allow-is-maximized","allow-is-focused","allow-is-decorated","allow-is-resizable","allow-is-maximizable","allow-is-minimizable","allow-is-closable","allow-is-visible","allow-is-enabled","allow-title","allow-current-monitor","allow-primary-monitor","allow-monitor-from-point","allow-available-monitors","allow-cursor-position","allow-theme","allow-is-always-on-top","allow-internal-toggle-maximize"]},"permissions":{"allow-available-monitors":{"identifier":"allow-available-monitors","description":"Enables the available_monitors command without any pre-configured scope.","commands":{"allow":["available_monitors"],"deny":[]}},"allow-center":{"identifier":"allow-center","description":"Enables the center command without any pre-configured scope.","commands":{"allow":["center"],"deny":[]}},"allow-close":{"identifier":"allow-close","description":"Enables the close command without any pre-configured scope.","commands":{"allow":["close"],"deny":[]}},"allow-create":{"identifier":"allow-create","description":"Enables the create command without any pre-configured scope.","commands":{"allow":["create"],"deny":[]}},"allow-current-monitor":{"identifier":"allow-current-monitor","description":"Enables the current_monitor command without any pre-configured scope.","commands":{"allow":["current_monitor"],"deny":[]}},"allow-cursor-position":{"identifier":"allow-cursor-position","description":"Enables the cursor_position command without any pre-configured scope.","commands":{"allow":["cursor_position"],"deny":[]}},"allow-destroy":{"identifier":"allow-destroy","description":"Enables the destroy command without any pre-configured scope.","commands":{"allow":["destroy"],"deny":[]}},"allow-get-all-windows":{"identifier":"allow-get-all-windows","description":"Enables the get_all_windows command without any pre-configured scope.","commands":{"allow":["get_all_windows"],"deny":[]}},"allow-hide":{"identifier":"allow-hide","description":"Enables the hide command without any pre-configured scope.","commands":{"allow":["hide"],"deny":[]}},"allow-inner-position":{"identifier":"allow-inner-position","description":"Enables the inner_position command without any pre-configured scope.","commands":{"allow":["inner_position"],"deny":[]}},"allow-inner-size":{"identifier":"allow-inner-size","description":"Enables the inner_size command without any pre-configured scope.","commands":{"allow":["inner_size"],"deny":[]}},"allow-internal-toggle-maximize":{"identifier":"allow-internal-toggle-maximize","description":"Enables the internal_toggle_maximize command without any pre-configured scope.","commands":{"allow":["internal_toggle_maximize"],"deny":[]}},"allow-is-always-on-top":{"identifier":"allow-is-always-on-top","description":"Enables the is_always_on_top command without any pre-configured scope.","commands":{"allow":["is_always_on_top"],"deny":[]}},"allow-is-closable":{"identifier":"allow-is-closable","description":"Enables the is_closable command without any pre-configured scope.","commands":{"allow":["is_closable"],"deny":[]}},"allow-is-decorated":{"identifier":"allow-is-decorated","description":"Enables the is_decorated command without any pre-configured scope.","commands":{"allow":["is_decorated"],"deny":[]}},"allow-is-enabled":{"identifier":"allow-is-enabled","description":"Enables the is_enabled command without any pre-configured scope.","commands":{"allow":["is_enabled"],"deny":[]}},"allow-is-focused":{"identifier":"allow-is-focused","description":"Enables the is_focused command without any pre-configured scope.","commands":{"allow":["is_focused"],"deny":[]}},"allow-is-fullscreen":{"identifier":"allow-is-fullscreen","description":"Enables the is_fullscreen command without any pre-configured scope.","commands":{"allow":["is_fullscreen"],"deny":[]}},"allow-is-maximizable":{"identifier":"allow-is-maximizable","description":"Enables the is_maximizable command without any pre-configured scope.","commands":{"allow":["is_maximizable"],"deny":[]}},"allow-is-maximized":{"identifier":"allow-is-maximized","description":"Enables the is_maximized command without any pre-configured scope.","commands":{"allow":["is_maximized"],"deny":[]}},"allow-is-minimizable":{"identifier":"allow-is-minimizable","description":"Enables the is_minimizable command without any pre-configured scope.","commands":{"allow":["is_minimizable"],"deny":[]}},"allow-is-minimized":{"identifier":"allow-is-minimized","description":"Enables the is_minimized command without any pre-configured scope.","commands":{"allow":["is_minimized"],"deny":[]}},"allow-is-resizable":{"identifier":"allow-is-resizable","description":"Enables the is_resizable command without any pre-configured scope.","commands":{"allow":["is_resizable"],"deny":[]}},"allow-is-visible":{"identifier":"allow-is-visible","description":"Enables the is_visible command without any pre-configured scope.","commands":{"allow":["is_visible"],"deny":[]}},"allow-maximize":{"identifier":"allow-maximize","description":"Enables the maximize command without any pre-configured scope.","commands":{"allow":["maximize"],"deny":[]}},"allow-minimize":{"identifier":"allow-minimize","description":"Enables the minimize command without any pre-configured scope.","commands":{"allow":["minimize"],"deny":[]}},"allow-monitor-from-point":{"identifier":"allow-monitor-from-point","description":"Enables the monitor_from_point command without any pre-configured scope.","commands":{"allow":["monitor_from_point"],"deny":[]}},"allow-outer-position":{"identifier":"allow-outer-position","description":"Enables the outer_position command without any pre-configured scope.","commands":{"allow":["outer_position"],"deny":[]}},"allow-outer-size":{"identifier":"allow-outer-size","description":"Enables the outer_size command without any pre-configured scope.","commands":{"allow":["outer_size"],"deny":[]}},"allow-primary-monitor":{"identifier":"allow-primary-monitor","description":"Enables the primary_monitor command without any pre-configured scope.","commands":{"allow":["primary_monitor"],"deny":[]}},"allow-request-user-attention":{"identifier":"allow-request-user-attention","description":"Enables the request_user_attention command without any pre-configured scope.","commands":{"allow":["request_user_attention"],"deny":[]}},"allow-scale-factor":{"identifier":"allow-scale-factor","description":"Enables the scale_factor command without any pre-configured scope.","commands":{"allow":["scale_factor"],"deny":[]}},"allow-set-always-on-bottom":{"identifier":"allow-set-always-on-bottom","description":"Enables the set_always_on_bottom command without any pre-configured scope.","commands":{"allow":["set_always_on_bottom"],"deny":[]}},"allow-set-always-on-top":{"identifier":"allow-set-always-on-top","description":"Enables the set_always_on_top command without any pre-configured scope.","commands":{"allow":["set_always_on_top"],"deny":[]}},"allow-set-background-color":{"identifier":"allow-set-background-color","description":"Enables the set_background_color command without any pre-configured scope.","commands":{"allow":["set_background_color"],"deny":[]}},"allow-set-badge-count":{"identifier":"allow-set-badge-count","description":"Enables the set_badge_count command without any pre-configured scope.","commands":{"allow":["set_badge_count"],"deny":[]}},"allow-set-badge-label":{"identifier":"allow-set-badge-label","description":"Enables the set_badge_label command without any pre-configured scope.","commands":{"allow":["set_badge_label"],"deny":[]}},"allow-set-closable":{"identifier":"allow-set-closable","description":"Enables the set_closable command without any pre-configured scope.","commands":{"allow":["set_closable"],"deny":[]}},"allow-set-content-protected":{"identifier":"allow-set-content-protected","description":"Enables the set_content_protected command without any pre-configured scope.","commands":{"allow":["set_content_protected"],"deny":[]}},"allow-set-cursor-grab":{"identifier":"allow-set-cursor-grab","description":"Enables the set_cursor_grab command without any pre-configured scope.","commands":{"allow":["set_cursor_grab"],"deny":[]}},"allow-set-cursor-icon":{"identifier":"allow-set-cursor-icon","description":"Enables the set_cursor_icon command without any pre-configured scope.","commands":{"allow":["set_cursor_icon"],"deny":[]}},"allow-set-cursor-position":{"identifier":"allow-set-cursor-position","description":"Enables the set_cursor_position command without any pre-configured scope.","commands":{"allow":["set_cursor_position"],"deny":[]}},"allow-set-cursor-visible":{"identifier":"allow-set-cursor-visible","description":"Enables the set_cursor_visible command without any pre-configured scope.","commands":{"allow":["set_cursor_visible"],"deny":[]}},"allow-set-decorations":{"identifier":"allow-set-decorations","description":"Enables the set_decorations command without any pre-configured scope.","commands":{"allow":["set_decorations"],"deny":[]}},"allow-set-effects":{"identifier":"allow-set-effects","description":"Enables the set_effects command without any pre-configured scope.","commands":{"allow":["set_effects"],"deny":[]}},"allow-set-enabled":{"identifier":"allow-set-enabled","description":"Enables the set_enabled command without any pre-configured scope.","commands":{"allow":["set_enabled"],"deny":[]}},"allow-set-focus":{"identifier":"allow-set-focus","description":"Enables the set_focus command without any pre-configured scope.","commands":{"allow":["set_focus"],"deny":[]}},"allow-set-focusable":{"identifier":"allow-set-focusable","description":"Enables the set_focusable command without any pre-configured scope.","commands":{"allow":["set_focusable"],"deny":[]}},"allow-set-fullscreen":{"identifier":"allow-set-fullscreen","description":"Enables the set_fullscreen command without any pre-configured scope.","commands":{"allow":["set_fullscreen"],"deny":[]}},"allow-set-icon":{"identifier":"allow-set-icon","description":"Enables the set_icon command without any pre-configured scope.","commands":{"allow":["set_icon"],"deny":[]}},"allow-set-ignore-cursor-events":{"identifier":"allow-set-ignore-cursor-events","description":"Enables the set_ignore_cursor_events command without any pre-configured scope.","commands":{"allow":["set_ignore_cursor_events"],"deny":[]}},"allow-set-max-size":{"identifier":"allow-set-max-size","description":"Enables the set_max_size command without any pre-configured scope.","commands":{"allow":["set_max_size"],"deny":[]}},"allow-set-maximizable":{"identifier":"allow-set-maximizable","description":"Enables the set_maximizable command without any pre-configured scope.","commands":{"allow":["set_maximizable"],"deny":[]}},"allow-set-min-size":{"identifier":"allow-set-min-size","description":"Enables the set_min_size command without any pre-configured scope.","commands":{"allow":["set_min_size"],"deny":[]}},"allow-set-minimizable":{"identifier":"allow-set-minimizable","description":"Enables the set_minimizable command without any pre-configured scope.","commands":{"allow":["set_minimizable"],"deny":[]}},"allow-set-overlay-icon":{"identifier":"allow-set-overlay-icon","description":"Enables the set_overlay_icon command without any pre-configured scope.","commands":{"allow":["set_overlay_icon"],"deny":[]}},"allow-set-position":{"identifier":"allow-set-position","description":"Enables the set_position command without any pre-configured scope.","commands":{"allow":["set_position"],"deny":[]}},"allow-set-progress-bar":{"identifier":"allow-set-progress-bar","description":"Enables the set_progress_bar command without any pre-configured scope.","commands":{"allow":["set_progress_bar"],"deny":[]}},"allow-set-resizable":{"identifier":"allow-set-resizable","description":"Enables the set_resizable command without any pre-configured scope.","commands":{"allow":["set_resizable"],"deny":[]}},"allow-set-shadow":{"identifier":"allow-set-shadow","description":"Enables the set_shadow command without any pre-configured scope.","commands":{"allow":["set_shadow"],"deny":[]}},"allow-set-simple-fullscreen":{"identifier":"allow-set-simple-fullscreen","description":"Enables the set_simple_fullscreen command without any pre-configured scope.","commands":{"allow":["set_simple_fullscreen"],"deny":[]}},"allow-set-size":{"identifier":"allow-set-size","description":"Enables the set_size command without any pre-configured scope.","commands":{"allow":["set_size"],"deny":[]}},"allow-set-size-constraints":{"identifier":"allow-set-size-constraints","description":"Enables the set_size_constraints command without any pre-configured scope.","commands":{"allow":["set_size_constraints"],"deny":[]}},"allow-set-skip-taskbar":{"identifier":"allow-set-skip-taskbar","description":"Enables the set_skip_taskbar command without any pre-configured scope.","commands":{"allow":["set_skip_taskbar"],"deny":[]}},"allow-set-theme":{"identifier":"allow-set-theme","description":"Enables the set_theme command without any pre-configured scope.","commands":{"allow":["set_theme"],"deny":[]}},"allow-set-title":{"identifier":"allow-set-title","description":"Enables the set_title command without any pre-configured scope.","commands":{"allow":["set_title"],"deny":[]}},"allow-set-title-bar-style":{"identifier":"allow-set-title-bar-style","description":"Enables the set_title_bar_style command without any pre-configured scope.","commands":{"allow":["set_title_bar_style"],"deny":[]}},"allow-set-visible-on-all-workspaces":{"identifier":"allow-set-visible-on-all-workspaces","description":"Enables the set_visible_on_all_workspaces command without any pre-configured scope.","commands":{"allow":["set_visible_on_all_workspaces"],"deny":[]}},"allow-show":{"identifier":"allow-show","description":"Enables the show command without any pre-configured scope.","commands":{"allow":["show"],"deny":[]}},"allow-start-dragging":{"identifier":"allow-start-dragging","description":"Enables the start_dragging command without any pre-configured scope.","commands":{"allow":["start_dragging"],"deny":[]}},"allow-start-resize-dragging":{"identifier":"allow-start-resize-dragging","description":"Enables the start_resize_dragging command without any pre-configured scope.","commands":{"allow":["start_resize_dragging"],"deny":[]}},"allow-theme":{"identifier":"allow-theme","description":"Enables the theme command without any pre-configured scope.","commands":{"allow":["theme"],"deny":[]}},"allow-title":{"identifier":"allow-title","description":"Enables the title command without any pre-configured scope.","commands":{"allow":["title"],"deny":[]}},"allow-toggle-maximize":{"identifier":"allow-toggle-maximize","description":"Enables the toggle_maximize command without any pre-configured scope.","commands":{"allow":["toggle_maximize"],"deny":[]}},"allow-unmaximize":{"identifier":"allow-unmaximize","description":"Enables the unmaximize command without any pre-configured scope.","commands":{"allow":["unmaximize"],"deny":[]}},"allow-unminimize":{"identifier":"allow-unminimize","description":"Enables the unminimize command without any pre-configured scope.","commands":{"allow":["unminimize"],"deny":[]}},"deny-available-monitors":{"identifier":"deny-available-monitors","description":"Denies the available_monitors command without any pre-configured scope.","commands":{"allow":[],"deny":["available_monitors"]}},"deny-center":{"identifier":"deny-center","description":"Denies the center command without any pre-configured scope.","commands":{"allow":[],"deny":["center"]}},"deny-close":{"identifier":"deny-close","description":"Denies the close command without any pre-configured scope.","commands":{"allow":[],"deny":["close"]}},"deny-create":{"identifier":"deny-create","description":"Denies the create command without any pre-configured scope.","commands":{"allow":[],"deny":["create"]}},"deny-current-monitor":{"identifier":"deny-current-monitor","description":"Denies the current_monitor command without any pre-configured scope.","commands":{"allow":[],"deny":["current_monitor"]}},"deny-cursor-position":{"identifier":"deny-cursor-position","description":"Denies the cursor_position command without any pre-configured scope.","commands":{"allow":[],"deny":["cursor_position"]}},"deny-destroy":{"identifier":"deny-destroy","description":"Denies the destroy command without any pre-configured scope.","commands":{"allow":[],"deny":["destroy"]}},"deny-get-all-windows":{"identifier":"deny-get-all-windows","description":"Denies the get_all_windows command without any pre-configured scope.","commands":{"allow":[],"deny":["get_all_windows"]}},"deny-hide":{"identifier":"deny-hide","description":"Denies the hide command without any pre-configured scope.","commands":{"allow":[],"deny":["hide"]}},"deny-inner-position":{"identifier":"deny-inner-position","description":"Denies the inner_position command without any pre-configured scope.","commands":{"allow":[],"deny":["inner_position"]}},"deny-inner-size":{"identifier":"deny-inner-size","description":"Denies the inner_size command without any pre-configured scope.","commands":{"allow":[],"deny":["inner_size"]}},"deny-internal-toggle-maximize":{"identifier":"deny-internal-toggle-maximize","description":"Denies the internal_toggle_maximize command without any pre-configured scope.","commands":{"allow":[],"deny":["internal_toggle_maximize"]}},"deny-is-always-on-top":{"identifier":"deny-is-always-on-top","description":"Denies the is_always_on_top command without any pre-configured scope.","commands":{"allow":[],"deny":["is_always_on_top"]}},"deny-is-closable":{"identifier":"deny-is-closable","description":"Denies the is_closable command without any pre-configured scope.","commands":{"allow":[],"deny":["is_closable"]}},"deny-is-decorated":{"identifier":"deny-is-decorated","description":"Denies the is_decorated command without any pre-configured scope.","commands":{"allow":[],"deny":["is_decorated"]}},"deny-is-enabled":{"identifier":"deny-is-enabled","description":"Denies the is_enabled command without any pre-configured scope.","commands":{"allow":[],"deny":["is_enabled"]}},"deny-is-focused":{"identifier":"deny-is-focused","description":"Denies the is_focused command without any pre-configured scope.","commands":{"allow":[],"deny":["is_focused"]}},"deny-is-fullscreen":{"identifier":"deny-is-fullscreen","description":"Denies the is_fullscreen command without any pre-configured scope.","commands":{"allow":[],"deny":["is_fullscreen"]}},"deny-is-maximizable":{"identifier":"deny-is-maximizable","description":"Denies the is_maximizable command without any pre-configured scope.","commands":{"allow":[],"deny":["is_maximizable"]}},"deny-is-maximized":{"identifier":"deny-is-maximized","description":"Denies the is_maximized command without any pre-configured scope.","commands":{"allow":[],"deny":["is_maximized"]}},"deny-is-minimizable":{"identifier":"deny-is-minimizable","description":"Denies the is_minimizable command without any pre-configured scope.","commands":{"allow":[],"deny":["is_minimizable"]}},"deny-is-minimized":{"identifier":"deny-is-minimized","description":"Denies the is_minimized command without any pre-configured scope.","commands":{"allow":[],"deny":["is_minimized"]}},"deny-is-resizable":{"identifier":"deny-is-resizable","description":"Denies the is_resizable command without any pre-configured scope.","commands":{"allow":[],"deny":["is_resizable"]}},"deny-is-visible":{"identifier":"deny-is-visible","description":"Denies the is_visible command without any pre-configured scope.","commands":{"allow":[],"deny":["is_visible"]}},"deny-maximize":{"identifier":"deny-maximize","description":"Denies the maximize command without any pre-configured scope.","commands":{"allow":[],"deny":["maximize"]}},"deny-minimize":{"identifier":"deny-minimize","description":"Denies the minimize command without any pre-configured scope.","commands":{"allow":[],"deny":["minimize"]}},"deny-monitor-from-point":{"identifier":"deny-monitor-from-point","description":"Denies the monitor_from_point command without any pre-configured scope.","commands":{"allow":[],"deny":["monitor_from_point"]}},"deny-outer-position":{"identifier":"deny-outer-position","description":"Denies the outer_position command without any pre-configured scope.","commands":{"allow":[],"deny":["outer_position"]}},"deny-outer-size":{"identifier":"deny-outer-size","description":"Denies the outer_size command without any pre-configured scope.","commands":{"allow":[],"deny":["outer_size"]}},"deny-primary-monitor":{"identifier":"deny-primary-monitor","description":"Denies the primary_monitor command without any pre-configured scope.","commands":{"allow":[],"deny":["primary_monitor"]}},"deny-request-user-attention":{"identifier":"deny-request-user-attention","description":"Denies the request_user_attention command without any pre-configured scope.","commands":{"allow":[],"deny":["request_user_attention"]}},"deny-scale-factor":{"identifier":"deny-scale-factor","description":"Denies the scale_factor command without any pre-configured scope.","commands":{"allow":[],"deny":["scale_factor"]}},"deny-set-always-on-bottom":{"identifier":"deny-set-always-on-bottom","description":"Denies the set_always_on_bottom command without any pre-configured scope.","commands":{"allow":[],"deny":["set_always_on_bottom"]}},"deny-set-always-on-top":{"identifier":"deny-set-always-on-top","description":"Denies the set_always_on_top command without any pre-configured scope.","commands":{"allow":[],"deny":["set_always_on_top"]}},"deny-set-background-color":{"identifier":"deny-set-background-color","description":"Denies the set_background_color command without any pre-configured scope.","commands":{"allow":[],"deny":["set_background_color"]}},"deny-set-badge-count":{"identifier":"deny-set-badge-count","description":"Denies the set_badge_count command without any pre-configured scope.","commands":{"allow":[],"deny":["set_badge_count"]}},"deny-set-badge-label":{"identifier":"deny-set-badge-label","description":"Denies the set_badge_label command without any pre-configured scope.","commands":{"allow":[],"deny":["set_badge_label"]}},"deny-set-closable":{"identifier":"deny-set-closable","description":"Denies the set_closable command without any pre-configured scope.","commands":{"allow":[],"deny":["set_closable"]}},"deny-set-content-protected":{"identifier":"deny-set-content-protected","description":"Denies the set_content_protected command without any pre-configured scope.","commands":{"allow":[],"deny":["set_content_protected"]}},"deny-set-cursor-grab":{"identifier":"deny-set-cursor-grab","description":"Denies the set_cursor_grab command without any pre-configured scope.","commands":{"allow":[],"deny":["set_cursor_grab"]}},"deny-set-cursor-icon":{"identifier":"deny-set-cursor-icon","description":"Denies the set_cursor_icon command without any pre-configured scope.","commands":{"allow":[],"deny":["set_cursor_icon"]}},"deny-set-cursor-position":{"identifier":"deny-set-cursor-position","description":"Denies the set_cursor_position command without any pre-configured scope.","commands":{"allow":[],"deny":["set_cursor_position"]}},"deny-set-cursor-visible":{"identifier":"deny-set-cursor-visible","description":"Denies the set_cursor_visible command without any pre-configured scope.","commands":{"allow":[],"deny":["set_cursor_visible"]}},"deny-set-decorations":{"identifier":"deny-set-decorations","description":"Denies the set_decorations command without any pre-configured scope.","commands":{"allow":[],"deny":["set_decorations"]}},"deny-set-effects":{"identifier":"deny-set-effects","description":"Denies the set_effects command without any pre-configured scope.","commands":{"allow":[],"deny":["set_effects"]}},"deny-set-enabled":{"identifier":"deny-set-enabled","description":"Denies the set_enabled command without any pre-configured scope.","commands":{"allow":[],"deny":["set_enabled"]}},"deny-set-focus":{"identifier":"deny-set-focus","description":"Denies the set_focus command without any pre-configured scope.","commands":{"allow":[],"deny":["set_focus"]}},"deny-set-focusable":{"identifier":"deny-set-focusable","description":"Denies the set_focusable command without any pre-configured scope.","commands":{"allow":[],"deny":["set_focusable"]}},"deny-set-fullscreen":{"identifier":"deny-set-fullscreen","description":"Denies the set_fullscreen command without any pre-configured scope.","commands":{"allow":[],"deny":["set_fullscreen"]}},"deny-set-icon":{"identifier":"deny-set-icon","description":"Denies the set_icon command without any pre-configured scope.","commands":{"allow":[],"deny":["set_icon"]}},"deny-set-ignore-cursor-events":{"identifier":"deny-set-ignore-cursor-events","description":"Denies the set_ignore_cursor_events command without any pre-configured scope.","commands":{"allow":[],"deny":["set_ignore_cursor_events"]}},"deny-set-max-size":{"identifier":"deny-set-max-size","description":"Denies the set_max_size command without any pre-configured scope.","commands":{"allow":[],"deny":["set_max_size"]}},"deny-set-maximizable":{"identifier":"deny-set-maximizable","description":"Denies the set_maximizable command without any pre-configured scope.","commands":{"allow":[],"deny":["set_maximizable"]}},"deny-set-min-size":{"identifier":"deny-set-min-size","description":"Denies the set_min_size command without any pre-configured scope.","commands":{"allow":[],"deny":["set_min_size"]}},"deny-set-minimizable":{"identifier":"deny-set-minimizable","description":"Denies the set_minimizable command without any pre-configured scope.","commands":{"allow":[],"deny":["set_minimizable"]}},"deny-set-overlay-icon":{"identifier":"deny-set-overlay-icon","description":"Denies the set_overlay_icon command without any pre-configured scope.","commands":{"allow":[],"deny":["set_overlay_icon"]}},"deny-set-position":{"identifier":"deny-set-position","description":"Denies the set_position command without any pre-configured scope.","commands":{"allow":[],"deny":["set_position"]}},"deny-set-progress-bar":{"identifier":"deny-set-progress-bar","description":"Denies the set_progress_bar command without any pre-configured scope.","commands":{"allow":[],"deny":["set_progress_bar"]}},"deny-set-resizable":{"identifier":"deny-set-resizable","description":"Denies the set_resizable command without any pre-configured scope.","commands":{"allow":[],"deny":["set_resizable"]}},"deny-set-shadow":{"identifier":"deny-set-shadow","description":"Denies the set_shadow command without any pre-configured scope.","commands":{"allow":[],"deny":["set_shadow"]}},"deny-set-simple-fullscreen":{"identifier":"deny-set-simple-fullscreen","description":"Denies the set_simple_fullscreen command without any pre-configured scope.","commands":{"allow":[],"deny":["set_simple_fullscreen"]}},"deny-set-size":{"identifier":"deny-set-size","description":"Denies the set_size command without any pre-configured scope.","commands":{"allow":[],"deny":["set_size"]}},"deny-set-size-constraints":{"identifier":"deny-set-size-constraints","description":"Denies the set_size_constraints command without any pre-configured scope.","commands":{"allow":[],"deny":["set_size_constraints"]}},"deny-set-skip-taskbar":{"identifier":"deny-set-skip-taskbar","description":"Denies the set_skip_taskbar command without any pre-configured scope.","commands":{"allow":[],"deny":["set_skip_taskbar"]}},"deny-set-theme":{"identifier":"deny-set-theme","description":"Denies the set_theme command without any pre-configured scope.","commands":{"allow":[],"deny":["set_theme"]}},"deny-set-title":{"identifier":"deny-set-title","description":"Denies the set_title command without any pre-configured scope.","commands":{"allow":[],"deny":["set_title"]}},"deny-set-title-bar-style":{"identifier":"deny-set-title-bar-style","description":"Denies the set_title_bar_style command without any pre-configured scope.","commands":{"allow":[],"deny":["set_title_bar_style"]}},"deny-set-visible-on-all-workspaces":{"identifier":"deny-set-visible-on-all-workspaces","description":"Denies the set_visible_on_all_workspaces command without any pre-configured scope.","commands":{"allow":[],"deny":["set_visible_on_all_workspaces"]}},"deny-show":{"identifier":"deny-show","description":"Denies the show command without any pre-configured scope.","commands":{"allow":[],"deny":["show"]}},"deny-start-dragging":{"identifier":"deny-start-dragging","description":"Denies the start_dragging command without any pre-configured scope.","commands":{"allow":[],"deny":["start_dragging"]}},"deny-start-resize-dragging":{"identifier":"deny-start-resize-dragging","description":"Denies the start_resize_dragging command without any pre-configured scope.","commands":{"allow":[],"deny":["start_resize_dragging"]}},"deny-theme":{"identifier":"deny-theme","description":"Denies the theme command without any pre-configured scope.","commands":{"allow":[],"deny":["theme"]}},"deny-title":{"identifier":"deny-title","description":"Denies the title command without any pre-configured scope.","commands":{"allow":[],"deny":["title"]}},"deny-toggle-maximize":{"identifier":"deny-toggle-maximize","description":"Denies the toggle_maximize command without any pre-configured scope.","commands":{"allow":[],"deny":["toggle_maximize"]}},"deny-unmaximize":{"identifier":"deny-unmaximize","description":"Denies the unmaximize command without any pre-configured scope.","commands":{"allow":[],"deny":["unmaximize"]}},"deny-unminimize":{"identifier":"deny-unminimize","description":"Denies the unminimize command without any pre-configured scope.","commands":{"allow":[],"deny":["unminimize"]}}},"permission_sets":{},"global_scope_schema":null},"shell":{"default_permission":{"identifier":"default","description":"This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n","permissions":["allow-open"]},"permissions":{"allow-execute":{"identifier":"allow-execute","description":"Enables the execute command without any pre-configured scope.","commands":{"allow":["execute"],"deny":[]}},"allow-kill":{"identifier":"allow-kill","description":"Enables the kill command without any pre-configured scope.","commands":{"allow":["kill"],"deny":[]}},"allow-open":{"identifier":"allow-open","description":"Enables the open command without any pre-configured scope.","commands":{"allow":["open"],"deny":[]}},"allow-spawn":{"identifier":"allow-spawn","description":"Enables the spawn command without any pre-configured scope.","commands":{"allow":["spawn"],"deny":[]}},"allow-stdin-write":{"identifier":"allow-stdin-write","description":"Enables the stdin_write command without any pre-configured scope.","commands":{"allow":["stdin_write"],"deny":[]}},"deny-execute":{"identifier":"deny-execute","description":"Denies the execute command without any pre-configured scope.","commands":{"allow":[],"deny":["execute"]}},"deny-kill":{"identifier":"deny-kill","description":"Denies the kill command without any pre-configured scope.","commands":{"allow":[],"deny":["kill"]}},"deny-open":{"identifier":"deny-open","description":"Denies the open command without any pre-configured scope.","commands":{"allow":[],"deny":["open"]}},"deny-spawn":{"identifier":"deny-spawn","description":"Denies the spawn command without any pre-configured scope.","commands":{"allow":[],"deny":["spawn"]}},"deny-stdin-write":{"identifier":"deny-stdin-write","description":"Denies the stdin_write command without any pre-configured scope.","commands":{"allow":[],"deny":["stdin_write"]}}},"permission_sets":{},"global_scope_schema":{"$schema":"http://json-schema.org/draft-07/schema#","anyOf":[{"additionalProperties":false,"properties":{"args":{"allOf":[{"$ref":"#/definitions/ShellScopeEntryAllowedArgs"}],"description":"The allowed arguments for the command execution."},"cmd":{"description":"The command name. It can start with a variable that resolves to a system base directory. The variables are: `$AUDIO`, `$CACHE`, `$CONFIG`, `$DATA`, `$LOCALDATA`, `$DESKTOP`, `$DOCUMENT`, `$DOWNLOAD`, `$EXE`, `$FONT`, `$HOME`, `$PICTURE`, `$PUBLIC`, `$RUNTIME`, `$TEMPLATE`, `$VIDEO`, `$RESOURCE`, `$LOG`, `$TEMP`, `$APPCONFIG`, `$APPDATA`, `$APPLOCALDATA`, `$APPCACHE`, `$APPLOG`.","type":"string"},"name":{"description":"The name for this allowed shell command configuration.\n\nThis name will be used inside of the webview API to call this command along with any specified arguments.","type":"string"}},"required":["cmd","name"],"type":"object"},{"additionalProperties":false,"properties":{"args":{"allOf":[{"$ref":"#/definitions/ShellScopeEntryAllowedArgs"}],"description":"The allowed arguments for the command execution."},"name":{"description":"The name for this allowed shell command configuration.\n\nThis name will be used inside of the webview API to call this command along with any specified arguments.","type":"string"},"sidecar":{"description":"If this command is a sidecar command.","type":"boolean"}},"required":["name","sidecar"],"type":"object"}],"definitions":{"ShellScopeEntryAllowedArg":{"anyOf":[{"description":"A non-configurable argument that is passed to the command in the order it was specified.","type":"string"},{"additionalProperties":false,"description":"A variable that is set while calling the command from the webview API.","properties":{"raw":{"default":false,"description":"Marks the validator as a raw regex, meaning the plugin should not make any modification at runtime.\n\nThis means the regex will not match on the entire string by default, which might be exploited if your regex allow unexpected input to be considered valid. When using this option, make sure your regex is correct.","type":"boolean"},"validator":{"description":"[regex] validator to require passed values to conform to an expected input.\n\nThis will require the argument value passed to this variable to match the `validator` regex before it will be executed.\n\nThe regex string is by default surrounded by `^...$` to match the full string. For example the `https?://\\w+` regex would be registered as `^https?://\\w+$`.\n\n[regex]: ","type":"string"}},"required":["validator"],"type":"object"}],"description":"A command argument allowed to be executed by the webview API."},"ShellScopeEntryAllowedArgs":{"anyOf":[{"description":"Use a simple boolean to allow all or disable all arguments to this command configuration.","type":"boolean"},{"description":"A specific set of [`ShellScopeEntryAllowedArg`] that are valid to call for the command configuration.","items":{"$ref":"#/definitions/ShellScopeEntryAllowedArg"},"type":"array"}],"description":"A set of command arguments allowed to be executed by the webview API.\n\nA value of `true` will allow any arguments to be passed to the command. `false` will disable all arguments. A list of [`ShellScopeEntryAllowedArg`] will set those arguments as the only valid arguments to be passed to the attached command configuration."}},"description":"Shell scope entry.","title":"ShellScopeEntry"}},"window-state":{"default_permission":{"identifier":"default","description":"This permission set configures what kind of\noperations are available from the window state plugin.\n\n#### Granted Permissions\n\nAll operations are enabled by default.\n\n","permissions":["allow-filename","allow-restore-state","allow-save-window-state"]},"permissions":{"allow-filename":{"identifier":"allow-filename","description":"Enables the filename command without any pre-configured scope.","commands":{"allow":["filename"],"deny":[]}},"allow-restore-state":{"identifier":"allow-restore-state","description":"Enables the restore_state command without any pre-configured scope.","commands":{"allow":["restore_state"],"deny":[]}},"allow-save-window-state":{"identifier":"allow-save-window-state","description":"Enables the save_window_state command without any pre-configured scope.","commands":{"allow":["save_window_state"],"deny":[]}},"deny-filename":{"identifier":"deny-filename","description":"Denies the filename command without any pre-configured scope.","commands":{"allow":[],"deny":["filename"]}},"deny-restore-state":{"identifier":"deny-restore-state","description":"Denies the restore_state command without any pre-configured scope.","commands":{"allow":[],"deny":["restore_state"]}},"deny-save-window-state":{"identifier":"deny-save-window-state","description":"Denies the save_window_state command without any pre-configured scope.","commands":{"allow":[],"deny":["save_window_state"]}}},"permission_sets":{},"global_scope_schema":null}} ================================================ FILE: desktop/src-tauri/gen/schemas/capabilities.json ================================================ {} ================================================ FILE: desktop/src-tauri/gen/schemas/desktop-schema.json ================================================ { "$schema": "http://json-schema.org/draft-07/schema#", "title": "CapabilityFile", "description": "Capability formats accepted in a capability file.", "anyOf": [ { "description": "A single capability.", "allOf": [ { "$ref": "#/definitions/Capability" } ] }, { "description": "A list of capabilities.", "type": "array", "items": { "$ref": "#/definitions/Capability" } }, { "description": "A list of capabilities.", "type": "object", "required": [ "capabilities" ], "properties": { "capabilities": { "description": "The list of capabilities.", "type": "array", "items": { "$ref": "#/definitions/Capability" } } } } ], "definitions": { "Capability": { "description": "A grouping and boundary mechanism developers can use to isolate access to the IPC layer.\n\nIt controls application windows' and webviews' fine grained access to the Tauri core, application, or plugin commands. If a webview or its window is not matching any capability then it has no access to the IPC layer at all.\n\nThis can be done to create groups of windows, based on their required system access, which can reduce impact of frontend vulnerabilities in less privileged windows. Windows can be added to a capability by exact name (e.g. `main-window`) or glob patterns like `*` or `admin-*`. A Window can have none, one, or multiple associated capabilities.\n\n## Example\n\n```json { \"identifier\": \"main-user-files-write\", \"description\": \"This capability allows the `main` window on macOS and Windows access to `filesystem` write related commands and `dialog` commands to enable programmatic access to files selected by the user.\", \"windows\": [ \"main\" ], \"permissions\": [ \"core:default\", \"dialog:open\", { \"identifier\": \"fs:allow-write-text-file\", \"allow\": [{ \"path\": \"$HOME/test.txt\" }] }, ], \"platforms\": [\"macOS\",\"windows\"] } ```", "type": "object", "required": [ "identifier", "permissions" ], "properties": { "identifier": { "description": "Identifier of the capability.\n\n## Example\n\n`main-user-files-write`", "type": "string" }, "description": { "description": "Description of what the capability is intended to allow on associated windows.\n\nIt should contain a description of what the grouped permissions should allow.\n\n## Example\n\nThis capability allows the `main` window access to `filesystem` write related commands and `dialog` commands to enable programmatic access to files selected by the user.", "default": "", "type": "string" }, "remote": { "description": "Configure remote URLs that can use the capability permissions.\n\nThis setting is optional and defaults to not being set, as our default use case is that the content is served from our local application.\n\n:::caution Make sure you understand the security implications of providing remote sources with local system access. :::\n\n## Example\n\n```json { \"urls\": [\"https://*.mydomain.dev\"] } ```", "anyOf": [ { "$ref": "#/definitions/CapabilityRemote" }, { "type": "null" } ] }, "local": { "description": "Whether this capability is enabled for local app URLs or not. Defaults to `true`.", "default": true, "type": "boolean" }, "windows": { "description": "List of windows that are affected by this capability. Can be a glob pattern.\n\nIf a window label matches any of the patterns in this list, the capability will be enabled on all the webviews of that window, regardless of the value of [`Self::webviews`].\n\nOn multiwebview windows, prefer specifying [`Self::webviews`] and omitting [`Self::windows`] for a fine grained access control.\n\n## Example\n\n`[\"main\"]`", "type": "array", "items": { "type": "string" } }, "webviews": { "description": "List of webviews that are affected by this capability. Can be a glob pattern.\n\nThe capability will be enabled on all the webviews whose label matches any of the patterns in this list, regardless of whether the webview's window label matches a pattern in [`Self::windows`].\n\n## Example\n\n`[\"sub-webview-one\", \"sub-webview-two\"]`", "type": "array", "items": { "type": "string" } }, "permissions": { "description": "List of permissions attached to this capability.\n\nMust include the plugin name as prefix in the form of `${plugin-name}:${permission-name}`. For commands directly implemented in the application itself only `${permission-name}` is required.\n\n## Example\n\n```json [ \"core:default\", \"shell:allow-open\", \"dialog:open\", { \"identifier\": \"fs:allow-write-text-file\", \"allow\": [{ \"path\": \"$HOME/test.txt\" }] } ] ```", "type": "array", "items": { "$ref": "#/definitions/PermissionEntry" }, "uniqueItems": true }, "platforms": { "description": "Limit which target platforms this capability applies to.\n\nBy default all platforms are targeted.\n\n## Example\n\n`[\"macOS\",\"windows\"]`", "type": [ "array", "null" ], "items": { "$ref": "#/definitions/Target" } } } }, "CapabilityRemote": { "description": "Configuration for remote URLs that are associated with the capability.", "type": "object", "required": [ "urls" ], "properties": { "urls": { "description": "Remote domains this capability refers to using the [URLPattern standard](https://urlpattern.spec.whatwg.org/).\n\n## Examples\n\n- \"https://*.mydomain.dev\": allows subdomains of mydomain.dev - \"https://mydomain.dev/api/*\": allows any subpath of mydomain.dev/api", "type": "array", "items": { "type": "string" } } } }, "PermissionEntry": { "description": "An entry for a permission value in a [`Capability`] can be either a raw permission [`Identifier`] or an object that references a permission and extends its scope.", "anyOf": [ { "description": "Reference a permission or permission set by identifier.", "allOf": [ { "$ref": "#/definitions/Identifier" } ] }, { "description": "Reference a permission or permission set by identifier and extends its scope.", "type": "object", "allOf": [ { "if": { "properties": { "identifier": { "anyOf": [ { "description": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`", "type": "string", "const": "shell:default", "markdownDescription": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`" }, { "description": "Enables the execute command without any pre-configured scope.", "type": "string", "const": "shell:allow-execute", "markdownDescription": "Enables the execute command without any pre-configured scope." }, { "description": "Enables the kill command without any pre-configured scope.", "type": "string", "const": "shell:allow-kill", "markdownDescription": "Enables the kill command without any pre-configured scope." }, { "description": "Enables the open command without any pre-configured scope.", "type": "string", "const": "shell:allow-open", "markdownDescription": "Enables the open command without any pre-configured scope." }, { "description": "Enables the spawn command without any pre-configured scope.", "type": "string", "const": "shell:allow-spawn", "markdownDescription": "Enables the spawn command without any pre-configured scope." }, { "description": "Enables the stdin_write command without any pre-configured scope.", "type": "string", "const": "shell:allow-stdin-write", "markdownDescription": "Enables the stdin_write command without any pre-configured scope." }, { "description": "Denies the execute command without any pre-configured scope.", "type": "string", "const": "shell:deny-execute", "markdownDescription": "Denies the execute command without any pre-configured scope." }, { "description": "Denies the kill command without any pre-configured scope.", "type": "string", "const": "shell:deny-kill", "markdownDescription": "Denies the kill command without any pre-configured scope." }, { "description": "Denies the open command without any pre-configured scope.", "type": "string", "const": "shell:deny-open", "markdownDescription": "Denies the open command without any pre-configured scope." }, { "description": "Denies the spawn command without any pre-configured scope.", "type": "string", "const": "shell:deny-spawn", "markdownDescription": "Denies the spawn command without any pre-configured scope." }, { "description": "Denies the stdin_write command without any pre-configured scope.", "type": "string", "const": "shell:deny-stdin-write", "markdownDescription": "Denies the stdin_write command without any pre-configured scope." } ] } } }, "then": { "properties": { "allow": { "items": { "title": "ShellScopeEntry", "description": "Shell scope entry.", "anyOf": [ { "type": "object", "required": [ "cmd", "name" ], "properties": { "args": { "description": "The allowed arguments for the command execution.", "allOf": [ { "$ref": "#/definitions/ShellScopeEntryAllowedArgs" } ] }, "cmd": { "description": "The command name. It can start with a variable that resolves to a system base directory. The variables are: `$AUDIO`, `$CACHE`, `$CONFIG`, `$DATA`, `$LOCALDATA`, `$DESKTOP`, `$DOCUMENT`, `$DOWNLOAD`, `$EXE`, `$FONT`, `$HOME`, `$PICTURE`, `$PUBLIC`, `$RUNTIME`, `$TEMPLATE`, `$VIDEO`, `$RESOURCE`, `$LOG`, `$TEMP`, `$APPCONFIG`, `$APPDATA`, `$APPLOCALDATA`, `$APPCACHE`, `$APPLOG`.", "type": "string" }, "name": { "description": "The name for this allowed shell command configuration.\n\nThis name will be used inside of the webview API to call this command along with any specified arguments.", "type": "string" } }, "additionalProperties": false }, { "type": "object", "required": [ "name", "sidecar" ], "properties": { "args": { "description": "The allowed arguments for the command execution.", "allOf": [ { "$ref": "#/definitions/ShellScopeEntryAllowedArgs" } ] }, "name": { "description": "The name for this allowed shell command configuration.\n\nThis name will be used inside of the webview API to call this command along with any specified arguments.", "type": "string" }, "sidecar": { "description": "If this command is a sidecar command.", "type": "boolean" } }, "additionalProperties": false } ] } }, "deny": { "items": { "title": "ShellScopeEntry", "description": "Shell scope entry.", "anyOf": [ { "type": "object", "required": [ "cmd", "name" ], "properties": { "args": { "description": "The allowed arguments for the command execution.", "allOf": [ { "$ref": "#/definitions/ShellScopeEntryAllowedArgs" } ] }, "cmd": { "description": "The command name. It can start with a variable that resolves to a system base directory. The variables are: `$AUDIO`, `$CACHE`, `$CONFIG`, `$DATA`, `$LOCALDATA`, `$DESKTOP`, `$DOCUMENT`, `$DOWNLOAD`, `$EXE`, `$FONT`, `$HOME`, `$PICTURE`, `$PUBLIC`, `$RUNTIME`, `$TEMPLATE`, `$VIDEO`, `$RESOURCE`, `$LOG`, `$TEMP`, `$APPCONFIG`, `$APPDATA`, `$APPLOCALDATA`, `$APPCACHE`, `$APPLOG`.", "type": "string" }, "name": { "description": "The name for this allowed shell command configuration.\n\nThis name will be used inside of the webview API to call this command along with any specified arguments.", "type": "string" } }, "additionalProperties": false }, { "type": "object", "required": [ "name", "sidecar" ], "properties": { "args": { "description": "The allowed arguments for the command execution.", "allOf": [ { "$ref": "#/definitions/ShellScopeEntryAllowedArgs" } ] }, "name": { "description": "The name for this allowed shell command configuration.\n\nThis name will be used inside of the webview API to call this command along with any specified arguments.", "type": "string" }, "sidecar": { "description": "If this command is a sidecar command.", "type": "boolean" } }, "additionalProperties": false } ] } } } }, "properties": { "identifier": { "description": "Identifier of the permission or permission set.", "allOf": [ { "$ref": "#/definitions/Identifier" } ] } } }, { "properties": { "identifier": { "description": "Identifier of the permission or permission set.", "allOf": [ { "$ref": "#/definitions/Identifier" } ] }, "allow": { "description": "Data that defines what is allowed by the scope.", "type": [ "array", "null" ], "items": { "$ref": "#/definitions/Value" } }, "deny": { "description": "Data that defines what is denied by the scope. This should be prioritized by validation logic.", "type": [ "array", "null" ], "items": { "$ref": "#/definitions/Value" } } } } ], "required": [ "identifier" ] } ] }, "Identifier": { "description": "Permission identifier", "oneOf": [ { "description": "Default core plugins set.\n#### This default permission set includes:\n\n- `core:path:default`\n- `core:event:default`\n- `core:window:default`\n- `core:webview:default`\n- `core:app:default`\n- `core:image:default`\n- `core:resources:default`\n- `core:menu:default`\n- `core:tray:default`", "type": "string", "const": "core:default", "markdownDescription": "Default core plugins set.\n#### This default permission set includes:\n\n- `core:path:default`\n- `core:event:default`\n- `core:window:default`\n- `core:webview:default`\n- `core:app:default`\n- `core:image:default`\n- `core:resources:default`\n- `core:menu:default`\n- `core:tray:default`" }, { "description": "Default permissions for the plugin.\n#### This default permission set includes:\n\n- `allow-version`\n- `allow-name`\n- `allow-tauri-version`\n- `allow-identifier`\n- `allow-bundle-type`\n- `allow-register-listener`\n- `allow-remove-listener`", "type": "string", "const": "core:app:default", "markdownDescription": "Default permissions for the plugin.\n#### This default permission set includes:\n\n- `allow-version`\n- `allow-name`\n- `allow-tauri-version`\n- `allow-identifier`\n- `allow-bundle-type`\n- `allow-register-listener`\n- `allow-remove-listener`" }, { "description": "Enables the app_hide command without any pre-configured scope.", "type": "string", "const": "core:app:allow-app-hide", "markdownDescription": "Enables the app_hide command without any pre-configured scope." }, { "description": "Enables the app_show command without any pre-configured scope.", "type": "string", "const": "core:app:allow-app-show", "markdownDescription": "Enables the app_show command without any pre-configured scope." }, { "description": "Enables the bundle_type command without any pre-configured scope.", "type": "string", "const": "core:app:allow-bundle-type", "markdownDescription": "Enables the bundle_type command without any pre-configured scope." }, { "description": "Enables the default_window_icon command without any pre-configured scope.", "type": "string", "const": "core:app:allow-default-window-icon", "markdownDescription": "Enables the default_window_icon command without any pre-configured scope." }, { "description": "Enables the fetch_data_store_identifiers command without any pre-configured scope.", "type": "string", "const": "core:app:allow-fetch-data-store-identifiers", "markdownDescription": "Enables the fetch_data_store_identifiers command without any pre-configured scope." }, { "description": "Enables the identifier command without any pre-configured scope.", "type": "string", "const": "core:app:allow-identifier", "markdownDescription": "Enables the identifier command without any pre-configured scope." }, { "description": "Enables the name command without any pre-configured scope.", "type": "string", "const": "core:app:allow-name", "markdownDescription": "Enables the name command without any pre-configured scope." }, { "description": "Enables the register_listener command without any pre-configured scope.", "type": "string", "const": "core:app:allow-register-listener", "markdownDescription": "Enables the register_listener command without any pre-configured scope." }, { "description": "Enables the remove_data_store command without any pre-configured scope.", "type": "string", "const": "core:app:allow-remove-data-store", "markdownDescription": "Enables the remove_data_store command without any pre-configured scope." }, { "description": "Enables the remove_listener command without any pre-configured scope.", "type": "string", "const": "core:app:allow-remove-listener", "markdownDescription": "Enables the remove_listener command without any pre-configured scope." }, { "description": "Enables the set_app_theme command without any pre-configured scope.", "type": "string", "const": "core:app:allow-set-app-theme", "markdownDescription": "Enables the set_app_theme command without any pre-configured scope." }, { "description": "Enables the set_dock_visibility command without any pre-configured scope.", "type": "string", "const": "core:app:allow-set-dock-visibility", "markdownDescription": "Enables the set_dock_visibility command without any pre-configured scope." }, { "description": "Enables the tauri_version command without any pre-configured scope.", "type": "string", "const": "core:app:allow-tauri-version", "markdownDescription": "Enables the tauri_version command without any pre-configured scope." }, { "description": "Enables the version command without any pre-configured scope.", "type": "string", "const": "core:app:allow-version", "markdownDescription": "Enables the version command without any pre-configured scope." }, { "description": "Denies the app_hide command without any pre-configured scope.", "type": "string", "const": "core:app:deny-app-hide", "markdownDescription": "Denies the app_hide command without any pre-configured scope." }, { "description": "Denies the app_show command without any pre-configured scope.", "type": "string", "const": "core:app:deny-app-show", "markdownDescription": "Denies the app_show command without any pre-configured scope." }, { "description": "Denies the bundle_type command without any pre-configured scope.", "type": "string", "const": "core:app:deny-bundle-type", "markdownDescription": "Denies the bundle_type command without any pre-configured scope." }, { "description": "Denies the default_window_icon command without any pre-configured scope.", "type": "string", "const": "core:app:deny-default-window-icon", "markdownDescription": "Denies the default_window_icon command without any pre-configured scope." }, { "description": "Denies the fetch_data_store_identifiers command without any pre-configured scope.", "type": "string", "const": "core:app:deny-fetch-data-store-identifiers", "markdownDescription": "Denies the fetch_data_store_identifiers command without any pre-configured scope." }, { "description": "Denies the identifier command without any pre-configured scope.", "type": "string", "const": "core:app:deny-identifier", "markdownDescription": "Denies the identifier command without any pre-configured scope." }, { "description": "Denies the name command without any pre-configured scope.", "type": "string", "const": "core:app:deny-name", "markdownDescription": "Denies the name command without any pre-configured scope." }, { "description": "Denies the register_listener command without any pre-configured scope.", "type": "string", "const": "core:app:deny-register-listener", "markdownDescription": "Denies the register_listener command without any pre-configured scope." }, { "description": "Denies the remove_data_store command without any pre-configured scope.", "type": "string", "const": "core:app:deny-remove-data-store", "markdownDescription": "Denies the remove_data_store command without any pre-configured scope." }, { "description": "Denies the remove_listener command without any pre-configured scope.", "type": "string", "const": "core:app:deny-remove-listener", "markdownDescription": "Denies the remove_listener command without any pre-configured scope." }, { "description": "Denies the set_app_theme command without any pre-configured scope.", "type": "string", "const": "core:app:deny-set-app-theme", "markdownDescription": "Denies the set_app_theme command without any pre-configured scope." }, { "description": "Denies the set_dock_visibility command without any pre-configured scope.", "type": "string", "const": "core:app:deny-set-dock-visibility", "markdownDescription": "Denies the set_dock_visibility command without any pre-configured scope." }, { "description": "Denies the tauri_version command without any pre-configured scope.", "type": "string", "const": "core:app:deny-tauri-version", "markdownDescription": "Denies the tauri_version command without any pre-configured scope." }, { "description": "Denies the version command without any pre-configured scope.", "type": "string", "const": "core:app:deny-version", "markdownDescription": "Denies the version command without any pre-configured scope." }, { "description": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-listen`\n- `allow-unlisten`\n- `allow-emit`\n- `allow-emit-to`", "type": "string", "const": "core:event:default", "markdownDescription": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-listen`\n- `allow-unlisten`\n- `allow-emit`\n- `allow-emit-to`" }, { "description": "Enables the emit command without any pre-configured scope.", "type": "string", "const": "core:event:allow-emit", "markdownDescription": "Enables the emit command without any pre-configured scope." }, { "description": "Enables the emit_to command without any pre-configured scope.", "type": "string", "const": "core:event:allow-emit-to", "markdownDescription": "Enables the emit_to command without any pre-configured scope." }, { "description": "Enables the listen command without any pre-configured scope.", "type": "string", "const": "core:event:allow-listen", "markdownDescription": "Enables the listen command without any pre-configured scope." }, { "description": "Enables the unlisten command without any pre-configured scope.", "type": "string", "const": "core:event:allow-unlisten", "markdownDescription": "Enables the unlisten command without any pre-configured scope." }, { "description": "Denies the emit command without any pre-configured scope.", "type": "string", "const": "core:event:deny-emit", "markdownDescription": "Denies the emit command without any pre-configured scope." }, { "description": "Denies the emit_to command without any pre-configured scope.", "type": "string", "const": "core:event:deny-emit-to", "markdownDescription": "Denies the emit_to command without any pre-configured scope." }, { "description": "Denies the listen command without any pre-configured scope.", "type": "string", "const": "core:event:deny-listen", "markdownDescription": "Denies the listen command without any pre-configured scope." }, { "description": "Denies the unlisten command without any pre-configured scope.", "type": "string", "const": "core:event:deny-unlisten", "markdownDescription": "Denies the unlisten command without any pre-configured scope." }, { "description": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-new`\n- `allow-from-bytes`\n- `allow-from-path`\n- `allow-rgba`\n- `allow-size`", "type": "string", "const": "core:image:default", "markdownDescription": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-new`\n- `allow-from-bytes`\n- `allow-from-path`\n- `allow-rgba`\n- `allow-size`" }, { "description": "Enables the from_bytes command without any pre-configured scope.", "type": "string", "const": "core:image:allow-from-bytes", "markdownDescription": "Enables the from_bytes command without any pre-configured scope." }, { "description": "Enables the from_path command without any pre-configured scope.", "type": "string", "const": "core:image:allow-from-path", "markdownDescription": "Enables the from_path command without any pre-configured scope." }, { "description": "Enables the new command without any pre-configured scope.", "type": "string", "const": "core:image:allow-new", "markdownDescription": "Enables the new command without any pre-configured scope." }, { "description": "Enables the rgba command without any pre-configured scope.", "type": "string", "const": "core:image:allow-rgba", "markdownDescription": "Enables the rgba command without any pre-configured scope." }, { "description": "Enables the size command without any pre-configured scope.", "type": "string", "const": "core:image:allow-size", "markdownDescription": "Enables the size command without any pre-configured scope." }, { "description": "Denies the from_bytes command without any pre-configured scope.", "type": "string", "const": "core:image:deny-from-bytes", "markdownDescription": "Denies the from_bytes command without any pre-configured scope." }, { "description": "Denies the from_path command without any pre-configured scope.", "type": "string", "const": "core:image:deny-from-path", "markdownDescription": "Denies the from_path command without any pre-configured scope." }, { "description": "Denies the new command without any pre-configured scope.", "type": "string", "const": "core:image:deny-new", "markdownDescription": "Denies the new command without any pre-configured scope." }, { "description": "Denies the rgba command without any pre-configured scope.", "type": "string", "const": "core:image:deny-rgba", "markdownDescription": "Denies the rgba command without any pre-configured scope." }, { "description": "Denies the size command without any pre-configured scope.", "type": "string", "const": "core:image:deny-size", "markdownDescription": "Denies the size command without any pre-configured scope." }, { "description": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-new`\n- `allow-append`\n- `allow-prepend`\n- `allow-insert`\n- `allow-remove`\n- `allow-remove-at`\n- `allow-items`\n- `allow-get`\n- `allow-popup`\n- `allow-create-default`\n- `allow-set-as-app-menu`\n- `allow-set-as-window-menu`\n- `allow-text`\n- `allow-set-text`\n- `allow-is-enabled`\n- `allow-set-enabled`\n- `allow-set-accelerator`\n- `allow-set-as-windows-menu-for-nsapp`\n- `allow-set-as-help-menu-for-nsapp`\n- `allow-is-checked`\n- `allow-set-checked`\n- `allow-set-icon`", "type": "string", "const": "core:menu:default", "markdownDescription": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-new`\n- `allow-append`\n- `allow-prepend`\n- `allow-insert`\n- `allow-remove`\n- `allow-remove-at`\n- `allow-items`\n- `allow-get`\n- `allow-popup`\n- `allow-create-default`\n- `allow-set-as-app-menu`\n- `allow-set-as-window-menu`\n- `allow-text`\n- `allow-set-text`\n- `allow-is-enabled`\n- `allow-set-enabled`\n- `allow-set-accelerator`\n- `allow-set-as-windows-menu-for-nsapp`\n- `allow-set-as-help-menu-for-nsapp`\n- `allow-is-checked`\n- `allow-set-checked`\n- `allow-set-icon`" }, { "description": "Enables the append command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-append", "markdownDescription": "Enables the append command without any pre-configured scope." }, { "description": "Enables the create_default command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-create-default", "markdownDescription": "Enables the create_default command without any pre-configured scope." }, { "description": "Enables the get command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-get", "markdownDescription": "Enables the get command without any pre-configured scope." }, { "description": "Enables the insert command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-insert", "markdownDescription": "Enables the insert command without any pre-configured scope." }, { "description": "Enables the is_checked command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-is-checked", "markdownDescription": "Enables the is_checked command without any pre-configured scope." }, { "description": "Enables the is_enabled command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-is-enabled", "markdownDescription": "Enables the is_enabled command without any pre-configured scope." }, { "description": "Enables the items command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-items", "markdownDescription": "Enables the items command without any pre-configured scope." }, { "description": "Enables the new command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-new", "markdownDescription": "Enables the new command without any pre-configured scope." }, { "description": "Enables the popup command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-popup", "markdownDescription": "Enables the popup command without any pre-configured scope." }, { "description": "Enables the prepend command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-prepend", "markdownDescription": "Enables the prepend command without any pre-configured scope." }, { "description": "Enables the remove command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-remove", "markdownDescription": "Enables the remove command without any pre-configured scope." }, { "description": "Enables the remove_at command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-remove-at", "markdownDescription": "Enables the remove_at command without any pre-configured scope." }, { "description": "Enables the set_accelerator command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-set-accelerator", "markdownDescription": "Enables the set_accelerator command without any pre-configured scope." }, { "description": "Enables the set_as_app_menu command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-set-as-app-menu", "markdownDescription": "Enables the set_as_app_menu command without any pre-configured scope." }, { "description": "Enables the set_as_help_menu_for_nsapp command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-set-as-help-menu-for-nsapp", "markdownDescription": "Enables the set_as_help_menu_for_nsapp command without any pre-configured scope." }, { "description": "Enables the set_as_window_menu command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-set-as-window-menu", "markdownDescription": "Enables the set_as_window_menu command without any pre-configured scope." }, { "description": "Enables the set_as_windows_menu_for_nsapp command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-set-as-windows-menu-for-nsapp", "markdownDescription": "Enables the set_as_windows_menu_for_nsapp command without any pre-configured scope." }, { "description": "Enables the set_checked command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-set-checked", "markdownDescription": "Enables the set_checked command without any pre-configured scope." }, { "description": "Enables the set_enabled command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-set-enabled", "markdownDescription": "Enables the set_enabled command without any pre-configured scope." }, { "description": "Enables the set_icon command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-set-icon", "markdownDescription": "Enables the set_icon command without any pre-configured scope." }, { "description": "Enables the set_text command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-set-text", "markdownDescription": "Enables the set_text command without any pre-configured scope." }, { "description": "Enables the text command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-text", "markdownDescription": "Enables the text command without any pre-configured scope." }, { "description": "Denies the append command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-append", "markdownDescription": "Denies the append command without any pre-configured scope." }, { "description": "Denies the create_default command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-create-default", "markdownDescription": "Denies the create_default command without any pre-configured scope." }, { "description": "Denies the get command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-get", "markdownDescription": "Denies the get command without any pre-configured scope." }, { "description": "Denies the insert command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-insert", "markdownDescription": "Denies the insert command without any pre-configured scope." }, { "description": "Denies the is_checked command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-is-checked", "markdownDescription": "Denies the is_checked command without any pre-configured scope." }, { "description": "Denies the is_enabled command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-is-enabled", "markdownDescription": "Denies the is_enabled command without any pre-configured scope." }, { "description": "Denies the items command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-items", "markdownDescription": "Denies the items command without any pre-configured scope." }, { "description": "Denies the new command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-new", "markdownDescription": "Denies the new command without any pre-configured scope." }, { "description": "Denies the popup command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-popup", "markdownDescription": "Denies the popup command without any pre-configured scope." }, { "description": "Denies the prepend command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-prepend", "markdownDescription": "Denies the prepend command without any pre-configured scope." }, { "description": "Denies the remove command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-remove", "markdownDescription": "Denies the remove command without any pre-configured scope." }, { "description": "Denies the remove_at command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-remove-at", "markdownDescription": "Denies the remove_at command without any pre-configured scope." }, { "description": "Denies the set_accelerator command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-set-accelerator", "markdownDescription": "Denies the set_accelerator command without any pre-configured scope." }, { "description": "Denies the set_as_app_menu command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-set-as-app-menu", "markdownDescription": "Denies the set_as_app_menu command without any pre-configured scope." }, { "description": "Denies the set_as_help_menu_for_nsapp command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-set-as-help-menu-for-nsapp", "markdownDescription": "Denies the set_as_help_menu_for_nsapp command without any pre-configured scope." }, { "description": "Denies the set_as_window_menu command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-set-as-window-menu", "markdownDescription": "Denies the set_as_window_menu command without any pre-configured scope." }, { "description": "Denies the set_as_windows_menu_for_nsapp command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-set-as-windows-menu-for-nsapp", "markdownDescription": "Denies the set_as_windows_menu_for_nsapp command without any pre-configured scope." }, { "description": "Denies the set_checked command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-set-checked", "markdownDescription": "Denies the set_checked command without any pre-configured scope." }, { "description": "Denies the set_enabled command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-set-enabled", "markdownDescription": "Denies the set_enabled command without any pre-configured scope." }, { "description": "Denies the set_icon command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-set-icon", "markdownDescription": "Denies the set_icon command without any pre-configured scope." }, { "description": "Denies the set_text command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-set-text", "markdownDescription": "Denies the set_text command without any pre-configured scope." }, { "description": "Denies the text command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-text", "markdownDescription": "Denies the text command without any pre-configured scope." }, { "description": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-resolve-directory`\n- `allow-resolve`\n- `allow-normalize`\n- `allow-join`\n- `allow-dirname`\n- `allow-extname`\n- `allow-basename`\n- `allow-is-absolute`", "type": "string", "const": "core:path:default", "markdownDescription": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-resolve-directory`\n- `allow-resolve`\n- `allow-normalize`\n- `allow-join`\n- `allow-dirname`\n- `allow-extname`\n- `allow-basename`\n- `allow-is-absolute`" }, { "description": "Enables the basename command without any pre-configured scope.", "type": "string", "const": "core:path:allow-basename", "markdownDescription": "Enables the basename command without any pre-configured scope." }, { "description": "Enables the dirname command without any pre-configured scope.", "type": "string", "const": "core:path:allow-dirname", "markdownDescription": "Enables the dirname command without any pre-configured scope." }, { "description": "Enables the extname command without any pre-configured scope.", "type": "string", "const": "core:path:allow-extname", "markdownDescription": "Enables the extname command without any pre-configured scope." }, { "description": "Enables the is_absolute command without any pre-configured scope.", "type": "string", "const": "core:path:allow-is-absolute", "markdownDescription": "Enables the is_absolute command without any pre-configured scope." }, { "description": "Enables the join command without any pre-configured scope.", "type": "string", "const": "core:path:allow-join", "markdownDescription": "Enables the join command without any pre-configured scope." }, { "description": "Enables the normalize command without any pre-configured scope.", "type": "string", "const": "core:path:allow-normalize", "markdownDescription": "Enables the normalize command without any pre-configured scope." }, { "description": "Enables the resolve command without any pre-configured scope.", "type": "string", "const": "core:path:allow-resolve", "markdownDescription": "Enables the resolve command without any pre-configured scope." }, { "description": "Enables the resolve_directory command without any pre-configured scope.", "type": "string", "const": "core:path:allow-resolve-directory", "markdownDescription": "Enables the resolve_directory command without any pre-configured scope." }, { "description": "Denies the basename command without any pre-configured scope.", "type": "string", "const": "core:path:deny-basename", "markdownDescription": "Denies the basename command without any pre-configured scope." }, { "description": "Denies the dirname command without any pre-configured scope.", "type": "string", "const": "core:path:deny-dirname", "markdownDescription": "Denies the dirname command without any pre-configured scope." }, { "description": "Denies the extname command without any pre-configured scope.", "type": "string", "const": "core:path:deny-extname", "markdownDescription": "Denies the extname command without any pre-configured scope." }, { "description": "Denies the is_absolute command without any pre-configured scope.", "type": "string", "const": "core:path:deny-is-absolute", "markdownDescription": "Denies the is_absolute command without any pre-configured scope." }, { "description": "Denies the join command without any pre-configured scope.", "type": "string", "const": "core:path:deny-join", "markdownDescription": "Denies the join command without any pre-configured scope." }, { "description": "Denies the normalize command without any pre-configured scope.", "type": "string", "const": "core:path:deny-normalize", "markdownDescription": "Denies the normalize command without any pre-configured scope." }, { "description": "Denies the resolve command without any pre-configured scope.", "type": "string", "const": "core:path:deny-resolve", "markdownDescription": "Denies the resolve command without any pre-configured scope." }, { "description": "Denies the resolve_directory command without any pre-configured scope.", "type": "string", "const": "core:path:deny-resolve-directory", "markdownDescription": "Denies the resolve_directory command without any pre-configured scope." }, { "description": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-close`", "type": "string", "const": "core:resources:default", "markdownDescription": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-close`" }, { "description": "Enables the close command without any pre-configured scope.", "type": "string", "const": "core:resources:allow-close", "markdownDescription": "Enables the close command without any pre-configured scope." }, { "description": "Denies the close command without any pre-configured scope.", "type": "string", "const": "core:resources:deny-close", "markdownDescription": "Denies the close command without any pre-configured scope." }, { "description": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-new`\n- `allow-get-by-id`\n- `allow-remove-by-id`\n- `allow-set-icon`\n- `allow-set-menu`\n- `allow-set-tooltip`\n- `allow-set-title`\n- `allow-set-visible`\n- `allow-set-temp-dir-path`\n- `allow-set-icon-as-template`\n- `allow-set-show-menu-on-left-click`", "type": "string", "const": "core:tray:default", "markdownDescription": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-new`\n- `allow-get-by-id`\n- `allow-remove-by-id`\n- `allow-set-icon`\n- `allow-set-menu`\n- `allow-set-tooltip`\n- `allow-set-title`\n- `allow-set-visible`\n- `allow-set-temp-dir-path`\n- `allow-set-icon-as-template`\n- `allow-set-show-menu-on-left-click`" }, { "description": "Enables the get_by_id command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-get-by-id", "markdownDescription": "Enables the get_by_id command without any pre-configured scope." }, { "description": "Enables the new command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-new", "markdownDescription": "Enables the new command without any pre-configured scope." }, { "description": "Enables the remove_by_id command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-remove-by-id", "markdownDescription": "Enables the remove_by_id command without any pre-configured scope." }, { "description": "Enables the set_icon command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-set-icon", "markdownDescription": "Enables the set_icon command without any pre-configured scope." }, { "description": "Enables the set_icon_as_template command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-set-icon-as-template", "markdownDescription": "Enables the set_icon_as_template command without any pre-configured scope." }, { "description": "Enables the set_menu command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-set-menu", "markdownDescription": "Enables the set_menu command without any pre-configured scope." }, { "description": "Enables the set_show_menu_on_left_click command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-set-show-menu-on-left-click", "markdownDescription": "Enables the set_show_menu_on_left_click command without any pre-configured scope." }, { "description": "Enables the set_temp_dir_path command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-set-temp-dir-path", "markdownDescription": "Enables the set_temp_dir_path command without any pre-configured scope." }, { "description": "Enables the set_title command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-set-title", "markdownDescription": "Enables the set_title command without any pre-configured scope." }, { "description": "Enables the set_tooltip command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-set-tooltip", "markdownDescription": "Enables the set_tooltip command without any pre-configured scope." }, { "description": "Enables the set_visible command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-set-visible", "markdownDescription": "Enables the set_visible command without any pre-configured scope." }, { "description": "Denies the get_by_id command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-get-by-id", "markdownDescription": "Denies the get_by_id command without any pre-configured scope." }, { "description": "Denies the new command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-new", "markdownDescription": "Denies the new command without any pre-configured scope." }, { "description": "Denies the remove_by_id command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-remove-by-id", "markdownDescription": "Denies the remove_by_id command without any pre-configured scope." }, { "description": "Denies the set_icon command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-set-icon", "markdownDescription": "Denies the set_icon command without any pre-configured scope." }, { "description": "Denies the set_icon_as_template command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-set-icon-as-template", "markdownDescription": "Denies the set_icon_as_template command without any pre-configured scope." }, { "description": "Denies the set_menu command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-set-menu", "markdownDescription": "Denies the set_menu command without any pre-configured scope." }, { "description": "Denies the set_show_menu_on_left_click command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-set-show-menu-on-left-click", "markdownDescription": "Denies the set_show_menu_on_left_click command without any pre-configured scope." }, { "description": "Denies the set_temp_dir_path command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-set-temp-dir-path", "markdownDescription": "Denies the set_temp_dir_path command without any pre-configured scope." }, { "description": "Denies the set_title command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-set-title", "markdownDescription": "Denies the set_title command without any pre-configured scope." }, { "description": "Denies the set_tooltip command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-set-tooltip", "markdownDescription": "Denies the set_tooltip command without any pre-configured scope." }, { "description": "Denies the set_visible command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-set-visible", "markdownDescription": "Denies the set_visible command without any pre-configured scope." }, { "description": "Default permissions for the plugin.\n#### This default permission set includes:\n\n- `allow-get-all-webviews`\n- `allow-webview-position`\n- `allow-webview-size`\n- `allow-internal-toggle-devtools`", "type": "string", "const": "core:webview:default", "markdownDescription": "Default permissions for the plugin.\n#### This default permission set includes:\n\n- `allow-get-all-webviews`\n- `allow-webview-position`\n- `allow-webview-size`\n- `allow-internal-toggle-devtools`" }, { "description": "Enables the clear_all_browsing_data command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-clear-all-browsing-data", "markdownDescription": "Enables the clear_all_browsing_data command without any pre-configured scope." }, { "description": "Enables the create_webview command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-create-webview", "markdownDescription": "Enables the create_webview command without any pre-configured scope." }, { "description": "Enables the create_webview_window command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-create-webview-window", "markdownDescription": "Enables the create_webview_window command without any pre-configured scope." }, { "description": "Enables the get_all_webviews command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-get-all-webviews", "markdownDescription": "Enables the get_all_webviews command without any pre-configured scope." }, { "description": "Enables the internal_toggle_devtools command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-internal-toggle-devtools", "markdownDescription": "Enables the internal_toggle_devtools command without any pre-configured scope." }, { "description": "Enables the print command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-print", "markdownDescription": "Enables the print command without any pre-configured scope." }, { "description": "Enables the reparent command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-reparent", "markdownDescription": "Enables the reparent command without any pre-configured scope." }, { "description": "Enables the set_webview_auto_resize command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-set-webview-auto-resize", "markdownDescription": "Enables the set_webview_auto_resize command without any pre-configured scope." }, { "description": "Enables the set_webview_background_color command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-set-webview-background-color", "markdownDescription": "Enables the set_webview_background_color command without any pre-configured scope." }, { "description": "Enables the set_webview_focus command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-set-webview-focus", "markdownDescription": "Enables the set_webview_focus command without any pre-configured scope." }, { "description": "Enables the set_webview_position command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-set-webview-position", "markdownDescription": "Enables the set_webview_position command without any pre-configured scope." }, { "description": "Enables the set_webview_size command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-set-webview-size", "markdownDescription": "Enables the set_webview_size command without any pre-configured scope." }, { "description": "Enables the set_webview_zoom command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-set-webview-zoom", "markdownDescription": "Enables the set_webview_zoom command without any pre-configured scope." }, { "description": "Enables the webview_close command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-webview-close", "markdownDescription": "Enables the webview_close command without any pre-configured scope." }, { "description": "Enables the webview_hide command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-webview-hide", "markdownDescription": "Enables the webview_hide command without any pre-configured scope." }, { "description": "Enables the webview_position command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-webview-position", "markdownDescription": "Enables the webview_position command without any pre-configured scope." }, { "description": "Enables the webview_show command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-webview-show", "markdownDescription": "Enables the webview_show command without any pre-configured scope." }, { "description": "Enables the webview_size command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-webview-size", "markdownDescription": "Enables the webview_size command without any pre-configured scope." }, { "description": "Denies the clear_all_browsing_data command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-clear-all-browsing-data", "markdownDescription": "Denies the clear_all_browsing_data command without any pre-configured scope." }, { "description": "Denies the create_webview command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-create-webview", "markdownDescription": "Denies the create_webview command without any pre-configured scope." }, { "description": "Denies the create_webview_window command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-create-webview-window", "markdownDescription": "Denies the create_webview_window command without any pre-configured scope." }, { "description": "Denies the get_all_webviews command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-get-all-webviews", "markdownDescription": "Denies the get_all_webviews command without any pre-configured scope." }, { "description": "Denies the internal_toggle_devtools command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-internal-toggle-devtools", "markdownDescription": "Denies the internal_toggle_devtools command without any pre-configured scope." }, { "description": "Denies the print command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-print", "markdownDescription": "Denies the print command without any pre-configured scope." }, { "description": "Denies the reparent command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-reparent", "markdownDescription": "Denies the reparent command without any pre-configured scope." }, { "description": "Denies the set_webview_auto_resize command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-set-webview-auto-resize", "markdownDescription": "Denies the set_webview_auto_resize command without any pre-configured scope." }, { "description": "Denies the set_webview_background_color command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-set-webview-background-color", "markdownDescription": "Denies the set_webview_background_color command without any pre-configured scope." }, { "description": "Denies the set_webview_focus command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-set-webview-focus", "markdownDescription": "Denies the set_webview_focus command without any pre-configured scope." }, { "description": "Denies the set_webview_position command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-set-webview-position", "markdownDescription": "Denies the set_webview_position command without any pre-configured scope." }, { "description": "Denies the set_webview_size command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-set-webview-size", "markdownDescription": "Denies the set_webview_size command without any pre-configured scope." }, { "description": "Denies the set_webview_zoom command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-set-webview-zoom", "markdownDescription": "Denies the set_webview_zoom command without any pre-configured scope." }, { "description": "Denies the webview_close command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-webview-close", "markdownDescription": "Denies the webview_close command without any pre-configured scope." }, { "description": "Denies the webview_hide command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-webview-hide", "markdownDescription": "Denies the webview_hide command without any pre-configured scope." }, { "description": "Denies the webview_position command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-webview-position", "markdownDescription": "Denies the webview_position command without any pre-configured scope." }, { "description": "Denies the webview_show command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-webview-show", "markdownDescription": "Denies the webview_show command without any pre-configured scope." }, { "description": "Denies the webview_size command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-webview-size", "markdownDescription": "Denies the webview_size command without any pre-configured scope." }, { "description": "Default permissions for the plugin.\n#### This default permission set includes:\n\n- `allow-get-all-windows`\n- `allow-scale-factor`\n- `allow-inner-position`\n- `allow-outer-position`\n- `allow-inner-size`\n- `allow-outer-size`\n- `allow-is-fullscreen`\n- `allow-is-minimized`\n- `allow-is-maximized`\n- `allow-is-focused`\n- `allow-is-decorated`\n- `allow-is-resizable`\n- `allow-is-maximizable`\n- `allow-is-minimizable`\n- `allow-is-closable`\n- `allow-is-visible`\n- `allow-is-enabled`\n- `allow-title`\n- `allow-current-monitor`\n- `allow-primary-monitor`\n- `allow-monitor-from-point`\n- `allow-available-monitors`\n- `allow-cursor-position`\n- `allow-theme`\n- `allow-is-always-on-top`\n- `allow-internal-toggle-maximize`", "type": "string", "const": "core:window:default", "markdownDescription": "Default permissions for the plugin.\n#### This default permission set includes:\n\n- `allow-get-all-windows`\n- `allow-scale-factor`\n- `allow-inner-position`\n- `allow-outer-position`\n- `allow-inner-size`\n- `allow-outer-size`\n- `allow-is-fullscreen`\n- `allow-is-minimized`\n- `allow-is-maximized`\n- `allow-is-focused`\n- `allow-is-decorated`\n- `allow-is-resizable`\n- `allow-is-maximizable`\n- `allow-is-minimizable`\n- `allow-is-closable`\n- `allow-is-visible`\n- `allow-is-enabled`\n- `allow-title`\n- `allow-current-monitor`\n- `allow-primary-monitor`\n- `allow-monitor-from-point`\n- `allow-available-monitors`\n- `allow-cursor-position`\n- `allow-theme`\n- `allow-is-always-on-top`\n- `allow-internal-toggle-maximize`" }, { "description": "Enables the available_monitors command without any pre-configured scope.", "type": "string", "const": "core:window:allow-available-monitors", "markdownDescription": "Enables the available_monitors command without any pre-configured scope." }, { "description": "Enables the center command without any pre-configured scope.", "type": "string", "const": "core:window:allow-center", "markdownDescription": "Enables the center command without any pre-configured scope." }, { "description": "Enables the close command without any pre-configured scope.", "type": "string", "const": "core:window:allow-close", "markdownDescription": "Enables the close command without any pre-configured scope." }, { "description": "Enables the create command without any pre-configured scope.", "type": "string", "const": "core:window:allow-create", "markdownDescription": "Enables the create command without any pre-configured scope." }, { "description": "Enables the current_monitor command without any pre-configured scope.", "type": "string", "const": "core:window:allow-current-monitor", "markdownDescription": "Enables the current_monitor command without any pre-configured scope." }, { "description": "Enables the cursor_position command without any pre-configured scope.", "type": "string", "const": "core:window:allow-cursor-position", "markdownDescription": "Enables the cursor_position command without any pre-configured scope." }, { "description": "Enables the destroy command without any pre-configured scope.", "type": "string", "const": "core:window:allow-destroy", "markdownDescription": "Enables the destroy command without any pre-configured scope." }, { "description": "Enables the get_all_windows command without any pre-configured scope.", "type": "string", "const": "core:window:allow-get-all-windows", "markdownDescription": "Enables the get_all_windows command without any pre-configured scope." }, { "description": "Enables the hide command without any pre-configured scope.", "type": "string", "const": "core:window:allow-hide", "markdownDescription": "Enables the hide command without any pre-configured scope." }, { "description": "Enables the inner_position command without any pre-configured scope.", "type": "string", "const": "core:window:allow-inner-position", "markdownDescription": "Enables the inner_position command without any pre-configured scope." }, { "description": "Enables the inner_size command without any pre-configured scope.", "type": "string", "const": "core:window:allow-inner-size", "markdownDescription": "Enables the inner_size command without any pre-configured scope." }, { "description": "Enables the internal_toggle_maximize command without any pre-configured scope.", "type": "string", "const": "core:window:allow-internal-toggle-maximize", "markdownDescription": "Enables the internal_toggle_maximize command without any pre-configured scope." }, { "description": "Enables the is_always_on_top command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-always-on-top", "markdownDescription": "Enables the is_always_on_top command without any pre-configured scope." }, { "description": "Enables the is_closable command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-closable", "markdownDescription": "Enables the is_closable command without any pre-configured scope." }, { "description": "Enables the is_decorated command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-decorated", "markdownDescription": "Enables the is_decorated command without any pre-configured scope." }, { "description": "Enables the is_enabled command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-enabled", "markdownDescription": "Enables the is_enabled command without any pre-configured scope." }, { "description": "Enables the is_focused command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-focused", "markdownDescription": "Enables the is_focused command without any pre-configured scope." }, { "description": "Enables the is_fullscreen command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-fullscreen", "markdownDescription": "Enables the is_fullscreen command without any pre-configured scope." }, { "description": "Enables the is_maximizable command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-maximizable", "markdownDescription": "Enables the is_maximizable command without any pre-configured scope." }, { "description": "Enables the is_maximized command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-maximized", "markdownDescription": "Enables the is_maximized command without any pre-configured scope." }, { "description": "Enables the is_minimizable command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-minimizable", "markdownDescription": "Enables the is_minimizable command without any pre-configured scope." }, { "description": "Enables the is_minimized command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-minimized", "markdownDescription": "Enables the is_minimized command without any pre-configured scope." }, { "description": "Enables the is_resizable command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-resizable", "markdownDescription": "Enables the is_resizable command without any pre-configured scope." }, { "description": "Enables the is_visible command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-visible", "markdownDescription": "Enables the is_visible command without any pre-configured scope." }, { "description": "Enables the maximize command without any pre-configured scope.", "type": "string", "const": "core:window:allow-maximize", "markdownDescription": "Enables the maximize command without any pre-configured scope." }, { "description": "Enables the minimize command without any pre-configured scope.", "type": "string", "const": "core:window:allow-minimize", "markdownDescription": "Enables the minimize command without any pre-configured scope." }, { "description": "Enables the monitor_from_point command without any pre-configured scope.", "type": "string", "const": "core:window:allow-monitor-from-point", "markdownDescription": "Enables the monitor_from_point command without any pre-configured scope." }, { "description": "Enables the outer_position command without any pre-configured scope.", "type": "string", "const": "core:window:allow-outer-position", "markdownDescription": "Enables the outer_position command without any pre-configured scope." }, { "description": "Enables the outer_size command without any pre-configured scope.", "type": "string", "const": "core:window:allow-outer-size", "markdownDescription": "Enables the outer_size command without any pre-configured scope." }, { "description": "Enables the primary_monitor command without any pre-configured scope.", "type": "string", "const": "core:window:allow-primary-monitor", "markdownDescription": "Enables the primary_monitor command without any pre-configured scope." }, { "description": "Enables the request_user_attention command without any pre-configured scope.", "type": "string", "const": "core:window:allow-request-user-attention", "markdownDescription": "Enables the request_user_attention command without any pre-configured scope." }, { "description": "Enables the scale_factor command without any pre-configured scope.", "type": "string", "const": "core:window:allow-scale-factor", "markdownDescription": "Enables the scale_factor command without any pre-configured scope." }, { "description": "Enables the set_always_on_bottom command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-always-on-bottom", "markdownDescription": "Enables the set_always_on_bottom command without any pre-configured scope." }, { "description": "Enables the set_always_on_top command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-always-on-top", "markdownDescription": "Enables the set_always_on_top command without any pre-configured scope." }, { "description": "Enables the set_background_color command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-background-color", "markdownDescription": "Enables the set_background_color command without any pre-configured scope." }, { "description": "Enables the set_badge_count command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-badge-count", "markdownDescription": "Enables the set_badge_count command without any pre-configured scope." }, { "description": "Enables the set_badge_label command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-badge-label", "markdownDescription": "Enables the set_badge_label command without any pre-configured scope." }, { "description": "Enables the set_closable command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-closable", "markdownDescription": "Enables the set_closable command without any pre-configured scope." }, { "description": "Enables the set_content_protected command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-content-protected", "markdownDescription": "Enables the set_content_protected command without any pre-configured scope." }, { "description": "Enables the set_cursor_grab command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-cursor-grab", "markdownDescription": "Enables the set_cursor_grab command without any pre-configured scope." }, { "description": "Enables the set_cursor_icon command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-cursor-icon", "markdownDescription": "Enables the set_cursor_icon command without any pre-configured scope." }, { "description": "Enables the set_cursor_position command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-cursor-position", "markdownDescription": "Enables the set_cursor_position command without any pre-configured scope." }, { "description": "Enables the set_cursor_visible command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-cursor-visible", "markdownDescription": "Enables the set_cursor_visible command without any pre-configured scope." }, { "description": "Enables the set_decorations command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-decorations", "markdownDescription": "Enables the set_decorations command without any pre-configured scope." }, { "description": "Enables the set_effects command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-effects", "markdownDescription": "Enables the set_effects command without any pre-configured scope." }, { "description": "Enables the set_enabled command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-enabled", "markdownDescription": "Enables the set_enabled command without any pre-configured scope." }, { "description": "Enables the set_focus command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-focus", "markdownDescription": "Enables the set_focus command without any pre-configured scope." }, { "description": "Enables the set_focusable command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-focusable", "markdownDescription": "Enables the set_focusable command without any pre-configured scope." }, { "description": "Enables the set_fullscreen command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-fullscreen", "markdownDescription": "Enables the set_fullscreen command without any pre-configured scope." }, { "description": "Enables the set_icon command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-icon", "markdownDescription": "Enables the set_icon command without any pre-configured scope." }, { "description": "Enables the set_ignore_cursor_events command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-ignore-cursor-events", "markdownDescription": "Enables the set_ignore_cursor_events command without any pre-configured scope." }, { "description": "Enables the set_max_size command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-max-size", "markdownDescription": "Enables the set_max_size command without any pre-configured scope." }, { "description": "Enables the set_maximizable command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-maximizable", "markdownDescription": "Enables the set_maximizable command without any pre-configured scope." }, { "description": "Enables the set_min_size command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-min-size", "markdownDescription": "Enables the set_min_size command without any pre-configured scope." }, { "description": "Enables the set_minimizable command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-minimizable", "markdownDescription": "Enables the set_minimizable command without any pre-configured scope." }, { "description": "Enables the set_overlay_icon command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-overlay-icon", "markdownDescription": "Enables the set_overlay_icon command without any pre-configured scope." }, { "description": "Enables the set_position command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-position", "markdownDescription": "Enables the set_position command without any pre-configured scope." }, { "description": "Enables the set_progress_bar command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-progress-bar", "markdownDescription": "Enables the set_progress_bar command without any pre-configured scope." }, { "description": "Enables the set_resizable command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-resizable", "markdownDescription": "Enables the set_resizable command without any pre-configured scope." }, { "description": "Enables the set_shadow command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-shadow", "markdownDescription": "Enables the set_shadow command without any pre-configured scope." }, { "description": "Enables the set_simple_fullscreen command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-simple-fullscreen", "markdownDescription": "Enables the set_simple_fullscreen command without any pre-configured scope." }, { "description": "Enables the set_size command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-size", "markdownDescription": "Enables the set_size command without any pre-configured scope." }, { "description": "Enables the set_size_constraints command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-size-constraints", "markdownDescription": "Enables the set_size_constraints command without any pre-configured scope." }, { "description": "Enables the set_skip_taskbar command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-skip-taskbar", "markdownDescription": "Enables the set_skip_taskbar command without any pre-configured scope." }, { "description": "Enables the set_theme command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-theme", "markdownDescription": "Enables the set_theme command without any pre-configured scope." }, { "description": "Enables the set_title command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-title", "markdownDescription": "Enables the set_title command without any pre-configured scope." }, { "description": "Enables the set_title_bar_style command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-title-bar-style", "markdownDescription": "Enables the set_title_bar_style command without any pre-configured scope." }, { "description": "Enables the set_visible_on_all_workspaces command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-visible-on-all-workspaces", "markdownDescription": "Enables the set_visible_on_all_workspaces command without any pre-configured scope." }, { "description": "Enables the show command without any pre-configured scope.", "type": "string", "const": "core:window:allow-show", "markdownDescription": "Enables the show command without any pre-configured scope." }, { "description": "Enables the start_dragging command without any pre-configured scope.", "type": "string", "const": "core:window:allow-start-dragging", "markdownDescription": "Enables the start_dragging command without any pre-configured scope." }, { "description": "Enables the start_resize_dragging command without any pre-configured scope.", "type": "string", "const": "core:window:allow-start-resize-dragging", "markdownDescription": "Enables the start_resize_dragging command without any pre-configured scope." }, { "description": "Enables the theme command without any pre-configured scope.", "type": "string", "const": "core:window:allow-theme", "markdownDescription": "Enables the theme command without any pre-configured scope." }, { "description": "Enables the title command without any pre-configured scope.", "type": "string", "const": "core:window:allow-title", "markdownDescription": "Enables the title command without any pre-configured scope." }, { "description": "Enables the toggle_maximize command without any pre-configured scope.", "type": "string", "const": "core:window:allow-toggle-maximize", "markdownDescription": "Enables the toggle_maximize command without any pre-configured scope." }, { "description": "Enables the unmaximize command without any pre-configured scope.", "type": "string", "const": "core:window:allow-unmaximize", "markdownDescription": "Enables the unmaximize command without any pre-configured scope." }, { "description": "Enables the unminimize command without any pre-configured scope.", "type": "string", "const": "core:window:allow-unminimize", "markdownDescription": "Enables the unminimize command without any pre-configured scope." }, { "description": "Denies the available_monitors command without any pre-configured scope.", "type": "string", "const": "core:window:deny-available-monitors", "markdownDescription": "Denies the available_monitors command without any pre-configured scope." }, { "description": "Denies the center command without any pre-configured scope.", "type": "string", "const": "core:window:deny-center", "markdownDescription": "Denies the center command without any pre-configured scope." }, { "description": "Denies the close command without any pre-configured scope.", "type": "string", "const": "core:window:deny-close", "markdownDescription": "Denies the close command without any pre-configured scope." }, { "description": "Denies the create command without any pre-configured scope.", "type": "string", "const": "core:window:deny-create", "markdownDescription": "Denies the create command without any pre-configured scope." }, { "description": "Denies the current_monitor command without any pre-configured scope.", "type": "string", "const": "core:window:deny-current-monitor", "markdownDescription": "Denies the current_monitor command without any pre-configured scope." }, { "description": "Denies the cursor_position command without any pre-configured scope.", "type": "string", "const": "core:window:deny-cursor-position", "markdownDescription": "Denies the cursor_position command without any pre-configured scope." }, { "description": "Denies the destroy command without any pre-configured scope.", "type": "string", "const": "core:window:deny-destroy", "markdownDescription": "Denies the destroy command without any pre-configured scope." }, { "description": "Denies the get_all_windows command without any pre-configured scope.", "type": "string", "const": "core:window:deny-get-all-windows", "markdownDescription": "Denies the get_all_windows command without any pre-configured scope." }, { "description": "Denies the hide command without any pre-configured scope.", "type": "string", "const": "core:window:deny-hide", "markdownDescription": "Denies the hide command without any pre-configured scope." }, { "description": "Denies the inner_position command without any pre-configured scope.", "type": "string", "const": "core:window:deny-inner-position", "markdownDescription": "Denies the inner_position command without any pre-configured scope." }, { "description": "Denies the inner_size command without any pre-configured scope.", "type": "string", "const": "core:window:deny-inner-size", "markdownDescription": "Denies the inner_size command without any pre-configured scope." }, { "description": "Denies the internal_toggle_maximize command without any pre-configured scope.", "type": "string", "const": "core:window:deny-internal-toggle-maximize", "markdownDescription": "Denies the internal_toggle_maximize command without any pre-configured scope." }, { "description": "Denies the is_always_on_top command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-always-on-top", "markdownDescription": "Denies the is_always_on_top command without any pre-configured scope." }, { "description": "Denies the is_closable command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-closable", "markdownDescription": "Denies the is_closable command without any pre-configured scope." }, { "description": "Denies the is_decorated command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-decorated", "markdownDescription": "Denies the is_decorated command without any pre-configured scope." }, { "description": "Denies the is_enabled command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-enabled", "markdownDescription": "Denies the is_enabled command without any pre-configured scope." }, { "description": "Denies the is_focused command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-focused", "markdownDescription": "Denies the is_focused command without any pre-configured scope." }, { "description": "Denies the is_fullscreen command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-fullscreen", "markdownDescription": "Denies the is_fullscreen command without any pre-configured scope." }, { "description": "Denies the is_maximizable command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-maximizable", "markdownDescription": "Denies the is_maximizable command without any pre-configured scope." }, { "description": "Denies the is_maximized command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-maximized", "markdownDescription": "Denies the is_maximized command without any pre-configured scope." }, { "description": "Denies the is_minimizable command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-minimizable", "markdownDescription": "Denies the is_minimizable command without any pre-configured scope." }, { "description": "Denies the is_minimized command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-minimized", "markdownDescription": "Denies the is_minimized command without any pre-configured scope." }, { "description": "Denies the is_resizable command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-resizable", "markdownDescription": "Denies the is_resizable command without any pre-configured scope." }, { "description": "Denies the is_visible command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-visible", "markdownDescription": "Denies the is_visible command without any pre-configured scope." }, { "description": "Denies the maximize command without any pre-configured scope.", "type": "string", "const": "core:window:deny-maximize", "markdownDescription": "Denies the maximize command without any pre-configured scope." }, { "description": "Denies the minimize command without any pre-configured scope.", "type": "string", "const": "core:window:deny-minimize", "markdownDescription": "Denies the minimize command without any pre-configured scope." }, { "description": "Denies the monitor_from_point command without any pre-configured scope.", "type": "string", "const": "core:window:deny-monitor-from-point", "markdownDescription": "Denies the monitor_from_point command without any pre-configured scope." }, { "description": "Denies the outer_position command without any pre-configured scope.", "type": "string", "const": "core:window:deny-outer-position", "markdownDescription": "Denies the outer_position command without any pre-configured scope." }, { "description": "Denies the outer_size command without any pre-configured scope.", "type": "string", "const": "core:window:deny-outer-size", "markdownDescription": "Denies the outer_size command without any pre-configured scope." }, { "description": "Denies the primary_monitor command without any pre-configured scope.", "type": "string", "const": "core:window:deny-primary-monitor", "markdownDescription": "Denies the primary_monitor command without any pre-configured scope." }, { "description": "Denies the request_user_attention command without any pre-configured scope.", "type": "string", "const": "core:window:deny-request-user-attention", "markdownDescription": "Denies the request_user_attention command without any pre-configured scope." }, { "description": "Denies the scale_factor command without any pre-configured scope.", "type": "string", "const": "core:window:deny-scale-factor", "markdownDescription": "Denies the scale_factor command without any pre-configured scope." }, { "description": "Denies the set_always_on_bottom command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-always-on-bottom", "markdownDescription": "Denies the set_always_on_bottom command without any pre-configured scope." }, { "description": "Denies the set_always_on_top command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-always-on-top", "markdownDescription": "Denies the set_always_on_top command without any pre-configured scope." }, { "description": "Denies the set_background_color command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-background-color", "markdownDescription": "Denies the set_background_color command without any pre-configured scope." }, { "description": "Denies the set_badge_count command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-badge-count", "markdownDescription": "Denies the set_badge_count command without any pre-configured scope." }, { "description": "Denies the set_badge_label command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-badge-label", "markdownDescription": "Denies the set_badge_label command without any pre-configured scope." }, { "description": "Denies the set_closable command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-closable", "markdownDescription": "Denies the set_closable command without any pre-configured scope." }, { "description": "Denies the set_content_protected command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-content-protected", "markdownDescription": "Denies the set_content_protected command without any pre-configured scope." }, { "description": "Denies the set_cursor_grab command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-cursor-grab", "markdownDescription": "Denies the set_cursor_grab command without any pre-configured scope." }, { "description": "Denies the set_cursor_icon command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-cursor-icon", "markdownDescription": "Denies the set_cursor_icon command without any pre-configured scope." }, { "description": "Denies the set_cursor_position command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-cursor-position", "markdownDescription": "Denies the set_cursor_position command without any pre-configured scope." }, { "description": "Denies the set_cursor_visible command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-cursor-visible", "markdownDescription": "Denies the set_cursor_visible command without any pre-configured scope." }, { "description": "Denies the set_decorations command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-decorations", "markdownDescription": "Denies the set_decorations command without any pre-configured scope." }, { "description": "Denies the set_effects command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-effects", "markdownDescription": "Denies the set_effects command without any pre-configured scope." }, { "description": "Denies the set_enabled command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-enabled", "markdownDescription": "Denies the set_enabled command without any pre-configured scope." }, { "description": "Denies the set_focus command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-focus", "markdownDescription": "Denies the set_focus command without any pre-configured scope." }, { "description": "Denies the set_focusable command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-focusable", "markdownDescription": "Denies the set_focusable command without any pre-configured scope." }, { "description": "Denies the set_fullscreen command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-fullscreen", "markdownDescription": "Denies the set_fullscreen command without any pre-configured scope." }, { "description": "Denies the set_icon command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-icon", "markdownDescription": "Denies the set_icon command without any pre-configured scope." }, { "description": "Denies the set_ignore_cursor_events command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-ignore-cursor-events", "markdownDescription": "Denies the set_ignore_cursor_events command without any pre-configured scope." }, { "description": "Denies the set_max_size command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-max-size", "markdownDescription": "Denies the set_max_size command without any pre-configured scope." }, { "description": "Denies the set_maximizable command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-maximizable", "markdownDescription": "Denies the set_maximizable command without any pre-configured scope." }, { "description": "Denies the set_min_size command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-min-size", "markdownDescription": "Denies the set_min_size command without any pre-configured scope." }, { "description": "Denies the set_minimizable command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-minimizable", "markdownDescription": "Denies the set_minimizable command without any pre-configured scope." }, { "description": "Denies the set_overlay_icon command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-overlay-icon", "markdownDescription": "Denies the set_overlay_icon command without any pre-configured scope." }, { "description": "Denies the set_position command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-position", "markdownDescription": "Denies the set_position command without any pre-configured scope." }, { "description": "Denies the set_progress_bar command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-progress-bar", "markdownDescription": "Denies the set_progress_bar command without any pre-configured scope." }, { "description": "Denies the set_resizable command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-resizable", "markdownDescription": "Denies the set_resizable command without any pre-configured scope." }, { "description": "Denies the set_shadow command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-shadow", "markdownDescription": "Denies the set_shadow command without any pre-configured scope." }, { "description": "Denies the set_simple_fullscreen command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-simple-fullscreen", "markdownDescription": "Denies the set_simple_fullscreen command without any pre-configured scope." }, { "description": "Denies the set_size command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-size", "markdownDescription": "Denies the set_size command without any pre-configured scope." }, { "description": "Denies the set_size_constraints command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-size-constraints", "markdownDescription": "Denies the set_size_constraints command without any pre-configured scope." }, { "description": "Denies the set_skip_taskbar command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-skip-taskbar", "markdownDescription": "Denies the set_skip_taskbar command without any pre-configured scope." }, { "description": "Denies the set_theme command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-theme", "markdownDescription": "Denies the set_theme command without any pre-configured scope." }, { "description": "Denies the set_title command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-title", "markdownDescription": "Denies the set_title command without any pre-configured scope." }, { "description": "Denies the set_title_bar_style command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-title-bar-style", "markdownDescription": "Denies the set_title_bar_style command without any pre-configured scope." }, { "description": "Denies the set_visible_on_all_workspaces command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-visible-on-all-workspaces", "markdownDescription": "Denies the set_visible_on_all_workspaces command without any pre-configured scope." }, { "description": "Denies the show command without any pre-configured scope.", "type": "string", "const": "core:window:deny-show", "markdownDescription": "Denies the show command without any pre-configured scope." }, { "description": "Denies the start_dragging command without any pre-configured scope.", "type": "string", "const": "core:window:deny-start-dragging", "markdownDescription": "Denies the start_dragging command without any pre-configured scope." }, { "description": "Denies the start_resize_dragging command without any pre-configured scope.", "type": "string", "const": "core:window:deny-start-resize-dragging", "markdownDescription": "Denies the start_resize_dragging command without any pre-configured scope." }, { "description": "Denies the theme command without any pre-configured scope.", "type": "string", "const": "core:window:deny-theme", "markdownDescription": "Denies the theme command without any pre-configured scope." }, { "description": "Denies the title command without any pre-configured scope.", "type": "string", "const": "core:window:deny-title", "markdownDescription": "Denies the title command without any pre-configured scope." }, { "description": "Denies the toggle_maximize command without any pre-configured scope.", "type": "string", "const": "core:window:deny-toggle-maximize", "markdownDescription": "Denies the toggle_maximize command without any pre-configured scope." }, { "description": "Denies the unmaximize command without any pre-configured scope.", "type": "string", "const": "core:window:deny-unmaximize", "markdownDescription": "Denies the unmaximize command without any pre-configured scope." }, { "description": "Denies the unminimize command without any pre-configured scope.", "type": "string", "const": "core:window:deny-unminimize", "markdownDescription": "Denies the unminimize command without any pre-configured scope." }, { "description": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`", "type": "string", "const": "shell:default", "markdownDescription": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`" }, { "description": "Enables the execute command without any pre-configured scope.", "type": "string", "const": "shell:allow-execute", "markdownDescription": "Enables the execute command without any pre-configured scope." }, { "description": "Enables the kill command without any pre-configured scope.", "type": "string", "const": "shell:allow-kill", "markdownDescription": "Enables the kill command without any pre-configured scope." }, { "description": "Enables the open command without any pre-configured scope.", "type": "string", "const": "shell:allow-open", "markdownDescription": "Enables the open command without any pre-configured scope." }, { "description": "Enables the spawn command without any pre-configured scope.", "type": "string", "const": "shell:allow-spawn", "markdownDescription": "Enables the spawn command without any pre-configured scope." }, { "description": "Enables the stdin_write command without any pre-configured scope.", "type": "string", "const": "shell:allow-stdin-write", "markdownDescription": "Enables the stdin_write command without any pre-configured scope." }, { "description": "Denies the execute command without any pre-configured scope.", "type": "string", "const": "shell:deny-execute", "markdownDescription": "Denies the execute command without any pre-configured scope." }, { "description": "Denies the kill command without any pre-configured scope.", "type": "string", "const": "shell:deny-kill", "markdownDescription": "Denies the kill command without any pre-configured scope." }, { "description": "Denies the open command without any pre-configured scope.", "type": "string", "const": "shell:deny-open", "markdownDescription": "Denies the open command without any pre-configured scope." }, { "description": "Denies the spawn command without any pre-configured scope.", "type": "string", "const": "shell:deny-spawn", "markdownDescription": "Denies the spawn command without any pre-configured scope." }, { "description": "Denies the stdin_write command without any pre-configured scope.", "type": "string", "const": "shell:deny-stdin-write", "markdownDescription": "Denies the stdin_write command without any pre-configured scope." }, { "description": "This permission set configures what kind of\noperations are available from the window state plugin.\n\n#### Granted Permissions\n\nAll operations are enabled by default.\n\n\n#### This default permission set includes:\n\n- `allow-filename`\n- `allow-restore-state`\n- `allow-save-window-state`", "type": "string", "const": "window-state:default", "markdownDescription": "This permission set configures what kind of\noperations are available from the window state plugin.\n\n#### Granted Permissions\n\nAll operations are enabled by default.\n\n\n#### This default permission set includes:\n\n- `allow-filename`\n- `allow-restore-state`\n- `allow-save-window-state`" }, { "description": "Enables the filename command without any pre-configured scope.", "type": "string", "const": "window-state:allow-filename", "markdownDescription": "Enables the filename command without any pre-configured scope." }, { "description": "Enables the restore_state command without any pre-configured scope.", "type": "string", "const": "window-state:allow-restore-state", "markdownDescription": "Enables the restore_state command without any pre-configured scope." }, { "description": "Enables the save_window_state command without any pre-configured scope.", "type": "string", "const": "window-state:allow-save-window-state", "markdownDescription": "Enables the save_window_state command without any pre-configured scope." }, { "description": "Denies the filename command without any pre-configured scope.", "type": "string", "const": "window-state:deny-filename", "markdownDescription": "Denies the filename command without any pre-configured scope." }, { "description": "Denies the restore_state command without any pre-configured scope.", "type": "string", "const": "window-state:deny-restore-state", "markdownDescription": "Denies the restore_state command without any pre-configured scope." }, { "description": "Denies the save_window_state command without any pre-configured scope.", "type": "string", "const": "window-state:deny-save-window-state", "markdownDescription": "Denies the save_window_state command without any pre-configured scope." } ] }, "Value": { "description": "All supported ACL values.", "anyOf": [ { "description": "Represents a null JSON value.", "type": "null" }, { "description": "Represents a [`bool`].", "type": "boolean" }, { "description": "Represents a valid ACL [`Number`].", "allOf": [ { "$ref": "#/definitions/Number" } ] }, { "description": "Represents a [`String`].", "type": "string" }, { "description": "Represents a list of other [`Value`]s.", "type": "array", "items": { "$ref": "#/definitions/Value" } }, { "description": "Represents a map of [`String`] keys to [`Value`]s.", "type": "object", "additionalProperties": { "$ref": "#/definitions/Value" } } ] }, "Number": { "description": "A valid ACL number.", "anyOf": [ { "description": "Represents an [`i64`].", "type": "integer", "format": "int64" }, { "description": "Represents a [`f64`].", "type": "number", "format": "double" } ] }, "Target": { "description": "Platform target.", "oneOf": [ { "description": "MacOS.", "type": "string", "enum": [ "macOS" ] }, { "description": "Windows.", "type": "string", "enum": [ "windows" ] }, { "description": "Linux.", "type": "string", "enum": [ "linux" ] }, { "description": "Android.", "type": "string", "enum": [ "android" ] }, { "description": "iOS.", "type": "string", "enum": [ "iOS" ] } ] }, "ShellScopeEntryAllowedArg": { "description": "A command argument allowed to be executed by the webview API.", "anyOf": [ { "description": "A non-configurable argument that is passed to the command in the order it was specified.", "type": "string" }, { "description": "A variable that is set while calling the command from the webview API.", "type": "object", "required": [ "validator" ], "properties": { "raw": { "description": "Marks the validator as a raw regex, meaning the plugin should not make any modification at runtime.\n\nThis means the regex will not match on the entire string by default, which might be exploited if your regex allow unexpected input to be considered valid. When using this option, make sure your regex is correct.", "default": false, "type": "boolean" }, "validator": { "description": "[regex] validator to require passed values to conform to an expected input.\n\nThis will require the argument value passed to this variable to match the `validator` regex before it will be executed.\n\nThe regex string is by default surrounded by `^...$` to match the full string. For example the `https?://\\w+` regex would be registered as `^https?://\\w+$`.\n\n[regex]: ", "type": "string" } }, "additionalProperties": false } ] }, "ShellScopeEntryAllowedArgs": { "description": "A set of command arguments allowed to be executed by the webview API.\n\nA value of `true` will allow any arguments to be passed to the command. `false` will disable all arguments. A list of [`ShellScopeEntryAllowedArg`] will set those arguments as the only valid arguments to be passed to the attached command configuration.", "anyOf": [ { "description": "Use a simple boolean to allow all or disable all arguments to this command configuration.", "type": "boolean" }, { "description": "A specific set of [`ShellScopeEntryAllowedArg`] that are valid to call for the command configuration.", "type": "array", "items": { "$ref": "#/definitions/ShellScopeEntryAllowedArg" } } ] } } } ================================================ FILE: desktop/src-tauri/gen/schemas/macOS-schema.json ================================================ { "$schema": "http://json-schema.org/draft-07/schema#", "title": "CapabilityFile", "description": "Capability formats accepted in a capability file.", "anyOf": [ { "description": "A single capability.", "allOf": [ { "$ref": "#/definitions/Capability" } ] }, { "description": "A list of capabilities.", "type": "array", "items": { "$ref": "#/definitions/Capability" } }, { "description": "A list of capabilities.", "type": "object", "required": [ "capabilities" ], "properties": { "capabilities": { "description": "The list of capabilities.", "type": "array", "items": { "$ref": "#/definitions/Capability" } } } } ], "definitions": { "Capability": { "description": "A grouping and boundary mechanism developers can use to isolate access to the IPC layer.\n\nIt controls application windows' and webviews' fine grained access to the Tauri core, application, or plugin commands. If a webview or its window is not matching any capability then it has no access to the IPC layer at all.\n\nThis can be done to create groups of windows, based on their required system access, which can reduce impact of frontend vulnerabilities in less privileged windows. Windows can be added to a capability by exact name (e.g. `main-window`) or glob patterns like `*` or `admin-*`. A Window can have none, one, or multiple associated capabilities.\n\n## Example\n\n```json { \"identifier\": \"main-user-files-write\", \"description\": \"This capability allows the `main` window on macOS and Windows access to `filesystem` write related commands and `dialog` commands to enable programmatic access to files selected by the user.\", \"windows\": [ \"main\" ], \"permissions\": [ \"core:default\", \"dialog:open\", { \"identifier\": \"fs:allow-write-text-file\", \"allow\": [{ \"path\": \"$HOME/test.txt\" }] }, ], \"platforms\": [\"macOS\",\"windows\"] } ```", "type": "object", "required": [ "identifier", "permissions" ], "properties": { "identifier": { "description": "Identifier of the capability.\n\n## Example\n\n`main-user-files-write`", "type": "string" }, "description": { "description": "Description of what the capability is intended to allow on associated windows.\n\nIt should contain a description of what the grouped permissions should allow.\n\n## Example\n\nThis capability allows the `main` window access to `filesystem` write related commands and `dialog` commands to enable programmatic access to files selected by the user.", "default": "", "type": "string" }, "remote": { "description": "Configure remote URLs that can use the capability permissions.\n\nThis setting is optional and defaults to not being set, as our default use case is that the content is served from our local application.\n\n:::caution Make sure you understand the security implications of providing remote sources with local system access. :::\n\n## Example\n\n```json { \"urls\": [\"https://*.mydomain.dev\"] } ```", "anyOf": [ { "$ref": "#/definitions/CapabilityRemote" }, { "type": "null" } ] }, "local": { "description": "Whether this capability is enabled for local app URLs or not. Defaults to `true`.", "default": true, "type": "boolean" }, "windows": { "description": "List of windows that are affected by this capability. Can be a glob pattern.\n\nIf a window label matches any of the patterns in this list, the capability will be enabled on all the webviews of that window, regardless of the value of [`Self::webviews`].\n\nOn multiwebview windows, prefer specifying [`Self::webviews`] and omitting [`Self::windows`] for a fine grained access control.\n\n## Example\n\n`[\"main\"]`", "type": "array", "items": { "type": "string" } }, "webviews": { "description": "List of webviews that are affected by this capability. Can be a glob pattern.\n\nThe capability will be enabled on all the webviews whose label matches any of the patterns in this list, regardless of whether the webview's window label matches a pattern in [`Self::windows`].\n\n## Example\n\n`[\"sub-webview-one\", \"sub-webview-two\"]`", "type": "array", "items": { "type": "string" } }, "permissions": { "description": "List of permissions attached to this capability.\n\nMust include the plugin name as prefix in the form of `${plugin-name}:${permission-name}`. For commands directly implemented in the application itself only `${permission-name}` is required.\n\n## Example\n\n```json [ \"core:default\", \"shell:allow-open\", \"dialog:open\", { \"identifier\": \"fs:allow-write-text-file\", \"allow\": [{ \"path\": \"$HOME/test.txt\" }] } ] ```", "type": "array", "items": { "$ref": "#/definitions/PermissionEntry" }, "uniqueItems": true }, "platforms": { "description": "Limit which target platforms this capability applies to.\n\nBy default all platforms are targeted.\n\n## Example\n\n`[\"macOS\",\"windows\"]`", "type": [ "array", "null" ], "items": { "$ref": "#/definitions/Target" } } } }, "CapabilityRemote": { "description": "Configuration for remote URLs that are associated with the capability.", "type": "object", "required": [ "urls" ], "properties": { "urls": { "description": "Remote domains this capability refers to using the [URLPattern standard](https://urlpattern.spec.whatwg.org/).\n\n## Examples\n\n- \"https://*.mydomain.dev\": allows subdomains of mydomain.dev - \"https://mydomain.dev/api/*\": allows any subpath of mydomain.dev/api", "type": "array", "items": { "type": "string" } } } }, "PermissionEntry": { "description": "An entry for a permission value in a [`Capability`] can be either a raw permission [`Identifier`] or an object that references a permission and extends its scope.", "anyOf": [ { "description": "Reference a permission or permission set by identifier.", "allOf": [ { "$ref": "#/definitions/Identifier" } ] }, { "description": "Reference a permission or permission set by identifier and extends its scope.", "type": "object", "allOf": [ { "if": { "properties": { "identifier": { "anyOf": [ { "description": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`", "type": "string", "const": "shell:default", "markdownDescription": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`" }, { "description": "Enables the execute command without any pre-configured scope.", "type": "string", "const": "shell:allow-execute", "markdownDescription": "Enables the execute command without any pre-configured scope." }, { "description": "Enables the kill command without any pre-configured scope.", "type": "string", "const": "shell:allow-kill", "markdownDescription": "Enables the kill command without any pre-configured scope." }, { "description": "Enables the open command without any pre-configured scope.", "type": "string", "const": "shell:allow-open", "markdownDescription": "Enables the open command without any pre-configured scope." }, { "description": "Enables the spawn command without any pre-configured scope.", "type": "string", "const": "shell:allow-spawn", "markdownDescription": "Enables the spawn command without any pre-configured scope." }, { "description": "Enables the stdin_write command without any pre-configured scope.", "type": "string", "const": "shell:allow-stdin-write", "markdownDescription": "Enables the stdin_write command without any pre-configured scope." }, { "description": "Denies the execute command without any pre-configured scope.", "type": "string", "const": "shell:deny-execute", "markdownDescription": "Denies the execute command without any pre-configured scope." }, { "description": "Denies the kill command without any pre-configured scope.", "type": "string", "const": "shell:deny-kill", "markdownDescription": "Denies the kill command without any pre-configured scope." }, { "description": "Denies the open command without any pre-configured scope.", "type": "string", "const": "shell:deny-open", "markdownDescription": "Denies the open command without any pre-configured scope." }, { "description": "Denies the spawn command without any pre-configured scope.", "type": "string", "const": "shell:deny-spawn", "markdownDescription": "Denies the spawn command without any pre-configured scope." }, { "description": "Denies the stdin_write command without any pre-configured scope.", "type": "string", "const": "shell:deny-stdin-write", "markdownDescription": "Denies the stdin_write command without any pre-configured scope." } ] } } }, "then": { "properties": { "allow": { "items": { "title": "ShellScopeEntry", "description": "Shell scope entry.", "anyOf": [ { "type": "object", "required": [ "cmd", "name" ], "properties": { "args": { "description": "The allowed arguments for the command execution.", "allOf": [ { "$ref": "#/definitions/ShellScopeEntryAllowedArgs" } ] }, "cmd": { "description": "The command name. It can start with a variable that resolves to a system base directory. The variables are: `$AUDIO`, `$CACHE`, `$CONFIG`, `$DATA`, `$LOCALDATA`, `$DESKTOP`, `$DOCUMENT`, `$DOWNLOAD`, `$EXE`, `$FONT`, `$HOME`, `$PICTURE`, `$PUBLIC`, `$RUNTIME`, `$TEMPLATE`, `$VIDEO`, `$RESOURCE`, `$LOG`, `$TEMP`, `$APPCONFIG`, `$APPDATA`, `$APPLOCALDATA`, `$APPCACHE`, `$APPLOG`.", "type": "string" }, "name": { "description": "The name for this allowed shell command configuration.\n\nThis name will be used inside of the webview API to call this command along with any specified arguments.", "type": "string" } }, "additionalProperties": false }, { "type": "object", "required": [ "name", "sidecar" ], "properties": { "args": { "description": "The allowed arguments for the command execution.", "allOf": [ { "$ref": "#/definitions/ShellScopeEntryAllowedArgs" } ] }, "name": { "description": "The name for this allowed shell command configuration.\n\nThis name will be used inside of the webview API to call this command along with any specified arguments.", "type": "string" }, "sidecar": { "description": "If this command is a sidecar command.", "type": "boolean" } }, "additionalProperties": false } ] } }, "deny": { "items": { "title": "ShellScopeEntry", "description": "Shell scope entry.", "anyOf": [ { "type": "object", "required": [ "cmd", "name" ], "properties": { "args": { "description": "The allowed arguments for the command execution.", "allOf": [ { "$ref": "#/definitions/ShellScopeEntryAllowedArgs" } ] }, "cmd": { "description": "The command name. It can start with a variable that resolves to a system base directory. The variables are: `$AUDIO`, `$CACHE`, `$CONFIG`, `$DATA`, `$LOCALDATA`, `$DESKTOP`, `$DOCUMENT`, `$DOWNLOAD`, `$EXE`, `$FONT`, `$HOME`, `$PICTURE`, `$PUBLIC`, `$RUNTIME`, `$TEMPLATE`, `$VIDEO`, `$RESOURCE`, `$LOG`, `$TEMP`, `$APPCONFIG`, `$APPDATA`, `$APPLOCALDATA`, `$APPCACHE`, `$APPLOG`.", "type": "string" }, "name": { "description": "The name for this allowed shell command configuration.\n\nThis name will be used inside of the webview API to call this command along with any specified arguments.", "type": "string" } }, "additionalProperties": false }, { "type": "object", "required": [ "name", "sidecar" ], "properties": { "args": { "description": "The allowed arguments for the command execution.", "allOf": [ { "$ref": "#/definitions/ShellScopeEntryAllowedArgs" } ] }, "name": { "description": "The name for this allowed shell command configuration.\n\nThis name will be used inside of the webview API to call this command along with any specified arguments.", "type": "string" }, "sidecar": { "description": "If this command is a sidecar command.", "type": "boolean" } }, "additionalProperties": false } ] } } } }, "properties": { "identifier": { "description": "Identifier of the permission or permission set.", "allOf": [ { "$ref": "#/definitions/Identifier" } ] } } }, { "properties": { "identifier": { "description": "Identifier of the permission or permission set.", "allOf": [ { "$ref": "#/definitions/Identifier" } ] }, "allow": { "description": "Data that defines what is allowed by the scope.", "type": [ "array", "null" ], "items": { "$ref": "#/definitions/Value" } }, "deny": { "description": "Data that defines what is denied by the scope. This should be prioritized by validation logic.", "type": [ "array", "null" ], "items": { "$ref": "#/definitions/Value" } } } } ], "required": [ "identifier" ] } ] }, "Identifier": { "description": "Permission identifier", "oneOf": [ { "description": "Default core plugins set.\n#### This default permission set includes:\n\n- `core:path:default`\n- `core:event:default`\n- `core:window:default`\n- `core:webview:default`\n- `core:app:default`\n- `core:image:default`\n- `core:resources:default`\n- `core:menu:default`\n- `core:tray:default`", "type": "string", "const": "core:default", "markdownDescription": "Default core plugins set.\n#### This default permission set includes:\n\n- `core:path:default`\n- `core:event:default`\n- `core:window:default`\n- `core:webview:default`\n- `core:app:default`\n- `core:image:default`\n- `core:resources:default`\n- `core:menu:default`\n- `core:tray:default`" }, { "description": "Default permissions for the plugin.\n#### This default permission set includes:\n\n- `allow-version`\n- `allow-name`\n- `allow-tauri-version`\n- `allow-identifier`\n- `allow-bundle-type`\n- `allow-register-listener`\n- `allow-remove-listener`", "type": "string", "const": "core:app:default", "markdownDescription": "Default permissions for the plugin.\n#### This default permission set includes:\n\n- `allow-version`\n- `allow-name`\n- `allow-tauri-version`\n- `allow-identifier`\n- `allow-bundle-type`\n- `allow-register-listener`\n- `allow-remove-listener`" }, { "description": "Enables the app_hide command without any pre-configured scope.", "type": "string", "const": "core:app:allow-app-hide", "markdownDescription": "Enables the app_hide command without any pre-configured scope." }, { "description": "Enables the app_show command without any pre-configured scope.", "type": "string", "const": "core:app:allow-app-show", "markdownDescription": "Enables the app_show command without any pre-configured scope." }, { "description": "Enables the bundle_type command without any pre-configured scope.", "type": "string", "const": "core:app:allow-bundle-type", "markdownDescription": "Enables the bundle_type command without any pre-configured scope." }, { "description": "Enables the default_window_icon command without any pre-configured scope.", "type": "string", "const": "core:app:allow-default-window-icon", "markdownDescription": "Enables the default_window_icon command without any pre-configured scope." }, { "description": "Enables the fetch_data_store_identifiers command without any pre-configured scope.", "type": "string", "const": "core:app:allow-fetch-data-store-identifiers", "markdownDescription": "Enables the fetch_data_store_identifiers command without any pre-configured scope." }, { "description": "Enables the identifier command without any pre-configured scope.", "type": "string", "const": "core:app:allow-identifier", "markdownDescription": "Enables the identifier command without any pre-configured scope." }, { "description": "Enables the name command without any pre-configured scope.", "type": "string", "const": "core:app:allow-name", "markdownDescription": "Enables the name command without any pre-configured scope." }, { "description": "Enables the register_listener command without any pre-configured scope.", "type": "string", "const": "core:app:allow-register-listener", "markdownDescription": "Enables the register_listener command without any pre-configured scope." }, { "description": "Enables the remove_data_store command without any pre-configured scope.", "type": "string", "const": "core:app:allow-remove-data-store", "markdownDescription": "Enables the remove_data_store command without any pre-configured scope." }, { "description": "Enables the remove_listener command without any pre-configured scope.", "type": "string", "const": "core:app:allow-remove-listener", "markdownDescription": "Enables the remove_listener command without any pre-configured scope." }, { "description": "Enables the set_app_theme command without any pre-configured scope.", "type": "string", "const": "core:app:allow-set-app-theme", "markdownDescription": "Enables the set_app_theme command without any pre-configured scope." }, { "description": "Enables the set_dock_visibility command without any pre-configured scope.", "type": "string", "const": "core:app:allow-set-dock-visibility", "markdownDescription": "Enables the set_dock_visibility command without any pre-configured scope." }, { "description": "Enables the tauri_version command without any pre-configured scope.", "type": "string", "const": "core:app:allow-tauri-version", "markdownDescription": "Enables the tauri_version command without any pre-configured scope." }, { "description": "Enables the version command without any pre-configured scope.", "type": "string", "const": "core:app:allow-version", "markdownDescription": "Enables the version command without any pre-configured scope." }, { "description": "Denies the app_hide command without any pre-configured scope.", "type": "string", "const": "core:app:deny-app-hide", "markdownDescription": "Denies the app_hide command without any pre-configured scope." }, { "description": "Denies the app_show command without any pre-configured scope.", "type": "string", "const": "core:app:deny-app-show", "markdownDescription": "Denies the app_show command without any pre-configured scope." }, { "description": "Denies the bundle_type command without any pre-configured scope.", "type": "string", "const": "core:app:deny-bundle-type", "markdownDescription": "Denies the bundle_type command without any pre-configured scope." }, { "description": "Denies the default_window_icon command without any pre-configured scope.", "type": "string", "const": "core:app:deny-default-window-icon", "markdownDescription": "Denies the default_window_icon command without any pre-configured scope." }, { "description": "Denies the fetch_data_store_identifiers command without any pre-configured scope.", "type": "string", "const": "core:app:deny-fetch-data-store-identifiers", "markdownDescription": "Denies the fetch_data_store_identifiers command without any pre-configured scope." }, { "description": "Denies the identifier command without any pre-configured scope.", "type": "string", "const": "core:app:deny-identifier", "markdownDescription": "Denies the identifier command without any pre-configured scope." }, { "description": "Denies the name command without any pre-configured scope.", "type": "string", "const": "core:app:deny-name", "markdownDescription": "Denies the name command without any pre-configured scope." }, { "description": "Denies the register_listener command without any pre-configured scope.", "type": "string", "const": "core:app:deny-register-listener", "markdownDescription": "Denies the register_listener command without any pre-configured scope." }, { "description": "Denies the remove_data_store command without any pre-configured scope.", "type": "string", "const": "core:app:deny-remove-data-store", "markdownDescription": "Denies the remove_data_store command without any pre-configured scope." }, { "description": "Denies the remove_listener command without any pre-configured scope.", "type": "string", "const": "core:app:deny-remove-listener", "markdownDescription": "Denies the remove_listener command without any pre-configured scope." }, { "description": "Denies the set_app_theme command without any pre-configured scope.", "type": "string", "const": "core:app:deny-set-app-theme", "markdownDescription": "Denies the set_app_theme command without any pre-configured scope." }, { "description": "Denies the set_dock_visibility command without any pre-configured scope.", "type": "string", "const": "core:app:deny-set-dock-visibility", "markdownDescription": "Denies the set_dock_visibility command without any pre-configured scope." }, { "description": "Denies the tauri_version command without any pre-configured scope.", "type": "string", "const": "core:app:deny-tauri-version", "markdownDescription": "Denies the tauri_version command without any pre-configured scope." }, { "description": "Denies the version command without any pre-configured scope.", "type": "string", "const": "core:app:deny-version", "markdownDescription": "Denies the version command without any pre-configured scope." }, { "description": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-listen`\n- `allow-unlisten`\n- `allow-emit`\n- `allow-emit-to`", "type": "string", "const": "core:event:default", "markdownDescription": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-listen`\n- `allow-unlisten`\n- `allow-emit`\n- `allow-emit-to`" }, { "description": "Enables the emit command without any pre-configured scope.", "type": "string", "const": "core:event:allow-emit", "markdownDescription": "Enables the emit command without any pre-configured scope." }, { "description": "Enables the emit_to command without any pre-configured scope.", "type": "string", "const": "core:event:allow-emit-to", "markdownDescription": "Enables the emit_to command without any pre-configured scope." }, { "description": "Enables the listen command without any pre-configured scope.", "type": "string", "const": "core:event:allow-listen", "markdownDescription": "Enables the listen command without any pre-configured scope." }, { "description": "Enables the unlisten command without any pre-configured scope.", "type": "string", "const": "core:event:allow-unlisten", "markdownDescription": "Enables the unlisten command without any pre-configured scope." }, { "description": "Denies the emit command without any pre-configured scope.", "type": "string", "const": "core:event:deny-emit", "markdownDescription": "Denies the emit command without any pre-configured scope." }, { "description": "Denies the emit_to command without any pre-configured scope.", "type": "string", "const": "core:event:deny-emit-to", "markdownDescription": "Denies the emit_to command without any pre-configured scope." }, { "description": "Denies the listen command without any pre-configured scope.", "type": "string", "const": "core:event:deny-listen", "markdownDescription": "Denies the listen command without any pre-configured scope." }, { "description": "Denies the unlisten command without any pre-configured scope.", "type": "string", "const": "core:event:deny-unlisten", "markdownDescription": "Denies the unlisten command without any pre-configured scope." }, { "description": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-new`\n- `allow-from-bytes`\n- `allow-from-path`\n- `allow-rgba`\n- `allow-size`", "type": "string", "const": "core:image:default", "markdownDescription": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-new`\n- `allow-from-bytes`\n- `allow-from-path`\n- `allow-rgba`\n- `allow-size`" }, { "description": "Enables the from_bytes command without any pre-configured scope.", "type": "string", "const": "core:image:allow-from-bytes", "markdownDescription": "Enables the from_bytes command without any pre-configured scope." }, { "description": "Enables the from_path command without any pre-configured scope.", "type": "string", "const": "core:image:allow-from-path", "markdownDescription": "Enables the from_path command without any pre-configured scope." }, { "description": "Enables the new command without any pre-configured scope.", "type": "string", "const": "core:image:allow-new", "markdownDescription": "Enables the new command without any pre-configured scope." }, { "description": "Enables the rgba command without any pre-configured scope.", "type": "string", "const": "core:image:allow-rgba", "markdownDescription": "Enables the rgba command without any pre-configured scope." }, { "description": "Enables the size command without any pre-configured scope.", "type": "string", "const": "core:image:allow-size", "markdownDescription": "Enables the size command without any pre-configured scope." }, { "description": "Denies the from_bytes command without any pre-configured scope.", "type": "string", "const": "core:image:deny-from-bytes", "markdownDescription": "Denies the from_bytes command without any pre-configured scope." }, { "description": "Denies the from_path command without any pre-configured scope.", "type": "string", "const": "core:image:deny-from-path", "markdownDescription": "Denies the from_path command without any pre-configured scope." }, { "description": "Denies the new command without any pre-configured scope.", "type": "string", "const": "core:image:deny-new", "markdownDescription": "Denies the new command without any pre-configured scope." }, { "description": "Denies the rgba command without any pre-configured scope.", "type": "string", "const": "core:image:deny-rgba", "markdownDescription": "Denies the rgba command without any pre-configured scope." }, { "description": "Denies the size command without any pre-configured scope.", "type": "string", "const": "core:image:deny-size", "markdownDescription": "Denies the size command without any pre-configured scope." }, { "description": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-new`\n- `allow-append`\n- `allow-prepend`\n- `allow-insert`\n- `allow-remove`\n- `allow-remove-at`\n- `allow-items`\n- `allow-get`\n- `allow-popup`\n- `allow-create-default`\n- `allow-set-as-app-menu`\n- `allow-set-as-window-menu`\n- `allow-text`\n- `allow-set-text`\n- `allow-is-enabled`\n- `allow-set-enabled`\n- `allow-set-accelerator`\n- `allow-set-as-windows-menu-for-nsapp`\n- `allow-set-as-help-menu-for-nsapp`\n- `allow-is-checked`\n- `allow-set-checked`\n- `allow-set-icon`", "type": "string", "const": "core:menu:default", "markdownDescription": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-new`\n- `allow-append`\n- `allow-prepend`\n- `allow-insert`\n- `allow-remove`\n- `allow-remove-at`\n- `allow-items`\n- `allow-get`\n- `allow-popup`\n- `allow-create-default`\n- `allow-set-as-app-menu`\n- `allow-set-as-window-menu`\n- `allow-text`\n- `allow-set-text`\n- `allow-is-enabled`\n- `allow-set-enabled`\n- `allow-set-accelerator`\n- `allow-set-as-windows-menu-for-nsapp`\n- `allow-set-as-help-menu-for-nsapp`\n- `allow-is-checked`\n- `allow-set-checked`\n- `allow-set-icon`" }, { "description": "Enables the append command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-append", "markdownDescription": "Enables the append command without any pre-configured scope." }, { "description": "Enables the create_default command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-create-default", "markdownDescription": "Enables the create_default command without any pre-configured scope." }, { "description": "Enables the get command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-get", "markdownDescription": "Enables the get command without any pre-configured scope." }, { "description": "Enables the insert command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-insert", "markdownDescription": "Enables the insert command without any pre-configured scope." }, { "description": "Enables the is_checked command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-is-checked", "markdownDescription": "Enables the is_checked command without any pre-configured scope." }, { "description": "Enables the is_enabled command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-is-enabled", "markdownDescription": "Enables the is_enabled command without any pre-configured scope." }, { "description": "Enables the items command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-items", "markdownDescription": "Enables the items command without any pre-configured scope." }, { "description": "Enables the new command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-new", "markdownDescription": "Enables the new command without any pre-configured scope." }, { "description": "Enables the popup command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-popup", "markdownDescription": "Enables the popup command without any pre-configured scope." }, { "description": "Enables the prepend command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-prepend", "markdownDescription": "Enables the prepend command without any pre-configured scope." }, { "description": "Enables the remove command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-remove", "markdownDescription": "Enables the remove command without any pre-configured scope." }, { "description": "Enables the remove_at command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-remove-at", "markdownDescription": "Enables the remove_at command without any pre-configured scope." }, { "description": "Enables the set_accelerator command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-set-accelerator", "markdownDescription": "Enables the set_accelerator command without any pre-configured scope." }, { "description": "Enables the set_as_app_menu command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-set-as-app-menu", "markdownDescription": "Enables the set_as_app_menu command without any pre-configured scope." }, { "description": "Enables the set_as_help_menu_for_nsapp command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-set-as-help-menu-for-nsapp", "markdownDescription": "Enables the set_as_help_menu_for_nsapp command without any pre-configured scope." }, { "description": "Enables the set_as_window_menu command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-set-as-window-menu", "markdownDescription": "Enables the set_as_window_menu command without any pre-configured scope." }, { "description": "Enables the set_as_windows_menu_for_nsapp command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-set-as-windows-menu-for-nsapp", "markdownDescription": "Enables the set_as_windows_menu_for_nsapp command without any pre-configured scope." }, { "description": "Enables the set_checked command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-set-checked", "markdownDescription": "Enables the set_checked command without any pre-configured scope." }, { "description": "Enables the set_enabled command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-set-enabled", "markdownDescription": "Enables the set_enabled command without any pre-configured scope." }, { "description": "Enables the set_icon command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-set-icon", "markdownDescription": "Enables the set_icon command without any pre-configured scope." }, { "description": "Enables the set_text command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-set-text", "markdownDescription": "Enables the set_text command without any pre-configured scope." }, { "description": "Enables the text command without any pre-configured scope.", "type": "string", "const": "core:menu:allow-text", "markdownDescription": "Enables the text command without any pre-configured scope." }, { "description": "Denies the append command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-append", "markdownDescription": "Denies the append command without any pre-configured scope." }, { "description": "Denies the create_default command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-create-default", "markdownDescription": "Denies the create_default command without any pre-configured scope." }, { "description": "Denies the get command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-get", "markdownDescription": "Denies the get command without any pre-configured scope." }, { "description": "Denies the insert command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-insert", "markdownDescription": "Denies the insert command without any pre-configured scope." }, { "description": "Denies the is_checked command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-is-checked", "markdownDescription": "Denies the is_checked command without any pre-configured scope." }, { "description": "Denies the is_enabled command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-is-enabled", "markdownDescription": "Denies the is_enabled command without any pre-configured scope." }, { "description": "Denies the items command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-items", "markdownDescription": "Denies the items command without any pre-configured scope." }, { "description": "Denies the new command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-new", "markdownDescription": "Denies the new command without any pre-configured scope." }, { "description": "Denies the popup command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-popup", "markdownDescription": "Denies the popup command without any pre-configured scope." }, { "description": "Denies the prepend command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-prepend", "markdownDescription": "Denies the prepend command without any pre-configured scope." }, { "description": "Denies the remove command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-remove", "markdownDescription": "Denies the remove command without any pre-configured scope." }, { "description": "Denies the remove_at command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-remove-at", "markdownDescription": "Denies the remove_at command without any pre-configured scope." }, { "description": "Denies the set_accelerator command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-set-accelerator", "markdownDescription": "Denies the set_accelerator command without any pre-configured scope." }, { "description": "Denies the set_as_app_menu command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-set-as-app-menu", "markdownDescription": "Denies the set_as_app_menu command without any pre-configured scope." }, { "description": "Denies the set_as_help_menu_for_nsapp command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-set-as-help-menu-for-nsapp", "markdownDescription": "Denies the set_as_help_menu_for_nsapp command without any pre-configured scope." }, { "description": "Denies the set_as_window_menu command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-set-as-window-menu", "markdownDescription": "Denies the set_as_window_menu command without any pre-configured scope." }, { "description": "Denies the set_as_windows_menu_for_nsapp command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-set-as-windows-menu-for-nsapp", "markdownDescription": "Denies the set_as_windows_menu_for_nsapp command without any pre-configured scope." }, { "description": "Denies the set_checked command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-set-checked", "markdownDescription": "Denies the set_checked command without any pre-configured scope." }, { "description": "Denies the set_enabled command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-set-enabled", "markdownDescription": "Denies the set_enabled command without any pre-configured scope." }, { "description": "Denies the set_icon command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-set-icon", "markdownDescription": "Denies the set_icon command without any pre-configured scope." }, { "description": "Denies the set_text command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-set-text", "markdownDescription": "Denies the set_text command without any pre-configured scope." }, { "description": "Denies the text command without any pre-configured scope.", "type": "string", "const": "core:menu:deny-text", "markdownDescription": "Denies the text command without any pre-configured scope." }, { "description": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-resolve-directory`\n- `allow-resolve`\n- `allow-normalize`\n- `allow-join`\n- `allow-dirname`\n- `allow-extname`\n- `allow-basename`\n- `allow-is-absolute`", "type": "string", "const": "core:path:default", "markdownDescription": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-resolve-directory`\n- `allow-resolve`\n- `allow-normalize`\n- `allow-join`\n- `allow-dirname`\n- `allow-extname`\n- `allow-basename`\n- `allow-is-absolute`" }, { "description": "Enables the basename command without any pre-configured scope.", "type": "string", "const": "core:path:allow-basename", "markdownDescription": "Enables the basename command without any pre-configured scope." }, { "description": "Enables the dirname command without any pre-configured scope.", "type": "string", "const": "core:path:allow-dirname", "markdownDescription": "Enables the dirname command without any pre-configured scope." }, { "description": "Enables the extname command without any pre-configured scope.", "type": "string", "const": "core:path:allow-extname", "markdownDescription": "Enables the extname command without any pre-configured scope." }, { "description": "Enables the is_absolute command without any pre-configured scope.", "type": "string", "const": "core:path:allow-is-absolute", "markdownDescription": "Enables the is_absolute command without any pre-configured scope." }, { "description": "Enables the join command without any pre-configured scope.", "type": "string", "const": "core:path:allow-join", "markdownDescription": "Enables the join command without any pre-configured scope." }, { "description": "Enables the normalize command without any pre-configured scope.", "type": "string", "const": "core:path:allow-normalize", "markdownDescription": "Enables the normalize command without any pre-configured scope." }, { "description": "Enables the resolve command without any pre-configured scope.", "type": "string", "const": "core:path:allow-resolve", "markdownDescription": "Enables the resolve command without any pre-configured scope." }, { "description": "Enables the resolve_directory command without any pre-configured scope.", "type": "string", "const": "core:path:allow-resolve-directory", "markdownDescription": "Enables the resolve_directory command without any pre-configured scope." }, { "description": "Denies the basename command without any pre-configured scope.", "type": "string", "const": "core:path:deny-basename", "markdownDescription": "Denies the basename command without any pre-configured scope." }, { "description": "Denies the dirname command without any pre-configured scope.", "type": "string", "const": "core:path:deny-dirname", "markdownDescription": "Denies the dirname command without any pre-configured scope." }, { "description": "Denies the extname command without any pre-configured scope.", "type": "string", "const": "core:path:deny-extname", "markdownDescription": "Denies the extname command without any pre-configured scope." }, { "description": "Denies the is_absolute command without any pre-configured scope.", "type": "string", "const": "core:path:deny-is-absolute", "markdownDescription": "Denies the is_absolute command without any pre-configured scope." }, { "description": "Denies the join command without any pre-configured scope.", "type": "string", "const": "core:path:deny-join", "markdownDescription": "Denies the join command without any pre-configured scope." }, { "description": "Denies the normalize command without any pre-configured scope.", "type": "string", "const": "core:path:deny-normalize", "markdownDescription": "Denies the normalize command without any pre-configured scope." }, { "description": "Denies the resolve command without any pre-configured scope.", "type": "string", "const": "core:path:deny-resolve", "markdownDescription": "Denies the resolve command without any pre-configured scope." }, { "description": "Denies the resolve_directory command without any pre-configured scope.", "type": "string", "const": "core:path:deny-resolve-directory", "markdownDescription": "Denies the resolve_directory command without any pre-configured scope." }, { "description": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-close`", "type": "string", "const": "core:resources:default", "markdownDescription": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-close`" }, { "description": "Enables the close command without any pre-configured scope.", "type": "string", "const": "core:resources:allow-close", "markdownDescription": "Enables the close command without any pre-configured scope." }, { "description": "Denies the close command without any pre-configured scope.", "type": "string", "const": "core:resources:deny-close", "markdownDescription": "Denies the close command without any pre-configured scope." }, { "description": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-new`\n- `allow-get-by-id`\n- `allow-remove-by-id`\n- `allow-set-icon`\n- `allow-set-menu`\n- `allow-set-tooltip`\n- `allow-set-title`\n- `allow-set-visible`\n- `allow-set-temp-dir-path`\n- `allow-set-icon-as-template`\n- `allow-set-show-menu-on-left-click`", "type": "string", "const": "core:tray:default", "markdownDescription": "Default permissions for the plugin, which enables all commands.\n#### This default permission set includes:\n\n- `allow-new`\n- `allow-get-by-id`\n- `allow-remove-by-id`\n- `allow-set-icon`\n- `allow-set-menu`\n- `allow-set-tooltip`\n- `allow-set-title`\n- `allow-set-visible`\n- `allow-set-temp-dir-path`\n- `allow-set-icon-as-template`\n- `allow-set-show-menu-on-left-click`" }, { "description": "Enables the get_by_id command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-get-by-id", "markdownDescription": "Enables the get_by_id command without any pre-configured scope." }, { "description": "Enables the new command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-new", "markdownDescription": "Enables the new command without any pre-configured scope." }, { "description": "Enables the remove_by_id command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-remove-by-id", "markdownDescription": "Enables the remove_by_id command without any pre-configured scope." }, { "description": "Enables the set_icon command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-set-icon", "markdownDescription": "Enables the set_icon command without any pre-configured scope." }, { "description": "Enables the set_icon_as_template command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-set-icon-as-template", "markdownDescription": "Enables the set_icon_as_template command without any pre-configured scope." }, { "description": "Enables the set_menu command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-set-menu", "markdownDescription": "Enables the set_menu command without any pre-configured scope." }, { "description": "Enables the set_show_menu_on_left_click command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-set-show-menu-on-left-click", "markdownDescription": "Enables the set_show_menu_on_left_click command without any pre-configured scope." }, { "description": "Enables the set_temp_dir_path command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-set-temp-dir-path", "markdownDescription": "Enables the set_temp_dir_path command without any pre-configured scope." }, { "description": "Enables the set_title command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-set-title", "markdownDescription": "Enables the set_title command without any pre-configured scope." }, { "description": "Enables the set_tooltip command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-set-tooltip", "markdownDescription": "Enables the set_tooltip command without any pre-configured scope." }, { "description": "Enables the set_visible command without any pre-configured scope.", "type": "string", "const": "core:tray:allow-set-visible", "markdownDescription": "Enables the set_visible command without any pre-configured scope." }, { "description": "Denies the get_by_id command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-get-by-id", "markdownDescription": "Denies the get_by_id command without any pre-configured scope." }, { "description": "Denies the new command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-new", "markdownDescription": "Denies the new command without any pre-configured scope." }, { "description": "Denies the remove_by_id command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-remove-by-id", "markdownDescription": "Denies the remove_by_id command without any pre-configured scope." }, { "description": "Denies the set_icon command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-set-icon", "markdownDescription": "Denies the set_icon command without any pre-configured scope." }, { "description": "Denies the set_icon_as_template command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-set-icon-as-template", "markdownDescription": "Denies the set_icon_as_template command without any pre-configured scope." }, { "description": "Denies the set_menu command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-set-menu", "markdownDescription": "Denies the set_menu command without any pre-configured scope." }, { "description": "Denies the set_show_menu_on_left_click command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-set-show-menu-on-left-click", "markdownDescription": "Denies the set_show_menu_on_left_click command without any pre-configured scope." }, { "description": "Denies the set_temp_dir_path command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-set-temp-dir-path", "markdownDescription": "Denies the set_temp_dir_path command without any pre-configured scope." }, { "description": "Denies the set_title command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-set-title", "markdownDescription": "Denies the set_title command without any pre-configured scope." }, { "description": "Denies the set_tooltip command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-set-tooltip", "markdownDescription": "Denies the set_tooltip command without any pre-configured scope." }, { "description": "Denies the set_visible command without any pre-configured scope.", "type": "string", "const": "core:tray:deny-set-visible", "markdownDescription": "Denies the set_visible command without any pre-configured scope." }, { "description": "Default permissions for the plugin.\n#### This default permission set includes:\n\n- `allow-get-all-webviews`\n- `allow-webview-position`\n- `allow-webview-size`\n- `allow-internal-toggle-devtools`", "type": "string", "const": "core:webview:default", "markdownDescription": "Default permissions for the plugin.\n#### This default permission set includes:\n\n- `allow-get-all-webviews`\n- `allow-webview-position`\n- `allow-webview-size`\n- `allow-internal-toggle-devtools`" }, { "description": "Enables the clear_all_browsing_data command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-clear-all-browsing-data", "markdownDescription": "Enables the clear_all_browsing_data command without any pre-configured scope." }, { "description": "Enables the create_webview command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-create-webview", "markdownDescription": "Enables the create_webview command without any pre-configured scope." }, { "description": "Enables the create_webview_window command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-create-webview-window", "markdownDescription": "Enables the create_webview_window command without any pre-configured scope." }, { "description": "Enables the get_all_webviews command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-get-all-webviews", "markdownDescription": "Enables the get_all_webviews command without any pre-configured scope." }, { "description": "Enables the internal_toggle_devtools command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-internal-toggle-devtools", "markdownDescription": "Enables the internal_toggle_devtools command without any pre-configured scope." }, { "description": "Enables the print command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-print", "markdownDescription": "Enables the print command without any pre-configured scope." }, { "description": "Enables the reparent command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-reparent", "markdownDescription": "Enables the reparent command without any pre-configured scope." }, { "description": "Enables the set_webview_auto_resize command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-set-webview-auto-resize", "markdownDescription": "Enables the set_webview_auto_resize command without any pre-configured scope." }, { "description": "Enables the set_webview_background_color command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-set-webview-background-color", "markdownDescription": "Enables the set_webview_background_color command without any pre-configured scope." }, { "description": "Enables the set_webview_focus command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-set-webview-focus", "markdownDescription": "Enables the set_webview_focus command without any pre-configured scope." }, { "description": "Enables the set_webview_position command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-set-webview-position", "markdownDescription": "Enables the set_webview_position command without any pre-configured scope." }, { "description": "Enables the set_webview_size command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-set-webview-size", "markdownDescription": "Enables the set_webview_size command without any pre-configured scope." }, { "description": "Enables the set_webview_zoom command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-set-webview-zoom", "markdownDescription": "Enables the set_webview_zoom command without any pre-configured scope." }, { "description": "Enables the webview_close command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-webview-close", "markdownDescription": "Enables the webview_close command without any pre-configured scope." }, { "description": "Enables the webview_hide command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-webview-hide", "markdownDescription": "Enables the webview_hide command without any pre-configured scope." }, { "description": "Enables the webview_position command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-webview-position", "markdownDescription": "Enables the webview_position command without any pre-configured scope." }, { "description": "Enables the webview_show command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-webview-show", "markdownDescription": "Enables the webview_show command without any pre-configured scope." }, { "description": "Enables the webview_size command without any pre-configured scope.", "type": "string", "const": "core:webview:allow-webview-size", "markdownDescription": "Enables the webview_size command without any pre-configured scope." }, { "description": "Denies the clear_all_browsing_data command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-clear-all-browsing-data", "markdownDescription": "Denies the clear_all_browsing_data command without any pre-configured scope." }, { "description": "Denies the create_webview command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-create-webview", "markdownDescription": "Denies the create_webview command without any pre-configured scope." }, { "description": "Denies the create_webview_window command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-create-webview-window", "markdownDescription": "Denies the create_webview_window command without any pre-configured scope." }, { "description": "Denies the get_all_webviews command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-get-all-webviews", "markdownDescription": "Denies the get_all_webviews command without any pre-configured scope." }, { "description": "Denies the internal_toggle_devtools command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-internal-toggle-devtools", "markdownDescription": "Denies the internal_toggle_devtools command without any pre-configured scope." }, { "description": "Denies the print command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-print", "markdownDescription": "Denies the print command without any pre-configured scope." }, { "description": "Denies the reparent command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-reparent", "markdownDescription": "Denies the reparent command without any pre-configured scope." }, { "description": "Denies the set_webview_auto_resize command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-set-webview-auto-resize", "markdownDescription": "Denies the set_webview_auto_resize command without any pre-configured scope." }, { "description": "Denies the set_webview_background_color command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-set-webview-background-color", "markdownDescription": "Denies the set_webview_background_color command without any pre-configured scope." }, { "description": "Denies the set_webview_focus command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-set-webview-focus", "markdownDescription": "Denies the set_webview_focus command without any pre-configured scope." }, { "description": "Denies the set_webview_position command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-set-webview-position", "markdownDescription": "Denies the set_webview_position command without any pre-configured scope." }, { "description": "Denies the set_webview_size command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-set-webview-size", "markdownDescription": "Denies the set_webview_size command without any pre-configured scope." }, { "description": "Denies the set_webview_zoom command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-set-webview-zoom", "markdownDescription": "Denies the set_webview_zoom command without any pre-configured scope." }, { "description": "Denies the webview_close command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-webview-close", "markdownDescription": "Denies the webview_close command without any pre-configured scope." }, { "description": "Denies the webview_hide command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-webview-hide", "markdownDescription": "Denies the webview_hide command without any pre-configured scope." }, { "description": "Denies the webview_position command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-webview-position", "markdownDescription": "Denies the webview_position command without any pre-configured scope." }, { "description": "Denies the webview_show command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-webview-show", "markdownDescription": "Denies the webview_show command without any pre-configured scope." }, { "description": "Denies the webview_size command without any pre-configured scope.", "type": "string", "const": "core:webview:deny-webview-size", "markdownDescription": "Denies the webview_size command without any pre-configured scope." }, { "description": "Default permissions for the plugin.\n#### This default permission set includes:\n\n- `allow-get-all-windows`\n- `allow-scale-factor`\n- `allow-inner-position`\n- `allow-outer-position`\n- `allow-inner-size`\n- `allow-outer-size`\n- `allow-is-fullscreen`\n- `allow-is-minimized`\n- `allow-is-maximized`\n- `allow-is-focused`\n- `allow-is-decorated`\n- `allow-is-resizable`\n- `allow-is-maximizable`\n- `allow-is-minimizable`\n- `allow-is-closable`\n- `allow-is-visible`\n- `allow-is-enabled`\n- `allow-title`\n- `allow-current-monitor`\n- `allow-primary-monitor`\n- `allow-monitor-from-point`\n- `allow-available-monitors`\n- `allow-cursor-position`\n- `allow-theme`\n- `allow-is-always-on-top`\n- `allow-internal-toggle-maximize`", "type": "string", "const": "core:window:default", "markdownDescription": "Default permissions for the plugin.\n#### This default permission set includes:\n\n- `allow-get-all-windows`\n- `allow-scale-factor`\n- `allow-inner-position`\n- `allow-outer-position`\n- `allow-inner-size`\n- `allow-outer-size`\n- `allow-is-fullscreen`\n- `allow-is-minimized`\n- `allow-is-maximized`\n- `allow-is-focused`\n- `allow-is-decorated`\n- `allow-is-resizable`\n- `allow-is-maximizable`\n- `allow-is-minimizable`\n- `allow-is-closable`\n- `allow-is-visible`\n- `allow-is-enabled`\n- `allow-title`\n- `allow-current-monitor`\n- `allow-primary-monitor`\n- `allow-monitor-from-point`\n- `allow-available-monitors`\n- `allow-cursor-position`\n- `allow-theme`\n- `allow-is-always-on-top`\n- `allow-internal-toggle-maximize`" }, { "description": "Enables the available_monitors command without any pre-configured scope.", "type": "string", "const": "core:window:allow-available-monitors", "markdownDescription": "Enables the available_monitors command without any pre-configured scope." }, { "description": "Enables the center command without any pre-configured scope.", "type": "string", "const": "core:window:allow-center", "markdownDescription": "Enables the center command without any pre-configured scope." }, { "description": "Enables the close command without any pre-configured scope.", "type": "string", "const": "core:window:allow-close", "markdownDescription": "Enables the close command without any pre-configured scope." }, { "description": "Enables the create command without any pre-configured scope.", "type": "string", "const": "core:window:allow-create", "markdownDescription": "Enables the create command without any pre-configured scope." }, { "description": "Enables the current_monitor command without any pre-configured scope.", "type": "string", "const": "core:window:allow-current-monitor", "markdownDescription": "Enables the current_monitor command without any pre-configured scope." }, { "description": "Enables the cursor_position command without any pre-configured scope.", "type": "string", "const": "core:window:allow-cursor-position", "markdownDescription": "Enables the cursor_position command without any pre-configured scope." }, { "description": "Enables the destroy command without any pre-configured scope.", "type": "string", "const": "core:window:allow-destroy", "markdownDescription": "Enables the destroy command without any pre-configured scope." }, { "description": "Enables the get_all_windows command without any pre-configured scope.", "type": "string", "const": "core:window:allow-get-all-windows", "markdownDescription": "Enables the get_all_windows command without any pre-configured scope." }, { "description": "Enables the hide command without any pre-configured scope.", "type": "string", "const": "core:window:allow-hide", "markdownDescription": "Enables the hide command without any pre-configured scope." }, { "description": "Enables the inner_position command without any pre-configured scope.", "type": "string", "const": "core:window:allow-inner-position", "markdownDescription": "Enables the inner_position command without any pre-configured scope." }, { "description": "Enables the inner_size command without any pre-configured scope.", "type": "string", "const": "core:window:allow-inner-size", "markdownDescription": "Enables the inner_size command without any pre-configured scope." }, { "description": "Enables the internal_toggle_maximize command without any pre-configured scope.", "type": "string", "const": "core:window:allow-internal-toggle-maximize", "markdownDescription": "Enables the internal_toggle_maximize command without any pre-configured scope." }, { "description": "Enables the is_always_on_top command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-always-on-top", "markdownDescription": "Enables the is_always_on_top command without any pre-configured scope." }, { "description": "Enables the is_closable command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-closable", "markdownDescription": "Enables the is_closable command without any pre-configured scope." }, { "description": "Enables the is_decorated command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-decorated", "markdownDescription": "Enables the is_decorated command without any pre-configured scope." }, { "description": "Enables the is_enabled command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-enabled", "markdownDescription": "Enables the is_enabled command without any pre-configured scope." }, { "description": "Enables the is_focused command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-focused", "markdownDescription": "Enables the is_focused command without any pre-configured scope." }, { "description": "Enables the is_fullscreen command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-fullscreen", "markdownDescription": "Enables the is_fullscreen command without any pre-configured scope." }, { "description": "Enables the is_maximizable command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-maximizable", "markdownDescription": "Enables the is_maximizable command without any pre-configured scope." }, { "description": "Enables the is_maximized command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-maximized", "markdownDescription": "Enables the is_maximized command without any pre-configured scope." }, { "description": "Enables the is_minimizable command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-minimizable", "markdownDescription": "Enables the is_minimizable command without any pre-configured scope." }, { "description": "Enables the is_minimized command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-minimized", "markdownDescription": "Enables the is_minimized command without any pre-configured scope." }, { "description": "Enables the is_resizable command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-resizable", "markdownDescription": "Enables the is_resizable command without any pre-configured scope." }, { "description": "Enables the is_visible command without any pre-configured scope.", "type": "string", "const": "core:window:allow-is-visible", "markdownDescription": "Enables the is_visible command without any pre-configured scope." }, { "description": "Enables the maximize command without any pre-configured scope.", "type": "string", "const": "core:window:allow-maximize", "markdownDescription": "Enables the maximize command without any pre-configured scope." }, { "description": "Enables the minimize command without any pre-configured scope.", "type": "string", "const": "core:window:allow-minimize", "markdownDescription": "Enables the minimize command without any pre-configured scope." }, { "description": "Enables the monitor_from_point command without any pre-configured scope.", "type": "string", "const": "core:window:allow-monitor-from-point", "markdownDescription": "Enables the monitor_from_point command without any pre-configured scope." }, { "description": "Enables the outer_position command without any pre-configured scope.", "type": "string", "const": "core:window:allow-outer-position", "markdownDescription": "Enables the outer_position command without any pre-configured scope." }, { "description": "Enables the outer_size command without any pre-configured scope.", "type": "string", "const": "core:window:allow-outer-size", "markdownDescription": "Enables the outer_size command without any pre-configured scope." }, { "description": "Enables the primary_monitor command without any pre-configured scope.", "type": "string", "const": "core:window:allow-primary-monitor", "markdownDescription": "Enables the primary_monitor command without any pre-configured scope." }, { "description": "Enables the request_user_attention command without any pre-configured scope.", "type": "string", "const": "core:window:allow-request-user-attention", "markdownDescription": "Enables the request_user_attention command without any pre-configured scope." }, { "description": "Enables the scale_factor command without any pre-configured scope.", "type": "string", "const": "core:window:allow-scale-factor", "markdownDescription": "Enables the scale_factor command without any pre-configured scope." }, { "description": "Enables the set_always_on_bottom command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-always-on-bottom", "markdownDescription": "Enables the set_always_on_bottom command without any pre-configured scope." }, { "description": "Enables the set_always_on_top command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-always-on-top", "markdownDescription": "Enables the set_always_on_top command without any pre-configured scope." }, { "description": "Enables the set_background_color command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-background-color", "markdownDescription": "Enables the set_background_color command without any pre-configured scope." }, { "description": "Enables the set_badge_count command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-badge-count", "markdownDescription": "Enables the set_badge_count command without any pre-configured scope." }, { "description": "Enables the set_badge_label command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-badge-label", "markdownDescription": "Enables the set_badge_label command without any pre-configured scope." }, { "description": "Enables the set_closable command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-closable", "markdownDescription": "Enables the set_closable command without any pre-configured scope." }, { "description": "Enables the set_content_protected command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-content-protected", "markdownDescription": "Enables the set_content_protected command without any pre-configured scope." }, { "description": "Enables the set_cursor_grab command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-cursor-grab", "markdownDescription": "Enables the set_cursor_grab command without any pre-configured scope." }, { "description": "Enables the set_cursor_icon command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-cursor-icon", "markdownDescription": "Enables the set_cursor_icon command without any pre-configured scope." }, { "description": "Enables the set_cursor_position command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-cursor-position", "markdownDescription": "Enables the set_cursor_position command without any pre-configured scope." }, { "description": "Enables the set_cursor_visible command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-cursor-visible", "markdownDescription": "Enables the set_cursor_visible command without any pre-configured scope." }, { "description": "Enables the set_decorations command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-decorations", "markdownDescription": "Enables the set_decorations command without any pre-configured scope." }, { "description": "Enables the set_effects command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-effects", "markdownDescription": "Enables the set_effects command without any pre-configured scope." }, { "description": "Enables the set_enabled command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-enabled", "markdownDescription": "Enables the set_enabled command without any pre-configured scope." }, { "description": "Enables the set_focus command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-focus", "markdownDescription": "Enables the set_focus command without any pre-configured scope." }, { "description": "Enables the set_focusable command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-focusable", "markdownDescription": "Enables the set_focusable command without any pre-configured scope." }, { "description": "Enables the set_fullscreen command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-fullscreen", "markdownDescription": "Enables the set_fullscreen command without any pre-configured scope." }, { "description": "Enables the set_icon command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-icon", "markdownDescription": "Enables the set_icon command without any pre-configured scope." }, { "description": "Enables the set_ignore_cursor_events command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-ignore-cursor-events", "markdownDescription": "Enables the set_ignore_cursor_events command without any pre-configured scope." }, { "description": "Enables the set_max_size command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-max-size", "markdownDescription": "Enables the set_max_size command without any pre-configured scope." }, { "description": "Enables the set_maximizable command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-maximizable", "markdownDescription": "Enables the set_maximizable command without any pre-configured scope." }, { "description": "Enables the set_min_size command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-min-size", "markdownDescription": "Enables the set_min_size command without any pre-configured scope." }, { "description": "Enables the set_minimizable command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-minimizable", "markdownDescription": "Enables the set_minimizable command without any pre-configured scope." }, { "description": "Enables the set_overlay_icon command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-overlay-icon", "markdownDescription": "Enables the set_overlay_icon command without any pre-configured scope." }, { "description": "Enables the set_position command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-position", "markdownDescription": "Enables the set_position command without any pre-configured scope." }, { "description": "Enables the set_progress_bar command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-progress-bar", "markdownDescription": "Enables the set_progress_bar command without any pre-configured scope." }, { "description": "Enables the set_resizable command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-resizable", "markdownDescription": "Enables the set_resizable command without any pre-configured scope." }, { "description": "Enables the set_shadow command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-shadow", "markdownDescription": "Enables the set_shadow command without any pre-configured scope." }, { "description": "Enables the set_simple_fullscreen command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-simple-fullscreen", "markdownDescription": "Enables the set_simple_fullscreen command without any pre-configured scope." }, { "description": "Enables the set_size command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-size", "markdownDescription": "Enables the set_size command without any pre-configured scope." }, { "description": "Enables the set_size_constraints command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-size-constraints", "markdownDescription": "Enables the set_size_constraints command without any pre-configured scope." }, { "description": "Enables the set_skip_taskbar command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-skip-taskbar", "markdownDescription": "Enables the set_skip_taskbar command without any pre-configured scope." }, { "description": "Enables the set_theme command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-theme", "markdownDescription": "Enables the set_theme command without any pre-configured scope." }, { "description": "Enables the set_title command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-title", "markdownDescription": "Enables the set_title command without any pre-configured scope." }, { "description": "Enables the set_title_bar_style command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-title-bar-style", "markdownDescription": "Enables the set_title_bar_style command without any pre-configured scope." }, { "description": "Enables the set_visible_on_all_workspaces command without any pre-configured scope.", "type": "string", "const": "core:window:allow-set-visible-on-all-workspaces", "markdownDescription": "Enables the set_visible_on_all_workspaces command without any pre-configured scope." }, { "description": "Enables the show command without any pre-configured scope.", "type": "string", "const": "core:window:allow-show", "markdownDescription": "Enables the show command without any pre-configured scope." }, { "description": "Enables the start_dragging command without any pre-configured scope.", "type": "string", "const": "core:window:allow-start-dragging", "markdownDescription": "Enables the start_dragging command without any pre-configured scope." }, { "description": "Enables the start_resize_dragging command without any pre-configured scope.", "type": "string", "const": "core:window:allow-start-resize-dragging", "markdownDescription": "Enables the start_resize_dragging command without any pre-configured scope." }, { "description": "Enables the theme command without any pre-configured scope.", "type": "string", "const": "core:window:allow-theme", "markdownDescription": "Enables the theme command without any pre-configured scope." }, { "description": "Enables the title command without any pre-configured scope.", "type": "string", "const": "core:window:allow-title", "markdownDescription": "Enables the title command without any pre-configured scope." }, { "description": "Enables the toggle_maximize command without any pre-configured scope.", "type": "string", "const": "core:window:allow-toggle-maximize", "markdownDescription": "Enables the toggle_maximize command without any pre-configured scope." }, { "description": "Enables the unmaximize command without any pre-configured scope.", "type": "string", "const": "core:window:allow-unmaximize", "markdownDescription": "Enables the unmaximize command without any pre-configured scope." }, { "description": "Enables the unminimize command without any pre-configured scope.", "type": "string", "const": "core:window:allow-unminimize", "markdownDescription": "Enables the unminimize command without any pre-configured scope." }, { "description": "Denies the available_monitors command without any pre-configured scope.", "type": "string", "const": "core:window:deny-available-monitors", "markdownDescription": "Denies the available_monitors command without any pre-configured scope." }, { "description": "Denies the center command without any pre-configured scope.", "type": "string", "const": "core:window:deny-center", "markdownDescription": "Denies the center command without any pre-configured scope." }, { "description": "Denies the close command without any pre-configured scope.", "type": "string", "const": "core:window:deny-close", "markdownDescription": "Denies the close command without any pre-configured scope." }, { "description": "Denies the create command without any pre-configured scope.", "type": "string", "const": "core:window:deny-create", "markdownDescription": "Denies the create command without any pre-configured scope." }, { "description": "Denies the current_monitor command without any pre-configured scope.", "type": "string", "const": "core:window:deny-current-monitor", "markdownDescription": "Denies the current_monitor command without any pre-configured scope." }, { "description": "Denies the cursor_position command without any pre-configured scope.", "type": "string", "const": "core:window:deny-cursor-position", "markdownDescription": "Denies the cursor_position command without any pre-configured scope." }, { "description": "Denies the destroy command without any pre-configured scope.", "type": "string", "const": "core:window:deny-destroy", "markdownDescription": "Denies the destroy command without any pre-configured scope." }, { "description": "Denies the get_all_windows command without any pre-configured scope.", "type": "string", "const": "core:window:deny-get-all-windows", "markdownDescription": "Denies the get_all_windows command without any pre-configured scope." }, { "description": "Denies the hide command without any pre-configured scope.", "type": "string", "const": "core:window:deny-hide", "markdownDescription": "Denies the hide command without any pre-configured scope." }, { "description": "Denies the inner_position command without any pre-configured scope.", "type": "string", "const": "core:window:deny-inner-position", "markdownDescription": "Denies the inner_position command without any pre-configured scope." }, { "description": "Denies the inner_size command without any pre-configured scope.", "type": "string", "const": "core:window:deny-inner-size", "markdownDescription": "Denies the inner_size command without any pre-configured scope." }, { "description": "Denies the internal_toggle_maximize command without any pre-configured scope.", "type": "string", "const": "core:window:deny-internal-toggle-maximize", "markdownDescription": "Denies the internal_toggle_maximize command without any pre-configured scope." }, { "description": "Denies the is_always_on_top command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-always-on-top", "markdownDescription": "Denies the is_always_on_top command without any pre-configured scope." }, { "description": "Denies the is_closable command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-closable", "markdownDescription": "Denies the is_closable command without any pre-configured scope." }, { "description": "Denies the is_decorated command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-decorated", "markdownDescription": "Denies the is_decorated command without any pre-configured scope." }, { "description": "Denies the is_enabled command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-enabled", "markdownDescription": "Denies the is_enabled command without any pre-configured scope." }, { "description": "Denies the is_focused command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-focused", "markdownDescription": "Denies the is_focused command without any pre-configured scope." }, { "description": "Denies the is_fullscreen command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-fullscreen", "markdownDescription": "Denies the is_fullscreen command without any pre-configured scope." }, { "description": "Denies the is_maximizable command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-maximizable", "markdownDescription": "Denies the is_maximizable command without any pre-configured scope." }, { "description": "Denies the is_maximized command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-maximized", "markdownDescription": "Denies the is_maximized command without any pre-configured scope." }, { "description": "Denies the is_minimizable command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-minimizable", "markdownDescription": "Denies the is_minimizable command without any pre-configured scope." }, { "description": "Denies the is_minimized command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-minimized", "markdownDescription": "Denies the is_minimized command without any pre-configured scope." }, { "description": "Denies the is_resizable command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-resizable", "markdownDescription": "Denies the is_resizable command without any pre-configured scope." }, { "description": "Denies the is_visible command without any pre-configured scope.", "type": "string", "const": "core:window:deny-is-visible", "markdownDescription": "Denies the is_visible command without any pre-configured scope." }, { "description": "Denies the maximize command without any pre-configured scope.", "type": "string", "const": "core:window:deny-maximize", "markdownDescription": "Denies the maximize command without any pre-configured scope." }, { "description": "Denies the minimize command without any pre-configured scope.", "type": "string", "const": "core:window:deny-minimize", "markdownDescription": "Denies the minimize command without any pre-configured scope." }, { "description": "Denies the monitor_from_point command without any pre-configured scope.", "type": "string", "const": "core:window:deny-monitor-from-point", "markdownDescription": "Denies the monitor_from_point command without any pre-configured scope." }, { "description": "Denies the outer_position command without any pre-configured scope.", "type": "string", "const": "core:window:deny-outer-position", "markdownDescription": "Denies the outer_position command without any pre-configured scope." }, { "description": "Denies the outer_size command without any pre-configured scope.", "type": "string", "const": "core:window:deny-outer-size", "markdownDescription": "Denies the outer_size command without any pre-configured scope." }, { "description": "Denies the primary_monitor command without any pre-configured scope.", "type": "string", "const": "core:window:deny-primary-monitor", "markdownDescription": "Denies the primary_monitor command without any pre-configured scope." }, { "description": "Denies the request_user_attention command without any pre-configured scope.", "type": "string", "const": "core:window:deny-request-user-attention", "markdownDescription": "Denies the request_user_attention command without any pre-configured scope." }, { "description": "Denies the scale_factor command without any pre-configured scope.", "type": "string", "const": "core:window:deny-scale-factor", "markdownDescription": "Denies the scale_factor command without any pre-configured scope." }, { "description": "Denies the set_always_on_bottom command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-always-on-bottom", "markdownDescription": "Denies the set_always_on_bottom command without any pre-configured scope." }, { "description": "Denies the set_always_on_top command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-always-on-top", "markdownDescription": "Denies the set_always_on_top command without any pre-configured scope." }, { "description": "Denies the set_background_color command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-background-color", "markdownDescription": "Denies the set_background_color command without any pre-configured scope." }, { "description": "Denies the set_badge_count command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-badge-count", "markdownDescription": "Denies the set_badge_count command without any pre-configured scope." }, { "description": "Denies the set_badge_label command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-badge-label", "markdownDescription": "Denies the set_badge_label command without any pre-configured scope." }, { "description": "Denies the set_closable command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-closable", "markdownDescription": "Denies the set_closable command without any pre-configured scope." }, { "description": "Denies the set_content_protected command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-content-protected", "markdownDescription": "Denies the set_content_protected command without any pre-configured scope." }, { "description": "Denies the set_cursor_grab command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-cursor-grab", "markdownDescription": "Denies the set_cursor_grab command without any pre-configured scope." }, { "description": "Denies the set_cursor_icon command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-cursor-icon", "markdownDescription": "Denies the set_cursor_icon command without any pre-configured scope." }, { "description": "Denies the set_cursor_position command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-cursor-position", "markdownDescription": "Denies the set_cursor_position command without any pre-configured scope." }, { "description": "Denies the set_cursor_visible command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-cursor-visible", "markdownDescription": "Denies the set_cursor_visible command without any pre-configured scope." }, { "description": "Denies the set_decorations command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-decorations", "markdownDescription": "Denies the set_decorations command without any pre-configured scope." }, { "description": "Denies the set_effects command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-effects", "markdownDescription": "Denies the set_effects command without any pre-configured scope." }, { "description": "Denies the set_enabled command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-enabled", "markdownDescription": "Denies the set_enabled command without any pre-configured scope." }, { "description": "Denies the set_focus command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-focus", "markdownDescription": "Denies the set_focus command without any pre-configured scope." }, { "description": "Denies the set_focusable command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-focusable", "markdownDescription": "Denies the set_focusable command without any pre-configured scope." }, { "description": "Denies the set_fullscreen command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-fullscreen", "markdownDescription": "Denies the set_fullscreen command without any pre-configured scope." }, { "description": "Denies the set_icon command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-icon", "markdownDescription": "Denies the set_icon command without any pre-configured scope." }, { "description": "Denies the set_ignore_cursor_events command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-ignore-cursor-events", "markdownDescription": "Denies the set_ignore_cursor_events command without any pre-configured scope." }, { "description": "Denies the set_max_size command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-max-size", "markdownDescription": "Denies the set_max_size command without any pre-configured scope." }, { "description": "Denies the set_maximizable command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-maximizable", "markdownDescription": "Denies the set_maximizable command without any pre-configured scope." }, { "description": "Denies the set_min_size command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-min-size", "markdownDescription": "Denies the set_min_size command without any pre-configured scope." }, { "description": "Denies the set_minimizable command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-minimizable", "markdownDescription": "Denies the set_minimizable command without any pre-configured scope." }, { "description": "Denies the set_overlay_icon command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-overlay-icon", "markdownDescription": "Denies the set_overlay_icon command without any pre-configured scope." }, { "description": "Denies the set_position command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-position", "markdownDescription": "Denies the set_position command without any pre-configured scope." }, { "description": "Denies the set_progress_bar command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-progress-bar", "markdownDescription": "Denies the set_progress_bar command without any pre-configured scope." }, { "description": "Denies the set_resizable command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-resizable", "markdownDescription": "Denies the set_resizable command without any pre-configured scope." }, { "description": "Denies the set_shadow command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-shadow", "markdownDescription": "Denies the set_shadow command without any pre-configured scope." }, { "description": "Denies the set_simple_fullscreen command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-simple-fullscreen", "markdownDescription": "Denies the set_simple_fullscreen command without any pre-configured scope." }, { "description": "Denies the set_size command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-size", "markdownDescription": "Denies the set_size command without any pre-configured scope." }, { "description": "Denies the set_size_constraints command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-size-constraints", "markdownDescription": "Denies the set_size_constraints command without any pre-configured scope." }, { "description": "Denies the set_skip_taskbar command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-skip-taskbar", "markdownDescription": "Denies the set_skip_taskbar command without any pre-configured scope." }, { "description": "Denies the set_theme command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-theme", "markdownDescription": "Denies the set_theme command without any pre-configured scope." }, { "description": "Denies the set_title command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-title", "markdownDescription": "Denies the set_title command without any pre-configured scope." }, { "description": "Denies the set_title_bar_style command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-title-bar-style", "markdownDescription": "Denies the set_title_bar_style command without any pre-configured scope." }, { "description": "Denies the set_visible_on_all_workspaces command without any pre-configured scope.", "type": "string", "const": "core:window:deny-set-visible-on-all-workspaces", "markdownDescription": "Denies the set_visible_on_all_workspaces command without any pre-configured scope." }, { "description": "Denies the show command without any pre-configured scope.", "type": "string", "const": "core:window:deny-show", "markdownDescription": "Denies the show command without any pre-configured scope." }, { "description": "Denies the start_dragging command without any pre-configured scope.", "type": "string", "const": "core:window:deny-start-dragging", "markdownDescription": "Denies the start_dragging command without any pre-configured scope." }, { "description": "Denies the start_resize_dragging command without any pre-configured scope.", "type": "string", "const": "core:window:deny-start-resize-dragging", "markdownDescription": "Denies the start_resize_dragging command without any pre-configured scope." }, { "description": "Denies the theme command without any pre-configured scope.", "type": "string", "const": "core:window:deny-theme", "markdownDescription": "Denies the theme command without any pre-configured scope." }, { "description": "Denies the title command without any pre-configured scope.", "type": "string", "const": "core:window:deny-title", "markdownDescription": "Denies the title command without any pre-configured scope." }, { "description": "Denies the toggle_maximize command without any pre-configured scope.", "type": "string", "const": "core:window:deny-toggle-maximize", "markdownDescription": "Denies the toggle_maximize command without any pre-configured scope." }, { "description": "Denies the unmaximize command without any pre-configured scope.", "type": "string", "const": "core:window:deny-unmaximize", "markdownDescription": "Denies the unmaximize command without any pre-configured scope." }, { "description": "Denies the unminimize command without any pre-configured scope.", "type": "string", "const": "core:window:deny-unminimize", "markdownDescription": "Denies the unminimize command without any pre-configured scope." }, { "description": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`", "type": "string", "const": "shell:default", "markdownDescription": "This permission set configures which\nshell functionality is exposed by default.\n\n#### Granted Permissions\n\nIt allows to use the `open` functionality with a reasonable\nscope pre-configured. It will allow opening `http(s)://`,\n`tel:` and `mailto:` links.\n\n#### This default permission set includes:\n\n- `allow-open`" }, { "description": "Enables the execute command without any pre-configured scope.", "type": "string", "const": "shell:allow-execute", "markdownDescription": "Enables the execute command without any pre-configured scope." }, { "description": "Enables the kill command without any pre-configured scope.", "type": "string", "const": "shell:allow-kill", "markdownDescription": "Enables the kill command without any pre-configured scope." }, { "description": "Enables the open command without any pre-configured scope.", "type": "string", "const": "shell:allow-open", "markdownDescription": "Enables the open command without any pre-configured scope." }, { "description": "Enables the spawn command without any pre-configured scope.", "type": "string", "const": "shell:allow-spawn", "markdownDescription": "Enables the spawn command without any pre-configured scope." }, { "description": "Enables the stdin_write command without any pre-configured scope.", "type": "string", "const": "shell:allow-stdin-write", "markdownDescription": "Enables the stdin_write command without any pre-configured scope." }, { "description": "Denies the execute command without any pre-configured scope.", "type": "string", "const": "shell:deny-execute", "markdownDescription": "Denies the execute command without any pre-configured scope." }, { "description": "Denies the kill command without any pre-configured scope.", "type": "string", "const": "shell:deny-kill", "markdownDescription": "Denies the kill command without any pre-configured scope." }, { "description": "Denies the open command without any pre-configured scope.", "type": "string", "const": "shell:deny-open", "markdownDescription": "Denies the open command without any pre-configured scope." }, { "description": "Denies the spawn command without any pre-configured scope.", "type": "string", "const": "shell:deny-spawn", "markdownDescription": "Denies the spawn command without any pre-configured scope." }, { "description": "Denies the stdin_write command without any pre-configured scope.", "type": "string", "const": "shell:deny-stdin-write", "markdownDescription": "Denies the stdin_write command without any pre-configured scope." }, { "description": "This permission set configures what kind of\noperations are available from the window state plugin.\n\n#### Granted Permissions\n\nAll operations are enabled by default.\n\n\n#### This default permission set includes:\n\n- `allow-filename`\n- `allow-restore-state`\n- `allow-save-window-state`", "type": "string", "const": "window-state:default", "markdownDescription": "This permission set configures what kind of\noperations are available from the window state plugin.\n\n#### Granted Permissions\n\nAll operations are enabled by default.\n\n\n#### This default permission set includes:\n\n- `allow-filename`\n- `allow-restore-state`\n- `allow-save-window-state`" }, { "description": "Enables the filename command without any pre-configured scope.", "type": "string", "const": "window-state:allow-filename", "markdownDescription": "Enables the filename command without any pre-configured scope." }, { "description": "Enables the restore_state command without any pre-configured scope.", "type": "string", "const": "window-state:allow-restore-state", "markdownDescription": "Enables the restore_state command without any pre-configured scope." }, { "description": "Enables the save_window_state command without any pre-configured scope.", "type": "string", "const": "window-state:allow-save-window-state", "markdownDescription": "Enables the save_window_state command without any pre-configured scope." }, { "description": "Denies the filename command without any pre-configured scope.", "type": "string", "const": "window-state:deny-filename", "markdownDescription": "Denies the filename command without any pre-configured scope." }, { "description": "Denies the restore_state command without any pre-configured scope.", "type": "string", "const": "window-state:deny-restore-state", "markdownDescription": "Denies the restore_state command without any pre-configured scope." }, { "description": "Denies the save_window_state command without any pre-configured scope.", "type": "string", "const": "window-state:deny-save-window-state", "markdownDescription": "Denies the save_window_state command without any pre-configured scope." } ] }, "Value": { "description": "All supported ACL values.", "anyOf": [ { "description": "Represents a null JSON value.", "type": "null" }, { "description": "Represents a [`bool`].", "type": "boolean" }, { "description": "Represents a valid ACL [`Number`].", "allOf": [ { "$ref": "#/definitions/Number" } ] }, { "description": "Represents a [`String`].", "type": "string" }, { "description": "Represents a list of other [`Value`]s.", "type": "array", "items": { "$ref": "#/definitions/Value" } }, { "description": "Represents a map of [`String`] keys to [`Value`]s.", "type": "object", "additionalProperties": { "$ref": "#/definitions/Value" } } ] }, "Number": { "description": "A valid ACL number.", "anyOf": [ { "description": "Represents an [`i64`].", "type": "integer", "format": "int64" }, { "description": "Represents a [`f64`].", "type": "number", "format": "double" } ] }, "Target": { "description": "Platform target.", "oneOf": [ { "description": "MacOS.", "type": "string", "enum": [ "macOS" ] }, { "description": "Windows.", "type": "string", "enum": [ "windows" ] }, { "description": "Linux.", "type": "string", "enum": [ "linux" ] }, { "description": "Android.", "type": "string", "enum": [ "android" ] }, { "description": "iOS.", "type": "string", "enum": [ "iOS" ] } ] }, "ShellScopeEntryAllowedArg": { "description": "A command argument allowed to be executed by the webview API.", "anyOf": [ { "description": "A non-configurable argument that is passed to the command in the order it was specified.", "type": "string" }, { "description": "A variable that is set while calling the command from the webview API.", "type": "object", "required": [ "validator" ], "properties": { "raw": { "description": "Marks the validator as a raw regex, meaning the plugin should not make any modification at runtime.\n\nThis means the regex will not match on the entire string by default, which might be exploited if your regex allow unexpected input to be considered valid. When using this option, make sure your regex is correct.", "default": false, "type": "boolean" }, "validator": { "description": "[regex] validator to require passed values to conform to an expected input.\n\nThis will require the argument value passed to this variable to match the `validator` regex before it will be executed.\n\nThe regex string is by default surrounded by `^...$` to match the full string. For example the `https?://\\w+` regex would be registered as `^https?://\\w+$`.\n\n[regex]: ", "type": "string" } }, "additionalProperties": false } ] }, "ShellScopeEntryAllowedArgs": { "description": "A set of command arguments allowed to be executed by the webview API.\n\nA value of `true` will allow any arguments to be passed to the command. `false` will disable all arguments. A list of [`ShellScopeEntryAllowedArg`] will set those arguments as the only valid arguments to be passed to the attached command configuration.", "anyOf": [ { "description": "Use a simple boolean to allow all or disable all arguments to this command configuration.", "type": "boolean" }, { "description": "A specific set of [`ShellScopeEntryAllowedArg`] that are valid to call for the command configuration.", "type": "array", "items": { "$ref": "#/definitions/ShellScopeEntryAllowedArg" } } ] } } } ================================================ FILE: desktop/src-tauri/icons/android/mipmap-anydpi-v26/ic_launcher.xml ================================================ ================================================ FILE: desktop/src-tauri/icons/android/values/ic_launcher_background.xml ================================================ #fff ================================================ FILE: desktop/src-tauri/src/main.rs ================================================ // Prevents additional console window on Windows in release #![cfg_attr(not(debug_assertions), windows_subsystem = "windows")] use directories::ProjectDirs; use serde::{Deserialize, Serialize}; use std::fs; use std::path::PathBuf; use std::process::Command; use std::sync::{Mutex, RwLock}; use std::io::Write as IoWrite; use std::time::SystemTime; #[cfg(target_os = "macos")] use std::time::Duration; use tauri::image::Image; use tauri::menu::{ CheckMenuItem, Menu, MenuBuilder, MenuItem, PredefinedMenuItem, SubmenuBuilder, HELP_SUBMENU_ID, }; use tauri::tray::{TrayIconBuilder, TrayIconEvent}; #[cfg(target_os = "macos")] use tauri::WebviewWindow; use tauri::Wry; use tauri::{ webview::PageLoadPayload, AppHandle, Manager, Webview, WebviewUrl, WebviewWindowBuilder, }; #[cfg(target_os = "macos")] use tokio::time::sleep; use url::Url; #[cfg(target_os = "macos")] use window_vibrancy::{apply_vibrancy, NSVisualEffectMaterial}; // ============================================================================ // Configuration // ============================================================================ const DEFAULT_SERVER_URL: &str = "https://cloud.onyx.app"; const CONFIG_FILE_NAME: &str = "config.json"; #[cfg(target_os = "macos")] const TITLEBAR_SCRIPT: &str = include_str!("../../src/titlebar.js"); const TRAY_ID: &str = "onyx-tray"; const TRAY_ICON_BYTES: &[u8] = include_bytes!("../icons/tray-icon.png"); const TRAY_MENU_OPEN_APP_ID: &str = "tray_open_app"; const TRAY_MENU_OPEN_CHAT_ID: &str = "tray_open_chat"; const TRAY_MENU_SHOW_IN_BAR_ID: &str = "tray_show_in_menu_bar"; const TRAY_MENU_QUIT_ID: &str = "tray_quit"; const MENU_SHOW_MENU_BAR_ID: &str = "show_menu_bar"; const MENU_HIDE_DECORATIONS_ID: &str = "hide_window_decorations"; const CHAT_LINK_INTERCEPT_SCRIPT: &str = r##" (() => { if (window.__ONYX_CHAT_LINK_INTERCEPT_INSTALLED__) { return; } window.__ONYX_CHAT_LINK_INTERCEPT_INSTALLED__ = true; function isChatSessionPage() { try { const currentUrl = new URL(window.location.href); return ( currentUrl.pathname.startsWith("/app") && currentUrl.searchParams.has("chatId") ); } catch { return false; } } function getAllowedNavigationUrl(rawUrl) { try { const parsed = new URL(String(rawUrl), window.location.href); const scheme = parsed.protocol.toLowerCase(); if (!["http:", "https:", "mailto:", "tel:"].includes(scheme)) { return null; } return parsed; } catch { return null; } } async function openWithTauri(url) { try { const invoke = window.__TAURI__?.core?.invoke || window.__TAURI_INTERNALS__?.invoke; if (typeof invoke !== "function") { return false; } await invoke("open_in_browser", { url }); return true; } catch { return false; } } function handleChatNavigation(rawUrl) { const parsedUrl = getAllowedNavigationUrl(rawUrl); if (!parsedUrl) { return false; } const safeUrl = parsedUrl.toString(); const scheme = parsedUrl.protocol.toLowerCase(); if (scheme === "mailto:" || scheme === "tel:") { void openWithTauri(safeUrl).then((opened) => { if (!opened) { window.location.assign(safeUrl); } }); return true; } window.location.assign(safeUrl); return true; } document.addEventListener( "click", (event) => { if (!isChatSessionPage() || event.defaultPrevented) { return; } const element = event.target; if (!(element instanceof Element)) { return; } const anchor = element.closest("a"); if (!(anchor instanceof HTMLAnchorElement)) { return; } const target = (anchor.getAttribute("target") || "").toLowerCase(); if (target !== "_blank") { return; } const href = anchor.getAttribute("href"); if (!href || href.startsWith("#")) { return; } if (!handleChatNavigation(href)) { return; } event.preventDefault(); event.stopPropagation(); }, true ); const nativeWindowOpen = window.open; window.open = function(url, target, features) { const resolvedTarget = typeof target === "string" ? target.toLowerCase() : ""; const shouldNavigateInPlace = resolvedTarget === "" || resolvedTarget === "_blank"; if ( isChatSessionPage() && shouldNavigateInPlace && url != null && String(url).length > 0 ) { if (!handleChatNavigation(url)) { return null; } return null; } if (typeof nativeWindowOpen === "function") { return nativeWindowOpen.call(window, url, target, features); } return null; }; })(); "##; #[cfg(not(target_os = "macos"))] const MENU_KEY_HANDLER_SCRIPT: &str = r#" (() => { if (window.__ONYX_MENU_KEY_HANDLER__) return; window.__ONYX_MENU_KEY_HANDLER__ = true; let altHeld = false; function invoke(cmd) { const fn_ = window.__TAURI__?.core?.invoke || window.__TAURI_INTERNALS__?.invoke; if (typeof fn_ === 'function') fn_(cmd); } function releaseAltAndHideMenu() { if (!altHeld) { return; } altHeld = false; invoke('hide_menu_bar_temporary'); } document.addEventListener('keydown', (e) => { if (e.key === 'Alt') { if (!altHeld) { altHeld = true; invoke('show_menu_bar_temporarily'); } return; } if (e.altKey && e.key === 'F1') { e.preventDefault(); e.stopPropagation(); altHeld = false; invoke('toggle_menu_bar'); return; } }, true); document.addEventListener('keyup', (e) => { if (e.key === 'Alt' && altHeld) { releaseAltAndHideMenu(); } }, true); window.addEventListener('blur', () => { releaseAltAndHideMenu(); }); document.addEventListener('visibilitychange', () => { if (document.hidden) { releaseAltAndHideMenu(); } }); })(); "#; const CONSOLE_CAPTURE_SCRIPT: &str = r#" (() => { if (window.__ONYX_CONSOLE_CAPTURE__) return; window.__ONYX_CONSOLE_CAPTURE__ = true; const levels = ['log', 'warn', 'error', 'info', 'debug']; const originals = {}; levels.forEach(level => { originals[level] = console[level]; console[level] = function(...args) { originals[level].apply(console, args); try { const invoke = window.__TAURI__?.core?.invoke || window.__TAURI_INTERNALS__?.invoke; if (typeof invoke === 'function') { const message = args.map(a => { try { return typeof a === 'string' ? a : JSON.stringify(a); } catch { return String(a); } }).join(' '); invoke('log_from_frontend', { level, message }); } } catch {} }; }); window.addEventListener('error', (event) => { try { const invoke = window.__TAURI__?.core?.invoke || window.__TAURI_INTERNALS__?.invoke; if (typeof invoke === 'function') { invoke('log_from_frontend', { level: 'error', message: `[uncaught] ${event.message} at ${event.filename}:${event.lineno}:${event.colno}` }); } } catch {} }); window.addEventListener('unhandledrejection', (event) => { try { const invoke = window.__TAURI__?.core?.invoke || window.__TAURI_INTERNALS__?.invoke; if (typeof invoke === 'function') { invoke('log_from_frontend', { level: 'error', message: `[unhandled rejection] ${event.reason}` }); } } catch {} }); })(); "#; const MENU_TOGGLE_DEVTOOLS_ID: &str = "toggle_devtools"; const MENU_OPEN_DEBUG_LOG_ID: &str = "open_debug_log"; #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AppConfig { pub server_url: String, #[serde(default = "default_window_title")] pub window_title: String, #[serde(default = "default_show_menu_bar")] pub show_menu_bar: bool, #[serde(default)] pub hide_window_decorations: bool, } fn default_window_title() -> String { "Onyx".to_string() } fn default_show_menu_bar() -> bool { true } impl Default for AppConfig { fn default() -> Self { Self { server_url: DEFAULT_SERVER_URL.to_string(), window_title: default_window_title(), show_menu_bar: true, hide_window_decorations: false, } } } /// Get the config directory path fn get_config_dir() -> Option { ProjectDirs::from("app", "onyx", "onyx-desktop").map(|dirs| dirs.config_dir().to_path_buf()) } /// Get the full config file path fn get_config_path() -> Option { get_config_dir().map(|dir| dir.join(CONFIG_FILE_NAME)) } /// Load config from file, or create default if it doesn't exist fn load_config() -> (AppConfig, bool) { let config_path = match get_config_path() { Some(path) => path, None => { return (AppConfig::default(), false); } }; if !config_path.exists() { return (AppConfig::default(), false); } match fs::read_to_string(&config_path) { Ok(contents) => match serde_json::from_str(&contents) { Ok(config) => (config, true), Err(_) => (AppConfig::default(), false), }, Err(_) => (AppConfig::default(), false), } } /// Save config to file fn save_config(config: &AppConfig) -> Result<(), String> { let config_dir = get_config_dir().ok_or("Could not determine config directory")?; let config_path = config_dir.join(CONFIG_FILE_NAME); // Ensure config directory exists fs::create_dir_all(&config_dir).map_err(|e| format!("Failed to create config dir: {}", e))?; let json = serde_json::to_string_pretty(config) .map_err(|e| format!("Failed to serialize config: {}", e))?; fs::write(&config_path, json).map_err(|e| format!("Failed to write config: {}", e))?; Ok(()) } // ============================================================================ // Debug Mode // ============================================================================ fn is_debug_mode() -> bool { std::env::args().any(|arg| arg == "--debug") || std::env::var("ONYX_DEBUG").is_ok() } fn get_debug_log_path() -> Option { get_config_dir().map(|dir| dir.join("frontend_debug.log")) } fn init_debug_log_file() -> Option { let log_path = get_debug_log_path()?; if let Some(parent) = log_path.parent() { let _ = fs::create_dir_all(parent); } fs::OpenOptions::new() .create(true) .append(true) .open(&log_path) .ok() } fn format_utc_timestamp() -> String { let now = SystemTime::now() .duration_since(SystemTime::UNIX_EPOCH) .unwrap_or_default(); let total_secs = now.as_secs(); let millis = now.subsec_millis(); let days = total_secs / 86400; let secs_of_day = total_secs % 86400; let hours = secs_of_day / 3600; let mins = (secs_of_day % 3600) / 60; let secs = secs_of_day % 60; // Days since Unix epoch -> Y/M/D via civil calendar arithmetic let z = days as i64 + 719468; let era = z / 146097; let doe = z - era * 146097; let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; let y = yoe + era * 400; let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); let mp = (5 * doy + 2) / 153; let d = doy - (153 * mp + 2) / 5 + 1; let m = if mp < 10 { mp + 3 } else { mp - 9 }; let y = if m <= 2 { y + 1 } else { y }; format!( "{:04}-{:02}-{:02}T{:02}:{:02}:{:02}.{:03}Z", y, m, d, hours, mins, secs, millis ) } fn inject_console_capture(webview: &Webview) { let _ = webview.eval(CONSOLE_CAPTURE_SCRIPT); } fn maybe_open_devtools(app: &AppHandle, window: &tauri::WebviewWindow) { #[cfg(any(debug_assertions, feature = "devtools"))] { let state = app.state::(); if state.debug_mode { window.open_devtools(); } } #[cfg(not(any(debug_assertions, feature = "devtools")))] { let _ = (app, window); } } // Global config state struct ConfigState { config: RwLock, config_initialized: RwLock, app_base_url: RwLock>, menu_temporarily_visible: RwLock, debug_mode: bool, debug_log_file: Mutex>, } fn focus_main_window(app: &AppHandle) { if let Some(window) = app.get_webview_window("main") { let _ = window.unminimize(); let _ = window.show(); let _ = window.set_focus(); } else { trigger_new_window(app); } } fn trigger_new_chat(app: &AppHandle) { let state = app.state::(); let server_url = state.config.read().unwrap().server_url.clone(); if let Some(window) = app.get_webview_window("main") { let url = format!("{}/chat", server_url); let _ = window.eval(&format!("window.location.href = '{}'", url)); } } fn trigger_new_window(app: &AppHandle) { let state = app.state::(); let server_url = state.config.read().unwrap().server_url.clone(); let handle = app.clone(); tauri::async_runtime::spawn(async move { let window_label = format!("onyx-{}", uuid::Uuid::new_v4()); let builder = WebviewWindowBuilder::new( &handle, &window_label, WebviewUrl::External(server_url.parse().unwrap()), ) .title("Onyx") .inner_size(1200.0, 800.0) .min_inner_size(800.0, 600.0) .transparent(true); #[cfg(target_os = "macos")] let builder = builder .title_bar_style(tauri::TitleBarStyle::Overlay) .hidden_title(true); #[cfg(target_os = "linux")] let builder = builder.background_color(tauri::window::Color(0x1a, 0x1a, 0x2e, 0xff)); if let Ok(window) = builder.build() { #[cfg(target_os = "macos")] { let _ = apply_vibrancy(&window, NSVisualEffectMaterial::Sidebar, None, None); inject_titlebar(window.clone()); } apply_settings_to_window(&handle, &window); maybe_open_devtools(&handle, &window); let _ = window.set_focus(); } }); } fn open_docs() { let _ = open_in_default_browser("https://docs.onyx.app"); } fn open_settings(app: &AppHandle) { // Navigate main window to the settings page (index.html) with settings flag let state = app.state::(); let settings_url = state .app_base_url .read() .unwrap() .as_ref() .cloned() .and_then(|mut url| { url.set_query(None); url.set_fragment(Some("settings")); url.set_path("/"); Some(url) }) .or_else(|| Url::parse("tauri://localhost/#settings").ok()); if let Some(window) = app.get_webview_window("main") { if let Some(url) = settings_url { let _ = window.navigate(url); } } } fn same_origin(left: &Url, right: &Url) -> bool { left.scheme() == right.scheme() && left.host_str() == right.host_str() && left.port_or_known_default() == right.port_or_known_default() } fn is_chat_session_url(url: &Url) -> bool { url.path().starts_with("/app") && url.query_pairs().any(|(key, _)| key == "chatId") } fn should_open_in_external_browser(current_url: &Url, destination_url: &Url) -> bool { if !is_chat_session_url(current_url) { return false; } match destination_url.scheme() { "mailto" | "tel" => true, "http" | "https" => !same_origin(current_url, destination_url), _ => false, } } fn open_in_default_browser(url: &str) -> bool { #[cfg(target_os = "macos")] { return Command::new("open").arg(url).status().is_ok(); } #[cfg(target_os = "linux")] { return Command::new("xdg-open").arg(url).status().is_ok(); } #[cfg(target_os = "windows")] { return Command::new("rundll32") .arg("url.dll,FileProtocolHandler") .arg(url) .status() .is_ok(); } #[allow(unreachable_code)] false } #[tauri::command] fn open_in_browser(url: String) -> Result<(), String> { let parsed_url = Url::parse(&url).map_err(|_| "Invalid URL".to_string())?; match parsed_url.scheme() { "http" | "https" | "mailto" | "tel" => {} _ => return Err("Unsupported URL scheme".to_string()), } if open_in_default_browser(parsed_url.as_str()) { Ok(()) } else { Err("Failed to open URL in default browser".to_string()) } } fn inject_chat_link_intercept(webview: &Webview) { let _ = webview.eval(CHAT_LINK_INTERCEPT_SCRIPT); } fn handle_toggle_devtools(app: &AppHandle) { #[cfg(any(debug_assertions, feature = "devtools"))] { let windows: Vec<_> = app.webview_windows().into_values().collect(); let any_open = windows.iter().any(|w| w.is_devtools_open()); for window in &windows { if any_open { window.close_devtools(); } else { window.open_devtools(); } } } #[cfg(not(any(debug_assertions, feature = "devtools")))] { let _ = app; } } fn handle_open_debug_log() { let log_path = match get_debug_log_path() { Some(p) => p, None => return, }; if !log_path.exists() { eprintln!("[ONYX DEBUG] Log file does not exist yet: {:?}", log_path); return; } let url_path = log_path.to_string_lossy().replace('\\', "/"); let _ = open_in_default_browser(&format!( "file:///{}", url_path.trim_start_matches('/') )); } // ============================================================================ // Tauri Commands // ============================================================================ #[tauri::command] fn log_from_frontend(level: String, message: String, state: tauri::State) { if !state.debug_mode { return; } let timestamp = format_utc_timestamp(); let log_line = format!("[{}] [{}] {}", timestamp, level.to_uppercase(), message); eprintln!("{}", log_line); if let Ok(mut guard) = state.debug_log_file.lock() { if let Some(ref mut file) = *guard { let _ = writeln!(file, "{}", log_line); let _ = file.flush(); } } } /// Get the current server URL #[tauri::command] fn get_server_url(state: tauri::State) -> String { state.config.read().unwrap().server_url.clone() } #[derive(Serialize)] struct BootstrapState { server_url: String, config_exists: bool, } /// Get the server URL plus whether a config file exists #[tauri::command] fn get_bootstrap_state(state: tauri::State) -> BootstrapState { let server_url = state.config.read().unwrap().server_url.clone(); let config_initialized = *state.config_initialized.read().unwrap(); let config_exists = config_initialized && get_config_path().map(|path| path.exists()).unwrap_or(false); BootstrapState { server_url, config_exists, } } /// Set a new server URL and save to config #[tauri::command] fn set_server_url(state: tauri::State, url: String) -> Result { // Validate URL if !url.starts_with("http://") && !url.starts_with("https://") { return Err("URL must start with http:// or https://".to_string()); } let mut config = state.config.write().unwrap(); config.server_url = url.trim_end_matches('/').to_string(); save_config(&config)?; *state.config_initialized.write().unwrap() = true; Ok(config.server_url.clone()) } /// Get the config file path (so users know where to edit) #[tauri::command] fn get_config_path_cmd() -> Result { get_config_path() .map(|p| p.to_string_lossy().to_string()) .ok_or_else(|| "Could not determine config path".to_string()) } /// Open the config file in the default editor #[tauri::command] fn open_config_file() -> Result<(), String> { let config_path = get_config_path().ok_or("Could not determine config path")?; // Ensure config exists if !config_path.exists() { save_config(&AppConfig::default())?; } #[cfg(target_os = "macos")] { std::process::Command::new("open") .arg("-t") .arg(&config_path) .spawn() .map_err(|e| format!("Failed to open config: {}", e))?; } #[cfg(target_os = "linux")] { std::process::Command::new("xdg-open") .arg(&config_path) .spawn() .map_err(|e| format!("Failed to open config: {}", e))?; } #[cfg(target_os = "windows")] { std::process::Command::new("notepad") .arg(&config_path) .spawn() .map_err(|e| format!("Failed to open config: {}", e))?; } Ok(()) } /// Open the config directory in file manager #[tauri::command] fn open_config_directory() -> Result<(), String> { let config_dir = get_config_dir().ok_or("Could not determine config directory")?; // Ensure directory exists fs::create_dir_all(&config_dir).map_err(|e| format!("Failed to create config dir: {}", e))?; #[cfg(target_os = "macos")] { std::process::Command::new("open") .arg(&config_dir) .spawn() .map_err(|e| format!("Failed to open directory: {}", e))?; } #[cfg(target_os = "linux")] { std::process::Command::new("xdg-open") .arg(&config_dir) .spawn() .map_err(|e| format!("Failed to open directory: {}", e))?; } #[cfg(target_os = "windows")] { std::process::Command::new("explorer") .arg(&config_dir) .spawn() .map_err(|e| format!("Failed to open directory: {}", e))?; } Ok(()) } /// Navigate to a specific path on the configured server #[tauri::command] fn navigate_to(window: tauri::WebviewWindow, state: tauri::State, path: &str) { let base_url = state.config.read().unwrap().server_url.clone(); let url = format!("{}{}", base_url, path); let _ = window.eval(&format!("window.location.href = '{}'", url)); } /// Reload the current page #[tauri::command] fn reload_page(window: tauri::WebviewWindow) { let _ = window.eval("window.location.reload()"); } /// Go back in history #[tauri::command] fn go_back(window: tauri::WebviewWindow) { let _ = window.eval("window.history.back()"); } /// Go forward in history #[tauri::command] fn go_forward(window: tauri::WebviewWindow) { let _ = window.eval("window.history.forward()"); } /// Open a new window #[tauri::command] async fn new_window(app: AppHandle, state: tauri::State<'_, ConfigState>) -> Result<(), String> { let server_url = state.config.read().unwrap().server_url.clone(); let window_label = format!("onyx-{}", uuid::Uuid::new_v4()); let builder = WebviewWindowBuilder::new( &app, &window_label, WebviewUrl::External( server_url .parse() .map_err(|e| format!("Invalid URL: {}", e))?, ), ) .title("Onyx") .inner_size(1200.0, 800.0) .min_inner_size(800.0, 600.0) .transparent(true); #[cfg(target_os = "macos")] let builder = builder .title_bar_style(tauri::TitleBarStyle::Overlay) .hidden_title(true); #[cfg(target_os = "linux")] let builder = builder.background_color(tauri::window::Color(0x1a, 0x1a, 0x2e, 0xff)); let window = builder.build().map_err(|e| e.to_string())?; #[cfg(target_os = "macos")] { let _ = apply_vibrancy(&window, NSVisualEffectMaterial::Sidebar, None, None); inject_titlebar(window.clone()); } apply_settings_to_window(&app, &window); maybe_open_devtools(&app, &window); Ok(()) } /// Reset config to defaults #[tauri::command] fn reset_config(state: tauri::State) -> Result<(), String> { let mut config = state.config.write().unwrap(); *config = AppConfig::default(); save_config(&config)?; *state.config_initialized.write().unwrap() = true; Ok(()) } #[cfg(target_os = "macos")] fn inject_titlebar(window: WebviewWindow) { let script = TITLEBAR_SCRIPT.to_string(); tauri::async_runtime::spawn(async move { // Keep trying for a few seconds to survive navigations and slow loads let delays = [0u64, 200, 600, 1200, 2000, 4000, 6000, 8000, 10000]; for delay in delays { if delay > 0 { sleep(Duration::from_millis(delay)).await; } let _ = window.eval(&script); } }); } /// Start dragging the window #[tauri::command] async fn start_drag_window(window: tauri::Window) -> Result<(), String> { window.start_dragging().map_err(|e| e.to_string()) } // ============================================================================ // Window Settings // ============================================================================ fn find_check_menu_item( app: &AppHandle, id: &str, ) -> Option> { let menu = app.menu()?; for item in menu.items().ok()? { if let Some(submenu) = item.as_submenu() { for sub_item in submenu.items().ok()? { if let Some(check) = sub_item.as_check_menuitem() { if check.id().as_ref() == id { return Some(check.clone()); } } } } } None } fn apply_settings_to_window(app: &AppHandle, window: &tauri::WebviewWindow) { if cfg!(target_os = "macos") { return; } let state = app.state::(); let config = state.config.read().unwrap(); let temp_visible = *state.menu_temporarily_visible.read().unwrap(); if !config.show_menu_bar && !temp_visible { let _ = window.hide_menu(); } if config.hide_window_decorations { let _ = window.set_decorations(false); } } fn handle_menu_bar_toggle(app: &AppHandle) { if cfg!(target_os = "macos") { return; } let state = app.state::(); let show = { let mut config = state.config.write().unwrap(); config.show_menu_bar = !config.show_menu_bar; let _ = save_config(&config); config.show_menu_bar }; *state.menu_temporarily_visible.write().unwrap() = false; for (_, window) in app.webview_windows() { if show { let _ = window.show_menu(); } else { let _ = window.hide_menu(); } } } fn handle_decorations_toggle(app: &AppHandle) { if cfg!(target_os = "macos") { return; } let state = app.state::(); let hide = { let mut config = state.config.write().unwrap(); config.hide_window_decorations = !config.hide_window_decorations; let _ = save_config(&config); config.hide_window_decorations }; for (_, window) in app.webview_windows() { let _ = window.set_decorations(!hide); } } #[tauri::command] fn toggle_menu_bar(app: AppHandle) { if cfg!(target_os = "macos") { return; } handle_menu_bar_toggle(&app); let state = app.state::(); let checked = state.config.read().unwrap().show_menu_bar; if let Some(check) = find_check_menu_item(&app, MENU_SHOW_MENU_BAR_ID) { let _ = check.set_checked(checked); } } #[tauri::command] fn show_menu_bar_temporarily(app: AppHandle) { if cfg!(target_os = "macos") { return; } let state = app.state::(); if state.config.read().unwrap().show_menu_bar { return; } let mut temp = state.menu_temporarily_visible.write().unwrap(); if *temp { return; } *temp = true; drop(temp); for (_, window) in app.webview_windows() { let _ = window.show_menu(); } } #[tauri::command] fn hide_menu_bar_temporary(app: AppHandle) { if cfg!(target_os = "macos") { return; } let state = app.state::(); let mut temp = state.menu_temporarily_visible.write().unwrap(); if !*temp { return; } *temp = false; drop(temp); if state.config.read().unwrap().show_menu_bar { return; } for (_, window) in app.webview_windows() { let _ = window.hide_menu(); } } // ============================================================================ // Menu Setup // ============================================================================ fn setup_app_menu(app: &AppHandle) -> tauri::Result<()> { let menu = app.menu().unwrap_or(Menu::default(app)?); let new_chat_item = MenuItem::with_id(app, "new_chat", "New Chat", true, Some("CmdOrCtrl+N"))?; let new_window_item = MenuItem::with_id( app, "new_window", "New Window", true, Some("CmdOrCtrl+Shift+N"), )?; let settings_item = MenuItem::with_id( app, "open_settings", "Settings...", true, Some("CmdOrCtrl+Comma"), )?; let docs_item = MenuItem::with_id(app, "open_docs", "Onyx Documentation", true, None::<&str>)?; if let Some(file_menu) = menu .items()? .into_iter() .filter_map(|item| item.as_submenu().cloned()) .find(|submenu| submenu.text().ok().as_deref() == Some("File")) { file_menu.insert_items(&[&new_chat_item, &new_window_item, &settings_item], 0)?; } else { let file_menu = SubmenuBuilder::new(app, "File") .items(&[ &new_chat_item, &new_window_item, &settings_item, &PredefinedMenuItem::close_window(app, None)?, ]) .build()?; menu.prepend(&file_menu)?; } #[cfg(not(target_os = "macos"))] { let config = app.state::(); let config_guard = config.config.read().unwrap(); let show_menu_bar_item = CheckMenuItem::with_id( app, MENU_SHOW_MENU_BAR_ID, "Show Menu Bar", true, config_guard.show_menu_bar, None::<&str>, )?; let hide_decorations_item = CheckMenuItem::with_id( app, MENU_HIDE_DECORATIONS_ID, "Hide Window Decorations", true, config_guard.hide_window_decorations, None::<&str>, )?; drop(config_guard); if let Some(window_menu) = menu .items()? .into_iter() .filter_map(|item| item.as_submenu().cloned()) .find(|submenu| submenu.text().ok().as_deref() == Some("Window")) { window_menu.append(&show_menu_bar_item)?; window_menu.append(&hide_decorations_item)?; } else { let window_menu = SubmenuBuilder::new(app, "Window") .item(&show_menu_bar_item) .item(&hide_decorations_item) .build()?; let items = menu.items()?; let help_idx = items .iter() .position(|item| { item.as_submenu() .and_then(|s| s.text().ok()) .as_deref() == Some("Help") }) .unwrap_or(items.len()); menu.insert(&window_menu, help_idx)?; } } if let Some(help_menu) = menu .get(HELP_SUBMENU_ID) .and_then(|item| item.as_submenu().cloned()) { help_menu.append(&docs_item)?; } else { let help_menu = SubmenuBuilder::with_id(app, HELP_SUBMENU_ID, "Help") .item(&docs_item) .build()?; menu.append(&help_menu)?; } let state = app.state::(); if state.debug_mode { let toggle_devtools_item = MenuItem::with_id( app, MENU_TOGGLE_DEVTOOLS_ID, "Toggle DevTools", true, Some("F12"), )?; let open_log_item = MenuItem::with_id( app, MENU_OPEN_DEBUG_LOG_ID, "Open Debug Log", true, None::<&str>, )?; let debug_menu = SubmenuBuilder::new(app, "Debug") .item(&toggle_devtools_item) .item(&open_log_item) .build()?; menu.append(&debug_menu)?; } app.set_menu(menu)?; Ok(()) } fn build_tray_menu(app: &AppHandle) -> tauri::Result> { let open_app = MenuItem::with_id(app, TRAY_MENU_OPEN_APP_ID, "Open Onyx", true, None::<&str>)?; let open_chat = MenuItem::with_id( app, TRAY_MENU_OPEN_CHAT_ID, "Open Chat Window", true, None::<&str>, )?; let show_in_menu_bar = CheckMenuItem::with_id( app, TRAY_MENU_SHOW_IN_BAR_ID, "Show in Menu Bar", true, true, None::<&str>, )?; // Keep it visible/pinned without letting users uncheck (avoids orphaning the tray) let _ = show_in_menu_bar.set_enabled(false); let quit = PredefinedMenuItem::quit(app, Some("Quit Onyx"))?; MenuBuilder::new(app) .item(&open_app) .item(&open_chat) .separator() .item(&show_in_menu_bar) .separator() .item(&quit) .build() } fn handle_tray_menu_event(app: &AppHandle, id: &str) { match id { TRAY_MENU_OPEN_APP_ID => { focus_main_window(app); } TRAY_MENU_OPEN_CHAT_ID => { focus_main_window(app); trigger_new_chat(app); } TRAY_MENU_QUIT_ID => { app.exit(0); } TRAY_MENU_SHOW_IN_BAR_ID => { // No-op for now; the item stays checked/disabled to indicate it's pinned. } _ => {} } } fn setup_tray_icon(app: &AppHandle) -> tauri::Result<()> { let mut builder = TrayIconBuilder::with_id(TRAY_ID).tooltip("Onyx"); let tray_icon = Image::from_bytes(TRAY_ICON_BYTES) .ok() .or_else(|| app.default_window_icon().cloned()); if let Some(icon) = tray_icon { builder = builder.icon(icon); #[cfg(target_os = "macos")] { builder = builder.icon_as_template(true); } } if let Ok(menu) = build_tray_menu(app) { builder = builder.menu(&menu); } builder .on_tray_icon_event(|tray, event| { if let TrayIconEvent::Click { .. } = event { focus_main_window(tray.app_handle()); } }) .on_menu_event(|app, event| handle_tray_menu_event(app, event.id().as_ref())) .build(app)?; Ok(()) } // ============================================================================ // Main // ============================================================================ fn main() { let (config, config_initialized) = load_config(); let debug_mode = is_debug_mode(); let debug_log_file = if debug_mode { eprintln!("[ONYX DEBUG] Debug mode enabled"); if let Some(path) = get_debug_log_path() { eprintln!("[ONYX DEBUG] Frontend logs: {}", path.display()); } eprintln!("[ONYX DEBUG] DevTools will open automatically"); eprintln!("[ONYX DEBUG] Capturing console.log/warn/error/info/debug from webview"); init_debug_log_file() } else { None }; tauri::Builder::default() .plugin(tauri_plugin_shell::init()) .plugin( tauri::plugin::Builder::::new("chat-external-navigation-handler") .on_navigation(|webview, destination_url| { let Ok(current_url) = webview.url() else { return true; }; if should_open_in_external_browser(¤t_url, destination_url) { if !open_in_default_browser(destination_url.as_str()) { eprintln!( "Failed to open external URL in default browser: {}", destination_url ); } return false; } true }) .build(), ) .plugin(tauri_plugin_window_state::Builder::default().build()) .manage(ConfigState { config: RwLock::new(config), config_initialized: RwLock::new(config_initialized), app_base_url: RwLock::new(None), menu_temporarily_visible: RwLock::new(false), debug_mode, debug_log_file: Mutex::new(debug_log_file), }) .invoke_handler(tauri::generate_handler![ get_server_url, get_bootstrap_state, set_server_url, get_config_path_cmd, open_in_browser, open_config_file, open_config_directory, navigate_to, reload_page, go_back, go_forward, new_window, reset_config, start_drag_window, toggle_menu_bar, show_menu_bar_temporarily, hide_menu_bar_temporary, log_from_frontend ]) .on_menu_event(|app, event| match event.id().as_ref() { "open_docs" => open_docs(), "new_chat" => trigger_new_chat(app), "new_window" => trigger_new_window(app), "open_settings" => open_settings(app), "show_menu_bar" => handle_menu_bar_toggle(app), "hide_window_decorations" => handle_decorations_toggle(app), MENU_TOGGLE_DEVTOOLS_ID => handle_toggle_devtools(app), MENU_OPEN_DEBUG_LOG_ID => handle_open_debug_log(), _ => {} }) .setup(move |app| { let app_handle = app.handle(); if let Err(e) = setup_app_menu(&app_handle) { eprintln!("Failed to setup menu: {}", e); } if let Err(e) = setup_tray_icon(&app_handle) { eprintln!("Failed to setup tray icon: {}", e); } // Setup main window with vibrancy effect if let Some(window) = app.get_webview_window("main") { // Apply vibrancy effect for translucent glass look #[cfg(target_os = "macos")] { let _ = apply_vibrancy(&window, NSVisualEffectMaterial::Sidebar, None, None); } if let Ok(url) = window.url() { let mut base_url = url; base_url.set_query(None); base_url.set_fragment(None); base_url.set_path("/"); *app.state::().app_base_url.write().unwrap() = Some(base_url); } #[cfg(target_os = "macos")] inject_titlebar(window.clone()); apply_settings_to_window(&app_handle, &window); maybe_open_devtools(&app_handle, &window); let _ = window.set_focus(); } Ok(()) }) .on_page_load(|webview: &Webview, _payload: &PageLoadPayload| { inject_chat_link_intercept(webview); { let app = webview.app_handle(); let state = app.state::(); if state.debug_mode { inject_console_capture(webview); } } #[cfg(not(target_os = "macos"))] { let _ = webview.eval(MENU_KEY_HANDLER_SCRIPT); let app = webview.app_handle(); let state = app.state::(); let config = state.config.read().unwrap(); let temp_visible = *state.menu_temporarily_visible.read().unwrap(); let label = webview.label().to_string(); if !config.show_menu_bar && !temp_visible { if let Some(win) = app.get_webview_window(&label) { let _ = win.hide_menu(); } } if config.hide_window_decorations { if let Some(win) = app.get_webview_window(&label) { let _ = win.set_decorations(false); } } } #[cfg(target_os = "macos")] let _ = webview.eval(TITLEBAR_SCRIPT); }) .run(tauri::generate_context!()) .expect("error while running tauri application"); } ================================================ FILE: desktop/src-tauri/tauri.conf.json ================================================ { "$schema": "https://schema.tauri.app/config/2.0.0", "productName": "Onyx", "version": "0.0.0-dev", "identifier": "app.onyx.desktop", "build": { "beforeBuildCommand": "", "beforeDevCommand": "", "frontendDist": "../src" }, "app": { "withGlobalTauri": true, "windows": [ { "title": "Onyx", "label": "main", "url": "index.html", "width": 1200, "height": 800, "minWidth": 800, "minHeight": 600, "resizable": true, "fullscreen": false, "decorations": true, "transparent": true, "backgroundColor": "#1a1a2e", "titleBarStyle": "Overlay", "hiddenTitle": true, "acceptFirstMouse": true, "tabbingIdentifier": "onyx" } ], "security": { "csp": null }, "macOSPrivateApi": true }, "bundle": { "active": true, "targets": "all", "icon": [ "icons/32x32.png", "icons/128x128.png", "icons/128x128@2x.png", "icons/icon.icns", "icons/icon.ico" ], "category": "Productivity", "shortDescription": "Onyx Cloud Desktop App", "longDescription": "A lightweight desktop wrapper for Onyx Cloud - your AI-powered knowledge assistant.", "macOS": { "entitlements": null, "exceptionDomain": "cloud.onyx.app", "minimumSystemVersion": "10.15", "signingIdentity": null, "dmg": { "windowSize": { "width": 660, "height": 400 } } } }, "plugins": { "shell": { "open": true } } } ================================================ FILE: docker-bake.hcl ================================================ group "default" { targets = ["backend", "model-server", "web"] } variable "BACKEND_REPOSITORY" { default = "onyxdotapp/onyx-backend" } variable "WEB_SERVER_REPOSITORY" { default = "onyxdotapp/onyx-web-server" } variable "MODEL_SERVER_REPOSITORY" { default = "onyxdotapp/onyx-model-server" } variable "INTEGRATION_REPOSITORY" { default = "onyxdotapp/onyx-integration" } variable "CLI_REPOSITORY" { default = "onyxdotapp/onyx-cli" } variable "TAG" { default = "latest" } target "backend" { context = "backend" dockerfile = "Dockerfile" cache-from = ["type=registry,ref=${BACKEND_REPOSITORY}:latest"] cache-to = ["type=inline"] tags = ["${BACKEND_REPOSITORY}:${TAG}"] } target "web" { context = "web" dockerfile = "Dockerfile" cache-from = ["type=registry,ref=${WEB_SERVER_REPOSITORY}:latest"] cache-to = ["type=inline"] tags = ["${WEB_SERVER_REPOSITORY}:${TAG}"] } target "model-server" { context = "backend" dockerfile = "Dockerfile.model_server" cache-from = ["type=registry,ref=${MODEL_SERVER_REPOSITORY}:latest"] cache-to = ["type=inline"] tags = ["${MODEL_SERVER_REPOSITORY}:${TAG}"] } target "integration" { context = "backend" dockerfile = "tests/integration/Dockerfile" // Provide the base image via build context from the backend target contexts = { base = "target:backend" } tags = ["${INTEGRATION_REPOSITORY}:${TAG}"] } target "cli" { context = "cli" dockerfile = "Dockerfile" cache-from = ["type=registry,ref=${CLI_REPOSITORY}:latest"] cache-to = ["type=inline"] tags = ["${CLI_REPOSITORY}:${TAG}"] } ================================================ FILE: docs/METRICS.md ================================================ # Onyx Prometheus Metrics Reference ## Adding New Metrics All Prometheus metrics live in the `backend/onyx/server/metrics/` package. Follow these steps to add a new metric. ### 1. Choose the right file (or create a new one) | File | Purpose | |------|---------| | `metrics/slow_requests.py` | Slow request counter + callback | | `metrics/postgres_connection_pool.py` | SQLAlchemy connection pool metrics | | `metrics/prometheus_setup.py` | FastAPI instrumentator config (orchestrator) | If your metric is a standalone concern (e.g. cache hit rates, queue depths), create a new file under `metrics/` and keep one metric concept per file. ### 2. Define the metric Use `prometheus_client` types directly at module level: ```python # metrics/my_metric.py from prometheus_client import Counter _my_counter = Counter( "onyx_my_counter_total", # Always prefix with onyx_ "Human-readable description", ["label_a", "label_b"], # Keep label cardinality low ) ``` **Naming conventions:** - Prefix all metric names with `onyx_` - Counters: `_total` suffix (e.g. `onyx_api_slow_requests_total`) - Histograms: `_seconds` or `_bytes` suffix for durations/sizes - Gauges: no special suffix **Label cardinality:** Avoid high-cardinality labels (raw user IDs, UUIDs, raw paths). Use route templates like `/api/items/{item_id}` instead of `/api/items/abc-123`. ### 3. Wire it into the instrumentator (if request-scoped) If your metric needs to run on every HTTP request, write a callback and register it in `prometheus_setup.py`: ```python # metrics/my_metric.py from prometheus_fastapi_instrumentator.metrics import Info def my_metric_callback(info: Info) -> None: _my_counter.labels(label_a=info.method, label_b=info.modified_handler).inc() ``` ```python # metrics/prometheus_setup.py from onyx.server.metrics.my_metric import my_metric_callback # Inside setup_prometheus_metrics(): instrumentator.add(my_metric_callback) ``` ### 4. Wire it into setup_prometheus_metrics (if infrastructure-scoped) For metrics that attach to engines, pools, or background systems, add a setup function and call it from `setup_prometheus_metrics()` in `metrics/prometheus_setup.py`: ```python # metrics/my_metric.py def setup_my_metrics(resource: SomeResource) -> None: # Register collectors, attach event listeners, etc. ... ``` ```python # metrics/prometheus_setup.py — inside setup_prometheus_metrics() from onyx.server.metrics.my_metric import setup_my_metrics def setup_prometheus_metrics(app, engines=None) -> None: setup_my_metrics(resource) # Add your call here ... ``` All metrics initialization is funneled through the single `setup_prometheus_metrics()` call in `onyx/main.py:lifespan()`. Do not add separate setup calls to `main.py`. ### 5. Write tests Add tests in `backend/tests/unit/onyx/server/`. Use `unittest.mock.patch` to mock the prometheus objects — don't increment real global counters in tests. ### 6. Document the metric Add your metric to the reference tables below in this file. Include the metric name, type, labels, and description. ### 7. Update Grafana dashboards After deploying, add panels to the relevant Grafana dashboard: 1. Open Grafana and navigate to the Onyx dashboard (or create a new one) 2. Add a new panel — choose the appropriate visualization: - **Counters** → use `rate()` in a time series panel (e.g. `rate(onyx_my_counter_total[5m])`) - **Histograms** → use `histogram_quantile()` for percentiles, or `_sum/_count` for averages - **Gauges** → display directly as a stat or gauge panel 3. Add meaningful thresholds and alerts where appropriate 4. Group related panels into rows (e.g. "API Performance", "Database Pool") --- ## API Server Metrics These metrics are exposed at `GET /metrics` on the API server. ### Built-in (via `prometheus-fastapi-instrumentator`) | Metric | Type | Labels | Description | |--------|------|--------|-------------| | `http_requests_total` | Counter | `method`, `status`, `handler` | Total request count | | `http_request_duration_highr_seconds` | Histogram | _(none)_ | High-resolution latency (many buckets, no labels) | | `http_request_duration_seconds` | Histogram | `method`, `handler` | Latency by handler (custom buckets for P95/P99) | | `http_request_size_bytes` | Summary | `handler` | Incoming request content length | | `http_response_size_bytes` | Summary | `handler` | Outgoing response content length | | `http_requests_inprogress` | Gauge | `method`, `handler` | Currently in-flight requests | ### Custom (via `onyx.server.metrics`) | Metric | Type | Labels | Description | |--------|------|--------|-------------| | `onyx_api_slow_requests_total` | Counter | `method`, `handler`, `status` | Requests exceeding `SLOW_REQUEST_THRESHOLD_SECONDS` (default 1s) | ### Configuration | Env Var | Default | Description | |---------|---------|-------------| | `SLOW_REQUEST_THRESHOLD_SECONDS` | `1.0` | Duration threshold for slow request counting | ### Instrumentator Settings - `should_group_status_codes=False` — Reports exact HTTP status codes (e.g. 401, 403, 500) - `should_instrument_requests_inprogress=True` — Enables the in-progress request gauge - `inprogress_labels=True` — Breaks down in-progress gauge by `method` and `handler` - `excluded_handlers=["/health", "/metrics", "/openapi.json"]` — Excludes noisy endpoints from metrics ## Database Pool Metrics These metrics provide visibility into SQLAlchemy connection pool state across all three engines (`sync`, `async`, `readonly`). Collected via `onyx.server.metrics.postgres_connection_pool`. ### Pool State (via custom Prometheus collector — snapshot on each scrape) | Metric | Type | Labels | Description | |--------|------|--------|-------------| | `onyx_db_pool_checked_out` | Gauge | `engine` | Currently checked-out connections | | `onyx_db_pool_checked_in` | Gauge | `engine` | Idle connections available in the pool | | `onyx_db_pool_overflow` | Gauge | `engine` | Current overflow connections beyond `pool_size` | | `onyx_db_pool_size` | Gauge | `engine` | Configured pool size (constant) | ### Pool Lifecycle (via SQLAlchemy pool event listeners) | Metric | Type | Labels | Description | |--------|------|--------|-------------| | `onyx_db_pool_checkout_total` | Counter | `engine` | Total connection checkouts from the pool | | `onyx_db_pool_checkin_total` | Counter | `engine` | Total connection checkins to the pool | | `onyx_db_pool_connections_created_total` | Counter | `engine` | Total new database connections created | | `onyx_db_pool_invalidations_total` | Counter | `engine` | Total connection invalidations | | `onyx_db_pool_checkout_timeout_total` | Counter | `engine` | Total connection checkout timeouts | ### Per-Endpoint Attribution (via pool events + endpoint context middleware) | Metric | Type | Labels | Description | |--------|------|--------|-------------| | `onyx_db_connections_held_by_endpoint` | Gauge | `handler`, `engine` | DB connections currently held, by endpoint | | `onyx_db_connection_hold_seconds` | Histogram | `handler`, `engine` | Duration a DB connection is held by an endpoint | Engine label values: `sync` (main read-write), `async` (async sessions), `readonly` (read-only user). Connections from background tasks (Celery) or boot-time warmup appear as `handler="unknown"`. ## OpenSearch Search Metrics These metrics track OpenSearch search latency and throughput. Collected via `onyx.server.metrics.opensearch_search`. | Metric | Type | Labels | Description | |--------|------|--------|-------------| | `onyx_opensearch_search_client_duration_seconds` | Histogram | `search_type` | Client-side end-to-end latency (network + serialization + server execution) | | `onyx_opensearch_search_server_duration_seconds` | Histogram | `search_type` | Server-side execution time from OpenSearch `took` field | | `onyx_opensearch_search_total` | Counter | `search_type` | Total search requests sent to OpenSearch | | `onyx_opensearch_searches_in_progress` | Gauge | `search_type` | Currently in-flight OpenSearch searches | Search type label values: See `OpenSearchSearchType`. --- ## Example PromQL Queries ### Which endpoints are saturated right now? ```promql # Top 10 endpoints by in-progress requests topk(10, http_requests_inprogress) ``` ### What's the P99 latency per endpoint? ```promql # P99 latency by handler over the last 5 minutes histogram_quantile(0.99, sum by (handler, le) (rate(http_request_duration_seconds_bucket[5m]))) ``` ### Which endpoints have the highest request rate? ```promql # Requests per second by handler, top 10 topk(10, sum by (handler) (rate(http_requests_total[5m]))) ``` ### Which endpoints are returning errors? ```promql # 5xx error rate by handler sum by (handler) (rate(http_requests_total{status=~"5.."}[5m])) ``` ### Slow request hotspots ```promql # Slow requests per minute by handler sum by (handler) (rate(onyx_api_slow_requests_total[5m])) * 60 ``` ### Latency trending up? ```promql # Compare P50 latency now vs 1 hour ago histogram_quantile(0.5, sum by (le) (rate(http_request_duration_highr_seconds_bucket[5m]))) - histogram_quantile(0.5, sum by (le) (rate(http_request_duration_highr_seconds_bucket[5m] offset 1h))) ``` ### Overall request throughput ```promql # Total requests per second across all endpoints sum(rate(http_requests_total[5m])) ``` ### Pool utilization (% of capacity in use) ```promql # Sync pool utilization: checked-out / (pool_size + max_overflow) # NOTE: Replace 10 with your actual POSTGRES_API_SERVER_POOL_OVERFLOW value. onyx_db_pool_checked_out{engine="sync"} / (onyx_db_pool_size{engine="sync"} + 10) * 100 ``` ### Pool approaching exhaustion? ```promql # Alert when checked-out connections exceed 80% of pool capacity # NOTE: Replace 10 with your actual POSTGRES_API_SERVER_POOL_OVERFLOW value. onyx_db_pool_checked_out{engine="sync"} > 0.8 * (onyx_db_pool_size{engine="sync"} + 10) ``` ### Which endpoints are hogging DB connections? ```promql # Top 10 endpoints by connections currently held topk(10, onyx_db_connections_held_by_endpoint{engine="sync"}) ``` ### Which endpoints hold connections the longest? ```promql # P99 connection hold time by endpoint histogram_quantile(0.99, sum by (handler, le) (rate(onyx_db_connection_hold_seconds_bucket{engine="sync"}[5m]))) ``` ### Connection checkout/checkin rate ```promql # Checkouts per second by engine sum by (engine) (rate(onyx_db_pool_checkout_total[5m])) ``` ### OpenSearch P99 search latency by type ```promql # P99 client-side latency by search type histogram_quantile(0.99, sum by (search_type, le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m]))) ``` ### OpenSearch search throughput ```promql # Searches per second by type sum by (search_type) (rate(onyx_opensearch_search_total[5m])) ``` ### OpenSearch concurrent searches ```promql # Total in-flight searches across all instances sum(onyx_opensearch_searches_in_progress) ``` ### OpenSearch network overhead ```promql # Difference between client and server P50 reveals network/serialization cost. histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m]))) - histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m]))) ``` ================================================ FILE: examples/assistants-api/topics_analyzer.py ================================================ import argparse import os import time from datetime import datetime from datetime import timedelta from datetime import timezone from openai import OpenAI ASSISTANT_NAME = "Topic Analyzer" SYSTEM_PROMPT = """ You are a helpful assistant that analyzes topics by searching through available \ documents and providing insights. These available documents come from common \ workplace tools like Slack, emails, Confluence, Google Drive, etc. When analyzing a topic: 1. Search for relevant information using the search tool 2. Synthesize the findings into clear insights 3. Highlight key trends, patterns, or notable developments 4. Maintain objectivity and cite sources where relevant """ USER_PROMPT = """ Please analyze and provide insights about this topic: {topic}. IMPORTANT: do not mention things that are not relevant to the specified topic. \ If there is no relevant information, just say "No relevant information found." """ def wait_on_run(client: OpenAI, run, thread): # type: ignore while run.status == "queued" or run.status == "in_progress": run = client.beta.threads.runs.retrieve( thread_id=thread.id, run_id=run.id, ) time.sleep(0.5) return run def show_response(messages) -> None: # type: ignore # Get only the assistant's response text for message in messages.data[::-1]: if message.role == "assistant": for content in message.content: if content.type == "text": print(content.text) break def analyze_topics(topics: list[str]) -> None: openai_api_key = os.environ.get( "OPENAI_API_KEY", "" ) onyx_api_key = os.environ.get( "DANSWER_API_KEY", "" ) client = OpenAI( api_key=openai_api_key, base_url="http://localhost:8080/openai-assistants", default_headers={ "Authorization": f"Bearer {onyx_api_key}", }, ) # Create an assistant if it doesn't exist try: assistants = client.beta.assistants.list(limit=100) # Find the Topic Analyzer assistant if it exists assistant = next((a for a in assistants.data if a.name == ASSISTANT_NAME)) client.beta.assistants.delete(assistant.id) except Exception: pass assistant = client.beta.assistants.create( name=ASSISTANT_NAME, instructions=SYSTEM_PROMPT, tools=[{"type": "SearchTool"}], # type: ignore model="gpt-4o", ) # Process each topic individually for topic in topics: thread = client.beta.threads.create() message = client.beta.threads.messages.create( thread_id=thread.id, role="user", content=USER_PROMPT.format(topic=topic), ) run = client.beta.threads.runs.create( thread_id=thread.id, assistant_id=assistant.id, tools=[ { # type: ignore "type": "SearchTool", "retrieval_details": { "run_search": "always", "filters": { "time_cutoff": str( datetime.now(timezone.utc) - timedelta(days=7) ) }, }, } ], ) run = wait_on_run(client, run, thread) messages = client.beta.threads.messages.list( thread_id=thread.id, order="asc", after=message.id ) print(f"\nAnalysis for topic: {topic}") print("-" * 40) show_response(messages) print() # Example usage if __name__ == "__main__": parser = argparse.ArgumentParser(description="Analyze specific topics") parser.add_argument("topics", nargs="+", help="Topics to analyze (one or more)") args = parser.parse_args() analyze_topics(args.topics) ================================================ FILE: examples/widget/.eslintrc.json ================================================ { "extends": "next/core-web-vitals" } ================================================ FILE: examples/widget/.gitignore ================================================ # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. # dependencies /node_modules /.pnp .pnp.js .yarn/install-state.gz # testing /coverage # next.js /.next/ /out/ # production /build # misc .DS_Store *.pem # debug npm-debug.log* yarn-debug.log* yarn-error.log* # local env files .env*.local # vercel .vercel # typescript *.tsbuildinfo next-env.d.ts ================================================ FILE: examples/widget/README.md ================================================ # Onyx Chat Bot Widget Note: The widget requires a Onyx API key, which is a paid (cloud/enterprise) feature. This is a code example for how you can use Onyx's APIs to build a chat bot widget for a website! The main code to look at can be found in `src/app/widget/Widget.tsx`. ## Getting Started To get the widget working on your webpage, follow these steps: ### 1. Install Dependencies Ensure you have the necessary dependencies installed. From the `examples/widget/README.md` file: ```bash npm i ``` ### 2. Set Environment Variables Make sure to set the environment variables `NEXT_PUBLIC_API_URL` and `NEXT_PUBLIC_API_KEY` in a `.env` file at the root of your project: ```bash NEXT_PUBLIC_API_URL= NEXT_PUBLIC_API_KEY= ``` ### 3. Run the Development Server Start the development server to see the widget in action. ```bash npm run dev ``` Open [http://localhost:3000](http://localhost:3000) with your browser to see the result. ### 4. Integrate the Widget To integrate the widget into your webpage, you can use the `ChatWidget` component. Here’s an example of how to include it in a page component: ```jsx import ChatWidget from "path/to/ChatWidget"; function MyPage() { return (

        My Webpage

        ); } export default MyPage; ``` ### 5. Deploy Once you are satisfied with the widget, you can build and start the application for production: ```bash npm run build npm run start ``` ### Custom Styling and Configuration If you need to customize the widget, you can modify the `ChatWidget` component in the `examples/widget/src/app/widget/Widget.tsx` file. By following these steps, you should be able to get the chat widget working on your webpage. If you want to get fancier, then take a peek at the Chat implementation within Onyx itself [here](https://github.com/onyx-dot-app/onyx/blob/main/web/src/app/chat/ChatPage.tsx#L82). ================================================ FILE: examples/widget/next.config.mjs ================================================ /** @type {import('next').NextConfig} */ const nextConfig = {}; export default nextConfig; ================================================ FILE: examples/widget/package.json ================================================ { "name": "widget", "version": "0.1.0", "private": true, "scripts": { "dev": "next dev", "build": "next build", "start": "next start", "lint": "next lint" }, "dependencies": { "next": "^16.1.7", "react": "^19", "react-dom": "^19", "react-markdown": "^10.1.0" }, "devDependencies": { "@tailwindcss/postcss": "^4.1.18", "@types/node": "^25", "@types/react": "^19", "@types/react-dom": "^19", "autoprefixer": "^10.4.23", "eslint": "^9", "eslint-config-next": "16.1.2", "postcss": "^8.5.6", "tailwindcss": "^4.1.18", "typescript": "^5" } } ================================================ FILE: examples/widget/postcss.config.mjs ================================================ /** @type {import('postcss-load-config').Config} */ const config = { plugins: { "@tailwindcss/postcss": {}, }, }; export default config; ================================================ FILE: examples/widget/src/app/globals.css ================================================ @import "tailwindcss"; ================================================ FILE: examples/widget/src/app/layout.tsx ================================================ import type { Metadata } from "next"; import { Inter } from "next/font/google"; import "./globals.css"; const inter = Inter({ subsets: ["latin"] }); export const metadata: Metadata = { title: "Example Onyx Widget", description: "Example Onyx Widget", }; export default function RootLayout({ children, }: Readonly<{ children: React.ReactNode; }>) { return ( {children} ); } ================================================ FILE: examples/widget/src/app/page.tsx ================================================ import { ChatWidget } from "./widget/Widget"; export default function Home() { return (
        ); } ================================================ FILE: examples/widget/src/app/widget/Widget.tsx ================================================ "use client"; import React, { useState } from "react"; import ReactMarkdown from "react-markdown"; const API_URL = process.env.NEXT_PUBLIC_API_URL || "http://localhost:8080"; const API_KEY = process.env.NEXT_PUBLIC_API_KEY || ""; type NonEmptyObject = { [k: string]: any }; const processSingleChunk = ( chunk: string, currPartialChunk: string | null, ): [T | null, string | null] => { const completeChunk = (currPartialChunk || "") + chunk; try { // every complete chunk should be valid JSON const chunkJson = JSON.parse(completeChunk); return [chunkJson, null]; } catch (err) { // if it's not valid JSON, then it's probably an incomplete chunk return [null, completeChunk]; } }; const processRawChunkString = ( rawChunkString: string, previousPartialChunk: string | null, ): [T[], string | null] => { /* This is required because, in practice, we see that nginx does not send over each chunk one at a time even with buffering turned off. Instead, chunks are sometimes in batches or are sometimes incomplete */ if (!rawChunkString) { return [[], null]; } const chunkSections = rawChunkString .split("\n") .filter((chunk) => chunk.length > 0); let parsedChunkSections: T[] = []; let currPartialChunk = previousPartialChunk; chunkSections.forEach((chunk) => { const [processedChunk, partialChunk] = processSingleChunk( chunk, currPartialChunk, ); if (processedChunk) { parsedChunkSections.push(processedChunk); currPartialChunk = null; } else { currPartialChunk = partialChunk; } }); return [parsedChunkSections, currPartialChunk]; }; async function* handleStream( streamingResponse: Response, ): AsyncGenerator { const reader = streamingResponse.body?.getReader(); const decoder = new TextDecoder("utf-8"); let previousPartialChunk: string | null = null; while (true) { const rawChunk = await reader?.read(); if (!rawChunk) { throw new Error("Unable to process chunk"); } const { done, value } = rawChunk; if (done) { break; } const [completedChunks, partialChunk] = processRawChunkString( decoder.decode(value, { stream: true }), previousPartialChunk, ); if (!completedChunks.length && !partialChunk) { break; } previousPartialChunk = partialChunk as string | null; yield await Promise.resolve(completedChunks); } } async function* sendMessage({ message, chatSessionId, parentMessageId, }: { message: string; chatSessionId?: number; parentMessageId?: number; }) { if (!chatSessionId || !parentMessageId) { // Create a new chat session if one doesn't exist const createSessionResponse = await fetch( `${API_URL}/chat/create-chat-session`, { method: "POST", headers: { "Content-Type": "application/json", Authorization: `Bearer ${API_KEY}`, }, body: JSON.stringify({ // or specify an assistant you have defined persona_id: 0, }), }, ); if (!createSessionResponse.ok) { const errorJson = await createSessionResponse.json(); const errorMsg = errorJson.message || errorJson.detail || ""; throw Error(`Failed to create chat session - ${errorMsg}`); } const sessionData = await createSessionResponse.json(); chatSessionId = sessionData.chat_session_id; } const sendMessageResponse = await fetch(`${API_URL}/chat/send-message`, { method: "POST", headers: { "Content-Type": "application/json", Authorization: `Bearer ${API_KEY}`, }, body: JSON.stringify({ chat_session_id: chatSessionId, parent_message_id: parentMessageId || null, message: message, prompt_id: null, search_doc_ids: null, file_descriptors: [], // checkout https://github.com/onyx-dot-app/onyx/blob/main/backend/onyx/search/models.py#L105 for // all available options retrieval_options: { run_search: "always", filters: null, }, query_override: null, }), }); if (!sendMessageResponse.ok) { const errorJson = await sendMessageResponse.json(); const errorMsg = errorJson.message || errorJson.detail || ""; throw Error(`Failed to send message - ${errorMsg}`); } yield* handleStream(sendMessageResponse); } export const ChatWidget = () => { const [messages, setMessages] = useState<{ text: string; isUser: boolean }[]>( [], ); const [inputText, setInputText] = useState(""); const [isLoading, setIsLoading] = useState(false); const handleSubmit = async (e: React.FormEvent) => { e.preventDefault(); if (inputText.trim()) { const initialPrevMessages = messages; setMessages([...initialPrevMessages, { text: inputText, isUser: true }]); setInputText(""); setIsLoading(true); try { const messageGenerator = sendMessage({ message: inputText, chatSessionId: undefined, parentMessageId: undefined, }); let fullResponse = ""; for await (const chunks of messageGenerator) { for (const chunk of chunks) { if ("answer_piece" in chunk) { fullResponse += chunk.answer_piece; setMessages([ ...initialPrevMessages, { text: inputText, isUser: true }, { text: fullResponse, isUser: false }, ]); } } } } catch (error) { console.error("Error sending message:", error); setMessages((prevMessages) => [ ...prevMessages, { text: "An error occurred. Please try again.", isUser: false }, ]); } finally { setIsLoading(false); } } }; return (
        Chat Support
        {messages.map((message, index) => (
        {message.text}
        ))} {isLoading && (
        )}
        setInputText(e.target.value)} placeholder="Type a message..." className=" w-full p-2 pr-10 border border-gray-300 rounded-full focus:outline-none focus:ring-2 focus:ring-blue-500 focus:border-transparent " disabled={isLoading} />
        ); }; ================================================ FILE: examples/widget/tailwind.config.ts ================================================ import type { Config } from "tailwindcss"; const config: Config = { content: [ "./src/pages/**/*.{js,ts,jsx,tsx,mdx}", "./src/components/**/*.{js,ts,jsx,tsx,mdx}", "./src/app/**/*.{js,ts,jsx,tsx,mdx}", ], theme: { extend: { backgroundImage: { "gradient-radial": "radial-gradient(var(--tw-gradient-stops))", "gradient-conic": "conic-gradient(from 180deg at 50% 50%, var(--tw-gradient-stops))", }, }, }, plugins: [], }; export default config; ================================================ FILE: examples/widget/tsconfig.json ================================================ { "compilerOptions": { "lib": [ "dom", "dom.iterable", "esnext" ], "allowJs": true, "skipLibCheck": true, "strict": true, "noEmit": true, "esModuleInterop": true, "module": "esnext", "moduleResolution": "bundler", "resolveJsonModule": true, "isolatedModules": true, "jsx": "preserve", "incremental": true, "plugins": [ { "name": "next" } ], "paths": { "@/*": [ "./src/*" ] }, "target": "ES2017" }, "include": [ "next-env.d.ts", "**/*.ts", "**/*.tsx", ".next/types/**/*.ts", ".next/dev/types/**/*.ts" ], "exclude": [ "node_modules" ] } ================================================ FILE: extensions/chrome/LICENSE ================================================ MIT License Copyright (c) 2025 DanswerAI, Inc. Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. ================================================ FILE: extensions/chrome/README.md ================================================ # Onyx Chrome Extension The Onyx chrome extension lets you research, create, and automate with LLMs powered by your team's unique knowledge. Just hit Ctrl + O on Mac or Alt + O on Windows to instantly access Onyx in your browser: 💡 Know what your company knows, instantly with the Onyx sidebar 💬 Chat: Onyx provides a natural language chat interface as the main way of interacting with the features. 🌎 Internal Search: Ask questions and get answers from all your team's knowledge, powered by Onyx's 50+ connectors to all the tools your team uses 🚀 With a simple Ctrl + O on Mac or Alt + O on Windows - instantly summarize information from any work application ⚡️ Get quick access to the work resources you need. 🆕 Onyx new tab page puts all of your company’s knowledge at your fingertips 🤖 Access custom AI Agents for unique use cases, and give them access to tools to take action. — Onyx connects with dozens of popular workplace apps like Google Drive, Jira, Confluence, Slack, and more. Use this extension if you have an account created by your team admin. ## Installation For Onyx Cloud Users, please visit the Chrome Plugin Store (pending approval still) ## Development - Load unpacked extension in your browser - Modify files in `src` directory - Refresh extension in Chrome ## Contributing Submit issues or pull requests for improvements ================================================ FILE: extensions/chrome/manifest.json ================================================ { "manifest_version": 3, "name": "Onyx", "version": "1.1", "description": "Onyx lets you research, create, and automate with LLMs powered by your team's unique knowledge", "permissions": [ "sidePanel", "storage", "activeTab", "tabs" ], "host_permissions": [""], "background": { "service_worker": "service_worker.js", "type": "module" }, "action": { "default_icon": { "16": "public/icon16.png", "48": "public/icon48.png", "128": "public/icon128.png" }, "default_popup": "src/pages/popup.html" }, "icons": { "16": "public/icon16.png", "48": "public/icon48.png", "128": "public/icon128.png" }, "options_page": "src/pages/options.html", "chrome_url_overrides": { "newtab": "src/pages/onyx_home.html" }, "commands": { "toggleNewTabOverride": { "suggested_key": { "default": "Ctrl+Shift+O", "mac": "Command+Shift+O" }, "description": "Toggle Onyx New Tab Override" }, "openSidePanel": { "suggested_key": { "default": "Ctrl+O", "windows": "Alt+O", "mac": "MacCtrl+O" }, "description": "Open Onyx Side Panel" } }, "side_panel": { "default_path": "src/pages/panel.html" }, "omnibox": { "keyword": "onyx" }, "content_scripts": [ { "matches": [""], "js": ["src/utils/selection-icon.js"], "css": ["src/styles/selection-icon.css"] } ], "web_accessible_resources": [ { "resources": ["public/icon32.png"], "matches": [""] } ] } ================================================ FILE: extensions/chrome/service_worker.js ================================================ import { DEFAULT_ONYX_DOMAIN, CHROME_SPECIFIC_STORAGE_KEYS, ACTIONS, SIDE_PANEL_PATH, } from "./src/utils/constants.js"; // Track side panel state per window const sidePanelOpenState = new Map(); // Open welcome page on first install chrome.runtime.onInstalled.addListener((details) => { if (details.reason === "install") { chrome.storage.local.get( { [CHROME_SPECIFIC_STORAGE_KEYS.ONBOARDING_COMPLETE]: false }, (result) => { if (!result[CHROME_SPECIFIC_STORAGE_KEYS.ONBOARDING_COMPLETE]) { chrome.tabs.create({ url: "src/pages/welcome.html" }); } }, ); } }); async function setupSidePanel() { if (chrome.sidePanel) { try { // Don't auto-open side panel on action click since we have a popup menu await chrome.sidePanel.setPanelBehavior({ openPanelOnActionClick: false, }); } catch (error) { console.error("Error setting up side panel:", error); } } } async function openSidePanel(tabId) { try { await chrome.sidePanel.open({ tabId }); } catch (error) { console.error("Error opening side panel:", error); } } function encodeUserPrompt(text) { return encodeURIComponent(text).replace(/\(/g, "%28").replace(/\)/g, "%29"); } async function sendToOnyx(info, tab) { const selectedText = encodeUserPrompt(info.selectionText); const currentUrl = encodeURIComponent(tab.url); try { const result = await chrome.storage.local.get({ [CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: DEFAULT_ONYX_DOMAIN, }); const url = `${ result[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN] }${SIDE_PANEL_PATH}?user-prompt=${selectedText}`; await openSidePanel(tab.id); chrome.runtime.sendMessage({ action: ACTIONS.OPEN_SIDE_PANEL_WITH_INPUT, url: url, pageUrl: tab.url, }); } catch (error) { console.error("Error sending to Onyx:", error); } } async function toggleNewTabOverride() { try { const result = await chrome.storage.local.get( CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB, ); const newValue = !result[CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]; await chrome.storage.local.set({ [CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]: newValue, }); chrome.notifications.create({ type: "basic", iconUrl: "icon.png", title: "Onyx New Tab", message: `New Tab Override ${newValue ? "enabled" : "disabled"}`, }); // Send a message to inform all tabs about the change chrome.tabs.query({}, (tabs) => { tabs.forEach((tab) => { chrome.tabs.sendMessage(tab.id, { action: "newTabOverrideToggled", value: newValue, }); }); }); } catch (error) { console.error("Error toggling new tab override:", error); } } // Note: This listener won't fire when a popup is defined in manifest.json // The popup will show instead. This is kept as a fallback if popup is removed. chrome.action.onClicked.addListener((tab) => { openSidePanel(tab.id); }); chrome.commands.onCommand.addListener(async (command) => { if (command === ACTIONS.SEND_TO_ONYX) { try { const [tab] = await chrome.tabs.query({ active: true, lastFocusedWindow: true, }); if (tab) { const response = await chrome.tabs.sendMessage(tab.id, { action: ACTIONS.GET_SELECTED_TEXT, }); const selectedText = response?.selectedText || ""; sendToOnyx({ selectionText: selectedText }, tab); } } catch (error) { console.error("Error sending to Onyx:", error); } } else if (command === ACTIONS.TOGGLE_NEW_TAB_OVERRIDE) { toggleNewTabOverride(); } else if (command === ACTIONS.CLOSE_SIDE_PANEL) { try { await chrome.sidePanel.hide(); } catch (error) { console.error("Error closing side panel via command:", error); } } else if (command === ACTIONS.OPEN_SIDE_PANEL) { chrome.tabs.query({ active: true, lastFocusedWindow: true }, (tabs) => { if (tabs && tabs.length > 0) { const tab = tabs[0]; const windowId = tab.windowId; const isOpen = sidePanelOpenState.get(windowId) || false; if (isOpen) { chrome.sidePanel.setOptions({ enabled: false }, () => { chrome.sidePanel.setOptions({ enabled: true }); sidePanelOpenState.set(windowId, false); }); } else { chrome.sidePanel.open({ tabId: tab.id }); sidePanelOpenState.set(windowId, true); } } }); return; } else { console.log("Unhandled command:", command); } }); async function sendActiveTabUrlToPanel() { try { const [tab] = await chrome.tabs.query({ active: true, lastFocusedWindow: true, }); if (tab?.url) { chrome.runtime.sendMessage({ action: ACTIONS.TAB_URL_UPDATED, url: tab.url, }); } } catch (error) { console.error("[Onyx SW] Error sending tab URL:", error); } } chrome.runtime.onMessage.addListener((request, sender, sendResponse) => { if (request.action === ACTIONS.GET_CURRENT_ONYX_DOMAIN) { chrome.storage.local.get( { [CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: DEFAULT_ONYX_DOMAIN }, (result) => { sendResponse({ [CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: result[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN], }); }, ); return true; } if (request.action === ACTIONS.CLOSE_SIDE_PANEL) { closeSidePanel(); chrome.storage.local.get( { [CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: DEFAULT_ONYX_DOMAIN }, (result) => { chrome.tabs.create({ url: `${result[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]}/auth/login`, active: true, }); }, ); return true; } if (request.action === ACTIONS.OPEN_SIDE_PANEL_WITH_INPUT) { const { selectedText, pageUrl } = request; const tabId = sender.tab?.id; const windowId = sender.tab?.windowId; if (tabId && windowId) { chrome.storage.local.get( { [CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: DEFAULT_ONYX_DOMAIN }, (result) => { const encodedText = encodeUserPrompt(selectedText); const onyxDomain = result[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]; const url = `${onyxDomain}${SIDE_PANEL_PATH}?user-prompt=${encodedText}`; chrome.storage.session.set({ pendingInput: { url: url, pageUrl: pageUrl, timestamp: Date.now(), }, }); chrome.sidePanel .open({ windowId }) .then(() => { chrome.runtime.sendMessage({ action: ACTIONS.OPEN_ONYX_WITH_INPUT, url: url, pageUrl: pageUrl, }); }) .catch((error) => { console.error( "[Onyx SW] Error opening side panel with text:", error, ); }); }, ); } else { console.error("[Onyx SW] Missing tabId or windowId"); } return true; } if (request.action === ACTIONS.TAB_READING_ENABLED) { chrome.storage.session.set({ tabReadingEnabled: true }); sendActiveTabUrlToPanel(); return false; } if (request.action === ACTIONS.TAB_READING_DISABLED) { chrome.storage.session.set({ tabReadingEnabled: false }); return false; } }); chrome.storage.onChanged.addListener((changes, namespace) => { if ( namespace === "local" && changes[CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB] ) { const newValue = changes[CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB] .newValue; if (newValue === false) { chrome.runtime.openOptionsPage(); } } }); chrome.windows.onRemoved.addListener((windowId) => { sidePanelOpenState.delete(windowId); }); chrome.omnibox.setDefaultSuggestion({ description: 'Search Onyx for "%s"', }); chrome.omnibox.onInputEntered.addListener(async (text) => { try { const result = await chrome.storage.local.get({ [CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: DEFAULT_ONYX_DOMAIN, }); const domain = result[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]; const searchUrl = `${domain}/chat?user-prompt=${encodeURIComponent(text)}`; chrome.tabs.update({ url: searchUrl }); } catch (error) { console.error("Error handling omnibox search:", error); } }); chrome.omnibox.onInputChanged.addListener((text, suggest) => { if (text.trim()) { suggest([ { content: text, description: `Search Onyx for "${text}"`, }, ]); } }); chrome.tabs.onActivated.addListener(async (activeInfo) => { const result = await chrome.storage.session.get({ tabReadingEnabled: false }); if (!result.tabReadingEnabled) return; try { const tab = await chrome.tabs.get(activeInfo.tabId); if (tab.url) { chrome.runtime.sendMessage({ action: ACTIONS.TAB_URL_UPDATED, url: tab.url, }); } } catch (error) { console.error("[Onyx SW] Error on tab activated:", error); } }); chrome.tabs.onUpdated.addListener(async (tabId, changeInfo, tab) => { if (!changeInfo.url) return; const result = await chrome.storage.session.get({ tabReadingEnabled: false }); if (!result.tabReadingEnabled) return; try { const [activeTab] = await chrome.tabs.query({ active: true, lastFocusedWindow: true, }); if (activeTab?.id === tabId) { chrome.runtime.sendMessage({ action: ACTIONS.TAB_URL_UPDATED, url: changeInfo.url, }); } } catch (error) { console.error("[Onyx SW] Error on tab updated:", error); } }); setupSidePanel(); ================================================ FILE: extensions/chrome/src/pages/onyx_home.html ================================================ Onyx Home
        ================================================ FILE: extensions/chrome/src/pages/onyx_home.js ================================================ import { CHROME_MESSAGE, CHROME_SPECIFIC_STORAGE_KEYS, WEB_MESSAGE, } from "../utils/constants.js"; import { showErrorModal, hideErrorModal, initErrorModal, } from "../utils/error-modal.js"; import { getOnyxDomain } from "../utils/storage.js"; (function () { let mainIframe = document.getElementById("onyx-iframe"); let preloadedIframe = null; const background = document.getElementById("background"); const content = document.getElementById("content"); const DEFAULT_LIGHT_BACKGROUND_IMAGE = "https://images.unsplash.com/photo-1692520883599-d543cfe6d43d?q=80&w=2666&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"; const DEFAULT_DARK_BACKGROUND_IMAGE = "https://images.unsplash.com/photo-1692520883599-d543cfe6d43d?q=80&w=2666&auto=format&fit=crop&ixlib=rb-4.0.3&ixid=M3wxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8fA%3D%3D"; let iframeLoadTimeout; let iframeLoaded = false; initErrorModal(); async function preloadChatInterface() { preloadedIframe = document.createElement("iframe"); const domain = await getOnyxDomain(); preloadedIframe.src = domain + "/chat"; preloadedIframe.style.opacity = "0"; preloadedIframe.style.visibility = "hidden"; preloadedIframe.style.transition = "opacity 0.3s ease-in"; preloadedIframe.style.border = "none"; preloadedIframe.style.width = "100%"; preloadedIframe.style.height = "100%"; preloadedIframe.style.position = "absolute"; preloadedIframe.style.top = "0"; preloadedIframe.style.left = "0"; preloadedIframe.style.zIndex = "1"; content.appendChild(preloadedIframe); } function setIframeSrc(url) { mainIframe.src = url; startIframeLoadTimeout(); iframeLoaded = false; } function startIframeLoadTimeout() { clearTimeout(iframeLoadTimeout); iframeLoadTimeout = setTimeout(() => { if (!iframeLoaded) { try { if ( mainIframe.contentWindow.location.pathname.includes("/auth/login") ) { showLoginPage(); } else { showErrorModal(mainIframe.src); } } catch (error) { showErrorModal(mainIframe.src); } } }, 2500); } function showLoginPage() { background.style.opacity = "0"; mainIframe.style.opacity = "1"; mainIframe.style.visibility = "visible"; content.style.opacity = "1"; hideErrorModal(); } function setTheme(theme, customBackgroundImage) { const imageUrl = customBackgroundImage || (theme === "dark" ? DEFAULT_DARK_BACKGROUND_IMAGE : DEFAULT_LIGHT_BACKGROUND_IMAGE); background.style.backgroundImage = `url('${imageUrl}')`; } function fadeInContent() { content.style.transition = "opacity 0.5s ease-in"; mainIframe.style.transition = "opacity 0.5s ease-in"; content.style.opacity = "0"; mainIframe.style.opacity = "0"; mainIframe.style.visibility = "visible"; requestAnimationFrame(() => { content.style.opacity = "1"; mainIframe.style.opacity = "1"; setTimeout(() => { background.style.transition = "opacity 0.3s ease-out"; background.style.opacity = "0"; }, 500); }); } function checkOnyxPreference() { chrome.storage.local.get( [ CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB, CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN, ], (items) => { let useOnyxAsDefaultNewTab = items[CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]; if (useOnyxAsDefaultNewTab === undefined) { useOnyxAsDefaultNewTab = !!( localStorage.getItem( CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB, ) === "1" ); chrome.storage.local.set({ [CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]: useOnyxAsDefaultNewTab, }); } if (!useOnyxAsDefaultNewTab) { chrome.tabs.update({ url: "chrome://new-tab-page", }); return; } setIframeSrc(items[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN] + "/nrf"); }, ); } function loadThemeAndBackground() { chrome.storage.local.get( [ CHROME_SPECIFIC_STORAGE_KEYS.THEME, CHROME_SPECIFIC_STORAGE_KEYS.BACKGROUND_IMAGE, CHROME_SPECIFIC_STORAGE_KEYS.DARK_BG_URL, CHROME_SPECIFIC_STORAGE_KEYS.LIGHT_BG_URL, ], function (result) { const theme = result[CHROME_SPECIFIC_STORAGE_KEYS.THEME] || "light"; const customBackgroundImage = result[CHROME_SPECIFIC_STORAGE_KEYS.BACKGROUND_IMAGE]; const darkBgUrl = result[CHROME_SPECIFIC_STORAGE_KEYS.DARK_BG_URL]; const lightBgUrl = result[CHROME_SPECIFIC_STORAGE_KEYS.LIGHT_BG_URL]; let backgroundImage; if (customBackgroundImage) { backgroundImage = customBackgroundImage; } else if (theme === "dark" && darkBgUrl) { backgroundImage = darkBgUrl; } else if (theme === "light" && lightBgUrl) { backgroundImage = lightBgUrl; } setTheme(theme, backgroundImage); checkOnyxPreference(); }, ); } function loadNewPage(newSrc) { if (preloadedIframe && preloadedIframe.contentWindow) { preloadedIframe.contentWindow.postMessage( { type: WEB_MESSAGE.PAGE_CHANGE, href: newSrc }, "*", ); } else { console.error("Preloaded iframe not available"); } } function completePendingPageLoad() { if (preloadedIframe) { preloadedIframe.style.visibility = "visible"; preloadedIframe.style.opacity = "1"; preloadedIframe.style.zIndex = "1"; mainIframe.style.zIndex = "2"; mainIframe.style.opacity = "0"; setTimeout(() => { if (content.contains(mainIframe)) { content.removeChild(mainIframe); } mainIframe = preloadedIframe; mainIframe.id = "onyx-iframe"; mainIframe.style.zIndex = ""; iframeLoaded = true; clearTimeout(iframeLoadTimeout); }, 200); } else { console.warn("No preloaded iframe available"); } } chrome.storage.onChanged.addListener(function (changes, namespace) { if (namespace === "local" && changes.useOnyxAsDefaultNewTab) { checkOnyxPreference(); } }); window.addEventListener("message", function (event) { if (event.data.type === CHROME_MESSAGE.SET_DEFAULT_NEW_TAB) { chrome.storage.local.set({ useOnyxAsDefaultNewTab: event.data.value }); } else if (event.data.type === CHROME_MESSAGE.ONYX_APP_LOADED) { clearTimeout(iframeLoadTimeout); hideErrorModal(); fadeInContent(); iframeLoaded = true; } else if (event.data.type === CHROME_MESSAGE.PREFERENCES_UPDATED) { const { theme, backgroundUrl } = event.data.payload; chrome.storage.local.set( { [CHROME_SPECIFIC_STORAGE_KEYS.THEME]: theme, [CHROME_SPECIFIC_STORAGE_KEYS.BACKGROUND_IMAGE]: backgroundUrl, }, () => {}, ); } else if (event.data.type === CHROME_MESSAGE.LOAD_NEW_PAGE) { loadNewPage(event.data.href); } else if (event.data.type === CHROME_MESSAGE.LOAD_NEW_CHAT_PAGE) { completePendingPageLoad(); } }); mainIframe.onload = function () { clearTimeout(iframeLoadTimeout); startIframeLoadTimeout(); }; mainIframe.onerror = function (error) { showErrorModal(mainIframe.src); }; loadThemeAndBackground(); preloadChatInterface(); })(); ================================================ FILE: extensions/chrome/src/pages/options.html ================================================ Onyx - Settings
        Onyx

        Settings

        General
        The root URL for your Onyx instance
        Search Engine
        Type onyx followed by a space in Chrome's address bar, then enter your search query and press Enter
        Searches will be directed to your configured Onyx instance at the Root Domain above

        ================================================ FILE: extensions/chrome/src/pages/options.js ================================================ import { CHROME_SPECIFIC_STORAGE_KEYS, DEFAULT_ONYX_DOMAIN, } from "../utils/constants.js"; document.addEventListener("DOMContentLoaded", function () { const domainInput = document.getElementById("onyxDomain"); const useOnyxAsDefaultToggle = document.getElementById("useOnyxAsDefault"); const statusContainer = document.getElementById("statusContainer"); const statusElement = document.getElementById("status"); const newTabButton = document.getElementById("newTab"); const themeToggle = document.getElementById("themeToggle"); const themeIcon = document.getElementById("themeIcon"); let currentTheme = "dark"; function updateThemeIcon(theme) { if (!themeIcon) return; if (theme === "light") { themeIcon.innerHTML = ` `; } else { themeIcon.innerHTML = ` `; } } function loadStoredValues() { chrome.storage.local.get( { [CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: DEFAULT_ONYX_DOMAIN, [CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]: false, [CHROME_SPECIFIC_STORAGE_KEYS.THEME]: "dark", }, (result) => { if (domainInput) domainInput.value = result[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]; if (useOnyxAsDefaultToggle) useOnyxAsDefaultToggle.checked = result[CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]; currentTheme = result[CHROME_SPECIFIC_STORAGE_KEYS.THEME] || "dark"; updateThemeIcon(currentTheme); document.body.className = currentTheme === "light" ? "light-theme" : ""; }, ); } function saveSettings() { const domain = domainInput.value.trim(); const useOnyxAsDefault = useOnyxAsDefaultToggle ? useOnyxAsDefaultToggle.checked : false; chrome.storage.local.set( { [CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: domain, [CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]: useOnyxAsDefault, [CHROME_SPECIFIC_STORAGE_KEYS.THEME]: currentTheme, }, () => { showStatusMessage( useOnyxAsDefault ? "Settings updated. Open a new tab to test it out. Click on the extension icon to bring up Onyx from any page." : "Settings updated.", ); }, ); } function showStatusMessage(message) { if (statusElement) { const useOnyxAsDefault = useOnyxAsDefaultToggle ? useOnyxAsDefaultToggle.checked : false; statusElement.textContent = message || (useOnyxAsDefault ? "Settings updated. Open a new tab to test it out. Click on the extension icon to bring up Onyx from any page." : "Settings updated."); if (newTabButton) { newTabButton.style.display = useOnyxAsDefault ? "block" : "none"; } } if (statusContainer) { statusContainer.classList.add("show"); } setTimeout(hideStatusMessage, 5000); } function hideStatusMessage() { if (statusContainer) { statusContainer.classList.remove("show"); } } function toggleTheme() { currentTheme = currentTheme === "light" ? "dark" : "light"; updateThemeIcon(currentTheme); document.body.className = currentTheme === "light" ? "light-theme" : ""; chrome.storage.local.set({ [CHROME_SPECIFIC_STORAGE_KEYS.THEME]: currentTheme, }); } function openNewTab() { chrome.tabs.create({}); } if (domainInput) { domainInput.addEventListener("input", () => { clearTimeout(domainInput.saveTimeout); domainInput.saveTimeout = setTimeout(saveSettings, 1000); }); } if (useOnyxAsDefaultToggle) { useOnyxAsDefaultToggle.addEventListener("change", saveSettings); } if (themeToggle) { themeToggle.addEventListener("click", toggleTheme); } if (newTabButton) { newTabButton.addEventListener("click", openNewTab); } loadStoredValues(); }); ================================================ FILE: extensions/chrome/src/pages/panel.html ================================================ Onyx Panel
        Loading Onyx...
        ================================================ FILE: extensions/chrome/src/pages/panel.js ================================================ import { showErrorModal, showAuthModal } from "../utils/error-modal.js"; import { ACTIONS, CHROME_MESSAGE, WEB_MESSAGE, CHROME_SPECIFIC_STORAGE_KEYS, SIDE_PANEL_PATH, } from "../utils/constants.js"; (function () { const iframe = document.getElementById("onyx-panel-iframe"); const loadingScreen = document.getElementById("loading-screen"); let currentUrl = ""; let iframeLoaded = false; let iframeLoadTimeout; let authRequired = false; // Returns the origin of the Onyx app loaded in the iframe. // We derive the origin from iframe.src so postMessage payloads // (including tab URLs) are only delivered to the expected page. // Throws if iframe.src is not a valid URL — this is intentional: // postMessage must never fall back to the unsafe wildcard "*". function getIframeOrigin() { return new URL(iframe.src).origin; } async function checkPendingInput() { try { const result = await chrome.storage.session.get("pendingInput"); if (result.pendingInput) { const { url, pageUrl, timestamp } = result.pendingInput; if (Date.now() - timestamp < 5000) { setIframeSrc(url, pageUrl); await chrome.storage.session.remove("pendingInput"); return true; } await chrome.storage.session.remove("pendingInput"); } } catch (error) { console.error("[Onyx Panel] Error checking pending input:", error); } return false; } async function initializePanel() { loadingScreen.style.display = "flex"; loadingScreen.style.opacity = "1"; iframe.style.opacity = "0"; // Check for pending input first (from selection icon click) const hasPendingInput = await checkPendingInput(); if (!hasPendingInput) { loadOnyxDomain(); } } function setIframeSrc(url, pageUrl) { iframe.src = url; currentUrl = pageUrl; } function sendWebsiteToIframe(pageUrl) { if (iframe.contentWindow && pageUrl !== currentUrl) { iframe.contentWindow.postMessage( { type: WEB_MESSAGE.PAGE_CHANGE, url: pageUrl, }, getIframeOrigin(), ); currentUrl = pageUrl; } } function startIframeLoadTimeout() { iframeLoadTimeout = setTimeout(() => { if (!iframeLoaded) { if (authRequired) { showAuthModal(); } else { showErrorModal(iframe.src); } } }, 2500); } function handleMessage(event) { // Only trust messages from the Onyx app iframe. // Check both source identity and origin so that a cross-origin page // navigated to inside the iframe cannot send privileged extension // messages (e.g. TAB_READING_ENABLED) after iframe.src changes. // getIframeOrigin() throws if iframe.src is not yet a valid URL — // catching it here fails closed (message is rejected, not processed). if (event.source !== iframe.contentWindow) return; try { if (event.origin !== getIframeOrigin()) return; } catch { return; } if (event.data.type === CHROME_MESSAGE.ONYX_APP_LOADED) { clearTimeout(iframeLoadTimeout); iframeLoaded = true; showIframe(); if (iframe.contentWindow) { iframe.contentWindow.postMessage( { type: "PANEL_READY" }, getIframeOrigin(), ); } } else if (event.data.type === CHROME_MESSAGE.AUTH_REQUIRED) { authRequired = true; } else if (event.data.type === CHROME_MESSAGE.TAB_READING_ENABLED) { chrome.runtime.sendMessage({ action: ACTIONS.TAB_READING_ENABLED }); } else if (event.data.type === CHROME_MESSAGE.TAB_READING_DISABLED) { chrome.runtime.sendMessage({ action: ACTIONS.TAB_READING_DISABLED }); } } function showIframe() { iframe.style.opacity = "1"; loadingScreen.style.opacity = "0"; setTimeout(() => { loadingScreen.style.display = "none"; }, 500); } async function loadOnyxDomain() { const response = await chrome.runtime.sendMessage({ action: ACTIONS.GET_CURRENT_ONYX_DOMAIN, }); if (response && response[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]) { setIframeSrc( response[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN] + SIDE_PANEL_PATH, "", ); } else { console.warn("Onyx domain not found, using default"); const domain = await getOnyxDomain(); setIframeSrc(domain + SIDE_PANEL_PATH, ""); } } chrome.runtime.onMessage.addListener((request, sender, sendResponse) => { if (request.action === ACTIONS.OPEN_ONYX_WITH_INPUT) { setIframeSrc(request.url, request.pageUrl); } else if (request.action === ACTIONS.UPDATE_PAGE_URL) { sendWebsiteToIframe(request.pageUrl); } else if (request.action === ACTIONS.TAB_URL_UPDATED) { if (iframe.contentWindow) { iframe.contentWindow.postMessage( { type: CHROME_MESSAGE.TAB_URL_UPDATED, url: request.url }, getIframeOrigin(), ); } } }); window.addEventListener("message", handleMessage); initializePanel(); startIframeLoadTimeout(); })(); ================================================ FILE: extensions/chrome/src/pages/popup.html ================================================ Onyx ================================================ FILE: extensions/chrome/src/pages/popup.js ================================================ import { CHROME_SPECIFIC_STORAGE_KEYS } from "../utils/constants.js"; document.addEventListener("DOMContentLoaded", async function () { const defaultNewTabToggle = document.getElementById("defaultNewTabToggle"); const openSidePanelButton = document.getElementById("openSidePanel"); const openOptionsButton = document.getElementById("openOptions"); async function loadSetting() { const result = await chrome.storage.local.get({ [CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]: false, }); if (defaultNewTabToggle) { defaultNewTabToggle.checked = result[CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]; } } async function toggleSetting() { const currentValue = defaultNewTabToggle.checked; await chrome.storage.local.set({ [CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]: currentValue, }); } async function openSidePanel() { try { const [tab] = await chrome.tabs.query({ active: true, currentWindow: true, }); if (tab && chrome.sidePanel) { await chrome.sidePanel.open({ tabId: tab.id }); window.close(); } } catch (error) { console.error("Error opening side panel:", error); } } function openOptions() { chrome.runtime.openOptionsPage(); window.close(); } await loadSetting(); if (defaultNewTabToggle) { defaultNewTabToggle.addEventListener("change", toggleSetting); } if (openSidePanelButton) { openSidePanelButton.addEventListener("click", openSidePanel); } if (openOptionsButton) { openOptionsButton.addEventListener("click", openOptions); } }); ================================================ FILE: extensions/chrome/src/pages/welcome.html ================================================ Welcome to Onyx
        Onyx

        Onyx

        Welcome to Onyx

        Enter your Onyx instance URL to get started. This is where your Onyx deployment is hosted.

        Customize Your Experience

        Set Onyx as your new tab page for quick access to your AI assistant.

        Use Onyx as new tab page Open Onyx every time you create a new tab
        ================================================ FILE: extensions/chrome/src/pages/welcome.js ================================================ import { CHROME_SPECIFIC_STORAGE_KEYS, DEFAULT_ONYX_DOMAIN, } from "../utils/constants.js"; document.addEventListener("DOMContentLoaded", function () { const domainInput = document.getElementById("onyxDomain"); const useOnyxAsDefaultToggle = document.getElementById("useOnyxAsDefault"); const continueBtn = document.getElementById("continueBtn"); const backBtn = document.getElementById("backBtn"); const finishBtn = document.getElementById("finishBtn"); const themeToggle = document.getElementById("themeToggle"); const themeIcon = document.getElementById("themeIcon"); const step1 = document.getElementById("step1"); const step2 = document.getElementById("step2"); const stepDots = document.querySelectorAll(".step-dot"); let currentStep = 1; let currentTheme = "dark"; // Initialize theme based on system preference or stored value function initTheme() { chrome.storage.local.get( { [CHROME_SPECIFIC_STORAGE_KEYS.THEME]: null }, (result) => { const storedTheme = result[CHROME_SPECIFIC_STORAGE_KEYS.THEME]; if (storedTheme) { currentTheme = storedTheme; } else { // Check system preference currentTheme = window.matchMedia("(prefers-color-scheme: light)") .matches ? "light" : "dark"; } applyTheme(); }, ); } function applyTheme() { document.body.className = currentTheme === "light" ? "light-theme" : ""; updateThemeIcon(); } function updateThemeIcon() { if (!themeIcon) return; if (currentTheme === "light") { themeIcon.innerHTML = ` `; } else { themeIcon.innerHTML = ` `; } } function toggleTheme() { currentTheme = currentTheme === "light" ? "dark" : "light"; applyTheme(); chrome.storage.local.set({ [CHROME_SPECIFIC_STORAGE_KEYS.THEME]: currentTheme, }); } function goToStep(step) { if (step === 1) { step2.classList.remove("active"); setTimeout(() => { step1.classList.add("active"); }, 50); } else if (step === 2) { step1.classList.remove("active"); setTimeout(() => { step2.classList.add("active"); }, 50); } stepDots.forEach((dot) => { const dotStep = parseInt(dot.dataset.step); if (dotStep === step) { dot.classList.add("active"); } else { dot.classList.remove("active"); } }); currentStep = step; } // Validate domain input function validateDomain(domain) { if (!domain) return false; try { new URL(domain); return true; } catch { return false; } } function handleContinue() { const domain = domainInput.value.trim(); if (domain && !validateDomain(domain)) { domainInput.style.borderColor = "rgba(255, 100, 100, 0.5)"; domainInput.focus(); return; } domainInput.style.borderColor = ""; goToStep(2); } function handleBack() { goToStep(1); } function handleFinish() { const domain = domainInput.value.trim() || DEFAULT_ONYX_DOMAIN; const useOnyxAsDefault = useOnyxAsDefaultToggle.checked; chrome.storage.local.set( { [CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: domain, [CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]: useOnyxAsDefault, [CHROME_SPECIFIC_STORAGE_KEYS.THEME]: currentTheme, [CHROME_SPECIFIC_STORAGE_KEYS.ONBOARDING_COMPLETE]: true, }, () => { // Open a new tab if they enabled the new tab feature, otherwise just close if (useOnyxAsDefault) { chrome.tabs.create({}, () => { window.close(); }); } else { window.close(); } }, ); } // Load any existing values (in case user returns to this page) function loadStoredValues() { chrome.storage.local.get( { [CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: "", [CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]: true, }, (result) => { if (result[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]) { domainInput.value = result[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]; } useOnyxAsDefaultToggle.checked = result[CHROME_SPECIFIC_STORAGE_KEYS.USE_ONYX_AS_DEFAULT_NEW_TAB]; }, ); } if (themeToggle) { themeToggle.addEventListener("click", toggleTheme); } if (continueBtn) { continueBtn.addEventListener("click", handleContinue); } if (backBtn) { backBtn.addEventListener("click", handleBack); } if (finishBtn) { finishBtn.addEventListener("click", handleFinish); } // Allow Enter key to proceed if (domainInput) { domainInput.addEventListener("keydown", (e) => { if (e.key === "Enter") { handleContinue(); } }); } initTheme(); loadStoredValues(); }); ================================================ FILE: extensions/chrome/src/styles/selection-icon.css ================================================ #onyx-selection-icon { position: fixed; z-index: 2147483647; width: 32px; height: 32px; border-radius: 50%; background-color: #ffffff; border: 1px solid #e0e0e0; box-shadow: 0 2px 8px rgba(0, 0, 0, 0.15); cursor: pointer; display: flex; align-items: center; justify-content: center; opacity: 0; transform: scale(0.8); transition: opacity 0.15s ease, transform 0.15s ease, box-shadow 0.15s ease; pointer-events: none; } #onyx-selection-icon.visible { opacity: 1; transform: scale(1); pointer-events: auto; } #onyx-selection-icon:hover { box-shadow: 0 4px 12px rgba(0, 0, 0, 0.2); transform: scale(1.1); } #onyx-selection-icon:active { transform: scale(0.95); } #onyx-selection-icon img { width: 20px; height: 20px; pointer-events: none; } ================================================ FILE: extensions/chrome/src/styles/shared.css ================================================ /* Import Hanken Grotesk font */ @import url("https://fonts.googleapis.com/css2?family=Hanken+Grotesk:wght@300;400;500;600;700&display=swap"); :root { --primary-color: #4285f4; --primary-hover-color: #3367d6; --secondary-color: #f1f3f4; --secondary-hover-color: #e8eaed; --text-color: #333; --text-light-color: #666; --background-color: #f1f3f4; --card-background-color: #fff; --border-color: #ccc; --font-family: Arial, sans-serif; --font-hanken-grotesk: "Hanken Grotesk", sans-serif; } body { font-family: var(--font-hanken-grotesk); margin: 0; padding: 0; } .container { max-width: 500px; width: 90%; margin: 0 auto; } .card { background-color: var(--card-background-color); padding: 25px; border-radius: 10px; box-shadow: 0 3px 5px rgba(0, 0, 0, 0.1); } h1 { color: var(--text-color); font-size: 24px; font-weight: 600; margin-top: 0; margin-bottom: 20px; } .option-group { margin-bottom: 20px; } label { display: block; margin-bottom: 5px; color: var(--text-light-color); font-weight: 400; font-size: 16px; } input[type="text"] { width: 100%; padding: 8px; border: 1px solid var(--border-color); border-radius: 4px; font-size: 14px; background-color: var(--card-background-color); color: var(--text-color); } .button { width: 100%; padding: 10px 20px; border-radius: 5px; border: none; cursor: pointer; font-size: 16px; font-weight: 500; transition: background-color 0.3s; } .button.primary { background-color: var(--primary-color); color: #fff; } .button.primary:hover { background-color: var(--primary-hover-color); } .button.secondary { background-color: var(--secondary-color); color: var(--text-color); } .button.secondary:hover { background-color: var(--secondary-hover-color); } .status-container { margin-top: 10px; margin-bottom: 15px; } .status-message { margin: 0 0 10px 0; color: var(--text-color); font-weight: 500; text-align: center; font-size: 16px; transition: opacity 0.5s ease-in-out; } kbd { background-color: var(--secondary-color); border: 1px solid var(--border-color); border-radius: 3px; padding: 2px 5px; font-family: monospace; font-weight: 500; color: var(--text-color); } .toggle-label { display: flex; justify-content: space-between; align-items: center; } .toggle-switch { position: relative; display: inline-block; width: 50px; height: 24px; } .toggle-switch input { opacity: 0; width: 0; height: 0; } .slider { position: absolute; cursor: pointer; top: 0; left: 0; right: 0; bottom: 0; background-color: var(--secondary-color); transition: 0.4s; border-radius: 24px; } .slider:before { position: absolute; content: ""; height: 20px; width: 20px; left: 2px; bottom: 2px; background-color: white; transition: 0.4s; border-radius: 50%; } input:checked + .slider { background-color: var(--primary-color); } input:checked + .slider:before { transform: translateX(26px); } ================================================ FILE: extensions/chrome/src/utils/constants.js ================================================ export const THEMES = { LIGHT: "light", DARK: "dark", }; export const DEFAULT_ONYX_DOMAIN = "http://localhost:3000"; export const SIDE_PANEL_PATH = "/nrf/side-panel"; export const ACTIONS = { GET_SELECTED_TEXT: "getSelectedText", GET_CURRENT_ONYX_DOMAIN: "getCurrentOnyxDomain", UPDATE_PAGE_URL: "updatePageUrl", SEND_TO_ONYX: "sendToOnyx", OPEN_SIDE_PANEL: "openSidePanel", TOGGLE_NEW_TAB_OVERRIDE: "toggleNewTabOverride", OPEN_SIDE_PANEL_WITH_INPUT: "openSidePanelWithInput", OPEN_ONYX_WITH_INPUT: "openOnyxWithInput", CLOSE_SIDE_PANEL: "closeSidePanel", TAB_URL_UPDATED: "tabUrlUpdated", TAB_READING_ENABLED: "tabReadingEnabled", TAB_READING_DISABLED: "tabReadingDisabled", }; export const CHROME_SPECIFIC_STORAGE_KEYS = { ONYX_DOMAIN: "onyxExtensionDomain", USE_ONYX_AS_DEFAULT_NEW_TAB: "onyxExtensionDefaultNewTab", THEME: "onyxExtensionTheme", BACKGROUND_IMAGE: "onyxExtensionBackgroundImage", DARK_BG_URL: "onyxExtensionDarkBgUrl", LIGHT_BG_URL: "onyxExtensionLightBgUrl", ONBOARDING_COMPLETE: "onyxExtensionOnboardingComplete", }; export const CHROME_MESSAGE = { PREFERENCES_UPDATED: "PREFERENCES_UPDATED", ONYX_APP_LOADED: "ONYX_APP_LOADED", SET_DEFAULT_NEW_TAB: "SET_DEFAULT_NEW_TAB", LOAD_NEW_CHAT_PAGE: "LOAD_NEW_CHAT_PAGE", LOAD_NEW_PAGE: "LOAD_NEW_PAGE", AUTH_REQUIRED: "AUTH_REQUIRED", TAB_READING_ENABLED: "TAB_READING_ENABLED", TAB_READING_DISABLED: "TAB_READING_DISABLED", TAB_URL_UPDATED: "TAB_URL_UPDATED", }; export const WEB_MESSAGE = { PAGE_CHANGE: "PAGE_CHANGE", }; ================================================ FILE: extensions/chrome/src/utils/content.js ================================================ let sidePanel = null; function createSidePanel() { sidePanel = document.createElement("div"); sidePanel.id = "onyx-side-panel"; sidePanel.style.cssText = ` position: fixed; top: 0; right: -400px; width: 400px; height: 100%; background-color: white; box-shadow: -2px 0 5px rgba(0,0,0,0.2); transition: right 0.3s ease-in-out; z-index: 9999; `; const iframe = document.createElement("iframe"); iframe.style.cssText = ` width: 100%; height: 100%; border: none; `; chrome.runtime.sendMessage( { action: ACTIONS.GET_CURRENT_ONYX_DOMAIN }, function (response) { iframe.src = response[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]; }, ); sidePanel.appendChild(iframe); document.body.appendChild(sidePanel); } ================================================ FILE: extensions/chrome/src/utils/error-modal.js ================================================ import { CHROME_SPECIFIC_STORAGE_KEYS, DEFAULT_ONYX_DOMAIN, ACTIONS, } from "./constants.js"; const errorModalHTML = `
        `; const style = document.createElement("style"); style.textContent = ` :root { --background-900: #0a0a0a; --background-800: #1a1a1a; --text-light-05: rgba(255, 255, 255, 0.95); --text-light-03: rgba(255, 255, 255, 0.6); --white-10: rgba(255, 255, 255, 0.1); --white-15: rgba(255, 255, 255, 0.15); --white-20: rgba(255, 255, 255, 0.2); --white-30: rgba(255, 255, 255, 0.3); } #error-modal { position: fixed; top: 0; left: 0; width: 100%; height: 100%; display: none; align-items: center; justify-content: center; z-index: 2000; font-family: var(--font-hanken-grotesk), 'Hanken Grotesk', sans-serif; } #error-modal .modal-backdrop { position: absolute; top: 0; left: 0; width: 100%; height: 100%; background: rgba(0, 0, 0, 0.7); backdrop-filter: blur(8px); } #error-modal .modal-content { position: relative; background: linear-gradient(to bottom, rgba(10, 10, 10, 0.95), rgba(26, 26, 26, 0.95)); backdrop-filter: blur(24px); border-radius: 16px; border: 1px solid var(--white-10); max-width: 95%; width: 500px; box-shadow: 0 8px 32px rgba(0, 0, 0, 0.4); overflow: hidden; } #error-modal .modal-header { padding: 24px; border-bottom: 1px solid var(--white-10); display: flex; align-items: center; gap: 12px; } #error-modal .modal-icon { width: 40px; height: 40px; border-radius: 12px; background: rgba(255, 87, 87, 0.15); display: flex; align-items: center; justify-content: center; flex-shrink: 0; } #error-modal .modal-icon svg { width: 24px; height: 24px; stroke: #ff5757; } #error-modal .modal-icon.auth-icon { background: rgba(66, 133, 244, 0.15); } #error-modal .modal-icon.auth-icon svg { stroke: #4285f4; } #error-modal h2 { margin: 0; color: var(--text-light-05); font-size: 20px; font-weight: 600; } #error-modal .modal-body { padding: 24px; } #error-modal .modal-description { color: var(--text-light-05); margin: 0 0 20px 0; font-size: 14px; line-height: 1.6; font-weight: 400; } #error-modal .url-display { background: rgba(255, 255, 255, 0.05); border-radius: 8px; padding: 12px; border: 1px solid var(--white-10); } #error-modal .url-label { display: block; font-size: 12px; color: var(--text-light-03); margin-bottom: 6px; font-weight: 500; text-transform: uppercase; letter-spacing: 0.05em; } #error-modal .url-value { display: block; font-size: 13px; color: var(--text-light-05); word-break: break-all; font-family: monospace; line-height: 1.5; } #error-modal .modal-footer { padding: 0 24px 24px 24px; } #error-modal .button-container { display: flex; flex-direction: column; gap: 10px; margin-bottom: 16px; } #error-modal .button { padding: 12px 20px; border-radius: 8px; border: none; cursor: pointer; font-size: 14px; font-weight: 500; transition: all 0.2s; font-family: var(--font-hanken-grotesk), 'Hanken Grotesk', sans-serif; } #error-modal .button.primary { background: rgba(255, 255, 255, 0.15); color: var(--text-light-05); border: 1px solid var(--white-10); } #error-modal .button.primary:hover { background: rgba(255, 255, 255, 0.2); border-color: var(--white-20); } #error-modal .button.secondary { background: rgba(255, 255, 255, 0.05); color: var(--text-light-05); border: 1px solid var(--white-10); } #error-modal .button.secondary:hover { background: rgba(255, 255, 255, 0.1); border-color: var(--white-15); } #error-modal kbd { background: rgba(255, 255, 255, 0.1); border: 1px solid var(--white-10); border-radius: 4px; padding: 2px 6px; font-family: monospace; font-weight: 500; color: var(--text-light-05); font-size: 11px; } @media (min-width: 768px) { #error-modal .button-container { flex-direction: row; } #error-modal .button { flex: 1; } } `; const authModalHTML = `
        `; let errorModal, attemptedUrlSpan, openOptionsButton, disableOverrideButton; let authModal, openAuthButton; export function initErrorModal() { if (!document.getElementById("error-modal")) { const link = document.createElement("link"); link.rel = "stylesheet"; link.href = "../styles/shared.css"; document.head.appendChild(link); document.body.insertAdjacentHTML("beforeend", errorModalHTML); document.head.appendChild(style); errorModal = document.getElementById("error-modal"); authModal = document.getElementById("error-modal"); attemptedUrlSpan = document.getElementById("attempted-url"); openOptionsButton = document.getElementById("open-options"); disableOverrideButton = document.getElementById("disable-override"); openOptionsButton.addEventListener("click", (e) => { e.preventDefault(); chrome.runtime.openOptionsPage(); }); disableOverrideButton.addEventListener("click", () => { chrome.storage.local.set({ useOnyxAsDefaultNewTab: false }, () => { chrome.tabs.update({ url: "chrome://new-tab-page" }); }); }); } } export function showErrorModal(url) { if (!errorModal) { initErrorModal(); } if (errorModal) { errorModal.style.display = "flex"; errorModal.style.zIndex = "9999"; attemptedUrlSpan.textContent = url; document.body.style.overflow = "hidden"; } } export function hideErrorModal() { if (errorModal) { errorModal.style.display = "none"; document.body.style.overflow = "auto"; } } export function checkModalVisibility() { return errorModal ? window.getComputedStyle(errorModal).display !== "none" : false; } export function initAuthModal() { if (!document.getElementById("error-modal")) { const link = document.createElement("link"); link.rel = "stylesheet"; link.href = "../styles/shared.css"; document.head.appendChild(link); document.body.insertAdjacentHTML("beforeend", authModalHTML); document.head.appendChild(style); authModal = document.getElementById("error-modal"); openAuthButton = document.getElementById("open-auth"); openAuthButton.addEventListener("click", (e) => { e.preventDefault(); chrome.storage.local.get( { [CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: DEFAULT_ONYX_DOMAIN }, (result) => { const onyxDomain = result[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]; chrome.runtime.sendMessage( { action: ACTIONS.CLOSE_SIDE_PANEL }, () => { if (chrome.runtime.lastError) { console.error( "Error closing side panel:", chrome.runtime.lastError, ); } chrome.tabs.create( { url: `${onyxDomain}/auth/login`, active: true, }, (_) => { if (chrome.runtime.lastError) { console.error( "Error opening auth tab:", chrome.runtime.lastError, ); } }, ); }, ); }, ); }); } } export function showAuthModal() { if (!authModal) { initAuthModal(); } if (authModal) { authModal.style.display = "flex"; authModal.style.zIndex = "9999"; document.body.style.overflow = "hidden"; } } export function hideAuthModal() { if (authModal) { authModal.style.display = "none"; document.body.style.overflow = "auto"; } } ================================================ FILE: extensions/chrome/src/utils/selection-icon.js ================================================ (function () { const OPEN_SIDE_PANEL_WITH_INPUT = "openSidePanelWithInput"; let selectionIcon = null; let currentSelectedText = ""; function createSelectionIcon() { if (selectionIcon) return; selectionIcon = document.createElement("div"); selectionIcon.id = "onyx-selection-icon"; const img = document.createElement("img"); img.src = chrome.runtime.getURL("public/icon32.png"); img.alt = "Search with Onyx"; selectionIcon.appendChild(img); document.body.appendChild(selectionIcon); selectionIcon.addEventListener("mousedown", handleIconClick); } function showIcon(text) { if (!selectionIcon) { createSelectionIcon(); } currentSelectedText = text; const selection = window.getSelection(); if (!selection.rangeCount) return; const range = selection.getRangeAt(0); const rect = range.getBoundingClientRect(); const iconSize = 32; const offset = 4; let posX = rect.right + offset; let posY = rect.bottom + offset; if (posX + iconSize > window.innerWidth) { posX = rect.left - iconSize - offset; } if (posY + iconSize > window.innerHeight) { posY = rect.top - iconSize - offset; } posX = Math.max( offset, Math.min(posX, window.innerWidth - iconSize - offset), ); posY = Math.max( offset, Math.min(posY, window.innerHeight - iconSize - offset), ); selectionIcon.style.left = `${posX}px`; selectionIcon.style.top = `${posY}px`; selectionIcon.classList.add("visible"); } function hideIcon() { if (selectionIcon) { selectionIcon.classList.remove("visible"); } currentSelectedText = ""; } function handleIconClick(e) { e.preventDefault(); e.stopPropagation(); const textToSend = currentSelectedText; if (textToSend) { chrome.runtime.sendMessage( { action: OPEN_SIDE_PANEL_WITH_INPUT, selectedText: textToSend, pageUrl: window.location.href, }, (response) => { if (chrome.runtime.lastError) { console.error( "[Onyx] Error sending message:", chrome.runtime.lastError.message, ); } else { } }, ); } hideIcon(); } document.addEventListener("mouseup", (e) => { if ( e.target.id === "onyx-selection-icon" || e.target.closest("#onyx-selection-icon") ) { return; } setTimeout(() => { const selection = window.getSelection(); const selectedText = selection.toString().trim(); if (selectedText && selectedText.length > 0) { showIcon(selectedText); } else { hideIcon(); } }, 10); }); document.addEventListener("mousedown", (e) => { if ( e.target.id !== "onyx-selection-icon" && !e.target.closest("#onyx-selection-icon") ) { const selection = window.getSelection(); const selectedText = selection.toString().trim(); if (!selectedText) { hideIcon(); } } }); document.addEventListener( "scroll", () => { hideIcon(); }, true, ); document.addEventListener("selectionchange", () => { const selection = window.getSelection(); const selectedText = selection.toString().trim(); if (!selectedText) { hideIcon(); } }); if (document.readyState === "loading") { document.addEventListener("DOMContentLoaded", createSelectionIcon); } else { createSelectionIcon(); } })(); ================================================ FILE: extensions/chrome/src/utils/storage.js ================================================ import { DEFAULT_ONYX_DOMAIN, CHROME_SPECIFIC_STORAGE_KEYS, } from "./constants.js"; export async function getOnyxDomain() { const result = await chrome.storage.local.get({ [CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: DEFAULT_ONYX_DOMAIN, }); return result[CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]; } export function setOnyxDomain(domain, callback) { chrome.storage.local.set( { [CHROME_SPECIFIC_STORAGE_KEYS.ONYX_DOMAIN]: domain }, callback, ); } export function getOnyxDomainSync() { return new Promise((resolve) => { getOnyxDomain(resolve); }); } ================================================ FILE: profiling/grafana/dashboards/onyx/opensearch-search-latency.json ================================================ { "annotations": { "list": [ { "builtIn": 1, "datasource": { "type": "grafana", "uid": "-- Grafana --" }, "enable": true, "hide": true, "iconColor": "rgba(0, 211, 255, 1)", "name": "Annotations & Alerts", "type": "dashboard" } ] }, "editable": true, "fiscalYearStartMonth": 0, "graphTooltip": 1, "id": null, "links": [], "liveNow": true, "panels": [ { "title": "Client-Side Search Latency (P50 / P95 / P99)", "description": "End-to-end latency as measured by the Python client, including network round-trip and serialization overhead.", "type": "timeseries", "gridPos": { "h": 10, "w": 12, "x": 0, "y": 0 }, "id": 1, "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "fieldConfig": { "defaults": { "color": { "mode": "palette-classic" }, "custom": { "axisBorderShow": false, "axisCenteredZero": false, "axisLabel": "seconds", "axisPlacement": "auto", "drawStyle": "line", "fillOpacity": 0, "gradientMode": "none", "lineInterpolation": "smooth", "lineWidth": 2, "pointSize": 5, "scaleDistribution": { "type": "linear" }, "showPoints": "never", "spanNulls": false, "stacking": { "group": "A", "mode": "none" }, "thresholdsStyle": { "mode": "dashed" } }, "thresholds": { "mode": "absolute", "steps": [ { "color": "green", "value": null }, { "color": "yellow", "value": 0.5 }, { "color": "red", "value": 2.0 } ] }, "unit": "s", "min": 0 }, "overrides": [] }, "targets": [ { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))", "legendFormat": "P50", "refId": "A" }, { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "expr": "histogram_quantile(0.95, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))", "legendFormat": "P95", "refId": "B" }, { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "expr": "histogram_quantile(0.99, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))", "legendFormat": "P99", "refId": "C" } ] }, { "title": "Server-Side Search Latency (P50 / P95 / P99)", "description": "OpenSearch server-side execution time from the 'took' field in the response. Does not include network or client-side overhead.", "type": "timeseries", "gridPos": { "h": 10, "w": 12, "x": 12, "y": 0 }, "id": 2, "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "fieldConfig": { "defaults": { "color": { "mode": "palette-classic" }, "custom": { "axisBorderShow": false, "axisCenteredZero": false, "axisLabel": "seconds", "axisPlacement": "auto", "drawStyle": "line", "fillOpacity": 0, "gradientMode": "none", "lineInterpolation": "smooth", "lineWidth": 2, "pointSize": 5, "scaleDistribution": { "type": "linear" }, "showPoints": "never", "spanNulls": false, "stacking": { "group": "A", "mode": "none" }, "thresholdsStyle": { "mode": "dashed" } }, "thresholds": { "mode": "absolute", "steps": [ { "color": "green", "value": null }, { "color": "yellow", "value": 0.5 }, { "color": "red", "value": 2.0 } ] }, "unit": "s", "min": 0 }, "overrides": [] }, "targets": [ { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))", "legendFormat": "P50", "refId": "A" }, { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "expr": "histogram_quantile(0.95, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))", "legendFormat": "P95", "refId": "B" }, { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "expr": "histogram_quantile(0.99, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))", "legendFormat": "P99", "refId": "C" } ] }, { "title": "Client-Side Latency by Search Type (P95)", "description": "P95 client-side latency broken down by search type (hybrid, keyword, semantic, random, doc_id_retrieval).", "type": "timeseries", "gridPos": { "h": 10, "w": 12, "x": 0, "y": 10 }, "id": 3, "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "fieldConfig": { "defaults": { "color": { "mode": "palette-classic" }, "custom": { "axisBorderShow": false, "axisCenteredZero": false, "axisLabel": "seconds", "axisPlacement": "auto", "drawStyle": "line", "fillOpacity": 0, "gradientMode": "none", "lineInterpolation": "smooth", "lineWidth": 2, "pointSize": 5, "scaleDistribution": { "type": "linear" }, "showPoints": "never", "spanNulls": false, "stacking": { "group": "A", "mode": "none" }, "thresholdsStyle": { "mode": "off" } }, "unit": "s", "min": 0 }, "overrides": [] }, "targets": [ { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "expr": "histogram_quantile(0.95, sum by (search_type, le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))", "legendFormat": "{{ search_type }}", "refId": "A" } ] }, { "title": "Search Throughput by Type", "description": "Searches per second broken down by search type.", "type": "timeseries", "gridPos": { "h": 10, "w": 12, "x": 12, "y": 10 }, "id": 4, "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "fieldConfig": { "defaults": { "color": { "mode": "palette-classic" }, "custom": { "axisBorderShow": false, "axisCenteredZero": false, "axisLabel": "searches/s", "axisPlacement": "auto", "drawStyle": "line", "fillOpacity": 0, "gradientMode": "none", "lineInterpolation": "smooth", "lineWidth": 2, "pointSize": 5, "scaleDistribution": { "type": "linear" }, "showPoints": "never", "spanNulls": false, "stacking": { "group": "A", "mode": "normal" }, "thresholdsStyle": { "mode": "off" } }, "unit": "ops", "min": 0 }, "overrides": [] }, "targets": [ { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "expr": "sum by (search_type) (rate(onyx_opensearch_search_total[5m]))", "legendFormat": "{{ search_type }}", "refId": "A" } ] }, { "title": "Concurrent Searches In Progress", "description": "Number of OpenSearch searches currently in flight, broken down by search type. Summed across all instances.", "type": "timeseries", "gridPos": { "h": 10, "w": 12, "x": 0, "y": 20 }, "id": 5, "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "fieldConfig": { "defaults": { "color": { "mode": "palette-classic" }, "custom": { "axisBorderShow": false, "axisCenteredZero": false, "axisLabel": "searches", "axisPlacement": "auto", "drawStyle": "line", "fillOpacity": 0, "gradientMode": "none", "lineInterpolation": "smooth", "lineWidth": 2, "pointSize": 5, "scaleDistribution": { "type": "linear" }, "showPoints": "never", "spanNulls": false, "stacking": { "group": "A", "mode": "normal" }, "thresholdsStyle": { "mode": "off" } }, "min": 0 }, "overrides": [] }, "targets": [ { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "expr": "sum by (search_type) (onyx_opensearch_searches_in_progress)", "legendFormat": "{{ search_type }}", "refId": "A" } ] }, { "title": "Client vs Server Latency Overhead (P50)", "description": "Difference between client-side and server-side P50 latency. Reveals network, serialization, and untracked OpenSearch overhead.", "type": "timeseries", "gridPos": { "h": 10, "w": 12, "x": 12, "y": 20 }, "id": 6, "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "fieldConfig": { "defaults": { "color": { "mode": "palette-classic" }, "custom": { "axisBorderShow": false, "axisCenteredZero": false, "axisLabel": "seconds", "axisPlacement": "auto", "drawStyle": "line", "fillOpacity": 0, "gradientMode": "none", "lineInterpolation": "smooth", "lineWidth": 2, "pointSize": 5, "scaleDistribution": { "type": "linear" }, "showPoints": "never", "spanNulls": false, "stacking": { "group": "A", "mode": "none" }, "thresholdsStyle": { "mode": "off" } }, "unit": "s", "min": 0 }, "overrides": [] }, "targets": [ { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m]))) - histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))", "legendFormat": "Client - Server overhead (P50)", "refId": "A" }, { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_client_duration_seconds_bucket[5m])))", "legendFormat": "Client P50", "refId": "B" }, { "datasource": { "type": "prometheus", "uid": "${DS_PROMETHEUS}" }, "expr": "histogram_quantile(0.5, sum by (le) (rate(onyx_opensearch_search_server_duration_seconds_bucket[5m])))", "legendFormat": "Server P50", "refId": "C" } ] } ], "refresh": "5s", "schemaVersion": 37, "style": "dark", "tags": ["onyx", "opensearch", "search", "latency"], "templating": { "list": [ { "current": { "text": "Prometheus", "value": "prometheus" }, "includeAll": false, "name": "DS_PROMETHEUS", "options": [], "query": "prometheus", "refresh": 1, "type": "datasource" } ] }, "time": { "from": "now-60m", "to": "now" }, "timepicker": { "refresh_intervals": ["5s", "10s", "30s", "1m"] }, "timezone": "", "title": "Onyx OpenSearch Search Latency", "uid": "onyx-opensearch-search-latency", "version": 0, "weekStart": "" } ================================================ FILE: pyproject.toml ================================================ [build-system] requires = ["setuptools>=61"] build-backend = "setuptools.build_meta" [project] name = "onyx" version = "0.0.0" requires-python = ">=3.11" # Shared dependencies between backend and model_server dependencies = [ "aioboto3==15.1.0", "cohere==5.6.1", "fastapi==0.133.1", "google-genai==1.52.0", "litellm==1.81.6", "openai==2.14.0", "pydantic==2.11.7", "prometheus_client>=0.21.1", "prometheus_fastapi_instrumentator==7.1.0", "retry==0.9.2", # This pulls in py which is in CVE-2022-42969, must remove py from image "sentry-sdk==2.14.0", "uvicorn==0.35.0", "voyageai==0.2.3", "brotli>=1.2.0", "claude-agent-sdk>=0.1.19", "agent-client-protocol>=0.7.1", "discord-py==2.4.0", "kubernetes>=31.0.0", ] [project.optional-dependencies] # Main backend application dependencies backend = [ "aiohttp==3.13.4", "alembic==1.10.4", "asyncpg==0.30.0", "atlassian-python-api==3.41.16", "azure-cognitiveservices-speech==1.38.0", "beautifulsoup4==4.12.3", "boto3==1.39.11", "boto3-stubs[s3]==1.39.11", "celery==5.5.1", "chardet==5.2.0", "chonkie==1.0.10", "dask==2026.1.1", "ddtrace==3.10.0", "discord.py==2.4.0", "distributed==2026.1.1", "fastapi-users==15.0.4", "fastapi-users-db-sqlalchemy==7.0.0", "fastapi-limiter==0.1.6", "fastmcp==3.2.0", "filelock==3.20.3", "google-api-python-client==2.86.0", "google-auth-httplib2==0.1.0", "google-auth-oauthlib==1.0.0", # GPT4All library has issues running on Macs and python:3.11.4-slim-bookworm # will reintroduce this when library version catches up # "gpt4all==2.0.2", "httpcore==1.0.9", "httpx[http2]==0.28.1", "httpx-oauth==0.15.1", "huggingface-hub==0.35.3", "inflection==0.5.1", "jira==3.10.5", "jsonref==1.1.0", "kubernetes==31.0.0", "trafilatura==1.12.2", "langchain-core==1.2.22", "lazy_imports==1.0.1", "lxml==5.3.0", "Mako==1.2.4", "markitdown[pdf, docx, pptx, xlsx, xls]==0.1.2", "mcp[cli]==1.26.0", "msal==1.34.0", "msoffcrypto-tool==5.4.2", "Office365-REST-Python-Client==2.6.2", "oauthlib==3.2.2", # NOTE: This is frozen to avoid https://foss.heptapod.net/openpyxl/openpyxl/-/issues/2147 "openpyxl==3.0.10", "opensearch-py==3.0.0", "passlib==1.7.4", "playwright==1.55.0", "psutil==7.1.3", "psycopg2-binary==2.9.9", "puremagic==1.28", "pyairtable==3.0.1", "pycryptodome==3.19.1", "PyGithub==2.5.0", "pympler==1.1", "python-dateutil==2.8.2", "python-gitlab==5.6.0", "python-pptx==0.6.23", "pypandoc_binary==1.16.2", "pypdf==6.9.2", "pytest-mock==3.12.0", "pytest-playwright==0.7.0", "python-docx==1.1.2", "python-dotenv==1.1.1", "python-multipart==0.0.22", "pywikibot==9.0.0", "redis==5.0.8", "requests==2.33.0", "requests-oauthlib==1.3.1", "rfc3986==1.5.0", "simple-salesforce==1.12.6", "slack-sdk==3.20.2", "SQLAlchemy[mypy]==2.0.15", "starlette==0.49.3", "supervisor==4.3.0", "RapidFuzz==3.13.0", "tiktoken==0.7.0", "timeago==1.0.16", "types-openpyxl==3.0.4.7", "unstructured==0.18.27", "unstructured-client==0.42.6", "zulip==0.8.2", "hubspot-api-client==11.1.0", "asana==5.0.8", "dropbox==12.0.2", "shapely==2.0.6", "stripe==10.12.0", "urllib3==2.6.3", "mistune==3.2.0", "sendgrid==6.12.5", "exa_py==1.15.4", "braintrust==0.3.9", "langfuse==3.10.0", "nest_asyncio==1.6.0", "openinference-instrumentation==0.1.42", "opentelemetry-proto>=1.39.0", "python3-saml==1.15.0", "xmlsec==1.3.14", ] # Dev tools dev = [ "black==25.1.0", "celery-types==0.19.0", "faker==40.1.2", "hatchling==1.28.0", "ipykernel==6.29.5", "manygo==0.2.0", "matplotlib==3.10.8", "mypy-extensions==1.0.0", "mypy==1.13.0", "onyx-devtools==0.7.2", "openapi-generator-cli==7.17.0", "pandas-stubs~=2.3.3", "pre-commit==3.2.2", "pytest-alembic==0.12.1", "pytest-asyncio==1.3.0", "pytest-dotenv==0.5.2", "pytest-repeat==0.9.4", "pytest-xdist==3.8.0", "pytest==8.3.5", "release-tag==0.5.2", "reorder-python-imports-black==3.14.0", "ruff==0.12.0", "types-beautifulsoup4==4.12.0.3", "types-html5lib==1.1.11.13", "types-oauthlib==3.2.0.9", "types-passlib==1.7.7.20240106", "types-Pillow==10.2.0.20240822", "types-psutil==7.1.3.20251125", "types-psycopg2==2.9.21.10", "types-python-dateutil==2.8.19.13", "types-PyYAML==6.0.12.11", "types-pytz==2023.3.1.1", "types-regex==2023.3.23.1", "types-requests==2.32.0.20250328", "types-retry==0.9.9.3", "types-setuptools==68.0.0.3", "zizmor==1.18.0", ] # Enterprise Edition features ee = [ "posthog==3.7.4", ] # Model server specific dependencies (ML packages) model_server = [ "accelerate==1.6.0", "einops==0.8.1", "numpy==2.4.1", "safetensors==0.5.3", "sentence-transformers==4.0.2", "torch==2.9.1", "transformers==4.53.0", "sentry-sdk[fastapi,celery,starlette]==2.14.0", ] [tool.mypy] plugins = "sqlalchemy.ext.mypy.plugin" mypy_path = "backend" explicit_package_bases = true disallow_untyped_defs = true warn_unused_ignores = true enable_error_code = ["possibly-undefined"] strict_equality = true # Patterns match paths whether mypy is run from backend/ (CI) or repo root (e.g. VS Code extension with target ./backend) exclude = [ "(?:^|/)generated/", "(?:^|/)\\.venv/", "(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/skills/", "(?:^|/)onyx/server/features/build/sandbox/kubernetes/docker/templates/", ] [[tool.mypy.overrides]] module = "alembic.versions.*" disable_error_code = ["var-annotated"] [[tool.mypy.overrides]] module = "alembic_tenants.versions.*" disable_error_code = ["var-annotated"] [[tool.mypy.overrides]] module = "generated.*" follow_imports = "silent" ignore_errors = true [[tool.mypy.overrides]] module = "transformers.*" follow_imports = "skip" ignore_errors = true [tool.uv.workspace] members = ["backend", "tools/ods"] [tool.basedpyright] include = ["backend"] exclude = ["backend/generated", "backend/onyx/server/features/build/sandbox/kubernetes/docker/skills/pptx", "backend/onyx/server/features/build/sandbox/kubernetes/docker/templates/venv"] typeCheckingMode = "off" [tool.ruff] line-length = 130 target-version = "py311" [tool.ruff.lint] ignore = [ "E501", # Long lines are handled by Black. ] select = [ "ARG", "E", "F", "S324", "W", ] [tool.setuptools.packages.find] where = ["backend"] include = ["onyx*", "tests*"] ================================================ FILE: web/.dockerignore ================================================ node_modules .next /tests/ # Explicitly include src/app/build (overrides .gitignore /build pattern) !src/app/build ================================================ FILE: web/.eslintrc.json ================================================ { "extends": "next/core-web-vitals", "plugins": ["unused-imports"], "rules": { "@next/next/no-img-element": "off", "react-hooks/exhaustive-deps": "off", "no-unused-vars": "off", "@typescript-eslint/no-unused-vars": "off", "unused-imports/no-unused-imports": "warn", "unused-imports/no-unused-vars": [ "warn", { "vars": "all", "varsIgnorePattern": "^_", "args": "after-used", "argsIgnorePattern": "^_", "ignoreRestSiblings": true } ] } } ================================================ FILE: web/.gitignore ================================================ # See https://help.github.com/articles/ignoring-files/ for more about ignoring files. .env.sentry-build-plugin # dependencies node_modules /.pnp .pnp.js # testing /coverage # next.js /.next/ /out/ # production /build # misc .DS_Store *.pem # debug npm-debug.log* yarn-debug.log* yarn-error.log* .pnpm-debug.log* # local env files .env*.local # vercel .vercel # typescript *.tsbuildinfo next-env.d.ts # playwright testing temp files /admin*_auth.json /worker*_auth.json /user_auth.json /build-archive.log /test-results /output/ # generated clients ... in particular, the API to the Onyx backend itself! /src/lib/generated .jest-cache # storybook storybook-static ================================================ FILE: web/.prettierignore ================================================ **/.git **/.svn **/.hg **/node_modules **/.next **/.vscode ================================================ FILE: web/.prettierrc.json ================================================ { "trailingComma": "es5" } ================================================ FILE: web/.storybook/Introduction.mdx ================================================ import { Meta } from "@storybook/blocks"; # Onyx Storybook A living catalog for browsing, testing, and documenting Onyx UI components in isolation. --- ## What is this? This Storybook contains interactive examples of every reusable UI component in the Onyx frontend. Each component has a dedicated page with: - **Live demos** you can interact with directly - **Controls** to tweak props and see how the component responds - **Auto-generated docs** showing the full props API - **Dark mode toggle** in the toolbar to preview both themes --- ## Navigating Storybook ### Sidebar The left sidebar organizes components by layer: - **opal/core** — Low-level primitives (`Interactive`, `Hoverable`) - **opal/components** — Design system atoms (`Button`, `OpenButton`, `Tag`) - **Layouts** — Structural layouts (`Content`, `ContentAction`, `IllustrationContent`) - **refresh-components** — App-level components (inputs, modals, tables, text, etc.) Click any component to see its stories. Click **Docs** to see the auto-generated props table. ### Controls panel At the bottom of each story, the **Controls** panel lets you change props in real time. Toggle booleans, pick from enums, type in strings — the preview updates instantly. ### Theme toggle Use the paint roller icon in the top toolbar to switch between **light** and **dark** mode. All components use CSS variables that automatically adapt. --- ## Running locally ```bash cd web npm run storybook # dev server on :6006 npm run storybook:build # static build to storybook-static/ ``` --- ## Adding a new story Stories are **co-located** next to their component: ``` lib/opal/src/components/buttons/Button/ ├── components.tsx ← the component ├── Button.stories.tsx ← the story ├── styles.css └── README.md ``` ### Minimal template ```tsx import type { Meta, StoryObj } from "@storybook/react"; import { MyComponent } from "./MyComponent"; const meta: Meta = { title: "opal/components/MyComponent", // sidebar path component: MyComponent, tags: ["autodocs"], // auto-generate docs page }; export default meta; type Story = StoryObj; export const Default: Story = { args: { title: "Hello", }, }; export const WithCustomLayout: Story = { render: () => (
        ), }; ``` ### Conventions - **Title format:** `opal/core/Name`, `opal/components/Name`, `Layouts/Name`, or `refresh-components/Name` - **Tags:** Add `tags: ["autodocs"]` to auto-generate a docs page from props - **Decorators:** If your component needs `TooltipPrimitive.Provider` (anything with tooltips), add it as a decorator - **Layout:** Use `parameters: { layout: "fullscreen" }` for modals/popovers that use portals --- ## Deployment Production builds deploy to [onyx-storybook.vercel.app](https://onyx-storybook.vercel.app) automatically when PRs touching component files merge to `main`. Monitored paths: - `web/lib/opal/**` - `web/src/refresh-components/**` - `web/.storybook/**` ================================================ FILE: web/.storybook/README.md ================================================ # Onyx Storybook Storybook is an isolated development environment for UI components. It renders each component in a standalone "story" outside of the main app, so you can visually verify appearance, interact with props, and catch regressions without navigating through the full application. The Onyx Storybook covers the full component library — from low-level `@opal/core` primitives up through `refresh-components` — giving designers and engineers a shared reference for every visual state. **Production:** [onyx-storybook.vercel.app](https://onyx-storybook.vercel.app) ## Running Locally ```bash cd web npm run storybook # dev server on http://localhost:6006 npm run storybook:build # static build to storybook-static/ ``` The dev server hot-reloads when you edit a component or story file. ## Writing Stories Stories are **co-located** next to their component source: ``` lib/opal/src/core/interactive/ ├── components.tsx ← the component ├── Interactive.stories.tsx ← the story └── styles.css src/refresh-components/buttons/ ├── Button.tsx └── Button.stories.tsx ``` ### Minimal Template ```tsx import type { Meta, StoryObj } from "@storybook/react"; import { MyComponent } from "./MyComponent"; const meta: Meta = { title: "Category/MyComponent", // sidebar path component: MyComponent, tags: ["autodocs"], // generates a docs page from props }; export default meta; type Story = StoryObj; export const Default: Story = { args: { label: "Hello" }, }; ``` ### Conventions - **Title format:** `Core/Name`, `Components/Name`, `Layouts/Name`, or `refresh-components/category/Name` - **Tags:** Add `tags: ["autodocs"]` to auto-generate a props docs page - **Decorators:** Components that use Radix tooltips need a `TooltipPrimitive.Provider` decorator - **Layout:** Use `parameters: { layout: "fullscreen" }` for modals/popovers that use portals ## Dark Mode Use the theme toggle (paint roller icon) in the Storybook toolbar to switch between light and dark modes. This adds/removes the `dark` class on the preview body, matching the app's `darkMode: "class"` Tailwind config. All color tokens from `colors.css` adapt automatically. ## Deployment The production Storybook is deployed as a static site on Vercel. The build runs `npm run storybook:build` which outputs to `storybook-static/`, and Vercel serves that directory. Deploys are triggered on merges to `main` when files in `web/lib/opal/`, `web/src/refresh-components/`, or `web/.storybook/` change. ## Component Layers The sidebar organizes components by their layer in the design system: | Layer | Path | Examples | |-------|------|----------| | **Core** | `lib/opal/src/core/` | Interactive, Hoverable | | **Components** | `lib/opal/src/components/` | Button, OpenButton, Tag | | **Layouts** | `lib/opal/src/layouts/` | Content, ContentAction, IllustrationContent | | **refresh-components** | `src/refresh-components/` | Inputs, tables, modals, text, cards, tiles, etc. | ================================================ FILE: web/.storybook/main.ts ================================================ import type { StorybookConfig } from "@storybook/react-vite"; import path from "path"; const config: StorybookConfig = { stories: [ "./*.mdx", "../lib/opal/src/**/*.stories.@(ts|tsx)", "../src/refresh-components/**/*.stories.@(ts|tsx)", ], addons: ["@storybook/addon-essentials", "@storybook/addon-themes"], framework: { name: "@storybook/react-vite", options: {}, }, staticDirs: ["../public"], docs: { autodocs: "tag", }, typescript: { reactDocgen: "react-docgen-typescript", }, viteFinal: async (config) => { config.resolve = config.resolve ?? {}; config.resolve.alias = { ...config.resolve.alias, "@": path.resolve(__dirname, "../src"), "@opal": path.resolve(__dirname, "../lib/opal/src"), "@public": path.resolve(__dirname, "../public"), // Next.js module stubs for Vite "next/link": path.resolve(__dirname, "mocks/next-link.tsx"), "next/navigation": path.resolve(__dirname, "mocks/next-navigation.tsx"), "next/image": path.resolve(__dirname, "mocks/next-image.tsx"), }; // Process CSS with Tailwind via PostCSS config.css = config.css ?? {}; config.css.postcss = path.resolve(__dirname, ".."); return config; }, }; export default config; ================================================ FILE: web/.storybook/mocks/next-image.tsx ================================================ import React from "react"; interface ImageProps { src: string; alt: string; width?: number; height?: number; fill?: boolean; [key: string]: unknown; } function Image({ src, alt, width, height, fill, ...props }: ImageProps) { const fillStyle: React.CSSProperties = fill ? { position: "absolute", inset: 0, width: "100%", height: "100%" } : {}; return ( )} src={src} alt={alt} width={fill ? undefined : width} height={fill ? undefined : height} style={{ ...(props.style as React.CSSProperties), ...fillStyle }} /> ); } export default Image; ================================================ FILE: web/.storybook/mocks/next-link.tsx ================================================ import React from "react"; interface LinkProps { href: string; children: React.ReactNode; [key: string]: unknown; } function Link({ href, children, prefetch: _prefetch, scroll: _scroll, shallow: _shallow, replace: _replace, passHref: _passHref, locale: _locale, legacyBehavior: _legacyBehavior, ...props }: LinkProps) { return ( {children} ); } export default Link; ================================================ FILE: web/.storybook/mocks/next-navigation.tsx ================================================ export function useRouter() { return { push: (_url: string) => {}, replace: (_url: string) => {}, back: () => {}, forward: () => {}, refresh: () => {}, prefetch: (_url: string) => Promise.resolve(), }; } export function usePathname() { return "/"; } export function useSearchParams() { return new URLSearchParams() as ReadonlyURLSearchParams; } export function useParams() { return {}; } export function redirect(_url: string): never { throw new Error("redirect() called in Storybook"); } export function notFound(): never { throw new Error("notFound() called in Storybook"); } ================================================ FILE: web/.storybook/preview-head.html ================================================ ================================================ FILE: web/.storybook/preview.ts ================================================ import type { Preview } from "@storybook/react"; import { withThemeByClassName } from "@storybook/addon-themes"; import "../src/app/globals.css"; const preview: Preview = { parameters: { layout: "centered", backgrounds: { disable: true }, controls: { matchers: { color: /(background|color)$/i, date: /Date$/i, }, }, }, decorators: [ withThemeByClassName({ themes: { light: "", dark: "dark", }, defaultTheme: "light", }), ], }; export default preview; ================================================ FILE: web/@types/favicon-fetch.d.ts ================================================ declare module "favicon-fetch" { interface FaviconFetchOptions { uri: string; } function faviconFetch(options: FaviconFetchOptions): string | null; export default faviconFetch; } ================================================ FILE: web/@types/images.d.ts ================================================ declare module "*.png" { const content: string; export default content; } declare module "*.svg" { const content: string; export default content; } declare module "*.jpeg" { const content: string; export default content; } declare module "*.jpg" { const content: string; export default content; } declare module "*.gif" { const content: string; export default content; } declare module "*.webp" { const content: string; export default content; } declare module "*.ico" { const content: string; export default content; } ================================================ FILE: web/AGENTS.md ================================================ # Frontend Standards This file is the single source of truth for frontend coding standards across all Onyx frontend projects (including, but not limited to, `/web`, `/desktop`). # Components UI components are spread across several directories while the codebase migrates to Opal: - **`web/lib/opal/src/`** — The Opal design system. Preferred for all new components. - **`web/src/refresh-components/`** — Production components not yet migrated to Opal. - **`web/src/sections/`** — Feature-specific composite components (cards, modals, etc.). - **`web/src/layouts/`** — Page-level layout components (settings pages, etc.). **Do NOT use anything from `web/src/components/`** — this directory contains legacy components that are being phased out. Always prefer Opal first; fall back to `refresh-components` only for components not yet available in Opal. ## Opal Layouts (`lib/opal/src/layouts/`) All layout primitives are imported from `@opal/layouts`. They handle sizing, font selection, icon alignment, and optional inline editing. ```typescript import { Content, ContentAction, IllustrationContent } from "@opal/layouts"; ``` ### Content **Use this for any combination of icon + title + description.** A two-axis layout component that automatically routes to the correct internal layout (`ContentXl`, `ContentLg`, `ContentMd`, `ContentSm`) based on `sizePreset` and `variant`: | sizePreset | variant | Routes to | Layout | |---|---|---|---| | `headline` / `section` | `heading` | `ContentXl` | Icon on top (flex-col) | | `headline` / `section` | `section` | `ContentLg` | Icon inline (flex-row) | | `main-content` / `main-ui` / `secondary` | `section` / `heading` | `ContentMd` | Compact inline | | `main-content` / `main-ui` / `secondary` | `body` | `ContentSm` | Body text layout | ```typescript ``` ### ContentAction **Use this when a Content block needs right-side actions** (buttons, badges, icons, etc.). Wraps `Content` and adds a `rightChildren` slot. Accepts all `Content` props plus: - `rightChildren`: `ReactNode` — actions rendered on the right - `paddingVariant`: `SizeVariant` — controls outer padding ```typescript Edit} /> ``` ### IllustrationContent **Use this for empty states, error pages, and informational placeholders.** A vertically-stacked, center-aligned layout that pairs a large illustration (7.5rem x 7.5rem) with a title and optional description. ```typescript import SvgNoResult from "@opal/illustrations/no-result"; ``` Props: - `illustration`: `IconFunctionComponent` — optional, from `@opal/illustrations` - `title`: `string` — required - `description`: `string` — optional ## Settings Page Layout (`src/layouts/settings-layouts.tsx`) **Use this for all admin/settings pages.** Provides a standardized layout with scroll-aware sticky headers, centered content containers, and responsive behavior. ```typescript import SettingsLayouts from "@/layouts/settings-layouts"; function MySettingsPage() { return ( Save} > Settings content here ); } ``` Sub-components: - **`SettingsLayouts.Root`** — Wrapper with centered, scrollable container. Width options: `"sm"` (672px), `"sm-md"` (752px), `"md"` (872px, default), `"lg"` (992px), `"full"` (100%). - **`SettingsLayouts.Header`** — Sticky header with icon, title, description, optional `rightChildren` actions, optional `children` below (e.g., search/filter), optional `backButton`, and optional `separator`. Automatically shows a scroll shadow when scrolled. - **`SettingsLayouts.Body`** — Content container with consistent padding and vertical spacing. ## Cards (`src/sections/cards/`) **When building a card that displays information about a specific entity (agent, document set, file, connector, etc.), add it to `web/src/sections/cards/`.** Each card is a self-contained component focused on a single entity type. Cards typically include entity identification (name, avatar, icon), summary information, and quick actions. ```typescript import AgentCard from "@/sections/cards/AgentCard"; import DocumentSetCard from "@/sections/cards/DocumentSetCard"; import FileCard from "@/sections/cards/FileCard"; ``` Guidelines: - One card per entity type — keep card-specific logic within the card component. - Cards should be reusable across different pages and contexts. - Use shared components from `@opal/components`, `@opal/layouts`, and `@/refresh-components` inside cards — do not duplicate layout or styling logic. ## Button (`components/buttons/button/`) **Always use the Opal `Button`.** Do not use raw ` // Icon-only button (omit children) ) } // ❌ Bad function ContactForm() { return (